a
    h                     @   s  d Z ddlZddlZddlZddlmZ ddlmZm	Z	 ddl
Z
ddlZ
ddl
mZmZ ddlmZmZmZ ddlmZ dd	lmZ dd
lmZmZmZmZmZmZ ddlmZ ddlm Z m!Z!m"Z" ddl#m$Z$m%Z%m&Z& ddl'm(Z( ddl)m*Z* e%+e,Z-ee$ddG dd deZ.dOe
je/e0e
jdddZ1G dd dej2Z3G dd dej2Z4G dd dej2Z5G d d! d!ej2Z6G d"d# d#e6Z7G d$d% d%ej2Z8e6e7d&Z9G d'd( d(ej2Z:G d)d* d*ej2Z;G d+d, d,ej2Z<G d-d. d.eZ=G d/d0 d0ej2Z>G d1d2 d2ej2Z?e$G d3d4 d4eZ@e$G d5d6 d6e@ZAG d7d8 d8ej2ZBe$d9dG d:d; d;e@ZCe$d<dG d=d> d>e@ZDG d?d@ d@ej2ZEG dAdB dBej2ZFG dCdD dDej2ZGG dEdF dFej2ZHG dGdH dHej2ZIe$G dIdJ dJe@ZJe$dKdG dLdM dMe@e(ZKg dNZLdS )PzPyTorch BEiT model.    N)	dataclass)OptionalUnion)Tensornn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)GradientCheckpointingLayer)BackboneOutputBaseModelOutputBaseModelOutputWithPoolingImageClassifierOutputMaskedLMOutputSemanticSegmenterOutput)PreTrainedModel)#compile_compatible_method_lru_cache find_pruneable_heads_and_indicesprune_linear_layer)auto_docstringlogging	torch_int)BackboneMixin   )
BeitConfigz-
    Class for outputs of [`BeitModel`].
    )Zcustom_introc                   @   s   e Zd ZdZdS )BeitModelOutputWithPoolingaF  
    pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
        Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
        *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
        will be returned.
    N)__name__
__module____qualname____doc__ r"   r"   b/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/beit/modeling_beit.pyr   0   s   r           F)input	drop_probtrainingreturnc                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    r$   r   r   r   )dtypedevice)shapendimtorchZrandr*   r+   Zfloor_div)r%   r&   r'   Z	keep_probr,   Zrandom_tensoroutputr"   r"   r#   	drop_path?   s    
r1   c                       sP   e Zd ZdZdee dd fddZejejdddZ	e
d	d
dZ  ZS )BeitDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).N)r&   r(   c                    s   t    || _d S N)super__init__r&   )selfr&   	__class__r"   r#   r5   V   s    
zBeitDropPath.__init__hidden_statesr(   c                 C   s   t || j| jS r3   )r1   r&   r'   r6   r:   r"   r"   r#   forwardZ   s    zBeitDropPath.forward)r(   c                 C   s   d| j  S )Nzp=)r&   r6   r"   r"   r#   
extra_repr]   s    zBeitDropPath.extra_repr)N)r   r   r    r!   r   floatr5   r.   r   r<   strr>   __classcell__r"   r"   r7   r#   r2   S   s   r2   c                       sd   e Zd ZdZedd fddZejeeejdddZ	deje
ej e
e ejd	d
dZ  ZS )BeitEmbeddingszc
    Construct the CLS token, position and patch embeddings. Optionally, also the mask token.

    Nconfigr(   c                    s   t    ttdd|j| _|jrBttdd|j| _	nd | _	t
|| _|j| _t|jtjjrp|jn
|j|jf| _| jj}|jrttd|d |j| _nd | _t|j| _d S )Nr   )r4   r5   r   	Parameterr.   zeroshidden_size	cls_tokenZuse_mask_token
mask_tokenBeitPatchEmbeddingspatch_embeddings
patch_size
isinstance
image_sizecollectionsabcIterablenum_patchesZ use_absolute_position_embeddingsposition_embeddingsDropouthidden_dropout_probdropout)r6   rD   rR   r7   r"   r#   r5   i   s     


zBeitEmbeddings.__init__)
embeddingsheightwidthr(   c                 C   s   |j d d }| jj d d }tj s>||kr>||kr>| jS | jddddf }| jddddf }|j d }|| j }	|| j }
t|d }|d|||}|dddd}t	j
j||	|
fdd	d
}|dddddd|}tj||fddS )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   Ng      ?r   r
      ZbicubicFsizemodealign_cornersdim)r,   rS   r.   Zjit
is_tracingrL   r   reshapepermuter   
functionalinterpolateviewcat)r6   rW   rX   rY   rR   Znum_positionsZclass_pos_embedZpatch_pos_embedra   
new_height	new_widthZsqrt_num_positionsr"   r"   r#   interpolate_pos_encoding   s(    



