a
    h                  
   @   sF  d Z ddlZddlmZ ddlmZmZmZmZ ddl	Z	ddl
Z	ddl	mZ ddlmZ ddlmZ dd	lmZmZmZmZ dd
lmZmZ ddlmZmZmZ ddlmZmZmZm Z  ddl!m"Z"m#Z#m$Z$ e %e&Z'eeddG dd deZ(eeddG dd deZ)eeG dd deZ*e	j+e	j+dddZ,e	j+e	j+dddZ-e$e.ddd Z/dVee.e0f e1d"d#d$Z2G d%d& d&ej3Z4G d'd( d(ej5Z6G d)d* d*ej3Z7G d+d, d,ej3Z8G d-d. d.ej3Z9G d/d0 d0ej3Z:G d1d2 d2ej3Z;G d3d4 d4ej3Z<G d5d6 d6ej3Z=dWej3e	j+e	j+e	j+ee	j+ e>e>ee	j+ d8d9d:Z?G d;d< d<ej3Z@G d=d> d>ej3ZAG d?d@ d@ej3ZBG dAdB dBej3ZCG dCdD dDej3ZDG dEdF dFeZEG dGdH dHej3ZFG dIdJ dJej3ZGeG dKdL dLeZHedMdG dNdO dOeHZIedPdG dQdR dReHZJeG dSdT dTeHZKg dUZLdS )XzPyTorch ALIGN model.    N)	dataclass)AnyCallableOptionalUnion)nn   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithNoAttentionBaseModelOutputWithPooling(BaseModelOutputWithPoolingAndNoAttention)ALL_ATTENTION_FUNCTIONSPreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputauto_docstringcan_return_tuplelogging   )AlignConfigAlignTextConfigAlignVisionConfigz}
    Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
    )Zcustom_introc                   @   sL   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dS )AlignVisionModelOutputz
    image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
        The image embeddings obtained by applying the projection layer to the pooler_output.
    Nimage_embedslast_hidden_statehidden_states)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   tuple r(   r(   d/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/align/modeling_align.pyr   *   s   
r   ze
    Base class for text model's outputs that also contains a pooling of the last hidden states.
    c                   @   sb   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dS )AlignTextModelOutputz
    text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
        The text embeddings obtained by applying the projection layer to the pooler_output.
    Ntext_embedsr   r   
attentions)r    r!   r"   r#   r+   r   r$   r%   r&   r   r   r'   r,   r(   r(   r(   r)   r*   ;   s
   
r*   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeed< dZeed	< ee d
ddZdS )AlignOutputar  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
        Contrastive loss for image-text similarity.
    logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
        The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
        similarity scores.
    logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
        The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
        similarity scores.
    text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
        The text embeddings obtained by applying the projection layer to the pooled output of [`AlignTextModel`].
    image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
        The output of [`AlignVisionModel`].
    text_model_output (`BaseModelOutputWithPooling`):
        The output of the [`AlignTextModel`].
    vision_model_output (`BaseModelOutputWithPoolingAndNoAttention`):
        The output of the [`AlignVisionModel`].
    Nlosslogits_per_imagelogits_per_textr+   r   text_model_outputvision_model_outputreturnc                    s   t  fdd  D S )Nc                 3   s,   | ]$}|d vr | nt  | V  qdS ))r1   r2   N)getattrto_tuple).0kselfr(   r)   	<genexpr>l   s   z'AlignOutput.to_tuple.<locals>.<genexpr>)r'   keysr9   r(   r9   r)   r6   k   s    zAlignOutput.to_tuple)r    r!   r"   r#   r.   r   r$   r%   r&   r/   r0   r+   r   r1   r   r2   r   r'   r   r6   r(   r(   r(   r)   r-   M   s   
