a
    h                     @   sP  d Z ddlZddlmZmZ ddlZddlZddlmZ ddlm	Z	m
Z
mZ ddlmZ ddlmZ dd	lmZmZmZmZ dd
lmZ ddlmZmZ ddlmZmZmZ ddlmZ e e!Z"d=e#e#ee# e#dddZ$G dd dej%Z&G dd dej%Z'G dd dej%Z(G dd dej%Z)G dd dej%Z*G dd dej%Z+G dd  d ej%Z,G d!d" d"ej%Z-G d#d$ d$ej%Z.G d%d& d&ej%Z/G d'd( d(eZ0G d)d* d*ej%Z1eG d+d, d,eZ2eG d-d. d.e2Z3ed/d0G d1d2 d2e2Z4G d3d4 d4ej%Z5G d5d6 d6ej%Z6G d7d8 d8ej%Z7ed9d0G d:d; d;e2Z8g d<Z9dS )>zPyTorch MobileViT model.    N)OptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputWithNoAttention(BaseModelOutputWithPoolingAndNoAttention$ImageClassifierOutputWithNoAttentionSemanticSegmenterOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)auto_docstringlogging	torch_int   )MobileViTConfig   )valuedivisor	min_valuereturnc                 C   sF   |du r|}t |t| |d  | | }|d|  k r>||7 }t|S )a  
    Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the
    original TensorFlow repo. It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    N   g?)maxint)r   r   r   	new_value r    l/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/mobilevit/modeling_mobilevit.pymake_divisible,   s    r"   c                       sT   e Zd Zdeeeeeeeeeeeef dd fddZe	j
e	j
dd	d
Z  ZS )MobileViTConvLayerr   FTN)configin_channelsout_channelskernel_sizestridegroupsbiasdilationuse_normalizationuse_activationr   c                    s   t    t|d d | }|| dkr@td| d| d|| dkrbtd| d| dtj||||||||dd		| _|	rtj|d
dddd| _nd | _|
rt	|
t
rt|
 | _qt	|jt
rt|j | _q|j| _nd | _d S )Nr   r   r   zInput channels (z) are not divisible by z groups.zOutput channels (Zzeros)	r%   r&   r'   r(   paddingr+   r)   r*   Zpadding_modegh㈵>g?T)Znum_featuresepsZmomentumZaffineZtrack_running_stats)super__init__r   
ValueErrorr   Conv2dconvolutionBatchNorm2dnormalization
isinstancestrr	   
activation
hidden_act)selfr$   r%   r&   r'   r(   r)   r*   r+   r,   r-   r.   	__class__r    r!   r1   <   sB    



zMobileViTConvLayer.__init__featuresr   c                 C   s6   |  |}| jd ur| |}| jd ur2| |}|S N)r4   r6   r9   )r;   r?   r    r    r!   forwardr   s    




zMobileViTConvLayer.forward)r   r   Fr   TT)__name__
__module____qualname__r   r   boolr   r8   r1   torchTensorrA   __classcell__r    r    r<   r!   r#   ;   s(         
6r#   c                       sF   e Zd ZdZd
eeeeedd fddZejejddd	Z	  Z
S )MobileViTInvertedResidualzY
    Inverted residual block (MobileNetv2): https://huggingface.co/papers/1801.04381
    r   N)r$   r%   r&   r(   r+   r   c              	      s   t    ttt||j d}|dvr:td| d|dkoH||k| _t|||dd| _	t|||d|||d| _
