a
    h                  	   @   s  d Z ddlZddlZddlmZ ddlmZmZmZm	Z	 ddl
ZddlZddlmZ ddlmZmZmZ ddlmZ dd	lmZ dd
lmZ ddlmZ ddlmZmZmZ ddlmZm Z  ddl!m"Z"m#Z#m$Z$m%Z% ddl&m'Z'm(Z(m)Z) dd Z*dNej+e,e,e,e,ej+dddZ-dOddZ.dd Z/d d! Z0ee#d"d#G d$d% d%e"Z1ee#d&d#G d'd( d(e"Z2ee#G d)d* d*e"Z3G d+d, d,ej4Z5G d-d. d.ej4Z6dPej4ej+ej+ej+eej+ e,e,d/d0d1Z7G d2d3 d3ej4Z8G d4d5 d5ej4Z9G d6d7 d7eZ:e#G d8d9 d9e Z;G d:d; d;ej4Z<G d<d= d=ej4Z=e#d>d#G d?d@ d@e;Z>G dAdB dBej4Z?G dCdD dDej4Z@e#dEd#G dFdG dGe;ZAe#G dHdI dIe;ZBe#dJd#G dKdL dLe;ZCg dMZDdS )QzPyTorch Siglip model.    N)	dataclass)AnyCallableOptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss)_calculate_fan_in_and_fan_out   )ACT2FN)_prepare_4d_attention_mask)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPoolingImageClassifierOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)ModelOutputauto_docstringcan_return_tuple	torch_int   )SiglipConfigSiglipTextConfigSiglipVisionConfigc                 C   s   dd }||d|  k s(||d|  kr6t jddd ||| | }||| | }| d| d d| d  |   | |td  | | | j||d d S )	Nc                 S   s   dt | t d  d S )N      ?       @)matherfsqrt)x r#   f/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/siglip/modeling_siglip.pynorm_cdf(   s    z _trunc_normal_.<locals>.norm_cdf   zjmean is more than 2 std from [a, b] in nn.init.trunc_normal_. The distribution of values may be incorrect.)
stacklevelr   r   )minmax)	warningswarnuniform_Zerfinv_mul_r   r!   add_Zclamp_)tensormeanstdabr%   lur#   r#   r$   _trunc_normal_%   s     	
r6           r          r   )r/   r0   r1   r2   r3   returnc                 C   sL   t  0 t| dd|| | || W d   n1 s>0    Y  dS )an  Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(	ext{mean}, 	ext{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq 	ext{mean} \leq b`.

    NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
    bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
    and the result is subsequently scaled and shifted by the mean and std args.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    r   r   N)torchno_gradr6   r-   r.   )r/   r0   r1   r2   r3   r#   r#   r$   trunc_normal_tf_I   s    
r<   fan_innormalc           	      C   s  t | \}}|dkr|}n"|dkr(|}n|dkr<|| d }|| }|dkrdt| t|d d n|dkrt " | jt|d W d    n1 s0    Y  n\|d	krtd
| }t  | | | W d    n1 s0    Y  ntd| d S )Nr=   fan_outZfan_avgr&   truncated_normalg۶%?r1   r>   uniformr   zinvalid distribution )	r   r<   r   r!   r:   r;   normal_r,   
ValueError)	r/   scalemodedistributionr=   r?   denomZvarianceboundr#   r#   r$   variance_scaling_c   s$    
2
.rJ   c                 C   s   t | ddd d S )Nr=   r@   rF   rG   rJ   r/   r#   r#   r$   lecun_normal_|   s    rN   c                 C   s   t | ddd d S )Nr=   r>   rK   rL   rM   r#   r#   r$   default_flax_embed_init   s    rO   z}
    Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
    )Zcustom_introc                   @   sj   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dS )SiglipVisionModelOutputz
    image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
        The image embeddings obtained by applying the projection layer to the pooler_output.
    Nimage_embedslast_hidden_state.hidden_states
attentions)__name__
__module____qualname____doc__rQ   r   r:   FloatTensor__annotations__rR   rS   tuplerT   r#   r#   r#   r$   rP      s
   
rP   ze
    Base class for text model's outputs that also contains a pooling of the last hidden states.
    c                   @   sj   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dS )SiglipTextModelOutputz
    text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
        The text embeddings obtained by applying the projection layer to the pooler_output.
    Ntext_embedsrR   .rS   rT   )rU   rV   rW   rX   r]   r   r:   rY   rZ   rR   rS   r[   rT   r#   r#   r#   r$   r\      s
   
r\   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeed< dZeed	< ee d
ddZdS )SiglipOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
        Contrastive loss for image-text similarity.
    logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
        The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
        similarity scores.
    logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
        The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
        similarity scores.
    text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
        The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
    image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
        The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
    text_model_output (`BaseModelOutputWithPooling`):
        The output of the [`SiglipTextModel`].
    vision_model_output (`BaseModelOutputWithPooling`):
        The output of the [`SiglipVisionModel`].
    Nlosslogits_per_imagelogits_per_textr]   rQ   text_model_outputvision_model_outputr9   c                    s   t  fdd  D S )Nc                 3   s,   | ]$}|d vr | nt  | V  qdS ))rb   rc   N)getattrto_tuple).0kselfr#   r$   	<genexpr>   s   z(SiglipOutput.to_tuple.<locals>.<genexpr>)r[   keysri   r#   ri   r$   rf      s    zSiglipOutput.to_tuple)rU   rV   rW   rX   r_   r   r:   rY   rZ   r`   ra   r]   rQ   rb   r   rc   r[   r   rf   r#   r#   r#   r$   r^      s   
r^   c                       sP   e Zd Zed fddZejeeejdddZdej	ejdd	d
Z
  ZS )SiglipVisionEmbeddingsconfigc                    s   t    || _|j| _|j| _|j| _tj|j	| j| j| jdd| _
| j| j d | _| j| _t| j| j| _| jdt| jddd d S )NZvalid)Zin_channelsZout_channelsZkernel_sizeZstridepaddingr&   position_idsr   F
persistent)super__init__ro   hidden_size	embed_dimZ
image_size
patch_sizer   Conv2dZnum_channelspatch_embeddingnum_patchesnum_positions	Embeddingposition_embeddingregister_bufferr:   arangeexpandrj   ro   	__class__r#   r$   rw      s     
zSiglipVisionEmbeddings.__init__)
embeddingsheightwidthr9   c                 C   s   |j d }| jjj d }tj s>||kr>||kr>| | jS | jjd}|j d }|| j }|| j }	t	|d }
|
d|
|
|}|dddd}tjj|||	fddd	}|dddddd|}|S )
a  
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing and no class embeddings.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   r   rs   g      ?r   r&   ZbicubicF)sizerF   Zalign_corners)shaper   weightr:   Zjit
is_tracingrq   Z	unsqueezerz   r   reshapeZpermuter   
functionalZinterpolateview)rj   r   r   r   r}   r~   Zpatch_pos_embeddimZ
new_heightZ	new_widthZsqrt_num_positionsr#   r#   r$   interpolate_pos_encoding   s&    