z'BeitEmbeddings.interpolate_pos_encoding)pixel_valuesbool_masked_posrk   r(   c                 C   s   | j d ur|d urtd |j\}}}}| |\}\}}	| \}
}}|d ur| j|
|d}|d	|}|d|  ||  }| j
|
dd}tj||fdd}| j d ur|| ||| }| |}|||	ffS )Nz`interpolate_pos_encoding` argument has no effect for BEiTEmbeddings, embeddings are always interpolated to the input image size. The argument will be removed in transformers v4.51.0.rZ   r   r`   )rS   warningswarnr,   rK   r]   rI   expand	unsqueezeZtype_asrH   r.   rh   rk   rV   )r6   rl   rm   rk   _rX   rY   rW   patch_heightpatch_width
batch_sizeZseq_lenZmask_tokenswZ
cls_tokensr"   r"   r#   r<      s"    

zBeitEmbeddings.forward)NN)r   r   r    r!   r   r5   r.   r   intrk   r   
BoolTensorboolr<   rA   r"   r"   r7   r#   rB   c   s   +  rB   c                       s4   e Zd ZdZ fddZejejdddZ  ZS )rJ   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j }}|j|j }}t|tjj	r8|n||f}t|tjj	rR|n||f}|d |d  |d |d   }|d |d  |d |d  f}|| _|| _|| _|| _
|| _tj||||d| _d S )Nr   r   kernel_sizeZstride)r4   r5   rN   rL   num_channelsrG   rM   rO   rP   rQ   rR   patch_shaper   Conv2d
projection)r6   rD   rN   rL   r|   rG   rR   r}   r7   r"   r#   r5      s    
  zBeitPatchEmbeddings.__init__)rl   r(   c           	      C   s^   |j \}}}}|| jkr td| |}|j d |j d  }}|ddd}|||ffS )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r[   r
   r   )r,   r|   
ValueErrorr   flatten	transpose)	r6   rl   ru   r|   rX   rY   rW   rs   rt   r"   r"   r#   r<      s    

zBeitPatchEmbeddings.forward)	r   r   r    r!   r5   r.   r   r<   rA   r"   r"   r7   r#   rJ      s   rJ   c                       sx   e Zd Zd	eee dd fddZd
ejeej e	eej e	eee
  eeej eejejf f dddZ  ZS )BeitSelfAttentionNrD   window_sizer(   c                    s   t    || _|j|j dkrDt|dsDtd|j d|j d|j| _t|j|j | _| j| j | _	t
|j| j	| _t
j|j| j	dd| _t
|j| j	| _t
|j| _t|| _| jrt||d| _d S )	Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .F)biasr   )r4   r5   rD   rG   num_attention_headshasattrr   rw   attention_head_sizeall_head_sizer   LinearquerykeyvaluerT   attention_probs_dropout_probrV   ry   has_relative_position_biasBeitRelativePositionBiasrelative_position_biasr6   rD   r   r7   r"   r#   r5      s$    


zBeitSelfAttention.__init__Fr:   	head_maskoutput_attentionsr   rk   
resolutionr(   c                 C   sl  |j \}}}	| ||d| j| jdd}
| ||d| j| jdd}| ||d| j| jdd}t	|
|dd}|t
| j }| jr|\}}|| jj || jj f}|| j|||j d d }|d ur|| }tjj|dd}| |}|d ur|| }t	||}|dddd }| d d | jf }|j| }|rb||fn|f}|S )	NrZ   r   r[   dim_sizer`   r   r
   )r,   r   rg   r   r   r   r   r   r.   matmulmathsqrtr   rD   rL   r   r   re   ZsoftmaxrV   rd   
contiguousr]   r   )r6   r:   r   r   r   rk   r   ru   
seq_lengthrr   query_layer	key_layervalue_layerZattention_scoresrX   rY   r   Zattention_probscontext_layernew_context_layer_shapeoutputsr"   r"   r#   r<     sN    	





zBeitSelfAttention.forward)N)NFNFN)r   r   r    r   r   tupler5   r.   r   ry   rw   r   r<   rA   r"   r"   r7   r#   r      s        
r   c                       s`   e Zd Zdejeej eeej eeee  e	eej eejejf f d fddZ
  ZS )BeitSdpaSelfAttentionNFr   c              	      sx  |s|d ur.t d t j||||||dS |j\}}}	| ||d| j| j	dd}
| 
||d| j| j	dd}| ||d| j| j	dd}d }| jr|\}}|| jj || jj f}| j|||jd d}|d ur|d u r|}n||7 }dt| j }tjjj|
|||| jr.| jjndd|d	}|d
ddd }| d d | jf }|j| }|d fS )Na  `BeitSdpaSelfAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.)r:   r   r   r   rk   r   rZ   r   r[   r   r$   F)Z	attn_maskZ	dropout_pZ	is_causalscaler   r
   r   )loggerZwarning_oncer4   r<   r,   r   rg   r   r   r   r   r   r   rD   rL   r   r   r   r.   r   re   Zscaled_dot_product_attentionr'   r   rd   r   r]   r   )r6   r:   r   r   r   rk   r   ru   r   rr   r   r   r   Z	attn_biasrX   rY   r   Zscalingr   r   r7   r"   r#   r<   I  sp    		


	
