a
    h*e                     @   s,  d Z ddlZddlmZ ddlmZmZ ddlZddl	Zddlm
Z
 ddlmZmZmZ ddlmZmZ dd	lmZ dd
lmZmZ ddlmZmZ ddlmZ eeZeeddG dd deZd;ej e!e"ej dddZ#G dd de
j$Z%G dd de
j$Z&G dd de
j$Z'G dd de
j$Z(G dd  d e
j$Z)G d!d" d"e
j$Z*G d#d$ d$e
j$Z+G d%d& d&e
j$Z,G d'd( d(e
j$Z-G d)d* d*e
j$Z.G d+d, d,e
j$Z/G d-d. d.e
j$Z0G d/d0 d0e
j$Z1G d1d2 d2e
j$Z2eG d3d4 d4eZ3eG d5d6 d6e3Z4ed7dG d8d9 d9e3Z5g d:Z6dS )<zPyTorch CvT model.    N)	dataclass)OptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )$ImageClassifierOutputWithNoAttentionModelOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)auto_docstringlogging   )	CvtConfigzV
    Base class for model's outputs, with potential hidden states and attentions.
    )Zcustom_introc                   @   sP   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dS )BaseModelOutputWithCLSTokenz
    cls_token_value (`torch.FloatTensor` of shape `(batch_size, 1, hidden_size)`):
        Classification token at the output of the last layer of the model.
    Nlast_hidden_statecls_token_value.hidden_states)__name__
__module____qualname____doc__r   r   torchZFloatTensor__annotations__r   r   tuple r   r   `/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/cvt/modeling_cvt.pyr   $   s   
r           F)input	drop_probtrainingreturnc                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    r    r   r   )r   )dtypedevice)shapendimr   Zrandr%   r&   Zfloor_div)r!   r"   r#   Z	keep_probr'   Zrandom_tensoroutputr   r   r   	drop_path6   s    
r+   c                       sP   e Zd ZdZdee dd fddZejejdddZ	e
d	d
dZ  ZS )CvtDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).N)r"   r$   c                    s   t    || _d S N)super__init__r"   )selfr"   	__class__r   r   r/   N   s    
zCvtDropPath.__init__)r   r$   c                 C   s   t || j| jS r-   )r+   r"   r#   )r0   r   r   r   r   forwardR   s    zCvtDropPath.forward)r$   c                 C   s   d| j  S )Nzp=r"   )r0   r   r   r   
extra_reprU   s    zCvtDropPath.extra_repr)N)r   r   r   r   r   floatr/   r   Tensorr3   strr5   __classcell__r   r   r1   r   r,   K   s   r,   c                       s(   e Zd ZdZ fddZdd Z  ZS )CvtEmbeddingsz'
    Construct the CvT embeddings.
    c                    s.   t    t|||||d| _t|| _d S )N)
patch_sizenum_channels	embed_dimstridepadding)r.   r/   CvtConvEmbeddingsconvolution_embeddingsr   Dropoutdropout)r0   r;   r<   r=   r>   r?   dropout_rater1   r   r   r/   ^   s
    

zCvtEmbeddings.__init__c                 C   s   |  |}| |}|S r-   )rA   rC   )r0   pixel_valueshidden_stater   r   r   r3   e   s    

zCvtEmbeddings.forwardr   r   r   r   r/   r3   r9   r   r   r1   r   r:   Y   s   r:   c                       s(   e Zd ZdZ fddZdd Z  ZS )r@   z"
    Image to Conv Embedding.
    c                    sP   t    t|tjjr|n||f}|| _tj|||||d| _	t
|| _d S )N)kernel_sizer>   r?   )r.   r/   
isinstancecollectionsabcIterabler;   r   Conv2d
projection	LayerNormnormalization)r0   r;   r<   r=   r>   r?   r1   r   r   r/   p   s
    
zCvtConvEmbeddings.__init__c                 C   sf   |  |}|j\}}}}|| }||||ddd}| jrH| |}|ddd||||}|S Nr      r   )rN   r'   viewpermuterP   )r0   rE   
batch_sizer<   heightwidthhidden_sizer   r   r   r3   w   s    

zCvtConvEmbeddings.forwardrG   r   r   r1   r   r@   k   s   r@   c                       s$   e Zd Z fddZdd Z  ZS )CvtSelfAttentionConvProjectionc              	      s4   t    tj|||||d|d| _t|| _d S )NF)rH   r?   r>   biasgroups)r.   r/   r   rM   convolutionZBatchNorm2drP   )r0   r=   rH   r?   r>   r1   r   r   r/      s    
	z'CvtSelfAttentionConvProjection.__init__c                 C   s   |  |}| |}|S r-   )r\   rP   r0   rF   r   r   r   r3      s    

z&CvtSelfAttentionConvProjection.forwardr   r   r   r/   r3   r9   r   r   r1   r   rY      s   rY   c                   @   s   e Zd Zdd ZdS ) CvtSelfAttentionLinearProjectionc                 C   s2   |j \}}}}|| }||||ddd}|S rQ   )r'   rS   rT   )r0   rF   rU   r<   rV   rW   rX   r   r   r   r3      s    z(CvtSelfAttentionLinearProjection.forwardN)r   r   r   r3   r   r   r   r   r_      s   r_   c                       s&   e Zd Zd fdd	Zdd Z  ZS )CvtSelfAttentionProjectiondw_bnc                    s.   t    |dkr"t||||| _t | _d S )Nra   )r.   r/   rY   convolution_projectionr_   linear_projection)r0   r=   rH   r?   r>   projection_methodr1   r   r   r/      s    
z#CvtSelfAttentionProjection.__init__c                 C   s   |  |}| |}|S r-   )rb   rc   r]   r   r   r   r3      s    

z"CvtSelfAttentionProjection.forward)ra   r^   r   r   r1   r   r`      s   r`   c                       s.   e Zd Zd fdd	Zdd Zdd Z  ZS )	CvtSelfAttentionTc                    s   t    |d | _|| _|| _|| _t|||||dkr<dn|d| _t|||||d| _t|||||d| _	t
j|||	d| _t
j|||	d| _t
j|||	d| _t
|
| _d S )Ng      ZavgZlinear)rd   )rZ   )r.   r/   scalewith_cls_tokenr=   	num_headsr`   convolution_projection_queryconvolution_projection_keyconvolution_projection_valuer   Linearprojection_queryprojection_keyprojection_valuerB   rC   )r0   rh   r=   rH   	padding_q
padding_kvstride_q	stride_kvqkv_projection_methodqkv_biasattention_drop_raterg   kwargsr1   r   r   r/      s,    