z/SiglipVisionEmbeddings.interpolate_pos_encodingF)pixel_valuesr9   c           	      C   sh   |j \}}}}| jjj}| |j|d}|ddd}|rT|| ||| }n|| | j	 }|S )N)dtyper&   r   )
r   r|   r   r   toflatten	transposer   r   rq   )	rj   r   r   _r   r   Ztarget_dtypeZpatch_embedsr   r#   r#   r$   forward  s    
zSiglipVisionEmbeddings.forward)F)rU   rV   rW   r   rw   r:   Tensorintr   rY   r   __classcell__r#   r#   r   r$   rm      s   &rm   c                       sL   e Zd Zed fddZdeej eej eej ej	dddZ
  ZS )	SiglipTextEmbeddingsrn   c                    sR   t    |j}t|j|| _t|j|| _| j	dt
|jddd d S )Nrq   rr   Frt   )rv   rw   rx   r   r   Z
vocab_sizetoken_embeddingZmax_position_embeddingsr   r   r:   r   r   rj   ro   ry   r   r#   r$   rw     s    
zSiglipTextEmbeddings.__init__N)	input_idsrq   inputs_embedsr9   c                 C   s   |d ur|j d n|j d }| jjj d }||krFtd| d| |d u rd| jd d d |f }|d u rv| |}| |}|| }|S )Nrs   r   zRSequence length must be less than max_position_embeddings (got `sequence length`: z and max_position_embeddings: )r   r   r   rD   rq   r   )rj   r   rq   r   
seq_lengthZmax_position_embeddingZposition_embeddingsr   r#   r#   r$   r   &  s"    

