a
    hF                     @   s  d Z ddlZddlZddlmZ ddlmZmZ ddl	Z	ddl
Z	ddl	mZ ddlmZ ddlmZ dd	lmZmZ dd
lmZ ddlmZmZmZ ddlmZmZmZ ddlmZ ee Z!eeddG dd deZ"dd Z#dd Z$dHe	j%e&e'e	j%dddZ(G dd dej)Z*G dd dej)Z+G d d! d!ej)Z,G d"d# d#ej)Z-G d$d% d%ej)Z.G d&d' d'ej)Z/G d(d) d)ej)Z0G d*d+ d+ej)Z1G d,d- d-ej)Z2G d.d/ d/ej)Z3G d0d1 d1ej)Z4G d2d3 d3eZ5G d4d5 d5ej)Z6eG d6d7 d7eZ7eG d8d9 d9e7Z8G d:d; d;ej)Z9G d<d= d=ej)Z:G d>d? d?ej)Z;G d@dA dAej)Z<G dBdC dCej)Z=edDdG dEdF dFe7Z>g dGZ?dS )Iz"PyTorch Swin2SR Transformer model.    N)	dataclass)OptionalUnion)nn   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputImageSuperResolutionOutput)PreTrainedModel) find_pruneable_heads_and_indicesmeshgridprune_linear_layer)ModelOutputauto_docstringlogging   )Swin2SRConfigzQ
    Swin2SR encoder's outputs, with potential hidden states and attentions.
    )Zcustom_introc                   @   sL   e Zd ZU dZeej ed< dZee	ej  ed< dZ
ee	ej  ed< dS )Swin2SREncoderOutputNlast_hidden_statehidden_states
attentions)__name__
__module____qualname__r   r   torchFloatTensor__annotations__r   tupler    r   r   h/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/swin2sr/modeling_swin2sr.pyr   &   s   
r   c                 C   sR   | j \}}}}| ||| ||| ||} | dddddd d|||}|S )z2
    Partitions the given input into windows.
    r   r   r            shapeviewpermute
contiguous)input_featurewindow_size
batch_sizeheightwidthnum_channelswindowsr   r   r    window_partition3   s    $r1   c                 C   sN   | j d }| d|| || |||} | dddddd d|||} | S )z?
    Merges windows to produce higher resolution features.
    r$   r   r   r   r!   r"   r#   r%   )r0   r+   r-   r.   r/   r   r   r    window_reverse@   s    
$r2           F)input	drop_probtrainingreturnc                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    r3   r   r   )r   )dtypedevice)r&   ndimr   Zrandr8   r9   Zfloor_div)r4   r5   r6   Z	keep_probr&   Zrandom_tensoroutputr   r   r    	drop_pathK   s    
r=   c                       sP   e Zd ZdZdee dd fddZejejdddZ	e
d	d
dZ  ZS )Swin2SRDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).N)r5   r7   c                    s   t    || _d S N)super__init__r5   )selfr5   	__class__r   r    rA   c   s    
zSwin2SRDropPath.__init__r   r7   c                 C   s   t || j| jS r?   )r=   r5   r6   rB   r   r   r   r    forwardg   s    zSwin2SRDropPath.forwardr7   c                 C   s   d| j  S )Nzp=)r5   rB   r   r   r    
extra_reprj   s    zSwin2SRDropPath.extra_repr)N)r   r   r   __doc__r   floatrA   r   TensorrG   strrJ   __classcell__r   r   rC   r    r>   `   s   r>   c                       s<   e Zd ZdZ fddZeej eej	 dddZ
  ZS )Swin2SREmbeddingsz?
    Construct the patch and optional position embeddings.
    c                    s`   t    t|| _| jj}|jr@tt	d|d |j
| _nd | _t|j| _|j| _d S )Nr   )r@   rA   Swin2SRPatchEmbeddingspatch_embeddingsnum_patchesZuse_absolute_embeddingsr   	Parameterr   zeros	embed_dimposition_embeddingsDropouthidden_dropout_probdropoutr+   )rB   configrS   rC   r   r    rA   s   s    

