a
    h                     @   s  d Z ddlZddlZddlZddlmZ ddlmZm	Z	 ddl
Z
ddlZ
ddl
mZmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZ ddlmZmZmZmZ ddlm Z  ddl!m"Z" e#e$Z%eeddG dd deZ&eeddG dd deZ'eeddG dd deZ(eeddG dd deZ)dd Z*dd  Z+dLe
je,e-e
jd#d$d%Z.G d&d' d'ej/Z0G d(d) d)ej/Z1G d*d+ d+ej/Z2G d,d- d-ej/Z3G d.d/ d/ej/Z4G d0d1 d1ej/Z5G d2d3 d3ej/Z6G d4d5 d5ej/Z7G d6d7 d7ej/Z8G d8d9 d9ej/Z9G d:d; d;eZ:G d<d= d=ej/Z;eG d>d? d?eZ<eG d@dA dAe<Z=edBdG dCdD dDe<Z>edEdG dFdG dGe<Z?edHdG dIdJ dJe<e Z@g dKZAdS )Mz!PyTorch Swinv2 Transformer model.    N)	dataclass)OptionalUnion)Tensornn   )ACT2FN)GradientCheckpointingLayer)BackboneOutput)PreTrainedModel) find_pruneable_heads_and_indicesmeshgridprune_linear_layer)ModelOutputauto_docstringlogging	torch_int)BackboneMixin   )Swinv2ConfigzP
    Swinv2 encoder's outputs, with potential hidden states and attentions.
    )Zcustom_introc                   @   sr   e Zd ZU dZdZeej ed< dZ	ee
ejdf  ed< dZee
ejdf  ed< dZee
ejdf  ed< dS )Swinv2EncoderOutputa  
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nlast_hidden_state.hidden_states
attentionsreshaped_hidden_states)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   tupler   r    r#   r#   f/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/swinv2/modeling_swinv2.pyr   +   s
   
	r   zX
    Swinv2 model's outputs that also contains a pooling of the last hidden states.
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< dS )	Swinv2ModelOutputa  
    pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
        Average pooling of the last layer hidden-state.
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nr   pooler_output.r   r   r   )r   r   r   r   r   r   r   r    r!   r&   r   r"   r   r   r#   r#   r#   r$   r%   B   s   
r%   z,
    Swinv2 masked image model outputs.
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< ed	d
 ZdS )Swinv2MaskedImageModelingOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
        Masked image modeling (MLM) loss.
    reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
        Reconstructed pixel values.
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nlossreconstruction.r   r   r   c                 C   s   t dt | jS )Nzlogits attribute is deprecated and will be removed in version 5 of Transformers. Please use the reconstruction attribute to retrieve the final output instead.)warningswarnFutureWarningr)   selfr#   r#   r$   logitsw   s
    z&Swinv2MaskedImageModelingOutput.logits)r   r   r   r   r(   r   r   r    r!   r)   r   r"   r   r   propertyr/   r#   r#   r#   r$   r'   \   s   
r'   z2
    Swinv2 outputs for image classification.
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< dS )	Swinv2ImageClassifierOutputa7  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Classification (or regression if config.num_labels==1) loss.
    logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
        Classification (or regression if config.num_labels==1) scores (before SoftMax).
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nr(   r/   .r   r   r   )r   r   r   r   r(   r   r   r    r!   r/   r   r"   r   r   r#   r#   r#   r$   r1      s   
r1   c                 C   sR   | j \}}}}| ||| ||| ||} | dddddd d|||}|S )z2
    Partitions the given input into windows.
    r   r   r            shapeviewpermute
contiguous)input_featurewindow_size
batch_sizeheightwidthnum_channelswindowsr#   r#   r$   window_partition   s    $rB   c                 C   sN   | j d }| d|| || |||} | dddddd d|||} | S )z?
    Merges windows to produce higher resolution features.
    r5   r   r   r   r2   r3   r4   r6   )rA   r<   r>   r?   r@   r#   r#   r$   window_reverse   s    