zBeitSdpaSelfAttention.forward)NFNFN)r   r   r    r.   r   r   ry   r   rw   r   r<   rA   r"   r"   r7   r#   r   H  s        
r   c                       sB   e Zd ZdZedd fddZd	ejejejdddZ  Z	S )
BeitSelfOutputz
    The residual connection is defined in BeitLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    NrC   c                    s.   t    t|j|j| _t|j| _d S r3   )	r4   r5   r   r   rG   denserT   rU   rV   r6   rD   r7   r"   r#   r5     s    
zBeitSelfOutput.__init__)r:   input_tensorr(   c                 C   s   |  |}| |}|S r3   r   rV   )r6   r:   r   gammar"   r"   r#   r<     s    

zBeitSelfOutput.forward)N)
r   r   r    r!   r   r5   r.   r   r<   rA   r"   r"   r7   r#   r     s   r   )eagerZsdpac                       s   e Zd Zdeee dd fddZdd Zdej	eej	 e
eej	 e
eee  eeej	 eej	ej	f f dd	d
Z  ZS )BeitAttentionNr   c                    s4   t    t|j ||d| _t|| _t | _d S )Nr   )	r4   r5   BEIT_SELF_ATTENTION_CLASSESZ_attn_implementation	attentionr   r0   setpruned_headsr   r7   r"   r#   r5     s    

zBeitAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r`   )lenr   r   r   r   r   r   r   r   r   r0   r   r   union)r6   headsindexr"   r"   r#   prune_heads  s    zBeitAttention.prune_headsFr   c           
      C   s:   |  ||||||}| |d |}|f|dd   }	|	S )Nr   r   )r   r0   )
r6   r:   r   r   r   rk   r   Zself_outputsattention_outputr   r"   r"   r#   r<     s    	zBeitAttention.forward)N)NFNFN)r   r   r    r   r   r   r5   r   r.   r   ry   rw   r   r<   rA   r"   r"   r7   r#   r     s         
r   c                       s8   e Zd Zedd fddZejejdddZ  ZS )BeitIntermediateNrC   c                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S r3   )r4   r5   r   r   rG   intermediate_sizer   rM   Z
hidden_actr@   r   intermediate_act_fnr   r7   r"   r#   r5     s
    
zBeitIntermediate.__init__r9   c                 C   s   |  |}| |}|S r3   )r   r   r;   r"   r"   r#   r<     s    

zBeitIntermediate.forward	r   r   r    r   r5   r.   r   r<   rA   r"   r"   r7   r#   r     s   r   c                       s8   e Zd Zedd fddZejejdddZ  ZS )
BeitOutputNrC   c                    s.   t    t|j|j| _t|j| _	d S r3   )
r4   r5   r   r   r   rG   r   rT   rU   rV   r   r7   r"   r#   r5     s    
zBeitOutput.__init__r9   c                 C   s   |  |}| |}|S r3   r   r;   r"   r"   r#   r<     s    

zBeitOutput.forwardr   r"   r"   r7   r#   r     s   r   c                       s   e Zd ZdZdeee edd fddZde	j
ee	j
 eee	j
 eeeeef  eee	j
 ee	j
e	j
f f dd	d
Z  ZS )	BeitLayerz?This corresponds to the Block class in the timm implementation.Nr$   )rD   r   drop_path_rater(   c                    s   t    |j| _d| _t||d| _t|| _t|| _	t
j|j|jd| _|dkr^t|nt
 | _t
j|j|jd| _|j}|dkrt
j|t|j dd| _t
j|t|j dd| _nd\| _| _d S )	Nr   r   epsr$   r   T)Zrequires_grad)NN)r4   r5   Zchunk_size_feed_forwardZseq_len_dimr   r   r   intermediater   r0   r   	LayerNormrG   layer_norm_epslayernorm_beforer2   Identityr1   layernorm_afterlayer_scale_init_valuerE   r.   Zoneslambda_1lambda_2)r6   rD   r   r   Zinit_valuesr7   r"   r#   r5     s    


zBeitLayer.__init__Fr   c                 C   s   | j | ||||||d}|d }|dd  }	| jd urD| j| }| || }| |}
| |
}
| |
}
| jd ur| j|
 }
| |
| }
|
f|	 }	|	S )N)r   r   rk   r   r   r   )r   r   r   r1   r   r   r0   r   )r6   r:   r   r   r   rk   r   Zself_attention_outputsr   r   Zlayer_outputr"   r"   r#   r<     s*    	







zBeitLayer.forward)Nr$   )NFNFN)r   r   r    r!   r   r   r   r?   r5   r.   r   ry   rw   r   r<   rA   r"   r"   r7   r#   r     s         r   c                       s^   e Zd Zeedd fddZeddeeef ej	ddd	Z
deej	dddZ  ZS )r   Nr   c                    sR   t    || _d|d  d d|d  d  d | _tt| j|j| _	d S )Nr[   r   r   r
   )
r4   r5   r   num_relative_distancer   rE   r.   rF   r   relative_position_bias_tabler   r7   r"   r#   r5   5  s    
&z!BeitRelativePositionBias.__init__
   )maxsize)r   r(   c           	      C   s  d|d  d d|d  d  d }|d |d  }t jt |d t |d dd}t |}t |d}|dddddf |dddddf  }|ddd }|dddddf  |d d 7  < |dddddf  |d d 7  < |dddddf  d|d  d 9  < t j|d fd |jd}|	d	|ddddf< |d |dddf< |d |dddf< |d |d
< |S )z
        This method creates the relative position index, modified to support arbitrary window sizes,
        as introduced in [MiDaS v3.1](https://huggingface.co/papers/2307.14460).
        r[   r   r   r
   Zij)ZindexingN)r]   r*   rZ   )r   r   )
r.   ZmeshgridZarangestackr   rd   r   rF   r*   sum)	r6   r   r   Zwindow_areagridZcoordsZcoords_flattenZrelative_coordsrelative_position_indexr"   r"   r#    generate_relative_position_index>  s     $$
,&&*z9BeitRelativePositionBias.generate_relative_position_indexF)rk   r(   c                 C   sh  d| j d  d }d| j d  d }d|d  d }d|d  d }| j}| j}	|| d }
|d|	d  }|d||ddddd}tjj|t|t|fdd}|dddd|
d d}t	
|||	d d g}| |}||d }||d |d  d |d |d  d d}|ddd }|r^tjj|d||fdd	d
d}|dS )zu
        Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
        r[   r   r   r
   NrZ   bilinear)r]   r^   Fr\   )r   r   r   rc   rd   r   re   rf   r   r.   rh   r   rg   r   rq   squeeze)r6   r   rk   r   Z
old_heightZ	old_widthri   rj   Z old_relative_position_bias_tableZold_num_relative_distanceZnew_num_relative_distanceZold_sub_tableZnew_sub_tableZ new_relative_position_bias_tabler   r   r"   r"   r#   r<   W  s@    
&z BeitRelativePositionBias.forward)FN)r   r   r    r   r   r5   r   rw   r.   r   r   ry   r<   rA   r"   r"   r7   r#   r   4  s   	r   c                       sf   e Zd Zd
eee dd fddZdejeej e	e	e	eee
e
f  e	eeef ddd	Z  ZS )BeitEncoderNr   c                    sz   t     | _ j| _| jr,t d| _dd tjd j	 j
ddD t fddt j
D | _d| _d S )	Nr   c                 S   s   g | ]}|  qS r"   )item.0xr"   r"   r#   
<listcomp>      z(BeitEncoder.__init__.<locals>.<listcomp>r   cpu)r+   c                    s(   g | ] }t   jrnd | dqS )N)r   r   )r   Zuse_relative_position_biasr   irD   Zdprr   r"   r#   r     s   F)r4   r5   rD   Z!use_shared_relative_position_biasr   r   r   r.   Zlinspacer   num_hidden_layersr   
ModuleListrangelayerZgradient_checkpointingr   r7   r   r#   r5     s    
 
zBeitEncoder.__init__FT)r:   r   r   output_hidden_statesrk   r   return_dictr(   c              	   C   s   |rdnd }|rdnd }	t | jD ]\}
}|r8||f }| jrv|\}}|| jj || jj f}| j|||jd d}nd }|d ur||
 nd }|||||||d}|d }|r"|	|d f }	q"|r||f }|stdd |||	fD S t|||	dS )	Nr"   r   )rk   r   )r   r   r   rk   r   r   c                 s   s   | ]}|d ur|V  qd S r3   r"   )r   vr"   r"   r#   	<genexpr>  r   z&BeitEncoder.forward.<locals>.<genexpr>)last_hidden_stater:   
attentions)		enumerater   r   rD   rL   r   r,   r   r   )r6   r:   r   r   r   rk   r   r   Zall_hidden_statesZall_self_attentionsr   Zlayer_modulerX   rY   r   r   Zlayer_head_maskZlayer_outputsr"   r"   r#   r<     sB    

	
zBeitEncoder.forward)N)NFFFNT)r   r   r    r   r   r   r5   r.   r   ry   rw   r   r   r<   rA   r"   r"   r7   r#   r     s"         
r   c                   @   s:   e Zd ZU eed< dZdZdZdgZdgZ	dZ
dd Zd	S )
BeitPreTrainedModelrD   beitrl   Tr   z.*relative_position_index.*c                 C   sL  t |tjtjtjfrF|jjjd| jj	d |j
durB|j
j  nt |tjr|jjjd| jj	d |jdur|jj|j   nt |tjr|j
j  |jjd nt |tr|jj  |jdur|jj  |jdur|jj  nVt |tr|jj  n<t |trH|jdurH|jj| jj |jj| jj dS )zInitialize the weightsr$   )meanZstdNg      ?)rM   r   r   r~   ConvTranspose2dweightdataZnormal_rD   Zinitializer_ranger   Zzero_Z	EmbeddingZpadding_idxr   Zfill_rB   rH   rI   rS   r   r   r   r   r   r   )r6   moduler"   r"   r#   _init_weights  s.    




z!BeitPreTrainedModel._init_weightsN)r   r   r    r   __annotations__Zbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_no_split_modulesZ"_keys_to_ignore_on_load_unexpectedZ_supports_sdpar   r"   r"   r"   r#   r     s   
r   c                       s|   e Zd Zdeedd fddZdd Zdd	 Zede	j
ee	j ee	j
 ee ee eee eeef dddZ  ZS )	BeitModelTN)rD   add_pooling_layerr(   c                    sp   t  | || _t|| _t|| jjjd| _|j	r>t
 nt
j|j|jd| _|r^t|nd| _|   dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        r   r   N)r4   r5   rD   rB   rW   r   rK   r}   encoderuse_mean_poolingr   r   r   rG   r   	layernorm
BeitPoolerpooler	post_init)r6   rD   r   r7   r"   r#   r5     s    
zBeitModel.__init__c                 C   s   | j jS r3   rW   rK   r=   r"   r"   r#   get_input_embeddings  s    zBeitModel.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   r   )r6   Zheads_to_pruner   r   r"   r"   r#   _prune_heads  s    zBeitModel._prune_headsF)rl   rm   r   r   r   rk   r   r(   c              	   C   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}| || j j}| j||d\}}	|jdd }
| j|||||
||d}|d }| 	|}| j
dur| 
|nd}|s|dur||fn|f}||dd  S t|||j|jdS )z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        N)rm   r[   )r   r   r   r   r   rk   r   r   )r   pooler_outputr:   r   )rD   r   r   use_return_dictZget_head_maskr   rW   r,   r   r   r   r   r:   r   )r6   rl   rm   r   r   r   rk   r   embedding_outputrr   r   Zencoder_outputssequence_outputpooled_outputZhead_outputsr"   r"   r#   r<     s8    	
zBeitModel.forward)T)NNNNFN)r   r   r    r   ry   r5   r  r  r   r.   r   r   rx   r   r   r   r<   rA   r"   r"   r7   r#   r     s(         
r   c                       s8   e Zd Zedd fddZejejdddZ  ZS )r   NrC   c                    s,   t    |jr"tj|j|jdnd | _d S )Nr   )r4   r5   r   r   r   rG   r   r   r   r7   r"   r#   r5   S  s    
zBeitPooler.__init__r9   c                 C   sJ   | j d ur6|d d dd d d f }|  |d}n|d d df }|S )Nr   r   )r   r   )r6   r:   Zpatch_tokensr
  r"   r"   r#   r<   Y  s
    
zBeitPooler.forwardr   r"   r"   r7   r#   r   R  s   r   a  
    Beit Model transformer with a 'language' modeling head on top. BEiT does masked image modeling by predicting
    visual tokens of a Vector-Quantize Variational Autoencoder (VQ-VAE), whereas other vision models like ViT and DeiT
    predict RGB pixel values. As a result, this class is incompatible with [`AutoModelForMaskedImageModeling`], so you
    will need to use [`BeitForMaskedImageModeling`] directly if you wish to do masked image modeling with BEiT.
    c                       s|   e Zd Zedd fddZdd Zedeej	 eej
 eej	 eej	 ee ee eee eeef d	d	d
Z  ZS )BeitForMaskedImageModelingNrC   c                    sT   t  | |j| _t|dd| _tj|j|jd| _	t
|j|j| _|   d S )NFr   r   )r4   r5   
num_labelsr   r   r   r   rG   r   r   r   Z
vocab_sizelm_headr  r   r7   r"   r#   r5   n  s    z#BeitForMaskedImageModeling.__init__c                 C   s   d S r3   r"   r=   r"   r"   r#   get_output_embeddings{  s    z0BeitForMaskedImageModeling.get_output_embeddingsF)	rl   rm   r   labelsr   r   rk   r   r(   c	              	   C   s   |dur|n| j j}| j|||||||d}	|	d }
| |
}
| |
ddddf }d}|durxt }||| |}|s|f|	dd  }|dur|f| S |S t|||	j|	jdS )a  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, BeitForMaskedImageModeling
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
        >>> model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k")

        >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
        >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
        >>> # create random boolean mask of shape (batch_size, num_patches)
        >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()

        >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
        >>> loss, logits = outputs.loss, outputs.logits
        >>> list(logits.shape)
        [1, 196, 8192]
        ```N)rm   r   r   r   rk   r   r   r   losslogitsr:   r   )	rD   r  r   r   r  r   r   r:   r   )r6   rl   rm   r   r  r   r   rk   r   r   r	  Zprediction_scoresZmasked_lm_lossloss_fctr0   r"   r"   r#   r<   ~  s4    ,