zSiglipTextEmbeddings.forward)NNN)rU   rV   rW   r   rw   r   r:   
LongTensorrY   r   r   r   r#   r#   r   r$   r     s      r   )modulequerykeyvalueattention_maskscalingdropoutc           
      K   s|   t ||dd| }|d ur(|| }tjj|dt jd|j}tjj	||| j
d}t ||}	|	dd }	|	|fS )Nrs   r   )r   r   )ptrainingr   r&   )r:   matmulr   r   r   ZsoftmaxZfloat32r   r   r   r   
contiguous)
r   r   r   r   r   r   r   kwargsattn_weightsattn_outputr#   r#   r$   eager_attention_forwardA  s    
r   c                       sL   e Zd ZdZ fddZdejeej eejeej f dddZ	  Z
S )	SiglipAttentionz=Multi-headed attention from 'Attention Is All You Need' paperc                    s   t    || _|j| _|j| _| j| j | _| j| j | jkrZtd| j d| j d| jd | _	|j
| _d| _t| j| j| _t| j| j| _t| j| j| _t| j| j| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).      F)rv   rw   ro   rx   ry   num_attention_heads	num_headshead_dimrD   rE   Zattention_dropoutr   	is_causalr   Lineark_projv_projq_projout_projr   r   r#   r$   rw   [  s$    

zSiglipAttention.__init__N)rS   r   r9   c              
   K   s   |j \}}}| |}| |}| |}	|||| j| jdd}|||| j| jdd}|	||| j| jdd}	t}
| j	j
dkrt| j	j
 }
|
| |||	|| j| j| jsdn| jd\}}|||| }| |}||fS )z#Input shape: Batch x Time x Channelr   r&   eagerr7   )r   r   r   )r   r   r   r   r   r   r   r   r   ro   _attn_implementationr   r   rE   r   r   r   r   r   )rj   rS   r   r   
batch_sizer   ry   Zqueriesrl   valuesZattention_interfacer   r   r#   r#   r$   r   o  s.    




zSiglipAttention.forward)N)rU   rV   rW   rX   rw   r:   r   r   r[   r   r   r#   r#   r   r$   r   X  s    r   c                       s0   e Zd Z fddZejejdddZ  ZS )	SiglipMLPc                    sD   t    || _t|j | _t|j|j	| _
t|j	|j| _d S N)rv   rw   ro   r   Z
hidden_actactivation_fnr   r   rx   Zintermediate_sizefc1fc2r   r   r#   r$   rw     s
    
zSiglipMLP.__init__)rS   r9   c                 C   s"   |  |}| |}| |}|S r   )r   r   r   )rj   rS   r#   r#   r$   r     s    


zSiglipMLP.forward)rU   rV   rW   rw   r:   r   r   r   r#   r#   r   r$   r     s   r   c                       sN   e Zd Zeeef d fddZdejeje	e
 eej dddZ  ZS )	SiglipEncoderLayerrn   c                    sR   t    |j| _tj| j|jd| _t|| _	tj| j|jd| _