zSwin2SREmbeddings.__init__)pixel_valuesr7   c                 C   s4   |  |\}}| jd ur"|| j }| |}||fS r?   )rR   rW   rZ   )rB   r\   
embeddingsoutput_dimensionsr   r   r    rG      s
    


zSwin2SREmbeddings.forward)r   r   r   rK   rA   r   r   r   r   rM   rG   rO   r   r   rC   r    rP   n   s   rP   c                       sB   e Zd Zd fdd	Zeej eejee	 f dddZ
  ZS )rQ   Tc                    s   t    |j}|j|j }}t|tjjr0|n||f}t|tjjrJ|n||f}|d |d  |d |d  g}|| _	|d |d  | _
tj||j||d| _|rt|jnd | _d S )Nr   r   )Zkernel_sizeZstride)r@   rA   rV   
image_size
patch_size
isinstancecollectionsabcIterablepatches_resolutionrS   r   Conv2d
projection	LayerNorm	layernorm)rB   r[   normalize_patchesr/   r_   r`   re   rC   r   r    rA      s    
 zSwin2SRPatchEmbeddings.__init__)r]   r7   c                 C   sN   |  |}|j\}}}}||f}|ddd}| jd urF| |}||fS )Nr!   r   )rg   r&   flatten	transposeri   )rB   r]   _r-   r.   r^   r   r   r    rG      s    


zSwin2SRPatchEmbeddings.forward)T)r   r   r   rA   r   r   r   r   rM   intrG   rO   r   r   rC   r    rQ      s   rQ   c                       s(   e Zd ZdZ fddZdd Z  ZS )Swin2SRPatchUnEmbeddingszImage to Patch Unembeddingc                    s   t    |j| _d S r?   )r@   rA   rV   )rB   r[   rC   r   r    rA      s    
z!Swin2SRPatchUnEmbeddings.__init__c                 C   s2   |j \}}}|dd|| j|d |d }|S )Nr   r!   r   )r&   rl   r'   rV   )rB   r]   Zx_sizer,   Zheight_widthr/   r   r   r    rG      s    "z Swin2SRPatchUnEmbeddings.forwardr   r   r   rK   rA   rG   rO   r   r   rC   r    ro      s   ro   c                       s^   e Zd ZdZejfee eejdd fddZ	dd Z
ejeeef ejdd	d
Z  ZS )Swin2SRPatchMerginga'  
    Patch Merging Layer.

    Args:
        input_resolution (`tuple[int]`):
            Resolution of input feature.
        dim (`int`):
            Number of input channels.
        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
            Normalization layer class.
    N)input_resolutiondim
norm_layerr7   c                    sB   t    || _|| _tjd| d| dd| _|d| | _d S )Nr"   r!   Fbias)r@   rA   rr   rs   r   Linear	reductionnorm)rB   rr   rs   rt   rC   r   r    rA      s
    
zSwin2SRPatchMerging.__init__c                 C   sF   |d dkp|d dk}|rBddd|d d|d f}t j||}|S )Nr!   r   r   )r   
functionalpad)rB   r*   r-   r.   Z
should_pad
pad_valuesr   r   r    	maybe_pad   s
    zSwin2SRPatchMerging.maybe_pad)r*   input_dimensionsr7   c                 C   s   |\}}|j \}}}|||||}| |||}|d d dd ddd dd d f }|d d dd ddd dd d f }	|d d dd ddd dd d f }
|d d dd ddd dd d f }t||	|
|gd}||dd| }| |}| |}|S )Nr   r!   r   r$   r"   )r&   r'   r}   r   catrx   ry   )rB   r*   r~   r-   r.   r,   rs   r/   Zinput_feature_0Zinput_feature_1Zinput_feature_2Zinput_feature_3r   r   r    rG      s    $$$$

zSwin2SRPatchMerging.forward)r   r   r   rK   r   rh   r   rn   ModulerA   r}   r   rM   rG   rO   r   r   rC   r    rq      s   $rq   c                       sT   e Zd Zddgf fdd	Zd	ejeej eej ee e	ej dddZ
  ZS )
Swin2SRSelfAttentionr   c              
      s  t    || dkr,td| d| d|| _t|| | _| j| j | _t|tj	j
r`|n||f| _|| _ttdt|ddf | _ttjddd	d
tjd	dtjd|dd
| _tj| jd d  | jd tjd }tj| jd d  | jd tjd }tt||gddddd d}|d dkr|d d d d d d df  |d d   < |d d d d d d df  |d d   < nf|dkr
|d d d d d d df  | jd d   < |d d d d d d df  | jd d   < |d9 }t|t t!|d  t" d }|#t$| j% j&}| j'd|dd t| jd }	t| jd }
tt|	|
gdd}t(|d}|d d d d d f |d d d d d f  }|ddd }|d d d d df  | jd d 7  < |d d d d df  | jd d 7  < |d d d d df  d| jd  d 9  < |)d}| j'd|dd tj| j| j|j*d
| _+tj| j| jdd
| _,tj| j| j|j*d
| _-t.|j/| _0d S )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ()
   r   r!   i   Tru   inplaceFr8   Zij)Zindexing         ?relative_coords_table
