a
    h                    @   s  d Z ddlZddlmZmZ ddlZddlZddlmZ ddlm	Z	 ddl
mZmZmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZmZmZ ddlmZ ddlmZmZmZm Z m!Z!m"Z"m#Z# ddl$m%Z% ddl&m'Z'm(Z(m)Z) e  rddl*m+Z+ ddl,m-Z- e#.e/Z0G dd dej1Z2zddl3m4Z4 e4Z2e05d W n0 e6yV   Y n e7yr   e08d Y n0 G dd dej1Z9G dd dej1Z:G dd dej1Z;G dd deZ<G d d! d!ej1Z=eG d"d# d#eZ>eG d$d% d%e>Z?G d&d' d'ej1Z@G d(d) d)ej1ZAG d*d+ d+ej1ZBG d,d- d-ej1ZCG d.d/ d/ej1ZDG d0d1 d1eZEed2d3G d4d5 d5e>ZFed6d3G d7d8 d8e>eZGg d9ZHdS ):zPix2Struct modeling file    N)OptionalUnion)nn   )ACT2FN)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)AttentionMaskConverter)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPooling!CausalLMOutputWithCrossAttentionsSeq2SeqLMOutputSeq2SeqModelOutput)PreTrainedModel)DUMMY_INPUTS
DUMMY_MASKauto_docstringis_torch_flex_attn_availableis_torch_fx_proxyis_torchdynamo_compilinglogging)deprecate_kwarg   )Pix2StructConfigPix2StructTextConfigPix2StructVisionConfig)	BlockMask)make_flex_block_causal_maskc                       s&   e Zd Zd fdd	Zdd Z  ZS )Pix2StructLayerNormư>c                    s&   t    tt|| _|| _dS )zc
        Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
        N)super__init__r   	Parametertorchonesweightvariance_epsilon)selfhidden_sizeeps	__class__ n/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/pix2struct/modeling_pix2struct.pyr$   ?   s    
zPix2StructLayerNorm.__init__c                 C   s\   | tjdjddd}|t|| j  }| jjtj	tj
fv rR| | jj}| j| S )N   T)Zkeepdim)tor&   float32powmeanZrsqrtr)   r(   dtypefloat16Zbfloat16)r*   hidden_statesZvariancer/   r/   r0   forwardG   s
    zPix2StructLayerNorm.forward)r"   __name__
__module____qualname__r$   r:   __classcell__r/   r/   r-   r0   r!   >   s   r!   )FusedRMSNormzWDiscovered apex.normalization.FusedRMSNorm - will use it instead of Pix2StructLayerNormzJDiscovered apex but it failed to load, falling back to Pix2StructLayerNormc                       s<   e Zd ZdZedd fddZejejdddZ  Z	S )	Pix2StructVisionEmbeddingsa-  
    Construct the embeddings from patch. In `Pix2Struct` the input is different from classic Vision-transformer models.
    Here the input is a sequence of `seq_len` flattened patches that also combines padding patches (tokens). Each patch
    is represented by a vector of `hidden_size` values.
    Nconfigreturnc                    sR   t    t|j|j| _t|j|j| _	t|j|j| _
t|j| _d S N)r#   r$   r   LinearZpatch_embed_hidden_sizer+   patch_projection	EmbeddingZseq_lenrow_embeddercolumn_embedderDropoutdropout_ratedropoutr*   rC   r-   r/   r0   r$   l   s
    
z#Pix2StructVisionEmbeddings.__init__)flattened_patchesrD   c                 C   s   |d d d d df   }|d d d d df   }|d d d d dd f }| |}| |}| |}|| | }| |}|S )Nr   r   r1   )longrG   rI   rJ   rM   )r*   rO   Zrow_indicesZcol_indices
embeddingsZrow_embeddingsZcol_embeddingsr/   r/   r0   r:   u   s    



z"Pix2StructVisionEmbeddings.forward)
r<   r=   r>   __doc__r   r$   r&   Tensorr:   r?   r/   r/   r-   r0   rA   e   s   	rA   c                       s&   e Zd Z fddZdddZ  ZS )Pix2StructVisionAttentionc                    s   t    |j| _|j| _|j| _|j| _| j| j | _	t
j| j| j	dd| _t
j| j| j	dd| _t
j| j| j	dd| _t
j| j	| jdd| _d| _d S NFbias)r#   r$   r+   d_kvkey_value_proj_dimZnum_attention_headsn_headsZattention_dropoutrM   	inner_dimr   rF   querykeyvalueoutputgradient_checkpointingrN   r-   r/   r0   r$      s    
z"Pix2StructVisionAttention.__init__NFc                    s  |j dd \ } fdd}||}||}	||}
t||	dd}|du rtjdj||f|j	|j
d}jrjrd|_| dkr||ddddddf |j	 }nJ|dur|||j	 }n0t stj |f|j	|j
d}|||j	 }d| }||dkt|j
j}||7 }t|tt|j
j}tjj|d	tjd
|}tjj|jjd}|dur|| }t||
}|dd  d	j} |}|f|f }|r||f }|S )z&
        Self-attention block
        Nr1   c                    s    |    djjddS )Z
