a
    h                     @   s  d Z ddlZddlZddlZddlmZ ddlmZm	Z	 ddl
Z
ddlZ
ddl
mZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZ ddlmZmZmZmZ ddlmZ ddl m!Z! e"e#Z$eeddG dd deZ%eeddG dd deZ&eeddG dd deZ'eeddG dd deZ(dd Z)dd  Z*G d!d" d"ej+Z,G d#d$ d$ej+Z-G d%d& d&ej+Z.dLe
j/e0e1e
j/d)d*d+Z2G d,d- d-ej+Z3G d.d/ d/ej+Z4G d0d1 d1ej+Z5G d2d3 d3ej+Z6G d4d5 d5ej+Z7G d6d7 d7ej+Z8G d8d9 d9ej+Z9G d:d; d;eZ:G d<d= d=ej+Z;eG d>d? d?eZ<eG d@dA dAe<Z=edBdG dCdD dDe<Z>edEdG dFdG dGe<Z?edHdG dIdJ dJe<eZ@g dKZAdS )MzPyTorch Swin Transformer model.    N)	dataclass)OptionalUnion)nn   )ACT2FN)GradientCheckpointingLayer)BackboneOutput)PreTrainedModel) find_pruneable_heads_and_indicesmeshgridprune_linear_layer)ModelOutputauto_docstringlogging	torch_int)BackboneMixin   )
SwinConfigzN
    Swin encoder's outputs, with potential hidden states and attentions.
    )Zcustom_introc                   @   sr   e Zd ZU dZdZeej ed< dZ	ee
ejdf  ed< dZee
ejdf  ed< dZee
ejdf  ed< dS )SwinEncoderOutputa  
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nlast_hidden_state.hidden_states
attentionsreshaped_hidden_states)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   tupler   r    r"   r"   b/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/swin/modeling_swin.pyr   +   s
   
	r   zV
    Swin model's outputs that also contains a pooling of the last hidden states.
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< dS )	SwinModelOutputa  
    pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
        Average pooling of the last layer hidden-state.
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nr   pooler_output.r   r   r   )r   r   r   r   r   r   r   r   r    r%   r   r!   r   r   r"   r"   r"   r#   r$   A   s   
r$   z*
    Swin masked image model outputs.
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< ed	d
 ZdS )SwinMaskedImageModelingOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
        Masked image modeling (MLM) loss.
    reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
        Reconstructed pixel values.
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nlossreconstruction.r   r   r   c                 C   s   t dt | jS )Nzlogits attribute is deprecated and will be removed in version 5 of Transformers. Please use the reconstruction attribute to retrieve the final output instead.)warningswarnFutureWarningr(   selfr"   r"   r#   logitst   s
    z$SwinMaskedImageModelingOutput.logits)r   r   r   r   r'   r   r   r   r    r(   r   r!   r   r   propertyr.   r"   r"   r"   r#   r&   Z   s   
r&   z0
    Swin outputs for image classification.
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< dS )	SwinImageClassifierOutputa7  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Classification (or regression if config.num_labels==1) loss.
    logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
        Classification (or regression if config.num_labels==1) scores (before SoftMax).
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nr'   r.   .r   r   r   )r   r   r   r   r'   r   r   r   r    r.   r   r!   r   r   r"   r"   r"   r#   r0   ~   s   
r0   c                 C   sR   | j \}}}}| ||| ||| ||} | dddddd d|||}|S )z2
    Partitions the given input into windows.
    r   r   r            shapeviewpermute
contiguous)input_featurewindow_size
batch_sizeheightwidthnum_channelswindowsr"   r"   r#   window_partition   s    $rA   c                 C   sN   | j d }| d|| || |||} | dddddd d|||} | S )z?
    Merges windows to produce higher resolution features.
    r4   r   r   r   r1   r2   r3   r5   )r@   r;   r=   r>   r?   r"   r"   r#   window_reverse   s    
$rB   c                       sb   e Zd ZdZd fdd	ZejeeejdddZde	ej
 e	ej eeej d	d
