a
    hn                  	   @   sJ  d dl Z d dlmZ d dlmZmZmZ d dlZd dlmZ ddl	m
Z
 ddlmZ ddlmZ dd	lmZ dd
lmZmZ ddlmZmZ ddlmZmZmZ ddlmZ ddlmZmZ eeddG dd deZ eeddG dd deZ!edG dd dej"Z#G dd dej"Z$G dd dej"Z%d>ej"ej&ej&ej&eej& e'e'd d!d"Z(G d#d$ d$ej"Z)G d%d& d&ej"Z*G d'd( d(ej"Z+G d)d* d*eZ,G d+d, d,ej"Z-G d-d. d.ej"Z.G d/d0 d0ej/Z0G d1d2 d2eZ1ej&e2d3d4d5Z3G d6d7 d7e1Z4ed8dG d9d: d:e1Z5eG d;d< d<e1eZ6g d=Z7dS )?    N)	dataclass)CallableOptionalUnion)nn   )ACT2FN)GenerationMixin)use_kernel_forward_from_hub)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPast)ALL_ATTENTION_FUNCTIONSPreTrainedModel)ModelOutputauto_docstringcan_return_tuple   )	AutoModel   )Ovis2ConfigOvis2VisionConfigzJ
    Base class for Llava outputs, with hidden states and attentions.
    )Zcustom_introc                   @   s$   e Zd ZU dZdZeej ed< dS )Ovis2ModelOutputWithPasta  
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
        `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
        `past_key_values` input) to speed up sequential decoding.
    image_hidden_states (`torch.FloatTensor`, *optional*):
        A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
        image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
    Nimage_hidden_states)	__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__ r!   r!   d/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/ovis2/modeling_ovis2.pyr   (   s   
r   zQ
    Base class for Ovis2 causal language model (or autoregressive) outputs.
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dZeeej  ed< dZeej ed< dS )	Ovis2CausalLMOutputWithPastaj  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss (for next-token prediction).
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
        `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
        `past_key_values` input) to speed up sequential decoding.
    image_hidden_states (`torch.FloatTensor`, *optional*):
        A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
        image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
    Nlosslogitspast_key_valueshidden_states
attentionsr   )r   r   r   r   r$   r   r   r   r    r%   r&   listr'   tupler(   r   r!   r!   r!   r"   r#   >   s   
r#   ZRMSNormc                       s.   e Zd Zd fdd	Zdd Zdd Z  ZS )	Ovis2RMSNormư>c                    s&   t    tt|| _|| _dS )z;
        Ovis2RMSNorm is equivalent to T5LayerNorm
        N)super__init__r   	Parameterr   Zonesweightvariance_epsilon)selfhidden_sizeeps	__class__r!   r"   r.   _   s    
zOvis2RMSNorm.__init__c                 C   sJ   |j }|tj}|djddd}|t|| j  }| j|| S )Nr   TZkeepdim)	dtypetor   float32powmeanZrsqrtr1   r0   )r2   r'   Zinput_dtypeZvariancer!   r!   r"   forwardg   s
    zOvis2RMSNorm.forwardc                 C   s   t | jj d| j S )Nz, eps=)r*   r0   shaper1   r2   r!   r!   r"   