z"BeitForMaskedImageModeling.forward)NNNNNNFN)r   r   r    r   r5   r  r   r   r.   r   rx   ry   r   r   r   r<   rA   r"   r"   r7   r#   r  e  s,   	        
r  z
    Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final
    hidden states of the patch tokens) e.g. for ImageNet.
    c                       sl   e Zd Zedd fddZed	eej eej eej ee	 ee	 e	ee	 e
eef dddZ  ZS )
BeitForImageClassificationNrC   c                    sR   t  | |j| _t|dd| _|jdkr<t|j|jnt | _	| 
  d S )NTr  r   )r4   r5   r  r   r   r   r   rG   r   
classifierr  r   r7   r"   r#   r5     s
    $z#BeitForImageClassification.__init__Frl   r   r  r   r   rk   r   r(   c                 C   sv  |dur|n| j j}| j||||||d}|r4|jn|d }	| |	}
d}|dur2| j jdu r| jdkrtd| j _n4| jdkr|jtj	ks|jtj
krd| j _nd| j _| j jdkrt }| jdkr||
 | }n
||
|}nN| j jdkrt }||
d| j|d}n| j jdkr2t }||
|}|sb|
f|dd  }|dur^|f| S |S t||
|j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   r   rk   r   r   Z
regressionZsingle_label_classificationZmulti_label_classificationrZ   r[   r  )rD   r  r   r  r  Zproblem_typer  r*   r.   longrw   r	   r   r   rg   r   r   r:   r   )r6   rl   r   r  r   r   rk   r   r   r
  r  r  r  r0   r"   r"   r#   r<     sN    	