r-   )logitsr4   c                 C   s"   t jj| tjt| | jdddS )Ndeviceg?)Zlabel_smoothing)r   
functionalZcross_entropyr$   arangelenr?   )r=   r(   r(   r)   contrastive_losst   s    rC   )
similarityr4   c                 C   s    t | }t |  }|| d S )Ng       @)rC   t)rD   Zcaption_lossZ
image_lossr(   r(   r)   
align_lossx   s    rF   )confignum_channelsc                 C   sJ   | j }|| j9 }t|t||d  | | }|d| k rB||7 }t|S )z<
    Round number of filters based on depth multiplier.
       g?)Zdepth_divisorZwidth_coefficientmaxint)rG   rH   ZdivisorZnew_dimr(   r(   r)   round_filters   s    
rL   T)kernel_sizeadjustc                 C   sr   t | tr| | f} | d d | d d f}|rR|d d |d |d d |d fS |d |d |d |d fS dS )aJ  
    Utility function to get the tuple padding value for the depthwise convolution.

    Args:
        kernel_size (`int` or `tuple`):
            Kernel size of the convolution layers.
        adjust (`bool`, *optional*, defaults to `True`):
            Adjusts padding value to apply to right and bottom sides of the input.
    r   rI   r   N)
isinstancerK   )rM   rN   Zcorrectr(   r(   r)   correct_pad   s    

$rP   c                       s:   e Zd ZdZed fddZejejdddZ  Z	S )AlignVisionEmbeddingszL
    A module that corresponds to the stem module of the original work.
    rG   c                    sh   t    t|d| _tjdd| _tj|j| jddddd| _	tj
| j|j|jd	| _t|j | _d S )
N    )r   r   r   r   paddingr   rI   validFrM   striderU   bias)epsmomentum)super__init__rL   out_dimr   	ZeroPad2drU   Conv2drH   convolutionBatchNorm2dbatch_norm_epsbatch_norm_momentum	batchnormr	   
hidden_act
activationr:   rG   	__class__r(   r)   r]      s    
zAlignVisionEmbeddings.__init__)pixel_valuesr4   c                 C   s,   |  |}| |}| |}| |}|S N)rU   ra   re   rg   )r:   rk   featuresr(   r(   r)   forward   s
    



zAlignVisionEmbeddings.forward)
r    r!   r"   r#   r   r]   r$   Tensorrn   __classcell__r(   r(   ri   r)   rQ      s   rQ   c                       s   e Zd Zd fdd	Z  ZS )	AlignVisionDepthwiseConv2dr   r   r   Tzerosc	           
         s*   || }	t  j||	|||||||d	 d S )N)	in_channelsout_channelsrM   rX   rU   dilationgroupsrY   padding_mode)r\   r]   )
r:   rs   Zdepth_multiplierrM   rX   rU   ru   rY   rw   rt   ri   r(   r)   r]      s    z#AlignVisionDepthwiseConv2d.__init__)r   r   r   r   r   Trr   )r    r!   r"   r]   rp   r(   r(   ri   r)   rq      s          rq   c                       s@   e Zd ZdZeeeed fddZejej	dddZ
  ZS )AlignVisionExpansionLayerz_
    This corresponds to the expansion phase of each block in the original implementation.
    rG   in_dimr^   rX   c                    sB   t    tj||dddd| _tj||jd| _t|j	 | _
