a
    h%F                 	   @   s  d Z ddlZddlmZ ddlmZmZmZmZ ddl	Z	ddl
Z	ddl	mZ ddlmZ ddlmZmZmZ dd	lmZ dd
lmZ ddlmZ ddlmZmZmZmZ ddlmZmZ ddl m!Z! ddl"m#Z#m$Z$m%Z%m&Z&m'Z'm(Z( ddl)m*Z* ddl+m,Z,m-Z-m.Z. e'/e0Z1dNe	j2e	j3ee4 dddZ5dOe	j6e	j3e	j7e4dddZ8dPddZ9ee%ddG dd de#Z:ee%ddG d d! d!e#Z;G d"d# d#ej<Z=dQej<e	j2e	j2e	j2ee	j2 e>e>d%d&d'Z?G d(d) d)ej<Z@G d*d+ d+ej<ZAG d,d- d-eZBG d.d/ d/ej<ZCG d0d1 d1ej<ZDG d2d3 d3ej<ZEG d4d5 d5ej<ZFG d6d7 d7ej<ZGG d8d9 d9eZHG d:d; d;ej<ZIe%G d<d= d=eZJG d>d? d?eJZKG d@dA dAeJZLe%dBdG dCdD dDeJeZMG dEdF dFej<ZNe%dGdG dHdI dIeJZOe%dJdG dKdL dLeJeZPg dMZQdS )RzPyTorch KOSMOS-2 model.    N)	dataclass)AnyCallableOptionalUnion)nn   )ACT2FN)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)FlashAttentionKwargs)GradientCheckpointingLayer)BaseModelOutput)BaseModelOutputWithPastAndCrossAttentionsBaseModelOutputWithPooling!CausalLMOutputWithCrossAttentions)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)ModelOutputTransformersKwargsauto_docstringcan_return_tuplelogging	torch_int)deprecate_kwarg   )Kosmos2ConfigKosmos2TextConfigKosmos2VisionConfig)maskdtypetgt_lenc                 C   sj   |   \}}|dur|n|}| ddddddf |d|||}d| }||tjt|jS )z_
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    Nr         ?)sizeexpandtoZmasked_filltorchboolfinfomin)r"   r#   r$   bszZsrc_lenZexpanded_maskZinverted_mask r.   h/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/kosmos2/modeling_kosmos2.py_expand_mask.   s
    *r0   )input_ids_shaper#   devicepast_key_values_lengthc                 C   s   | \}}t j||ft |j|d}t j|d|d}|||d |ddk d ||}|dkrt j	t j
||||d|gdd}|ddddddf |d||| S )zB
    Make causal mask used for bi-directional self-attention.
    )r2   r   r   r#   r2   dimN)r)   fullr+   r,   aranger&   Zmasked_fill_viewr(   catzerosr'   )r1   r#   r2   r3   r-   r$   r"   Z	mask_condr.   r.   r/   _make_causal_mask<   s    "
 r=   c                 C   s6   |  | }tj|dd|| | }| | S )a  
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's `utils.make_positions`.

    Args:
        x: torch.Tensor x:

    Returns: torch.Tensor
    r   r6   )neintr)   ZcumsumZtype_aslong)	input_idspadding_idxr3   r"   Zincremental_indicesr.   r.   r/   "create_position_ids_from_input_idsN   s    rC   ze
    Base class for text model's outputs that also contains a pooling of the last hidden states.
    )Zcustom_introc                   @   s   e Zd ZU dZdZeej ed< dZ	ee
e
ej   ed< dZee
ej  ed< dZee
ej  ed< dZeej ed< dZee
ej  ed< dZeed	< e
e d
ddZdS )Kosmos2ModelOutputa  
    past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
        `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
        `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
        encoder_sequence_length, embed_size_per_head)`.

        Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
        `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
        input) to speed up sequential decoding.
    image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
        Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
    projection_attentions (`tuple(torch.FloatTensor)`, *optional*):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute
        the weighted average in the self-attention heads.
    vision_model_output (`BaseModelOutputWithPooling`, *optional*):
        The output of the [`Kosmos2VisionModel`].
    Nlast_hidden_statepast_key_valueshidden_states
