a
    h9                  	   @   s  d Z ddlZddlmZ ddlmZmZmZ ddlZddl	Zddlm
Z
 ddlmZ ddlmZmZ dd	lmZ dd
lmZ ddlmZ ddlmZmZmZmZ ddlmZmZ ddlmZm Z m!Z! ddl"m#Z#m$Z$m%Z%m&Z&m'Z' ddl(m)Z) ddl*m+Z+m,Z, e&-e.Z/ee$ddG dd de#Z0G dd de
j1Z2G dd de
j1Z3G dd de
j1Z4de3iZ5G dd de
j1Z6G d d! d!e
j1Z7G d"d# d#e
j1Z8G d$d% d%eZ9G d&d' d'e
j1Z:e$G d(d) d)eZ;G d*d+ d+e
j1Z<G d,d- d-e
j1Z=dFe
j1ej>ej>ej>eej> e?e?d/d0d1Z@G d2d3 d3e
j1ZAG d4d5 d5eZBG d6d7 d7e
j1ZCG d8d9 d9e
j1ZDe$d:dG d;d< d<e;ZEG d=d> d>e
j1ZFe$d?dG d@dA dAe;ZGe$dBdG dCdD dDe;eZHg dEZIdS )GzPyTorch GIT model.    N)	dataclass)CallableOptionalUnion)nn   )ACT2FN)CacheDynamicCache)GenerationMixin)_prepare_4d_attention_mask)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPastBaseModelOutputWithPoolingCausalLMOutputWithPast)ALL_ATTENTION_FUNCTIONSPreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputauto_docstringcan_return_tuplelogging	torch_int)deprecate_kwarg   )	GitConfigGitVisionConfigz}
    Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
    )Zcustom_introc                   @   sj   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dS )GitVisionModelOutputz
    image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
        The image embeddings obtained by applying the projection layer to the pooler_output.
    Nimage_embedslast_hidden_state.hidden_states
attentions)__name__
__module____qualname____doc__r!   r   torchFloatTensor__annotations__r"   r#   tupler$    r-   r-   `/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/git/modeling_git.pyr    5   s
   
r    c                       sL   e Zd ZdZ fddZd	eej eej eej e	ej
dddZ  ZS )
GitEmbeddingsz;Construct the embeddings from word and position embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	tj
|j|jd| _
t|j| _t|dd| _| jdt|jddd d S )	N)padding_idxepsposition_embedding_typeabsoluteposition_idsr   F
persistent)super__init__r   	Embedding
vocab_sizehidden_sizeZpad_token_idword_embeddingsmax_position_embeddingsposition_embeddings	LayerNormlayer_norm_epsDropouthidden_dropout_probdropoutgetattrr3   register_bufferr)   arangeexpandselfconfig	__class__r-   r.   r;   K   s    
zGitEmbeddings.__init__Nr   )	input_idsr5   inputs_embedspast_key_values_lengthreturnc           	      C   s   |d ur|  }n|  d d }|d }|d u rL| jd d ||| f }|d u r`| |}n|}| jdkr| |}||7 }| |}| |}|S )Nr7   r   r4   )sizer5   r?   r3   rA   rB   rF   )	rL   rP   r5   rQ   rR   input_shape
seq_length
embeddingsrA   r-   r-   r.   forwardZ   s    




zGitEmbeddings.forward)NNNr   )r%   r&   r'   r(   r;   r   r)   Z
LongTensorr*   intTensorrX   __classcell__r-   r-   rN   r.   r/   H   s       r/   c                
       sh   e Zd Zd fdd	Zedddddejeej eej ee	 ee
 ee
 eej d	d
dZ  ZS )GitSelfAttentionNc                    sX  t    |j|j dkr>t|ds>td|j d|j d|| _|d u rbtd| j	j
 d |j| _t|j|j | _| j| j | _t|jj|jj d d	 | _|jd ur|  j|j9  _t|j| j| _t|j| j| _t|j| j| _t|j| _|pt|d
d| _| jdks2| jdkrT|j| _td|j d	 | j| _d S )Nr   Zembedding_sizezThe hidden size (z6) is not a multiple of the number of attention heads ()zInstantiating z without passing a `layer_idx` is not recommended and will lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` when creating this class.   r   r3   r4   relative_keyrelative_key_query) r:   r;   r>   num_attention_headshasattr
ValueError	layer_idxloggerwarning_oncerO   r%   rY   attention_head_sizeall_head_sizevision_config
image_size
patch_sizeimage_patch_tokensnum_image_with_embeddingr   LinearquerykeyvaluerD   Zattention_probs_dropout_probrF   rG   r3   r@   r<   distance_embeddingrL   rM   r3   rd   rN   r-   r.   r;   y   s8    


