a
    h{                     @   s  d Z ddlZddlmZ ddlmZmZ ddlZddlZddlm	Z	 ddl
mZmZmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZ ddlmZmZmZmZmZmZ ddlmZ ddl m!Z! e rddl"m#Z#m$Z$ ndd Z$dd Z#e%e&Z'eeddG dd deZ(eeddG dd deZ)eeddG dd deZ*G dd  d e	j+Z,G d!d" d"e	j+Z-G d#d$ d$e	j+Z.dGej/e0e1ej/d'd(d)Z2G d*d+ d+e	j+Z3G d,d- d-e	j+Z4G d.d/ d/e	j+Z5G d0d1 d1e	j+Z6G d2d3 d3e	j+Z7G d4d5 d5e	j+Z8G d6d7 d7e	j+Z9G d8d9 d9e	j+Z:G d:d; d;e	j+Z;eG d<d= d=eZ<eG d>d? d?e<Z=ed@dG dAdB dBe<Z>edCdG dDdE dEe<eZ?g dFZ@dS )Hz9PyTorch Dilated Neighborhood Attention Transformer model.    N)	dataclass)OptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BackboneOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputOptionalDependencyNotAvailableauto_docstringis_natten_availableloggingrequires_backends)BackboneMixin   )DinatConfig)
natten2davnatten2dqkrpbc                  O   s
   t  d S Nr   argskwargs r   d/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/dinat/modeling_dinat.pyr   .   s    r   c                  O   s
   t  d S r   r   r   r   r   r    r   1   s    r   zO
    Dinat encoder's outputs, with potential hidden states and attentions.
    )Zcustom_introc                   @   sr   e Zd ZU dZdZeej ed< dZ	ee
ejdf  ed< dZee
ejdf  ed< dZee
ejdf  ed< dS )DinatEncoderOutputa  
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nlast_hidden_state.hidden_states
attentionsreshaped_hidden_states)__name__
__module____qualname____doc__r"   r   torchFloatTensor__annotations__r#   tupler$   r%   r   r   r   r    r!   ;   s
   
	r!   zW
    Dinat model's outputs that also contains a pooling of the last hidden states.
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< dS )	DinatModelOutputa  
    pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
        Average pooling of the last layer hidden-state.
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nr"   pooler_output.r#   r$   r%   )r&   r'   r(   r)   r"   r   r*   r+   r,   r/   r#   r-   r$   r%   r   r   r   r    r.   Q   s   
r.   z1
    Dinat outputs for image classification.
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< dS )	DinatImageClassifierOutputa7  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Classification (or regression if config.num_labels==1) loss.
    logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
        Classification (or regression if config.num_labels==1) scores (before SoftMax).
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nlosslogits.r#   r$   r%   )r&   r'   r(   r)   r1   r   r*   r+   r,   r2   r#   r-   r$   r%   r   r   r   r    r0   j   s   
r0   c                       s<   e Zd ZdZ fddZeej eej	 dddZ
  ZS )DinatEmbeddingsz6
    Construct the patch and position embeddings.
    c                    s4   t    t|| _t|j| _t|j	| _
d S r   )super__init__DinatPatchEmbeddingspatch_embeddingsr   	LayerNorm	embed_dimnormDropouthidden_dropout_probdropoutselfconfig	__class__r   r    r5      s    

zDinatEmbeddings.__init__pixel_valuesreturnc                 C   s"   |  |}| |}| |}|S r   )r7   r:   r=   )r?   rD   
embeddingsr   r   r    forward   s    