t|| _d S Neps)rv   rw   rx   ry   r   	LayerNormlayer_norm_epslayer_norm1r   	self_attnlayer_norm2r   mlpr   r   r#   r$   rw     s    

zSiglipEncoderLayer.__init__F)rS   r   output_attentionsr9   c                 C   sb   |}|  |}| j|||d\}}|| }|}| |}| |}|| }|f}|r^||f7 }|S )a=  
        Args:
            hidden_states (`torch.FloatTensor`):
                Input to the layer of shape `(batch, seq_len, embed_dim)`.
            attention_mask (`torch.FloatTensor`):
                Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )rS   r   r   )r   r   r   r   )rj   rS   r   r   residualr   outputsr#   r#   r$   r     s     




zSiglipEncoderLayer.forward)F)rU   rV   rW   r   r   r   rw   r:   r   r   boolr[   rY   r   r   r#   r#   r   r$   r     s    r   c                   @   s>   e Zd ZU eed< dZdZg dZdZdZ	dZ
dZdd ZdS )SiglipPreTrainedModelro   ZsiglipT)r   rm   r   #SiglipMultiheadAttentionPoolingHeadc                 C   s^  t |trJt | jtr | jjjn| jj}tjj|j	j
dt| d nt |tjrdt|j
 nt |trtj|jj
 tj|jj
 tj|jj
 tj|jj
 tj|jj tj|jj tj|jj tj|jj nht |trJtj|jj
 tj|jj
 tjj|jjdd tjj|jjdd nt |trtj|jj tj|jjj tj|jjj nt |t rt!"t!#d}|j$j%| |j&j'  nt |t(rtjj|j)j
| jjjd | jj* d nbt |tj+tj,fr2t-|j
 |jdurZtj|j n(t |tj.rZ|jj'  |j
j%d dS )zInitialize the weightsr   rA   gư>r   r   N)/
isinstancerm   ro   r   vision_configrx   r   initrC   r   r   npr!   r   rO   r   Zxavier_uniform_r   r   r   r   Zzeros_Zbiasr   r   r   r   probedata	attentionZin_proj_weightZin_proj_biasSiglipModelr:   logr/   logit_scaleZfill_
logit_biasZzero_SiglipForImageClassification
classifierZinitializer_factorr   r{   rN   r   )rj   r   r   Zlogit_scale_initr#   r#   r$   _init_weights  sT    

"

z#SiglipPreTrainedModel._init_weightsN)rU   rV   rW   r   rZ   Zbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesZ_supports_flash_attnZ_supports_sdpaZ_supports_flex_attnZ_supports_attention_backendr   r#   r#   r#   r$   r     s   
r   c                       sN   e Zd ZdZed fddZed	eej	 ee
 ee
 edddZ  ZS )
SiglipEncoderz
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`SiglipEncoderLayer`].

    Args:
        config: SiglipConfig
    rn   c                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r#   )r   )rg   r   rn   r#   r$   
<listcomp>#      z*SiglipEncoder.__init__.<locals>.<listcomp>F)	rv   rw   ro   r   Z
ModuleListrangeZnum_hidden_layerslayersZgradient_checkpointingr   r   rn   r$   rw      s    
 zSiglipEncoder.__init__N)r   r   output_hidden_statesr9   c           
      C   s   |dur|n| j j}|dur |n| j j}|r0dnd}|r<dnd}|}| jD ]:}|r\||f }||||d}	|	d }|rJ||	d f }qJ|r||f }t|||dS )ad  
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr#   )r   r   r   )rR   rS   rT   )ro   r   r   r   r   )
rj   r   r   r   r   Zencoder_statesZall_attentionsrS   Zencoder_layerZlayer_outputsr#   r#   r$   r   '  s0    


zSiglipEncoder.forward)NNN)rU   rV   rW   rX   r   rw   r   r   r:   r   r   r   r   r   r#   r#   r   r$   r     s      r   c                
       s^   e Zd Zed fddZeedeej	 eej	 eej	 ee
 ee
 edddZ  ZS )	SiglipTextTransformerrn   c                    sP   t    || _|j}t|| _t|| _tj	||j