$rC           F)input	drop_probtrainingreturnc                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    rD   r   r   )r   )dtypedevice)r7   ndimr   ZrandrI   rJ   Zfloor_div)rE   rF   rG   Z	keep_probr7   Zrandom_tensoroutputr#   r#   r$   	drop_path   s    
rN   c                       sP   e Zd ZdZdee dd fddZejejdddZ	e
d	d
dZ  ZS )Swinv2DropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).N)rF   rH   c                    s   t    || _d S N)super__init__rF   )r.   rF   	__class__r#   r$   rR      s    
zSwinv2DropPath.__init__r   rH   c                 C   s   t || j| jS rP   )rN   rF   rG   r.   r   r#   r#   r$   forward   s    zSwinv2DropPath.forwardrH   c                 C   s   d| j  S )Nzp=)rF   r-   r#   r#   r$   
extra_repr   s    zSwinv2DropPath.extra_repr)N)r   r   r   r   r   floatrR   r   r   rW   strrY   __classcell__r#   r#   rS   r$   rO      s   rO   c                       sb   e Zd ZdZd fdd	ZejeeejdddZde	ej
 e	ej eeej d	d
dZ  ZS )Swinv2EmbeddingszW
    Construct the patch and position embeddings. Optionally, also the mask token.
    Fc                    s   t    t|| _| jj}| jj| _|r@tt	
dd|jnd | _|jrjtt	
d|d |j| _nd | _t|j| _t|j| _|j| _|| _d S )Nr   )rQ   rR   Swinv2PatchEmbeddingspatch_embeddingsnum_patches	grid_size
patch_gridr   	Parameterr   zeros	embed_dim
mask_tokenZuse_absolute_embeddingsposition_embeddings	LayerNormnormDropouthidden_dropout_probdropout
patch_sizeconfig)r.   rn   use_mask_tokenr`   rS   r#   r$   rR      s    


 zSwinv2Embeddings.__init__)
embeddingsr>   r?   rH   c                 C   s   |j d d }| jj d d }tj s>||kr>||kr>| jS | jddddf }| jddddf }|j d }|| j }	|| j }
t|d }|d|||}|dddd}t	j
j||	|
fdd	d
}|dddddd|}tj||fddS )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   Nr5         ?r   r   r2   ZbicubicF)sizemodeZalign_cornersdim)r7   rg   r   Zjit
is_tracingrm   r   reshaper9   r   
functionalZinterpolater8   cat)r.   rp   r>   r?   r`   Znum_positionsZclass_pos_embedZpatch_pos_embedru   Z
new_heightZ	new_widthZsqrt_num_positionsr#   r#   r$   interpolate_pos_encoding   s(    



z)Swinv2Embeddings.interpolate_pos_encodingN)pixel_valuesbool_masked_posrz   rH   c                 C   s   |j \}}}}| |\}}	| |}| \}
}}|d urp| j|
|d}|d|}|d|  ||  }| jd ur|r|| 	||| }n
|| j }| 
|}||	fS )Nr5         ?)r7   r_   ri   rr   rf   expand	unsqueezeZtype_asrg   rz   rl   )r.   r{   r|   rz   _r@   r>   r?   rp   output_dimensionsr=   Zseq_lenZmask_tokensmaskr#   r#   r$   rW     s    



zSwinv2Embeddings.forward)F)NF)r   r   r   r   rR   r   r   intrz   r   r    
BoolTensorboolr"   rW   r\   r#   r#   rS   r$   r]      s   +  r]   c                       sL   e Zd ZdZ fddZdd Zeej e	ej
e	e f dddZ  ZS )	r^   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j }}|j|j }}t|tjj	r8|n||f}t|tjj	rR|n||f}|d |d  |d |d   }|| _|| _|| _|| _
|d |d  |d |d  f| _tj||||d| _d S )Nr   r   )kernel_sizeZstride)rQ   rR   
image_sizerm   r@   re   
isinstancecollectionsabcIterabler`   ra   r   Conv2d
projection)r.   rn   r   rm   r@   hidden_sizer`   rS   r#   r$   rR   >  s    
 "zSwinv2PatchEmbeddings.__init__c                 C   s   || j d  dkr<d| j d || j d   f}tj||}|| j d  dkr|ddd| j d || j d   f}tj||}|S )Nr   r   )rm   r   rx   pad)r.   r{   r>   r?   