dZ  ZS )SwinEmbeddingszW
    Construct the patch and position embeddings. Optionally, also the mask token.
    Fc                    s   t    t|| _| jj}| jj| _|r@tt	
dd|jnd | _|jrjtt	
d|d |j| _nd | _t|j| _t|j| _|j| _|| _d S )Nr   )super__init__SwinPatchEmbeddingspatch_embeddingsnum_patches	grid_size
patch_gridr   	Parameterr   zeros	embed_dim
mask_tokenZuse_absolute_embeddingsposition_embeddings	LayerNormnormDropouthidden_dropout_probdropout
patch_sizeconfig)r-   rV   use_mask_tokenrH   	__class__r"   r#   rE      s    


 zSwinEmbeddings.__init__)
embeddingsr=   r>   returnc                 C   s   |j d d }| jj d d }tj s>||kr>||kr>| jS | jddddf }| jddddf }|j d }|| j }	|| j }
t|d }|d|||}|dddd}t	j
j||	|
fdd	d
}|dddddd|}tj||fddS )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   Nr4         ?r   r   r1   ZbicubicF)sizemodeZalign_cornersdim)r6   rO   r   jit
is_tracingrU   r   reshaper8   r   
functionalZinterpolater7   cat)r-   rZ   r=   r>   rH   Znum_positionsZclass_pos_embedZpatch_pos_embedr`   Z
new_heightZ	new_widthZsqrt_num_positionsr"   r"   r#   interpolate_pos_encoding   s(    



z'SwinEmbeddings.interpolate_pos_encodingN)pixel_valuesbool_masked_posrf   r[   c                 C   s   |j \}}}}| |\}}	| |}| \}
}}|d urp| j|
|d}|d|}|d|  ||  }| jd ur|r|| 	||| }n
|| j }| 
|}||	fS )Nr4         ?)r6   rG   rQ   r]   rN   expand	unsqueezeZtype_asrO   rf   rT   )r-   rg   rh   rf   _r?   r=   r>   rZ   output_dimensionsr<   Zseq_lenZmask_tokensmaskr"   r"   r#   forward   s    



zSwinEmbeddings.forward)F)NF)r   r   r   r   rE   r   Tensorintrf   r   r   
BoolTensorboolr!   ro   __classcell__r"   r"   rX   r#   rC      s   +  rC   c                       sL   e Zd ZdZ fddZdd Zeej e	ej
e	e f dddZ  ZS )	rF   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j }}|j|j }}t|tjj	r8|n||f}t|tjj	rR|n||f}|d |d  |d |d   }|| _|| _|| _|| _
|d |d  |d |d  f| _tj||||d| _d S )Nr   r   )kernel_sizeZstride)rD   rE   
image_sizerU   r?   rM   
isinstancecollectionsabcIterablerH   rI   r   Conv2d
projection)r-   rV   rv   rU   r?   hidden_sizerH   rX   r"   r#   rE     s    
 "zSwinPatchEmbeddings.__init__c                 C   s   || j d  dkr<d| j d || j d   f}tj||}|| j d  dkr|ddd| j d || j d   f}tj||}|S )Nr   r   )rU   r   rd   pad)r-   rg   r=   r>   
pad_valuesr"   r"   r#   	maybe_pad!  s     zSwinPatchEmbeddings.maybe_pad)rg   r[   c                 C   sV   |j \}}}}| |||}| |}|j \}}}}||f}|ddd}||fS )Nr1   r   )r6   r   r|   flatten	transpose)r-   rg   rl   r?   r=   r>   rZ   rm   r"   r"   r#   ro   *  s    
zSwinPatchEmbeddings.forward)r   r   r   r   rE   r   r   r   r   r!   rp   rq   ro   rt   r"   r"   rX   r#   rF     s   	rF   c                       s^   e Zd ZdZejfee eejdd fddZ	dd Z
ejeeef ejdd	d
Z  ZS )SwinPatchMerginga'  
    Patch Merging Layer.

    Args:
        input_resolution (`tuple[int]`):
            Resolution of input feature.
        dim (`int`):
            Number of input channels.
        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
            Normalization layer class.
    N)input_resolutionr`   
