a
    h                  	   @   s  d dl Zd dlmZ d dlmZmZmZ d dlZd dl	m
Z
 ddlmZ ddlmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZ ddlmZmZ ddlmZ ddl m!Z!m"Z"m#Z#m$Z$m%Z% ddl&m'Z' ddl(m)Z)m*Z* edG dd de
j+Z,d@e
j+ej-ej-ej-eej- e.e.dddZ/G dd de
j+Z0e#G dd deZ1ee#ddG d d! d!eZ2G d"d# d#e
j+Z3G d$d% d%e
j+Z4G d&d' d'e
j+Z5e
j6e,d(Z7G d)d* d*eZ8G d+d, d,e
j+Z9e#G d-d. d.e1Z:e#G d/d0 d0eZ;G d1d2 d2e
j+Z<ee#d3dG d4d5 d5eZ=e#d6dG d7d8 d8e;Z>ee#d9dG d:d; d;e!Z?e#d<dG d=d> d>e;eZ@g d?ZAdS )A    N)	dataclass)CallableOptionalUnion   )ACT2FN)Cache)GenerationMixin)use_kernel_forward_from_hub)FlashAttentionKwargs)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPastBaseModelOutputWithPooling)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)ModelOutputTransformersKwargsauto_docstringcan_return_tuple	torch_int   )	AutoModel   )InternVLConfigInternVLVisionConfigZRMSNormc                       s.   e Zd Zd fdd	Zdd Zdd Z  ZS )	InternVLVisionRMSNormư>c                    s&   t    tt|| _|| _dS )zD
        InternVLVisionRMSNorm is equivalent to T5LayerNorm
        N)super__init__nn	Parametertorchonesweightvariance_epsilon)selfhidden_sizeeps	__class__ j/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/internvl/modeling_internvl.pyr    .   s    
zInternVLVisionRMSNorm.__init__c                 C   sJ   |j }|tj}|djddd}|t|| j  }| j|| S )Nr   T)Zkeepdim)	dtypetor#   Zfloat32powmeanZrsqrtr&   r%   )r'   hidden_statesZinput_dtypeZvariancer,   r,   r-   forward6   s
    zInternVLVisionRMSNorm.forwardc                 C   s   t | jj d| j S )Nz, eps=)tupler%   shaper&   r'   r,   r,   r-   
extra_repr=   s    z InternVLVisionRMSNorm.extra_repr)r   )__name__
__module____qualname__r    r4   r8   __classcell__r,   r,   r*   r-   r   ,   s   r           )modulequerykeyvalueattention_maskscalingdropoutc                 K   s   |}|}	t ||dd| }
|d urV|d d d d d d d |jd f }|
| }
tjj|
dd}
tjj|
|| jd}
t |
|	}|dd	 }||
fS )Nr   r   r.   dim)ptrainingr   )
r#   matmul	transposer6   r!   
functionalZsoftmaxrD   rI   
contiguous)r>   r?   r@   rA   rB   rC   rD   kwargs
key_statesvalue_statesattn_weightsZcausal_maskattn_outputr,   r,   r-   eager_attention_forwardA   s    
&rS   c                       sN   e Zd ZdZed fddZd	ejeej eej e	e
 dddZ  ZS )
InternVLVisionAttentionz+Attention Class for InternVL Vision Encoderconfigc                    s@  t    || _|j| _|j| _| j| j | _| j| j | jkrZtd| j d| j d| jd | _	|j
| _
|j}|j}d| _tj| j| j| j |jd| _tj| j| j| j |jd| _tj| j| j| j |jd| _t| j| j| _|dkrt|nt | _|rt| jnt | _|r2t| jnt | _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      FZbiasr   )r   r    rV   r(   	embed_dimZnum_attention_heads	num_headshead_dim
ValueErrorscaleattention_dropoutprojection_dropoutZuse_qk_norm	is_causalr!   LinearZattention_biasq_projk_projv_projprojection_layerDropoutIdentityr   q_normk_norm)r'   rV   Zproj_dropoutZqk_normr*   r,   r-   r    _   s.    