t|||dd	d
| _d S )Nr   )r   r   zInvalid stride .r   r%   r&   r'   r   )r%   r&   r'   r(   r)   r+   Fr%   r&   r'   r-   )r0   r1   r"   r   roundZexpand_ratior2   use_residualr#   
expand_1x1conv_3x3
reduce_1x1)r;   r$   r%   r&   r(   r+   Zexpanded_channelsr<   r    r!   r1      s0    

z"MobileViTInvertedResidual.__init__r>   c                 C   s4   |}|  |}| |}| |}| jr0|| S |S r@   )rO   rP   rQ   rN   )r;   r?   residualr    r    r!   rA      s
    


z!MobileViTInvertedResidual.forward)r   )rB   rC   rD   __doc__r   r   r1   rF   rG   rA   rH   r    r    r<   r!   rI   {   s    
!rI   c                       sB   e Zd Zd	eeeeedd fddZejejdddZ  Z	S )
MobileViTMobileNetLayerr   N)r$   r%   r&   r(   
num_stagesr   c                    sR   t    t | _t|D ]0}t||||dkr4|ndd}| j| |}qd S )Nr   r   )r%   r&   r(   )r0   r1   r   
ModuleListlayerrangerI   append)r;   r$   r%   r&   r(   rU   irW   r<   r    r!   r1      s    

z MobileViTMobileNetLayer.__init__r>   c                 C   s   | j D ]}||}q|S r@   rW   )r;   r?   layer_moduler    r    r!   rA      s    

zMobileViTMobileNetLayer.forward)r   r   
rB   rC   rD   r   r   r1   rF   rG   rA   rH   r    r    r<   r!   rT      s    
rT   c                       s:   e Zd Zeedd fddZejejdddZ  Z	S )MobileViTSelfAttentionNr$   hidden_sizer   c                    s   t    ||j dkr0td| d|j d|j| _t||j | _| j| j | _tj|| j|j	d| _
tj|| j|j	d| _tj|| j|j	d| _t|j| _d S )Nr   zThe hidden size z4 is not a multiple of the number of attention heads rJ   )r*   )r0   r1   num_attention_headsr2   r   attention_head_sizeall_head_sizer   LinearZqkv_biasquerykeyr   DropoutZattention_probs_dropout_probdropoutr;   r$   r`   r<   r    r!   r1      s    
zMobileViTSelfAttention.__init__hidden_statesr   c                 C   s   |j \}}}| ||d| j| jdd}| ||d| j| jdd}| ||d| j| jdd}t	||dd}|t
| j }tjj|dd}	| |	}	t	|	|}
|
dddd }
|
 d d | jf }|
j| }
|
S )Nr   r   dimr   r   )shapere   viewra   rb   	transposerf   r   rF   matmulmathsqrtr   
functionalZsoftmaxrh   Zpermute
contiguoussizerc   )r;   rk   
batch_sizeZ
seq_length_Zquery_layerZ	key_layerZvalue_layerZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shaper    r    r!   rA      s8    




zMobileViTSelfAttention.forwardr]   r    r    r<   r!   r^      s   r^   c                       s:   e Zd Zeedd fddZejejdddZ  Z	S )MobileViTSelfOutputNr_   c                    s*   t    t||| _t|j| _d S r@   r0   r1   r   rd   denserg   Zhidden_dropout_probrh   ri   r<   r    r!   r1      s    
zMobileViTSelfOutput.__init__rj   c                 C   s   |  |}| |}|S r@   r}   rh   r;   rk   r    r    r!   rA     s    

zMobileViTSelfOutput.forwardr]   r    r    r<   r!   r{      s   r{   c                       sN   e Zd Zeedd fddZee ddddZej	ej	dd	d
Z
  ZS )MobileViTAttentionNr_   c                    s.   t    t||| _t||| _t | _d S r@   )r0   r1   r^   	attentionr{   outputsetpruned_headsri   r<   r    r!   r1     s    
zMobileViTAttention.__init__)headsr   c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rn   )lenr   r   ra   rb   r   r   re   rf   r   r   r}   rc   union)r;   r   indexr    r    r!   prune_heads  s    zMobileViTAttention.prune_headsrj   c                 C   s   |  |}| |}|S r@   )r   r   )r;   rk   Zself_outputsattention_outputr    r    r!   rA      s    

zMobileViTAttention.forward)rB   rC   rD   r   r   r1   r   r   rF   rG   rA   rH   r    r    r<   r!   r     s   r   c                       s<   e Zd Zeeedd fddZejejdddZ  Z	S )MobileViTIntermediateNr$   r`   intermediate_sizer   c                    s>   t    t||| _t|jtr2t|j | _	n|j| _	d S r@   )
r0   r1   r   rd   r}   r7   r:   r8   r	   intermediate_act_fnr;   r$   r`   r   r<   r    r!   r1   '  s
    
zMobileViTIntermediate.__init__rj   c                 C   s   |  |}| |}|S r@   )r}   r   r   r    r    r!   rA   /  s    

zMobileViTIntermediate.forwardr]   r    r    r<   r!   r   &  s   r   c                       s@   e Zd Zeeedd fddZejejejdddZ  Z	S )MobileViTOutputNr   c                    s*   t    t||| _t|j| _d S r@   r|   r   r<   r    r!   r1   6  s    
zMobileViTOutput.__init__)rk   input_tensorr   c                 C   s    |  |}| |}|| }|S r@   r~   )r;   rk   r   r    r    r!   rA   ;  s    

zMobileViTOutput.forwardr]   r    r    r<   r!   r   5  s   r   c                       s<   e Zd Zeeedd fddZejejdddZ  Z	S )MobileViTTransformerLayerNr   c                    sZ   t    t||| _t|||| _t|||| _tj	||j
d| _tj	||j
d| _d S )Nr/   )r0   r1   r   r   r   intermediater   r   r   	LayerNormlayer_norm_epslayernorm_beforelayernorm_afterr   r<   r    r!   r1   C  s    
z"MobileViTTransformerLayer.__init__rj   c                 C   s<   |  | |}|| }| |}| |}| ||}|S r@   )r   r   r   r   r   )r;   rk   r   Zlayer_outputr    r    r!   rA   K  s    

z!MobileViTTransformerLayer.forwardr]   r    r    r<   r!   r   B  s   r   c                       s<   e Zd Zeeedd fddZejejdddZ  Z	S )MobileViTTransformerN)r$   r`   rU   r   c                    sJ   t    t | _t|D ](}t||t||j d}| j	| qd S )N)r`   r   )
r0   r1   r   rV   rW   rX   r   r   Z	mlp_ratiorY   )r;   r$   r`   rU   rz   transformer_layerr<   r    r!   r1   V  s    