persistentr$   relative_position_index)1r@   rA   
ValueErrornum_attention_headsrn   attention_head_sizeall_head_sizera   rb   rc   rd   r+   pretrained_window_sizer   rT   r   logZoneslogit_scale
Sequentialrw   ZReLUcontinuous_position_bias_mlpZarangeZint64rL   stackr   r(   r)   	unsqueezesignlog2absmathtonext
parametersr8   register_bufferrk   sumZqkv_biasquerykeyvaluerX   attention_probs_dropout_probrZ   )rB   r[   rs   	num_headsr+   r   Zrelative_coords_hZrelative_coords_wr   Zcoords_hZcoords_wZcoordsZcoords_flattenZrelative_coordsr   rC   r   r    rA      sb    
"&((,.
..&,((,
zSwin2SRSelfAttention.__init__NFr   attention_mask	head_maskoutput_attentionsr7   c                 C   s"  |j \}}}| ||d| j| jdd}| ||d| j| jdd}	| ||d| j| jdd}
tj	j
|ddtj	j
|	dddd }tj| jtdd }|| }| | jd| j}|| jd | jd | jd  | jd | jd  d}|ddd }d	t| }||d }|d ur|j d }||| || j|||dd }||dd }|d| j||}tj	j|dd}| |}|d ur|| }t||
}|dddd
 }| d d | jf }||}|r||fn|f}|S )Nr$   r   r!   rs   g      Y@)maxr      r   )r&   r   r'   r   r   rl   r   r   r   rz   	normalizer   clampr   r   r   expr   r   r   r+   r(   r)   Zsigmoidr   ZsoftmaxrZ   matmulsizer   )rB   r   r   r   r   r,   rs   r/   Zquery_layerZ	key_layerZvalue_layerZattention_scoresr   Zrelative_position_bias_tableZrelative_position_biasZ
mask_shapeZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr   r   r    rG   +  sl    


&




zSwin2SRSelfAttention.forward)NNF)r   r   r   rA   r   rM   r   r   boolr   rG   rO   r   r   rC   r    r      s   @   r   c                       s4   e Zd Z fddZejejejdddZ  ZS )Swin2SRSelfOutputc                    s*   t    t||| _t|j| _d S r?   )r@   rA   r   rw   denserX   r   rZ   rB   r[   rs   rC   r   r    rA   u  s    
zSwin2SRSelfOutput.__init__)r   input_tensorr7   c                 C   s   |  |}| |}|S r?   r   rZ   )rB   r   r   r   r   r    rG   z  s    