d S )Nr   sameFrs   rt   rM   rU   rY   )num_featuresrZ   )r\   r]   r   r`   expand_convrb   rc   	expand_bnr	   rf   
expand_act)r:   rG   rz   r^   rX   ri   r(   r)   r]      s    
z"AlignVisionExpansionLayer.__init__r   r4   c                 C   s"   |  |}| |}| |}|S rl   )r~   r   r   r:   r   r(   r(   r)   rn      s    


z!AlignVisionExpansionLayer.forward)r    r!   r"   r#   r   rK   r]   r$   r%   ro   rn   rp   r(   r(   ri   r)   rx      s   rx   c                       sB   e Zd ZdZeeeeed fddZej	ej
dddZ  ZS )AlignVisionDepthwiseLayerzk
    This corresponds to the depthwise convolution phase of each block in the original implementation.
    rG   rz   rX   rM   adjust_paddingc                    sv   t    || _| jdkrdnd}t||d}tj|d| _t||||dd| _tj	||j
|jd| _t|j | _d S )	NrI   rV   r{   )rN   rT   FrW   r}   rZ   r[   )r\   r]   rX   rP   r   r_   depthwise_conv_padrq   depthwise_convrb   rc   rd   depthwise_normr	   rf   depthwise_act)r:   rG   rz   rX   rM   r   Zconv_padrU   ri   r(   r)   r]      s    


z"AlignVisionDepthwiseLayer.__init__r   c                 C   s6   | j dkr| |}| |}| |}| |}|S )NrI   )rX   r   r   r   r   r   r(   r(   r)   rn     s    




z!AlignVisionDepthwiseLayer.forwardr    r!   r"   r#   r   rK   boolr]   r$   r%   ro   rn   rp   r(   r(   ri   r)   r      s   r   c                       sB   e Zd ZdZd	eeeed fddZej	ej
dddZ  ZS )
AlignVisionSqueezeExciteLayerzl
    This corresponds to the Squeeze and Excitement phase of each block in the original implementation.
    FrG   rz   
expand_dimexpandc                    s   t    |r|n|| _tdt||j | _tjdd| _	tj
| j| jddd| _tj
| j| jddd| _t|j | _t | _d S )Nr   )Zoutput_sizer{   )rs   rt   rM   rU   )r\   r]   dimrJ   rK   Zsqueeze_expansion_ratioZdim_ser   ZAdaptiveAvgPool2dsqueezer`   reducer   r	   rf   
act_reduceZSigmoid
act_expand)r:   rG   rz   r   r   ri   r(   r)   r]   !  s$    
z&AlignVisionSqueezeExciteLayer.__init__r   c                 C   sF   |}|  |}| |}| |}| |}| |}t||}|S rl   )r   r   r   r   r   r$   mul)r:   r   inputsr(   r(   r)   rn   6  s    




z%AlignVisionSqueezeExciteLayer.forward)Fr   r(   r(   ri   r)   r     s   r   c                       sH   e Zd ZdZeeeeeed fddZe	j
e	j
e	jdddZ  ZS )AlignVisionFinalBlockLayerz[
    This corresponds to the final phase of each block in the original implementation.
    rG   rz   r^   rX   	drop_rateid_skipc                    sX   t    |dko| | _tj||dddd| _tj||j|jd| _	tj
|d| _d S )Nr   r{   Fr|   r   )p)r\   r]   apply_dropoutr   r`   project_convrb   rc   rd   
project_bnDropoutdropout)r:   rG   rz   r^   rX   r   r   ri   r(   r)   r]   H  s    

z#AlignVisionFinalBlockLayer.__init__)
embeddingsr   r4   c                 C   s0   |  |}| |}| jr,| |}|| }|S rl   )r   r   r   r   )r:   r   r   r(   r(   r)   rn   Y  s    


z"AlignVisionFinalBlockLayer.forwardr    r!   r"   r#   r   rK   floatr   r]   r$   r%   ro   rn   rp   r(   r(   ri   r)   r   C  s   r   c                
       sJ   e Zd ZdZeeeeeeeeed	 fddZe	j
e	jdddZ  ZS )AlignVisionBlocka  
    This corresponds to the block module of original the EfficientNet vision encoder implementation.

    Args:
        config ([`AlignVisionConfig`]):
            Model configuration class.
        in_dim (`int`):
            Number of input channels.
        out_dim (`int`):
            Number of output channels.
        stride (`int`):
            Stride size to be used in convolution layers.
        expand_ratio (`int`):
            Expand ratio to set the output dimensions for the expansion and squeeze-excite layers.
        kernel_size (`int`):
            Kernel size for the depthwise convolution layer.
        drop_rate (`float`):
            Dropout rate to be used in the final phase of each block.
        id_skip (`bool`):
            Whether to apply dropout and sum the final hidden states with the input embeddings during the final phase
            of each block. Set to `True` for the first block of each stage.
        adjust_padding (`bool`):
            Whether to apply padding to only right and bottom side of the input kernel before the depthwise convolution
            operation, set to `True` for inputs with odd input sizes.
    )	rG   rz   r^   rX   expand_ratiorM   r   r   r   c
                    s   t    || _| jdk| _|| }
| jr<t|||
|d| _t|| jrJ|
n||||	d| _t|||
| jd| _	t
|| jr||
n|||||d| _d S )Nr   ry   r   r   r   )r\   r]   r   r   rx   	expansionr   r   r   squeeze_exciter   
projection)r:   rG   rz   r^   rX   r   rM   r   r   r   Zexpand_in_dimri   r(   r)   r]     s4    

zAlignVisionBlock.__init__r   c                 C   s<   |}| j dkr| |}| |}| |}| ||}|S Nr   )r   r   r   r   r   )r:   r   r   r(   r(   r)   rn     s    



zAlignVisionBlock.forwardr   r(   r(   ri   r)   r   d  s   )r   c                       sF   e Zd ZdZed fddZd
ejee	 ee	 e
ddd	Z  ZS )AlignVisionEncoderz
    Forward propagates the embeddings through each vision encoder (EfficientNet) block.

    Args:
        config ([`AlignVisionConfig`]):
            Model configuration class.
    rR   c                    s(  t    |j_fdd t|j}t fdd|jD }d}g }t|D ]}t||j| }t||j	| }|j
| }	|j| }
|j| }t |j| D ]p}|dk}|dkrdn|	}	|dkr|n|}||jv}|j| | }t||||	|
||||d	}|| |d7 }qqPt|_d S )Nc                    s   t t j|  S rl   )rK   mathceildepth_coefficient)Zrepeatsr9   r(   r)   round_repeats  s    z2AlignVisionEncoder.__init__.<locals>.round_repeatsc                 3   s   | ]} |V  qd S rl   r(   )r7   n)r   r(   r)   r;         z.AlignVisionEncoder.__init__.<locals>.<genexpr>r   r   )	rG   rz   r^   rX   rM   r   r   r   r   )r\   r]   r   rB   rs   sumZnum_block_repeatsrangerL   rt   stridesZkernel_sizesZexpand_ratiosZdepthwise_paddingZdrop_connect_rater   appendr   
ModuleListblocks)r:   rG   Znum_base_blocksZ
num_blocksZcurr_block_numr   irz   r^   rX   rM   r   jr   r   r   blockri   )r   r:   r)   r]     sB    






zAlignVisionEncoder.__init__FT)r   output_hidden_statesreturn_dictr4   c                 C   sV   |r
|fnd }| j D ]}||}|r||f7 }q|sJtdd ||fD S t||dS )Nc                 s   s   | ]}|d ur|V  qd S rl   r(   )r7   vr(   r(   r)   r;     r   z-AlignVisionEncoder.forward.<locals>.<genexpr>)r   r   )r   r'   r   )r:   r   r   r   all_hidden_statesr   r(   r(   r)   rn     s    
zAlignVisionEncoder.forward)FT)r    r!   r"   r#   r   r]   r$   r%   r   r   r   rn   rp   r(   r(   ri   r)   r     s   .  r   c                       sR   e Zd ZdZ fddZdeej eej eej eej ej	dddZ
  ZS )	AlignTextEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _t|dd| _| jdt|jddd | jd	tj| j tjd
dd d S )N)padding_idxrZ   position_embedding_typeabsoluteposition_ids)r   F)
persistenttoken_type_ids)dtype)r\   r]   r   	EmbeddingZ
vocab_sizehidden_sizeZpad_token_idword_embeddingsZmax_position_embeddingsposition_embeddingsZtype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsr   hidden_dropout_probr   r5   r   Zregister_bufferr$   rA   r   rr   r   sizelongrh   ri   r(   r)   r]     s    
zAlignTextEmbeddings.__init__N)	input_idsr   r   inputs_embedsr4   c                 C   s   |d ur|  }n|  d d }|d }|d u rH| jd d d |f }|d u rt| dr| jd d d |f }||d |}|}ntj|tj| jjd}|d u r| 	|}| 
|}	||	 }
| jdkr| |}|
|7 }
| |
}
| |
}
|
S )Nr   r   r   r   r   r?   r   )r   r   hasattrr   r   r$   rr   r   r?   r   r   r   r   r   r   )r:   r   r   r   r   input_shape
seq_lengthbuffered_token_type_ids buffered_token_type_ids_expandedr   r   r   r(   r(   r)   rn     s,    







zAlignTextEmbeddings.forward)NNNN)r    r!   r"   r#   r]   r   r$   
LongTensorr%   ro   rn   rp   r(   r(   ri   r)   r     s       r           )modulequerykeyvalueattention_maskscalingr   	head_maskc                 K   s   t ||dd| }	|d urN|d d d d d d d |jd f }
|	|
 }	tjj|	dt jd|j	}	tjj
|	|| jd}	|d ur|	|dddd }	t |	|}|dd }||	fS )NrI   r   r   )r   r   )r   trainingr   )r$   matmul	transposeshaper   r@   ZsoftmaxZfloat32tor   r   r   view
contiguous)r   r   r   r   r   r   r   r   kwargsattn_weightsZcausal_maskattn_outputr(   r(   r)   eager_attention_forward>  s    &r   c                       sL   e Zd Z fddZdejeej eej ee e	ej dddZ
  ZS )	AlignTextSelfAttentionc                    s   t    |j|j dkr>t|ds>td|j d|j d|| _|j| _t|j|j | _| j| j | _	t
|j| j	| _t
|j| j	| _t
|j| j	| _t
|j| _|j| _| jd | _d S )Nr   Zembedding_sizezThe hidden size (z6) is not a multiple of the number of attention heads ()g      )r\   r]   r   num_attention_headsr   
ValueErrorrG   rK   attention_head_sizeall_head_sizer   Linearr   r   r   r   Zattention_probs_dropout_probr   attention_dropoutr   rh   ri   r(   r)   r]   Z  s"    

zAlignTextSelfAttention.__init__NFr   r   r   output_attentionsr4   c                 K   s   |j d d }g |d| jR }| ||dd}| ||dd}	| ||dd}
t}| jj	dkrt
| jj	 }|| ||	|
|f| jsdn| j| j|d|\}}|jg |dR   }|r||fn|f}|S )Nr   r   rI   eagerr   )r   r   r   )r   r   r   r   r   r   r   r   rG   Z_attn_implementationr   r   r   r   reshaper   )r:   r   r   r   r   r   r   Zhidden_shapeZquery_statesZ
key_statesZvalue_statesZattention_interfacer   r   outputsr(   r(   r)   rn   o  s0    	
zAlignTextSelfAttention.forward)NNF)r    r!   r"   r]   r$   ro   r   r%   r   r'   rn   rp   r(   r(   ri   r)   r   Y  s      r   c                       s4   e Zd Z fddZejejejdddZ  ZS )AlignTextSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S Nr   )r\   r]   r   r   r   denser   r   r   r   r   rh   ri   r(   r)   r]     s    
zAlignTextSelfOutput.__init__r   input_tensorr4   c                 C   s&   |  |}| |}| || }|S rl   r   r   r   r:   r   r  r(   r(   r)   rn     s    

zAlignTextSelfOutput.forwardr    r!   r"   r]   r$   ro   rn   rp   r(   r(   ri   r)   r     s   r   c                       sT   e Zd Z fddZdd Zd
ejeej eej ee	 e
ej ddd	Z  ZS )AlignTextAttentionc                    s*   t    t|| _t|| _t | _d S rl   )r\   r]   r   r:   r   outputsetpruned_headsrh   ri   r(   r)   r]     s    


zAlignTextAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   )r   )rB   r   r:   r   r   r  r   r   r   r   r  r   r   union)r:   Zheadsindexr(   r(   r)   prune_heads  s    zAlignTextAttention.prune_headsNFr   c           	      K   s@   | j |f|||d|}| |d |}|f|dd   }|S N)r   r   r   r   r   )r:   r  )	r:   r   r   r   r   r   Zself_outputsattention_outputr   r(   r(   r)   rn     s    zAlignTextAttention.forward)NNF)r    r!   r"   r]   r  r$   ro   r   r%   r   r'   rn   rp   r(   r(   ri   r)   r    s      r  c                       s0   e Zd Z fddZejejdddZ  ZS )AlignTextIntermediatec                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S rl   )r\   r]   r   r   r   intermediate_sizer   rO   rf   strr	   intermediate_act_fnrh   ri   r(   r)   r]     s
    
zAlignTextIntermediate.__init__r   c                 C   s   |  |}| |}|S rl   )r   r  r   r(   r(   r)   rn     s    

zAlignTextIntermediate.forwardr  r(   r(   ri   r)   r    s   r  c                       s4   e Zd Z fddZejejejdddZ  ZS )AlignTextOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S r   )r\   r]   r   r   r  r   r   r   r   r   r   r   rh   ri   r(   r)   r]     s    
zAlignTextOutput.__init__r   c                 C   s&   |  |}| |}| || }|S rl   r  r  r(   r(   r)   rn     s    

zAlignTextOutput.forwardr  r(   r(   ri   r)   r    s   r  c                       sT   e Zd Z fddZd
ejeej eej ee e	ej dddZ
dd	 Z  ZS )AlignTextLayerc                    s:   t    |j| _d| _t|| _t|| _t|| _	d S r   )
r\   r]   chunk_size_feed_forwardseq_len_dimr  	attentionr  intermediater  r  rh   ri   r(   r)   r]     s    


zAlignTextLayer.__init__NFr   c           
      K   sP   | j |f|||d|}|d }|dd  }t| j| j| j|}	|	f| }|S r  )r  r   feed_forward_chunkr  r  )
r:   r   r   r   r   r   Zself_attention_outputsr  r   layer_outputr(   r(   r)   rn     s     
zAlignTextLayer.forwardc                 C   s   |  |}| ||}|S rl   )r  r  )r:   r  Zintermediate_outputr  r(   r(   r)   r    s    
z!AlignTextLayer.feed_forward_chunk)NNF)r    r!   r"   r]   r$   ro   r   r%   r   r'   rn   r  rp   r(   r(   ri   r)   r    s      r  c                       sd   e Zd Z fddZed	ejeej eej ee	 ee	 ee	 e
eej ef dddZ  ZS )
AlignTextEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r(   )r  )r7   r   rR   r(   r)   
<listcomp>  r   z-AlignTextEncoder.__init__.<locals>.<listcomp>F)	r\   r]   rG   r   r   r   num_hidden_layerslayerZgradient_checkpointingrh   ri   rR   r)   r]     s    
 zAlignTextEncoder.__init__NFT)r   r   r   r   r   r   r4   c                 K   s   |rdnd }|rdnd }	t | jD ]\\}
}|r8||f }|d urH||
 nd }|f ||||d|}|d }|r"|	|d f }	q"|r||f }t|||	dS )Nr(   )r   r   r   r   r   r   )r   r   r,   )	enumerater  r   )r:   r   r   r   r   r   r   r   r   Zall_self_attentionsr   Zlayer_moduleZlayer_head_maskZlayer_outputsr(   r(   r)   rn     s0    

zAlignTextEncoder.forward)NNFFT)r    r!   r"   r]   r   r$   ro   r   r%   r   r   r'   r   rn   rp   r(   r(   ri   r)   r    s         r  c                       s0   e Zd Z fddZejejdddZ  ZS )AlignTextPoolerc                    s*   t    t|j|j| _t | _d S rl   )r\   r]   r   r   r   r   ZTanhrg   rh   ri   r(   r)   r]   I  s    
zAlignTextPooler.__init__r   c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   rg   )r:   r   Zfirst_token_tensorpooled_outputr(   r(   r)   rn   N  s    

zAlignTextPooler.forwardr  r(   r(   ri   r)   r  H  s   r  c                   @   s.   e Zd ZU eed< dZdZejdddZ	dS )AlignPreTrainedModelrG   alignT)r   c                 C   s   | j j}t|tjtjfrD|jjjd|d |j	dur|j	j
  nvt|trtj|jj |jj	j
  |jj| j j n:t|tjr|jjjd|d |jdur|jj|j 
  t|tjtjfr|j	j
  |jjd dS )zInitialize the weightsr   )meanstdNg      ?)rG   Zinitializer_rangerO   r   r   r`   weightdataZnormal_rY   Zzero_