pad_valuesr#   r#   r$   	maybe_padM  s     zSwinv2PatchEmbeddings.maybe_pad)r{   rH   c                 C   sV   |j \}}}}| |||}| |}|j \}}}}||f}|ddd}||fS )Nr2   r   )r7   r   r   flatten	transpose)r.   r{   r   r@   r>   r?   rp   r   r#   r#   r$   rW   V  s    
zSwinv2PatchEmbeddings.forward)r   r   r   r   rR   r   r   r   r    r"   r   r   rW   r\   r#   r#   rS   r$   r^   7  s   	r^   c                       s^   e Zd ZdZejfee eejdd fddZ	dd Z
ejeeef ejdd	d
Z  ZS )Swinv2PatchMerginga'  
    Patch Merging Layer.

    Args:
        input_resolution (`tuple[int]`):
            Resolution of input feature.
        dim (`int`):
            Number of input channels.
        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
            Normalization layer class.
    N)input_resolutionru   
norm_layerrH   c                    sB   t    || _|| _tjd| d| dd| _|d| | _d S )Nr3   r2   Fbias)rQ   rR   r   ru   r   Linear	reductionri   )r.   r   ru   r   rS   r#   r$   rR   o  s
    
zSwinv2PatchMerging.__init__c                 C   sF   |d dkp|d dk}|rBddd|d d|d f}t j||}|S )Nr2   r   r   )r   rx   r   )r.   r;   r>   r?   Z
should_padr   r#   r#   r$   r   v  s
    zSwinv2PatchMerging.maybe_pad)r;   input_dimensionsrH   c                 C   s   |\}}|j \}}}|||||}| |||}|d d dd ddd dd d f }|d d dd ddd dd d f }	|d d dd ddd dd d f }
|d d dd ddd dd d f }t||	|
|gd}||dd| }| |}| |}|S )Nr   r2   r   r5   r3   )r7   r8   r   r   ry   r   ri   )r.   r;   r   r>   r?   r=   ru   r@   Zinput_feature_0Zinput_feature_1Zinput_feature_2Zinput_feature_3r#   r#   r$   rW   ~  s    $$$$

zSwinv2PatchMerging.forward)r   r   r   r   r   rh   r"   r   ModulerR   r   r   r   rW   r\   r#   r#   rS   r$   r   b  s   $r   c                       sT   e Zd Zddgf fdd	Zd	ejeej eej ee e	ej dddZ
  ZS )
Swinv2SelfAttentionr   c              
      s  t    || dkr,td| d| d|| _t|| | _| j| j | _t|tj	j
