a
    hl                     @   sj  d Z ddlmZ ddlmZmZ ddlZddlZddlmZ ddl	m
Z
mZ ddlmZ dd	lmZ dd
lmZ ddlmZ ddlmZ ddlmZmZmZmZmZ ddlmZ ddlmZ e e!Z"eeddG dd deZ#eeddG dd deZ$G dd dej%Z&eG dd deZ'eddG dd de'Z(ed dG d!d" d"e'eZ)g d#Z*dS )$zPyTorch PaliGemmamodel.    )	dataclass)OptionalUnionN)nn   )CacheStaticCache)GenerationMixin)FlashAttentionKwargs)BaseModelOutputWithPast)PreTrainedModel)Unpack)ModelOutputTransformersKwargsauto_docstringcan_return_tuplelogging   )	AutoModel   )PaliGemmaConfigzN
    Base class for Paligemma outputs, with hidden states and attentions.
    )Zcustom_introc                   @   s$   e Zd ZU dZdZeej ed< dS )PaligemmaModelOutputWithPasta  
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
        `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
        `past_key_values` input) to speed up sequential decoding.
    image_hidden_states (`torch.FloatTensor`, *optional*):
        A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
        image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
    Nimage_hidden_states)	__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__ r    r    l/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/paligemma/modeling_paligemma.pyr   ,   s   
r   zU
    Base class for PaliGemma causal language model (or autoregressive) outputs.
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeeej ef  ed< dZeeej  ed< dZeeej  ed< dZeej ed< dS )	PaliGemmaCausalLMOutputWithPastaa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss (for next-token prediction).
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
        `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
        `past_key_values` input) to speed up sequential decoding.
    image_hidden_states (`torch.FloatTensor`, *optional*):
        A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
        image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
    Nlosslogitspast_key_valueshidden_states
attentionsr   )r   r   r   r   r#   r   r   r   r   r$   r%   r   listr   r&   tupler'   r   r    r    r    r!   r"   B   s   
r"   c                       s*   e Zd Zed fddZdd Z  ZS )PaliGemmaMultiModalProjectorconfigc                    s(   t    tj|jj|jjdd| _d S )NTbias)super__init__r   Linearvision_confighidden_sizeZprojection_dimlinearselfr,   	__class__r    r!   r0   b   s    
z%PaliGemmaMultiModalProjector.__init__c                 C   s   |  |}|S N)r4   )r6   image_featuresr&   r    r    r!   forwardf   s    
z$PaliGemmaMultiModalProjector.forward)r   r   r   r   r0   r;   __classcell__r    r    r7   r!   r*   a   s   r*   c                   @   sD   e Zd ZU eed< dZdZdgZdZdZ	dZ
dZdZdZdd Zd	S )
PaliGemmaPreTrainedModelr,    Tr*   r%   Fc                 C   sN   t | jd| j j}t|tjrJ|jjj	d|d |j
d urJ|j
j  d S )Ninitializer_range        )meanstd)getattrr,   Zget_text_configr?   
isinstancer   r1   weightdataZnormal_r.   Zzero_)r6   modulerB   r    r    r!   _init_weightsz   s
    
z&PaliGemmaPreTrainedModel._init_weightsN)r   r   r   r   r   Zbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesZ_skip_keys_device_placementZ_can_compile_fullgraphZ_supports_flash_attnZ_supports_sdpaZ_supports_flex_attnZ_supports_attention_backendrH   r    r    r    r!   r=   l   s   
r=   z{
    The Base Paligemma model which consists of a vision backbone and a language model withou language modeling head.,
    c                       s  e Zd ZddiZdZed fddZdd Zd	d
 Zdd Z	dd Z
dee dddZejdddZejejejdddZeedejejeej eej eeeej ef  eej eej eej eej ee ee ee ee ee eeef dddZ  ZS )PaliGemmaModelzlanguage_model.modellanguage_modelFr+   c                    sj   t  | tj|jd| _t|| _|jj	| _	tj|jd}|| _
| jjd urX| jjnd| _|   d S )Nr+   )r/   r0   r   from_configr2   vision_towerr*   multi_modal_projectortext_config
vocab_sizerJ   r,   Zpad_token_id	post_init)r6   r,   rJ   r7   r    r!   r0      s    

zPaliGemmaModel.__init__c                 C   s
   | j  S r9   )rJ   get_input_embeddingsr6   r    r    r!   rR      s    z#PaliGemmaModel.get_input_embeddingsc                 C   s   | j | d S r9   )rJ   set_input_embeddingsr6   valuer    r    r!   rT      s    z#PaliGemmaModel.set_input_embeddingsc                 C   s
   || _ d S r9   rJ   r6   decoderr    r    r!   set_decoder   s    zPaliGemmaModel.set_decoderc                 C   s   | j S r9   rW   rS   r    r    r!   get_decoder   s    zPaliGemmaModel.get_decoderN)is_trainingc                 C   sr  | j jjdkr&|d ur"d|v r"|S d S |d ur2|n| j}t|t}t| jj	}|d u r\|}|j