attentionsimage_embedsprojection_attentionsvision_model_outputreturnc                    s   t  fdd  D S )Nc                 3   s,   | ]$}|d vr | nt  | V  qdS )Ztext_model_outputrK   Ngetattrto_tuple.0kselfr.   r/   	<genexpr>   s   z.Kosmos2ModelOutput.to_tuple.<locals>.<genexpr>tuplekeysrU   r.   rU   r/   rQ      s    zKosmos2ModelOutput.to_tuple)__name__
__module____qualname____doc__rE   r   r)   FloatTensor__annotations__rF   rY   rG   rH   rI   rJ   rK   r   r   rQ   r.   r.   r.   r/   rD   ^   s   
rD   zC
    Model output class for `Kosmos2ForConditionalGeneration`.
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeeej   ed< dZeeej  ed< dZeeej  ed< dZeej ed< dZeeej  ed	< dZeed
< ee dddZdS )*Kosmos2ForConditionalGenerationModelOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss (for next-token prediction).
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
    past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
        `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
        `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
        encoder_sequence_length, embed_size_per_head)`.

        Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
        `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
        input) to speed up sequential decoding.
    image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
        Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
    projection_attentions (`tuple(torch.FloatTensor)`, *optional*):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute
        the weighted average in the self-attention heads.
    vision_model_output (`BaseModelOutputWithPooling`, *optional*):
        The output of the [`Kosmos2VisionModel`].
    NlosslogitsrF   rG   rH   rI   rJ   rK   rL   c                    s   t  fdd  D S )Nc                 3   s,   | ]$}|d vr | nt  | V  qdS rN   rO   rR   rU   r.   r/   rW      s   zFKosmos2ForConditionalGenerationModelOutput.to_tuple.<locals>.<genexpr>rX   rU   r.   rU   r/   rQ      s    z3Kosmos2ForConditionalGenerationModelOutput.to_tuple)r[   r\   r]   r^   rb   r   r)   r_   r`   rc   rF   rY   rG   rH   rI   rJ   rK   r   r   rQ   r.   r.   r.   r/   ra      s   
ra   c                       sP   e Zd Zed fddZejeeejdddZdej	ejdd	d
Z
  ZS )Kosmos2VisionEmbeddingsconfigc                    s   t    || _|j| _|j| _|j| _tt	
| j| _tj|j| j| j| jdd| _| j| j d | _| jd | _t| j| j| _| jdt	| jddd d S )NF)Zin_channelsZout_channelsZkernel_sizeZstridebias   r   position_ids)r   r4   
persistent)super__init__rf   hidden_size	embed_dim
image_size
patch_sizer   	Parameterr)   randnclass_embeddingZConv2dZnum_channelspatch_embeddingnum_patchesnum_positions	Embeddingposition_embeddingregister_bufferr9   r'   rV   rf   	__class__r.   r/   rm      s"    
z Kosmos2VisionEmbeddings.__init__)
embeddingsheightwidthrM   c                 C   s  |j d d }| jjd}|j d d }tj sP||krP||krP| | jS |ddddf }|ddddf }|j d }	|| j }
|| j }t	|d }|
d|||	}|dddd}tjj||
|fdd	d
}|dddddd|	}tj||fddS )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   r   Nr4   g      ?r   rh   ZbicubicF)r&   modeZalign_cornersr6   )shapery   weight	unsqueezer)   Zjit
is_tracingri   rq   r   reshapeZpermuter   
functionalZinterpolater:   r;   )rV   r~   r   r   rv   ry   rw   Zclass_pos_embedZpatch_pos_embedr7   Z
new_heightZ	new_widthZsqrt_num_positionsr.   r.   r/   interpolate_pos_encoding   s*    



z0Kosmos2VisionEmbeddings.interpolate_pos_encodingF)pixel_valuesrM   c              
   C   s   |j \}}}}|sL|| jks&|| jkrLtd| d| d| j d| j d	| jjj}| |j|d}|ddd}| j	
|dd}	tj|	|gdd	}
|r|
| |
|| }
n|
| | j }
|
S )
NzInput image size (*z) doesn't match model ().r#   rh   r   r4   r6   )r   rp   
ValueErrorru   r   r#   r(   flatten	transposert   r'   r)   r;   r   ry   ri   )rV   r   r   
batch_size_r   r   Ztarget_dtypeZpatch_embedsZclass_embedsr~   r.   r.   r/   forward   s     
zKosmos2VisionEmbeddings.forward)F)r[   r\   r]   r!   rm   r)   Tensorr?   r   r_   r   __classcell__r.   r.   r|   r/   rd      s   )rd           )modulequerykeyvalueattention_maskscalingdropoutc           
      K   sp   t ||dd| }|d ur(|| }tjj|dd}tjj||| jd}t ||}	|	dd }	|	|fS )Nr4   r6   ptrainingr   rh   )	r)   matmulr   r   r   Zsoftmaxr   r   
contiguous)
r   r   r   r   r   r   r   kwargsattn_weightsattn_outputr.   r.   r/   eager_attention_forward  s    
r   c                	       sZ   e Zd ZdZ fddZd	ejeej eej ee e	ejeej f dddZ
  ZS )
Kosmos2VisionAttention=Multi-headed attention from 'Attention Is All You Need' paperc                    s   t    || _|j| _|j| _| j| j | _| j| j | jkrZtd| j d| j d| jd | _	|j
| _d| _t| j| j| _t| j| j| _t| j| j| _t| j| j| _d S )N;embed_dim must be divisible by num_heads (got `embed_dim`:  and `num_heads`: r         F)rl   rm   rf   rn   ro   Znum_attention_heads	num_headshead_dimr   scaleattention_dropoutr   	is_causalr   Lineark_projv_projq_projout_projr{   r|   r.   r/   rm   )  s$    

zKosmos2VisionAttention.__init__NFrG   r   causal_attention_maskoutput_attentionsrM   c              
   C   sP  |j \}}}| |}| |}	| |}
|||| j| jdd}|	||| j| jdd}	|
||| j| jdd}
| jj	dkr|dur|dur|| }q|dur|}n
|du| _
t}| jj	dkr| jj	dkr|rtd nt| jj	 }|| ||	|
|| j
| j| jsdn| jd	\}}|||| }| |}|sHd}||fS )
#Input shape: Batch x Time x Channelr   rh   Zflash_attention_2Neagersdpa`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.r   )r   r   r   )r   r   r   r   r:   r   r   r   rf   _attn_implementationr   r   loggerwarning_oncer   r   r   r   r   r   r   )rV   rG   r   r   r   r   
seq_lengthro   ZqueriesrZ   valuesattention_interfacer   r   r.   r.   r/   r   =  sF    	






zKosmos2VisionAttention.forward)NNF)r[   r\   r]   r^   rm   r)   r   r   r*   rY   r   r   r.   r.   r|   r/   r   &  s      r   c                       s0   e Zd Z fddZejejdddZ  ZS )Kosmos2VisionMLPc                    sD   t    || _t|j | _t|j|j	| _
t|j	|j| _d S N)rl   rm   rf   r	   Z
hidden_actactivation_fnr   r   rn   Zintermediate_sizefc1fc2r{   r|   r.   r/   rm   w  s
    
zKosmos2VisionMLP.__init__)rG   rM   c                 C   s"   |  |}| |}| |}|S r   )r   r   r   rV   rG   r.   r.   r/   r   ~  s    