d| _t||j| _d S r   )rv   rw   ro   rx   r   r   r   encoderr   r   r   final_layer_normr   Zprojection_sizeheadr   r   r#   r$   rw   h  s    


zSiglipTextTransformer.__init__Nr   r   rq   r   r   r9   c                 C   s   |d ur|n| j j}|d ur |n| j j}|d u r8td| }|d|d }| j||d}d| j jv }|rtd }n|d ur|st||j	}| j
||||d}	|	j}
| |
}
|
d d dd d f }| |}t|
||	j|	jdS )NzYou have to specify input_idsrs   )r   rq   Zflash)r   r   r   r   rR   pooler_outputrS   rT   )ro   r   r   rD   r   r   r   r   r   r   r   rR   r   r   r   rS   rT   )rj   r   r   rq   r   r   Zinput_shaperS   Zuses_flash_attentionencoder_outputsrR   pooled_outputr#   r#   r$   r   r  s:    


zSiglipTextTransformer.forward)NNNNN)rU   rV   rW   r   rw   r   r   r   r:   r   r   r   r   r   r#   r#   r   r$   r   g  s    
     r   zK
    The text model from SigLIP without any head or projection on top.
    c                
       s   e Zd ZU eed< ed fddZejdddZdd	 Z	e
edeej eej eej ee ee edddZ  ZS )SiglipTextModelro   rn   c                    s"   t  | t|| _|   d S r   )rv   rw   r   
text_model	post_initr   r   r#   r$   rw     s    
zSiglipTextModel.__init__rd   c                 C   s
   | j jjS r   r   r   r   ri   r#   r#   r$   get_input_embeddings  s    z$SiglipTextModel.get_input_embeddingsc                 C   s   || j j_d S r   r   )rj   r   r#   r#   r$   set_input_embeddings  s    z$SiglipTextModel.set_input_embeddingsNr   c                 C   s   | j |||||dS )a  
        Examples:

        ```python
        >>> from transformers import AutoTokenizer, SiglipTextModel

        >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")

        >>> # important: make sure to set padding="max_length" as that's how the model was trained
        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled (EOS token) states
        ```r   r   rq   r   r   )r   )rj   r   r   rq   r   r   r#   r#   r$   r     s    zSiglipTextModel.forward)NNNNN)rU   rV   rW   r   rZ   rw   r   Moduler   r   r   r   r   r:   r   r   r   r   r   r#   r#   r   r$   r     s&   
     r   c                       sL   e Zd Zed fddZeed	ee ee ee e	dddZ
  ZS )
SiglipVisionTransformerrn   c                    sf   t    || _|j}t|| _t|| _tj	||j
d| _t|dsJdn|j| _| jrbt|| _d S )Nr   vision_use_headT)rv   rw   ro   rx   rm   r   r   r   r   r   r   post_layernormhasattrr   use_headr   r   r   r   r#   r$   rw     s    


z SiglipVisionTransformer.__init__NFr   r   r   r9   c           	      C   s~   |d ur|n| j j}|d ur |n| j j}| j||d}| j|||d}|j}| |}| jrf| |nd }t	|||j
|jdS )N)r   )r   r   r   r   )ro   r   r   r   r   rR   r   r  r   r   rS   rT   )	rj   r   r   r   r   rS   r   rR   r   r#   r#   r$   r     s$    	
zSiglipVisionTransformer.forward)NNF)rU   rV   rW   r   rw   r   r   r   r   r   r   r   r#   r#   r   r$   r     s      r   c                       s.   e Zd ZdZed fddZdd Z  ZS )r   zMultihead Attention Pooling.rn   c                    s\   t    ttdd|j| _tjj|j|j	dd| _
tj|j|jd| _t|| _d S )Nr   T)Zbatch_firstr   )rv   rw   r   	Parameterr:   randnrx   r   ZMultiheadAttentionr   r   r   r   	layernormr   r   r   r   r#   r$   rw     s
    