zSwin2SRSelfOutput.forwardr   r   r   rA   r   rM   rG   rO   r   r   rC   r    r   t  s   r   c                       sV   e Zd Zd fdd	Zdd Zdejeej eej ee	 e
ej dd	d
Z  ZS )Swin2SRAttentionr   c                    sL   t    t||||t|tjjr&|n||fd| _t||| _	t
 | _d S )Nr[   rs   r   r+   r   )r@   rA   r   ra   rb   rc   rd   rB   r   r<   setpruned_heads)rB   r[   rs   r   r+   r   rC   r   r    rA     s    
	zSwin2SRAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r   )lenr   rB   r   r   r   r   r   r   r   r<   r   r   union)rB   headsindexr   r   r    prune_heads  s    zSwin2SRAttention.prune_headsNFr   c                 C   s6   |  ||||}| |d |}|f|dd   }|S Nr   r   )rB   r<   )rB   r   r   r   r   Zself_outputsattention_outputr   r   r   r    rG     s    zSwin2SRAttention.forward)r   )NNF)r   r   r   rA   r   r   rM   r   r   r   r   rG   rO   r   r   rC   r    r     s      r   c                       s0   e Zd Z fddZejejdddZ  ZS )Swin2SRIntermediatec                    sH   t    t|t|j| | _t|jt	r<t
|j | _n|j| _d S r?   )r@   rA   r   rw   rn   	mlp_ratior   ra   Z
hidden_actrN   r   intermediate_act_fnr   rC   r   r    rA     s
    
zSwin2SRIntermediate.__init__rE   c                 C   s   |  |}| |}|S r?   )r   r   rF   r   r   r    rG     s    

zSwin2SRIntermediate.forwardr   r   r   rC   r    r     s   r   c                       s0   e Zd Z fddZejejdddZ  ZS )Swin2SROutputc                    s4   t    tt|j| || _t|j| _	d S r?   )
r@   rA   r   rw   rn   r   r   rX   rY   rZ   r   rC   r   r    rA     s    
zSwin2SROutput.__init__rE   c                 C   s   |  |}| |}|S r?   r   rF   r   r   r    rG     s    

zSwin2SROutput.forwardr   r   r   rC   r    r     s   r   c                       s   e Zd Zd fdd	Zeeeef eeef f dddZdd	 Zd
d Zde	j
eeef ee	j ee ee	j
e	j
f dddZ  ZS )Swin2SRLayerr3   r   c           	         s   t    || _| |j|jf||f\}}|d | _|d | _t|||| jt|tj	j
r^|n||fd| _tj||jd| _|dkrt|nt | _t||| _t||| _tj||jd| _d S )Nr   r   epsr3   )r@   rA   rr   _compute_window_shiftr+   
shift_sizer   ra   rb   rc   rd   	attentionr   rh   layer_norm_epslayernorm_beforer>   ZIdentityr=   r   intermediater   r<   layernorm_after)	rB   r[   rs   rr   r   drop_path_rater   r   r+   rC   r   r    rA     s*    


	zSwin2SRLayer.__init__rH   c                 C   s6   dd t | j|D }dd t | j||D }||fS )Nc                 S   s    g | ]\}}||kr|n|qS r   r   ).0rwr   r   r    
<listcomp>      z6Swin2SRLayer._compute_window_shift.<locals>.<listcomp>c                 S   s"   g | ]\}}}||krd n|qS )r   r   )r   r   r   sr   r   r    r     r   )ziprr   )rB   Ztarget_window_sizeZtarget_shift_sizer+   r   r   r   r    r     s    z"Swin2SRLayer._compute_window_shiftc              	   C   s  | j dkrtjd||df|d}td| j t| j | j  t| j  d f}td| j t| j | j  t| j  d f}d}|D ].}|D ]$}	||d d ||	d d f< |d7 }qqt|| j}
|
d| j| j }
|
d|
d }||dkd|dkd}nd }|S )Nr   r   r   r$   r!   g      Yr3   )	r   r   rU   slicer+   r1   r'   r   Zmasked_fill)rB   r-   r.   r8   Zimg_maskZheight_slicesZwidth_slicescountZheight_sliceZwidth_sliceZmask_windows	attn_maskr   r   r    get_attn_mask  s*    zSwin2SRLayer.get_attn_maskc                 C   sR   | j || j   | j  }| j || j   | j  }ddd|d|f}tj||}||fS )Nr   )r+   r   rz   r{   )rB   r   r-   r.   	pad_rightZ
pad_bottomr|   r   r   r    r}   	  s
    zSwin2SRLayer.maybe_padNFr   r~   r   r   r7   c                 C   s  |\}}|  \}}}	|}