d d \}	}
|r|| }n&t|tjr|j
d n|d |
 d }|d ur| dkr|S tj|
|f|| j|jd}|
dkr|rtj|dd	}nd|d d d |
f< |tj||jd
|ddk9 }|d d d d d d f |	ddd}|d urn| }|j
d }|r|d u r~td|d d d d d d d |f |d d d d d d f |jdkd|d d d d d d d |f< |d d d d d d d |f |d d d d d d f |j }|dk}|d d d d d d d |f |||d d d d d d d |f< |S )NZflash_attention_2r@   r   rK   r   r      Z
fill_valuedtypedeviceZdiagonalr`   z/Token type ids must be provided during training)r,   rO   Z_attn_implementationZtrainingrD   r   r   finfor_   minshapeZget_max_cache_shapeTensordimfullr`   triuarangereshapeexpandclone
ValueErrormasked_fillto)r6   attention_masktoken_type_idsr%   cache_positioninput_tensorr\   Zusing_static_cache	min_dtypeZinputs_lead_dimsequence_lengthtarget_lengthcausal_maskmask_lengthpadding_maskr    r    r!   _update_causal_mask   sT    	



 $


 $ @  z"PaliGemmaModel._update_causal_mask)pixel_valuesc                 C   s0   |  |}|j}| |}|| jjjd  }|S )a  
        Obtains image last hidden states from the vision tower and apply multimodal projection.

        Args:
            pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
               The tensors corresponding to the input images.
        Returns:
            image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
        g      ?)rM   last_hidden_staterN   r,   rO   r3   )r6   r|   Zimage_outputsZselected_image_featurer:   r    r    r!   get_image_features   s
    


z!PaliGemmaModel.get_image_features)	input_idsinputs_embedsr:   c                 C   s   |du r8||   tj| jjtj|jdk}|d}n|| jjk}| }|	d
||j}|jd |jd  }||  | krtd| d| |S )z
        Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
        equal to the length of multimodal features. If the lengths are different, an error is raised.
        N)r_   r`   rK   r   r   z6Image features and image tokens do not match: tokens: z, features )rR   r   Ztensorr,   image_token_idlongr`   allsum	unsqueezeZ	expand_asrp   re   Znumelrn   )r6   r   r   r:   special_image_maskZn_image_tokensZn_image_featuresr    r    r!   get_placeholder_mask   s    z#PaliGemmaModel.get_placeholder_mask)r   r|   rq   position_idsr%   rr   rs   r   labels	use_cacheoutput_attentionsoutput_hidden_statesreturn_dictkwargsreturnc                 K   s  |du |duA rt d|dur$|n| jj}|dur8|n| jj}|durL|n| jj}|duob|	du}|dur| jj| jkr|| jjk}| }d||< n|}|du r|  |}|du r|dur|	 nd}t
j|||jd  |jd}|du r|dd }|dur>| |}||j|j}| j|||d}|||}| ||||||}| jf |||||
||d|d	|}t|j|j|j|j|dur|ndd	S )
i  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.

        Example:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration

        >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
        >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")

        >>> prompt = "Where is the cat standing?"
        >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, text=prompt,  return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(**inputs,)
        >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Where is the cat standing?\nsnow"
        ```Nz:You must specify exactly one of input_ids or inputs_embedsr   r   rb   )r   r:   T)	rq   r   r%   r   r   r   r   r   rs   )r}   r%   r&   r'   r   )rn   r,   r   r   use_return_dictr   rP   rm   rR   Zget_seq_lengthr   rj   re   r`   r   r~   rp   r_   r   Zmasked_scatterr{   rJ   r   r}   r%   r&   r'   )r6   r   r|   rq   r   r%   rr   rs   r   r   r   r   r   r   r   r\   r   Zllm_input_idsZpast_seen_tokensr:   rx   outputsr    r    r!   r;     sf    /



zPaliGemmaModel.forward)NNNNN)NNNNNNNNNNNNN)r   r   r   _checkpoint_conversion_mappingZaccepts_loss_kwargsr   r0   rR   rT   rZ   r[   r   boolr{   r   r   r~   
LongTensorr   r   r   rf   r   r(   r   r   r
   r)   r   r;   r<   r    r    r7   r!   rI      sd        B             
rI   z|
    The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
    c                       sH  e Zd ZdddddZdgZed fdd	Zd
d Zdd Zdd Z	dd Z
dd Zedd Zedd Zedd Zeed%ejejeej eej eeeej ef  eej eej eej eej ee ee ee ee eeejf ee eeef dddZ d& fd d!	Z!e"ejeeej#ejed"d#d$Z$  Z%S )'!PaliGemmaForConditionalGenerationzmodel.language_modelzmodel.vision_towerzmodel.multi_modal_projectorlm_head)z^language_model.modelz^vision_towerz^multi_modal_projectorz^language_model.lm_headzlm_head.weightr+   c                    s<   t  | t|| _tj|jj|jjdd| _	| 
  d S )NFr-   )r/   r0   rI   modelr   r1   rO   r3   rP   r   rQ   r5   r7   r    r!   r0     s    