"


z"BeitForImageClassification.forward)NNNNNFN)r   r   r    r   r5   r   r   r.   r   ry   r   r   r   r<   rA   r"   r"   r7   r#   r    s&          
r  c                       sz   e Zd ZdZdeeeeeeef f eeeeef ef eeeeeef f dd fddZ	e
je
jd	d
dZ  ZS )BeitConvModuleaD  
    A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
    layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).

    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
    r   Fr   N)in_channelsout_channelsr{   paddingr   dilationr(   c                    s<   t    tj||||||d| _t|| _t | _d S )N)r  r  r{   r  r   r  )	r4   r5   r   r~   convBatchNorm2dbnZReLU
activation)r6   r  r  r{   r  r   r  r7   r"   r#   r5   '  s    	
zBeitConvModule.__init__r%   r(   c                 C   s"   |  |}| |}| |}|S r3   )r  r!  r"  )r6   r%   r0   r"   r"   r#   r<   <  s    


zBeitConvModule.forward)r   Fr   )r   r   r    r!   rw   r   r   r@   ry   r5   r.   r   r<   rA   r"   r"   r7   r#   r    s      r  c                       s<   e Zd Zeeedd fddZejejdddZ  ZS )BeitPyramidPoolingBlockN)
pool_scaler  channelsr(   c                    sL   t    t|t||ddg| _t| jD ]\}}| t|| q.d S )Nr   r{   )	r4   r5   r   ZAdaptiveAvgPool2dr  layersr   
add_moduler@   )r6   r%  r  r&  r   r   r7   r"   r#   r5   E  s    
z BeitPyramidPoolingBlock.__init__r#  c                 C   s   |}| j D ]}||}q
|S r3   )r(  )r6   r%   hidden_stater   r"   r"   r#   r<   N  s    

zBeitPyramidPoolingBlock.forward)	r   r   r    rw   r5   r.   r   r<   rA   r"   r"   r7   r#   r$  D  s   	r$  c                       sN   e Zd ZdZeedf eeedd fddZej	e
ej	 ddd	Z  ZS )
BeitPyramidPoolingModulea  
    Pyramid Pooling Module (PPM) used in PSPNet.

    Args:
        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
            Module.
        in_channels (int): Input channels.
        channels (int): Channels after modules, before conv_seg.
        align_corners (bool): align_corners argument of F.interpolate.

    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
    .N)pool_scalesr  r&  r_   r(   c                    sh   t    || _|| _|| _|| _g | _t|D ]2\}}t|||d}| j	| | 
t|| q0d S )N)r%  r  r&  )r4   r5   r,  r_   r  r&  blocksr   r$  appendr)  r@   )r6   r,  r  r&  r_   r   r%  blockr7   r"   r#   r5   c  s    
z!BeitPyramidPoolingModule.__init__)r   r(   c                 C   sH   g }| j D ]8}||}tjj|| dd  d| jd}|| q
|S )Nr[   r   r\   )r-  r   re   rf   r]   r_   r.  )r6   r   Zppm_outsppmZppm_outZupsampled_ppm_outr"   r"   r#   r<   o  s    
z BeitPyramidPoolingModule.forward)r   r   r    r!   r   rw   ry   r5   r.   r   listr<   rA   r"   r"   r7   r#   r+  U  s   "r+  c                       sD   e Zd ZdZedd fddZdd Zejejdd	d
Z	  Z
S )BeitUperHeadz
    Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
    [UPerNet](https://huggingface.co/papers/1807.10221).

    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
    NrC   c                    s  t    |j| _|jgd | _|j| _d| _tj| j|j	dd| _
t| j| jd | j| jd| _t| jd t| j| j  | jddd| _t | _t | _| jd d D ]@}t|| jdd}t| j| jddd}| j| | j| qtt| j| j | jddd| _d S )	N   Fr   r'  rZ   )r_   r
   r{   r  )r4   r5   r,  rG   r  r&  r_   r   r~   r  r  r+  psp_modulesr  r   