zDinatEmbeddings.forward)r&   r'   r(   r)   r5   r   r*   r+   r-   TensorrG   __classcell__r   r   rA   r    r3      s   r3   c                       s8   e Zd ZdZ fddZeej ejdddZ	  Z
S )r6   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, height, width, hidden_size)` to be consumed by a
    Transformer.
    c              
      sr   t    |j}|j|j }}|| _|dkr.ntdttj| j|d ddddtj|d |dddd| _	d S )N   z2Dinat only supports patch size of 4 at the moment.   r	   r	   rK   rK   r   r   )kernel_sizestridepadding)
r4   r5   
patch_sizenum_channelsr9   
ValueErrorr   Z
SequentialConv2d
projection)r?   r@   rR   rS   Zhidden_sizerA   r   r    r5      s    
zDinatPatchEmbeddings.__init__rC   c                 C   s>   |j \}}}}|| jkr td| |}|dddd}|S )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   rK   r	   r   )shaperS   rT   rV   permute)r?   rD   _rS   heightwidthrF   r   r   r    rG      s    

zDinatPatchEmbeddings.forward)r&   r'   r(   r)   r5   r   r*   r+   rH   rG   rI   r   r   rA   r    r6      s   r6   c                       sF   e Zd ZdZejfeejdd fddZe	j
e	j
dddZ  ZS )	DinatDownsamplerz
    Convolutional Downsampling Layer.

    Args:
        dim (`int`):
            Number of input channels.
        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
            Normalization layer class.
    N)dim
norm_layerrE   c                    s>   t    || _tj|d| ddddd| _|d| | _d S )NrK   rL   rM   rN   F)rO   rP   rQ   bias)r4   r5   r]   r   rU   	reductionr:   )r?   r]   r^   rA   r   r    r5      s    
zDinatDownsampler.__init__)input_featurerE   c                 C   s0   |  |dddddddd}| |}|S )Nr   r	   r   rK   )r`   rX   r:   )r?   ra   r   r   r    rG      s    "
zDinatDownsampler.forward)r&   r'   r(   r)   r   r8   intModuler5   r*   rH   rG   rI   r   r   rA   r    r\      s   
r\           F)input	drop_probtrainingrE   c                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    rd   r   r   )r   )dtypedevice)rW   ndimr*   Zrandrh   ri   Zfloor_div)re   rf   rg   Z	keep_probrW   Zrandom_tensoroutputr   r   r    	drop_path   s    
rm   c                       sP   e Zd ZdZdee dd fddZejejdddZ	e
d	d
dZ  ZS )DinatDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).N)rf   rE   c                    s   t    || _d S r   )r4   r5   rf   )r?   rf   rA   r   r    r5      s    
zDinatDropPath.__init__r#   rE   c                 C   s   t || j| jS r   )rm   rf   rg   r?   r#   r   r   r    rG      s    zDinatDropPath.forward)rE   c                 C   s   d| j  S )Nzp=)rf   r?   r   r   r    
extra_repr   s    zDinatDropPath.extra_repr)N)r&   r'   r(   r)   r   floatr5   r*   rH   rG   strrr   rI   r   r   rA   r    rn      s   rn   c                       s<   e Zd Z fddZdejee eej dddZ	  Z
S )NeighborhoodAttentionc                    s   t    || dkr,td| d| d|| _t|| | _| j| j | _|| _|| _t	
t|d| j d d| j d | _t	j| j| j|jd| _t	j| j| j|jd| _t	j| j| j|jd| _t	|j| _d S )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ()rK   r   )r_   )r4   r5   rT   num_attention_headsrb   attention_head_sizeall_head_sizerO   dilationr   	Parameterr*   ZzerosrpbLinearZqkv_biasquerykeyvaluer;   attention_probs_dropout_probr=   r?   r@   r]   	num_headsrO   rz   rA   r   r    r5      s    
*zNeighborhoodAttention.__init__Fr#   output_attentionsrE   c                 C   s  |j \}}}| ||d| j| jdd}| ||d| j| jdd}| ||d| j| jdd}|t	| j }t
||| j| j| j}	tjj|	dd}
| |
}
t|
|| j| j}|ddddd }| d d | jf }||}|r
||
fn|f}|S )	Nr   rK   r]   r   r	   rJ   )rW   r~   viewrw   rx   	transposer   r   mathsqrtr   r|   rO   rz   r   
functionalZsoftmaxr=   r   rX   
contiguoussizery   )r?   r#   r   
batch_sizeZ
seq_lengthrY   Zquery_layerZ	key_layerZvalue_layerZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr   r   r    rG     s:    


	

zNeighborhoodAttention.forward)Fr&   r'   r(   r5   r*   rH   r   boolr-   rG   rI   r   r   rA   r    ru      s    ru   c                       s4   e Zd Z fddZejejejdddZ  ZS )NeighborhoodAttentionOutputc                    s*   t    t||| _t|j| _d S r   )r4   r5   r   r}   denser;   r   r=   r?   r@   r]   rA   r   r    r5   A  s    
z$NeighborhoodAttentionOutput.__init__)r#   input_tensorrE   c                 C   s   |  |}| |}|S r   r   r=   )r?   r#   r   r   r   r    rG   F  s    

z#NeighborhoodAttentionOutput.forwardr&   r'   r(   r5   r*   rH   rG   rI   r   r   rA   r    r   @  s   r   c                       sD   e Zd Z fddZdd Zd	ejee e	ej dddZ
  ZS )