zMobileViTTransformer.__init__rj   c                 C   s   | j D ]}||}q|S r@   r[   )r;   rk   r\   r    r    r!   rA   b  s    

zMobileViTTransformer.forwardr]   r    r    r<   r!   r   U  s   r   c                
       s|   e Zd ZdZdeeeeeeedd fddZeje	eje
f ddd	Zeje
ejd
ddZejejdddZ  ZS )MobileViTLayerzC
    MobileViT block: https://huggingface.co/papers/2110.02178
    r   N)r$   r%   r&   r(   r`   rU   r+   r   c                    s   t    |j| _|j| _|dkrXt||||dkr6|nd|dkrH|d ndd| _|}nd | _t||||jd| _	t|||dddd| _
t|||d| _tj||jd| _t|||dd| _t|d| ||jd| _d S )	Nr   r   )r%   r&   r(   r+   rK   F)r%   r&   r'   r,   r-   )r`   rU   r   )r0   r1   Z
patch_sizepatch_widthpatch_heightrI   downsampling_layerr#   Zconv_kernel_sizeconv_kxkconv_1x1r   transformerr   r   r   	layernormconv_projectionfusion)r;   r$   r%   r&   r(   r`   rU   r+   r<   r    r!   r1   m  sN    

	zMobileViTLayer.__init__r>   c                 C   sN  | j | j }}t|| }|j\}}}}tj rHtt|| | ntt	|| | }	tj r~tt|| | ntt	|| | }
d}|
|ks|	|krt
jj||	|
fddd}d}|
| }|	| }|| }||| | |||}|dd}|||||}|dd}||| |d}||f||||||d	}||fS )
NFbilinearrx   modeZalign_cornersTr   r   r   rl   )	orig_sizery   channelsinterpolatenum_patchesnum_patches_widthnum_patches_height)r   r   r   rp   rF   Zjit
is_tracingr   ceilrt   r   rv   r   reshaperr   )r;   r?   r   r   
patch_areary   r   Zorig_heightZ
orig_widthZ
new_heightZ	new_widthr   num_patch_widthnum_patch_heightr   patches	info_dictr    r    r!   	unfolding  sH    	zMobileViTLayer.unfolding)r   r   r   c                 C   s   | j | j }}t|| }|d }|d }|d }|d }	|d }
| |||d}|dd}||| |	 |
||}|dd	}||||	| |
| }|d
 rtjj	||d ddd}|S )Nry   r   r   r   r   rl   r   r   r   r   r   r   Fr   )