zKosmos2VisionMLP.forward)r[   r\   r]   rm   r)   r   r   r   r.   r.   r|   r/   r   v  s   r   c                       sJ   e Zd Zed fddZdejejejee e	ej
 dddZ  ZS )	Kosmos2VisionEncoderLayerre   c                    sR   t    |j| _t|| _tj| j|jd| _	t
|| _tj| j|jd| _d S Neps)rl   rm   rn   ro   r   	self_attnr   	LayerNormlayer_norm_epslayer_norm1r   mlplayer_norm2r{   r|   r.   r/   rm     s    


z"Kosmos2VisionEncoderLayer.__init__Fr   c                 C   sd   |}|  |}| j||||d\}}|| }|}| |}| |}|| }|f}|r`||f7 }|S )aI  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                `(config.encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )rG   r   r   r   )r   r   r   r   )rV   rG   r   r   r   residualr   outputsr.   r.   r/   r     s"    




z!Kosmos2VisionEncoderLayer.forward)F)r[   r\   r]   r!   rm   r)   r   r   r*   rY   r_   r   r   r.   r.   r|   r/   r     s    r   c                
       sd   e Zd ZdZed fddZed	eej	 eej	 ee
 ee
 ee
 eeef dddZ  ZS )
Kosmos2VisionEncoderz
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`Kosmos2VisionEncoderLayer`].

    Args:
        config: Kosmos2VisionConfig
    re   c                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r.   )r   )rS   r   re   r.   r/   
<listcomp>      z1Kosmos2VisionEncoder.__init__.<locals>.<listcomp>F)	rl   rm   rf   r   
ModuleListrangenum_hidden_layerslayersgradient_checkpointingr{   r|   re   r/   rm     s    
 zKosmos2VisionEncoder.__init__N)r   r   r   output_hidden_statesreturn_dictrM   c                 C   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}|rDdnd}|rPdnd}|}	t| jD ]@\}
}|rx||	f }||	|||d}|d }	|rb||d f }qb|r||	f }t|	||dS )a  
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Causal mask for the text model. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr.   )r   r   r   )rE   rG   rH   )rf   r   r   use_return_dict	enumerater   r   )rV   inputs_embedsr   r   r   r   r   Zencoder_statesZall_attentionsrG   idxZencoder_layerlayer_outputsr.   r.   r/   r     s0    '

zKosmos2VisionEncoder.forward)NNNNN)r[   r\   r]   r^   r!   rm   r   r   r)   r   r*   r   rY   r   r   r   r.   r.   r|   r/   r     s         
r   c                	       sV   e Zd Zed fddZd	eej ee ee eee e	e
ef dddZ  ZS )
Kosmos2VisionTransformerre   c                    sR   t    || _|j}t|| _tj||jd| _	t
|| _tj||jd| _d S r   )rl   rm   rf   rn   rd   r~   r   r   r   pre_layrnormr   encoderpost_layernorm)rV   rf   ro   r|   r.   r/   rm     s    


z!Kosmos2VisionTransformer.__init__NFr   r   r   r   r   rM   c           
      C   s   |d ur|n| j j}|d ur |n| j j}|d ur4|n| j j}|d u rLtd| j||d}| |}| j||||d}|d }|d d dd d f }	| |	}	|s||	f|dd   S t	||	|j
|jdS )Nz You have to specify pixel_values)r   )r   r   r   r   r   r   )rE   Zpooler_outputrG   rH   )rf   r   r   r   r   r~   r   r   r   r   rG   rH   )
rV   r   r   r   r   r   rG   Zencoder_outputsrE   Zpooled_outputr.   r.   r/   r     s2    

z Kosmos2VisionTransformer.forward)NNNFN)r[   r\   r]   r!   rm   r   r)   r_   r*   r   rY   r   r   r   r.   r.   r|   r/   r     s        
r   c                       s   e Zd ZdZdeeee d fddZdeeee dddZedeeee dd	d
Z	e
 dee
j ee
j eee
j dddZdd Z  ZS )(Kosmos2TextSinusoidalPositionalEmbeddingzDThis module produces sinusoidal positional embeddings of any length.Nrw   embedding_dimrB   c                    s4   t    d| _|| _|| _| || j || d S )Nrh   )rl   rm   offsetr   rB   make_weights)rV   rw   r   rB   r|   r.   r/   rm   L  s
    
z1Kosmos2TextSinusoidalPositionalEmbedding.__init__)num_embeddingsr   rB   c                 C   sB   |  |||}t| dr.|j| jj| jjd}| jd|dd d S )Nweightsr5   Frj   )get_embeddinghasattrr(   r   r#   r2   rz   )rV   r   r   rB   Zemb_weightsr.   r.   r/   r   T  s    
z5Kosmos2TextSinusoidalPositionalEmbedding.make_weightsc                 C   s   |d }t d|d  }ttj|tjd |  }tj| tjd d|d }tjt	|t
|gdd| d}|d dkrtj|t| dgdd}|durd||ddf< |t S )	z
        Build sinusoidal embeddings.

        This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
        "Attention Is All You Need".
        rh   i'  r   r   r   r6   r4   N)mathlogr)   expr9   Zint64floatr   r;   sincosr:   r<   r(   Zget_default_dtype)r   r   rB   Zhalf_dimZembr.   r.   r/   r   \  s    	 $&z6Kosmos2TextSinusoidalPositionalEmbedding.get_embeddingr   rA   r   r3   ri   c                 C   s   |d ur4|  \}}|d u r\t|| j||j}n(|  d d \}}|d u r\| ||}| jd | | }|| j dkr| || j | j	| j | j