z InternVLVisionAttention.__init__N)r3   rB   output_attentionsrN   c                 K   s   |  \}}}| |}| |}	| |}
| |}| |	}	|||| j| j	dd}|	||| j| j	dd}	|

||| j| j	dd}
t}| jjdkrt| jj }|| ||	|
|f| jsdn| j| jdd|\}}|||| j}| |}| |}|r||fn|d f}|S )Nr   r   eagerr=   F)rD   rC   r_   )sizera   rb   rc   rg   rh   reshaperY   rZ   rK   viewrS   rV   Z_attn_implementationr   rI   r]   r\   rX   rd   r^   )r'   r3   rB   ri   rN   
batch_sizeseq_len_Zquery_statesrO   rP   Zattention_interfacerR   rQ   outputoutputsr,   r,   r-   r4   {   s<    




	


zInternVLVisionAttention.forward)NN)r9   r:   r;   __doc__r   r    r#   Tensorr   r   r   r4   r<   r,   r,   r*   r-   rT   \   s     rT   c                       sH   e Zd ZU eed< dZdZdZdgZdZ	dZ
dZdZ fddZ  ZS )InternVLVisionPreTrainedModelrV   Zinternvl_visionpixel_valuesTInternVLVisionLayerc                    s   t  | t|trP|jj  |jdur8|jj  |jdur~|jj  n.t|t	r~|j
j| jj |jj| jj dS )zInitialize the weightsN)r   _init_weights
isinstanceInternVLVisionEmbeddings	cls_tokendataZzero_
mask_tokenposition_embeddingsrw   lambda_1Zfill_rV   layer_scale_init_valuelambda_2)r'   r>   r*   r,   r-   rx      s    



z+InternVLVisionPreTrainedModel._init_weights)r9   r:   r;   r   __annotations__base_model_prefixZmain_input_namesupports_gradient_checkpointingZ_no_split_modules_supports_sdpa_supports_flash_attn_supports_flex_attn_supports_attention_backendrx   r<   r,   r,   r*   r-   ru      s   
ru   z7
    Class for outputs of [`InternVLVisionModel`].
    )Zcustom_introc                   @   s   e Zd ZdZdS )$InternVLVisionModelOutputWithPoolingaF  
    pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
        Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
        *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
        will be returned.
    N)r9   r:   r;   rs   r,   r,   r,   r-   r      s   r   c                       s4   e Zd ZdZ fddZejejdddZ  ZS )InternVLVisionPatchEmbeddingsz
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j }}|j|j }}|d |d  |d |d   }|d |d  |d |d  f}|| _|| _|| _|| _|| _tj	||||d| _
d S )Nr   r   )Zkernel_sizeZstride)r   r    
image_size
patch_sizenum_channelsr(   num_patchespatch_shaper!   ZConv2d
projection)r'   rV   r   r   r   r(   r   r   r*   r,   r-   r       s    
  z&InternVLVisionPatchEmbeddings.__init__)rv   returnc           	      C   s^   |j \}}}}|| jkr td| |}|j d |j d  }}|ddd}|||ffS )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   r   r   )r6   r   r[   r   flattenrK   )	r'   rv   rn   r   heightwidth
embeddingspatch_heightpatch_widthr,   r,   r-   r4      s    

z%InternVLVisionPatchEmbeddings.forward)	r9   r:   r;   rs   r    r#   rt   r4   r<   r,   r,   r*   r-   r      s   r   c                       s^   e Zd ZdZedd fddZejeeejdddZ	deje
ej ejd	d
dZ  ZS )rz   zc
    Construct the CLS token, position and patch embeddings. Optionally, also the mask token.

    NrV   r   c                    s   t    ttdd|j| _|jrBttdd|j| _	nd | _	t