AlignModelinitZxavier_uniform_text_projectiontemperatureZfill_temperature_init_valuer   r   r   rb   )r:   r   r$  r(   r(   r)   _init_weights]  s     


z"AlignPreTrainedModel._init_weightsN)
r    r!   r"   r   r&   Zbase_model_prefixsupports_gradient_checkpointingr   Moduler,  r(   r(   r(   r)   r!  W  s   
r!  zJ
    The text model from ALIGN without any head or projection on top.
    c                       s   e Zd ZU eed< dgZdeed fddZdd Zd	d
 Z	e
edeej eej eej eej eej eej ee ee ee eeef d
ddZ  ZS )AlignTextModelrG   r   T)rG   add_pooling_layerc                    sD   t  | || _t|| _t|| _|r2t|nd| _| 	  dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        N)
r\   r]   rG   r   r   r  encoderr  pooler	post_init)r:   rG   r0  ri   r(   r)   r]   z  s    

zAlignTextModel.__init__c                 C   s   | j jS rl   r   r   r9   r(   r(   r)   get_input_embeddings  s    z#AlignTextModel.get_input_embeddingsc                 C   s   || j _d S rl   r4  )r:   r   r(   r(   r)   set_input_embeddings  s    z#AlignTextModel.set_input_embeddingsN
r   r   r   r   r   r   r   r   r   r4   c
                 K   s  |dur|n| j j}|dur |n| j j}|	dur4|	n| j j}	|durV|durVtdn@|durt| || | }n"|dur| dd }ntd|\}}|dur|jn|j}|du rtj	||f|d}|du r t