z,SiglipMultiheadAttentionPoolingHead.__init__c                 C   sX   |j d }| j|dd}| |||d }|}| |}|| | }|d d df S )Nr   r   )r   r   repeatr   r  r   )rj   Zhidden_stater   r   r   r#   r#   r$   r     s    

z+SiglipMultiheadAttentionPoolingHead.forward)rU   rV   rW   rX   r   rw   r   r   r#   r#   r   r$   r     s   r   zM
    The vision model from SigLIP without any head or projection on top.
    c                       sf   e Zd ZU eed< dZed fddZejdddZ	e
edee ee eedddZ  ZS )SiglipVisionModelro   r   rn   c                    s"   t  | t|| _|   d S r   )rv   rw   r   vision_modelr   r   r   r#   r$   rw   1  s    
zSiglipVisionModel.__init__rd   c                 C   s
   | j jjS r   )r  r   r|   ri   r#   r#   r$   r   9  s    z&SiglipVisionModel.get_input_embeddingsNFr  c                 C   s   | j ||||dS )a  
        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, SiglipVisionModel

        >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
        >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled features
        ```r   r   r   r   )r  )rj   r   r   r   r   r#   r#   r$   r   <  s    zSiglipVisionModel.forward)NNF)rU   rV   rW   r   rZ   main_input_namerw   r   r   r   r   r   r   r   r   r   r   r#   r#   r   r$   r  (  s   
   r  c                       s   e Zd ZU eed< ed fddZedeej	 eej	 eej	 ee
 ee
 ejdddZedeej ee
 ee
 e
ejd
ddZeedeej eej eej	 eej ee
 ee
 ee
 e
ed	ddZ  ZS )r   ro   rn   c                    s   t  | t|jts.tdt|j dt|jtsPtdt|j d|j}|j}t	
|}t
|}|j| _|j| _ttd| _ttd| _|   d S )NzMconfig.text_config is expected to be of type SiglipTextConfig but is of type .zQconfig.vision_config is expected to be of type SiglipVisionConfig but is of type r   )rv   rw   r   text_configr   	TypeErrortyper   r   r   _from_configr  r   r  r   r  r:   r  r   r   r   )rj   ro   r  r   r   r  r   r#   r$   rw   f  s,    

zSiglipModel.__init__Nr   c                 C   sF   |dur|n| j j}|dur |n| j j}| j|||||d}|j}|S )aJ  
        Returns:
            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of [`SiglipTextModel`].

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, AutoModel
        >>> import torch

        >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")

        >>> # important: make sure to set padding="max_length" as that's how the model was trained
        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
        >>> with torch.no_grad():
        ...     text_features = model.get_text_features(**inputs)
        ```Nr   )ro   r   r   r   r   )rj   r   r   rq   r   r   text_outputsr   r#   r#   r$   get_text_features  s    zSiglipModel.get_text_featuresF)r   r   r   r   r9   c                 C   sD   |dur|n| j j}|dur |n| j j}| j||||d}|j}|S )a  
        Returns:
            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of [`SiglipVisionModel`].

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, AutoModel
        >>> import torch

        >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
        >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> with torch.no_grad():
        ...     image_features = model.get_image_features(**inputs)
        ```Nr	  )ro   r   r   r  r   )rj   r   r   r   r   vision_outputsr   r#   r#   r$   get_image_features  s    !zSiglipModel.get_image_features)	r   r   r   rq   return_lossr   r   r   r9   c	              	   C   sD  |dur|n| j j}|dur |n| j j}| j||||d}	| j|||||d}
|	j}|
j}||jdddd }||jdddd }t||	 
|j}| j
|j| j
|j }}||  | }|	 }d}|r.tj|d|jd	}t| d|  }tjj|| }tj|dd
 }| }t||||||
|	dS )a  
        return_loss (`bool`, *optional*):
            Whether or not to return the contrastive loss.

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, AutoModel
        >>> import torch

        >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
        >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
        >>> # important: we pass `padding=max_length` since the model was trained with this
        >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**inputs)

        >>> logits_per_image = outputs.logits_per_image
        >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
        >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
        31.9% that image 0 is 'a photo of 2 cats'
        ```Nr	  r   r&   rs   T)r   r   Zkeepdimr   )devicer   )r_   r`   ra   r]   rQ   rb   rc   )ro   r   r   r  r   r   Znormr:   r   tr   r  r   r   expeyer   Z	ones_liker   r   Z
logsigmoidsumr0   r^   )rj   r   r   r   rq   r  r   r   r   r  r  rQ   r]   ra   r   r   r`   r_   r  Zm1_diag1ZloglikZnllr#   r#   r$   r     sP    ,zSiglipModel.forward)NNNNN)NNNF)NNNNNNNF)rU   rV   rW   r   rZ   rw   r   r   r:   r   r   rY   r  r  r   r   r^   r   r   r#   r#   r   r$   r   b  sb   
      -    0        r   z
    SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
    the patch tokens) e.g. for ImageNet.
    c                
       s^   e Zd ZdZedd fddZeed