bottleneckr   lateral_convs	fpn_convsr.  fpn_bottleneck)r6   rD   r  Zl_convZfpn_convr7   r"   r#   r5     s>    


zBeitUperHead.__init__c                 C   s:   |d }|g}| | | tj|dd}| |}|S )NrZ   r   r`   )extendr5  r.   rh   r6  )r6   inputsr   Zpsp_outsr0   r"   r"   r#   psp_forward  s    
zBeitUperHead.psp_forwardencoder_hidden_statesr(   c                    s   fddt jD   t}t|d ddD ]H}|d  jdd  }|d  tjj	| |dj
d |d < q@fd	dt|d D }|d  t|d ddD ]0}tjj	|| |d jdd  dj
d||< qtj|dd
}|}|}|S )Nc                    s   g | ]\}}| | qS r"   r"   )r   r   Zlateral_conv)r>  r"   r#   r     r   z(BeitUperHead.forward.<locals>.<listcomp>r   r   rZ   r[   r   r\   c                    s   g | ]}j |  | qS r"   )r8  r   )lateralsr6   r"   r#   r     r   r`   )r   r7  r.  r<  r   r   r,   r   re   rf   r_   r.   rh   r9  r  )r6   r>  Zused_backbone_levelsr   Z
prev_shapeZfpn_outsr0   r"   )r>  r?  r6   r#   r<     s$    

zBeitUperHead.forward)r   r   r    r!   r   r5   r<  r.   r   r<   rA   r"   r"   r7   r#   r2  z  s   &	r2  c                	       sT   e Zd ZdZdeeeeeeeef f dd fddZe	j
e	j
d	d
dZ  ZS )BeitFCNHeada  
    Fully Convolution Networks for Semantic Segmentation. This head is implemented of
    [FCNNet](https://huggingface.co/papers/1411.4038>).

    Args:
        config (BeitConfig): Configuration.
        in_channels
        kernel_size (int): The kernel size for convs in the head. Default: 3.
        dilation (int): The dilation rate for convs in the head. Default: 1.


    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
    r[   r
   r   N)rD   in_indexr{   r  r(   c              
      s   t    |j| _|j| _|j| _|j| _	|| _
|d | }g }|t| j| j|||d t| jd D ] }|t| j| j|||d qj| jdkrt | _ntj| | _| j	rt| j| j | j||d d| _tj| j|jdd| _d S )Nr[   )r{   r  r  r   r   r4  r'  )r4   r5   rG   r  Zauxiliary_channelsr&  Zauxiliary_num_convsZ	num_convsZauxiliary_concat_inputconcat_inputrA  r.  r  r   r   r   convs
Sequentialconv_catr~   r  r  )r6   rD   rA  r{   r  Zconv_paddingrC  r   r7   r"   r#   r5     s6    

zBeitFCNHead.__init__r=  c                 C   s@   || j  }| |}| jr2| tj||gdd}| |}|S )Nr   r`   )rA  rC  rB  rE  r.   rh   r  )r6   r>  r:   r0   r"   r"   r#   r<      s    


zBeitFCNHead.forward)r[   r
   r   )r   r   r    r!   r   rw   r   r   r5   r.   r   r<   rA   r"   r"   r7   r#   r@    s    "r@  c                       st   e Zd Zedd fddZdd Zedeej	 eej	 eej	 ee
 ee
 e
ee
 eeef dd	d
Z  ZS )BeitForSemanticSegmentationNrC   c                    s   t  | |j| _t|dd| _t| jjdkr:tdt	
t	j|j|jdddt	|jt	 t	j|j|jddd| _t	
t	j|j|jddd| _t	 | _t	jddd| _t|| _|jrt|nd | _|   d S )NFr  r3  zBeitForSemanticSegmentation requires config.out_indices to be a list of 4 integers, specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of a base-sized architecture.r[   rz   )r4   r5   r  r   r   r   rD   out_indicesr   r   rD  r   rG   r   GELUfpn1fpn2r   fpn3	MaxPool2dfpn4r2  decode_headZuse_auxiliary_headr@  auxiliary_headr  r   r7   r"   r#   r5     s*    


z$BeitForSemanticSegmentation.__init__c           
      C   s   t jj||jdd  ddd}|d urDt jj||jdd  ddd}t| jjd}|||}|}|d ur|||}	|| jj|	 7 }|S )Nr   r   Fr\   )Zignore_index)r   re   rf   r,   r   rD   Zsemantic_loss_ignore_indexZauxiliary_loss_weight)
r6   r  auxiliary_logitsr  Zupsampled_logitsZupsampled_auxiliary_logitsr  Z	main_lossr  Zauxiliary_lossr"   r"   r#   compute_loss,  s    

z(BeitForSemanticSegmentation.compute_lossFr  c                    s  |dur|nj j}|dur |nj j}|durDj jdkrDtdj|||d||d}|rd|jn|d }	fddt|	D }
|jd  j j	j j
  fd	d|
D }
jjjjg}tt|
D ]}|| |
| |
|< qЈ|
}d}jdur|
}d}|dur*|||}|st|rJ|f|dd  }n|f|d
d  }|durp|f| S |S t|||r|jnd|jdS )aD  
        labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
            Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, BeitForSemanticSegmentation
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
        >>> model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> # logits are of shape (batch_size, num_labels, height, width)
        >>> logits = outputs.logits
        ```Nr   z/The number of labels should be greater than oneTr  c                    s$   g | ]\}}|d   j jv r|qS r)   )rD   rG  )r   idxfeaturer=   r"   r#   r   v  r   z7BeitForSemanticSegmentation.forward.<locals>.<listcomp>r   c                    s<   g | ]4}|d d dd d d f  ddd dqS )Nr   r   r[   rZ   )rd   rc   r   )ru   patch_resolutionr"   r#   r   y  s   r[   r  )rD   r  r   r  r   r   r:   r   r,   rN   rL   rI  rJ  rK  rM  r   r   rN  rO  rQ  r   r   )r6   rl   r   r  r   r   rk   r   r   r>  featuresopsr   r  rP  r  r0   r"   )ru   rT  r6   r#   r<   ?  sR    "	



