a
    h$                  	   @   s  d dl Z d dlZd dlmZ d dlmZmZmZmZ d dl	Z
d dlZd dlmZ d dlm  mZ d dlmZmZmZ d dlmZ ddlmZ ddlmZ dd	lmZ dd
lmZmZmZ ddl m!Z!m"Z" ddl#m$Z$m%Z%m&Z& ddl'm(Z(m)Z)m*Z* ee%ddG dd de$Z+ee%ddG dd de$Z,ee%G dd de$Z-G dd dej.Z/dLej.ej0ej0ej0eej0 e1e1dddZ2G dd dej.Z3G d d! d!ej.Z4G d"d# d#eZ5G d$d% d%ej.Z6G d&d' d'ej.Z7d(d) Z8dMej0e1e1e1e1ej0d-d.d/Z9dNd2d3Z:d4d5 Z;d6d7 Z<e%G d8d9 d9e"Z=G d:d; d;ej.Z>G d<d= d=ej.Z?e%d>dG d?d@ d@e=Z@G dAdB dBej.ZAe%dCdG dDdE dEe=ZBe%G dFdG dGe=ZCe%dHdG dIdJ dJe=ZDg dKZEdS )O    N)	dataclass)AnyCallableOptionalUnion)BCEWithLogitsLossCrossEntropyLossMSELoss)_calculate_fan_in_and_fan_out   )ACT2FN)_prepare_4d_attention_mask)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPoolingImageClassifierOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)ModelOutputauto_docstringcan_return_tuple   )Siglip2ConfigSiglip2TextConfigSiglip2VisionConfigz}
    Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
    )Zcustom_introc                   @   sj   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dS )Siglip2VisionOutputz
    image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
        The image embeddings obtained by applying the projection layer to the pooler_output.
    Nimage_embedslast_hidden_state.hidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   tupler    r(   r(   h/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/siglip2/modeling_siglip2.pyr   *   s
   
r   ze
    Base class for text model's outputs that also contains a pooling of the last hidden states.
    c                   @   sj   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dS )Siglip2TextOutputz
    text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
        The text embeddings obtained by applying the projection layer to the pooler_output.
    Ntext_embedsr   .r   r   )r    r!   r"   r#   r+   r   r$   r%   r&   r   r   r'   r   r(   r(   r(   r)   r*   <   s
   