projectionr2   r   r1   )
contiguousviewrZ   rY   	transpose)Zstates
batch_sizer*   r/   r0   to_projection_shape   s    z>Pix2StructVisionAttention.forward.<locals>.to_projection_shaper   r   devicer7   Tr2   )dimr7   ptraining)!shaper\   r]   r^   r&   matmulrc   zerosrZ   rh   r7   r`   rl   requires_gradri   r3   r   r'   masked_fillfinfominmaxtensorr   
functionalsoftmaxr4   type_asrM   ra   rb   r[   r_   )r*   r9   attention_maskposition_biaslayer_head_maskoutput_attentions
seq_lengthrf   query_states
key_statesvalue_statesscoresposition_bias_maskedattn_weightsattn_outputoutputsr/   rd   r0   r:      sH    
&


z!Pix2StructVisionAttention.forward)NNNFr;   r/   r/   r-   r0   rT      s       rT   c                       s*   e Zd Zed fddZdd Z  ZS )Pix2StructVisionMlprC   c                    sj   t    tj|j|jdd| _tj|j|jdd| _tj|j|jdd| _t	|j
| _t|j | _d S rU   r#   r$   r   rF   r+   d_ffwi_0wi_1worK   rL   rM   r   Zdense_act_fnactrN   r-   r/   r0   r$      s    
zPix2StructVisionMlp.__init__c                 C   sz   |  | |}| |}|| }| |}t| jjtjrl|j	| jjj	krl| jjj	tj
krl|| jjj	}| |}|S rE   r   r   r   rM   
isinstancer   r(   r&   rS   r7   Zint8r3   r*   r9   Zhidden_geluZhidden_linearr/   r/   r0   r:      s    


zPix2StructVisionMlp.forward)r<   r=   r>   r   r$   r:   r?   r/   r/   r-   r0   r      s   r   c                	       sd   e Zd Zedd fddZd	ejeej eej ee	e
ejejf e
ej f dddZ  ZS )
Pix2StructVisionLayerNrB   c                    sT   t    |j| _d| _t|| _t|| _t|j	|j
d| _t|j	|j
d| _d S )Nr   r,   )r#   r$   Zchunk_size_feed_forwardZseq_len_dimrT   	attentionr   mlpr!   r+   layer_norm_epspre_mlp_layer_normpre_attention_layer_normrN   r-   r/   r0   r$   	  s    


zPix2StructVisionLayer.__init__F)r9   ry   	head_maskr|   rD   c           
      C   sb   |}|  |}| j||||d}|d }|dd  }|| }| |}	| |	| }	|	f| }|S )N)ry   r{   r|   r   r   )r   r   r   r   )
r*   r9   ry   r   r|   Zresidualself_attention_outputsattention_outputr   layer_outputr/   r/   r0   r:     s    


zPix2StructVisionLayer.forward)NNF)r<   r=   r>   r   r$   r&   rS   r   boolr   tupler:   r?   r/   r/   r-   r0   r     s      r   c                
       sV   e Zd Zedd fddZd
ejeej eej eeee	e
ef ddd	Z  ZS )Pix2StructVisionEncoderNrB   c                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r/   )r   ).0_r   r/   r0   
<listcomp>7      z4Pix2StructVisionEncoder.__init__.<locals>.<listcomp>F)	r#   r$   rC   r   
ModuleListrangenum_hidden_layerslayerr`   rN   r-   r   r0   r$   4  s    
 z Pix2StructVisionEncoder.__init__FT)r9   ry   r   r|   output_hidden_statesreturn_dictrD   c                 C   s   |rdnd }|rdnd }t | jD ]R\}	}