zGitSelfAttention.__init__past_key_valuepast_key_values4.58new_nameversionFr#   attention_mask	head_maskru   output_attentionspixel_values_presentrS   c              	   C   s  |j \}}}	| ||d| j| jdd}
|r8| jnd}| ||d| j| jdd}| ||d| j| jdd}|d ur2|	|d d d d |d d d f |d d d d |d d d f | j
\}}tj|d d d d d |d d f |gdd}tj|d d d d d |d d f |gdd}t|
|dd}| jdks^| jdkrZ|
j d |j d  }}|d urtj|d tj|jd	dd}ntj|tj|jd	dd}tj|tj|jd	dd}|| }| || j d }|j|
jd
}| jdkr&td|
|}|| }n4| jdkrZtd|
|}td||}|| | }|t| j }|d ur||| }tjj|dd}| |}|d ur|| }t||}|dddd }|  d d | j!f }||}||fS )Nr7   r   r^   r   dimr_   r`   dtypedevicer   zbhld,lrd->bhlrzbhrd,lrd->bhlrr   )"shapero   viewra   rg   	transposerl   rp   rq   updaterd   r)   catmatmulr3   Ztensorlongr   rI   rr   r@   tor   Zeinsummathsqrtr   
functionalsoftmaxrF   permute
contiguousrT   rh   )rL   r#   r{   r|   ru   r}   r~   
batch_sizerV   _Zquery_layercutoffZ	key_layerZvalue_layerZkey_layer_pastZvalue_layer_pastZattention_scoresZquery_lengthZ
key_lengthZposition_ids_lZposition_ids_rZdistanceZpositional_embeddingZrelative_position_scoresZrelative_position_scores_queryZrelative_position_scores_keyZattention_probsZcontext_layerZnew_context_layer_shaper-   r-   r.   rX      sr    




@..





zGitSelfAttention.forward)NN)NNNFF)r%   r&   r'   r;   r   r)   rZ   r   r*   r	   boolr,   rX   r[   r-   r-   rN   r.   r\   x   s    "     r\   c                       s4   e Zd Z fddZejejejdddZ  ZS )GitSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S Nr1   )r:   r;   r   rn   r>   denserB   rC   rD   rE   rF   rK   rN   r-   r.   r;      s    
zGitSelfOutput.__init__r#   input_tensorrS   c                 C   s&   |  |}| |}| || }|S Nr   rF   rB   rL   r#   r   r-   r-   r.   rX      s    

zGitSelfOutput.forwardr%   r&   r'   r;   r)   rZ   rX   r[   r-   r-   rN   r.   r      s   r   eagerc                
       sp   e Zd Zd fdd	Zdd Zedddd	dejeej	 eej	 ee
 ee ee eej dddZ  ZS )GitAttentionNc                    s6   t    t|j |||d| _t|| _t | _d S )N)r3   rd   )	r:   r;   GIT_SELF_ATTENTION_CLASSES_attn_implementationrL   r   outputsetpruned_headsrs   rN   r-   r.   r;     s    

zGitAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r   )lenr   rL   ra   rg   r   r   ro   rp   rq   r   r   rh   union)rL   headsindexr-   r-   r.   prune_heads  s    zGitAttention.prune_headsrt   ru   rv   rw   Frz   c           
      C   s,   |  ||||||\}}| ||}	|	|fS r   )rL   r   )
rL   r#   r{   r|   ru   r}   r~   attn_outputZself_attn_weightsattention_outputr-   r-   r.   rX   !  s    
zGitAttention.forward)NN)NNNFF)r%   r&   r'   r;   r   r   r)   rZ   r   r*   r	   r   r,   rX   r[   r-   r-   rN   r.   r     s"   	     r   c                       s0   e Zd Z fddZejejdddZ  ZS )GitIntermediatec                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S r   )r:   r;   r   rn   r>   intermediate_sizer   
isinstance
hidden_actstrr   intermediate_act_fnrK   rN   r-   r.   r;   9  s
    
zGitIntermediate.__init__r#   rS   c                 C   s   |  |}| |}|S r   )r   r   rL   r#   r-   r-   r.   rX   A  s    

zGitIntermediate.forwardr   r-   r-   rN   r.   r   8  s   r   c                       s4   e Zd Z fddZejejejdddZ  ZS )	GitOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S r   )r:   r;   r   rn   r   r>   r   rB   rC   rD   rE   rF   rK   rN   r-   r.   r;   I  s    
zGitOutput.__init__r   c                 C   s&   |  |}| |}| || }|S r   r   r   r-   r-   r.   rX   O  s    

zGitOutput.forwardr   r-   r-   rN   r.   r   H  s   r   c                
       sp   e Zd Zd fdd	Zedddddejeej eej ee	 ee
 ee
 eej d	d
dZdd Z  ZS )GitLayerNc                    s>   t    |j| _d| _t||d| _t|| _t|| _	d S )Nr   )rd   )
r:   r;   chunk_size_feed_forwardseq_len_dimr   	attentionr   intermediater   r   )rL   rM   rd   rN   r-   r.   r;   W  s    

zGitLayer.__init__rt   ru   rv   rw   Frz   c           
      C   s6   | j ||||||d\}}t| j| j| j|}	|	|fS )N)r}   ru   r~   )r   r   feed_forward_chunkr   r   )
rL   r#   r{   r|   ru   r}   r~   r   Zself_attention_weightslayer_outputr-   r-   r.   rX   _  s    
	zGitLayer.forwardc                 C   s   |  |}| ||}|S r   )r   r   )rL   r   Zintermediate_outputr   r-   r-   r.   r   x  s    
zGitLayer.feed_forward_chunk)N)NNNFF)r%   r&   r'   r;   r   r)   rZ   r   r*   r	   r   r,   rX   r   r[   r-   r-   rN   r.   r   V  s"        r   c                       s   e Zd Z fddZd	ejeej eej eee	e
e
ej  f  ee ee ee ee ee ee
ej ef d
ddZ  ZS )

GitEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  |qS r-   )r   ).0irM   r-   r.   
<listcomp>      z'GitEncoder.__init__.<locals>.<listcomp>F)	r:   r;   rM   r   
ModuleListrangenum_hidden_layerslayergradient_checkpointingrK   rN   r   r.   r;     s    
 zGitEncoder.__init__NFT)
r#   r{   r|   ru   	use_cacher}   output_hidden_statesr~   return_dictrS   c
                 C   s  | j r| jr|rtd d}t|td tfs8td|rP|d u rPt| j	d}|rXdnd }
|rddnd }t
| jD ]V\}}|r|
|f }
|d ur|| nd }|||||||}|d }|rr||d f }qr|r|
|f }
|	stdd	 |||
|fD S t|||
|d
S )NzZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...FzBThe `past_key_values` should be either a `Cache` object or `None`.r   r-   r   r   c                 s   s   | ]}|d ur|V  qd S r   r-   )r   vr-   r-   r.   	<genexpr>  s   z%GitEncoder.forward.<locals>.<genexpr>r"   ru   r#   r$   )r   trainingre   rf   r   typer	   rc   r
   rM   	enumerater   r,   r   )rL   r#   r{   r|   ru   r   r}   r   r~   r   Zall_hidden_statesZall_self_attentionsr   Zlayer_moduleZlayer_head_masklayer_outputsr-   r-   r.   rX     sV    
	

zGitEncoder.forward)NNNNFFFT)r%   r&   r'   r;   r)   rZ   r   r*   r   r	   r,   r   r   rX   r[   r-   r-   rN   r.   r   ~  s*   	        r   c                   @   s&   e Zd ZU eed< dZdZdd ZdS )GitPreTrainedModelrM   gitTc                 C   s   t |trRtjj|jd| jjd tjj|jj	| jjd tjj|j
j	| jjd t |tjr|j	jjd| jjd |jdur|jj  nft |tjr|j	jjd| jjd |jdur|j	j|j   n&t |tjr|jj  |j	jd dS )zInitialize the weights        )meanstd)r   Ng      ?)r   GitVisionEmbeddingsr   initZnormal_class_embeddingrM   Zinitializer_rangepatch_embeddingweightposition_embeddingrn   databiasZzero_r<   r0   rB   Zfill_)rL   moduler-   r-   r.   _init_weights  s    


z GitPreTrainedModel._init_weightsN)r%   r&   r'   r   r+   Zbase_model_prefixZsupports_gradient_checkpointingr   r-   r-   r-   r.   r     s   
r   c                       sP   e Zd Zed fddZejeeejdddZdej	ejdd	d
Z
  ZS )r   r   c                    s   t    || _|j| _|j| _|j| _tt	
| j| _tj|j| j| j| jdd| _| j| j d | _| jd | _t| j| j| _| jdt	| jddd d S )NF)Zin_channelsZout_channelsZkernel_sizeZstrider   r^   r   r5   r6   r8   )r:   r;   rM   r>   	embed_dimrj   rk   r   	Parameterr)   Zrandnr   ZConv2dZnum_channelsr   num_patchesnum_positionsr<   r   rH   rI   rJ   rK   rN   r-   r.   r;     s"    
zGitVisionEmbeddings.__init__)rW   heightwidthrS   c                 C   s  |j d d }| jjd}|j d d }tj sP||krP||krP| | jS |ddddf }|ddddf }|j d }	|| j }
|| j }t	|d }|
d|||	}|dddd}tjj||
|fdd	d
}|dddddd|	}tj||fddS )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   r   Nr7   g      ?r   r^   ZbicubicF)rT   modeZalign_cornersr   )r   r   r   Z	unsqueezer)   Zjit
is_tracingr5   rk   r   reshaper   r   r   Zinterpolater   r   )rL   rW   r   r   r   r   r   Zclass_pos_embedZpatch_pos_embedr   Z
new_heightZ	new_widthZsqrt_num_positionsr-   r-   r.   interpolate_pos_encoding  s*    



z,GitVisionEmbeddings.interpolate_pos_encodingF)pixel_valuesrS   c              
   C   s   |j \}}}}|sL|| jks&|| jkrLtd| d| d| j d| j d	| jjj}| |j|d}|ddd}| j	
|dd}	tj|	|gdd	}
|r|
| |
|| }
n|
| | j }
|
S )
NzInput image size (*z) doesn't match model ().r   r^   r   r7   r   )r   rj   rc   r   r   r   r   flattenr   r   rJ   r)   r   r   r   r5   )rL   r   r   r   r   r   r   Ztarget_dtypeZpatch_embedsZclass_embedsrW   r-   r-   r.   rX   &  s     
zGitVisionEmbeddings.forward)F)r%   r&   r'   r   r;   r)   rZ   rY   r   r*   rX   r[   r-   r-   rN   r.   r     s   )r   c                       s0   e Zd Z fddZejejdddZ  ZS )GitVisionMLPc                    sD   t    || _t|j | _t|j|j	| _
t|j	|j| _d S r   )r:   r;   rM   r   r   activation_fnr   rn   r>   r   fc1fc2rK   rN   r-   r.   r;   :  s
    
zGitVisionMLP.__init__r   c                 C   s"   |  |}| |}| |}|S r   )r   r   r   r   r-   r-   r.   rX   A  s    


zGitVisionMLP.forwardr   r-   r-   rN   r.   r   9  s   r   r   )r   ro   rp   rq   r{   scalingrF   c           
      K   s|   t ||dd| }|d ur(|| }tjj|dt jd|j}tjj	||| j
d}t ||}	|	dd }	|	|fS )Nr7   r   )r   r   )pr   r   r^   )r)   r   r   r   r   r   Zfloat32r   r   rF   r   r   )
r   ro   rp   rq   r{   r   rF   kwargsattn_weightsr   r-   r-   r.   eager_attention_forwardI  s    
r   c                	       sZ   e Zd ZdZ fddZd	ejeej eej ee e	ejeej f dddZ
  ZS )
GitVisionAttentionz=Multi-headed attention from 'Attention Is All You Need' paperc                    s   t    || _|j| _|j| _| j| j | _| j| j | jkrZtd| j d| j d| jd | _	|j
| _d| _t| j| j| _t| j| j| _t| j| j| _t| j| j| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: r   g      F)r:   r;   rM   r>   r   ra   	num_headshead_dimrc   scaleZattention_dropoutrF   	is_causalr   rn   k_projv_projq_projout_projrK   rN   r-   r.   r;   c  s$    

zGitVisionAttention.__init__NFr#   r{   causal_attention_maskr}   rS   c              
   C   sP  |j \}}}| |}| |}	| |}
|||| j| jdd}|	||| j| jdd}	|
||| j| jdd}
| jj	dkr|dur|dur|| }q|dur|}n
|du| _
t}| jj	dkr| jj	dkr|rtd nt| jj	 }|| ||	|
|| j
| j| jsdn| jd	\}}|||| }| |}|sHd}||fS )
z#Input shape: Batch x Time x Channelr   r^   Zflash_attention_2Nr   Zsdpaz`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.r   )r   r   rF   )r   r  r   r   r   r   r   r   rM   r   r   r   re   rf   r   r   r   rF   r   r   r  )rL   r#   r{   r  r}   r   rV   r   ZquerieskeysvaluesZattention_interfacer   r   r-   r-   r.   rX   w  sF    	