d|d||| jjd  S )Nr4   r   r   )r&   rC   rB   r(   r2   &create_position_ids_from_inputs_embedsr   r   r   r   Zindex_selectr:   r   detach)rV   rA   r   r3   ri   r-   seq_lenZmax_posr.   r.   r/   r   r  s    z0Kosmos2TextSinusoidalPositionalEmbedding.forwardc                 C   sV   |  dd }|d }tj| jd || j d tj|jd}|d| | S )z
        We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.

        Args:
            inputs_embeds: torch.Tensor

        Returns: torch.Tensor
        Nr4   r   r5   r   )	r&   r)   r9   rB   r@   r2   r   r'   r   )rV   r   r3   input_shapeZsequence_lengthri   r.   r.   r/   r     s    	zOKosmos2TextSinusoidalPositionalEmbedding.create_position_ids_from_inputs_embeds)N)N)N)NNr   N)r[   r\   r]   r^   r?   r   rm   r   staticmethodr   r)   no_gradr   r   r   r   r.   r.   r|   r/   r   H  s"       r   c                       s   e Zd ZdZdeeeee ee ee ee d fddZe	d	d
ddde
jee
j ee ee
j ee
j eee
j ee
jee
j ee f dddZ  ZS )KosmosTextAttentionr   r   FTN)ro   r   r   
is_decoderadd_inner_attn_layernormrg   	layer_idxc	           	         s   t    || _|| _|| _|| _|| | _| j| | jkrTtd| j d| d| jd | _|| _	|| _
tj|||d| _tj|||d| _tj|||d| _tj|||d| _d | _|rtj||jd| _d S )Nr   r   r   r   )rg   r   )rl   rm   rf   ro   r   r   r   r   r   r   r  r   r   r   r   r   r   inner_attn_lnr   r   )	rV   rf   ro   r   r   r   r  rg   r  r|   r.   r/   rm     s,    


zKosmosTextAttention.__init__past_key_valuerF   4.58new_nameversion)rG   encoder_hidden_statesrF   r   layer_head_maskr   cache_positionrM   c                 K   s  |du}	|j dd \}
}| |}||
|| j| jdd}|durxt|trt|j	| j
}|	rl|j}qx|j}n|}|	r|n|}|	r|dur|r|j| j
 j}|j| j
 j}n| |}| |}||
d| j| jdd}||
d| j| jdd}|durB|	s|nd}|||| j
d|i\}}|	rBd|j| j
< t}| jjdkr| jjdkrt|rttd	 nt| jj }|| ||||f| jsd
n| j| jd|\}}||
|d }| jdur| |}| |}||fS )r   Nrh   r   r4   r  Tr   r   r   r   )r   r   ) r   r   r:   r   r   r   
isinstancer   
is_updatedgetr  Zcross_attention_cacheZself_attention_cacher   rZ   r   r   r   updater   rf   r   r   r   r   r   r   r   r   r   r  r   )rV   rG   r	  rF   r   r
  r   r  r   Zis_cross_attentionr   r   Zquery_statesr  Zcurr_past_key_valueZcurrent_statesZ
key_statesZvalue_statesr   r   r   r.   r.   r/   r     sd    







zKosmosTextAttention.forward)r   FFTN)NNNNFN)r[   r\   r]   r^   r?   r   r   r*   rm   r   r)   r   r
   rY   r   r   r.   r.   r|   r/   r     s@        %      r   c                       s*   e Zd Zed fddZdd Z  ZS )Kosmos2TextFFNre   c                    sb   t    |j| _t|j | _|j| _t|j	|j
| _t|j
|j	| _tj|j
|jd| _d S r   )rl   rm   r   r	   Zactivation_functionr   activation_dropoutr   r   ro   Zffn_dimr   r   r   r   ffn_layernormr{   r|   r.   r/   rm     s    
zKosmos2TextFFN.__init__c                 C   sT   |  | |}tjj|| j| jd}| |}| |}tjj|| j| jd}|S )Nr   )	r   r   r   r   r   r  r   r  r   r   r.   r.   r/   r   +  s    

zKosmos2TextFFN.forward)r[   r\   r]   r    rm   r   r   r.   r.   r|   r/   r    s   r  c                       s   e Zd Zded fddZedddddejeej eej eej eej eej ee	 ee
 ee
 eej eejeeejejf  f dddZ  ZS )Kosmos2TextBlockNre   c              	      s   t    |j| _t|| j|j|jdd|d| _|j| _tj	| j|j
d| _|jrt|| j|j|jdd|d| _tj	| j|j
d| _t|| _tj	| j|j
d| _d S )NT)ro   r   r   r   r  r  r   F)rl   rm   ro   r   attention_headsr   r   r   r   r   r   self_attn_layer_normZadd_cross_attentionencoder_attnencoder_attn_layer_normr  ffnfinal_layer_norm)rV   rf   r  r|   r.   r/   rm   6  s4    
		
zKosmos2TextBlock.__init__r  rF   r  r  FT)rG   r   r	  encoder_attention_maskr
  cross_attn_layer_head_maskrF   r   	use_cacher  rM   c              
   K   s   |}|  |}| jf ||||||