norm_layerr[   c                    sB   t    || _|| _tjd| d| dd| _|d| | _d S )Nr2   r1   Fbias)rD   rE   r   r`   r   Linear	reductionrQ   )r-   r   r`   r   rX   r"   r#   rE   C  s
    
zSwinPatchMerging.__init__c                 C   sF   |d dkp|d dk}|rBddd|d d|d f}t j||}|S )Nr1   r   r   )r   rd   r~   )r-   r:   r=   r>   Z
should_padr   r"   r"   r#   r   J  s
    zSwinPatchMerging.maybe_pad)r:   input_dimensionsr[   c                 C   s   |\}}|j \}}}|||||}| |||}|d d dd ddd dd d f }|d d dd ddd dd d f }	|d d dd ddd dd d f }
|d d dd ddd dd d f }t||	|
|gd}||dd| }| |}| |}|S )Nr   r1   r   r4   r2   )r6   r7   r   r   re   rQ   r   )r-   r:   r   r=   r>   r<   r`   r?   Zinput_feature_0Zinput_feature_1Zinput_feature_2Zinput_feature_3r"   r"   r#   ro   R  s    $$$$

zSwinPatchMerging.forward)r   r   r   r   r   rP   r!   rq   ModulerE   r   r   rp   ro   rt   r"   r"   rX   r#   r   6  s   $r           F)input	drop_probtrainingr[   c                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    r   r   r   )r   dtypedevice)r6   ndimr   Zrandr   r   Zfloor_div)r   r   r   Z	keep_probr6   Zrandom_tensoroutputr"   r"   r#   	drop_pathm  s    
r   c                       sP   e Zd ZdZdee dd fddZejejdddZ	e
d	d
dZ  ZS )SwinDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).N)r   r[   c                    s   t    || _d S N)rD   rE   r   )r-   r   rX   r"   r#   rE     s    
zSwinDropPath.__init__r   r[   c                 C   s   t || j| jS r   )r   r   r   r-   r   r"   r"   r#   ro     s    zSwinDropPath.forward)r[   c                 C   s   d| j  S )Nzp=)r   r,   r"   r"   r#   
extra_repr  s    zSwinDropPath.extra_repr)N)r   r   r   r   r   floatrE   r   rp   ro   strr   rt   r"   r"   rX   r#   r     s   r   c                       sL   e Zd Z fddZdejeej eej ee e	ej dddZ
  ZS )	SwinSelfAttentionc                    s
  t    || dkr,td| d| d|| _t|| | _| j| j | _t|tj	j
r`|n||f| _ttd| jd  d d| jd  d  || _t| jd }t| jd }tt||gdd}t|d}|d d d d d f |d d d d d f  }	|	ddd }	|	d d d d df  | jd d 7  < |	d d d d df  | jd d 7  < |	d d d d df  d| jd  d 9  < |	d	}
| d
|
 tj| j| j|jd| _tj| j| j|jd| _tj| j| j|jd| _t|j| _ d S )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ()r1   r   Zij)Zindexingr4   relative_position_indexr   )!rD   rE   
ValueErrornum_attention_headsrq   attention_head_sizeall_head_sizerw   rx   ry   rz   r;   r   rK   r   rL   relative_position_bias_tableZarangestackr   r   r8   r9   sumZregister_bufferr   Zqkv_biasquerykeyvaluerR   attention_probs_dropout_probrT   )r-   rV   r`   	num_headsr;   Zcoords_hZcoords_wZcoordsZcoords_flattenZrelative_coordsr   rX   r"   r#   rE     s8    
*,((,
zSwinSelfAttention.__init__NFr   attention_mask	head_maskoutput_attentionsr[   c                 C   s  |j \}}}||d| jf}| ||dd}	| ||dd}
| ||dd}t|	|
dd}|t	
| j }| j| jd }|| jd | jd  | jd | jd  d}|ddd }||d }|d ur8|j d }||| || j||}||dd }|d| j||}tjj|dd}| |}|d urd|| }t||}|dddd }| d d | jf }||}|r||fn|f}|S )Nr4   r   r1   r   r_   r   )r6   r   r   r7   r   r   r   r   matmulmathsqrtr   r   r;   r8   r9   rk   r   r   rd   ZsoftmaxrT   r]   r   )r-   r   r   r   r   r<   r`   r?   Zhidden_shapeZquery_layerZ	key_layerZvalue_layerZattention_scoresZrelative_position_biasZ
mask_shapeZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr"   r"   r#   ro     s<    &