extra_reprn   s    zOvis2RMSNorm.extra_repr)r,   )r   r   r   r.   r>   rA   __classcell__r!   r!   r5   r"   r+   ]   s   r+   c                       s$   e Zd Z fddZdd Z  ZS )Ovis2VisionMLPc                    sx   t    || _|j| _|j| _tj| j| j|jd| _tj| j| j|jd| _	tj| j| j|jd| _
t|j | _d S NZbiasr-   r.   configr3   Zintermediate_sizer   LinearZmlp_bias	gate_projup_proj	down_projr   Z
hidden_actact_fnr2   rG   r5   r!   r"   r.   s   s    
zOvis2VisionMLP.__init__c                 C   s$   |  | | || | }|S NrK   rL   rI   rJ   r2   xrK   r!   r!   r"   r>   }   s     zOvis2VisionMLP.forwardr   r   r   r.   r>   rB   r!   r!   r5   r"   rC   r   s   
rC   c                       s6   e Zd Zed fddZejejdddZ  Z	S )Ovis2VisionEmbeddingsrG   c                    s   t    || _|j| _|j| _|j| _tj|j	| j| j| jdd| _
| j| j d | _| j| _t| j| j| _| jdt| jddd t|j|j| _d S )NZvalid)Zin_channelsZout_channelsZkernel_sizeZstridepaddingr   position_ids)r   r7   F)
persistent)r-   r.   rG   r3   	embed_dimZ
image_sizeZ
patch_sizer   ZConv2dZnum_channelspatch_embeddingZnum_patchesZnum_positions	Embeddingposition_embeddingZregister_bufferr   arangeexpandr+   rms_norm_epsrms_normrM   r5   r!   r"   r.      s"    
zOvis2VisionEmbeddings.__init__pixel_valuesreturnc                 C   sL   | j jj}|  |j|d}|ddd}| |}|| | j }|S )Nr9   r   r   )	rY   r0   r9   r:   flatten	transposer_   r[   rV   )r2   ra   Ztarget_dtypeZpatch_embeds
embeddingsr!   r!   r"   r>      s    

zOvis2VisionEmbeddings.forward)
r   r   r   r   r.   r   r   Tensorr>   rB   r!   r!   r5   r"   rS      s   rS           )modulequerykeyvalueattention_maskscalingdropoutc           
      K   s|   t ||dd| }|d ur(|| }tjj|dt jd|j}tjj	||| j
d}t ||}	|	dd }	|	|fS )Nr7   )dimr9   )ptrainingr   r   )r   matmulre   r   
functionalsoftmaxr;   r:   r9   ro   rs   
contiguous)
ri   rj   rk   rl   rm   rn   ro   kwargsattn_weightsattn_outputr!   r!   r"   eager_attention_forward   s    
r{   c                       sL   e Zd ZdZ fddZdejeej eejeej f dddZ	  Z
S )	Ovis2VisionAttention=Multi-headed attention from 'Attention Is All You Need' paperc                    s   t    || _|j| _|j| _| j| j | _| j| j | jkrZtd| j d| j d| jd | _	|j
| _d| _tj| j| j|jd| _tj| j| j|jd| _tj| j| j|jd| _tj| j| j|jd| _d S Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      FrE   r-   r.   rG   r3   rX   Znum_attention_heads	num_headshead_dim
ValueErrorscaleZattention_dropoutro   	is_causalr   rH   Zqkv_biask_projv_projq_projout_projrM   r5   r!   r"   r.      s$    

zOvis2VisionAttention.__init__Nr'   rm   rb   c              
   K   s   |j \}}}| |}| |}| |}	|||| j| jdd}|||| j| jdd}|	||| j| jdd}	t}
| j	j
dkrt| j	j
 }
|
| |||	|| j| j| jsdn| jd\}}|||| }| |}||fS z#Input shape: Batch x Time x Channelr   r   eagerrh   )r   rn   ro   r?   r   r   r   viewr   r   re   r{   rG   Z_attn_implementationr   r   r   rs   ro   reshaperw   r   r2   r'   rm   rx   
batch_sizeZ
seq_lengthrX   ZquerieskeysvaluesZattention_interfacerz   ry   r!   r!   r"   r>      s.    




zOvis2VisionAttention.forward)Nr   r   r   r   r.   r   rg   r   r*   r>   rB   r!   r!   r5   r"   r|      s    r|   c                       s$   e Zd Z fddZdd Z  ZS )Ovis2MLPc                    sx   t    || _|j| _|j| _tj| j| j|jd| _tj| j| j|jd| _	tj| j| j|jd| _
t|j | _d S rD   rF   rM   r5   r!   r"   r.      s    
zOvis2MLP.__init__c                 C   s$   |  | | || | }|S rN   rO   rP   r!   r!   r"   r>     s     zOvis2MLP.forwardrR   r!   r!   r5   r"   r      s   
r   c                       sL   e Zd ZdZ fddZdejeej eejeej f dddZ	  Z
S )	Ovis2Attentionr}   c                    s   t    || _|j| _|j| _| j| j | _| j| j | jkrZtd| j d| j d| jd | _	|j
| _d| _tj| j| j|jd| _tj| j| j|jd| _tj| j| j|jd| _tj| j| j|jd| _d S r~   r   rM   r5   r!   r"   r.   
  s$    