r`|n||f| _|| _ttdt|ddf | _ttjddd	d
tjd	dtjd|dd
| _tj| jd d  | jd tjd }tj| jd d  | jd tjd }tt||gddddd d}|d dkr|d d d d d d df  |d d   < |d d d d d d df  |d d   < nf|dkr
|d d d d d d df  | jd d   < |d d d d d d df  | jd d   < |d9 }t|t t!|d  t" d }|#t$| j% j&}| j'd|dd t| jd }	t| jd }
tt|	|
gdd}t(|d}|d d d d d f |d d d d d f  }|ddd }|d d d d df  | jd d 7  < |d d d d df  | jd d 7  < |d d d d df  d| jd  d 9  < |)d}| j'd|dd tj| j| j|j*d
| _+tj| j| jdd
| _,tj| j| j|j*d
| _-t.|j/| _0d S )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ()
   r   r2   i   Tr   )ZinplaceFrI   Zij)Zindexing   r}   relative_coords_table)
persistentr5   relative_position_index)1rQ   rR   
ValueErrornum_attention_headsr   attention_head_sizeall_head_sizer   r   r   r   r<   pretrained_window_sizer   rc   r   logZoneslogit_scale
Sequentialr   ZReLUcontinuous_position_bias_mlpZarangeZint64rZ   stackr   r9   r:   r   signlog2absmathtonext
parametersrI   Zregister_bufferr   sumZqkv_biasquerykeyvaluerj   attention_probs_dropout_probrl   )r.   rn   ru   	num_headsr<   r   Zrelative_coords_hZrelative_coords_wr   Zcoords_hZcoords_wZcoordsZcoords_flattenZrelative_coordsr   rS   r#   r$   rR     sb    
"&((,.
..&,((,
zSwinv2SelfAttention.__init__NFr   attention_mask	head_maskoutput_attentionsrH   c                 C   s"  |j \}}}| ||d| j| jdd}| ||d| j| jdd}	| ||d| j| jdd}
tj	j
|ddtj	j
|	dddd }tj| jtdd }|| }| | jd| j}|| jd | jd | jd  | jd | jd  d}|ddd }d	t| }||d }|d ur|j d }||| || j|||dd }||dd }|d| j||}tj	j|dd}| |}|d ur|| }t||
}|dddd
 }| d d | jf }||}|r||fn|f}|S )Nr5   r   r2   rt   g      Y@)maxr      r   )r7   r   r8   r   r   r   r   r   r   rx   	normalizer   clampr   r   r   expr   r   r   r<   r9   r:   Zsigmoidr   Zsoftmaxrl   matmulrr   r   )r.   r   r   r   r   r=   ru   r@   Zquery_layerZ	key_layerZvalue_layerZattention_scoresr   Zrelative_position_bias_tableZrelative_position_biasZ
mask_shapeZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr#   r#   r$   rW     sl    


&




zSwinv2SelfAttention.forward)NNF)r   r   r   rR   r   r   r   r    r   r"   rW   r\   r#   r#   rS   r$   r     s   @   r   c                       s4   e Zd Z fddZejejejdddZ  ZS )Swinv2SelfOutputc                    s*   t    t||| _t|j| _d S rP   )rQ   rR   r   r   denserj   r   rl   r.   rn   ru   rS   r#   r$   rR      s    
zSwinv2SelfOutput.__init__)r   input_tensorrH   c                 C   s   |  |}| |}|S rP   r   rl   )r.   r   r   r#   r#   r$   rW   %  s    

zSwinv2SelfOutput.forwardr   r   r   rR   r   r   rW   r\   r#   r#   rS   r$   r     s   r   c                       sV   e Zd Zd fdd	Zdd Zdejeej eej ee	 e
ej dd	d
Z  ZS )Swinv2Attentionr   c                    sL   t    t||||t|tjjr&|n||fd| _t||| _	t
 | _d S )Nrn   ru   r   r<   r   )rQ   rR   r   r   r   r   r   r.   r   rM   setpruned_heads)r.   rn   ru   r   r<   r   rS   r#   r$   rR   -  s    
	zSwinv2Attention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rt   )lenr   r.   r   r   r   r   r   r   r   rM   r   r   union)r.   headsindexr#   r#   r$   prune_heads;  s    zSwinv2Attention.prune_headsNFr   c                 C   s6   |  ||||}| |d |}|f|dd   }|S )Nr   r   )r.   rM   )r.   r   r   r   r   Zself_outputsattention_outputr   r#   r#   r$   rW   M  s    zSwinv2Attention.forward)r   )NNF)r   r   r   rR   r   r   r   r   r    r   r"   rW   r\   r#   r#   rS   r$   r   ,  s      r   c                       s0   e Zd Z fddZejejdddZ  ZS )Swinv2Intermediatec                    sH   t    t|t|j| | _t|jt	r<t
|j | _n|j| _d S rP   )rQ   rR   r   r   r   	mlp_ratior   r   Z
hidden_actr[   r   intermediate_act_fnr   rS   r#   r$   rR   \  s
    
zSwinv2Intermediate.__init__rU   c                 C   s   |  |}| |}|S rP   )r   r   rV   r#   r#   r$   rW   d  s    

zSwinv2Intermediate.forwardr   r#   r#   rS   r$   r   [  s   r   c                       s0   e Zd Z fddZejejdddZ  ZS )Swinv2Outputc                    s4   t    tt|j| || _t|j| _	d S rP   )
rQ   rR   r   r   r   r   r   rj   rk   rl   r   rS   r#   r$   rR   l  s    
zSwinv2Output.__init__rU   c                 C   s   |  |}| |}|S rP   r   rV   r#   r#   r$   rW   q  s    

zSwinv2Output.forwardr   r#   r#   rS   r$   r   k  s   r   c                       s   e Zd Zd fdd	Zeeeef eeef f dddZdd	 Zd
d Zde	j
eeef ee	j ee ee	j
e	j
f dddZ  ZS )Swinv2LayerrD   r   c           	         s   t    || _| |j|jf||f\}}|d | _|d | _t|||| jt|tj	j
r^|n||fd| _tj||jd| _|dkrt|nt | _t||| _t||| _tj||jd| _d S )Nr   r   epsrD   )rQ   rR   r   _compute_window_shiftr<   
shift_sizer   r   r   r   r   	attentionr   rh   layer_norm_epslayernorm_beforerO   IdentityrN   r   intermediater   rM   layernorm_after)	r.   rn   ru   r   r   drop_path_rater   r   r<   rS   r#   r$   rR   x  s*    


	zSwinv2Layer.__init__rX   c                 C   s6   dd t | j|D }dd t | j||D }||fS )Nc                 S   s    g | ]\}}||kr|n|qS r#   r#   ).0rwr#   r#   r$   
<listcomp>      z5Swinv2Layer._compute_window_shift.<locals>.<listcomp>c                 S   s"   g | ]\}}}||krd n|qS )r   r#   )r   r   r   sr#   r#   r$   r     r   )zipr   )r.   Ztarget_window_sizeZtarget_shift_sizer<   r   r#   r#   r$   r     s    z!Swinv2Layer._compute_window_shiftc              	   C   s  | j dkrtjd||df|d}td| j t| j | j  t| j  d f}td| j t| j | j  t| j  d f}d}|D ].}|D ]$}	||d d ||	d d f< |d7 }qqt|| j}
|
d| j| j }
|
d|
d }||dkd|dkd}nd }|S )Nr   r   r   r5   r2   g      YrD   )	r   r   rd   slicer<   rB   r8   r   Zmasked_fill)r.   r>   r?   rI   Zimg_maskZheight_slicesZwidth_slicescountZheight_sliceZwidth_sliceZmask_windows	attn_maskr#   r#   r$   get_attn_mask  s*    zSwinv2Layer.get_attn_maskc                 C   sR   | j || j   | j  }| j || j   | j  }ddd|d|f}tj||}||fS Nr   )r<   r   rx   r   )r.   r   r>   r?   	pad_rightZ
pad_bottomr   r#   r#   r$   r     s
    zSwinv2Layer.maybe_padNFr   r   r   r   rH   c                 C   s  |\}}|  \}}}	|}
|||||	}| |||\}}|j\}}}}| jdkrrtj|| j | j fdd}n|}t|| j}|d| j| j |	}| j	|||j
d}|d ur||j}| j||||d}|d }|d| j| j|	}t|| j||}| jdkr"tj|| j| jfdd}n|}|d dkp>|d dk}|rj|d d d |d |d d f  }|||| |	}| |}|
| | }| |}| |}|| | | }|r||d	 fn|f}|S )
Nr   )r   r2   )Zshiftsdimsr5   r   )r   r   r4   r   )rr   r8   r   r7   r   r   ZrollrB   r<   r   rI   r   rJ   r   rC   r:   r   rN   r   rM   r   )r.   r   r   r   r   r>   r?   r=   r   ZchannelsZshortcutr   Z
height_padZ	width_padZshifted_hidden_statesZhidden_states_windowsr   Zattention_outputsr   Zattention_windowsZshifted_windowsZ
was_paddedZlayer_outputlayer_outputsr#   r#   r$   rW     sD    
$


zSwinv2Layer.forward)rD   r   r   )NF)r   r   r   rR   r"   r   r   r   r   r   r   r   r    r   rW   r\   r#   r#   rS   r$   r   w  s    &  
r   c                       sP   e Zd Zd	 fdd	Zd
ejeeef eej	 ee
 eej dddZ  ZS )Swinv2Stager   c	              
      s   t    || _|| _g }	t|D ]>}
t||||||
 |
d dkrFdn|jd |d}|	| q"t	|	| _
|d ur|||tjd| _nd | _d| _d S )Nr2   r   )rn   ru   r   r   r   r   r   )ru   r   F)rQ   rR   rn   ru   ranger   r<   appendr   
ModuleListblocksrh   
downsampleZpointing)r.   rn   ru   r   depthr   rN   r   r   r   iblockrS   r#   r$   rR     s(    
	zSwinv2Stage.__init__NFr   c                 C   s   |\}}t | jD ]2\}}|d ur*|| nd }	||||	|}
|
d }q|}| jd ur|d d |d d  }}||||f}| ||}n||||f}|||f}|r||
dd  7 }|S )Nr   r   r2   )	enumerater   r   )r.   r   r   r   r   r>   r?   r   layer_modulelayer_head_maskr   !hidden_states_before_downsamplingZheight_downsampledZwidth_downsampledr   Zstage_outputsr#   r#   r$   rW     s(    


zSwinv2Stage.forward)r   )NF)r   r   r   rR   r   r   r"   r   r   r    r   rW   r\   r#   r#   rS   r$   r     s       
r   c                       sd   e Zd Zd
 fdd	Zdejeeef eej	 ee
 ee
 ee
 ee
 eeef ddd	Z  ZS )Swinv2Encoderr   r   r   r   c                    s  t    t|j| _|| _| jjd ur.|j}dd tjd|j	t
|jddD }g }t| jD ]}t|t|jd|  |d d|  |d d|  f|j| |j| |t
|jd | t
|jd |d   || jd k rtnd || d}|| q`t|| _d	| _d S )
Nc                 S   s   g | ]}|  qS r#   )item)r   xr#   r#   r$   r   ;  r   z*Swinv2Encoder.__init__.<locals>.<listcomp>r   cpu)rJ   r2   r   )rn   ru   r   r   r   rN   r   r   F)rQ   rR   r   depths
num_layersrn   pretrained_window_sizesr   Zlinspacer   r   r   r   r   re   r   r   r   r   r   layersZgradient_checkpointing)r.   rn   ra   r  Zdprr  Zi_layerstagerS   r#   r$   rR   5  s*    
$*
zSwinv2Encoder.__init__NFT)r   r   r   r   output_hidden_states(output_hidden_states_before_downsamplingreturn_dictrH   c                 C   s  |rdnd }|rdnd }	|r dnd }
|rn|j \}}}|j|g||R  }|dddd}||f7 }|	|f7 }	t| jD ]\}}|d ur|| nd }|||||}|d }|d }|d }|d |d f}|r,|r,|j \}}}|j|g|d |d f|R  }|dddd}||f7 }|	|f7 }	nR|r~|s~|j \}}}|j|g||R  }|dddd}||f7 }|	|f7 }	|rx|
|dd  7 }
qx|stdd	 |||
|	fD S t|||
|	d
S )Nr#   r   r   r   r2   r   r5   c                 s   s   | ]}|d ur|V  qd S rP   r#   )r   vr#   r#   r$   	<genexpr>  s   z(Swinv2Encoder.forward.<locals>.<genexpr>)r   r   r   r   )r7   r8   r9   r  r  r"   r   )r.   r   r   r   r   r  r  r  Zall_hidden_statesZall_reshaped_hidden_statesZall_self_attentionsr=   r   r   Zreshaped_hidden_stater   r  r  r   r  r   r#   r#   r$   rW   N  sf    