|r8||f }|d urH||	 nd }|
||||}|d }|r"||d f }q"|r||f }|stdd |||fD S t|||dS )Nr/   r   r   c                 s   s   | ]}|d ur|V  qd S rE   r/   r   vr/   r/   r0   	<genexpr>W  r   z2Pix2StructVisionEncoder.forward.<locals>.<genexpr>last_hidden_stater9   
attentions)	enumerater   r   r   )r*   r9   ry   r   r|   r   r   all_hidden_statesZall_self_attentionsilayer_moduler{   layer_outputsr/   r/   r0   r:   :  s&    	

zPix2StructVisionEncoder.forward)NNFFT)r<   r=   r>   r   r$   r&   rS   r   r   r   r   r   r:   r?   r/   r/   r-   r0   r   3  s   	     
r   c                   @   s6   e Zd ZU eed< dZedd Zdd Zdd Z	d	S )
Pix2StructPreTrainedModelrC   Fc                 C   s$   t t}t t}|||d}|S )N)decoder_input_ids	input_idsdecoder_attention_mask)r&   ru   r   r   )r*   r   Z
input_maskdummy_inputsr/   r/   r0   r   e  s    

z&Pix2StructPreTrainedModel.dummy_inputsc                 C   s  | j j}t|tr(|jj|d  nt|tr>t| j trJ| j j	j
n| j j
}t| j trh| j j	jn| j j}|jjjjd||d  d t|jdr|jjdur|jjj  |jjjjd||d  d t|jdr|jjdur|jjj  |jjjjd||d  d t|jdr|jjdur|jjj  nt|trJt| j trb| j j	j
n| j j
}t| j tr| j j	jn| j j
}t| j tr| j j	jn| j j}|jjjjd||| d  d |jjjjd||d  d |jjjjd||d  d |jjjjd||| d  d |jr|jjjjd||d  d n|t|tjrt| j trp| j j	j
n| j j
}|jjjd||d  d |jdur|jj|j   nt|trt| j tr| j j	j
n| j j
}|j jjjd||d  d nt|tj!tj"fr\tj#j$|jj%t&j'd| j j(d%|jj)|j_|jdur|jj  njt|tr|jdur|jjd nBt|tjr|jjjd| j j(d |jdur|jj|j   dS )zInitialize the weights      ?        g      )r6   ZstdrW   N)*rC   Zinitializer_factorr   r!   r(   dataZfill_ Pix2StructTextDenseGatedActDenser   text_configr+   r   r   Znormal_hasattrrW   Zzero_r   r   Pix2StructTextAttentionrX   	num_headsr\   r]   r^   r_   has_relative_attention_biasrelative_attention_biasr   rH   Zpadding_idxPix2StructTextModellm_headrF   ZConv2dinitZtrunc_normal_r3   r&   r4   Zinitializer_ranger7   )r*   modulefactorr+   r   rY   rZ   r/   r/   r0   _init_weightsp  s    



   

z'Pix2StructPreTrainedModel._init_weightsc                 C   s   | j j}| j j}|d u r tdt|rbt|jd d d |}tj||dd df gdd}n4|	|j}|dd df 
 |ddd f< ||d< |d u rtd||d	k| |S )
Nzself.model.config.decoder_start_token_id has to be defined. In Pix2Struct it is usually set to the pad_token_id. See Pix2Struct docs for more information.r2   )r   .ri   r   ).r   z1self.model.config.pad_token_id has to be defined.)rC   decoder_start_token_idpad_token_id
ValueErrorr   r&   fullrm   catZ	new_zeroscloneZmasked_fill_)r*   r   r   r   Zshifted_input_idsr/   r/   r0   _shift_right  s       z&Pix2StructPreTrainedModel._shift_rightN)
r<   r=   r>   r   __annotations__Z_can_compile_fullgraphpropertyr   r   r   r/   r/   r/   r0   r   _  s   


Pr   c                       s   e Zd ZU eed< dZdZdgZed fddZdd	 Z	e
eee f d
dddZedeej eej eej ee ee ee eeef dddZ  ZS )Pix2StructVisionModelrC   rO   Tr   r   c                    sD   t  | || _t|| _t|| _t|j|j	d| _
|   d S Nr   )r#   r$   rC   rA   rQ   r   encoderr!   r+   r   	layernorm	post_initrN   r-   r/   r0   r$     s    