zCvtSelfAttention.__init__c                 C   s6   |j \}}}| j| j }|||| j|ddddS )Nr   rR   r   r	   )r'   r=   rh   rS   rT   )r0   rF   rU   rX   _head_dimr   r   r   "rearrange_for_multi_head_attention   s    z3CvtSelfAttention.rearrange_for_multi_head_attentionc                 C   sT  | j r t|d|| gd\}}|j\}}}|ddd||||}| |}| |}	| |}
| j rtj	||	fdd}	tj	||fdd}tj	||
fdd}
| j
| j }| | |	}	| | |}| | |
}
td|	|g| j }tjjj|dd}| |}td||
g}|j\}}}}|dddd ||| j| }|S )	Nr   r   rR   dimzbhlk,bhtk->bhltzbhlt,bhtv->bhlvr	   )rg   r   splitr'   rT   rS   rj   ri   rk   catr=   rh   rz   rm   rn   ro   Zeinsumrf   r   Z
functionalZsoftmaxrC   
contiguous)r0   rF   rV   rW   	cls_tokenrU   rX   r<   keyqueryvaluery   Zattention_scoreZattention_probscontextrx   r   r   r   r3      s,    



$zCvtSelfAttention.forward)T)r   r   r   r/   rz   r3   r9   r   r   r1   r   re      s    )re   c                       s(   e Zd ZdZ fddZdd Z  ZS )CvtSelfOutputz
    The residual connection is defined in CvtLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    c                    s(   t    t||| _t|| _d S r-   )r.   r/   r   rl   denserB   rC   )r0   r=   	drop_rater1   r   r   r/     s    
zCvtSelfOutput.__init__c                 C   s   |  |}| |}|S r-   r   rC   r0   rF   Zinput_tensorr   r   r   r3   
  s    