|| _|j| _t|jtjjrp|jn
|j|jf| _| jj}|jrttd|d |j| _nd | _t|j| _d S )Nr   )r   r    r!   r"   r#   Zzerosr(   r{   Zuse_mask_tokenr}   r   patch_embeddingsr   ry   r   collectionsabcIterabler   Z use_absolute_position_embeddingsr~   re   hidden_dropout_probrD   )r'   rV   r   r*   r,   r-   r       s     


z!InternVLVisionEmbeddings.__init__)r   r   r   r   c                 C   s   |j d d }| jj d d }tj s>||kr>||kr>| jS | jddddf }| jddddf }|j d }|| jd  }	|| jd  }
t|d }|d|||}|dddd}t	j
j||	|
fdd	d
}|dddddd|}tj||fddS )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   Nr.   r         ?r   r   ZbicubicF)rk   modeZalign_cornersrF   )r6   r~   r#   Zjit
is_tracingr   r   rl   permuter!   rL   Zinterpolaterm   cat)r'   r   r   r   r   Znum_positionsZclass_pos_embedZpatch_pos_embedrG   Z
new_heightZ	new_widthZsqrt_num_positionsr,   r,   r-   interpolate_pos_encoding  s(    

z1InternVLVisionEmbeddings.interpolate_pos_encoding)rv   bool_masked_posr   c                 C   s   |j \}}}}| |\}\}}| \}	}
}|d urj| j|	|
d}|d|}|d|  ||  }| j|	dd}tj	||fdd}| j
d ur|| ||| }| |}|||ffS )Nr.   r   rF   )r6   r   rk   r}   expand	unsqueezeZtype_asr{   r#   r   r~   r   rD   )r'   rv   r   rp   r   r   r   r   r   rn   ro   Zmask_tokenswZ
cls_tokensr,   r,   r-   r4   :  s    

z InternVLVisionEmbeddings.forward)N)r9   r:   r;   rs   r   r    r#   rt   intr   r   
BoolTensorr4   r<   r,   r,   r*   r-   rz      s   + rz   c                       s0   e Zd Z fddZejejdddZ  ZS )InternVLVisionMLPc                    sD   t    || _t|j | _t|j|j	| _
t|j	|j| _d S N)r   r    rV   r   Z
hidden_actactivation_fnr!   r`   r(   Zintermediate_sizefc1fc2r'   rV   r*   r,   r-   r    U  s
    
zInternVLVisionMLP.__init__)r3   r   c                 C   s"   |  |}| |}| |}|S r   )r   r   r   )r'   r3   r,   r,   r-   r4   \  s    


zInternVLVisionMLP.forward)r9   r:   r;   r    r#   rt   r4   r<   r,   r,   r*   r-   r   T  s   r   )
layer_normZrms_normc                       sX   e Zd ZdZedd fddZd
ejee	e
ej e
ejejf f ddd	Z  ZS )rw   z?This corresponds to the Block class in the timm implementation.Nr   c                    s   t    |j| _d| _t|| _t|| _t|j	 |j
|jd| _t|j	 |j
|jd| _|j}tj|t|j
 dd| _tj|t|j
 dd| _t|j| _d S )Nr   r)   T)Zrequires_grad)r   r    Zchunk_size_feed_forwardZseq_len_dimrT   	attentionr   mlpNORM2FNZ	norm_typer(   layer_norm_epslayernorm_beforelayernorm_afterr   r!   r"   r#   r$   r   r   re   r   rD   )r'   rV   Zinit_valuesr*   r,   r-   r    i  s    


zInternVLVisionLayer.__init__F)r3   ri   r   c                 C   sl   | j | ||d\}}| j| }|| }| |}| |}| |}| jd ur\| j| }|| }||fS )N)ri   )r   r   r   r   r   rD   r   )r'   r3   ri   Zattention_outputZattention_weightsZlayer_outputr,   r,   r-   r4   x  s    






zInternVLVisionLayer.forward)F)r9   r:   r;   rs   r   r    r#   rt   boolr   r5   r4   r<   r,   r,   r*   r-   rw   f  s    rw   c                       sH   e Zd Zedd fddZed	ejeee	e
ef dddZ  ZS )
InternVLVisionEncoderNr   c                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r,   )rw   ).0irU   r,   r-   
<listcomp>      z2InternVLVisionEncoder.__init__.<locals>.<listcomp>F)	r   r    rV   r!   Z
ModuleListrangeZnum_hidden_layerslayerZgradient_checkpointingr   r*   rU   r-   r      s    
 zInternVLVisionEncoder.__init__F)r3   ri   output_hidden_statesr   c           	      C   sz   |rdnd }|rdnd }t | jD ]:\}}|r8||f }|||}|d }|r"||d f }q"|rl||f }t|||dS )Nr,   r   r   last_hidden_stater3   
attentions)	enumerater   r   )	r'   r3   ri   r   Zall_hidden_statesZall_self_attentionsr   Zlayer_moduleZlayer_outputsr,   r,   r-   r4     s     


zInternVLVisionEncoder.forward)FF)r9   r:   r;   r   r    r   r#   rt   r   r   r5   r   r4   r<   r,   r,   r*   r-   r     s     
r   c                
       sd   e Zd Zedd fddZdd Zeed
ej	e
ej e
e e
e eeef ddd	Z  ZS )InternVLVisionModelNr   c                    sT   t  | || _t|| _t|| _|jr4t	 ntj
|j|jd| _|   d S )Nr   )r   r    rV   rz   r   r   encoderZuse_mean_poolingr!   rf   	LayerNormr(   r   	layernorm	post_initr   r*   r,   r-   r      s    

zInternVLVisionModel.__init__c                 C   s   | j jS r   )r   r   r7   r,   r,   r-   get_input_embeddings  s    z(InternVLVisionModel.get_input_embeddings)rv   r   ri   r   r   c           	      C   sn   |dur|n| j j}|dur |n| j j}| j||d\}}| j|||d}|d }| |}t||j|jdS )z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        N)r   )ri   r   r   r   )	rV   ri   r   r   r   r   r   r3   r   )	r'   rv   r   ri   r   Zembedding_outputrp   Zencoder_outputsZsequence_outputr,   r,   r-   r4     s     
zInternVLVisionModel.forward)NNN)r9   r:   r;   r   r    r   r   r   r#   rt   r   r   r   r   r5   r   r4   r<   r,   r,   r*   r-   r     s      
r   c                   @   s6   e Zd ZU eed< dZdZdZdZdZ	dZ
dZdZdS )InternVLPreTrainedModelrV    Tpast_key_valuesN)r9   r:   r;   r   r   r   r   Z_skip_keys_device_placementr   r   Z_can_compile_fullgraphr   r   r,   r,   r,   r-   r     s   
r   c                       s*   e Zd Zed fddZdd Z  ZS )InternVLMultiModalProjectorrU   c                    sz   t    t|jjtd|j d  | _t	|jjtd|j d  |j
j| _t|j | _t	|j
j|j
j| _d S )Nr   r   )r   r    r!   r   vision_configr(   r   downsample_ratior   r`   text_configlinear_1r   Zprojector_hidden_actactlinear_2r   r*   r,   r-   r      s    
"z$InternVLMultiModalProjector.__init__c                 C   s,   |  |}| |}| |}| |}|S r   )r   r   r   r   )r'   image_featuresr3   r,   r,   r-   r4     s
    