r*   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeed< dZeed	< ee d
ddZdS )Siglip2Outputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
        Contrastive loss for image-text similarity.
    logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
        The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
        similarity scores.
    logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
        The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
        similarity scores.
    text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
        The text embeddings obtained by applying the projection layer to the pooled output of [`Siglip2TextModel`].
    image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
        The image embeddings obtained by applying the projection layer to the pooled output of [`Siglip2VisionModel`].
    text_model_output (`BaseModelOutputWithPooling`):
        The output of the [`Siglip2TextModel`].
    vision_model_output (`BaseModelOutputWithPooling`):
        The output of the [`Siglip2VisionModel`].
    Nlosslogits_per_imagelogits_per_textr+   r   text_model_outputvision_model_outputreturnc                    s   t  fdd  D S )Nc                 3   s,   | ]$}|d vr | nt  | V  qdS ))r0   r1   N)getattrto_tuple).0kselfr(   r)   	<genexpr>m   s   z)Siglip2Output.to_tuple.<locals>.<genexpr>)r'   keysr8   r(   r8   r)   r5   l   s    zSiglip2Output.to_tuple)r    r!   r"   r#   r-   r   r$   r%   r&   r.   r/   r+   r   r0   r   r1   r'   r   r5   r(   r(   r(   r)   r,   N   s   
r,   c                       sX   e Zd Zed fddZeejeje	ejdddZ
ejejejddd	Z  ZS )
Siglip2VisionEmbeddingsconfigc                    sn   t    || _|j| _|j| _tj|j| j | j | jd| _	|j
| _
t| j
d | _t| j
| j| _d S )N)Zin_featuresZout_featuresg      ?)super__init__r>   hidden_size	embed_dimZ
patch_sizennLinearZnum_channelspatch_embeddingZnum_patchesintposition_embedding_size	Embeddingposition_embeddingr9   r>   	__class__r(   r)   r@   t   s    
z Siglip2VisionEmbeddings.__init__)positional_embeddingsspatial_shapes
max_lengthr3   c                 C   s   |j d }| j d }| j}tj|||f| j|d}| dddd} | jjdkr^| tj	} t
|D ]v}|| \}}	tj| ||	fddd	d
}
|
|||	 dd}
|
|}
|
||d||	 f< |
d ||||	 df< qf|S )ac  
        Resize positional embeddings to image-specific size and pad to a fixed size.

        Args:
            positional_embeddings (`torch.Tensor`):
                Position embeddings of shape (height, width, embed_dim)
            spatial_shapes (`torch.LongTensor`):
                Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
            max_length (`int`):
                Maximum length of the positional embeddings to pad resized positional embeddings to

        Returns:
            `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
        r   )devicedtype   r   cpuZbilinearFT)sizemodeZalign_cornersZ	antialiasN)shaperR   r$   emptyrQ   ZpermuteZ	unsqueezetypetofloat32rangeFZinterpolatereshape	transpose)rM   rN   rO   
batch_sizerB   Zsource_dtypeZresulted_positional_embeddingsiheightwidthZresized_embeddingsr(   r(   r)   resize_positional_embeddings   s2    

	
z4Siglip2VisionEmbeddings.resize_positional_embeddings)pixel_valuesrN   r3   c                 C   sT   | j jj}|  |j|d}| jj| j| jd}| j|||jd d}|| }|S )aH  
        Args:
            pixel_values (`torch.FloatTensor`):
                Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size)
            spatial_shapes (`list[tuple[int, int]]`):
                Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
        )rR   rP   r   )rO   )	rE   weightrR   rZ   rI   r^   rG   rd   rW   )r9   re   rN   Ztarget_dtypeZpatch_embedsrM   Zresized_positional_embeddings
embeddingsr(   r(   r)   forward   s    


zSiglip2VisionEmbeddings.forward)r    r!   r"   r   r@   staticmethodr$   Tensor
LongTensorrF   rd   r%   rh   __classcell__r(   r(   rK   r)   r<   s   s   :r<           )modulequerykeyvalueattention_maskscalingdropoutc           
      K   s|   t ||dd| }|d ur(|| }tjj|dt jd|j}tjj	||| j
d}t ||}	|	dd }	|	|fS )NrP   )dimrR   )ptrainingr   rS   )r$   matmulr_   rC   
functionalZsoftmaxr[   rZ   rR   rt   rx   
contiguous)
rn   ro   rp   rq   rr   rs   rt   kwargsattn_weightsattn_outputr(   r(   r)   eager_attention_forward   s    
r   c                       sL   e Zd ZdZ fddZdejeej eejeej f dddZ	  Z
S )	Siglip2Attentionz=Multi-headed attention from 'Attention Is All You Need' paperc                    s   t    || _|j| _|j| _| j| j | _| j| j | jkrZtd| j d| j d| jd | _	|j
| _d| _t| j| j| _t| j| j| _t| j| j| _t| j| j| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).      F)r?   r@   r>   rA   rB   num_attention_heads	num_headshead_dim
ValueErrorscaleZattention_dropoutrt   	is_causalrC   rD   k_projv_projq_projout_projrJ   rK   r(   r)   r@      s$    

zSiglip2Attention.__init__N)r   rr   r3   c              
   K   s   |j \}}}| |}| |}| |}	|||| j| jdd}|||| j| jdd}|	||| j| jdd}	t}
| j	j
dkrt| j	j
 }
|
| |||	|| j| j| jsdn| jd\}}|||| }| |}||fS )z#Input shape: Batch x Time x Channelr   rS   eagerrm   )r   rs   rt   )rW   r   r   r   viewr   r   r_   r   r>   _attn_implementationr   r   r   rx   rt   r^   r{   r   )r9   r   rr   r|   r`   
seq_lengthrB   Zqueriesr;   valuesZattention_interfacer~   r}   r(   r(   r)   rh     s.    




zSiglip2Attention.forward)N)r    r!   r"   r#   r@   r$   rj   r   r'   rh   rl   r(   r(   rK   r)   r      s    r   c                       s0   e Zd Z fddZejejdddZ  ZS )
Siglip2MLPc                    sD   t    || _t|j | _t|j|j	| _
t|j	|j| _d S N)r?   r@   r>   r   Z
hidden_actactivation_fnrC   rD   rA   Zintermediate_sizefc1fc2rJ   rK   r(   r)   r@   .  s
    