zOvis2Attention.__init__Nr   c              
   K   s   |j \}}}| |}| |}| |}	|||| j| jdd}|||| j| jdd}|	||| j| jdd}	t}
| j	j
dkrt| j	j
 }
|
| |||	|| j| j| jsdn| jd\}}|||| }| |}||fS r   r   r   r!   r!   r"   r>     s.    




zOvis2Attention.forward)Nr   r!   r!   r5   r"   r     s    r   c                       sP   e Zd Zed fddZd	ejeej ee e	ejejf dddZ
  ZS )
Ovis2VisionEncoderLayerrT   c                    sB   t    t|| _t|| _t|j|j| _	t|j|j| _
d S rN   )r-   r.   r   	attentionr   ffnr+   r3   r^   	rms_norm1	rms_norm2rM   r5   r!   r"   r.   E  s
    


z Ovis2VisionEncoderLayer.__init__NF)r'   rm   output_attentionsrb   c                 C   sT   |  |}| j||d\}}|| }| |}| |}|| }|rL||fS |d fS )N)r'   rm   )r   r   r   r   )r2   r'   rm   r   Znorm_hidden_statesrz   ry   Z
mlp_outputr!   r!   r"   r>   L  s    


zOvis2VisionEncoderLayer.forward)NF)r   r   r   r   r.   r   rg   r   boolr*   r>   rB   r!   r!   r5   r"   r   D  s   
  r   c                       sN   e Zd ZdZed fddZed	eej	 ee
 ee
 edddZ  ZS )
Ovis2VisionEncoderz
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`Ovis2VisionEncoderLayer`].

    Args:
        config: Ovis2VisionConfig
    rT   c                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r!   )r   ).0_rT   r!   r"   
<listcomp>i      z/Ovis2VisionEncoder.__init__.<locals>.<listcomp>F)	r-   r.   rG   r   Z
ModuleListrangeZnum_hidden_layerslayersgradient_checkpointingrM   r5   rT   r"   r.   f  s    
 zOvis2VisionEncoder.__init__N)rm   r   output_hidden_statesrb   c           
      C   s   |dur|n| j j}|dur |n| j j}|r0dnd}|r<dnd}|}| jD ]:}|r\||f }||||d}	|	d }|rJ||	d f }qJ|r||f }t|||dS )ad  
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr!   )r   r   r   last_hidden_stater'   r(   )rG   r   r   r   r   )
r2   inputs_embedsrm   r   r   Zencoder_statesZall_attentionsr'   Zencoder_layerZlayer_outputsr!   r!   r"   r>   m  s0    


zOvis2VisionEncoder.forward)NNN)r   r   r   r   r   r.   r   r   r   rg   r   r   r>   rB   r!   r!   r5   r"   r   ]  s      r   c                       sH   e Zd Zed fddZedeej ee	 ee	 dddZ
  ZS )	Ovis2VisionTransformerrT   c                    s>   t    || _t|| _t|| _t|j|j	| _
d| _d S )NF)r-   r.   rG   rS   rf   r   encoderr+   r3   r^   r_   r   rM   r5   r!   r"   r.     s    


zOvis2VisionTransformer.__init__N)rm   r   r   c                 C   sj   |d ur|n| j j}|d ur |n| j j}| |}| j||||dd}|d }| |}t||j|jdS )NT)r   rm   r   r   return_dictr   r   )	rG   r   r   rf   r   r_   r   r'   r(   )r2   ra   rm   r   r   r'   Zencoder_outputsr   r!   r!   r"   r>     s$    