d|\}}tjj|| j| jd}|| }d }|d urt| dsttd|  d|}| |}| j	f |||||||
d|\}}tjj|| j| jd}|| }|}| 
|}| |}|| }|f}|r|||f7 }|S )N)rG   rF   r   r
  r   r  r   r  z'If `encoder_hidden_states` are passed, z` has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`)rG   r	  r   r
  rF   r   r  )r  r   r   r   r   r   r   r   r  r  r  r  )rV   rG   r   r	  r  r
  r  rF   r   r  r  r   r   Zself_attn_weightsZcross_attn_weightsr   r.   r.   r/   r   U  sV    

	






zKosmos2TextBlock.forward)N)	NNNNNNFTN)r[   r\   r]   r    rm   r   r)   r   r   r
   r*   rY   r_   r   r   r.   r.   r|   r/   r  5  s0            r  c                       s   e Zd ZdZed fddZdd Zdeej	 eej	 eej	 e
eej	 d	d
dZdeej	 eej	 eej	 eej	 eej	 eej	 eej	 eej	 eeej  eej	 eej	 ee ee ee ee eej	 ee eeef dddZ  ZS )Kosmos2TextTransformerz
    Transformer decoder consisting of `config.layers` layers. Each layer is a [`Kosmos2TextBlock`].

    Args:
        config: Kosmos2TextConfig
    re   c                    s   t     | _ j| _ j| _ jr2t jnd| _	t
j j j jd| _t j j jd| _t
 fddt jD | _t
 j j| _d| _d S )Nr%   )rB   r   c                    s   g | ]}t  |d qS ))r  )r  )rS   ire   r.   r/   r     r   z3Kosmos2TextTransformer.__init__.<locals>.<listcomp>F)rl   rm   rf   r   	layerdropZscale_embeddingr   sqrtro   embed_scaler   rx   
vocab_sizeZpad_token_idembed_tokensr   Zmax_position_embeddingsembed_positionsr   r   r   r   r   
layer_normr   r{   r|   re   r/   rm     s    
 zKosmos2TextTransformer.__init__c                 C   s`   d }|d dkr$t ||j|j|d}|d ur\t||j|d d|j}|d u rT|n|| }|S )Nr4   r   )r2   r3   r$   )r=   r#   r2   r0   r(   )rV   r   r   r   r3   Zcombined_attention_maskZexpanded_attn_maskr.   r.   r/   _prepare_decoder_attention_mask  s    z6Kosmos2TextTransformer._prepare_decoder_attention_maskNr   )r   rI   img_input_maskr3   ri   c           	      C   s   |d u r|  |}|d urB||jd|d||jtjd< || j }| j||||d}||j}|| }t	j
j|| j| jd}|S )Nr4   r   r   r   )r#  r(   r2   r:   r&   r)   r*   r!  r$  r   r   r   r   )	rV   rA   r   rI   r(  r3   ri   Z	positionsrG   r.   r.   r/   forward_embedding  s"    



z(Kosmos2TextTransformer.forward_embeddingrA   r   rI   image_embeds_position_maskr	  r  	head_maskcross_attn_head_maskrF   r   ri   r  r   r   r   r  r   rM   c                 K   s0  |d ur|n| j j}|d ur |n| j j}|d ur4|n| j j}|d urV|
d urVtdnB|d urv|j}|d|d }n"|
d ur|
 d d }ntd| jr| j	r|rt
d d}|r|	d u r|d urtt| j dt| j dn
t| j d}	|rt|	trt
d t|	}	|	d ur(|	 nd}|dkr>d }d }| j||
||||d	}| ||||}|d ur|d urt||
j|d d
}tjj|| j| j	d}|rdnd }|rdnd }|r|d urdnd }t||gddgD ]V\}}|d ur| d t| jkrtd| dt| j d| d  dqt| jD ]\}}|rb||f7 }| j	rtg }|| jk rqJ||||f||d ur|| nd |d ur|| nd |	|||d|}|d }|rJ||d f7 }|d urJ||d f7 }qJ|  |}|r||f7 }t!||	|||dS )NzDYou cannot specify both input_ids and inputs_embeds at the same timer4   z5You have to specify either input_ids or inputs_embedszZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fre   zPassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.r   )rA   r   rI   r(  r3   ri   r&  r   r.   r,  r-  zThe `z` should be specified for z layers, but it is for .)r  r
  r  rF   r   r  r  r   rh   )rE   rF   rG   rH   cross_attentions)"rf   r   r   r  r   r   r:   r&   r   r   r   r   r   r   r  rY   Zfrom_legacy_cacheZget_seq_lengthr)  r'  r0   r#   r   r   r   ziplenr   r   r)   Zrandr  r%  r   )rV   rA   r   rI   r+  r	  r  r,  r-  rF   r   ri   r  r   r   r   r  r   r   r3   rG   Zall_hidden_statesZall_self_attnsZall_cross_attentionsZ	attn_maskZ	mask_namer   Zdecoder_layerZdropout_probabilityr   r.   r.   r/   r     s    



	






zKosmos2TextTransformer.forward)NNNr   N)NNNNNNNNNNNNNNNN)r[   r\   r]   r^   r    rm   r'  r   r)   r   r?   r)  listr_   r*   r   r   r   rY   r   r   r   r.   r.   r|   r/   r    sf        %                
r  c                   @   s>   e Zd ZU eed< dZddgZdZdZdZ	e
jdddZdS )	Kosmos2PreTrainedModelrf   Tr   r  )r   c                 C   s6  t | tr| jj}nt | ttfr,| jjj}t | ttfrD| jj	}nt | ttfr\| jj
j	}t |trtjj|jd|jd | d tjj|jj|jj| d tjj|jj|jj| d nLt |trL|jd d|jj d  | }|jd | }tjj|jj|d tjj|jj|d tjj|jj|d tjj|jj|d nt |tr|jjd d|jj d  | }d|jj d | }tjj|jj|d tjj|jj|d nTt |trtjj|jj|d tjj|jj|d tjj|jj|d tjj|jj|d nt |t rLtjj|jj|d tjj|jj|d nt |trntjj|j!j|d nt |t"rtjj|j#j|d tj|j$ nnt |t%r|j&jj'jd|d |j&j(dur|j&jj'|j&j( )  n(t |tj*r|jj'+d |j,j')  t |tj-r2|j,dur2|j,j')  dS )zInitialize the weightsr   r   )meanstd)r5  rh   Nr%   ).r  Kosmos2VisionModelrf   Zinitializer_factorKosmos2ModelKosmos2ForConditionalGenerationvision_configKosmos2TextModelKosmos2TextForCausalLMZinit_stdtext_configrd   r   initZnormal_rt   ro   ru   r   Zinitializer_rangery   r   r   r   r   r   r   r   rn   r   r   r   r  lm_headKosmos2ImageToTextProjectiondenselatent_queryr  r#  datarB   Zzero_r   Zfill_rg   r   )rV   r   factorr5  Zin_proj_stdZout_proj_stdZfc_stdr.   r.   r/   _init_weights  s\    





  z$Kosmos2PreTrainedModel._init_weightsN)r[   r\   r]   r   r`   Zsupports_gradient_checkpointingZ_no_split_modulesZ_supports_attention_backendZ_supports_flash_attnZ_supports_sdpar   ModulerD  r.   r.   r.   r/   r3    s   
r3  c                
       sx   e Zd ZU eed< dZed fddZejdddZ	e
deej ee ee eee eeef dddZ  ZS )r6  rf   r   re   c                    s"   t  | t|| _|   d S r   )rl   rm   r   model	post_initr{   r|   r.   r/   rm     s    
zKosmos2VisionModel.__init__rL   c                 C   s
   | j jjS r   )rF  r~   ru   rU   r.   r.   r/   get_input_embeddings  s    z'Kosmos2VisionModel.get_input_embeddingsNFr   c                 C   s   | j |||||dS )N)r   r   r   r   r   rF  )rV   r   r   r   r   r   r.   r.   r/   r     s    	zKosmos2VisionModel.forward)NNNFN)r[   r\   r]   r!   r`   main_input_namerm   r   rE  rH  r   r   r)   r_   r*   r   rY   r   r   r   r.   r.   r|   r/   r6    s$   
     