zSwinSelfAttention.forward)NNF)r   r   r   rE   r   rp   r   r   rs   r!   ro   rt   r"   r"   rX   r#   r     s   (   r   c                       s4   e Zd Z fddZejejejdddZ  ZS )SwinSelfOutputc                    s*   t    t||| _t|j| _d S r   )rD   rE   r   r   denserR   r   rT   r-   rV   r`   rX   r"   r#   rE     s    
zSwinSelfOutput.__init__)r   input_tensorr[   c                 C   s   |  |}| |}|S r   r   rT   )r-   r   r   r"   r"   r#   ro     s    

zSwinSelfOutput.forwardr   r   r   rE   r   rp   ro   rt   r"   r"   rX   r#   r     s   r   c                       sT   e Zd Z fddZdd Zd
ejeej eej ee	 e
ej ddd	Z  ZS )SwinAttentionc                    s2   t    t||||| _t||| _t | _d S r   )rD   rE   r   r-   r   r   setpruned_heads)r-   rV   r`   r   r;   rX   r"   r#   rE     s    
zSwinAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r_   )lenr   r-   r   r   r   r   r   r   r   r   r   r   union)r-   headsindexr"   r"   r#   prune_heads  s    zSwinAttention.prune_headsNFr   c                 C   s6   |  ||||}| |d |}|f|dd   }|S )Nr   r   )r-   r   )r-   r   r   r   r   Zself_outputsattention_outputr   r"   r"   r#   ro     s    zSwinAttention.forward)NNF)r   r   r   rE   r   r   rp   r   r   rs   r!   ro   rt   r"   r"   rX   r#   r     s      r   c                       s0   e Zd Z fddZejejdddZ  ZS )SwinIntermediatec                    sH   t    t|t|j| | _t|jt	r<t
|j | _n|j| _d S r   )rD   rE   r   r   rq   	mlp_ratior   rw   Z
hidden_actr   r   intermediate_act_fnr   rX   r"   r#   rE   #  s
    
zSwinIntermediate.__init__r   c                 C   s   |  |}| |}|S r   )r   r   r   r"   r"   r#   ro   +  s    

zSwinIntermediate.forwardr   r"   r"   rX   r#   r   "  s   r   c                       s0   e Zd Z fddZejejdddZ  ZS )
SwinOutputc                    s4   t    tt|j| || _t|j| _	d S r   )
rD   rE   r   r   rq   r   r   rR   rS   rT   r   rX   r"   r#   rE   2  s    
zSwinOutput.__init__r   c                 C   s   |  |}| |}|S r   r   r   r"   r"   r#   ro   7  s    

zSwinOutput.forwardr   r"   r"   rX   r#   r   1  s   r   c                	       st   e Zd Zd fdd	Zdd Zdd Zd	d
 Zdeje	e
e
f eej ee ee e	ejejf dddZ  ZS )	SwinLayerr   r   c                    s   t    |j| _|| _|j| _|| _tj||jd| _	t
|||| jd| _|dkr\t|nt | _tj||jd| _t||| _t||| _d S )Neps)r;   r   )rD   rE   Zchunk_size_feed_forward
shift_sizer;   r   r   rP   layer_norm_epslayernorm_beforer   	attentionr   Identityr   layernorm_afterr   intermediater   r   )r-   rV   r`   r   r   drop_path_rater   rX   r"   r#   rE   >  s    
zSwinLayer.__init__c                 C   s@   t || jkr<td| _tj r2t t|nt || _d S Nr   )minr;   r   r   r   ra   rb   Ztensor)r-   r   r"   r"   r#   set_shift_and_window_sizeK  s    
 z#SwinLayer.set_shift_and_window_sizec              	   C   s  | j dkrtjd||df||d}td| j t| j | j  t| j  d f}td| j t| j | j  t| j  d f}d}|D ].}	|D ]$}