zOvis2VisionTransformer.forward)NNN)r   r   r   r   r.   r   r   r   rg   r   r>   rB   r!   r!   r5   r"   r     s      r   c                       s(   e Zd Zejejd fddZ  ZS )Ovis2VisualEmbeddingTable)visual_tokensrb   c                    s8   |j tjtjtjtjtjfv r*t |S t	|| j
S rN   )r9   r   Zint8Zint16Zint32Zint64longr-   r>   rt   r0   )r2   r   r5   r!   r"   r>     s    z!Ovis2VisualEmbeddingTable.forward)r   r   r   r   rg   r>   rB   r!   r!   r5   r"   r     s   r   c                   @   s@   e Zd ZU eed< dZdZdgZdZdZ	dZ
dZdZdZdZdS )Ovis2PreTrainedModelrG   modelTr|   r&   N)r   r   r   r   r    Zbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesZ_skip_keys_device_placementZ_supports_cache_classZ_supports_flash_attnZ_supports_flex_attnZ_supports_sdpaZ_can_compile_fullgraphZ_supports_attention_backendr!   r!   r!   r"   r     s   
r   )r%   rq   c                 C   sJ   |  |}|j|ddd }tj| tjd||d}||  | }|S )NTr8   r   )Zmemory_formatg      ?)rv   maxr   Z
zeros_likeZlegacy_contiguous_formatZscatter_detach)r%   rq   Zy_softindexZy_hardretr!   r!   r"   hard_softmax  s
    
r   c                       sJ   e Zd ZU eed< ed fddZejeej	ej	f dddZ
  ZS )Ovis2VisionModelrG   rT   c                    sl   t  | || _t|| _|j| _|j| _tj|j	|j
 |j
 | j| j dd| _t| j| j | _d S NFrE   )r-   r.   rG   r   transformernum_visual_indicator_tokens
vocab_sizer   rH   r3   hidden_stridehead_linearZ	LayerNorm	head_normrM   r5   r!   r"   r.     s    

zOvis2VisionModel.__init__r`   c              	   C   sB  |  |}|j}| jjdkr|j\}}}| jj}tt|}|| |krRtd|||  | }	t	j
|ddd|	d|	fdd}||	7 }|||| ||| ||}|dddddd}||d	|| | }| |}
| |
}
| jjd
krt	j
j|
d	dd}n:| jjdkr t|
d	d}n| jjdkr>t	j
j|
d	d}|S )Nr   z.Token sequence length must be a perfect squarer   Zconstantr   r         r7   Zgumbel_argmaxT)rq   hardZ	st_argmaxrq   rv   )r   r   rG   r   r?   intmathsqrtr   r   ru   padr   Zpermuter   r   Ztokenize_functionZgumbel_softmaxr   rv   )r2   ra   outputsr   Z
num_imagesZseq_lenZ
hidden_dimr   Zsqrt_lZpad_sizer%   Z
prob_tokenr!   r!   r"   r>     s6    


zOvis2VisionModel.forward)r   r   r   r   r    r.   r   r   r*   rg   r>   rB   r!   r!   r5   r"   r     s   
r   zu
    The Ovis2 model which consists of a vision backbone and a language model, without a language modeling head.
    c                       s   e Zd Zi Zed fddZdd Zdd Zdd	 Zd
d Z	e
je
jdddZe
je
je
jdddZeede
je
jee
j ee
j eee
j  ee
j ee
j ee ee ee ee ee
j eee
jf eeef dddZ  ZS )
Ovis2ModelrT   c                    s^   t  | t|j| _t|j| _t	|jj
|j| _|jj
| _|j
| _
|j| _|   d S rN   )r-   r.   r   Zvision_configvision_towerr   from_configtext_configlanguage_modelr   r   r3   visual_embeddings_tablevisual_vocab_sizevisual_indicator_token_ids	post_initrM   r5   r!   r"   r.   4  s    
zOvis2Model.__init__c                 C   s
   | j  S rN   )r   get_input_embeddingsr@   r!   r!   r"   r   ?  s    zOvis2Model.get_input_embeddingsc                 C   s   | j | d S rN   )r   set_input_embeddingsr2   rl   r!   r!   r"   r   B  s    zOvis2Model.set_input_embeddingsc                 C   s
   || _ d S rN   r   r2   decoderr!   r!   r"   set_decoderE  s    zOvis2Model.set_decoderc                 C   s   | j S rN   r   r@   r!   r!   r"   get_decoderH  s    zOvis2Model.get_decoderr`   c           	      C   s   |  |}|j\}}}tj||| j jf|j|jd|jd}tj||gdd}| 	|}tj