zGitVisionAttention.forward)NNF)r%   r&   r'   r(   r;   r)   rZ   r   r   r,   rX   r[   r-   r-   rN   r.   r   `  s      r   c                       sJ   e Zd Zed fddZdejejejee e	ej
 dddZ  ZS )	GitVisionEncoderLayerr   c                    sR   t    |j| _t|| _tj| j|jd| _	t
|| _tj| j|jd| _d S r   )r:   r;   r>   r   r   	self_attnr   rB   rC   layer_norm1r   mlplayer_norm2rK   rN   r-   r.   r;     s    


zGitVisionEncoderLayer.__init__Fr  c                 C   sd   |}|  |}| j||||d\}}|| }|}| |}| |}|| }|f}|r`||f7 }|S )aI  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                `(config.encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r#   r{   r  r}   )r	  r  r  r
  )rL   r#   r{   r  r}   Zresidualr   outputsr-   r-   r.   rX     s"    




zGitVisionEncoderLayer.forward)F)r%   r&   r'   r   r;   r)   rZ   r   r   r,   r*   rX   r[   r-   r-   rN   r.   r    s    r  c                
       sd   e Zd ZdZed fddZed	eej	 eej	 ee
 ee
 ee
 eeef dddZ  ZS )
GitVisionEncoderz
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`GitVisionEncoderLayer`].

    Args:
        config: GitVisionConfig
    r   c                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r-   )r  r   r   r   r-   r.   r     r   z-GitVisionEncoder.__init__.<locals>.<listcomp>F)	r:   r;   rM   r   r   r   r   layersr   rK   rN   r   r.   r;     s    
 zGitVisionEncoder.__init__N)r{   r  r}   r   r   rS   c                 C   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}|rDdnd}|rPdnd}|}	t| jD ]@\}
}|rx||	f }||	|||d}|d }	|rb||d f }qb|r||	f }t|	||dS )a  
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Causal mask for the text model. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr-   )r}   r   r   r"   r#   r$   )rM   r}   r   use_return_dictr   r  r   )rL   rQ   r{   r  r}   r   r   Zencoder_statesZall_attentionsr#   idxZencoder_layerr   r-   r-   r.   rX     s0    '