r   r   r   rw   rq   rr   r   r   rv   r   )r;   r   r   r   r   r   ry   r   r   r   r   r?   r    r    r!   folding  s*    zMobileViTLayer.foldingc                 C   s|   | j r|  |}|}| |}| |}| |\}}| |}| |}| ||}| |}| t	j
||fdd}|S Nr   rn   )r   r   r   r   r   r   r   r   r   rF   cat)r;   r?   rR   r   r   r    r    r!   rA     s    





zMobileViTLayer.forward)r   )rB   rC   rD   rS   r   r   r1   rF   rG   tupledictr   r   rA   rH   r    r    r<   r!   r   h  s    :3r   c                       sD   e Zd Zedd fddZd
ejeeee	e
f ddd	Z  ZS )MobileViTEncoderNr$   r   c           
   	      sZ  t    || _t | _d| _d }}|jdkr<d}d}n|jdkrJd}d}t||j	d |j	d ddd}| j
| t||j	d |j	d dd	d}| j
| t||j	d |j	d	 d|jd dd
}| j
| |r|d9 }t||j	d	 |j	d d|jd d|d}| j
| |r"|d9 }t||j	d |j	d d|jd d	|d}	| j
|	 d S )NFr   T   r   r   )r%   r&   r(   rU   r   r   )r%   r&   r(   r`   rU      )r%   r&   r(   r`   rU   r+      )r0   r1   r$   r   rV   rW   Zgradient_checkpointingZoutput_striderT   neck_hidden_sizesrY   r   Zhidden_sizes)
r;   r$   Zdilate_layer_4Zdilate_layer_5r+   Zlayer_1Zlayer_2Zlayer_3Zlayer_4Zlayer_5r<   r    r!   r1     sx    



		zMobileViTEncoder.__init__FT)rk   output_hidden_statesreturn_dictr   c                 C   s\   |rdnd }t | jD ]\}}||}|r||f }q|sPtdd ||fD S t||dS )Nr    c                 s   s   | ]}|d ur|V  qd S r@   r    ).0vr    r    r!   	<genexpr>k      z+MobileViTEncoder.forward.<locals>.<genexpr>)last_hidden_staterk   )	enumeraterW   r   r   )r;   rk   r   r   Zall_hidden_statesrZ   r\   r    r    r!   rA   \  s    zMobileViTEncoder.forward)FT)rB   rC   rD   r   r1   rF   rG   rE   r   r   r   rA   rH   r    r    r<   r!   r     s   M  
r   c                   @   s:   e Zd ZU eed< dZdZdZdgZe	j
dddd	ZdS )
MobileViTPreTrainedModelr$   	mobilevitpixel_valuesTr   N)moduler   c                 C   sn   t |tjtjtjfrD|jjjd| jj	d |j
durj|j
j  n&t |tjrj|j
j  |jjd dS )zInitialize the weightsg        )meanZstdNg      ?)r7   r   rd   r3   r5   weightdataZnormal_r$   Zinitializer_ranger*   Zzero_r   Zfill_)r;   r   r    r    r!   _init_weightsx  s    
z&MobileViTPreTrainedModel._init_weights)rB   rC   rD   r   __annotations__Zbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_no_split_modulesr   Moduler   r    r    r    r!   r   p  s   
r   c                       s^   e Zd Zdeed fddZdd Zedee	j
 ee ee eeef dd	d
Z  ZS )MobileViTModelT)r$   expand_outputc                    sn   t  | || _|| _t||j|jd ddd| _t|| _	| jrbt||jd |jd dd| _
|   d	S )
aE  
        expand_output (`bool`, *optional*, defaults to `True`):
            Whether to expand the output of the model using a 1x1 convolution. If `True`, the model will apply an additional
            1x1 convolution to expand the output channels from `config.neck_hidden_sizes[5]` to `config.neck_hidden_sizes[6]`.
        r   r   r   )r%   r&   r'   r(   r      r   rK   N)r0   r1   r$   r   r#   Znum_channelsr   	conv_stemr   encoderconv_1x1_exp	post_init)r;   r$   r   r<   r    r!   r1     s&    