| j| j j | jtjd|j}| 	|}||fS )a  
        Obtains image last hidden states from the vision tower and apply multimodal projection.

        Args:
            pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
               The tensors corresponding to the input images.
            vision_feature_layer (`Union[int, list[int]]`, *optional*):
                The index of the layer to select the vision feature. If multiple indices are provided,
                the vision feature of the corresponding indices will be concatenated to form the
                vision features.
            vision_feature_select_strategy (`str`, *optional*):
                The feature selection strategy used to select the vision feature from the vision backbone.
                Can be one of `"default"` or `"full"`
        Returns:
            image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
        F)r9   deviceZrequires_gradlayoutr   r   rc   )r   r?   r   Zzerosr   r9   r   r   catr   r\   r   r   r:   )	r2   ra   image_featuresr   Zimg_seq_lenr   Zpadding_tensorZvisual_indicatorvisual_indicator_featuresr!   r!   r"   get_image_featuresK  s(    


zOvis2Model.get_image_features)	input_idsr   r   c                 C   s   |du r8||   tj| jjtj|jdk}|d}n|| jjk}| }|	d
||j}|jd |jd  }||  | krtd| d| |S )z
        Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
        equal to the length of multimodal features. If the lengths are different, an error is raised.
        Nr9   r   r7   r   r   z6Image features and image tokens do not match: tokens: z, features )r   r   tensorrG   Zimage_token_idr   r   allsumZ	unsqueeze	expand_asr:   r?   Znumelr   )r2   r   r   r   special_image_maskZn_image_tokensZn_image_featuresr!   r!   r"   get_placeholder_maskt  s    zOvis2Model.get_placeholder_maskNr   r   ra   rm   rV   r&   r   labels	use_cacher   r   r   cache_positionlogits_to_keeprb   c                 K   sZ  |	d ur|	n| j j}	|
d ur |
n| j j}
|d u |d uA r@td|d u rT|  |}|d ur| j|d\}}| j|||d}|||}t| j	D ]v\}}|d u r||  t
j|t
j|jdk}|d}n||k|j}| r|| || |j|j||< q| jf ||||||	|
d||d
|}t|j|j|j|j|d urR|nd dS )	Nz:You must specify exactly one of input_ids or inputs_embedsra   )r   r   r   r7   T)
rm   rV   r&   r   r   r   r   r   r   r   )r   r&   r'   r(   r   )rG   r   r   r   r   r   r   Zmasked_scatter	enumerater   r   r   r   r   r   r:   anyr   r9   r   r   r   r&   r'   r(   )r2   r   ra   rm   rV   r&   r   r   r   r   r   r   r   r   rx   r   r   r   iZvisual_indicator_idmaskr   r!   r!   r"   r>     sd    
zOvis2Model.forward)NNNNNNNNNNNNr   )r   r   r   _checkpoint_conversion_mappingr   r.   r   r   r   r   r   r   r   
LongTensorr   r   r   r   rg   r)   r   r   r   r*   r   r>   rB   r!   r!   r5   r"   r   ,  sT   *             
r   c                       s  e Zd Zi ZdgZed fddZdd Zdd Ze	j
d	d
dZdd Zdd ZejdddZedd Zedd Zedd Zeed ejejeej eej eeej  eej eej ee ee ee ee eej eeejf eeef dddZ d! fdd	Z!  Z"S )"Ovis2ForConditionalGenerationzlm_head.weightrT   c                    s8   t  | t|| _tj|j|jdd| _| 	  d S r   )
r-   r.   r   r   r   rH   r3   r   lm_headr   rM   r5   r!   r"   r.     s    
z&Ovis2ForConditionalGeneration.__init__c                 C   s
   | j  S rN   )r   r   r@   r!   r!   r"   r     s    z2Ovis2ForConditionalGeneration.get_input_embeddingsc                 C   s   | j | d S rN   )r   r   r   r!   r!   r"   r     s    z2Ovis2ForConditionalGeneration.set_input_embeddings)rb   c                 C   s   | j S rN   )r   r@   r!   r!   r"   get_output_embeddings  s    z3Ovis2ForConditionalGeneration.get_output_embeddingsc                 C   s   | j | d S rN   )r   r   r   r!   r!   r"   r     s    z)Ovis2ForConditionalGeneration.set_decoderc                 C   s
   | j  S rN   )r   r   r@   r!   r!   r"   r     s    z)Ovis2ForConditionalGeneration.get_decoderr   c                 C   s   | j j|dS )Nr   )r   r   )r2   ra   r!   r!   r"   r     s    z0Ovis2ForConditionalGeneration.get_image_featuresc                 C   s   | j jS rN   )r   r   r@   r!   r!   r"   r     s    z,Ovis2ForConditionalGeneration.language_modelc                 C   s   | j jS rN   )r   r   r@   r!   r!   r"   r     s    z*Ovis2ForConditionalGeneration.vision_towerc                 C   s   t dd S )NzNot needed for Ovis2)AttributeErrorr@   r!   r!   r"   multi_modal_projector  s    z3Ovis2ForConditionalGeneration.multi_modal_projectorNr   r   c                 K   s   |	dur|	n| j j}	|