z#InternVLMultiModalProjector.forward)r9   r:   r;   r   r    r4   r<   r,   r,   r*   r-   r     s   	r   zM
    Base class for InternVL outputs, with hidden states and attentions.
    c                   @   s$   e Zd ZU dZdZeej ed< dS )InternVLModelOutputWithPasta  
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
        `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
        `past_key_values` input) to speed up sequential decoding.
    image_hidden_states (`torch.FloatTensor`, *optional*):
        A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
        image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
    Nimage_hidden_states)	r9   r:   r;   rs   r   r   r#   FloatTensorr   r,   r,   r,   r-   r     s   
r   zx
    The InternVL model which consists of a vision backbone and a language model, without a language modeling head.
    c                       s$  e Zd ZddiZed fddZdd Zdd	 Zd
d Zdd Z	de
jeeeee f  ee dddZe
je
je
jdddZeede
je
jee
j ee
j ee ee
j eeeee f  ee ee ee ee ee ee
j ee eeef dddZde
jedddZ  Z S )InternVLModelzlanguage_model.modellanguage_modelrU   c                    s>   t  | t|j| _t|| _t|j| _	| 
  d S r   )r   r    r   from_configr   vision_towerr   multi_modal_projectorr   r   r   r   r*   r,   r-   r    1  s
    
zInternVLModel.__init__c                 C   s
   | j  S r   )r   r   r7   r,   r,   r-   r   9  s    z"InternVLModel.get_input_embeddingsc                 C   s   | j | d S r   )r   set_input_embeddingsr'   rA   r,   r,   r-   r   <  s    z"InternVLModel.set_input_embeddingsc                 C   s
   || _ d S r   r   r'   decoderr,   r,   r-   set_decoder?  s    zInternVLModel.set_decoderc                 C   s   | j S r   r   r7   r,   r,   r-   get_decoderB  s    zInternVLModel.get_decoderNrv   vision_feature_layervision_feature_select_strategyc           
      K   s   |dur|n| j j}|dur |n| j j}|j| jd}| j j}|dkrV| j|dj}n| j|dj	| }|dkr|ddddddf }|j
d }t|d }|j
d }	||	||d}| j||d	}||	d|j
d }| |}|S )
a%  
        Obtains image last hidden states from the vision tower and apply multimodal projection.

        Args:
            pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
               The tensors corresponding to the input images.
            vision_feature_layer (`int` or `list[int]`):
                Layer index or list of layer indices to extract features from.
        Returns:
            vision_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`.
        N)r/   r.   )rv   defaultr   r   r   )scale_factor)rV   r   r   r0   r/   r   r   r   Zvision_modelr3   r6   r   rl   pixel_shuffler   )