zCvtSelfOutput.forwardrG   r   r   r1   r   r      s   r   c                       s.   e Zd Zd fdd	Zdd Zdd Z  ZS )	CvtAttentionTc                    s@   t    t|||||||||	|
|| _t||| _t | _d S r-   )r.   r/   re   	attentionr   r*   setpruned_heads)r0   rh   r=   rH   rp   rq   rr   rs   rt   ru   rv   r   rg   r1   r   r   r/     s     
zCvtAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r{   )lenr   r   Znum_attention_headsZattention_head_sizer   r   r   r   r   r*   r   Zall_head_sizeunion)r0   headsindexr   r   r   prune_heads1  s    zCvtAttention.prune_headsc                 C   s   |  |||}| ||}|S r-   )r   r*   )r0   rF   rV   rW   Zself_outputattention_outputr   r   r   r3   C  s    zCvtAttention.forward)T)r   r   r   r/   r   r3   r9   r   r   r1   r   r     s     r   c                       s$   e Zd Z fddZdd Z  ZS )CvtIntermediatec                    s.   t    t|t|| | _t | _d S r-   )r.   r/   r   rl   intr   ZGELU
activation)r0   r=   	mlp_ratior1   r   r   r/   J  s    
zCvtIntermediate.__init__c                 C   s   |  |}| |}|S r-   )r   r   r]   r   r   r   r3   O  s    

zCvtIntermediate.forwardr^   r   r   r1   r   r   I  s   r   c                       s$   e Zd Z fddZdd Z  ZS )	CvtOutputc                    s0   t    tt|| || _t|| _d S r-   )r.   r/   r   rl   r   r   rB   rC   )r0   r=   r   r   r1   r   r   r/   V  s    
zCvtOutput.__init__c                 C   s    |  |}| |}|| }|S r-   r   r   r   r   r   r3   [  s    

zCvtOutput.forwardr^   r   r   r1   r   r   U  s   r   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )CvtLayerzb
    CvtLayer composed by attention layers, normalization and multi-layer perceptrons (mlps).
    Tc                    s|   t    t|||||||||	|
||| _t||| _t|||| _|dkrVt|dnt	
 | _t	|| _t	|| _d S )Nr    r4   )r.   r/   r   r   r   intermediater   r*   r,   r   Identityr+   rO   layernorm_beforelayernorm_after)r0   rh   r=   rH   rp   rq   rr   rs   rt   ru   rv   r   r   drop_path_raterg   r1   r   r   r/   g  s(    
zCvtLayer.__init__c                 C   sX   |  | |||}|}| |}|| }| |}| |}| ||}| |}|S r-   )r   r   r+   r   r   r*   )r0   rF   rV   rW   Zself_attention_outputr   Zlayer_outputr   r   r   r3     s    



zCvtLayer.forward)TrG   r   r   r1   r   r   b  s    'r   c                       s$   e Zd Z fddZdd Z  ZS )CvtStagec                    s   t     _|_jjj rBttddjj	d _t
 jj  jj jdkrh jn j	jd   j	j  jj  jj d_dd tjd jj  j| ddD tj fd	dt jj D  _d S )
Nr   r}   r   )r;   r>   r<   r=   r?   rD   c                 S   s   g | ]}|  qS r   )item).0xr   r   r   
<listcomp>  s   z%CvtStage.__init__.<locals>.<listcomp>cpu)r&   c                    s   g | ]}t  jj  jj  jj  jj  jj  jj  jj  j	j  j
j  jj  jj j  jj  jj d qS ))rh   r=   rH   rp   rq   rs   rr   rt   ru   rv   r   r   r   rg   )r   rh   stager=   Z
kernel_qkvrp   rq   rs   rr   rt   ru   rv   r   r   r   )r   rx   configZdrop_path_ratesr0   r   r   r     s"   












)r.   r/   r   r   r   r   	Parameterr   Zrandnr=   r:   Zpatch_sizesZpatch_strider<   Zpatch_paddingr   	embeddingZlinspacer   depthZ
Sequentialrangelayers)r0   r   r   r1   r   r   r/     s*    





	zCvtStage.__init__c           	      C   s   d }|  |}|j\}}}}||||| ddd}| jj| j rh| j|dd}tj	||fdd}| j
D ]}||||}|}qn| jj| j rt|d|| gd\}}|ddd||||}||fS )Nr   rR   r   r}   r{   )r   r'   rS   rT   r   r   r   expandr   r   r   r~   )	r0   rF   r   rU   r<   rV   rW   layerZlayer_outputsr   r   r   r3     s    

zCvtStage.forwardr^   r   r   r1   r   r     s   *r   c                       s&   e Zd Z fddZdddZ  ZS )
CvtEncoderc                    sF   t    || _tg | _tt|jD ]}| j	t
|| q*d S r-   )r.   r/   r   r   Z
ModuleListstagesr   r   r   appendr   )r0   r   Z	stage_idxr1   r   r   r/     s
    