|||||	}| |||\}}|j\}}}}| jdkrrtj|| j | j fdd}n|}t|| j}|d| j| j |	}| j	|||j
d}|d ur||j}| j||||d}|d }|d| j| j|	}t|| j||}| jdkr"tj|| j| jfdd}n|}|d dkp>|d dk}|rj|d d d |d |d d f  }|||| |	}| |}|
| | }| |}| |}|| | | }|r||d	 fn|f}|S )
Nr   )r   r!   )Zshiftsdimsr$   r   )r   r   r#   r   )r   r'   r}   r&   r   r   Zrollr1   r+   r   r8   r   r9   r   r2   r)   r   r=   r   r<   r   )rB   r   r~   r   r   r-   r.   r,   rm   ZchannelsZshortcutr|   Z
height_padZ	width_padZshifted_hidden_statesZhidden_states_windowsr   Zattention_outputsr   Zattention_windowsZshifted_windowsZ
was_paddedZlayer_outputlayer_outputsr   r   r    rG     sD    
$


zSwin2SRLayer.forward)r3   r   r   )NF)r   r   r   rA   r   rn   r   r   r}   r   rM   r   r   r   rG   rO   r   r   rC   r    r     s    &  
r   c                       sT   e Zd ZdZd
 fdd	Zdejeeef e	ej
 e	e eej ddd	Z  ZS )Swin2SRStagezh
    This corresponds to the Residual Swin Transformer Block (RSTB) in the original implementation.
    r   c                    s   t     | _| _t fddt|D | _ jdkr\t	ddd| _
nl jdkrtt	d dddtjdd	d
t	d d dddtjdd	d
t	d ddd| _
t dd| _t | _d S )Nc              
      s6   g | ].}t  |d  dkr"dn jd  dqS )r!   r   )r[   rs   rr   r   r   r   )r   r+   )r   ir[   rs   rr   r   r   r   r    r   U  s   	z)Swin2SRStage.__init__.<locals>.<listcomp>Z1convr   r   Z3convr"   皙?TZnegative_sloper   r   F)rj   )r@   rA   r[   rs   r   
ModuleListrangelayersZresi_connectionrf   convr   	LeakyReLUrQ   patch_embedro   patch_unembed)rB   r[   rs   rr   depthr   r=   r   rC   r   r    rA   P  s(    
	

zSwin2SRStage.__init__NFr   c                 C   s   |}|\}}t | jD ]2\}}	|d ur.|| nd }
|	|||
|}|d }q||||f}| ||}| |}| |\}}|| }||f}|r||dd  7 }|S r   )	enumerater   r   r   r   )rB   r   r~   r   r   Zresidualr-   r.   r   Zlayer_modulelayer_head_maskr   r^   rm   Zstage_outputsr   r   r    rG   r  s    

zSwin2SRStage.forward)r   )NF)r   r   r   rK   rA   r   rM   r   rn   r   r   r   rG   rO   r   r   rC   r    r   K  s   &  
r   c                
       s\   e Zd Z fddZd	ejeeef eej	 ee
 ee
 ee
 eeef dddZ  ZS )
Swin2SREncoderc                    sn   t    t j| _ | _dd tjd jt	 jddD t
 fddt| jD | _d| _d S )Nc                 S   s   g | ]}|  qS r   )item)r   xr   r   r    r     r   z+Swin2SREncoder.__init__.<locals>.<listcomp>r   cpu)r9   c                    sd   g | ]\}t   jd  d f j|  j| t jd| t jd|d   d dqS )r   r   N)r[   rs   rr   r   r   r=   r   )r   rV   depthsr   r   )r   Z	stage_idxr[   Zdpr	grid_sizer   r    r     s   
*F)r@   rA   r   r   Z
num_stagesr[   r   Zlinspacer   r   r   r   r   stagesZgradient_checkpointing)rB   r[   r   rC   r   r    rA     s    
$
zSwin2SREncoder.__init__NFT)r   r~   r   r   output_hidden_statesreturn_dictr7   c                 C   s   d}|rdnd }|rdnd }	|r*||f7 }t | jD ]v\}
}|d urL||
 nd }|||||}|d }|d }|d |d f}||f7 }|r||f7 }|r4|	|dd  7 }	q4|stdd |||	fD S t|||	d	S )