zSiglip2MLP.__init__)r   r3   c                 C   s"   |  |}| |}| |}|S r   )r   r   r   )r9   r   r(   r(   r)   rh   5  s    


zSiglip2MLP.forward)r    r!   r"   r@   r$   rj   rh   rl   r(   r(   rK   r)   r   -  s   r   c                       sN   e Zd Zeeef d fddZdejeje	e
 eej dddZ  ZS )	Siglip2EncoderLayerr=   c                    sR   t    |j| _tj| j|jd| _t|| _	tj| j|jd| _
t|| _d S Neps)r?   r@   rA   rB   rC   	LayerNormlayer_norm_epslayer_norm1r   	self_attnlayer_norm2r   mlprJ   rK   r(   r)   r@   =  s    

zSiglip2EncoderLayer.__init__F)r   rr   output_attentionsr3   c                 C   sb   |}|  |}| j|||d\}}|| }|}| |}| |}|| }|f}|r^||f7 }|S )a=  
        Args:
            hidden_states (`torch.FloatTensor`):
                Input to the layer of shape `(batch, seq_len, embed_dim)`.
            attention_mask (`torch.FloatTensor`):
                Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r   rr   r   )r   r   r   r   )r9   r   rr   r   residualr}   outputsr(   r(   r)   rh   E  s     




zSiglip2EncoderLayer.forward)F)r    r!   r"   r   r   r   r@   r$   rj   r   boolr'   r%   rh   rl   r(   r(   rK   r)   r   <  s    r   c                       sN   e Zd ZdZed fddZed	eej	 ee
 ee
 edddZ  ZS )
Siglip2Encoderz
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`Siglip2EncoderLayer`].

    Args:
        config: Siglip2Config
    r=   c                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r(   )r   )r6   _r=   r(   r)   
<listcomp>x      z+Siglip2Encoder.__init__.<locals>.<listcomp>F)	r?   r@   r>   rC   Z
ModuleListr\   Znum_hidden_layerslayersZgradient_checkpointingrJ   rK   r=   r)   r@   u  s    
 zSiglip2Encoder.__init__N)rr   r   output_hidden_statesr3   c           
      C   s   |dur|n| j j}|dur |n| j j}|r0dnd}|r<dnd}|}| jD ]:}|r\||f }||||d}	|	d }|rJ||	d f }qJ|r||f }t|||dS )ad  
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr(   )r   r   r   )r   r   r   )r>   r   r   r   r   )
r9   inputs_embedsrr   r   r   Zencoder_statesZall_attentionsr   Zencoder_layerZlayer_outputsr(   r(   r)   rh   |  s0    


zSiglip2Encoder.forward)NNN)r    r!   r"   r#   r   r@   r   r   r$   rj   r   r   rh   rl   r(   r(   rK   r)   r   l  s      r   c                
       sR   e Zd Zed fddZeedejej	ej
ee ee edddZ  ZS )	Siglip2VisionTransformerr=   c                    sr   t    || _|j}t|| _t|| _tj	||j
d| _t|dsJdn|j| _| jrbt|| _|jdk| _d S )Nr   vision_use_headTZflash_attention_2)r?   r@   r>   rA   r<   rg   r   encoderrC   r   r   post_layernormhasattrr   use_head$Siglip2MultiheadAttentionPoolingHeadheadr   _use_flash_attention_2r9   r>   rB   rK   r(   r)   r@     s    