zGitVisionEncoder.forward)NNNNN)r%   r&   r'   r(   r   r;   r   r   r)   rZ   r   r   r,   r   rX   r[   r-   r-   rN   r.   r    s         
r  c                
       s^   e Zd Zed fddZed	eej ee	 ee	 ee	 ee	 e
eef dddZ  ZS )
GitVisionTransformerr   c                    sR   t    || _|j}t|| _tj||jd| _	t
|| _tj||jd| _d S r   )r:   r;   rM   r>   r   rW   r   rB   rC   pre_layrnormr  encoderpost_layernorm)rL   rM   r   rN   r-   r.   r;   <  s    


zGitVisionTransformer.__init__NFr   r}   r   r   r   rS   c           	      C   s   |d ur|n| j j}|d ur |n| j j}|d ur4|n| j j}|d u rLtd| j||d}| |}| j||||d}|d }| |}|s|f|dd   S t	||j
|jdS )Nz You have to specify pixel_valuesr   )rQ   r}   r   r   r   r   r  )rM   r}   r   r  rc   rW   r  r  r  r   r#   r$   )	rL   r   r}   r   r   r   r#   encoder_outputsr"   r-   r-   r.   rX   F  s.    	

zGitVisionTransformer.forward)NNNFN)r%   r&   r'   r   r;   r   r   r)   r*   r   r   r,   r   rX   r[   r-   r-   rN   r.   r  :  s   
     