r6  c                       s   e Zd ZU eed< ed fddZejdddZe	e
deej eej eej eej eej eej eej eej eeej  eej eej ee ee ee ee eej ee eeef d	d
dZ  ZS )r:  rf   re   c                    s"   t  | t|| _|   d S r   )rl   rm   r  rF  rG  r{   r|   r.   r/   rm     s    
zKosmos2TextModel.__init__rL   c                 C   s   | j jS r   rF  r#  rU   r.   r.   r/   rH    s    z%Kosmos2TextModel.get_input_embeddingsNr*  c                 K   s2   | j f |||||||||	|
||||||d|S )a  
        image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
        image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0,
            1]`:

            - 1 for places where to put the image features,
            - 0 for places that are not for image features (i.e. for text tokens).
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        rA   r   rI   r+  r	  r  r,  r-  rF   r   ri   r  r   r   r   r  rI  )rV   rA   r   rI   r+  r	  r  r,  r-  rF   r   ri   r  r   r   r   r  r   r.   r.   r/   r     s(    %zKosmos2TextModel.forward)NNNNNNNNNNNNNNNN)r[   r\   r]   r    r`   rm   r   rE  rH  r   r   r   r)   r   r2  r_   r*   r   r   r   rY   r   r   r   r.   r.   r|   r/   r:    sR   
                
r:  z
    The text model from KOSMOS-2 with a language modeling head on top (linear layer with weights tied to the input
    embeddings).
    c                       s
  e Zd ZU eed< dgZed fddZejdddZ	ejdd	d
Z
eedeej eej eej eej eej eej eej eej eeej  eej eej eej ee ee ee ee eej ee eeef dddZd fdd	Z  ZS )r;  rf   zlm_head.weightre   c                    s8   t  | t|| _tj|j|jdd| _| 	  d S )NF)Zin_featuresZout_featuresrg   )
rl   rm   r  rF  r   r   ro   r"  r>  rG  r{   r|   r.   r/   rm   1  s    
zKosmos2TextForCausalLM.__init__rL   c                 C   s   | j jS r   rK  rU   r.   r.   r/   rH  :  s    z+Kosmos2TextForCausalLM.get_input_embeddingsc                 C   s   | j S r   )r>  rU   r.   r.   r/   get_output_embeddings=  s    z,Kosmos2TextForCausalLM.get_output_embeddingsN)rA   r   rI   r+  r	  r  r,  r-  rF   r   ri   labelsr  r   r   r   r  r   rM   c                 K   s   |dur|n| j j}|dur.|r*td d}| jf |||||||||	|
||||d|d|}| |d }d}|dur| jf ||| j jd|}t|||j	|j
|j|jdS )	aK  
        image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
        image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0,
            1]`:

            - 1 for places where to put the image features,
            - 0 for places that are not for image features (i.e. for text tokens).
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
            `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
            ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        NzJThe `use_cache` argument is changed to `False` since `labels` is provided.FTrL  r   )rc   rN  r"  )rb   rc   rF   rG   rH   r/  )rf   r   r   warningrF  r>  Zloss_functionr"  r   rF   rG   rH   r/  )rV   rA   r   rI   r+  r	  r  r,  r-  rF   r   ri   rN  r  r   r   r   r  r   r   Z	lm_logitsrb   r.   r.   r/   r   @  sJ    *