ee	j
 ee	j
 ee ee eeddd	Z  ZS )r   r   N)ro   r9   c                    sZ   t  | |j| _t|j}|j| _|jdkrDt|jj	|jnt
 | _|   d S )Nr   )rv   rw   
num_labelsr  r  r   r  r   r   rx   ZIdentityr   r   )rj   ro   r  r   r#   r$   rw   Q  s    "z%SiglipForImageClassification.__init__F)r   labelsr   r   r   r9   c                 C   sf  |dur|n| j j}|dur |n| j j}| j||||d}|j}tj|dd}| |}d}	|durR||j	}| j j
du r| jdkrd| j _
n4| jdkr|jtjks|jtjkrd| j _
nd| j _
| j j
dkrt }
| jdkr|
| | }	n
|
||}	nN| j j
dkr4t }
|
|d| j|d}	n| j j
dkrRt }
|
||}	t|	||j|jd	S )
a$  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, SiglipForImageClassification
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> torch.manual_seed(3)  # doctest: +IGNORE_RESULT
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> # note: we are loading a `SiglipModel` from the hub here,
        >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
        >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
        >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        >>> # model predicts one of the two classes
        >>> predicted_class_idx = logits.argmax(-1).item()
        >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
        Predicted class: LABEL_1
        ```N)r   r   r   r   r  Z
regressionZsingle_label_classificationZmulti_label_classificationrs   )r_   logitsrS   rT   )ro   r   r   r  rR   r:   r0   r   r   r  Zproblem_typer  r   longr   r
   Zsqueezer	   r   r   r   rS   rT   )rj   r   r  r   r   r   r   Zsequence_outputr  r_   Zloss_fctr#   r#   r$   r   c  sL    )



"


z$SiglipForImageClassification.forward)NNNNF)rU   rV   rW   r
  r   rw   r   r   r   r:   r   r   r   r   r   r#   r#   r   r$   r   H  s"        r   )r   r   r   r  r   )r7   r   r8   r   )r   r=   r>   )r7   )ErX   r   r*   dataclassesr   typingr   r   r   r   numpyr   r:   r   Ztorch.nnr   r	   r
   Ztorch.nn.initr   Zactivationsr   Zmodeling_attn_mask_utilsr   Zmodeling_layersr   Zmodeling_outputsr   r   r   Zmodeling_utilsr   r   utilsr   r   r   r   Zconfiguration_siglipr   r   r   r6   r   floatr<   rJ   rN   rO   rP   r\   r^   r   rm   r   r   r   r   r   r   r   r   r   r   r   r  r   r   __all__r#   r#   r#   r$   <module>   s   % 
#I/ ?0@PA305 fr