| jdr| jjddd|f }|||}|}ntj|tj|d}| ||}| || j j}| j||||d}| j|f||||d	d
|
}|d }| jdur| |nd}t|||j|jdS )a-  
        Examples:

        ```python
        >>> from transformers import AutoTokenizer, AlignTextModel

        >>> model = AlignTextModel.from_pretrained("kakaobrain/align-base")
        >>> tokenizer = AutoTokenizer.from_pretrained("kakaobrain/align-base")

        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled (EOS token) states
        ```NzDYou cannot specify both input_ids and inputs_embeds at the same timer   z5You have to specify either input_ids or inputs_embedsr>   r   r   )r   r   r   r   T)r   r   r   r   r   r   )r   pooler_outputr   r,   )rG   r   r   use_return_dictr   Z%warn_if_padding_and_no_attention_maskr   r?   r$   Zonesr   r   r   r   rr   r   Zget_extended_attention_maskZget_head_maskr  r1  r2  r   r   r,   )r:   r   r   r   r   r   r   r   r   r   r   r   Z
batch_sizer   r?   r   r   Zextended_attention_maskembedding_outputencoder_outputsZsequence_outputr   r(   r(   r)   rn     sb    


	zAlignTextModel.forward)T)	NNNNNNNNN)r    r!   r"   r   r&   Z_no_split_modulesr   r]   r5  r6  r   r   r   r$   ro   r%   r   r'   r   rn   rp   r(   r(   ri   r)   r/  q  s8   
         