zMobileViTModel.__init__c                 C   sF   |  D ]8\}}| jj| }t|tr|jjD ]}|j| q.qdS )zPrunes heads of the model.
        heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel
        N)itemsr   rW   r7   r   r   r   r   )r;   Zheads_to_pruneZlayer_indexr   Zmobilevit_layerr   r    r    r!   _prune_heads  s
    
zMobileViTModel._prune_headsN)r   r   r   r   c           	      C   s   |d ur|n| j j}|d ur |n| j j}|d u r8td| |}| j|||d}| jr|| |d }tj	|ddgdd}n|d }d }|s|d ur||fn|f}||dd   S t
|||jd	S )
Nz You have to specify pixel_valuesr   r   r   rm   rl   F)ro   Zkeepdimr   )r   pooler_outputrk   )r$   r   use_return_dictr2   r   r   r   r   rF   r   r   rk   )	r;   r   r   r   Zembedding_outputZencoder_outputsr   pooled_outputr   r    r    r!   rA     s0    
zMobileViTModel.forward)T)NNN)rB   rC   rD   r   rE   r1   r   r   r   rF   rG   r   r   r   rA   rH   r    r    r<   r!   r     s   
   
r   z
    MobileViT model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
    ImageNet.
    )Zcustom_introc                	       s\   e Zd Zedd fddZedeej ee	 eej ee	 e
eef dddZ  ZS )	MobileViTForImageClassificationNr   c                    sd   t  | |j| _t|| _tj|jdd| _|jdkrNt	|j
d |jnt | _|   d S )NT)Zinplacer   rl   )r0   r1   
num_labelsr   r   r   rg   classifier_dropout_probrh   rd   r   ZIdentity
classifierr   r;   r$   r<   r    r!   r1     s    
$z(MobileViTForImageClassification.__init__)r   r   labelsr   r   c                 C   sr  |dur|n| j j}| j|||d}|r.|jn|d }| | |}d}|dur2| j jdu r| jdkrtd| j _n4| jdkr|jt	j
ks|jt	jkrd| j _nd| j _| j jdkrt }	| jdkr|	| | }n
|	||}nN| j jdkrt }	|	|d| j|d}n| j jdkr2t }	|	||}|sb|f|dd  }
|dur^|f|
 S |
S t|||jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   Z
regressionZsingle_label_classificationZmulti_label_classificationrl   r   )losslogitsrk   )r$   r   r   r   r   rh   Zproblem_typer   ZdtyperF   longr   r   Zsqueezer   rq   r   r   rk   )r;   r   r   r   r   outputsr   r   r   loss_fctr   r    r    r!   rA     s>    


"


z'MobileViTForImageClassification.forward)NNNN)rB   rC   rD   r   r1   r   r   rF   rG   rE   r   r   r   rA   rH   r    r    r<   r!   r     s       
r   c                       s<   e Zd Zeeedd fddZejejdddZ  Z	S )MobileViTASPPPoolingN)r$   r%   r&   r   c              	      s4   t    tjdd| _t|||ddddd| _d S )Nr   )Zoutput_sizeTrelu)r%   r&   r'   r(   r,   r-   )r0   r1   r   ZAdaptiveAvgPool2dglobal_poolr#   r   )r;   r$   r%   r&   r<   r    r!   r1   *  s    
zMobileViTASPPPooling.__init__r>   c                 C   s:   |j dd  }| |}| |}tjj||ddd}|S )Nrm   r   Fr   )rp   r   r   r   rv   r   )r;   r?   Zspatial_sizer    r    r!   rA   9  s
    

zMobileViTASPPPooling.forwardr]   r    r    r<   r!   r   )  s   r   c                       s<   e Zd ZdZedd fddZejejdddZ  Z	S )	MobileViTASPPz
    ASPP module defined in DeepLab papers: https://huggingface.co/papers/1606.00915, https://huggingface.co/papers/1706.05587
    Nr   c                    s   t     jd  jt jdkr0tdt | _	t
 ddd}| j	| | j	 fdd jD  t }| j	| t
 d	 ddd| _tj jd