r  zY
    The vision model from CLIP, used in GIT, without any head or projection on top.
    c                
       sx   e Zd ZU eed< dZed fddZejdddZ	e
deej ee ee eee eeef dddZ  ZS )GitVisionModelrM   r   r   c                    s"   t  | t|| _|   d S r   )r:   r;   r  vision_model	post_initrK   rN   r-   r.   r;   z  s    
zGitVisionModel.__init__)rS   c                 C   s
   | j jjS r   )r  rW   r   rL   r-   r-   r.   get_input_embeddings  s    z#GitVisionModel.get_input_embeddingsNFr  c                 C   s(   |dur|n| j j}| j|||||dS )a{  
        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, GitVisionModel

        >>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
        >>> model = GitVisionModel.from_pretrained("microsoft/git-base")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        ```N)r   r}   r   r   r   )rM   r  r  )rL   r   r}   r   r   r   r-   r-   r.   rX     s    zGitVisionModel.forward)NNNFN)r%   r&   r'   r   r+   Zmain_input_namer;   r   Moduler  r   r   r)   r*   r   r   r,   r   rX   r[   r-   r-   rN   r.   r  p  s$   
     
r  c                       s6   e Zd Zed fddZejejdddZ  ZS )GitProjectionr   c                    s@   t    || _tt|jj|jtj|j|jj	d| _
d S r   )r:   r;   rM   r   Z
Sequentialrn   ri   r>   rB   rC   visual_projectionrK   rN   r-   r.   r;     s    
zGitProjection.__init__)rW   rS   c                 C   s
   |  |S r   )r!  )rL   rW   r-   r-   r.   rX     s    zGitProjection.forward)	r%   r&   r'   r   r;   r)   rZ   rX   r[   r-   r-   rN   r.   r     s   r   zy
    The bare GIT Model transformer consisting of a CLIP image encoder and text decoder outputting raw hidden-states
    c                       s   e Zd Z fddZdd Zdd Zdd Zeej	ej
ejd	d
dZdddZedeej eej eej eej eej eej eeeeej f  ee ee ee eee eeej ef dddZ  ZS )GitModelc                    sr   t     | _t | _t j| _t | _	t
 | _ jd urft fddt jD | _|   d S )Nc                 3   s&   | ]}t td d  jjV  qdS )r   N)r   r   r)   zerosri   r>   r  r   r-   r.   r     s   z$GitModel.__init__.<locals>.<genexpr>)r:   r;   rM   r/   rW   r  ri   image_encoderr   r  r   r!  rm   r   ZParameterListr   img_temperal_embeddingr  rK   rN   r   r.   r;     s    




zGitModel.__init__c                 C   s   | j jS r   rW   r?   r  r-   r-   r.   r    s    zGitModel.get_input_embeddingsc                 C   s   || j _d S r   r&  )rL   rq   r-   r-   r.   set_input_embeddings  s    zGitModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr  r   r   r   )rL   Zheads_to_pruner   r   r-   r-   r.   _prune_heads  s    zGitModel._prune_heads)rT   r   r   rS   c                 C   s4   t jt j||||ddd}||dktd}|S )Nr   r   r   )Zdiagonal-inf)r)   ZtriuZonesZmasked_fillfloat)rL   rT   r   r   maskr-   r-   r.   _generate_future_mask  s    zGitModel._generate_future_maskNc                 C   s  |j d }|j d }|j}|j}	tj||f||	d}
tj||| ftd|j|	d}tj||f|	|jd}|dkrtj|j d |j d | f|	|jd}tj|
|fdd}tj|||	fdd}tj||fddd d d f }|d u r
tj|j d |j d fd|d}|jtj	kr t
d	tj||jd
}td||< ||j d || || | f}| }|d d d d d |f }|d d d d d f }|| |d d d d d |f< |d d d d d d d f }|S )Nr   r*  r+  r   r   r   F)Z
fill_valuer   z1Memory key padding mask must be a boolean tensor.r   )r   r   r   r)   r#  fullr,  r   r   r   rc   Z
zeros_likerJ   clone)rL   tgtmemorytgt_maskrR   Zmemory_key_padding_maskZnum_tgtZ
num_memoryr   r   top_left	top_rightbottom_leftleftrightZfull_attention_maskZzero_negative_infinityZorigin_leftr   r-   r-   r.   create_attention_mask  sP    



 zGitModel.create_attention_maskF)rP   r{   r5   r   r|   rQ   ru   r   r}   r   r   r   rS   c                 C   s  |	dur|	n| j j}	|
dur |
n| j j}
|dur4|n| j j}|durH|n| j j}|durj|durjtdn@|dur| || | }n"|dur| dd }ntd|d }d}|durt|t	s|
 n|
 }| || j j}d}|dur|jdkr| j||dj}n~|jd	krg }t|jd D ]F}| j|dd|ddddf |dj}|| j| 7 }|| q2tj|dd
}ntd| |}| j||||d}|du rtj|jd d|jd f|j|jd}||d|d dd}tj||fdd
}| ||j|j}| j||||d}|durt||j|d d|j}|dkr|dddd| dddf }n4|dddd|d  d|d  df  |7  < | j ||||||	|
||dud	}|d }|s|f|dd  S t!||j"|j#|j$dS )a  
        Examples:

        ```python
        >>> from transformers import AutoProcessor, AutoModel
        >>> import requests
        >>> from PIL import Image

        >>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
        >>> model = AutoModel.from_pretrained("microsoft/git-base")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> text = "this is an image of two cats"

        >>> inputs = processor(images=image, text=text, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        ```NzDYou cannot specify both input_ids and inputs_embeds at the same timer7   z5You have to specify either input_ids or inputs_embedsr   r      r     r   z#pixel_values must be of rank 4 or 5)rP   r5   rQ   rR   r^   r   )r1  r2  r3  rR   )Ztgt_len)r{   r|   ru   r   r}   r   r   r~   r   )%rM   r}   r   r   r  rc   Z%warn_if_padding_and_no_attention_maskrT   r   r	   get_seq_lengthZget_head_maskr   ndimr$  r"   r   r   r%  appendr)   r   r!  rW   r#  r   r   repeatr.  r9  r   r   r  r   ru   r#   r$   )rL   rP   r{   r5   r   r|   rQ   ru   r   r}   r   r   r   rU   rV   rR   Zprojected_visual_featuresZvisual_featuresZ	frame_idxZvisual_features_frameZembedding_outputr#   r3  Zcombined_attention_maskZexpanded_attn_maskr  sequence_outputr-   r-   r.   rX     s    %








$4zGitModel.forward)N)NNNNNNNNNNFN)r%   r&   r'   r;   r  r'  r)  rY   r)   r   r   rZ   r.  r9  r   r   r   r	   listr*   r   r,   r   rX   r[   r-   r-   rN   r.   r"    sD   
2            r"  z`
    GIT Model with a `language modeling` head on top for autoregressive language modeling.
    c                       s   e Zd ZdgZ fddZdd Zdd Zedee	j
 ee	j
 ee	j
 ee	j
 ee	j
 ee	j
 ee	j
 eeeee	j
 f  ee ee ee eee eee	j
 ef d
ddZdddZ  ZS )GitForCausalLMzoutput.weightc                    s4   t  | t|| _t|j|j| _| 	  d S r   )
r:   r;   r"  r   r   rn   r>   r=   r   r  rK   rN   r-   r.   r;     s    
zGitForCausalLM.__init__c                 C   s   | j S r   r   r  r-   r-   r.   get_output_embeddings  s    z$GitForCausalLM.get_output_embeddingsc                 C   s
   || _ d S r   rC  )rL   Znew_embeddingsr-   r-   r.   set_output_embeddings  s    z$GitForCausalLM.set_output_embeddingsNF)rP   r{   r5   r   r|   rQ   labelsru   r   r}   r   r   r   rS   c                 K   s   |dur|n| j j}|dur d}	| j||||||||	|