r/  zL
    The vision model from ALIGN without any head or projection on top.
    c                	       sx   e Zd ZU eed< dZdZed fddZej	ddd	Z
eedeej ee ee eeef dddZ  ZS )AlignVisionModelrG   rk   FrR   c                    s~   t  | || _t|| _t|| _|jdkrDtj	|j
dd| _n.|jdkrbtj|j
dd| _ntd|j |   d S )Nr#  T)Z	ceil_moderJ   z2config.pooling must be one of ['mean', 'max'] got )r\   r]   rG   rQ   r   r   r1  Zpooling_typer   Z	AvgPool2dZ
hidden_dimr2  Z	MaxPool2dr   Zpoolingr3  rh   ri   r(   r)   r]     s    



zAlignVisionModel.__init__r3   c                 C   s
   | j jjS rl   )vision_modelr   ra   r9   r(   r(   r)   r5    s    z%AlignVisionModel.get_input_embeddingsNrk   r   r   r4   c                 C   s   |dur|n| j j}|dur |n| j j}|du r8td| |}| j||dd}|d }| |}||jdd }t	|||j
dS )a  
        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, AlignVisionModel

        >>> model = AlignVisionModel.from_pretrained("kakaobrain/align-base")
        >>> processor = AutoProcessor.from_pretrained("kakaobrain/align-base")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled CLS states
        ```Nz You have to specify pixel_valuesT)r   r   r   rI   )r   r8  r   )rG   r   r9  r   r   r1  r2  r   r   r   r   )r:   rk   r   r   r:  r;  r   r   r(   r(   r)   rn     s&    

zAlignVisionModel.forward)NNN)r    r!   r"   r   r&   Zmain_input_namer-  r]   r   r.  r5  r   r   r   r$   r%   r   r   r'   r   rn   rp   r(   r(   ri   r)   r<    s    
   
r<  c                       s  e Zd ZU eed< ed fddZedeej	 eej	 eej	 eej	 eej	 eej	 ee
 ee
 ee
 ejd
ddZedeej ee
 ee
 ejd	d
dZeedeej eej eej	 eej	 eej	 eej	 eej	 ee
 ee
 ee
 ee
 eeef dddZ  ZS )r'  rG   rR   c                    s   t  | t|jts.tdt|j dt|jtsPtdt|j d|j}|j}|j	| _	|j
| _t|| _t|| _t| j| j	| _tt| jj| _|   d S )NzLconfig.text_config is expected to be of type AlignTextConfig but is of type .zPconfig.vision_config is expected to be of type AlignVisionConfig but is of type )r\   r]   rO   text_configr   	TypeErrortypevision_configr   Zprojection_dimr   Ztext_embed_dimr/  
text_modelr<  r=  r   r   r)  	Parameterr$   ZtensorrG   r+  r*  r3  )r:   rG   r@  rC  ri   r(   r)   r]   J  s,    

zAlignModel.__init__Nr7  c
                 C   s   |dur|n| j j}|dur |n| j j}|	dur4|	n| j j}	| j|||||||||	d	}
|
d dddddf }| |}|S )a  
        Returns:
            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of [`AlignTextModel`].

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, AlignModel

        >>> model = AlignModel.from_pretrained("kakaobrain/align-base")
        >>> tokenizer = AutoTokenizer.from_pretrained("kakaobrain/align-base")

        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
        >>> text_features = model.get_text_features(**inputs)
        ```N	r   r   r   r   r   r   r   r   r   r   )rG   r   r   r9  rD  r)  )r:   r   r   r   r   r   r   r   r   r   text_outputsr   Ztext_featuresr(   r(   r)   get_text_featuresh  s$    
zAlignModel.get_text_featuresr>  c                 C   sD   |dur|n| j j}|dur |n| j j}| j|||d}|d }|S )a9  
        Returns:
            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of [`AlignVisionModel`].

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, AlignModel

        >>> model = AlignModel.from_pretrained("kakaobrain/align-base")
        >>> processor = AutoProcessor.from_pretrained("kakaobrain/align-base")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> image_features = model.get_image_features(**inputs)
        ```Nrk   r   r   r   )rG   r   r9  r=  )r:   rk   r   r   vision_outputsZimage_featuresr(   r(   r)   get_image_features  s    zAlignModel.get_image_features)r   rk   r   r   r   r   r   return_lossr   r   r   r4   c                 C   s   |	dur|	n| j j}	|
dur |
n| j j}
|dur4|n| j j}| j||
dd}| j|||||||	|
dd	}|d }|d dddddf }| |}||jdddd	 }||jdddd	 }t	||
 | j }|
 }d}|rt|}t|||||||d
S )a  
        return_loss (`bool`, *optional*):
            Whether or not to return the contrastive loss.

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, AlignModel

        >>> model = AlignModel.from_pretrained("kakaobrain/align-base")
        >>> processor = AutoProcessor.from_pretrained("kakaobrain/align-base")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(
        ...     images=image, text=["a photo of a cat", "a photo of a dog"], return_tensors="pt", padding=True
        ... )

        >>> outputs = model(**inputs)
        >>> logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
        >>> probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities
        ```NTrI  rF  r   r   rI   r   )r   r   Zkeepdim)r.   r/   r0   r+   r   r1   r2   )rG   r   r   r9  r=  rD  r)  Znormr$   r   rE   r*  rF   r-   )r:   r   rk   r   r   r   r   r   rL  r   r   r   rJ  rG  r   r+   r0   r/   r.   r(   r(   r)   rn     sN    *
zAlignModel.forward)	NNNNNNNNN)NNN)NNNNNNNNNNN)r    r!   r"   r   r&   r]   r   r   r$   ro   r   r%   rH  rK  r   r   r   r'   r-   rn   rp   r(   r(   ri   r)   r'  F  sz   
         4   ,           
r'  )r!  r/  r<  r'  )T)r   N)Mr#   r   dataclassesr   typingr   r   r   r   r$   Ztorch.utils.checkpointr   Zactivationsr	   Zmodeling_layersr
   Zmodeling_outputsr   r   r   r   Zmodeling_utilsr   r   Zpytorch_utilsr   r   r   utilsr   r   r   r   Zconfiguration_alignr   r   r   Z
get_loggerr    loggerr   r*   r-   ro   rC   rF   rK   rL   r'   r   rP   r.  rQ   r`   rq   rx   r   r   r   r   r   r   r   r   r   r   r  r  r  r  r  r  r!  r/  r<  r'  __all__r(   r(   r(   r)   <module>   s   
%('!QJF  ;.(2{P a