a
    h                     @   s   d Z ddlmZmZ ddlZddlZddlmZ ddlmZm	Z	m
Z
 ddlmZ ddlmZ dd	lmZmZmZmZ dd
lmZ ddlmZmZ ddlmZ eeZd;eeee edddZe de dfe e e e dddZ!G dd dej"Z#G dd dej"Z$G dd dej"Z%G dd dej"Z&G dd  d ej"Z'G d!d" d"ej"Z(G d#d$ d$ej"Z)G d%d& d&eZ*G d'd( d(ej"Z+eG d)d* d*eZ,eG d+d, d,e,Z-ed-d.G d/d0 d0e,Z.G d1d2 d2ej"Z/G d3d4 d4ej"Z0G d5d6 d6ej"Z1ed7d.G d8d9 d9e,Z2g d:Z3dS )<zPyTorch MobileViTV2 model.    )OptionalUnionN)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputWithNoAttention(BaseModelOutputWithPoolingAndNoAttention$ImageClassifierOutputWithNoAttentionSemanticSegmenterOutput)PreTrainedModel)auto_docstringlogging   )MobileViTV2Config   )valuedivisor	min_valuereturnc                 C   sF   |du r|}t |t| |d  | | }|d|  k r>||7 }t|S )a  
    Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the
    original TensorFlow repo. It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    N   g?)maxint)r   r   r   	new_value r   p/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/mobilevitv2/modeling_mobilevitv2.pymake_divisible+   s    r   z-infinf)r   min_valmax_valr   c                 C   s   t |t|| S N)r   minr   r!   r"   r   r   r   clip:   s    r&   c                       sT   e Zd Zdeeeeeeeeeeeef dd fddZe	j
e	j
dd	d
Z  ZS )MobileViTV2ConvLayerr   FTN)configin_channelsout_channelskernel_sizestridegroupsbiasdilationuse_normalizationuse_activationr   c                    s   t    t|d d | }|| dkr@td| d| d|| dkrbtd| d| dtj||||||||dd		| _|	rtj|d
dddd| _nd | _|
rt	|
t
rt|
 | _qt	|jt
rt|j | _q|j| _nd | _d S )Nr   r   r   zInput channels (z) are not divisible by z groups.zOutput channels (Zzeros)	r)   r*   r+   r,   paddingr/   r-   r.   Zpadding_modegh㈵>g?T)Znum_featuresepsZmomentumZaffineZtrack_running_stats)super__init__r   
ValueErrorr   Conv2dconvolutionBatchNorm2dnormalization
isinstancestrr	   
activationZ
hidden_act)selfr(   r)   r*   r+   r,   r-   r.   r/   r0   r1   r2   	__class__r   r   r5   @   sB    



zMobileViTV2ConvLayer.__init__featuresr   c                 C   s6   |  |}| jd ur| |}| jd ur2| |}|S r#   )r8   r:   r=   )r>   rB   r   r   r   forwardv   s    




zMobileViTV2ConvLayer.forward)r   r   Fr   TT)__name__
__module____qualname__r   r   boolr   r<   r5   torchTensorrC   __classcell__r   r   r?   r   r'   ?   s(         
6r'   c                       sF   e Zd ZdZd
eeeeedd fddZejejddd	Z	  Z
S )MobileViTV2InvertedResidualzY
    Inverted residual block (MobileNetv2): https://huggingface.co/papers/1801.04381
    r   N)r(   r)   r*   r,   r/   r   c              	      s   t    ttt||j d}|dvr:td| d|dkoH||k| _t|||dd| _	t|||d|||d| _
t|||dd	d
| _d S )Nr   )r   r   zInvalid stride .r   )r)   r*   r+   r   )r)   r*   r+   r,   r-   r/   Fr)   r*   r+   r1   )r4   r5   r   r   roundZexpand_ratior6   use_residualr'   
expand_1x1conv_3x3
reduce_1x1)r>   r(   r)   r*   r,   r/   Zexpanded_channelsr?   r   r   r5      s0    