zPix2StructVisionModel.__init__c                 C   s   | j jS rE   )rQ   rG   r*   r/   r/   r0   get_input_embeddings  s    z*Pix2StructVisionModel.get_input_embeddingsN)heads_to_prunerD   c                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   Zprune_heads)r*   r   r   Zheadsr/   r/   r0   _prune_heads  s    z"Pix2StructVisionModel._prune_heads)rO   ry   r   r|   r   r   rD   c                 C   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}|du rLtd|du rh|jdddk }| || j j}| 	|}| j
||||||d}|d }	| |	}	|s|	f}
|
|dd  S t|	|j|jdS )	a  
        flattened_patches (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_channels x patch_height x patch_width)`):
            Flattened and padded pixel values. These values can be obtained using [`AutoImageProcessor`]. See
            [`Pix2StructVisionImageProcessor.__call__`] for details. Check the [original
            paper](https://huggingface.co/papers/2210.03347) (figure 5) for more details.

        Example:

        ```python
        >>> import requests
        >>> from PIL import Image
        >>> from transformers import AutoProcessor, Pix2StructVisionModel

        >>> image_processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
        >>> model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base")

        >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> with torch.no_grad():
        ...     outputs = model(**inputs)

        >>> last_hidden_states = outputs.last_hidden_state
        >>> list(last_hidden_states.shape)
        [1, 2048, 768]
        ```
        Nz%You have to specify flattened_patchesr2   r   r   )ry   r   r|   r   r   r   r   )rC   r|   r   use_return_dictr   sumfloatget_head_maskr   rQ   r   r   r   r9   r   )r*   rO   ry   r   r|   r   r   Zembedding_outputencoder_outputsZsequence_outputZhead_outputsr/   r/   r0   r:     s8    &

zPix2StructVisionModel.forward)NNNNNN)r<   r=   r>   r   r   main_input_namesupports_gradient_checkpointing_no_split_modulesr$   r   dictintlistr   r   r   r&   rS   r   r   r   r   r:   r?   r/   r/   r-   r0   r     s.   
      
r   c                       s*   e Zd Zed fddZdd Z  ZS )r   r   c                    sj   t    tj|j|jdd| _tj|j|jdd| _tj|j|jdd| _t	|j
| _t|j | _d S rU   r   rN   r-   r/   r0   r$   N  s    
z)Pix2StructTextDenseGatedActDense.__init__c                 C   sz   |  | |}| |}|| }| |}t| jjtjrl|j	| jjj	krl| jjj	tj
krl|| jjj	}| |}|S rE   r   r   r/   r/   r0   r:   V  s    


z(Pix2StructTextDenseGatedActDense.forwardr<   r=   r>   r   r$   r:   r?   r/   r/   r-   r0   r   M  s   r   c                       s*   e Zd Zed fddZdd Z  ZS )Pix2StructTextLayerFFr   c                    s8   t    t|| _t|j|jd| _t	|j
| _d S r   )r#   r$   r   DenseReluDenser!   r+   layer_norm_epsilon
layer_normr   rK   rL   rM   rN   r-   r/   r0   r$   k  s    

zPix2StructTextLayerFF.__init__c                 C   s&   |  |}| |}|| | }|S rE   )r   r   rM   )r*   r9   Zforwarded_statesr/   r/   r0   r:   s  s    

zPix2StructTextLayerFF.forwardr   r/   r/   r-   r0   r   j  s   r   c                
       sZ   e Zd Zdeee d fddZedd	d
ZdddZ	e
dddddddZ  ZS )r   FN)rC   	layer_idxc                    s   t    || _|j| _|j| _|j| _|j| _|j| _	|j
| _| j	| j | _|| _|d u rrtd| jj d tj| j| jdd| _tj| j| jdd| _tj| j| jdd| _tj| j| jdd| _| jrt| j| j	| _t | _d| _d S )NzInstantiating a decoder z without passing `layer_idx` is not recommended and will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` when creating this class.FrV   )r#   r$   r   relative_attention_num_bucketsrelative_attention_max_distancer+   rX   rY   r   rZ   rL   rM   r[   r   loggerZwarning_oncer.   r<   r   rF   r\   r]   r^   r_   rH   r   setpruned_headsr`   r*   rC   r   r   r-   r/   r0   r$   {  s,    
z Pix2StructTextAttention.__init__T       c                 C   s   d}|r4|d }|| dk tj| 7 }t| } nt| t|  } |d }| |k }|t|  | t||  ||   tj }t|t	||d }|t
|| |7 }|S )a  
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on

        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
            max_distance: an integer

        Returns:
            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
        r   r1   r   )r3   r&   rP   absrs   Z
zeros_likelogr   mathZ	full_likewhere)relative_positionbidirectionalnum_bucketsmax_distanceZrelative_bucketsZ	max_exactZis_smallZrelative_position_if_larger/   r/   r0   _relative_position_bucket  s,    z1Pix2StructTextAttention._relative_position_bucketc           
      C   s   |du r| j jj}|du r:tj|tj|ddddf }n|dddf |}tj|tj|ddddf }|| }| j|d| j| j	d}|  |}	|	