r'   rv   r   r   rN   r   vision_featureschannelsZfeature_sizern   r,   r,   r-   get_image_featuresE  s*    


z InternVLModel.get_image_features)	input_idsinputs_embedsr   c                 C   s   |du r8||   tj| jjtj|jdk}|d}n|| jjk}| }|	d
||j}|jd |jd  }||  | krtd| d| |S )z
        Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
        equal to the length of multimodal features. If the lengths are different, an error is raised.
        N)r/   devicer.   r   r   z6Image features and image tokens do not match: tokens: z, features )r   r#   ZtensorrV   Zimage_token_idlongr   allsumr   Z	expand_asr0   r6   Znumelr[   )r'   r   r   r   special_image_maskZn_image_tokensZn_image_featuresr,   r,   r-   get_placeholder_mask{  s    z"InternVLModel.get_placeholder_mask)r   rv   rB   position_idsr   r   r   r   	use_cacheri   r   return_dictcache_positionrN   r   c                 K   s   |
d ur|
n| j j}
|d ur |n| j j}|d ur4|n| j j}|d urH|n| j j}|d ur\|n| j j}|d u |d uA r|td|d u r|  |}|d ur| j|||d}|	|j
|j}| j|||d}|||}| jf |||||	|
|d|d	|}t|j|j|j|j|d ur|nd dS )Nz:You must specify exactly one of input_ids or inputs_embedsr   )r   r   T)	rB   r   r   r   r   ri   r   r   r   )r   r   r3   r   r   )rV   ri   r   use_return_dictr   r   r[   r   r   r0   r   r/   r   Zmasked_scatterr   r   r   r   r3   r   )r'   r   rv   rB   r   r   r   r   r   r   ri   r   r   r   rN   r   r   rr   r,   r,   r-   r4     sZ    
zInternVLModel.forwardr   )r   r   c              	   C   s   |  \}}}}|| dks(|| dkr0td|||t|| t|| }|dddd }||t|| t|| t||d  }|dddd }|S )a&  Perform pixel shuffle downsampling on vision features.

        Args:
            vision_features (`torch.Tensor`):
                Input tensor of shape (batch_size, width, height, channels).
            scale_factor (`float`, *optional*, defaults to `0.5`):
                Factor by which to downsample. Default is 0.5, which halves the dimensions.

        Returns:
            vision_features (`torch.Tensor`):
                Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)).
        r   zKHeight and width must be divisible by scale_factor for proper downsampling.r   r   r   )rk   r[   rm   r   r   rM   )r'   r   r   rn   r   r   r   r,   r,   r-   r     s    $zInternVLModel.pixel_shuffle)NN)NNNNNNNNNNNNN)r   )!r9   r:   r;   _checkpoint_conversion_mappingr   r    r   r   r   r   r#   r   r   r   r   liststrr   