zCvtEncoder.__init__FTc           	      C   sl   |rdnd }|}d }t | jD ]"\}}||\}}|r||f }q|s^tdd |||fD S t|||dS )Nr   c                 s   s   | ]}|d ur|V  qd S r-   r   )r   vr   r   r   	<genexpr>      z%CvtEncoder.forward.<locals>.<genexpr>r   r   r   )	enumerater   r   r   )	r0   rE   output_hidden_statesreturn_dictZall_hidden_statesrF   r   rx   Zstage_moduler   r   r   r3     s    zCvtEncoder.forward)FTr^   r   r   r1   r   r     s   r   c                   @   s,   e Zd ZU eed< dZdZdgZdd ZdS )CvtPreTrainedModelr   cvtrE   r   c                 C   s   t |tjtjfrHtjj|jjd| jj	d|j_|j
dur|j
j  n^t |tjrp|j
j  |jjd n6t |tr| jj|j rtjj|jjd| jj	d|j_dS )zInitialize the weightsr    )meanZstdNg      ?)rI   r   rl   rM   initZtrunc_normal_weightdatar   Zinitializer_rangerZ   Zzero_rO   Zfill_r   r   r   )r0   moduler   r   r   _init_weights  s    

z CvtPreTrainedModel._init_weightsN)	r   r   r   r   r   Zbase_model_prefixZmain_input_nameZ_no_split_modulesr   r   r   r   r   r     s
   
r   c                       sV   e Zd Zd
 fdd	Zdd Zedeej ee	 ee	 e
eef ddd	Z  ZS )CvtModelTc                    s(   t  | || _t|| _|   dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        N)r.   r/   r   r   encoder	post_init)r0   r   add_pooling_layerr1   r   r   r/     s    
zCvtModel.__init__c                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   r   )r0   Zheads_to_pruner   r   r   r   r   _prune_heads"  s    zCvtModel._prune_headsN)rE   r   r   r$   c                 C   sx   |d ur|n| j j}|d ur |n| j j}|d u r8td| j|||d}|d }|sf|f|dd   S t||j|jdS )Nz You have to specify pixel_valuesr   r   r   r   r   )r   r   use_return_dict
ValueErrorr   r   r   r   )r0   rE   r   r   Zencoder_outputssequence_outputr   r   r   r3   *  s$    zCvtModel.forward)T)NNN)r   r   r   r/   r   r   r   r   r7   boolr   r   r   r3   r9   r   r   r1   r   r     s   
   
r   z
    Cvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
    the [CLS] token) e.g. for ImageNet.
    c                	       sT   e Zd Z fddZedeej eej ee ee e	e
ef dddZ  ZS )CvtForImageClassificationc                    sh   t  | |j| _t|dd| _t|jd | _|jdkrRt	|jd |jnt
 | _|   d S )NF)r   r}   r   )r.   r/   
num_labelsr   r   r   rO   r=   	layernormrl   r   
classifierr   )r0   r   r1   r   r   r/   Q  s    $z"CvtForImageClassification.__init__N)rE   labelsr   r   r$   c                 C   s  |dur|n| j j}| j|||d}|d }|d }| j jd rL| |}n4|j\}}	}
}|||	|
| ddd}| |}|jdd}| 	|}d}|dur| j j
du r| j jdkrd| j _
n6| j jdkr|jtjks|jtjkrd	| j _
nd
| j _
| j j
dkr>t }| j jdkr2|| | }n
|||}nP| j j
d	krpt }||d| j j|d}n| j j
d
krt }|||}|s|f|dd  }|dur|f| S |S t|||jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   r   r}   rR   r{   Z
regressionZsingle_label_classificationZmulti_label_classification)losslogitsr   )r   r   r   r   r   r'   rS   rT   r   r   Zproblem_typer   r%   r   longr   r   Zsqueezer   r   r
   r   )r0   rE   r   r   r   outputsr   r   rU   r<   rV   rW   Zsequence_output_meanr   r   Zloss_fctr*   r   r   r   r3   _  sL    



$

z!CvtForImageClassification.forward)NNNN)r   r   r   r/   r   r   r   r7   r   r   r   r
   r3   r9   r   r   r1   r   r   J  s       
r   )r   r   r   )r    F)7r   collections.abcrJ   dataclassesr   typingr   r   r   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   Zmodeling_outputsr
   r   Zmodeling_utilsr   Zpytorch_utilsr   r   utilsr   r   Zconfiguration_cvtr   Z
get_loggerr   loggerr   r7   r6   r   r+   Moduler,   r:   r@   rY   r_   r`   re   r   r   r   r   r   r   r   r   r   r   __all__r   r   r   r   <module>   sT   
	Q9B?3O