Nr   r   r   r   r$   r!   c                 s   s   | ]}|d ur|V  qd S r?   r   )r   vr   r   r    	<genexpr>  r   z)Swin2SREncoder.forward.<locals>.<genexpr>r   r   r   )r   r   r   r   )rB   r   r~   r   r   r   r   Zall_input_dimensionsZall_hidden_statesZall_self_attentionsr   Zstage_moduler   r   r^   r   r   r    rG     s.    	


zSwin2SREncoder.forward)NFFT)r   r   r   rA   r   rM   r   rn   r   r   r   r   r   rG   rO   r   r   rC   r    r     s       

r   c                   @   s*   e Zd ZU eed< dZdZdZdd ZdS )Swin2SRPreTrainedModelr[   swin2srr\   Tc                 C   sn   t |tjtjfrDtjjj|jj| j	j
d |jdurj|jj  n&t |tjrj|jj  |jjd dS )zInitialize the weights)ZstdNr   )ra   r   rw   rf   r   initZtrunc_normal_weightdatar[   Zinitializer_rangerv   Zzero_rh   Zfill_)rB   moduler   r   r    _init_weights  s    
z$Swin2SRPreTrainedModel._init_weightsN)	r   r   r   r   r   Zbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingr  r   r   r   r    r     s
   
r   c                
       sn   e Zd Z fddZdd Zdd Zdd Zedej	e
ej	 e
e e
e e
e eeef d
ddZ  ZS )Swin2SRModelc                    s   t  | || _|jdkrB|jdkrBtg ddddd}ntdddd}| j	d|dd |j
| _
t|j|jddd| _t|| _t|| jjjd| _tj|j|jd| _t|| _t|j|jddd| _|   d S )	Nr   )gw#?g8EGr?gB`"?r   meanFr   )r   r   )r@   rA   r[   r/   num_channels_outr   Ztensorr'   rU   r   	img_ranger   rf   rV   first_convolutionrP   r]   r   rR   re   encoderrh   r   ri   ro   r   conv_after_body	post_init)rB   r[   r  rC   r   r    rA     s    

zSwin2SRModel.__init__c                 C   s   | j jS r?   )r]   rR   rI   r   r   r    get_input_embeddings  s    z!Swin2SRModel.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr  layerr   r   )rB   Zheads_to_pruner  r   r   r   r    _prune_heads  s    zSwin2SRModel._prune_headsc           	      C   sn   |  \}}}}| jj}|||  | }|||  | }tj|d|d|fd}| j|}|| | j }|S )Nr   Zreflect)	r   r[   r+   r   rz   r{   r  Ztype_asr
  )	rB   r\   rm   r-   r.   r+   Zmodulo_pad_heightZmodulo_pad_widthr  r   r   r    pad_and_normalize
  s    zSwin2SRModel.pad_and_normalizeN)r\   r   r   r   r   r7   c                 C   s   |d ur|n| j j}|d ur |n| j j}|d ur4|n| j j}| |t| j j}|j\}}}}| |}| 	|}	| 