||d d |	|
d d f< |d7 }qqt|| j}|d| j| j }|d|d }||dkd|dkd}nd }|S )Nr   r   r   r4   r1   g      Yr   )	r   r   rL   slicer;   rA   r7   rk   Zmasked_fill)r-   r=   r>   r   r   Zimg_maskZheight_slicesZwidth_slicescountZheight_sliceZwidth_sliceZmask_windows	attn_maskr"   r"   r#   get_attn_maskS  s*    zSwinLayer.get_attn_maskc                 C   sR   | j || j   | j  }| j || j   | j  }ddd|d|f}tj||}||fS r   )r;   r   rd   r~   )r-   r   r=   r>   	pad_rightZ
pad_bottomr   r"   r"   r#   r   o  s
    zSwinLayer.maybe_padNFr   r   r   r   always_partitionr[   c                 C   s  |s|  | n |\}}| \}}	}
|}| |}|||||
}| |||\}}|j\}	}}}	| jdkrtj|| j | j fdd}n|}t	|| j
}|d| j
| j
 |
}| j|||j|jd}| j||||d}|d }|d| j
| j
|
}t|| j
||}| jdkr,tj|| j| jfdd}n|}|d dkpH|d dk}|rt|d d d |d |d d f  }|||| |
}|| | }| |}| |}|| | }|r||d	 fn|f}|S )
Nr   )r   r1   )Zshiftsdimsr4   r   )r   r   r3   r   )r   r]   r   r7   r   r6   r   r   ZrollrA   r;   r   r   r   r   rB   r9   r   r   r   r   )r-   r   r   r   r   r   r=   r>   r<   rl   channelsZshortcutr   Z
height_padZ	width_padZshifted_hidden_statesZhidden_states_windowsr   Zattention_outputsr   Zattention_windowsZshifted_windowsZ
was_paddedZlayer_outputlayer_outputsr"   r"   r#   ro   v  sH    

$

zSwinLayer.forward)r   r   )NFF)r   r   r   rE   r   r   r   r   rp   r!   rq   r   r   rs   ro   rt   r"   r"   rX   r#   r   =  s      
r   c                       sT   e Zd Z fddZdejeeef eej	 ee
 ee
 eej dddZ  ZS )		SwinStagec                    sh   t     | _| _t fddt|D | _|d urX|tjd| _	nd | _	d| _
d S )Nc              
      s:   g | ]2}t  | |d  dkr(dn jd  dqS )r1   r   )rV   r`   r   r   r   r   )r   r;   .0irV   r`   r   r   r   r"   r#   
<listcomp>  s   	z&SwinStage.__init__.<locals>.<listcomp>)r`   r   F)rD   rE   rV   r`   r   
ModuleListrangeblocksrP   
downsampleZpointing)r-   rV   r`   r   depthr   r   r   rX   r   r#   rE     s    
	zSwinStage.__init__NFr   c                 C   s   |\}}t | jD ]4\}}	|d ur*|| nd }
|	|||
||}|d }q|}| jd ur|d d |d d  }}||||f}| ||}n||||f}|||f}|r||dd  7 }|S )Nr   r   r1   )	enumerater   r   )r-   r   r   r   r   r   r=   r>   r   layer_modulelayer_head_maskr   !hidden_states_before_downsamplingZheight_downsampledZwidth_downsampledrm   Zstage_outputsr"   r"   r#   ro     s"    



zSwinStage.forward)NFF)r   r   r   rE   r   rp   r!   rq   r   r   rs   ro   rt   r"   r"   rX   r#   r     s      
r   c                       sh   e Zd Z fddZd	ejeeef eej	 ee
 ee
 ee
 ee
 ee
 eeef d	ddZ  ZS )
SwinEncoderc                    sp   t    t j_ _dd tjd jt	 jddD t
 fddtjD _d_d S )Nc                 S   s   g | ]}|  qS r"   )item)r   xr"   r"   r#   r         z(SwinEncoder.__init__.<locals>.<listcomp>r   cpu)r   c                    s   g | ]}t  t jd |  d d |  d d |  f j|  j| t jd| t jd|d   |jd k rtnddqS )r1   r   r   N)rV   r`   r   r   r   r   r   )r   rq   rM   depthsr   r   