z$MobileViTV2InvertedResidual.__init__rA   c                 C   s4   |}|  |}| |}| |}| jr0|| S |S r#   )rP   rQ   rR   rO   )r>   rB   Zresidualr   r   r   rC      s
    


z#MobileViTV2InvertedResidual.forward)r   rD   rE   rF   __doc__r   r   r5   rH   rI   rC   rJ   r   r   r?   r   rK      s    
!rK   c                       sB   e Zd Zd	eeeeedd fddZejejdddZ  Z	S )
MobileViTV2MobileNetLayerr   N)r(   r)   r*   r,   
num_stagesr   c                    sR   t    t | _t|D ]0}t||||dkr4|ndd}| j| |}qd S )Nr   r   )r)   r*   r,   )r4   r5   r   
ModuleListlayerrangerK   append)r>   r(   r)   r*   r,   rV   irX   r?   r   r   r5      s    

z"MobileViTV2MobileNetLayer.__init__rA   c                 C   s   | j D ]}||}q|S r#   rX   )r>   rB   layer_moduler   r   r   rC      s    

z!MobileViTV2MobileNetLayer.forward)r   r   
rD   rE   rF   r   r   r5   rH   rI   rC   rJ   r   r   r?   r   rU      s    
rU   c                       s>   e Zd ZdZeedd fddZejejdddZ	  Z
S )	MobileViTV2LinearSelfAttentionay  
    This layer applies a self-attention with linear complexity, as described in MobileViTV2 paper:
    https://huggingface.co/papers/2206.02680

    Args:
        config (`MobileVitv2Config`):
             Model configuration object
        embed_dim (`int`):
            `input_channels` from an expected input of size :math:`(batch_size, input_channels, height, width)`
    N)r(   	embed_dimr   c              	      s\   t    t||dd|  ddddd| _tj|jd| _t|||ddddd| _|| _d S )Nr   r   TF)r(   r)   r*   r.   r+   r0   r1   p)	r4   r5   r'   qkv_projr   Dropoutattn_dropoutout_projr`   )r>   r(   r`   r?   r   r   r5      s*    


	z'MobileViTV2LinearSelfAttention.__init__hidden_statesr   c           	      C   s   |  |}tj|d| j| jgdd\}}}tjjj|dd}| |}|| }tj|ddd}tjj	||
| }| |}|S )Nr   )Zsplit_size_or_sectionsdimri   Tri   Zkeepdim)rc   rH   splitr`   r   
functionalZsoftmaxre   sumreluZ	expand_asrf   )	r>   rh   Zqkvquerykeyr   Zcontext_scoresZcontext_vectoroutr   r   r   rC      s    
 

z&MobileViTV2LinearSelfAttention.forwardrS   r   r   r?   r   r_      s   r_   c                       s@   e Zd Zd	eeeedd fddZejejdddZ	  Z
S )
MobileViTV2FFN        N)r(   r`   ffn_latent_dimffn_dropoutr   c              
      sZ   t    t|||dddddd| _t|| _t|||dddddd| _t|| _d S )Nr   TF)r(   r)   r*   r+   r,   r.   r0   r1   )	r4   r5   r'   conv1r   rd   dropout1conv2dropout2)r>   r(   r`   rv   rw   r?   r   r   r5     s.    


zMobileViTV2FFN.__init__rg   c                 C   s,   |  |}| |}| |}| |}|S r#   )rx   ry   rz   r{   )r>   rh   r   r   r   rC   (  s
    



zMobileViTV2FFN.forward)ru   rD   rE   rF   r   r   floatr5   rH   rI   rC   rJ   r   r   r?   r   rt     s     rt   c                       s@   e Zd Zd	eeeedd fddZejejdddZ	  Z
S )
MobileViTV2TransformerLayerru   N)r(   r`   rv   dropoutr   c                    sb   t    tjd||jd| _t||| _tj|d| _	tjd||jd| _
t||||j| _d S )Nr   Z
num_groupsnum_channelsr3   ra   )r4   r5   r   	GroupNormlayer_norm_epslayernorm_beforer_   	attentionrd   ry   layernorm_afterrt   rw   ffn)r>   r(   r`   rv   r   r?   r   r   r5   1  s    
z$MobileViTV2TransformerLayer.__init__rg   c                 C   s<   |  |}| |}|| }| |}| |}|| }|S r#   )r   r   r   r   )r>   rh   Zlayernorm_1_outZattention_outputZlayer_outputr   r   r   rC   ?  s    