zSwinv2Encoder.forward)r  )NFFFT)r   r   r   rR   r   r   r"   r   r   r    r   r   r   rW   r\   r#   r#   rS   r$   r  4  s         

r  c                   @   s0   e Zd ZU eed< dZdZdZdgZdd Z	dS )	Swinv2PreTrainedModelrn   swinv2r{   Tr   c                 C   s   t |tjtjfr@|jjjd| jjd |j	dur|j	j
  n~t |tjrh|j	j
  |jjd nVt |tr|jdur|jj
  |jdur|jj
  nt |tr|jjtd dS )zInitialize the weightsrD   )meanZstdNr}   r   )r   r   r   r   weightdataZnormal_rn   Zinitializer_ranger   Zzero_rh   Zfill_r]   rf   rg   r   r   r   r   )r.   moduler#   r#   r$   _init_weights  s    




z#Swinv2PreTrainedModel._init_weightsN)
r   r   r   r   r!   Zbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_no_split_modulesr  r#   r#   r#   r$   r    s   
r  c                       sv   e Zd Zd fdd	Zdd Zdd Zedeej	 eej
 eej	 ee ee eee eeef d
ddZ  ZS )Swinv2ModelTFc                    s   t  | || _t|j| _t|jd| jd   | _t	||d| _
t|| j
j| _tj| j|jd| _|rxtdnd| _|   dS )a  
        add_pooling_layer (`bool`, *optional*, defaults to `True`):
            Whether or not to apply pooling layer.
        use_mask_token (`bool`, *optional*, defaults to `False`):
            Whether or not to create and apply mask tokens in the embedding layer.
        r2   r   )ro   r   N)rQ   rR   rn   r   r
  r  r   re   num_featuresr]   rp   r  rb   encoderr   rh   r   	layernormZAdaptiveAvgPool1dpooler	post_init)r.   rn   add_pooling_layerro   rS   r#   r$   rR     s    zSwinv2Model.__init__c                 C   s   | j jS rP   rp   r_   r-   r#   r#   r$   get_input_embeddings  s    z Swinv2Model.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr  layerr   r   )r.   Zheads_to_pruner%  r   r#   r#   r$   _prune_heads  s    zSwinv2Model._prune_headsNr{   r|   r   r   r  rz   r  rH   c                 C   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}|du rLtd| |t| j j}| j|||d\}}	| j	||	||||d}