dur |
n| j j}
| jf ||||||||	|
d|d|}|d }t|trnt| dn|}| |dd|ddf }d}|dur| jf ||| j j	j
d|}t|||j|j|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Ovis2ForConditionalGeneration

        >>> model = Ovis2ForConditionalGeneration.from_pretrained("thisisiron/Ovis2-2B-hf")
        >>> processor = AutoProcessor.from_pretrained("thisisiron/Ovis2-2B-hf")

        >>> prompt = "<|im_start|>user\n<image>\nDescribe the image.<|im_end|>\n<|im_start|>assistant\n"
        >>> url = "http://images.cocodataset.org/val2014/COCO_val2014_000000537955.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, text=prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
        >>> processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
        "user\n\nDescribe the image.\nassistant\nThe image features a brown dog standing on a wooden floor, looking up with"
        ```NT)r   ra   rm   rV   r&   r   r   r   r   r   r   r   )r%   r   r   )r$   r%   r&   r'   r(   r   )rG   r   r   r   
isinstancer   slicer   Zloss_functionr   r   r#   r&   r'   r(   r   )r2   r   ra   rm   rV   r&   r   r   r   r   r   r   r   r   rx   r   r'   Zslice_indicesr%   r$   r!   r!   r"   r>     sH    .z%Ovis2ForConditionalGeneration.forwardc           
         s8   t  j|f|||||d|}	|d dkr4||	d< |	S )N)r&   r   rm   r   r   r   ra   )r-   prepare_inputs_for_generation)
r2   r   r&   r   ra   rm   r   r   rx   Zmodel_inputsr5   r!   r"   r   [  s    
z;Ovis2ForConditionalGeneration.prepare_inputs_for_generation)NNNNNNNNNNNNr   )NNNNNN)#r   r   r   r   Z_tied_weights_keysr   r.   r   r   r   Moduler   r   r   r   r   r   propertyr   r   r   r   r   r   r   rg   r)   r   r   r   r*   r#   r>   r   rB   r!   r!   r5   r"   r     sj   


             
W      r   )r   r   r   )rh   )8r   dataclassesr   typingr   r   r   r   r   Zactivationsr   Z
generationr	   Zintegrationsr
   Zmodeling_layersr   Zmodeling_outputsr   r   Zmodeling_utilsr   r   utilsr   r   r   autor   Zconfiguration_ovis2r   r   r   r#   r   r+   rC   rS   rg   floatr{   r|   r   r   r   r   r   rZ   r   r   r   r   r   r   r   __all__r!   r!   r!   r"   <module>   sp   ( ==P*
5 + 