z#BeitForSemanticSegmentation.forward)NNNNNFN)r   r   r    r   r5   rQ  r   r   r.   r   ry   r   r   r   r<   rA   r"   r"   r7   r#   rF  
  s(           
rF  zM
    BEiT backbone, to be used with frameworks like DETR and MaskFormer.
    c                       sL   e Zd Z fddZdd Zed	eee ee ee e	dddZ
  ZS )
BeitBackbonec                    s   t    t     fddt jd D | _t | _t | jj	j
d| _ jrt| jjdkrrtd j}ttj||dddtj| jd	t tj||ddd| _ttj||ddd| _t | _tjddd| _|   d S )
Nc                    s   g | ]
} j qS r"   )rG   )r   rr   rD   r"   r#   r     r   z)BeitBackbone.__init__.<locals>.<listcomp>r   r   r3  zBeitBackbone requires config.out_indices to be a list of 4 integers, specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of a base-sized architecture.r[   rz   r   )r4   r5   Z_init_backboner   r   Znum_featuresrB   rW   r   rK   r}   r   add_fpnr   rD   rG  r   rG   r   rD  r   r   Zbatch_norm_epsrH  rI  rJ  r   rK  rL  rM  r  )r6   rD   rG   r7   rX  r#   r5     s*    

zBeitBackbone.__init__c                 C   s   | j jS r3   r  r=   r"   r"   r#   r    s    z!BeitBackbone.get_input_embeddingsN)rl   r   r   r   r(   c                 C   s  |dur|n| j j}|dur |n| j j}|dur4|n| j j}|jd }| |\}\}}|jdd }	| j|d||	|d}
|r|
jn|
d }d}t| j	|D ]\\}}|| j
v r| j jr|ddddddf }|ddd}||d||}||f7 }q| j jr@| |d | |d | |d | |d	 g}t|}|sv|r`|f|
dd  }n|f|
dd  }|S t||r|
jnd|
jd
S )a:  
        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoBackbone
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224")
        >>> model = AutoBackbone.from_pretrained(
        ...     "microsoft/beit-base-patch16-224", out_features=["stage1", "stage2", "stage3", "stage4"]
        ... )

        >>> inputs = processor(image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> feature_maps = outputs.feature_maps
        >>> list(feature_maps[-1].shape)
        [1, 768, 14, 14]
        ```Nr   r[   T)r   r   r   r   r   r"   rZ   r
   )feature_mapsr:   r   )rD   r  r   r   r,   rW   r   r:   zipZstage_namesZout_featuresZreshape_hidden_statesrd   rc   rY  rI  rJ  rK  rM  r   r   r   )r6   rl   r   r   r   ru   r  rs   rt   r   r   r:   rZ  Zstager*  r0   r"   r"   r#   r<     sP     


zBeitBackbone.forward)NNN)r   r   r    r5   r  r   r   r   ry   r   r<   rA   r"   r"   r7   r#   rW    s      rW  )r  r  rF  r   r   rW  )r$   F)Mr!   collections.abcrO   r   rn   dataclassesr   typingr   r   r.   Ztorch.utils.checkpointr   r   Ztorch.nnr   r   r	   Zactivationsr   Zmodeling_layersr   Zmodeling_outputsr   r   r   r   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   utilsr   r   r   Zutils.backbone_utilsr   Zconfiguration_beitr   Z
get_loggerr   r   r   r?   ry   r1   Moduler2   rB   rJ   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r$  r+  r2  r@  rF  rW  __all__r"   r"   r"   r#   <module>   s    
	f&YJ,ASL&W^N%%U; w