LongTensorr   r   r   rt   r   r   r   r   r5   r   r4   floatr   r<   r,   r,   r*   r-   r   )  s`     7             
Fr   zT
    Base class for InternVL causal language model (or autoregressive) outputs.
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dZeeej  ed< dZeej ed< dS )	InternVLCausalLMOutputWithPasta]  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss (for next-token prediction).
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
        `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
        `past_key_values` input) to speed up sequential decoding.
    image_hidden_states (`torch.FloatTensor`, *optional*):
        A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
        image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
    Nlosslogitsr   r3   r   r   )r9   r:   r;   rs   r  r   r#   r   r   r  r   r   r3   r5   r   r   r,   r,   r,   r-   r     s   
r   zV
    The INTERNVL model which consists of a vision backbone and a language model.
    c                       sb  e Zd ZdddddZdgZed fdd	Zd
d Zdd Ze	j
dddZdd Zdd Zd%ejeeeee f  ee dddZedd Zedd Zedd Zeed&ejejeej eej ee eej eeeee f  ee eej ee ee ee ee eej eeejf eej e e! ee"e#f d d!d"Z$d' fd#d$	Z%  Z&S )( InternVLForConditionalGenerationzmodel.language_modelzmodel.vision_towerzmodel.multi_modal_projectorlm_head)z^language_model.modelz^vision_towerz^multi_modal_projectorz^language_model.lm_headzlm_head.weightrU   c                    s<   t  | t|| _tj|jj|jjdd| _	| 
  d S )NFrW   )r   r    r   modelr!   r`   r   r(   
vocab_sizer  r   r   r*   r,   r-   r    ,  s    
z)InternVLForConditionalGeneration.__init__c                 C   s
   | j  S r   )r  r   r7   r,   r,   r-   r   2  s    z5InternVLForConditionalGeneration.get_input_embeddingsc                 C   s   | j | d S r   )r  r   r   r,   r,   r-   r   5  s    z5InternVLForConditionalGeneration.set_input_embeddings)r   c                 C   s   | j S r   )r  r7   r,   r,   r-   get_output_embeddings8  s    z6InternVLForConditionalGeneration.get_output_embeddingsc                 C   s   | j | d S r   )r  r   r   r,   r,   r-   r   ;  s    z,InternVLForConditionalGeneration.set_decoderc                 C   s
   | j  S r   )r  r   r7   r,   r,   r-   r   >  s    z,InternVLForConditionalGeneration.get_decoderNr   c                 K   s   | j jf |||d|S )Nr   )r  r   )r'   rv   r   r   rN   r,   r,   r-   r   A  s    z3InternVLForConditionalGeneration.get_image_featuresc                 C   s   | j jS r   )r  r   r7   r,   r,   r-   r   P  s    z/InternVLForConditionalGeneration.language_modelc                 C   s   | j jS r   )r  r   r7   r,   r,   r-   r   T  s    z-InternVLForConditionalGeneration.vision_towerc                 C   s   | j jS r   )r  r   r7   r,   r,   r-   r   X  s    z6InternVLForConditionalGeneration.multi_modal_projectorr   )r   rv   rB   r   r   r   r   r   labelsr   ri   r   r   r   logits_to_keepimage_sizesrN   r   c                 K   s  |dur|n| j j}|dur |n| j j}|dur4|n| j j}|durH|n| j j}|dur\|n| j j}| jf |||||||||
||d||d|}|d }t|trt	| dn|}| 
|dd|ddf }d}|	dur| jf ||	| j jjd|}t|||j|j|j|jdS )ac  
        Example:

        ```python
        >>> import torch
        >>> from transformers import AutoProcessor, AutoModelForImageTextToText

        >>> torch_device = "cuda"
        >>> processor = AutoProcessor.from_pretrained("OpenGVLab/InternVL3-1B-hf")
        >>> model = AutoModelForImageTextToText.from_pretrained(
        ...     "OpenGVLab/InternVL3-1B-hf", dtype=torch.bfloat16, device_map=torch_device
        ... )

        >>> messages = [
        ...     {
        ...         "role": "user",
        ...         "content": [
        ...             {
        ...                 "type": "image",
        ...                 "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
        ...             },
        ...             {
        ...                 "type": "image",
        ...                 "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
        ...             },
        ...             {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
        ...         ],
        ...     },
        ... ]

        >>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(torch_device)
        >>> generate_ids = model.generate(**inputs, max_new_tokens=200)
        >>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True))
        The images depict the Statue of Liberty and the Golden Gate Bridge.
        ```NT)r   rv   rB   r   r   r   r   r   r   ri   r   r   r   r
  r   )r  r  r  )r  r  r   r3   r   r   )rV   ri   r   r   r   r   r  ry   r   slicer  Zloss_functionr   r  r   r   r3   r   r   )r'   r   rv   rB   r   r   r   r   r   r  r   ri   r   r   r   r	  r
  rN   rr   r3   Zslice_indicesr  r  r,   r,   r-   r4   \  s\    9z(InternVLForConditionalGeneration.forwardc           
         s8   t  j|f|||||d|}	|d dkr4||	d< |	S )N)r   r   rB   r   r	  r   rv   )r   prepare_inputs_for_generation)
r'   r   r   r   rv   rB   r   r	  rN   Zmodel_inputsr*   r,   r-   r    s    
z>InternVLForConditionalGeneration.prepare_inputs_for_generation)NN)NNNNNNNNNNNNNNr   N)NNNNNN)'r9   r:   r;   r   Z_tied_weights_keysr   r    r   r   r!   Moduler  r   r   r#   r   r   r   r   r   r   r   propertyr   r   r   r   r   r   rt   r   r   r   r   r5   r   r4   r  r<   r,   r,   r*   r-   r    s     


                
n      r  )ru   r   r   r   r  )r=   )Bcollections.abcr   dataclassesr   typingr   r   r   r#   Ztorch.nnr!   Zactivationsr   Zcache_utilsr   Z
generationr	   Zintegrationsr
   Zmodeling_flash_attention_utilsr   Zmodeling_layersr   Zmodeling_outputsr   r   r   Zmodeling_utilsr   r   Zprocessing_utilsr   utilsr   r   r   r   r   autor   Zconfiguration_internvlr   r   r  r   rt   r   rS   rT   ru   r   r   rz   r   r   r   rw   r   r   r   r   r   r   r   r  __all__r,   r,   r,   r-   <module>   s    K	&^0&5 R F