|	\}
}| j|
|||||d}|d }| |}| |||f}| ||	 }|s|f|dd   }|S t||j|jdS )Nr   r   r   r   r   r   r   )r[   r   r   use_return_dictZget_head_maskr   r   r&   r  r  r]   r  ri   r   r  r	   r   r   )rB   r\   r   r   r   r   rm   r-   r.   r]   Zembedding_outputr~   Zencoder_outputssequence_outputr<   r   r   r    rG     s:    	

	
zSwin2SRModel.forward)NNNN)r   r   r   rA   r  r  r  r   r   r   r   r   r   r   r	   rG   rO   r   r   rC   r    r    s"       
r  c                       s(   e Zd ZdZ fddZdd Z  ZS )UpsamplezUpsample module.

    Args:
        scale (`int`):
            Scale factor. Supported scales: 2^n and 3.
        num_features (`int`):
            Channel number of intermediate features.
    c                    s   t    || _||d @ dkrxttt|dD ]@}| d| t	|d| ddd | d| t
d q4n>|dkrt	|d| ddd| _t
d| _ntd	| d
d S )Nr   r   r!   convolution_r"   r   pixelshuffle_	   zScale z/ is not supported. Supported scales: 2^n and 3.)r@   rA   scaler   rn   r   r   Z
add_moduler   rf   PixelShuffleconvolutionpixelshuffler   )rB   r  num_featuresr   rC   r   r    rA   \  s    
$zUpsample.__init__c                 C   s|   | j | j d @ dkrZttt| j dD ],}| d| |}| d| |}q*n| j dkrx| |}| |}|S )Nr   r   r!   r  r  r   )r  r   rn   r   r   __getattr__r  r  )rB   Zhidden_stater   r   r   r    rG   k  s    


zUpsample.forwardrp   r   r   rC   r    r  R  s   	r  c                       s(   e Zd ZdZ fddZdd Z  ZS )UpsampleOneStepa  UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)

    Used in lightweight SR to save parameters.

    Args:
        scale (int):
            Scale factor. Supported scales: 2^n and 3.
        in_channels (int):
            Channel number of intermediate features.
        out_channels (int):
            Channel number of output features.
    c                    s6   t    t||d | ddd| _t|| _d S )Nr!   r   r   )r@   rA   r   rf   r   r  pixel_shuffle)rB   r  Zin_channelsZout_channelsrC   r   r    rA     s    
zUpsampleOneStep.__init__c                 C   s   |  |}| |}|S r?   )r   r"  )rB   r   r   r   r    rG     s    

zUpsampleOneStep.forwardrp   r   r   rC   r    r!  x  s   r!  c                       s$   e Zd Z fddZdd Z  ZS )PixelShuffleUpsamplerc                    sV   t    t|j|ddd| _tjdd| _t|j	|| _
t||jddd| _d S Nr   r   Tr   )r@   rA   r   rf   rV   conv_before_upsampler   
activationr  upscaleupsampler	  final_convolutionrB   r[   r  rC   r   r    rA     s
    
zPixelShuffleUpsampler.__init__c                 C   s,   |  |}| |}| |}| |}|S r?   )r%  r&  r(  r)  )rB   r  r   r   r   r    rG     s
    



zPixelShuffleUpsampler.forwardr   r   r   rA   rG   rO   r   r   rC   r    r#    s   r#  c                       s$   e Zd Z fddZdd Z  ZS )NearestConvUpsamplerc                    s   t    |jdkrtdt|j|ddd| _tjdd| _	t||ddd| _
t||ddd| _t||ddd| _t||jddd| _tjddd| _d S )	Nr"   zNThe nearest+conv upsampler only supports an upscale factor of 4 at the moment.r   r   Tr   r   r   )r@   rA   r'  r   r   rf   rV   r%  r   r&  conv_up1conv_up2conv_hrr	  r)  lrelur*  rC   r   r    rA     s    

zNearestConvUpsampler.__init__c              	   C   sn   |  |}| |}| | tjjj|ddd}| | tjjj|ddd}| 	| | 
|}|S )Nr!   Znearest)Zscale_factormode)r%  r&  r0  r-  r   r   rz   interpolater.  r)  r/  )rB   r  reconstructionr   r   r    rG     s    