NeighborhoodAttentionModulec                    s4   t    t|||||| _t||| _t | _d S r   )r4   r5   ru   r?   r   rl   setpruned_headsr   rA   r   r    r5   N  s    
z$NeighborhoodAttentionModule.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r   )lenr   r?   rw   rx   r   r   r~   r   r   rl   r   ry   union)r?   headsindexr   r   r    prune_headsT  s    z'NeighborhoodAttentionModule.prune_headsFr   c                 C   s2   |  ||}| |d |}|f|dd   }|S Nr   r   )r?   rl   )r?   r#   r   Zself_outputsattention_outputr   r   r   r    rG   f  s    z#NeighborhoodAttentionModule.forward)F)r&   r'   r(   r5   r   r*   rH   r   r   r-   rG   rI   r   r   rA   r    r   M  s    r   c                       s0   e Zd Z fddZejejdddZ  ZS )DinatIntermediatec                    sH   t    t|t|j| | _t|jt	r<t
|j | _n|j| _d S r   )r4   r5   r   r}   rb   	mlp_ratior   
isinstanceZ
hidden_actrt   r
   intermediate_act_fnr   rA   r   r    r5   r  s
    
zDinatIntermediate.__init__ro   c                 C   s   |  |}| |}|S r   )r   r   rp   r   r   r    rG   z  s    

zDinatIntermediate.forwardr   r   r   rA   r    r   q  s   r   c                       s0   e Zd Z fddZejejdddZ  ZS )DinatOutputc                    s4   t    tt|j| || _t|j| _	d S r   )
r4   r5   r   r}   rb   r   r   r;   r<   r=   r   rA   r   r    r5     s    
zDinatOutput.__init__ro   c                 C   s   |  |}| |}|S r   r   rp   r   r   r    rG     s    

zDinatOutput.forwardr   r   r   rA   r    r     s   r   c                       sL   e Zd Zd
 fdd	Zdd Zdejee e	ejejf ddd	Z
  ZS )
DinatLayerrd   c                    s   t    |j| _|j| _|| _| j| j | _tj||jd| _	t
|||| j| jd| _|dkrht|nt | _tj||jd| _t||| _t||| _|jdkrtj|jtd|f ddnd | _d S )Neps)rO   rz   rd   r   rK   T)Zrequires_grad)r4   r5   Zchunk_size_feed_forwardrO   rz   window_sizer   r8   layer_norm_epslayernorm_beforer   	attentionrn   Identityrm   layernorm_afterr   intermediater   rl   Zlayer_scale_init_valuer{   r*   Zoneslayer_scale_parameters)r?   r@   r]   r   rz   drop_path_raterA   r   r    r5     s"    
 zDinatLayer.__init__c           
      C   sd   | j }d}||k s||k r\d }}td|| }td|| }	dd||||	f}tj||}||fS )N)r   r   r   r   r   r   r   )r   maxr   r   pad)
r?   r#   rZ   r[   r   
pad_valuesZpad_lZpad_tZpad_rZpad_br   r   r    	maybe_pad  s    zDinatLayer.maybe_padFr   c                 C   s  |  \}}}}|}| |}| |||\}}|j\}	}
}}	| j||d}|d }|d dkpj|d dk}|r|d d d |d |d d f  }| jd ur| jd | }|| | }| |}| 	| 
|}| jd ur| jd | }|| | }|r||d fn|f}|S )N)r   r   r	      r   )r   r   r   rW   r   r   r   rm   r   rl   r   )r?   r#   r   r   rZ   r[   channelsZshortcutr   rY   Z
height_padZ	width_padZattention_outputsr   Z
was_paddedZlayer_outputlayer_outputsr   r   r    rG     s(    
$


zDinatLayer.forward)rd   )F)r&   r'   r(   r5   r   r*   rH   r   r   r-   rG   rI   r   r   rA   r    r     s    r   c                       s<   e Zd Z fddZdejee eej dddZ	  Z
S )
DinatStagec                    sf   t     | _| _t fddt|D | _|d urV|tjd| _	nd | _	d| _
d S )Nc              	      s&   g | ]}t  | | d qS ))r@   r]   r   rz   r   )r   .0ir@   	dilationsr]   r   r   r   r    
<listcomp>  s   z'DinatStage.__init__.<locals>.<listcomp>)r]   r^   F)r4   r5   r@   r]   r   
ModuleListrangelayersr8   
downsampleZpointing)r?   r@   r]   depthr   r   r   r   rA   r   r    r5     s    
zDinatStage.__init__Fr   c                 C   sn   |  \}}}}t| jD ]\}}|||}|d }q|}	| jd urN| |	}||	f}
|rj|
|dd  7 }
|
S r   )r   	enumerater   r   )r?   r#   r   rY   rZ   r[   r   layer_moduler   !hidden_states_before_downsamplingZstage_outputsr   r   r    rG     s    