num_layersr   )r   Zi_layerrV   ZdprrI   r-   r"   r#   r     s   
*F)rD   rE   r   r   r   rV   r   Zlinspacer   r   r   r   r   layersZgradient_checkpointing)r-   rV   rI   rX   r   r#   rE     s    
$
zSwinEncoder.__init__NFT)	r   r   r   r   output_hidden_states(output_hidden_states_before_downsamplingr   return_dictr[   c	                 C   s  |rdnd }	|rdnd }
|r dnd }|rn|j \}}}|j|g||R  }|dddd}|	|f7 }	|
|f7 }
t| jD ]\}}|d ur|| nd }||||||}|d }|d }|d }|d |d f}|r.|r.|j \}}}|j|g|d |d f|R  }|dddd}|	|f7 }	|
|f7 }
nR|r|s|j \}}}|j|g||R  }|dddd}|	|f7 }	|
|f7 }
|rx||dd  7 }qx|stdd	 ||	|fD S t||	||
d
S )Nr"   r   r   r   r1   r   r4   c                 s   s   | ]}|d ur|V  qd S r   r"   )r   vr"   r"   r#   	<genexpr>G  r   z&SwinEncoder.forward.<locals>.<genexpr>)r   r   r   r   )r6   r7   r8   r   r   r!   r   )r-   r   r   r   r   r   r   r   r   Zall_hidden_statesZall_reshaped_hidden_statesZall_self_attentionsr<   rl   r}   Zreshaped_hidden_stater   r   r   r   r   rm   r"   r"   r#   ro     s\    





zSwinEncoder.forward)NFFFFT)r   r   r   rE   r   rp   r!   rq   r   r   rs   r   r   ro   rt   r"   r"   rX   r#   r     s$         

r   c                   @   s0   e Zd ZU eed< dZdZdZdgZdd Z	dS )	SwinPreTrainedModelrV   swinrg   Tr   c                 C   s   t |tjtjfr@|jjjd| jjd |j	dur|j	j
  nvt |tjrh|j	j
  |jjd nNt |tr|jdur|jj
  |jdur|jj
  nt |tr|jj
  dS )zInitialize the weightsr   )meanZstdNri   )rw   r   r   r{   weightdataZnormal_rV   Zinitializer_ranger   Zzero_rP   Zfill_rC   rN   rO   r   r   )r-   moduler"   r"   r#   _init_weightsY  s    




z!SwinPreTrainedModel._init_weightsN)
r   r   r   r   r    Zbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_no_split_modulesr  r"   r"   r"   r#   r   Q  s   
r   c                       sv   e Zd Zd fdd	Zdd Zdd Zedeej	 eej
 eej	 ee ee eee eeef d
ddZ  ZS )	SwinModelTFc                    s   t  | || _t|j| _t|jd| jd   | _t	||d| _
t|| j
j| _tj| j|jd| _|rxtdnd| _|   dS )a  
        add_pooling_layer (`bool`, *optional*, defaults to `True`):
            Whether or not to apply pooling layer.
        use_mask_token (`bool`, *optional*, defaults to `False`):
            Whether or not to create and apply mask tokens in the embedding layer.
        r1   r   )rW   r   N)rD   rE   rV   r   r   r   rq   rM   num_featuresrC   rZ   r   rJ   encoderr   rP   r   	layernormZAdaptiveAvgPool1dpooler	post_init)r-   rV   add_pooling_layerrW   rX   r"   r#   rE   o  s    zSwinModel.__init__c                 C   s   | j jS r   rZ   rG   r,   r"   r"   r#   get_input_embeddings  s    zSwinModel.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr  layerr   r   )r-   Zheads_to_pruner  r   r"   r"   r#   _prune_heads  s    zSwinModel._prune_headsNrg   rh   r   r   r   rf   r   r[   c                 C   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}|du rLtd| |t| j j}| j|||d\}}	| j	||	||||d}