| _d S )Nrm   r   z"Expected 3 values for atrous_ratesr   r   rL   c              
      s    g | ]}t  d |ddqS )r   r   )r%   r&   r'   r+   r-   )r#   )r   Zrater$   r%   r&   r    r!   
<listcomp>[  s   	z*MobileViTASPP.__init__.<locals>.<listcomp>r   )p)r0   r1   r   aspp_out_channelsr   Zatrous_ratesr2   r   rV   convsr#   rY   extendr   projectrg   Zaspp_dropout_probrh   )r;   r$   Zin_projectionZ
pool_layerr<   r   r!   r1   F  s2    


	zMobileViTASPP.__init__r>   c                 C   sD   g }| j D ]}||| q
tj|dd}| |}| |}|S r   )r   rY   rF   r   r   rh   )r;   r?   ZpyramidconvZpooled_featuresr    r    r!   rA   q  s    


zMobileViTASPP.forward
rB   rC   rD   rS   r   r1   rF   rG   rA   rH   r    r    r<   r!   r   A  s   +r   c                       s<   e Zd ZdZedd fddZejejdddZ  Z	S )	MobileViTDeepLabV3zJ
    DeepLabv3 architecture: https://huggingface.co/papers/1706.05587
    Nr   c              	      sB   t    t|| _t|j| _t||j	|j
ddddd| _d S )Nr   FT)r%   r&   r'   r,   r-   r*   )r0   r1   r   asppr   Z	Dropout2dr   rh   r#   r   r   r   r   r<   r    r!   r1     s    

zMobileViTDeepLabV3.__init__rj   c                 C   s&   |  |d }| |}| |}|S )Nrl   )r   rh   r   )r;   rk   r?   r    r    r!   rA     s    

zMobileViTDeepLabV3.forwardr   r    r    r<   r!   r   |  s   r   zX
    MobileViT model with a semantic segmentation head on top, e.g. for Pascal VOC.
    c                	       s\   e Zd Zedd fddZedeej eej ee	 ee	 e
eef dddZ  ZS )	 MobileViTForSemanticSegmentationNr   c                    s8   t  | |j| _t|dd| _t|| _|   d S )NF)r   )r0   r1   r   r   r   r   segmentation_headr   r   r<   r    r!   r1     s
    
z)MobileViTForSemanticSegmentation.__init__)r   r   r   r   r   c                 C   s  |dur|n| j j}|dur |n| j j}|durD| j jdkrDtd| j|d|d}|r^|jn|d }| |}d}|durtj	j
||jdd ddd	}	t| j jd
}
|
|	|}|s|r|f|dd  }n|f|dd  }|dur|f| S |S t|||r|jndddS )a{  
        labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
            Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> import requests
        >>> import torch
        >>> from PIL import Image
        >>> from transformers import AutoImageProcessor, MobileViTForSemanticSegmentation

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small")
        >>> model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small")

        >>> inputs = image_processor(images=image, return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**inputs)

        >>> # logits are of shape (batch_size, num_labels, height, width)
        >>> logits = outputs.logits
        ```Nr   z/The number of labels should be greater than oneTr   rm   r   Fr   )Zignore_indexr   )r   r   rk   Z
attentions)r$   r   r   r   r2   r   rk   r   r   rv   r   rp   r   Zsemantic_loss_ignore_indexr   )r;   r   r   r   r   r   Zencoder_hidden_statesr   r   Zupsampled_logitsr   r   r    r    r!   rA     s<    $

z(MobileViTForSemanticSegmentation.forward)NNNN)rB   rC   rD   r   r1   r   r   rF   rG   rE   r   r   r   rA   rH   r    r    r<   r!   r     s   
    
r   )r   r   r   r   )r   N):rS   rt   typingr   r   rF   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   Zactivationsr	   Zmodeling_layersr
   Zmodeling_outputsr   r   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   utilsr   r   r   Zconfiguration_mobilevitr   Z
get_loggerrB   loggerr   r"   r   r#   rI   rT   r^   r{   r   r   r   r   r   r   r   r   r   r   r   r   r   r   __all__r    r    r    r!   <module>   sX   
@09 *_UH;X