z!Siglip2VisionTransformer.__init__N)re   rr   rN   r   r   r3   c                 C   s   |dur|n| j j}|dur |n| j j}| ||}|durP| jsPt||j}n|}| j||||d}|j}	| 	|	}	| j
r| |	|nd}
t|	|
|j|jdS )z
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width) of the input images.
        Nr   rr   r   r   r   pooler_outputr   r   )r>   r   r   rg   r   r   rR   r   r   r   r   r   r   r   r   )r9   re   rr   rN   r   r   r   Zencoder_attention_maskencoder_outputsr   r   r(   r(   r)   rh     s,    
z Siglip2VisionTransformer.forward)NN)r    r!   r"   r   r@   r   r   r$   r%   rj   rk   r   r   r   rh   rl   r(   r(   rK   r)   r     s     r   c                 C   s   dd }||d|  k s(||d|  kr6t jddd ||| | }||| | }| d| d d| d  |   | |td  | | | j||d d S )	Nc                 S   s   dt | t d  d S )N      ?       @)matherfsqrt)xr(   r(   r)   norm_cdf  s    z _trunc_normal_.<locals>.norm_cdfrS   zjmean is more than 2 std from [a, b] in nn.init.trunc_normal_. The distribution of values may be incorrect.)
stacklevelr   r   )minmax)	warningswarnuniform_Zerfinv_mul_r   r   add_Zclamp_)tensormeanstdabr   lur(   r(   r)   _trunc_normal_  s     	
r   r          r   )r   r   r   r   r   r3   c                 C   sL   t  0 t| dd|| | || W d   n1 s>0    Y  dS )an  Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(	ext{mean}, 	ext{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq 	ext{mean} \leq b`.

    NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
    bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
    and the result is subsequently scaled and shifted by the mean and std args.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    r   r   N)r$   no_gradr   r   r   )r   r   r   r   r   r(   r(   r)   trunc_normal_tf_  s    
r   fan_innormalc           	      C   s  t | \}}|dkr|}n"|dkr(|}n|dkr<|| d }|| }|dkrdt| t|d d n|dkrt " | jt|d W d    n1 s0    Y  n\|d	krtd
| }t  | | | W d    n1 s0    Y  ntd| d S )Nr   fan_outZfan_avgrS   truncated_normalg۶%?r   r   uniformr   zinvalid distribution )	r
   r   r   r   r$   r   normal_r   r   )	r   r   rV   distributionr   r   denomZvarianceboundr(   r(   r)   variance_scaling_7  s$    
2
.r   c                 C   s   t | ddd d S )Nr   r   rV   r   r   r   r(   r(   r)   lecun_normal_P  s    r   c                 C   s   t | ddd d S )Nr   r   r   r   r   r(   r(   r)   default_flax_embed_initT  s    r   c                   @   s>   e Zd ZU eed< dZdZg dZdZdZ	dZ
dZdd ZdS )Siglip2PreTrainedModelr>   Zsiglip2T)Siglip2TextEmbeddingsr<   r   r   c                 C   s^  t |trJt | jtr | jjjn| jj}tjj|j	j
dt| d nt |tjrdt|j
 nt |trtj|jj
 tj|jj
 tj|jj
 tj|jj
 tj|jj tj|jj tj|jj tj|jj nht |trJtj|jj
 tj|jj
 tjj|jjdd tjj|jjdd nt |trtj|jj tj|jjj tj|jjj nt |t rt!"t!#d}|j$j%| |j&j'  nt |t(rtjj|j)j
| jjjd | jj* d nbt |tj+tj,fr2t-|j
 |jdurZtj|j n(t |tj.rZ|jj'  |j
j%d dS )zInitialize the weightsr   r   gư>r   r   N)/
isinstancer<   r>   r   vision_configrA   rC   initr   rI   rf   npr   rH   r   r   Zxavier_uniform_r   r   r   r   Zzeros_Zbiasr   r   r   r   probedata	attentionZin_proj_weightZin_proj_biasSiglip2Modelr$   logr   logit_scaleZfill_
logit_biasZzero_Siglip2ForImageClassification
classifierZinitializer_factorrD   ZConv2dr   r   )r9   rn   rc   Zlogit_scale_initr(   r(   r)   _init_weightsi  sT    

"

z$Siglip2PreTrainedModel._init_weightsN)r    r!   r"   r   r&   Zbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesZ_supports_flash_attnZ_supports_sdpaZ_supports_flex_attnZ_supports_attention_backendr   r(   r(   r(   r)   r   X  s   
r   c                       sL   e Zd Zed fddZdeej eej eej ej	dddZ
  ZS )	r   r=   c                    sR   t    |j}t|j|| _t|j|| _| j	dt
|jddd d S )Nposition_ids)r   rP   F)
persistent)r?   r@   rA   rC   rH   Z
vocab_sizetoken_embeddingZmax_position_embeddingsrI   Zregister_bufferr$   Zarangeexpandr   rK   r(   r)   r@     s    
zSiglip2TextEmbeddings.__init__N)	input_idsr   r   r3   c                 C   s   |d ur|j d n|j d }| jjj d }||krFtd| d| |d u rd| jd d d |f }|d u rv| |}| |}|| }|S )NrP   ru   r   zRSequence length must be less than max_position_embeddings (got `sequence length`: z and max_position_embeddings: )rW   rI   rf   r   r   r   )r9   r   r   r   r   Zmax_position_embeddingZposition_embeddingsrg   r(   r(   r)   rh     s"    

zSiglip2TextEmbeddings.forward)NNN)r    r!   r"   r   r@   r   r$   rk   r%   rj   rh   rl   r(   r(   rK   r)   r     s      r   c                
       s^   e Zd Zed fddZeedeej	 eej	 eej	 ee
 ee
 edddZ  ZS )	Siglip2TextTransformerr=   c                    sP   t    || _|j}t|| _t|| _tj	||j
d| _t||j| _d S r   )r?   r@   r>   rA   r   rg   r   r   rC   r   r   final_layer_normrD   Zprojection_sizer   r   rK   r(   r)   r@     s    


zSiglip2TextTransformer.__init__Nr   rr   r   r   r   r3   c                 C   s   |d ur|n| j j}|d ur |n| j j}|d u r8td| }|d|d }| j||d}d| j jv }|rtd }n|d ur|st||j	}| j
||||d}	|	j}
| |
}
|
d d dd d f }| |}t|
||	j|	jdS )NzYou have to specify input_idsrP   )r   r   Zflashr   r   )r>   r   r   r   rU   r   rg   r   r   rR   r   r   r   r   r   r   r   )r9   r   rr   r   r   r   Zinput_shaper   Zuses_flash_attentionr   r   pooled_outputr(   r(   r)   rh     s:    


zSiglip2TextTransformer.forward)NNNNN)r    r!   r"   r   r@   r   r   r   r$   rj   r   r   rh   rl   r(   r(   rK   r)   r     s    
     r   zL
    The text model from Siglip2 without any head or projection on top.
    c                
       s   e Zd ZU eed< ed fddZejdddZdd	 Z	e
edeej eej eej ee ee edddZ  ZS )Siglip2TextModelr>   r=   c                    s"   t  | t|| _|   d S r   )r?   r@   r   
text_model	post_initrJ   rK   r(   r)   r@   	  s    
zSiglip2TextModel.__init__r2   c                 C   s
   | j jjS r   r   rg   r   r8   r(   r(   r)   get_input_embeddings  s    z%Siglip2TextModel.get_input_embeddingsc                 C   s   || j j_d S r   r   )r9   rq   r(   r(   r)   set_input_embeddings  s    z%Siglip2TextModel.set_input_embeddingsNr   c                 C   s   | j |||||dS )a  
        Examples:

        ```python
        >>> from transformers import AutoTokenizer, Siglip2TextModel

        >>> model = Siglip2TextModel.from_pretrained("google/siglip2-base-patch16-224")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip2-base-patch16-224")

        >>> # important: make sure to set padding="max_length" as that's how the model was trained
        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled (EOS token) states
        ```r   rr   r   r   r   )r   )r9   r   rr   r   r   r   r(   r(   r)   rh     s    zSiglip2TextModel.forward)NNNNN)r    r!   r"   r   r&   r@   rC   Moduler   r  r   r   r   r$   rj   r   r   rh   rl   r(   r(   rK   r)   r     s&   
     r   c                       sD   e Zd ZdZed fddZd	ejeej ejdddZ	  Z
S )
r   zMultihead Attention Pooling.r=   c                    sd   t    ttdd|j| _tjj|j|j	dd| _
tj|j|jd| _t|| _|j	| _d S )Nr   T)Zbatch_firstr   )r?   r@   rC   	Parameterr$   randnrA   r   ZMultiheadAttentionr   r   r   r   	layernormr   r   r   rJ   rK   r(   r)   r@   <  s    

z-Siglip2MultiheadAttentionPoolingHead.__init__N)hidden_staterr   r3   c                 C   s   |j d }| j|dd}|d urf|j d |j d  }}t||j|}|d| j|d}|d||}| j||||dd }|}| |}|| 	| }|d d df S )Nr   r   rP   )Z	attn_mask)
rW   r   repeatr   rR   r   r^   r   r  r   )r9   r  rr   r`   r   Z
target_lenZ
source_lenr   r(   r(   r)   rh   E  s    

z,Siglip2MultiheadAttentionPoolingHead.forward)N)r    r!   r"   r#   r   r@   r$   rj   r   rh   rl   r(   r(   rK   r)   r   9  s   	r   zN
    The vision model from Siglip2 without any head or projection on top.
    c                
       sp   e Zd ZU eed< dZed fddZejdddZ	e
edejejejee ee ed
ddZ  ZS )Siglip2VisionModelr>   re   r=   c                    s"   t  | t|| _|   d S r   )r?   r@   r   vision_modelr   rJ   rK   r(   r)   r@   a  s    
zSiglip2VisionModel.__init__r2   c                 C   s
   | j jjS r   )r
  rg   rE   r8   r(   r(   r)   r   i  s    z'Siglip2VisionModel.get_input_embeddingsNre   pixel_attention_maskrN   r   r   r3   c                 C   s   | j |||||dS )a9  
        pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
            Mask to avoid performing attention on padding pixel indices.
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width) of the input images.

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Siglip2VisionModel

        >>> model = Siglip2VisionModel.from_pretrained("google/siglip2-base-patch16-224")
        >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled features
        ```re   rr   rN   r   r   )r
  )r9   re   r  rN   r   r   r(   r(   r)   rh   l  s    #zSiglip2VisionModel.forward)NN)r    r!   r"   r   r&   main_input_namer@   rC   r  r   r   r   r$   r%   rj   rk   r   r   r   rh   rl   r(   r(   rK   r)   r	  X  s    
  r	  c                       s   e Zd ZU eed< ed fddZedeej	 eej	 eej	 ee
 ee
 ejdddZedeej eej	 eej ee
 ee
 ejd	d
dZeedeej eej eej	 eej eej	 eej ee
 ee
 ee
 ed
ddZ  ZS )r   r>   r=   c                    s   t  | t|jts.tdt|j dt|jtsPtdt|j d|j}|j}t	
|}t
|}|j| _|j| _ttd| _ttd| _|   d S )NzNconfig.text_config is expected to be of type Siglip2TextConfig but is of type .zRconfig.vision_config is expected to be of type Siglip2VisionConfig but is of type r   )r?   r@   r   text_configr   	TypeErrorrY   r   r   r   _from_configr	  r   r
  rC   r  r$   r  r   r   r   )r9   r>   r  r   r   r
  rK   r(   r)   r@     s,    

zSiglip2Model.__init__Nr   c                 C   sF   |dur|n| j j}|dur |n| j j}| j|||||d}|j}|S )aM  
        Returns:
            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of [`Siglip2TextModel`].

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, AutoModel
        >>> import torch

        >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip2-base-patch16-224")

        >>> # important: make sure to set padding="max_length" as that's how the model was trained
        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
        >>> with torch.no_grad():
        ...     text_features = model.get_text_features(**inputs)
        ```Nr  )r>   r   r   r   r   )r9   r   rr   r   r   r   text_outputsr   r(   r(   r)   get_text_features  s    zSiglip2Model.get_text_featuresr  c                 C   sF   |dur|n| j j}|dur |n| j j}| j|||||d}|j}|S )a  
        pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
            Mask to avoid performing attention on padding pixel indices.
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width) of the input images.

        Returns:
            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of [`Siglip2VisionModel`].

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, AutoModel
        >>> import torch

        >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
        >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> with torch.no_grad():
        ...     image_features = model.get_image_features(**inputs)
        ```
        Nr  )r>   r   r   r
  r   )r9   re   r  rN   r   r   vision_outputsr   r(   r(   r)   get_image_features  s    (zSiglip2Model.get_image_features)
r   re   r  rN   rr   r   return_lossr   r   r3   c
              	   C   sF  |dur|n| j j}|	dur |	n| j j}	| j|||||	d}
| j|||||	d}|
j}|j}||jdddd }||jdddd }t||	 
|j}| j
|j| j
|j }}||  | }|	 }d}|r0tj|d|jd	}t| d|  }tjj|| }tj|dd
 }| }t|||||||
dS )a  
        pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
            Mask to avoid performing attention on padding pixel indices.
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width) of the input images.
        return_loss (`bool`, *optional*):
            Whether or not to return the contrastive loss.

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, AutoModel
        >>> import torch

        >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
        >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
        >>> # important: we pass `padding=max_length` since the model was trained with this
        >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**inputs)

        >>> logits_per_image = outputs.logits_per_image
        >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
        >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
        31.9% that image 0 is 'a photo of 2 cats'
        ```
        Nr  r  rS   rP   T)rw   rv   Zkeepdimr   )rQ   rv   )r-   r.   r/   r+   r   r0   r1   )r>   r   r   r
  r   r   Znormr$   ry   trZ   rQ   r   r   expeyerU   Z	ones_likerC   rz   Z
