a
    hd                     @   s  d Z ddlZddlZddlmZ ddlmZmZ ddlZddl	m
  mZ ddlZddlm
Z
 ddlmZmZmZ ddlmZ dd	lmZmZ dd
lmZ ddlmZmZ ddlmZmZ ddlm Z  e!e"Z#d-ej$e%e&ej$dddZ'G dd de
j(Z)G dd de
j(Z*G dd de
j(Z+G dd de
j(Z,G dd de
j(Z-G dd de
j(Z.G d d! d!e
j(Z/G d"d# d#e
j(Z0eG d$d% d%eZ1eG d&d' d'e1Z2ed(d)G d*d+ d+e1Z3g d,Z4dS ).zPyTorch PVT model.    N)Iterable)OptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BaseModelOutputImageClassifierOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)auto_docstringlogging   )	PvtConfig        F)input	drop_probtrainingreturnc                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    r   r   r   )r   )dtypedevice)shapendimtorchZrandr   r   Zfloor_div)r   r   r   Z	keep_probr   Zrandom_tensoroutput r    `/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/pvt/modeling_pvt.py	drop_path*   s    
r"   c                       sP   e Zd ZdZdee dd fddZejejdddZ	e
d	d
dZ  ZS )PvtDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).N)r   r   c                    s   t    || _d S N)super__init__r   )selfr   	__class__r    r!   r&   B   s    
zPvtDropPath.__init__hidden_statesr   c                 C   s   t || j| jS r$   )r"   r   r   r'   r+   r    r    r!   forwardF   s    zPvtDropPath.forward)r   c                 C   s   d| j  S )Nzp=)r   )r'   r    r    r!   
extra_reprI   s    zPvtDropPath.extra_repr)N)__name__
__module____qualname____doc__r   floatr&   r   Tensorr-   strr.   __classcell__r    r    r(   r!   r#   ?   s   r#   c                	       s   e Zd ZdZdeeeee f eeee f eeeed fddZ	e
jeee
jdddZe
jee
jeef d	d
dZ  ZS )PvtPatchEmbeddingsz
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    Fconfig
image_size
patch_sizestridenum_channelshidden_size	cls_tokenc           	         s   t    || _t|tjjr"|n||f}t|tjjr<|n||f}|d |d  |d |d   }|| _|| _|| _	|| _
ttd|r|d n||| _|rttdd|nd | _tj||||d| _tj||jd| _tj|jd| _d S )Nr   r   Zkernel_sizer<   eps)p)r%   r&   r9   
isinstancecollectionsabcr   r:   r;   r=   num_patchesr   	Parameterr   Zrandnposition_embeddingsZzerosr?   Conv2d
projection	LayerNormlayer_norm_eps
layer_normDropouthidden_dropout_probdropout)	r'   r9   r:   r;   r<   r=   r>   r?   rG   r(   r    r!   r&   T   s     

 zPvtPatchEmbeddings.__init__)
embeddingsheightwidthr   c                 C   s|   || }t j s,|| jj| jj kr,| jS |d||ddddd}tj	|||fdd}|dd|| ddd}|S )Nr   r   r	      Zbilinear)sizemode)
r   Zjit
is_tracingr9   r:   rI   reshapepermuteFZinterpolate)r'   rR   rS   rT   rG   Zinterpolated_embeddingsr    r    r!   interpolate_pos_encodingp   s    z+PvtPatchEmbeddings.interpolate_pos_encoding)pixel_valuesr   c                 C   s   |j \}}}}|| jkr td| |}|j ^ }}}|ddd}| |}| jd ur| j|dd}	t	j
|	|fdd}| | jd d dd f ||}
t	j
| jd d d df |
fdd}
n| | j||}
| ||
 }|||fS )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.rV   r   rU   dim)r   r=   
ValueErrorrK   flatten	transposerN   r?   expandr   catr]   rI   rQ   )r'   r^   
batch_sizer=   rS   rT   Zpatch_embed_rR   r?   rI   r    r    r!   r-   {   s"    



 &zPvtPatchEmbeddings.forward)F)r/   r0   r1   r2   r   r   intr   boolr&   r   r4   r]   tupler-   r6   r    r    r(   r!   r7   M   s    r7   c                       s8   e Zd Zeed fddZejejdddZ  Z	S )PvtSelfOutput)r9   r>   c                    s*   t    t||| _t|j| _d S r$   )r%   r&   r   LineardenserO   rP   rQ   )r'   r9   r>   r(   r    r!   r&      s    
zPvtSelfOutput.__init__r*   c                 C   s   |  |}| |}|S r$   )rm   rQ   r,   r    r    r!   r-      s    

zPvtSelfOutput.forward)
r/   r0   r1   r   rh   r&   r   r4   r-   r6   r    r    r(   r!   rk      s   rk   c                       s^   e Zd ZdZeeeed fddZeej	dddZ
dej	eeeeej	 d	d
dZ  ZS )PvtEfficientSelfAttentionzxEfficient self-attention mechanism with reduction of the sequence [PvT paper](https://huggingface.co/papers/2102.12122).r9   r>   num_attention_headssequences_reduction_ratioc                    s   t    || _|| _| j| j dkr@td| j d| j dt| j| j | _| j| j | _tj	| j| j|j
d| _tj	| j| j|j
d| _tj	| j| j|j
d| _t|j| _|| _|dkrtj||||d| _tj||jd| _d S )	Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ())biasr   r@   rA   )r%   r&   r>   rp   ra   rh   attention_head_sizeall_head_sizer   rl   Zqkv_biasquerykeyvaluerO   Zattention_probs_dropout_probrQ   rq   rJ   sequence_reductionrL   rM   rN   r'   r9   r>   rp   rq   r(   r    r!   r&      s*    

z"PvtEfficientSelfAttention.__init__r*   c                 C   s6   |  d d | j| jf }||}|ddddS )NrU   r   rV   r   r	   )rW   rp   rt   viewr[   )r'   r+   Z	new_shaper    r    r!   transpose_for_scores   s    
z.PvtEfficientSelfAttention.transpose_for_scoresFr+   rS   rT   output_attentionsr   c                 C   s$  |  | |}| jdkrl|j\}}}|ddd||||}| |}|||dddd}| |}|  | |}	|  | 	|}
t
||	dd}|t| j }tjj|dd}| |}t
||
}|dddd }| d d | jf }||}|r||fn|f}|S )Nr   r   rV   rU   r_   r	   )r|   rv   rq   r   r[   rZ   ry   rN   rw   rx   r   matmulrc   mathsqrtrt   r   
functionalZsoftmaxrQ   
contiguousrW   ru   r{   )r'   r+   rS   rT   r~   Zquery_layerrf   Zseq_lenr=   Z	key_layerZvalue_layerZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr    r    r!   r-      s&    




z!PvtEfficientSelfAttention.forward)F)r/   r0   r1   r2   r   rh   r3   r&   r   r4   r|   ri   rj   r-   r6   r    r    r(   r!   rn      s   
 rn   c                       sP   e Zd Zeeeed fddZdd Zd
ej	eee
eej	 ddd	Z  ZS )PvtAttentionro   c                    s6   t    t||||d| _t||d| _t | _d S )N)r>   rp   rq   )r>   )r%   r&   rn   r'   rk   r   setpruned_headsrz   r(   r    r!   r&      s    
zPvtAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r_   )lenr   r'   rp   rt   r   r   rv   rw   rx   r   rm   ru   union)r'   headsindexr    r    r!   prune_heads   s    zPvtAttention.prune_headsFr}   c                 C   s4   |  ||||}| |d }|f|dd   }|S )Nr   r   )r'   r   )r'   r+   rS   rT   r~   Zself_outputsattention_outputr   r    r    r!   r-     s    zPvtAttention.forward)F)r/   r0   r1   r   rh   r3   r&   r   r   r4   ri   rj   r-   r6   r    r    r(   r!   r      s    
r   c                       sF   e Zd Zdeeee ee d fddZejejdddZ	  Z
S )	PvtFFNN)r9   in_featureshidden_featuresout_featuresc                    sj   t    |d ur|n|}t||| _t|jtrBt|j | _	n|j| _	t||| _
t|j| _d S r$   )r%   r&   r   rl   dense1rD   Z
hidden_actr5   r
   intermediate_act_fndense2rO   rP   rQ   )r'   r9   r   r   r   r(   r    r!   r&     s    
zPvtFFN.__init__r*   c                 C   s6   |  |}| |}| |}| |}| |}|S r$   )r   r   rQ   r   r,   r    r    r!   r-   +  s    




zPvtFFN.forward)NN)r/   r0   r1   r   rh   r   r&   r   r4   r-   r6   r    r    r(   r!   r     s     r   c                       sD   e Zd Zeeeeeed fddZdejeee	dddZ
  ZS )	PvtLayerr9   r>   rp   r"   rq   	mlp_ratioc                    sz   t    tj||jd| _t||||d| _|dkr>t|nt	 | _
tj||jd| _t|| }t|||d| _d S )NrA   ro   r   )r9   r   r   )r%   r&   r   rL   rM   layer_norm_1r   	attentionr#   Identityr"   layer_norm_2rh   r   mlp)r'   r9   r>   rp   r"   rq   r   Zmlp_hidden_sizer(   r    r!   r&   5  s    	
zPvtLayer.__init__Fr+   rS   rT   r~   c           
      C   sn   | j | ||||d}|d }|dd  }| |}|| }| | |}| |}|| }	|	f| }|S )Nr   r   r   )r   r   r"   r   r   )
r'   r+   rS   rT   r~   Zself_attention_outputsr   r   Z
mlp_outputZlayer_outputr    r    r!   r-   K  s    


zPvtLayer.forward)F)r/   r0   r1   r   rh   r3   r&   r   r4   ri   r-   r6   r    r    r(   r!   r   4  s   r   c                       sP   e Zd Zed fddZd	ejee ee ee e	e
ef dddZ  ZS )

PvtEncoderr9   c           	         sx  t    || _tjd|jt|jdd }g }t	|j
D ]r}|t||dkrV|jn| jjd|d   |j| |j| |dkr|jn|j|d  |j| ||j
d kd q<t|| _g }d}t	|j
D ]}g }|dkr||j|d  7 }t	|j| D ]>}|t||j| |j| |||  |j| |j| d q|t| qt|| _tj|jd |jd	| _d S )
Nr   cpu)r   rV   r   r8   r   rU   rA   )r%   r&   r9   r   ZlinspaceZdrop_path_ratesumZdepthstolistrangeZnum_encoder_blocksappendr7   r:   Zpatch_sizesstridesr=   hidden_sizesr   Z
ModuleListpatch_embeddingsr   rp   Zsequence_reduction_ratiosZ
mlp_ratiosblockrL   rM   rN   )	r'   r9   Zdrop_path_decaysrR   iblockscurZlayersjr(   r    r!   r&   c  sJ    
 

zPvtEncoder.__init__FTr^   r~   output_hidden_statesreturn_dictr   c                 C   s  |rdnd }|rdnd }|j d }t| j}|}	tt| j| jD ]\}
\}}||	\}	}}|D ]:}||	|||}|d }	|r||d f }|r`||	f }q`|
|d krB|	|||ddddd }	qB| 	|	}	|r||	f }|st
dd |	||fD S t|	||d	S )
Nr    r   r   rU   r	   rV   c                 s   s   | ]}|d ur|V  qd S r$   r    ).0vr    r    r!   	<genexpr>      z%PvtEncoder.forward.<locals>.<genexpr>Zlast_hidden_stater+   
attentions)r   r   r   	enumeratezipr   rZ   r[   r   rN   rj   r   )r'   r^   r~   r   r   Zall_hidden_statesZall_self_attentionsrf   Z
num_blocksr+   idxZembedding_layerZblock_layerrS   rT   r   Zlayer_outputsr    r    r!   r-     s4    

"

zPvtEncoder.forward)FFT)r/   r0   r1   r   r&   r   FloatTensorr   ri   r   rj   r   r-   r6   r    r    r(   r!   r   b  s   5   
r   c                   @   s4   e Zd ZU eed< dZdZg Zej	ddddZ
dS )PvtPreTrainedModelr9   pvtr^   N)moduler   c                 C   s   | j j}t|tjtjfrHtjj|jj	d|d |j
dur|j
j	  npt|tjrp|j
j	  |jj	d nHt|trtjj|jj	d|d|j_	|jdurtjj|jj	d|d|j_	dS )zInitialize the weightsr   )meanstdNg      ?)r9   Zinitializer_rangerD   r   rl   rJ   initZtrunc_normal_weightdatars   Zzero_rL   Zfill_r7   rI   r?   )r'   r   r   r    r    r!   _init_weights  s(    



z PvtPreTrainedModel._init_weights)r/   r0   r1   r   __annotations__Zbase_model_prefixZmain_input_nameZ_no_split_modulesr   Moduler   r    r    r    r!   r     s
   
r   c                	       s\   e Zd Zed fddZdd Zed
eje	e
 e	e
 e	e
 eeef ddd	Z  ZS )PvtModelr   c                    s(   t  | || _t|| _|   d S r$   )r%   r&   r9   r   encoder	post_initr'   r9   r(   r    r!   r&     s    
zPvtModel.__init__c                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   layerr   r   )r'   Zheads_to_pruner   r   r    r    r!   _prune_heads  s    zPvtModel._prune_headsNr   c                 C   s~   |d ur|n| j j}|d ur |n| j j}|d ur4|n| j j}| j||||d}|d }|sl|f|dd   S t||j|jdS )Nr^   r~   r   r   r   r   r   )r9   r~   r   use_return_dictr   r   r+   r   )r'   r^   r~   r   r   Zencoder_outputssequence_outputr    r    r!   r-     s$    zPvtModel.forward)NNN)r/   r0   r1   r   r&   r   r   r   r   r   ri   r   rj   r   r-   r6   r    r    r(   r!   r     s   
   
r   z
    Pvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
    the [CLS] token) e.g. for ImageNet.
    )Zcustom_introc                
       sb   e Zd Zedd fddZedeej eej ee	 ee	 ee	 e
eef dddZ  ZS )	PvtForImageClassificationN)r9   r   c                    sR   t  | |j| _t|| _|jdkr<t|jd |jnt | _	| 
  d S )Nr   rU   )r%   r&   
num_labelsr   r   r   rl   r   r   
classifierr   r   r(   r    r!   r&     s    
$z"PvtForImageClassification.__init__)r^   labelsr~   r   r   r   c                 C   sz  |dur|n| j j}| j||||d}|d }| |dddddf }d}	|dur6| j jdu r| jdkrxd| j _n4| jdkr|jtjks|jtj	krd| j _nd| j _| j jdkrt
 }
| jdkr|
| | }	n
|
||}	nN| j jdkrt }
|
|d| j|d}	n| j jdkr6t }
|
||}	|sf|f|dd  }|	durb|	f| S |S t|	||j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   r   Z
regressionZsingle_label_classificationZmulti_label_classificationrU   )losslogitsr+   r   )r9   r   r   r   Zproblem_typer   r   r   longrh   r   Zsqueezer   r{   r   r   r+   r   )r'   r^   r   r~   r   r   r   r   r   r   Zloss_fctr   r    r    r!   r-   %  sJ    


"


z!PvtForImageClassification.forward)NNNN)r/   r0   r1   r   r&   r   r   r   r4   ri   r   rj   r   r-   r6   r    r    r(   r!   r     s       
r   )r   r   r   )r   F)5r2   rE   r   collections.abcr   typingr   r   r   Ztorch.nn.functionalr   r   r\   Ztorch.utils.checkpointZtorch.nnr   r   r   Zactivationsr
   Zmodeling_outputsr   r   Zmodeling_utilsr   Zpytorch_utilsr   r   utilsr   r   Zconfiguration_pvtr   Z
get_loggerr/   loggerr4   r3   ri   r"   r   r#   r7   rk   rn   r   r   r   r   r   r   r   __all__r    r    r    r!   <module>   sD   
DR*.Y 3N