|
d }| 
|}d}| jdur| |dd}t|d}|s||f|
dd  }|S t|||
j|
j|
jdS )	z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        Nz You have to specify pixel_values)r|   rz   )r   r   r  r  r   r   r2   )r   r&   r   r   r   )rn   r   r  use_return_dictr   Zget_head_maskr   r
  rp   r  r  r  r   r   r   r%   r   r   r   )r.   r{   r|   r   r   r  rz   r  embedding_outputr   Zencoder_outputssequence_outputpooled_outputrM   r#   r#   r$   rW     sD    
	

zSwinv2Model.forward)TF)NNNNNFN)r   r   r   rR   r#  r&  r   r   r   r    r   r   r   r"   r%   rW   r\   r#   r#   rS   r$   r    s*          
r  a~  
        Swinv2 Model with a decoder on top for masked image modeling, as proposed in
    [SimMIM](https://huggingface.co/papers/2111.09886).

        <Tip>

        Note that we provide a script to pre-train this model on custom data in our [examples
        directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).

        </Tip>
    c                       sd   e Zd Z fddZedeej eej eej ee	 ee	 e	ee	 e
eef dddZ  ZS )	Swinv2ForMaskedImageModelingc                    sn   t  | t|ddd| _t|jd|jd   }ttj	||j
d |j ddt|j
| _|   d S )NFT)r!  ro   r2   r   )Zin_channelsZout_channelsr   )rQ   rR   r  r  r   re   r  r   r   r   Zencoder_strider@   ZPixelShuffledecoderr   )r.   rn   r  rS   r#   r$   rR   (  s    
z%Swinv2ForMaskedImageModeling.__init__NFr'  c              	   C   sB  |dur|n| j j}| j|||||||d}|d }	|	dd}	|	j\}
}}t|d  }}|	|
|||}	| |	}d}|dur| j j	| j j
 }|d||}|| j j