zDinatStage.forward)Fr   r   r   rA   r    r     s    r   c                	       sP   e Zd Z fddZdejee ee ee ee ee	e
f dddZ  ZS )	DinatEncoderc                    sh   t    t j_ _dd tjd jt	 jddD t
 fddtjD _d S )Nc                 S   s   g | ]}|  qS r   )item)r   xr   r   r    r         z)DinatEncoder.__init__.<locals>.<listcomp>r   cpu)ri   c                    s|   g | ]t}t  t jd |   j|  j|  j| t jd| t jd|d   |jd k rptnddqS )rK   Nr   )r@   r]   r   r   r   r   r   )	r   rb   r9   depthsr   r   sum
num_levelsr\   )r   Zi_layerr@   Zdprr?   r   r    r   	  s   
*)r4   r5   r   r   r   r@   r*   Zlinspacer   r   r   r   r   levelsr>   rA   r   r    r5     s    
$
zDinatEncoder.__init__FT)r#   r   output_hidden_states(output_hidden_states_before_downsamplingreturn_dictrE   c                 C   s  |rdnd }|rdnd }|r dnd }|rL| dddd}	||f7 }||	f7 }t| jD ]\}
}|||}|d }|d }|r|r| dddd}	||f7 }||	f7 }n,|r|s| dddd}	||f7 }||	f7 }|rV||dd  7 }qV|stdd |||fD S t||||dS )	Nr   r   r	   r   rK   c                 s   s   | ]}|d ur|V  qd S r   r   )r   vr   r   r    	<genexpr>>  r   z'DinatEncoder.forward.<locals>.<genexpr>)r"   r#   r$   r%   )rX   r   r   r-   r!   )r?   r#   r   r   r   r   Zall_hidden_statesZall_reshaped_hidden_statesZall_self_attentionsZreshaped_hidden_stater   r   r   r   r   r   r    rG     s:    





zDinatEncoder.forward)FFFT)r&   r'   r(   r5   r*   rH   r   r   r   r-   r!   rG   rI   r   r   rA   r    r     s       
r   c                   @   s&   e Zd ZU eed< dZdZdd ZdS )DinatPreTrainedModelr@   dinatrD   c                 C   sj   t |tjtjfr@|jjjd| jjd |j	durf|j	j
  n&t |tjrf|j	j
  |jjd dS )zInitialize the weightsrd   )meanZstdNg      ?)r   r   r}   rU   weightdataZnormal_r@   Zinitializer_ranger_   Zzero_r8   Zfill_)r?   moduler   r   r    _init_weightsN  s    
z"DinatPreTrainedModel._init_weightsN)r&   r'   r(   r   r,   Zbase_model_prefixZmain_input_namer   r   r   r   r    r   H  s   
r   c                	       sd   e Zd Zd fdd	Zdd Zdd Zedeej	 ee
 ee
 ee
 eeef d	d
dZ  ZS )
DinatModelTc                    s   t  | t| dg || _t|j| _t|jd| jd   | _	t
|| _t|| _tj| j	|jd| _|rztdnd| _|   dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        nattenrK   r   r   N)r4   r5   r   r@   r   r   r   rb   r9   num_featuresr3   rF   r   encoderr   r8   r   	layernormZAdaptiveAvgPool1dpooler	post_init)r?   r@   Zadd_pooling_layerrA   r   r    r5   ]  s    

zDinatModel.__init__c                 C   s   | j jS r   rF   r7   rq   r   r   r    get_input_embeddingss  s    zDinatModel.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   layerr   r   )r?   Zheads_to_pruner   r   r   r   r    _prune_headsv  s    zDinatModel._prune_headsN)rD   r   r   r   rE   c           
      C   s   |d ur|n| j j}|d ur |n| j j}|d ur4|n| j j}|d u rLtd| |}| j||||d}|d }| |}d }| jd ur| |	dd
dd}t	|d}|s||f|dd   }	|	S t|||j|j|jdS )Nz You have to specify pixel_valuesr   r   r   r   r   rK   )r"   r/   r#   r$   r%   )r@   r   r   use_return_dictrT   rF   r   r   r   flattenr   r*   r.   r#   r$   r%   )
r?   rD   r   r   r   embedding_outputZencoder_outputsZsequence_outputpooled_outputrl   r   r   r    rG   ~  s:    