zKosmos2TextForCausalLM.forwardc	              
      s   |d dkrd }d }nb|d urx|d ur6|  d d n|  \}
}|  d }tj|tj|
|| ftj|jdfdd}t j|f|||||||d|	}|dd  |S )Nr   r4   )r&   r#   r2   r   r6   )rF   r   rI   r+  r   r  r  ri   )	r&   r)   r;   r<   r*   r2   rl   prepare_inputs_for_generationpop)rV   rA   rI   r+  rF   r   r   r  r  Zmodel_kwargsr   r   Zmask_lenZmodel_inputsr|   r.   r/   rP    s6    $	z4Kosmos2TextForCausalLM.prepare_inputs_for_generation)NNNNNNNNNNNNNNNNN)NNNNNNN)r[   r\   r]   r    r`   _tied_weights_keysrm   r   rE  rH  rM  r   r   r   r)   r   r2  r_   
LongTensorr*   r   r   r   rY   r   r   rP  r   r.   r.   r|   r/   r;  '  sj   
	                 
T       r;  c                       s.   e Zd ZdZed fddZdd Z  ZS )r?  zmThe layer that transforms the image model's output to part of the text model's input (namely, image features)re   c                    sb   t    t|jj|jj| _t	t
|j|jj| _t|j|jj|jj|jjddd| _d S )NF)r   r   r  )rl   rm   r   r   r9  rn   r<  ro   r@  rr   r)   rs   Zlatent_query_numrA  r   r  r   x_attnr{   r|   r.   r/   rm     s    
z%Kosmos2ImageToTextProjection.__init__c                 C   sX   |  |}| jd|ddd}tj||gdd}| j||d d d d\}}||fS )Nr   r4   r   r6   )rG   r	  rF   r   r   )r@  rA  r   r'   r&   r)   r;   rT  )rV   featuresrG   rA  Zkey_value_statesr   r.   r.   r/   r     s    

z$Kosmos2ImageToTextProjection.forward)r[   r\   r]   r^   r   rm   r   r   r.   r.   r|   r/   r?    s   r?  z}
    KOSMOS-2 Model for generating text and image features. The model consists of a vision encoder and a language model.
    c                       s   e Zd ZU eed< dZed fddZejdddZ	d	d
 Z
dejee ee dddZeedeej eej eej eej eej eeej  eej eej eej ee ee ee eee ee eeef dddZ  ZS )r7  rf   r   re   c                    s:   t  | t|j| _t|j| _t|| _	| 
  d S r   )rl   rm   r:  r<  
text_modelr6  r9  vision_modelr?  image_to_text_projectionrG  r{   r|   r.   r/   rm     s
    
zKosmos2Model.__init__rL   c                 C   s
   | j jjS r   rV  rF  r#  rU   r.   r.   r/   rH    s    z!Kosmos2Model.get_input_embeddingsc                 C   s   || j j_d S r   rY  rV   r   r.   r.   r/   set_input_embeddings  s    z!Kosmos2Model.set_input_embeddingsF)r   return_attentionsr   c                 C   sN   | j ||d}| j j|d }tjj|dd}| |\}}|rJ||fS |S )aD  
        Encodes images into continuous embeddings that can be forwarded to the language model.

        Args:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
                The tensors corresponding to the input images.
            return_attentions (`bool`, *optional*, defaults to `False`):
                Whether to return `projection_attentions` or not.
            interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
                Whether to interpolate positional embeddings or not.
        )r   r   r   r4   r6   )rW  rF  r   r   r   	normalizerX  )rV   r   r\  r   rK   rI   rJ   r.   r.   r/   get_image_features  s    zKosmos2Model.get_image_featuresN)r   rA   r+  r   r,  rF   rI   r   ri   r  r   r   r   r   r   rM   c                 K   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}d}d}|du rp|du r\td| j|d|d\}}| jf ||||||||	|
||dd|}t|j|j	|j
|j|||dS )aE  
        image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0,
            1]`:

            - 1 for places where to put the image features,
            - 0 for places that are not for image features (i.e. for text tokens).
        image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Kosmos2Model

        >>> model = Kosmos2Model.from_pretrained("microsoft/kosmos-2-patch14-224")
        >>> processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")

        >>> url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> text = (
        ...     "<grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863>"
        ...     "</object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911>"
        ...     "</object>"
        ... )

        >>> inputs = processor(text=text, images=image, return_tensors="pt", add_eos_token=True)

        >>> last_hidden_state = model(
        ...     pixel_values=inputs["pixel_values"],
        ...     input_ids=inputs["input_ids"],
        ...     attention_mask=inputs["attention_mask"],
        ...     image_embeds_position_mask=inputs["image_embeds_position_mask"],
        ... ).last_hidden_state
        >>> list(last_hidden_state.shape)
        [1, 91, 2048]
        ```N<You have to specify either `pixel_values` or `image_embeds`.T)r\  r   )rA   r   rI   r+  r,  rF   r   ri   r  r   r   r   )rE   rF   rG   rH   rI   rJ   rK   )rf   r   r   r   r   r^  rV  rD   rE   rF   rG   rH   )rV   r   rA   r+  r   r,  rF   rI   r   ri   r  r   r   r   r   r   rK   rJ   r   r.   r.   r/   r     sJ    <