logsigmoidsumr   r,   )r9   r   re   r  rN   rr   r   r  r   r   r  r  r   r+   r/   r   r   r.   r-   r  Zm1_diag1ZloglikZnllr(   r(   r)   rh   #  sR    2zSiglip2Model.forward)NNNNN)NNNNN)	NNNNNNNNN)r    r!   r"   r   r&   r@   r   r   r$   rj   r   r%   r  rk   r  r   r,   rh   rl   r(   r(   rK   r)   r     sj   
      -     8         r   z
    Siglip2 vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
    the patch tokens) e.g. for ImageNet.
    c                       sl   e Zd ZdZedd fddZeed	ee	j
 ee	j
 ee	j ee	j
 ee ee edddZ  ZS )
r   re   N)r>   r3   c                    sZ   t  | |j| _t|j}|j| _|jdkrDt|jj	|jnt
 | _|   d S )Nr   )r?   r@   
num_labelsr	  r  r   r
  rC   rD   rA   ZIdentityr   r   )r9   r>   r
  rK   r(   r)   r@     s    "z&Siglip2ForImageClassification.__init__)re   r  rN   labelsr   r   r3   c                 C   s  |dur|n| j j}|dur |n| j j}| j|||||d}|j}|dur||d |j}	tj||	 ddtj|	dd }ntj	|dd}| 