d| j j
dd }tjj||dd	}||  | d
  | j j }|s*|f|dd  }|dur&|f| S |S t|||j|j|jdS )a?  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, Swinv2ForMaskedImageModeling
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
        >>> model = Swinv2ForMaskedImageModeling.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")

        >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
        >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
        >>> # create random boolean mask of shape (batch_size, num_patches)
        >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()

        >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
        >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
        >>> list(reconstructed_pixel_values.shape)
        [1, 3, 256, 256]
        ```N)r|   r   r   r  rz   r  r   r   r2   rq   r5   none)r   gh㈵>)r(   r)   r   r   r   )rn   r(  r  r   r7   r   floorrw   r-  r   rm   Zrepeat_interleaver   r:   r   rx   Zl1_lossr   r@   r'   r   r   r   )r.   r{   r|   r   r   r  rz   r  r   r*  r=   r@   Zsequence_lengthr>   r?   Zreconstructed_pixel_valuesZmasked_im_lossrr   r   Zreconstruction_lossrM   r#   r#   r$   rW   8  sL    &

 z$Swinv2ForMaskedImageModeling.forward)NNNNNFN)r   r   r   rR   r   r   r   r    r   r   r   r"   r'   rW   r\   r#   r#   rS   r$   r,    s&          
r,  a  
    Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
    of the [CLS] token) e.g. for ImageNet.

    <Tip>

        Note that it's possible to fine-tune SwinV2 on higher resolution images than the ones it has been trained on, by
        setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
        position embeddings to the higher resolution.

    </Tip>
    c                       sd   e Zd Z fddZedeej eej eej ee	 ee	 e	ee	 e
eef dddZ  ZS )	Swinv2ForImageClassificationc                    sP   t  | |j| _t|| _|jdkr:t| jj|jnt | _	| 
  d S r   )rQ   rR   Z
num_labelsr  r  r   r   r  r   
classifierr   r.   rn   rS   r#   r$   rR     s    
"z%Swinv2ForImageClassification.__init__NF)r{   r   labelsr   r  rz   r  rH   c                 C   s   |dur|n| j j}| j||||||d}|d }	| |	}
d}|dur\| j|
||
| j d}|s|
f|dd  }|dur|f| S |S t||
|j|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        N)r   r   r  rz   r  r   )r/   r3  Zpooled_logitsrn   r2   )r(   r/   r   r   r   )	rn   r(  r  r1  Zloss_functionr1   r   r   r   )r.   r{   r   r3  r   r  rz   r  r   r+  r/   r(   rM   r#   r#   r$   rW     s0    	
z$Swinv2ForImageClassification.forward)NNNNNFN)r   r   r   rR   r   r   r   r    Z
LongTensorr   r   r"   r1   rW   r\   r#   r#   rS   r$   r0    s&          
r0  zO
    Swinv2 backbone, to be used with frameworks like DETR and MaskFormer.
    c                       sL   e Zd Z fddZdd Zed	eee ee ee e	dddZ
  ZS )
Swinv2Backbonec                    sd   t    t     jg fddtt jD  | _t | _	t
 | j	j| _|   d S )Nc                    s   g | ]}t  jd |  qS )r2   )r   re   )r   r   rn   r#   r$   r     r   z+Swinv2Backbone.__init__.<locals>.<listcomp>)rQ   rR   Z_init_backbonere   r   r   r
  r  r]   rp   r  rb   r  r   r2  rS   r5  r$   rR     s    &
zSwinv2Backbone.__init__c                 C   s   | j jS rP   r"  r-   r#   r#   r$   r#    s    z#Swinv2Backbone.get_input_embeddingsN)r{   r   r  r  rH   c              	   C   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}| |\}}| j||d|dd|d}|rl|jn|d }d}	t| j|D ]\}
}|
| j	v r|	|f7 }	q|s|	f}|r||d f7 }|r||d f7 }|S t
|	|r|jnd|jdS )	aK  
        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoBackbone
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
        >>> model = AutoBackbone.from_pretrained(
        ...     "microsoft/swinv2-tiny-patch4-window8-256", out_features=["stage1", "stage2", "stage3", "stage4"]
        ... )

        >>> inputs = processor(image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> feature_maps = outputs.feature_maps
        >>> list(feature_maps[-1].shape)
        [1, 2048, 7, 7]
        ```NT)r   r   r  r  r  r5   r#   r   r2   )feature_mapsr   r   )rn   r(  r  r   rp   r  r   r   Zstage_namesZout_featuresr
   r   r   )r.   r{   r   r  r  r)  r   r   r   r6  r  Zhidden_staterM   r#   r#   r$   rW     s>     

zSwinv2Backbone.forward)NNN)r   r   r   rR   r#  r   r   r   r   r
   rW   r\   r#   r#   rS   r$   r4    s      r4  )r0  r,  r  r  r4  )rD   F)Br   collections.abcr   r   r*   dataclassesr   typingr   r   r   Ztorch.utils.checkpointr   r   Zactivationsr   Zmodeling_layersr	   Zmodeling_outputsr
   Zmodeling_utilsr   Zpytorch_utilsr   r   r   utilsr   r   r   r   Zutils.backbone_utilsr   Zconfiguration_swinv2r   Z
get_loggerr   loggerr   r%   r'   r1   rB   rC   rZ   r   rN   r   rO   r]   r^   r   r   r   r   r   r   r   r   r  r  r  r,  r0  r4  __all__r#   r#   r#   r$   <module>   s   
]+6 /}@dcg@W