z#MobileViTV2TransformerLayer.forward)ru   r|   r   r   r?   r   r~   0  s    r~   c                       s<   e Zd Zeeedd fddZejejdddZ  Z	S )MobileViTV2TransformerN)r(   n_layersd_modelr   c                    sf   t    |j}|| g| }dd |D }t | _t|D ]"}t|||| d}| j| q>d S )Nc                 S   s   g | ]}t |d  d  qS )   )r   ).0dr   r   r   
<listcomp>T      z3MobileViTV2Transformer.__init__.<locals>.<listcomp>)r`   rv   )	r4   r5   ffn_multiplierr   rW   rX   rY   r~   rZ   )r>   r(   r   r   r   Zffn_dimsZ	block_idxtransformer_layerr?   r   r   r5   L  s    


zMobileViTV2Transformer.__init__rg   c                 C   s   | j D ]}||}q|S r#   r\   )r>   rh   r]   r   r   r   rC   ]  s    

zMobileViTV2Transformer.forwardr^   r   r   r?   r   r   K  s   r   c                
       s   e Zd ZdZdeeeeeeedd fddZeje	eje	eef f dd	d
Z
eje	eef ejdddZejejdddZ  ZS )MobileViTV2LayerzE
    MobileViTV2 layer: https://huggingface.co/papers/2206.02680
    r   r   N)r(   r)   r*   attn_unit_dimn_attn_blocksr/   r,   r   c           	         s   t    |j| _|j| _|}|dkr\t||||dkr:|nd|dkrL|d ndd| _|}nd | _t||||j|d| _	t|||dddd| _
t|||d| _tjd||jd| _t|||dd	dd| _d S )
Nr   r   )r)   r*   r,   r/   )r)   r*   r+   r-   F)r)   r*   r+   r0   r1   )r   r   r   T)r4   r5   
patch_sizepatch_widthpatch_heightrK   downsampling_layerr'   Zconv_kernel_sizeconv_kxkconv_1x1r   transformerr   r   r   	layernormconv_projection)	r>   r(   r)   r*   r   r   r/   r,   Zcnn_out_dimr?   r   r   r5   h  sN    


zMobileViTV2Layer.__init__)feature_mapr   c                 C   sT   |j \}}}}tjj|| j| jf| j| jfd}|||| j| j d}|||ffS )N)r+   r,   rj   )shaper   rn   Zunfoldr   r   reshape)r>   r   
batch_sizer)   Z
img_heightZ	img_widthpatchesr   r   r   	unfolding  s    

zMobileViTV2Layer.unfolding)r   output_sizer   c                 C   sH   |j \}}}}|||| |}tjj||| j| jf| j| jfd}|S )N)r   r+   r,   )r   r   r   rn   foldr   r   )r>   r   r   r   Zin_dimr   Z	n_patchesr   r   r   r   folding  s    

zMobileViTV2Layer.foldingrA   c                 C   s`   | j r|  |}| |}| |}| |\}}| |}| |}| ||}| |}|S r#   )r   r   r   r   r   r   r   r   )r>   rB   r   r   r   r   r   rC     s    





zMobileViTV2Layer.forward)r   r   r   )rD   rE   rF   rT   r   r   r5   rH   rI   tupler   r   rC   rJ   r   r   r?   r   r   c  s"   
   =$r   c                       sD   e Zd Zedd fddZd