|}
d}|dur||
j}| j jdu r| jdkrd| j _n4| jdkr|jtjks|jtjkrd| j _nd| j _| j jdkrDt }| jdkr8||
 | }n
||
|}nN| j jdkrtt }||
d	| j|d	}n| j jdkrt }||
|}t||
|j|jd
S )a  
        pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
            Mask to avoid performing attention on padding pixel indices.
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width) of the input images.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, Siglip2ForImageClassification
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> torch.manual_seed(3)  # doctest: +IGNORE_RESULT
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> # note: we are loading a `Siglip2Model` from the hub here,
        >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
        >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip2-base-patch16-224")
        >>> model = Siglip2ForImageClassification.from_pretrained("google/siglip2-base-patch16-224")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        >>> # model predicts one of the two classes
        >>> predicted_class_idx = logits.argmax(-1).item()
        >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
        Predicted class: LABEL_1
        ```
        N)rr   rN   r   r   ).Nr   r  Z
regressionZsingle_label_classificationZmulti_label_classificationrP   )r-   logitsr   r   )r>   r   r   r
  r   rZ   rQ   r$   r  r   r   Zproblem_typer  rR   longrF   r	   Zsqueezer   r   r   r   r   r   )r9   re   r  rN   r  r   r   r   Zsequence_outputZ	pool_maskr  r-   Zloss_fctr(   r(   r)   rh     sT    /"



"

z%Siglip2ForImageClassification.forward)NNNNNN)r    r!   r"   r  r   r@   r   r   r   r$   rj   rk   r   r   rh   rl   r(   r(   rK   r)   r     s&         r   )r   r   r   r	  r   )rm   )rm   r   r   r   )r   r   r   )Fr   r   dataclassesr   typingr   r   r   r   numpyr   r$   Ztorch.nnrC   Ztorch.nn.functionalrz   r]   r   r   r	   Ztorch.nn.initr
   Zactivationsr   Zmodeling_attn_mask_utilsr   Zmodeling_layersr   Zmodeling_outputsr   r   r   Zmodeling_utilsr   r   utilsr   r   r   Zconfiguration_siglip2r   r   r   r   r*   r,   r  r<   rj   floatr   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r	  r   r   __all__r(   r(   r(   r)   <module>   s   #l >0P=% 
?(A3; u~