zKosmos2Model.forward)FF)NNNNNNNNNNNNFN)r[   r\   r]   r   r`   rJ  rm   r   rE  rH  r[  r)   r_   r   r*   r^  r   r   r   r2  r   r   r   rY   rD   r   r   r.   r.   r|   r/   r7    s\   

                
r7  z
    KOSMOS-2 Model for generating text and bounding boxes given an image. The model consists of a vision encoder and a
    language model.
    c                       s8  e Zd ZU eed< dZdgZed fddZej	ddd	Z
d
d Zej	dddZdd Zeedeej eej eej eej eej eeej  eej eej eej eej ee ee ee ee eeef dddZe deej eej eej eej eej eej dddZ  ZS )r8  rf   r   ztext_model.lm_head.weightre   c                    s:   t  | t|j| _t|j| _t|| _	| 
  d S r   )rl   rm   r;  r<  rV  r6  r9  rW  r?  rX  rG  r{   r|   r.   r/   rm     s
    
z(Kosmos2ForConditionalGeneration.__init__rL   c                 C   s
   | j jjS r   rY  rU   r.   r.   r/   rH    s    z4Kosmos2ForConditionalGeneration.get_input_embeddingsc                 C   s   || j j_d S r   rY  rZ  r.   r.   r/   r[    s    z4Kosmos2ForConditionalGeneration.set_input_embeddingsc                 C   s
   | j  S r   )rV  rM  rU   r.   r.   r/   rM    s    z5Kosmos2ForConditionalGeneration.get_output_embeddingsc                 C   s   | j | d S r   )rV  set_output_embeddings)rV   Znew_embeddingsr.   r.   r/   r`    s    z5Kosmos2ForConditionalGeneration.set_output_embeddingsN)r   rA   r+  r   r,  rF   rI   r   ri   rN  r  r   r   r   rM   c                 K   s   |dur|n| j j}|dur |n| j j}d}d}|du r|du rHtd| j|||d}| jj|d }tjj	|dd}| 
|\}}| jf ||||||||	|
|||dd|}t|j|j|j|j|j|||d	S )
a5  
        image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0,
            1]`:

            - 1 for places where to put the image features,
            - 0 for places that are not for image features (i.e. for text tokens).
        image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
            `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
            ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Kosmos2ForConditionalGeneration

        >>> model = Kosmos2ForConditionalGeneration.from_pretrained("microsoft/kosmos-2-patch14-224")
        >>> processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")

        >>> url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> prompt = "<grounding> An image of"

        >>> inputs = processor(text=prompt, images=image, return_tensors="pt")

        >>> generated_ids = model.generate(
        ...     pixel_values=inputs["pixel_values"],
        ...     input_ids=inputs["input_ids"],
        ...     attention_mask=inputs["attention_mask"],
        ...     image_embeds=None,
        ...     image_embeds_position_mask=inputs["image_embeds_position_mask"],
        ...     use_cache=True,
        ...     max_new_tokens=64,
        ... )
        >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        >>> processed_text = processor.post_process_generation(generated_text, cleanup_and_extract=False)
        >>> processed_text
        '<grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863></object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911></object>.'

        >>> caption, entities = processor.post_process_generation(generated_text)
        >>> caption
        'An image of a snowman warming himself by a fire.'

        >>> entities
        [('a snowman', (12, 21), [(0.390625, 0.046875, 0.984375, 0.828125)]), ('a fire', (41, 47), [(0.171875, 0.015625, 0.484375, 0.890625)])]
        ```Nr_  )r   r   r   r   r4   r6   T)rA   r   rI   r+  r,  rF   r   ri   rN  r  r   r   r   )rb   rc   rF   rG   rH   rI   rJ   rK   )rf   r   r   r   rW  rF  r   r   r   r]  rX  rV  ra   rb   rc   rF   rG   rH   )rV   r   rA   r+  r   r,  rF   rI   r   ri   rN  r  r   r   r   rK   rJ   Z
lm_outputsr.   r.   r/   r     sV    Gz'Kosmos2ForConditionalGeneration.forward)r   r+  rA   r   rI   r   c                 K   s   | dd }|d ur,|d ur,td| d|d u r@|d ur@|}|d u r| |}	| jj|	d }tjj|dd}| |\}}
| j	j
f |||||d|}|S )Ninputsz
`inputs`: zp were passed alongside `pixel_values` which is not allowed.Make sure to either pass `inputs` or pixel_values=...r   r4   r6   )rA   r   rI   r+  r   )rQ  r   rW  rF  r   r   r   r]  rX  rV  generate)rV   r   r+  rA   r   rI   r   r   ra  rK   rJ   outputr.   r.   r/   rb    s,    

	z(Kosmos2ForConditionalGeneration.generate)NNNNNNNNNNNNN)NNNNNN) r[   r\   r]   r   r`   rJ  rR  rm   r   rE  rH  r[  rM  r`  r   r   r   r)   r   r2  r_   rS  r*   r   r   r   rY   ra   r   r   rb  r   r.   r.   r|   r/   r8    sn   
             
w      r8  )r8  r7  r3  )N)r   )r   )r   )Rr^   r   dataclassesr   typingr   r   r   r   r)   Ztorch.utils.checkpointr   Zactivationsr	   Zcache_utilsr
   r   r   Z
generationr   Zmodeling_flash_attention_utilsr   Zmodeling_layersr   Zmodeling_outputsr   r   r   r   Zmodeling_utilsr   r   Zprocessing_utilsr   utilsr   r   r   r   r   r   Zutils.deprecationr   Zconfiguration_kosmos2r   r    r!   Z
get_loggerr[   r   r   r#   r?   r0   Sizer2   r=   rC   rD   ra   rE  rd   r   r   r   r   r   r   r   r   r   r  r  r  r3  r6  r:  r;  r?  r7  r8  __all__r.   r.   r.   r/   <module>   s    
 
&,[ P3X7X~g g=!F #  ?