g dd}	|	S )z%Compute binned relative position biasN)r7   rh   F)r   r   r   )r1   r   r   r   )r   r(   rh   r&   arangerP   r3   r   r   r   ZpermuteZ	unsqueeze)
r*   query_length
key_lengthrh   cache_positionZcontext_positionZmemory_positionr   Zrelative_position_bucketvaluesr/   r/   r0   compute_bias  s     
 
z$Pix2StructTextAttention.compute_biaspast_key_valuepast_key_values4.58new_nameversionc                 C   s  |j dd \}}|du}| |}||d| j| jdd}|durtt|trt|j	| j
}|rl|j}qx|j}n|}|r|n|}|r|r|r|j| j
 j}|j| j
 j}n| |}| |}||d| j| jdd}||d| j| jdd}|dur>|s|
nd}
|||| j
d|
i\}}|r>d|j| j
< t||dd}|du r,|j d }|durt|n
|
d d }| jstjd| j||f|j|jd	}| jr| jrd|_n6| j|||j|
d
}|dddd| dddf }|dur,|ddddddd|j d f }|| }| jrht|j d }d|t| j< |dd|  f }n|}||7 }t!j"j#|$ dd%|}t!j"j&|| j&| jd}|dur|| }t||}|dd' }||d| j(}| )|}||f}|	r||f }|S )z
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
        Nr1   r2   r   r  Tr   rg   )rh   r  r   r   rj   )*rm   r\   rb   rZ   rY   rc   r   r	   
is_updatedgetr   Zcross_attention_cacheself_attention_cacheZlayerskeysr  r]   r^   updater&   rn   r   ro   rh   r7   r`   rl   rp   r  r   r'   r   r   r   rv   rw   r   rx   rM   ra   r[   r_   )r*   r9   maskkey_value_statesrz   r  r{   r  	use_cacher|   r  re   r}   Zis_cross_attentionr~   r  Zcurr_past_key_valueZcurrent_statesr   r   r   r  Zreal_seq_lengthcausal_maskr   r   r   r   r/   r/   r0   r:     sx    






"
&


zPix2StructTextAttention.forward)FN)Tr   r   )NN)	NNNNNNFFN)r<   r=   r>   r   r   r   r$   staticmethodr   r  r   r:   r?   r/   r/   r-   r0   r   z  s$    0
         r   c                       s@   e Zd Zdee d fddZedddd	dd
dZ  ZS ) Pix2StructTextLayerSelfAttentionFNr   c                    s>   t    t|||d| _t|j|jd| _t	|j
| _d S )Nr   r   r   r#   r$   r   r   r!   r+   r   r   r   rK   rL   rM   r   r-   r/   r0   r$   S  s    
z)Pix2StructTextLayerSelfAttention.__init__r  r  r  r	  c	              
   C   sL   |  |}	| j|	|||||||d}
|| |
d  }|f|
dd   }|S )N)r  rz   r{   r  r  r|   r  r   r   r   r   rM   )r*   r9   ry   rz   r{   r  r  r|   r  normed_hidden_statesr   r   r/   r/   r0   r:   [  s    

z(Pix2StructTextLayerSelfAttention.forward)FN)NNNNFFN	r<   r=   r>   r   r   r$   r   r:   r?   r/   r/   r-   r0   r  R  s          r  c                	       s@   e Zd Zdee d fddZedddddd
dZ  ZS )!Pix2StructTextLayerCrossAttentionNr  c                    s>   t    t|d|d| _t|j|jd| _t	|j
| _d S )NFr  r   r  )r*   rC   r   r-   r/   r0   r$   y  s    
z*Pix2StructTextLayerCrossAttention.__init__r  r  r  r	  Fc                 C   sP   |  |}| j|||||||||	|
d
}|| |d  }|f|dd   }|S )N)	r  r  rz   r{   r  r  r  r|   r  r   r   r  )r*   r9   r  ry   rz   r{   r  r  r  r|   r  r  r   r   r   r/   r/   r0   r:     s     
z)Pix2StructTextLayerCrossAttention.forward)N)NNNNFNFNr  r/   r/   r-   r0   r  x  s           r  c                       s@   e Zd Zdee d fddZedddd	dddZ  ZS )Pix2StructTextBlockFNr  c                    s6   t    t|||d| _t||d| _t|| _d S )Nr  r  )r#   r$   r  self_attentionr  encoder_decoder_attentionr   r   r   r-   r/   r0   r$     s    
zPix2StructTextBlock.__init__r  r  r  r	  Tc                 C   sL  | j |||||	|
||d}|d }|dd  }|jtjkrlt| rlt|jjd }tj|| |d}|d u}|r| j	||||||	|d d |
|d	}|d }|jtjkrt| rt|jjd }tj|| |d}||dd   }| 
|}|jtjkr>t| r>t|jjd }tj|| |d}|f}|| S )N)ry   rz   r{   r  r  r|   r  r   r   i  )rs   rt   r2   )r  ry   rz   r{   r  r  r  r|   )r   r7   r&   r8   isinfanyrr   rt   clampr!  r   )r*   r9   ry   rz   encoder_hidden_statesencoder_attention_maskencoder_decoder_position_biasr{   cross_attn_layer_head_maskr  r  r|   r   r  r   Zattention_outputsZclamp_valueZdo_cross_attentionZcross_attention_outputsr   r/   r/   r0   r:     sN    


zPix2StructTextBlock.forward)FN)NNNNNNNNFFTNr  r/   r/   r-   r0   r    s               r  z3
    The standalone text decoder of Pix2Struct
    )Zcustom_introc                       s  e Zd ZU eed< dgZdgZdZ fddZdd Z	e
deej eej eej eej eej eej eej ee ee ee ee eej ee eej eeejd
f ef dddZdeejdf ejejeedddZeejeeejejedddZ  ZS )r   rC   r  zlm_head.weightTc                    s   t    t j j| _t fddt j	D | _
t j jd| _t j| _tj j jdd| _|   d| _d S )Nc                    s"   g | ]}t  t|d k|dqS )r   r  )r  r   )r   r   r   r/   r0   r     s   z0Pix2StructTextModel.__init__.<locals>.<listcomp>r   FrV   )r#   r$   r   rH   Z
vocab_sizer+   embed_tokensr   r   
num_layersr   r!   r   final_layer_normrK   rL   rM   rF   r   r   r`   rN   r-   r   r0   r$     s    
zPix2StructTextModel.__init__c                 C   s
   || _ d S rE   )r)  r*   Znew_embeddingsr/   r/   r0   set_input_embeddings  s    z(Pix2StructTextModel.set_input_embeddingsN.)r   ry   r%  r&  inputs_embedsr   cross_attn_head_maskr  r  r|   r   labelsr   r  rD   c           )      K   s.  |	dur|	n| j j}	|