zDinatModel.forward)T)NNNN)r&   r'   r(   r5   r   r   r   r   r*   r+   r   r   r-   r.   rG   rI   r   r   rA   r    r   [  s       
r   z
    Dinat Model transformer with an image classification head on top (a linear layer on top of the final hidden state
    of the [CLS] token) e.g. for ImageNet.
    c                
       sZ   e Zd Z fddZedeej eej ee	 ee	 ee	 e
eef dddZ  ZS )DinatForImageClassificationc                    s\   t  | t| dg |j| _t|| _|jdkrFt| jj|jnt	 | _
|   d S )Nr   r   )r4   r5   r   
num_labelsr   r   r   r}   r   r   
classifierr   r>   rA   r   r    r5     s    
"z$DinatForImageClassification.__init__N)rD   labelsr   r   r   rE   c                 C   sl  |dur|n| j j}| j||||d}|d }| |}d}	|dur$| j jdu r| jdkrfd| j _n4| jdkr|jtjks|jtj	krd| j _nd| j _| j jdkrt
 }
| jdkr|
| | }	n
|
||}	nN| j jdkrt }
|
|d| j|d}	n| j jdkr$t }
|
||}	|sT|f|dd  }|	durP|	f| S |S t|	||j|j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr   rK   )r1   r2   r#   r$   r%   )r@   r   r   r   Zproblem_typer   rh   r*   longrb   r   Zsqueezer   r   r   r0   r#   r$   r%   )r?   rD   r   r   r   r   r   r   r2   r1   Zloss_fctrl   r   r   r    rG     sL    



"


z#DinatForImageClassification.forward)NNNNN)r&   r'   r(   r5   r   r   r*   r+   Z
LongTensorr   r   r-   r0   rG   rI   r   r   rA   r    r     s        
r   zL
    NAT backbone, to be used with frameworks like DETR and MaskFormer.
    c                       sN   e Zd Z fddZdd Zed	ejee	 ee	 ee	 e
dddZ  ZS )
DinatBackbonec                    s   t    t    t| dg t | _t | _ jg fddt	t
 jD  | _i }t| j| jD ]\}}t|||< qpt|| _|   d S )Nr   c                    s   g | ]}t  jd |  qS )rK   )rb   r9   r   r@   r   r    r     r   z*DinatBackbone.__init__.<locals>.<listcomp>)r4   r5   Z_init_backboner   r3   rF   r   r   r9   r   r   r   r   zipZ_out_featuresr   r   r8   Z
ModuleDicthidden_states_normsr   )r?   r@   r   stagerS   rA   r   r    r5     s    

&zDinatBackbone.__init__c                 C   s   | j jS r   r   rq   r   r   r    r     s    z"DinatBackbone.get_input_embeddingsN)rD   r   r   r   rE   c                 C   s2  |dur|n| j j}|dur |n| j j}|dur4|n| j j}| |}| j||dddd}|j}d}t| j|D ]\}	}
|	| j	v rp|
j
\}}}}|
dddd }
|
||| |}
| j|	 |
}
|
||||}
|
dddd }
||
f7 }qp|s|f}|r||jf7 }|S t||r&|jnd|jd	S )
a/  
        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoBackbone
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
        >>> model = AutoBackbone.from_pretrained(
        ...     "shi-labs/nat-mini-in1k-224", out_features=["stage1", "stage2", "stage3", "stage4"]
        ... )

        >>> inputs = processor(image, return_tensors="pt")

        >>> outputs = model(**inputs)

        >>> feature_maps = outputs.feature_maps
        >>> list(feature_maps[-1].shape)
        [1, 512, 7, 7]
        ```NT)r   r   r   r   r   r   rK   r	   r   )feature_mapsr#   r$   )r@   r   r   r   rF   r   r%   r   Zstage_namesZout_featuresrW   rX   r   r   r   r#   r   r$   )r?   rD   r   r   r   r   r   r#   r   r   Zhidden_stater   rS   rZ   r[   rl   r   r   r    rG   !  sB    !

zDinatBackbone.forward)NNN)r&   r'   r(   r5   r   r   r*   rH   r   r   r   rG   rI   r   r   rA   r    r     s      r   )r   r   r   r   )rd   F)Ar)   r   dataclassesr   typingr   r   r*   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   Zactivationsr
   Zmodeling_outputsr   Zmodeling_utilsr   Zpytorch_utilsr   r   utilsr   r   r   r   r   r   Zutils.backbone_utilsr   Zconfiguration_dinatr   Znatten.functionalr   r   Z
get_loggerr&   loggerr!   r.   r0   rc   r3   r6   r\   rH   rs   r   rm   rn   ru   r   r   r   r   r   r   r   r   r   r   r   __all__r   r   r   r    <module>   sx    
$F$G/FRQb