z*PaliGemmaForConditionalGeneration.__init__c                 C   s
   | j  S r9   )r   rR   rS   r    r    r!   rR     s    z6PaliGemmaForConditionalGeneration.get_input_embeddingsc                 C   s   | j | d S r9   )r   rT   rU   r    r    r!   rT     s    z6PaliGemmaForConditionalGeneration.set_input_embeddingsc                 C   s   | j | d S r9   )r   rZ   rX   r    r    r!   rZ     s    z-PaliGemmaForConditionalGeneration.set_decoderc                 C   s
   | j  S r9   )r   r[   rS   r    r    r!   r[     s    z-PaliGemmaForConditionalGeneration.get_decoderc                 C   s   | j |S r9   )r   r~   )r6   r|   r    r    r!   r~     s    z4PaliGemmaForConditionalGeneration.get_image_featuresc                 C   s   | j jS r9   )r   rJ   rS   r    r    r!   rJ     s    z0PaliGemmaForConditionalGeneration.language_modelc                 C   s   | j jS r9   )r   rM   rS   r    r    r!   rM     s    z.PaliGemmaForConditionalGeneration.vision_towerc                 C   s   | j jS r9   )r   rN   rS   r    r    r!   rN     s    z7PaliGemmaForConditionalGeneration.multi_modal_projectorNr   )r   r|   rq   r   r%   rr   rs   r   r   r   r   r   r   logits_to_keepr   r   c                 K   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}| jf ||||||||
|	||d|d|}|d }t|trt| dn|}| |dd|ddf }d}|	dur| j	f ||	| j j
jd|}t|||j|j|j|jdS )r   NT)r   r|   rr   rq   r   r%   r   r   r   r   r   r   rs   r   )r$   r   rP   )r#   r$   r%   r&   r'   r   )r,   r   r   r   r   rD   intslicer   Zloss_functionrO   rP   r"   r%   r&   r'   r   )r6   r   r|   rq   r   r%   rr   rs   r   r   r   r   r   r   r   r   r   r&   Zslice_indicesr$   r#   r    r    r!   r;     sN    /z)PaliGemmaForConditionalGeneration.forwardTc                    s   t  j|f||||||	|
|d|}|dd urD|d  d7  < |d dkrX||d< |d uof|d u}t|tozt|j}|d dkr|r|d ur|n|}| j||||||}||d< |S )N)r%   r   rq   r   rs   r   r   rr   r   r   r   r|   rq   )	r/   prepare_inputs_for_generationgetrD   r   anyZ
is_slidingr   r{   )r6   r   r%   r   rs   r   r|   rq   rr   r   r   r   r   Zmodel_inputsr\   Zis_static_hybrid_cachert   rx   r7   r    r!   r     s6    
z?PaliGemmaForConditionalGeneration.prepare_inputs_for_generation)rq   rv   rw   r_   rs   
batch_sizec                 K   sF  | dur|   dkr| }n&t|j}tj||f|||jd}|dkrVtj|dd}|tj||jd|ddk9 }|ddddddf 	|ddd}| durB|
 }| jd }	|ddddddd|	f | ddddddf |j }
|
dk}
|ddddddd|	f |
||ddddddd|	f< |S )	aM  
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (`torch.Tensor`):
                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
                `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (`int`):
                The sequence length being processed.
            target_length (`int`):
                The target length: when generating with static cache, the mask should be as long as the static cache,
                to account for the 0 padding, the part of the cache that is not filled yet.
            dtype (`torch.dtype`):
                The dtype to use for the 4D attention mask.
            cache_position (`torch.Tensor`):
                Indices depicting the position of the input sequence tokens in the sequence.
            batch_size (`torch.Tensor`):
                Batch size.
        Nr]   r^   r   ra   rb   rK   r   )rg   r   rc   rd   rh   r`   ri   rj   rk   rl   rm   re   rp   ro   )rq   rv   rw   r_   rs   r   r   rx   ru   ry   rz   r    r    r!   5_prepare_4d_causal_attention_mask_with_cache_position<  s*     $

6  zWPaliGemmaForConditionalGeneration._prepare_4d_causal_attention_mask_with_cache_position)NNNNNNNNNNNNNr   )
NNNNNNNTNN)&r   r   r   r   Z_tied_weights_keysr   r0   rR   rT   rZ   r[   r~   propertyrJ   rM   rN   r   r   r   r   r   r   rf   r   r(   r   r   r   r   r   r)   r"   r;   r   staticmethodr_   r   r<   r    r    r7   r!   r     s   


              
[          /r   )r   r=   rI   )+r   dataclassesr   typingr   r   r   Ztorch.utils.checkpointr   Zcache_utilsr   r   Z
generationr	   Zmodeling_flash_attention_utilsr
   Zmodeling_outputsr   Zmodeling_utilsr   Zprocessing_utilsr   utilsr   r   r   r   r   autor   Zconfiguration_paligemmar   Z
get_loggerr   loggerr   r"   Moduler*   r=   rI   r   __all__r    r    r    r!   <module>   sN   
 z n