|||d}|d }| |}d}|dur| jjjd jjj}|dd|dddf 	 }|ddddf 	 }| j
|d| j j|dfd| j ji|}|s|f|dd  }|dur|f| S |S t|||j|j|jdS )	a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
            `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
            ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`

        Examples:

        Image captioning example:

        ```python
        >>> from transformers import AutoProcessor, AutoModelForCausalLM
        >>> import requests
        >>> from PIL import Image

        >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
        >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> pixel_values = processor(images=image, return_tensors="pt").pixel_values

        >>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
        >>> generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        >>> print(generated_caption)
        two cats sleeping on a pink blanket next to remotes.
        ```

        Visual question answering (VQA) example:

        ```python
        >>> from transformers import AutoProcessor, AutoModelForCausalLM
        >>> from huggingface_hub import hf_hub_download
        >>> from PIL import Image

        >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-textvqa")
        >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textvqa")

        >>> file_path = hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset")
        >>> image = Image.open(file_path).convert("RGB")

        >>> pixel_values = processor(images=image, return_tensors="pt").pixel_values

        >>> question = "what does the front of the bus say at the top?"

        >>> input_ids = processor(text=question, add_special_tokens=False).input_ids
        >>> input_ids = [processor.tokenizer.cls_token_id] + input_ids
        >>> input_ids = torch.tensor(input_ids).unsqueeze(0)

        >>> generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
        >>> print(processor.batch_decode(generated_ids, skip_special_tokens=True))
        ['what does the front of the bus say at the top? special']
        ```

        Video captioning example:

        ```python
        >>> import av
        >>> import numpy as np
        >>> from PIL import Image
        >>> from huggingface_hub import hf_hub_download
        >>> from transformers import AutoProcessor, AutoModelForCausalLM

        >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-vatex")
        >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vatex")

        >>> # set seed for reproducibility
        >>> np.random.seed(45)


        >>> def read_video_pyav(container, indices):
        ...     '''
        ...     Decode the video with PyAV decoder.
        ...     Args:
        ...         container (`av.container.input.InputContainer`): PyAV container.
        ...         indices (`list[int]`): List of frame indices to decode.
        ...     Returns:
        ...         result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
        ...     '''
        ...     frames = []
        ...     container.seek(0)
        ...     start_index = indices[0]
        ...     end_index = indices[-1]
        ...     for i, frame in enumerate(container.decode(video=0)):
        ...         if i > end_index:
        ...             break
        ...         if i >= start_index and i in indices:
        ...             frames.append(frame)
        ...     return np.stack([x.to_ndarray(format="rgb24") for x in frames])


        >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
        ...     '''
        ...     Sample a given number of frame indices from the video.
        ...     Args:
        ...         clip_len (`int`): Total number of frames to sample.
        ...         frame_sample_rate (`int`): Sample every n-th frame.
        ...         seg_len (`int`): Maximum allowed index of sample's last frame.
        ...     Returns:
        ...         indices (`list[int]`): List of sampled frame indices
        ...     '''
        ...     converted_len = int(clip_len * frame_sample_rate)
        ...     end_idx = np.random.randint(converted_len, seg_len)
        ...     start_idx = end_idx - converted_len
        ...     indices = np.linspace(start_idx, end_idx, num=clip_len)
        ...     indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
        ...     return indices


        >>> # load video
        >>> file_path = hf_hub_download(
        ...     repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
        ... )
        >>> container = av.open(file_path)

        >>> # sample frames
        >>> num_frames = model.config.num_image_with_embedding
        >>> indices = sample_frame_indices(
        ...     clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames
        ... )
        >>> frames = read_video_pyav(container, indices)

        >>> pixel_values = processor(images=list(frames), return_tensors="pt").pixel_values

        >>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)

        >>> print("Generated caption:", processor.batch_decode(generated_ids, skip_special_tokens=True))
        Generated caption: ['a woman is sitting at a table and she is talking about the food she is holding.']
        ```
        NF)r{   r5   r   r|   rQ   ru   r   r}   r   r   r   r   r7   r   r=   )losslogitsru   r#   r$   )rM   r  r   r   r  r   r   rL   rl   r   Zloss_functionr   r=   r   ru   r#   r$   )rL   rP   r{   r5   r   r|   rQ   rF  ru   r   r}   r   r   r   r   r  r@  rH  rG  Znum_image_tokensZshifted_logitsr   r-   r-   r.   rX     sV     