dur |
n| j j}
|dur4|n| j j}|durH|n| j j}| jrn| jrn|	rntd d}	|dur|durt	dnD|dur|
 }|d|d }n"|dur|
 dd }nt	d|du r| jdusJ d| |}|\}}|	r:|du r:| j jr.tt| j dt| j d}nt| j d}d	}|durR|d	 }n|durd| }|du rtj||| |jd
}|du r|dur| | n|}tj|||jd
}| j jr| |||t|tr|jn||
}n<|ddddddf }|j|jd}d| t|jj }|durl|
 \}}}||f}|du r`tj||jd
}| |}nd}| || j j}| || j j}|rdnd}|
rdnd}|
rdnd}d}d}|  |} t!| j"D ]\}!}"||! }#||! }$|r|| f }|"| ||||||#|$||	|
|d}%|%d	 } |%d }|durJ|%|
rDdnd }|
r||%d f }|dur||%d f }q| #| } |  | } | $| }&|r|| f }d}'|dur||&j}t%j&ddd}(|(|&' d|&
d|' d}'|st(dd |'|&||||fD S t)|'|&||||dS )a  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Pix2StructText is a model with relative position
            embeddings so you should be able to pad the inputs on both the right and the left.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for detail.

            [What are input IDs?](../glossary#input-ids)

            To know more on how to prepare `input_ids` for pretraining take a look a [Pix2StructText
            Training](./t5#training).
        cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
            `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        Example:

        ```python
        >>> from transformers import AutoProcessor, Pix2StructTextModel

        >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
        >>> model = Pix2StructTextModel.from_pretrained("google/pix2struct-textcaps-base")

        >>> inputs = processor(text="Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> loss = outputs.loss
        ```
        NzZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...FzTYou cannot specify both decoder_input_ids and decoder_inputs_embeds at the same timer2   zEYou have to specify either decoder_input_ids or decoder_inputs_embedsz<You have to initialize the model with valid token embeddingsr   r   rh   )r7   r   r/   )r{   r(  r  r  r|   r  r   r   r1      r   r6   )Zignore_indexZ	reductionc                 s   s   | ]}|d ur|V  qd S rE   r/   r   r/   r/   r0   r     s   z.Pix2StructTextModel.forward.<locals>.<genexpr>)losslogitsr  r9   r   cross_attentions)*rC   r  r|   r   r   r`   rl   r   warningr   sizerb   r)  Zis_encoder_decoderr	   r   get_seq_lengthr&   r   rh   r'   Z
is_decoder_update_causal_maskr   r  r3   r7   rr   rs   Zinvert_attention_maskr   r*  rM   r   r   r+  r   r   ZCrossEntropyLossra   r   r   ))r*   r   ry   r%  r&  r.  r   r/  r  r  r|   r   r0  r   r  kwargsZinput_shapere   r}   past_key_values_lengthZmask_seq_lengthr  Zencoder_batch_sizeZencoder_sequence_lengthr   Zencoder_hidden_shapeZencoder_extended_attention_maskr   Zall_attentionsZall_cross_attentionsrz   r'  r9   r   r   r{   r(  r   r4  r3  Zloss_fctr/   r/   r0   r:     s    3





















&zPix2StructTextModel.forwardFr   )ry   input_tensorr  r  r|   c                 C   sB  | j jdkr(|d ur$|dk r$|S d S | j jdkrLt|tjrHt|}|S |d ur\| nd}|d urn|jnd}| j jdkr|s|st	j
|||| jdrd S |j}|jd }	|r| }
n"t|tjr|jd	 n
||	 d }
| j||	|
|||jd d
}| j jdkr>|d ur>|jjdv r>|s>t|j}t	||}|S )NZflash_attention_2r   Zflex_attentionr   FZsdpa)r.  r;  Zis_trainingr   r2   )sequence_lengthtarget_lengthr7   r  re   )cudaZxpuZnpu)rC   Z_attn_implementationr#  r   r&   rS   r    r8  Zis_compileabler   Z_ignore_causal_mask_sdparl   r7   rm   Zget_max_cache_shape5_prepare_4d_causal_attention_mask_with_cache_positionrh   typerr   rs   Z_unmask_unattended)r*   ry   r<  r  r  r|   Zpast_seen_tokensZusing_compilable_cacher7   r=  r>  r  	min_dtyper/   r/   r0   r9    sZ    






	z'Pix2StructTextModel._update_causal_mask)ry   r=  r>  r7   r  re   c                 K   sF  | dur|   dkr| }n&t|j}tj||f|||jd}|dkrVtj|dd}|tj||jd|ddk9 }|ddddddf 	|ddd}| durB|
 }| jd }	|ddddddd|	f | ddddddf |j }
|
dk}
|ddddddd|	f |
||ddddddd|	f< |S )	aM  
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (`torch.Tensor`):
                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
                `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (`int`):
                The sequence length being processed.
            target_length (`int`):
                The target length: when generating with static cache, the mask should be as long as the static cache,
                to account for the 0 padding, the part of the cache that is not filled yet.
            dtype (`torch.dtype`):
                The dtype to use for the 4D attention mask.
            cache_position (`torch.Tensor`):
                Indices depicting the position of the input sequence tokens in the sequence.
            batch_size (`torch.Tensor`):
                Batch size.
        Nr2  )Z