ejeeee	e
f ddd	Z  ZS )MobileViTV2EncoderNr(   r   c                    s  t    || _t | _d| _d }}|jdkr<d}d}n|jdkrJd}d}tt	d|j
 dddddd	}td|j
 dd
}td|j
 dd
}td|j
 dd
}td|j
 dd
}	td|j
 dd
}
t|||ddd}| j| t|||ddd}| j| t|||t|jd |j
 dd
|jd d}| j| |rH|d9 }t|||	t|jd |j
 dd
|jd |d}| j| |r|d9 }t||	|
t|jd |j
 dd
|jd |d}| j| d S )NFr   Tr   r       @   r%   r   r   r         i     )r)   r*   r,   rV   r   r   )r)   r*   r   r   )r)   r*   r   r   r/   )r4   r5   r(   r   rW   rX   Zgradient_checkpointingZoutput_strider   r&   width_multiplierrU   rZ   r   Zbase_attn_unit_dimsr   )r>   r(   Zdilate_layer_4Zdilate_layer_5r/   layer_0_dimZlayer_1_dimZlayer_2_dimZlayer_3_dimZlayer_4_dimZlayer_5_dimZlayer_1Zlayer_2Zlayer_3Zlayer_4Zlayer_5r?   r   r   r5     s    



zMobileViTV2Encoder.__init__FT)rh   output_hidden_statesreturn_dictr   c                 C   s\   |rdnd }t | jD ]\}}||}|r||f }q|sPtdd ||fD S t||dS )Nr   c                 s   s   | ]}|d ur|V  qd S r#   r   )r   vr   r   r   	<genexpr>6  r   z-MobileViTV2Encoder.forward.<locals>.<genexpr>)last_hidden_staterh   )	enumeraterX   r   r   )r>   rh   r   r   Zall_hidden_statesr[   r]   r   r   r   rC   '  s    zMobileViTV2Encoder.forward)FT)rD   rE   rF   r   r5   rH   rI   rG   r   r   r   rC   rJ   r   r   r?   r   r     s   T  
r   c                   @   s:   e Zd ZU eed< dZdZdZdgZe	j
dddd	ZdS )
MobileViTV2PreTrainedModelr(   mobilevitv2pixel_valuesTr   N)moduler   c                 C   sn   t |tjtjtjfrD|jjjd| jj	d |j
durj|j
j  n&t |tjrj|j
j  |jjd dS )zInitialize the weightsru   )meanZstdNg      ?)r;   r   Linearr7   r9   weightdataZnormal_r(   Zinitializer_ranger.   Zzero_r   Zfill_)r>   r   r   r   r   _init_weightsC  s    
z(MobileViTV2PreTrainedModel._init_weights)rD   rE   rF   r   __annotations__Zbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_no_split_modulesr   Moduler   r   r   r   r   r   ;  s   
r   c                       s^   e Zd Zdeed fddZdd Zedee	j
 ee ee eeef dd	d
Z  ZS )MobileViTV2ModelT)r(   expand_outputc              	      sf   t  | || _|| _ttd|j dddddd}t||j|ddd	d	d
| _	t
|| _|   dS )a  
        expand_output (`bool`, *optional*, defaults to `True`):
            Whether to expand the output of the model. If `True`, the model will output pooled features in addition to
            hidden states. If `False`, only the hidden states will be returned.
        r   r   r   r%   r   r   r   r   Tr)   r*   r+   r,   r0   r1   N)r4   r5   r(   r   r   r&   r   r'   r   	conv_stemr   encoder	post_init)r>   r(   r   r   r?   r   r   r5   R  s"    	
zMobileViTV2Model.__init__c                 C   sF   |  D ]8\}}| jj| }t|tr|jjD ]}|j| q.qdS )zPrunes heads of the model.
        heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel
        N)itemsr   rX   r;   r   r   r   Zprune_heads)r>   Zheads_to_pruneZlayer_indexZheadsZmobilevitv2_layerr   r   r   r   _prune_headsn  s
    