zNearestConvUpsampler.forwardr+  r   r   rC   r    r,    s   r,  c                       s$   e Zd Z fddZdd Z  ZS )PixelShuffleAuxUpsamplerc              	      s   t    |j| _t|j|ddd| _t|j|ddd| _tj	dd| _
t||jddd| _ttd|dddtj	dd| _t|j|| _t||jddd| _d S r$  )r@   rA   r'  r   rf   r/   conv_bicubicrV   r%  r   r&  conv_auxr   conv_after_auxr  r(  r	  r)  r*  rC   r   r    rA     s    
$z!PixelShuffleAuxUpsampler.__init__c                 C   s   |  |}| |}| |}| |}| |}| |d d d d d || j d || j f |d d d d d || j d || j f  }| |}||fS r?   )r5  r%  r&  r6  r7  r(  r'  r)  )rB   r  bicubicr-   r.   auxr3  r   r   r    rG     s    




0*
z PixelShuffleAuxUpsampler.forwardr+  r   r   rC   r    r4    s   r4  zm
    Swin2SR Model transformer with an upsampler head on top for image super resolution and restoration.
    c                       sb   e Zd Z fddZedeej eej eej ee	 ee	 ee	 e
eef dddZ  ZS )Swin2SRForImageSuperResolutionc                    s   t  | t|| _|j| _|j| _d}| jdkrBt||| _nh| jdkrZt||| _nP| jdkrzt	|j|j
|j| _n0| jdkrt||| _nt|j
|jddd| _|   d S )N@   r  pixelshuffle_auxpixelshuffledirectnearest+convr   r   )r@   rA   r  r  	upsamplerr'  r#  r(  r4  r!  rV   r	  r,  r   rf   r)  r  r*  rC   r   r    rA     s    




z'Swin2SRForImageSuperResolution.__init__N)r\   r   labelsr   r   r   r7   c                 C   sb  |dur|n| j j}d}|dur(td|jdd \}}	| j jdkrjtjj||| j |	| j fddd}
| j	|||||d}|d	 }| jd
v r| 
|}nB| jdkr| 
||
||	\}}|| j	j | j	j }n|| | }|| j	j | j	j }|ddddd|| j d|	| j f }|sN|f|dd  }|durJ|f| S |S t|||j|jdS )a  
        Example:
         ```python
         >>> import torch
         >>> import numpy as np
         >>> from PIL import Image
         >>> import requests

         >>> from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution

         >>> processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
         >>> model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64")

         >>> url = "https://huggingface.co/spaces/jjourney1125/swin2sr/resolve/main/samples/butterfly.jpg"
         >>> image = Image.open(requests.get(url, stream=True).raw)
         >>> # prepare image for the model
         >>> inputs = processor(image, return_tensors="pt")

         >>> # forward pass
         >>> with torch.no_grad():
         ...     outputs = model(**inputs)

         >>> output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
         >>> output = np.moveaxis(output, source=0, destination=-1)
         >>> output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8
         >>> # you can visualize `output` with `Image.fromarray`
         ```Nz'Training is not supported at the momentr!   r<  r8  F)r   r1  Zalign_cornersr  r   )r  r=  r>  r   )lossr3  r   r   )r[   r  NotImplementedErrorr&   r?  r   rz   r2  r'  r  r(  r
  r  r)  r
   r   r   )rB   r\   r   r@  r   r   r   rA  r-   r.   r8  r   r  r3  r9  r<   r   r   r    rG     sJ    %

,z&Swin2SRForImageSuperResolution.forward)NNNNNN)r   r   r   rA   r   r   r   r   Z
LongTensorr   r   r   r
   rG   rO   r   r   rC   r    r:    s"         
r:  )r:  r  r   )r3   F)@rK   collections.abcrb   r   dataclassesr   typingr   r   r   Ztorch.utils.checkpointr   Zactivationsr   Zmodeling_layersr   Zmodeling_outputsr	   r
   Zmodeling_utilsr   Zpytorch_utilsr   r   r   utilsr   r   r   Zconfiguration_swin2srr   Z
get_loggerr   loggerr   r1   r2   rM   rL   r   r=   r   r>   rP   rQ   ro   rq   r   r   r   r   r   r   r   r   r   r  r  r!  r#  r,  r4  r:  __all__r   r   r   r    <module>   sf   
7 /}GBk&q