|
d }| 
|}d}| jdur| |dd}t|d}|s||f|
dd  }|S t|||
j|
j|
jdS )	z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        Nz You have to specify pixel_values)rh   rf   )r   r   r   r   r   r   r1   )r   r%   r   r   r   )rV   r   r   use_return_dictr   Zget_head_maskr   r   rZ   r  r  r	  r   r   r   r$   r   r   r   )r-   rg   rh   r   r   r   rf   r   embedding_outputr   Zencoder_outputssequence_outputpooled_outputr   r"   r"   r#   ro     sD    
	

zSwinModel.forward)TF)NNNNNFN)r   r   r   rE   r  r  r   r   r   r   rr   rs   r   r!   r$   ro   rt   r"   r"   rX   r#   r  m  s*          
r  ad  
    Swin Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://huggingface.co/papers/2111.09886).

    <Tip>

    Note that we provide a script to pre-train this model on custom data in our [examples
    directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).

    </Tip>
    c                       sd   e Zd Z fddZedeej eej eej ee	 ee	 e	ee	 e
eef dddZ  ZS )	SwinForMaskedImageModelingc                    sn   t  | t|ddd| _t|jd|jd   }ttj	||j
d |j ddt|j
| _|   d S )NFT)r  rW   r1   r   )Zin_channelsZout_channelsru   )rD   rE   r  r   rq   rM   r   r   Z
Sequentialr{   Zencoder_strider?   ZPixelShuffledecoderr
  )r-   rV   r  rX   r"   r#   rE     s    
z#SwinForMaskedImageModeling.__init__NFr  c              	   C   sB  |dur|n| j j}| j|||||||d}|d }	|	dd}	|	j\}
}}t|d  }}|	|
|||}	| |	}d}|dur| j j	| j j
 }|d||}|| j j