fill_valuer7   rh   r   )Zdiagonalr1  r2   r   )ri   r&   rr   rs   r   rh   Ztriur   Zreshapeexpandr   rm   r3   rq   )ry   r=  r>  r7   r  re   r:  r  rB  Zmask_lengthZpadding_maskr/   r/   r0   r@  6  s*     $

6  zIPix2StructTextModel._prepare_4d_causal_attention_mask_with_cache_position)NNNNNNNNNNNNNN)F)r<   r=   r>   r   r   r   _tied_weights_keysr   r$   r-  r   r   r&   
LongTensorFloatTensorrS   r   r   r   r   r   r:   r9  r  r   r7   r@  r?   r/   r/   r-   r0   r     sn   
               ` Dr   zr
    A conditional generation model with a language modeling head. Can be used for sequence generation tasks.
    c                       s  e Zd ZU eed< dZdgZed fddZdd Zd	d
 Z	e
jdddZdd Zdd Zedeej eej eej eej eej eej eej eeeej   ee eej eej ee ee ee ee eej eeej ef dddZ  ZS )"Pix2StructForConditionalGenerationrC   rO   zdecoder.lm_head.weightr   c                    s8   t  | t|j| _t|j| _|j| _| 	  d S rE   )
r#   r$   r   Zvision_configr   r   r   decoderZis_vqar   rN   r-   r/   r0   r$   y  s
    z+Pix2StructForConditionalGeneration.__init__c                 C   s
   | j  S rE   )rH  r   r   r/   r/   r0   r     s    z7Pix2StructForConditionalGeneration.get_input_embeddingsc                 C   s   | j | d S rE   )rH  r-  r,  r/   r/   r0   r-    s    z7Pix2StructForConditionalGeneration.set_input_embeddings)rD   c                 C   s
   | j  S rE   )rH  get_output_embeddingsr   r/   r/   r0   rI    s    z8Pix2StructForConditionalGeneration.get_output_embeddingsc                 C   s   | j | d S rE   )rH  set_output_embeddingsr,  r/   r/   r0   rJ    s    z8Pix2StructForConditionalGeneration.set_output_embeddingsc                 C   s   | j S rE   )r   r   r/   r/   r0   get_encoder  s    z.Pix2StructForConditionalGeneration.get_encoderN)rO   ry   r   r   r   decoder_head_maskr/  r   r  r0  decoder_inputs_embedsr  r|   r   r   r  rD   c                 C   sJ  |dur|n| j jj}|dur"|n| j j}|du rJ| j||||||d}nH|rt|tst|d t|dkrt|d ndt|dkr|d ndd}|d }|
dur|du r|du r| |
}|dur|n|	| j j
 }d|dddf< | j||||	||||||||
||d}|s|| S t|j|j|j|j|j|j|j|j|jd	S )	a  
        flattened_patches (`torch.FloatTensor` of shape `(batch_size, seq_length, hidden_size)`):
            Flattened pixel patches. the `hidden_size` is obtained by the following formula: `hidden_size` =
            `num_channels` * `patch_size` * `patch_size`

            The process of flattening the pixel patches is done by `Pix2StructProcessor`.
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            Pix2StructText uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
            `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
            `past_key_values`).

            To know more on how to prepare `decoder_input_ids` for pretraining take a look at [Pix2StructText
            Training](./t5#training).
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
            1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
            `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss for the decoder.

        Example:

        Inference:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration

        >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
        >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")

        >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> # autoregressive generation
        >>> generated_ids = model.generate(**inputs, max_new_tokens=50)
        >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        >>> print(generated_text)
        A stop sign is on a street corner.

        >>> # conditional generation
        >>> text = "A picture of"
        >>> inputs = processor(text=text, images=image, return_tensors="pt", add_special_tokens=False)

        >>> generated_ids = model.generate(**inputs, max_new_tokens=50)
        >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        >>> print(generated_text)
        A picture of a stop sign with a red stop sign
        ```

        Training:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration

        >>> processor = AutoProcessor.from_pretrained("google/pix2struct-base")
        >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-base")

        >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> text = "A stop sign is on the street corner."

        >>> inputs = processor(images=image, return_tensors="pt")
        >>> labels = processor(text=text, return_tensors="pt").input_ids

        >>> # forward pass
        >>> outputs = model(**inputs, labels=labels)
        >>> loss = outputs.loss
        >>> print(f"{loss.item():.5f}")
        5.94282
        ```N)rO   ry   r   r|   r   r   r   r   r1   r   )r   ry   r.  r  r%  r&  r   r/  r  r|   r   r0  r   r  )	r3  r4  r  Zdecoder_hidden_statesZdecoder_attentionsr5  Zencoder_last_hidden_stater%  Zencoder_attentions)rC   r   r  r   r   r   r   lenr   ner   r   rH  r   r3  r4  r  r9   r   r5  r   )r*   rO   ry   r   r   r   rL  r/  r   r  r0  rM  r  r|   r   r   r  r9   Zdecoder_outputsr/   r/   r0   r:     sl    r
z*Pix2StructForConditionalGeneration.forward)NNNNNNNNNNNNNNNN)r<   r=   r>   r   r   r   rD  r$   r   r-  r   ModulerI  rJ  rK  r   r   r&   rF  rE  Z
BoolTensorrS   r   r   r   r   r   r:   r?   r/   r/   r-   r0   rG  o  sZ   
                rG  )r   rG  r   r   )IrR   r   typingr   r   r&   Ztorch.utils.checkpointr   Zactivationsr   Zcache_utilsr   r   r	   Z
generationr
   Zmodeling_attn_mask_utilsr   Zmodeling_layersr   Zmodeling_outputsr   r   r   r   r   Zmodeling_utilsr   utilsr   r   r   r   r   r   r   Zutils.deprecationr   Zconfiguration_pix2structr   r   r   Z!torch.nn.attention.flex_attentionr   Zintegrations.flex_attentionr    Z
get_loggerr<   r   rP  r!   Zapex.normalizationr@   infoImportError	Exceptionr6  rA   rT   r   r   r   r   r   r   r   r   r  r  r  r   rG  __all__r/   r/   r/   r0   <module>   sr   $	

$b+,|p Y&'X  u U