zMobileViTV2Model._prune_headsN)r   r   r   r   c           	      C   s   |d ur|n| j j}|d ur |n| j j}|d u r8td| |}| j|||d}| jrv|d }tj|ddgdd}n|d }d }|s|d ur||fn|f}||dd   S t	|||j
d	S )
Nz You have to specify pixel_valuesr   r   r   rj   Frl   r   )r   pooler_outputrh   )r(   r   use_return_dictr6   r   r   r   rH   r   r   rh   )	r>   r   r   r   Zembedding_outputZencoder_outputsr   pooled_outputoutputr   r   r   rC   x  s0    
zMobileViTV2Model.forward)T)NNN)rD   rE   rF   r   rG   r5   r   r   r   rH   rI   r   r   r   rC   rJ   r   r   r?   r   r   P  s   
   
r   z
    MobileViTV2 model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
    ImageNet.
    )Zcustom_introc                	       s\   e Zd Zedd fddZedeej ee	 eej ee	 e
eef dddZ  ZS )	!MobileViTV2ForImageClassificationNr   c                    s`   t  | |j| _t|| _td|j dd}|jdkrJtj||jdnt	 | _
|   d S )Nr   r   r   r   )Zin_featuresZout_features)r4   r5   
num_labelsr   r   r   r   r   r   ZIdentity
classifierr   )r>   r(   r*   r?   r   r   r5     s    
z*MobileViTV2ForImageClassification.__init__)r   r   labelsr   r   c                 C   sl  |dur|n| j j}| j|||d}|r.|jn|d }| |}d}|dur,| j jdu r| jdkrnd| j _n4| jdkr|jtj	ks|jtj
krd| j _nd| j _| j jdkrt }	| jdkr|	| | }n
|	||}nN| j jdkrt }	|	|d| j|d}n| j jdkr,t }	|	||}|s\|f|dd  }
|durX|f|
 S |
S t|||jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   Z
regressionZsingle_label_classificationZmulti_label_classificationrj   r   )losslogitsrh   )r(   r   r   r   r   Zproblem_typer   ZdtyperH   longr   r   Zsqueezer   viewr   r   rh   )r>   r   r   r   r   outputsr   r   r   loss_fctr   r   r   r   rC     s>    



"


z)MobileViTV2ForImageClassification.forward)NNNN)rD   rE   rF   r   r5   r   r   rH   rI   rG   r   r   r   rC   rJ   r   r   r?   r   r     s       
r   c                       s<   e Zd Zeeedd fddZejejdddZ  Z	S )MobileViTV2ASPPPoolingN)r(   r)   r*   r   c              	      s4   t    tjdd| _t|||ddddd| _d S )Nr   )r   Trp   r   )r4   r5   r   ZAdaptiveAvgPool2dglobal_poolr'   r   )r>   r(   r)   r*   r?   r   r   r5     s    
zMobileViTV2ASPPPooling.__init__rA   c                 C   s:   |j dd  }| |}| |}tjj||ddd}|S )Nr   bilinearFsizemodeZalign_corners)r   r   r   r   rn   interpolate)r>   rB   Zspatial_sizer   r   r   rC     s
    

zMobileViTV2ASPPPooling.forwardr^   r   r   r?   r   r     s   r   c                       s<   e Zd ZdZedd fddZejejdddZ  Z	S )	MobileViTV2ASPPz
    ASPP module defined in DeepLab papers: https://huggingface.co/papers/1606.00915, https://huggingface.co/papers/1706.05587
    Nr   c                    s   t    td j dd}| jt jdkr<tdt	 | _
t ddd}| j
| | j
 fd	d
 jD  t }| j
| t d ddd| _tj jd| _d S )Nr   r   r   r   z"Expected 3 values for atrous_ratesr   rp   rM   c              
      s    g | ]}t  d |ddqS )r   rp   )r)   r*   r+   r/   r1   )r'   )r   Zrater(   r)   r*   r   r   r   '  s   	z,MobileViTV2ASPP.__init__.<locals>.<listcomp>   ra   )r4   r5   r   r   aspp_out_channelslenZatrous_ratesr6   r   rW   convsr'   rZ   extendr   projectrd   Zaspp_dropout_probr   )r>   r(   Zencoder_out_channelsZin_projectionZ
pool_layerr?   r   r   r5     s4    

	zMobileViTV2ASPP.__init__rA   c                 C   sD   g }| j D ]}||| q
tj|dd}| |}| |}|S )Nr   rk   )r   rZ   rH   catr   r   )r>   rB   ZpyramidconvZpooled_featuresr   r   r   rC   =  s    


zMobileViTV2ASPP.forward
rD   rE   rF   rT   r   r5   rH   rI   rC   rJ   r   r   r?   r   r     s   ,r   c                       s<   e Zd ZdZedd fddZejejdddZ  Z	S )	MobileViTV2DeepLabV3zJ
    DeepLabv3 architecture: https://huggingface.co/papers/1706.05587
    Nr   c              	      sB   t    t|| _t|j| _t||j	|j
ddddd| _d S )Nr   FT)r)   r*   r+   r0   r1   r.   )r4   r5   r   asppr   Z	Dropout2dZclassifier_dropout_probr   r'   r   r   r   r>   r(   r?   r   r   r5   N  s    

zMobileViTV2DeepLabV3.__init__rg   c                 C   s&   |  |d }| |}| |}|S )Nrj   )r   r   r   )r>   rh   rB   r   r   r   rC   ^  s    

zMobileViTV2DeepLabV3.forwardr   r   r   r?   r   r   I  s   r   zZ
    MobileViTV2 model with a semantic segmentation head on top, e.g. for Pascal VOC.
    c                	       s\   e Zd Zedd fddZedeej eej ee	 ee	 e
eef dddZ  ZS )	"MobileViTV2ForSemanticSegmentationNr   c                    s8   t  | |j| _t|dd| _t|| _|   d S )NF)r   )r4   r5   r   r   r   r   segmentation_headr   r   r?   r   r   r5   k  s
    
z+MobileViTV2ForSemanticSegmentation.__init__)r   r   r   r   r   c                 C   s  |dur|n| j j}|dur |n| j j}|durD| j jdkrDtd| j|d|d}|r^|jn|d }| |}d}|durtj	j
||jdd ddd	}	t| j jd
}
|
|	|}|s|r|f|dd  }n|f|dd  }|dur|f| S |S t|||r|jndddS )a  
        labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
            Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> import requests
        >>> import torch
        >>> from PIL import Image
        >>> from transformers import AutoImageProcessor, MobileViTV2ForSemanticSegmentation

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("apple/mobilevitv2-1.0-imagenet1k-256")
        >>> model = MobileViTV2ForSemanticSegmentation.from_pretrained("apple/mobilevitv2-1.0-imagenet1k-256")

        >>> inputs = image_processor(images=image, return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**inputs)

        >>> # logits are of shape (batch_size, num_labels, height, width)
        >>> logits = outputs.logits
        ```Nr   z/The number of labels should be greater than oneTr   r   r   Fr   )Zignore_indexr   )r   r   rh   Z
attentions)r(   r   r   r   r6   r   rh   r   r   rn   r   r   r   Zsemantic_loss_ignore_indexr   )r>   r   r   r   r   r   Zencoder_hidden_statesr   r   Zupsampled_logitsr   r   r   r   r   rC   u  s<    $

z*MobileViTV2ForSemanticSegmentation.forward)NNNN)rD   rE   rF   r   r5   r   r   rH   rI   rG   r   r   r   rC   rJ   r   r   r?   r   r   e  s   
    
r   )r   r   r   r   )r   N)4rT   typingr   r   rH   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   Zactivationsr	   Zmodeling_layersr
   Zmodeling_outputsr   r   r   r   Zmodeling_utilsr   utilsr   r   Zconfiguration_mobilevitv2r   Z
get_loggerrD   loggerr   r   r}   r&   r   r'   rK   rU   r_   rt   r~   r   r   r   r   r   r   r   r   r   r   __all__r   r   r   r   <module>   sN   
"A1?)rfRK=X