d| j j
dd }tjj||dd	}||  | d
  | j j }|s*|f|dd  }|dur&|f| S |S t|||j|j|jdS )a7  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, SwinForMaskedImageModeling
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-base-simmim-window6-192")
        >>> model = SwinForMaskedImageModeling.from_pretrained("microsoft/swin-base-simmim-window6-192")

        >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
        >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
        >>> # create random boolean mask of shape (batch_size, num_patches)
        >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()

        >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
        >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
        >>> list(reconstructed_pixel_values.shape)
        [1, 3, 192, 192]
        ```N)rh   r   r   r   rf   r   r   r   r1   r\   r4   none)r   gh㈵>)r'   r(   r   r   r   )rV   r  r   r   r6   r   floorrc   r  rv   rU   Zrepeat_interleaverk   r9   r   rd   Zl1_lossr   r?   r&   r   r   r   )r-   rg   rh   r   r   r   rf   r   r   r  r<   r?   Zsequence_lengthr=   r>   Zreconstructed_pixel_valuesZmasked_im_lossr]   rn   Zreconstruction_lossr   r"   r"   r#   ro     sL    &

 z"SwinForMaskedImageModeling.forward)NNNNNFN)r   r   r   rE   r   r   r   r   rr   rs   r   r!   r&   ro   rt   r"   r"   rX   r#   r    s&          
r  a  
    Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
    the [CLS] token) e.g. for ImageNet.

    <Tip>

        Note that it's possible to fine-tune Swin on higher resolution images than the ones it has been trained on, by
        setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
        position embeddings to the higher resolution.

    </Tip>
    c                       sd   e Zd Z fddZedeej eej eej ee	 ee	 e	ee	 e
eef dddZ  ZS )	SwinForImageClassificationc                    sP   t  | |j| _t|| _|jdkr:t| jj|jnt | _	| 
  d S r   )rD   rE   Z
num_labelsr  r   r   r   r  r   
classifierr
  )r-   rV   rX   r"   r#   rE   S  s    
"z#SwinForImageClassification.__init__NF)rg   r   labelsr   r   rf   r   r[   c                 C   s   |dur|n| j j}| j||||||d}|d }	| |	}
d}|dur\| j|
||
| j d}|s|
f|dd  }|dur|f| S |S t||
|j|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        N)r   r   r   rf   r   r   )r.   r  Zpooled_logitsrV   r1   )r'   r.   r   r   r   )	rV   r  r   r  Zloss_functionr0   r   r   r   )r-   rg   r   r  r   r   rf   r   r   r  r.   r'   r   r"   r"   r#   ro   a  s0    	
z"SwinForImageClassification.forward)NNNNNFN)r   r   r   rE   r   r   r   r   Z
LongTensorrs   r   r!   r0   ro   rt   r"   r"   rX   r#   r  D  s&          
r  zM
    Swin backbone, to be used with frameworks like DETR and MaskFormer.
    c                       sP   e Zd Zed fddZdd Zd
ejee	 ee	 ee	 e
ddd	Z  ZS )SwinBackbonerV   c                    s   t    t     jg fddtt jD  | _t | _	t
 | j	j| _i }t| j| jD ]\}}t|||< qjt|| _|   d S )Nc                    s   g | ]}t  jd |  qS )r1   )rq   rM   r   r  r"   r#   r     r   z)SwinBackbone.__init__.<locals>.<listcomp>)rD   rE   Z_init_backbonerM   r   r   r   r  rC   rZ   r   rJ   r  zipZ_out_featuresr   r   rP   Z
ModuleDicthidden_states_normsr
  )r-   rV   r   stager?   rX   r  r#   rE     s    &
zSwinBackbone.__init__c                 C   s   | j jS r   r  r,   r"   r"   r#   r    s    z!SwinBackbone.get_input_embeddingsN)rg   r   r   r   r[   c              
   C   s<  |dur|n| j j}|dur |n| j j}|dur4|n| j j}| |\}}| j||d|ddddd}|j}d}	t| j|D ]\}
}|
| j	v rz|j
\}}}}|dddd }|||| |}| j|
 |}|||||}|dddd }|	|f7 }	qz|s |	f}|r||jf7 }|S t|	|r0|jnd|jd	S )
aK  
        Returns:

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoBackbone
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
        >>> model = AutoBackbone.from_pretrained(
        ...     "microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"]
        ... )

        >>> inputs = processor(image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> feature_maps = outputs.feature_maps
        >>> list(feature_maps[-1].shape)
        [1, 768, 7, 7]
        ```NT)r   r   r   r   r   r   r"   r   r1   r   r   )feature_mapsr   r   )rV   r  r   r   rZ   r  r   r  Zstage_namesZout_featuresr6   r8   r9   r7   r   r   r	   r   )r-   rg   r   r   r   r  r   r   r   r"  r!  Zhidden_stater<   r?   r=   r>   r   r"   r"   r#   ro     sH     
zSwinBackbone.forward)NNN)r   r   r   r   rE   r  r   rp   r   rs   r	   ro   rt   r"   r"   rX   r#   r    s      r  )r  r  r  r   r  )r   F)Br   collections.abcrx   r   r)   dataclassesr   typingr   r   r   Ztorch.utils.checkpointr   Zactivationsr   Zmodeling_layersr   Zmodeling_outputsr	   Zmodeling_utilsr
   Zpytorch_utilsr   r   r   utilsr   r   r   r   Zutils.backbone_utilsr   Zconfiguration_swinr   Z
get_loggerr   loggerr   r$   r&   r0   rA   rB   r   rC   rF   r   rp   r   rs   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r  __all__r"   r"   r"   r#   <module>   s   

\+7_&}<[cg@b