zGitForCausalLM.forwardc           	      K   st   |d urF|  }|jd |kr$|}n|jd d }|d d |d f }|j}|d u r^||}|||d||dS )Nr   r   )rP   r{   r   ru   r   )r<  r   Znew_onesget)	rL   rP   ru   r{   r   r   Zpast_lengthZremove_prefix_lengthrU   r-   r-   r.   prepare_inputs_for_generation  s    
z,GitForCausalLM.prepare_inputs_for_generation)NNNNNNNNNNNFN)NNN)r%   r&   r'   Z_tied_weights_keysr;   rD  rE  r   r   r)   rZ   r   r	   rA  r   r,   r   rX   rJ  r[   r-   r-   rN   r.   rB    sJ   	              E rB  )rB  r"  r   r  )r   )Jr(   r   dataclassesr   typingr   r   r   r)   Ztorch.utils.checkpointr   Zactivationsr   Zcache_utilsr	   r
   Z
generationr   Zmodeling_attn_mask_utilsr   Zmodeling_layersr   Zmodeling_outputsr   r   r   r   Zmodeling_utilsr   r   Zpytorch_utilsr   r   r   utilsr   r   r   r   r   Zutils.deprecationr   Zconfiguration_gitr   r   Z
get_loggerr%   re   r    r  r/   r\   r   r   r   r   r   r   r   r   r   r   rZ   r,  r   r   r  r  r  r  r   r"  rB  __all__r-   r-   r-   r.   <module>   s   
0z3(LS P3W65   w