a
    hq                     @   s  d Z ddlZddlmZmZ ddlZddlZddlmZ ddlm	Z	 ddl
mZmZ ddlmZ dd	lmZmZ dd
lmZmZ ddlmZ eeZd7ejeeejdddZG dd dejZG dd dejZ G dd dejZ!G dd dejZ"G dd dejZ#G dd dejZ$G dd dejZ%G d d! d!ejZ&G d"d# d#ejZ'eG d$d% d%eZ(eG d&d' d'e(Z)G d(d) d)ejZ*G d*d+ d+ejZ+G d,d- d-ejZ,G d.d/ d/ejZ-G d0d1 d1ejZ.ed2d3G d4d5 d5e(Z/g d6Z0dS )8zPyTorch GLPN model.    N)OptionalUnion)nn   )ACT2FN)BaseModelOutputDepthEstimatorOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)auto_docstringlogging   )
GLPNConfig        F)input	drop_probtrainingreturnc                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    r   r   r   )r   )dtypedevice)shapendimtorchZrandr   r   Zfloor_div)r   r   r   Z	keep_probr   Zrandom_tensoroutput r   b/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/glpn/modeling_glpn.py	drop_path$   s    
r   c                       sP   e Zd ZdZdee dd fddZejejdddZ	e
d	d
dZ  ZS )GLPNDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).N)r   r   c                    s   t    || _d S N)super__init__r   )selfr   	__class__r   r   r"   <   s    
zGLPNDropPath.__init__hidden_statesr   c                 C   s   t || j| jS r    )r   r   r   )r#   r'   r   r   r   forward@   s    zGLPNDropPath.forward)r   c                 C   s   d| j  S )Nzp=)r   )r#   r   r   r   
extra_reprC   s    zGLPNDropPath.extra_repr)N)__name__
__module____qualname____doc__r   floatr"   r   Tensorr(   strr)   __classcell__r   r   r$   r   r   9   s   r   c                       s(   e Zd ZdZ fddZdd Z  ZS )GLPNOverlapPatchEmbeddingsz+Construct the overlapping patch embeddings.c                    s4   t    tj|||||d d| _t|| _d S )N   kernel_sizestridepadding)r!   r"   r   Conv2dproj	LayerNorm
layer_norm)r#   
patch_sizer6   num_channelshidden_sizer$   r   r   r"   K   s    
z#GLPNOverlapPatchEmbeddings.__init__c                 C   s>   |  |}|j\}}}}|ddd}| |}|||fS )Nr3   r   )r9   r   flatten	transposer;   )r#   pixel_values
embeddings_heightwidthr   r   r   r(   W   s
    

z"GLPNOverlapPatchEmbeddings.forwardr*   r+   r,   r-   r"   r(   r1   r   r   r$   r   r2   H   s   r2   c                       s*   e Zd ZdZ fddZdddZ  ZS )GLPNEfficientSelfAttentionzSegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT
    paper](https://huggingface.co/papers/2102.12122).c                    s   t    || _|| _| j| j dkr@td| j d| j dt| j| j | _| j| j | _t	| j| j| _
t	| j| j| _t	| j| j| _t|j| _|| _|dkrtj||||d| _t|| _d S )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ()r   )r5   r6   )r!   r"   r>   num_attention_heads
ValueErrorintattention_head_sizeall_head_sizer   LinearquerykeyvalueDropoutZattention_probs_dropout_probdropoutsr_ratior8   srr:   r;   r#   configr>   rI   sequence_reduction_ratior$   r   r   r"   f   s*    

z#GLPNEfficientSelfAttention.__init__Fc                 C   sf  |j \}}}| ||d| j| jdd}| jdkr|j \}}	}
|ddd||
||}| 	|}|||
dddd}| 
|}| ||d| j| jdd}| ||d| j| jdd}t||dd}|t| j }tjj|dd}| |}t||}|dddd }| d d | jf }||}|r\||fn|f}|S )Nr   r3   r   dimr   )r   rO   viewrI   rL   r@   rT   permutereshaperU   r;   rP   rQ   r   matmulmathsqrtr   Z
functionalZsoftmaxrS   
contiguoussizerM   )r#   r'   rD   rE   output_attentions
batch_sizeZ
seq_lengthrC   Zquery_layerseq_lenr=   Z	key_layerZvalue_layerZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr   r   r   r(      sF    







z"GLPNEfficientSelfAttention.forward)FrF   r   r   r$   r   rG   b   s     rG   c                       s$   e Zd Z fddZdd Z  ZS )GLPNSelfOutputc                    s*   t    t||| _t|j| _d S r    )r!   r"   r   rN   denserR   hidden_dropout_probrS   )r#   rW   r>   r$   r   r   r"      s    
zGLPNSelfOutput.__init__c                 C   s   |  |}| |}|S r    )rj   rS   )r#   r'   Zinput_tensorr   r   r   r(      s    

zGLPNSelfOutput.forwardr*   r+   r,   r"   r(   r1   r   r   r$   r   ri      s   ri   c                       s.   e Zd Z fddZdd ZdddZ  ZS )	GLPNAttentionc                    s6   t    t||||d| _t||d| _t | _d S )N)rW   r>   rI   rX   )r>   )r!   r"   rG   r#   ri   r   setpruned_headsrV   r$   r   r   r"      s    
zGLPNAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r[   )lenr
   r#   rI   rL   ro   r   rO   rP   rQ   r   rj   rM   union)r#   headsindexr   r   r   prune_heads   s    zGLPNAttention.prune_headsFc                 C   s6   |  ||||}| |d |}|f|dd   }|S )Nr   r   )r#   r   )r#   r'   rD   rE   re   Zself_outputsattention_outputrh   r   r   r   r(      s    zGLPNAttention.forward)F)r*   r+   r,   r"   rt   r(   r1   r   r   r$   r   rm      s   rm   c                       s&   e Zd Zd fdd	Zdd Z  ZS )
GLPNDWConv   c              	      s(   t    tj||dddd|d| _d S )Nr   r   T)biasgroups)r!   r"   r   r8   dwconv)r#   r\   r$   r   r   r"      s    
zGLPNDWConv.__init__c                 C   sD   |j \}}}|dd||||}| |}|ddd}|S )Nr   r3   )r   r@   r]   rz   r?   )r#   r'   rD   rE   rf   rg   r=   r   r   r   r(      s
    
zGLPNDWConv.forward)rw   rl   r   r   r$   r   rv      s   rv   c                       s&   e Zd Zd fdd	Zdd Z  ZS )
GLPNMixFFNNc                    sl   t    |p|}t||| _t|| _t|jt	rDt
|j | _n|j| _t||| _t|j| _d S r    )r!   r"   r   rN   dense1rv   rz   
isinstanceZ
hidden_actr0   r   intermediate_act_fndense2rR   rk   rS   )r#   rW   in_featureshidden_featuresZout_featuresr$   r   r   r"      s    

zGLPNMixFFN.__init__c                 C   sD   |  |}| |||}| |}| |}| |}| |}|S r    )r|   rz   r~   rS   r   )r#   r'   rD   rE   r   r   r   r(     s    




zGLPNMixFFN.forward)NNrl   r   r   r$   r   r{      s   r{   c                       s*   e Zd ZdZ fddZdddZ  ZS )	GLPNLayerzCThis corresponds to the Block class in the original implementation.c                    sn   t    t|| _t||||d| _|dkr8t|nt | _	t|| _
t|| }t|||d| _d S )N)r>   rI   rX   r   )r   r   )r!   r"   r   r:   layer_norm_1rm   	attentionr   Identityr   layer_norm_2rK   r{   mlp)r#   rW   r>   rI   r   rX   	mlp_ratioZmlp_hidden_sizer$   r   r   r"     s    
zGLPNLayer.__init__Fc           
      C   sr   | j | ||||d}|d }|dd  }| |}|| }| | |||}| |}|| }	|	f| }|S )N)re   r   r   )r   r   r   r   r   )
r#   r'   rD   rE   re   Zself_attention_outputsru   rh   Z
mlp_outputZlayer_outputr   r   r   r(   (  s    


zGLPNLayer.forward)FrF   r   r   r$   r   r     s   r   c                       s&   e Zd Z fddZdddZ  ZS )GLPNEncoderc           	         sX  t     | _dd tjd jt jddD }g }t j	D ]D}|
t j|  j| |dkrj jn j|d   j| d qBt|| _g }d}t j	D ]}g }|dkr| j|d  7 }t j| D ]>}|
t  j|  j| |||   j|  j| d q|
t| qt|| _t fd	dt j	D | _d S )
Nc                 S   s   g | ]}|  qS r   )item).0xr   r   r   
<listcomp>H      z(GLPNEncoder.__init__.<locals>.<listcomp>r   cpu)r   r   )r<   r6   r=   r>   )r>   rI   r   rX   r   c                    s   g | ]}t  j| qS r   )r   r:   hidden_sizes)r   irW   r   r   r   p  r   )r!   r"   rW   r   ZlinspaceZdrop_path_ratesumZdepthsrangeZnum_encoder_blocksappendr2   Zpatch_sizesstridesr=   r   r   
ModuleListpatch_embeddingsr   rI   Z	sr_ratiosZ
mlp_ratiosblockr;   )	r#   rW   ZdprrB   r   blockscurZlayersjr$   r   r   r"   C  sH    
$

zGLPNEncoder.__init__FTc                 C   s   |rdnd }|rdnd }|j d }|}tt| j| j| jD ]\}	}
|
\}}}||\}}}t|D ]0\}}|||||}|d }|rd||d f }qd||}||||ddddd }|r<||f }q<|st	dd |||fD S t
|||d	S )
Nr   r   r   rY   r   r3   c                 s   s   | ]}|d ur|V  qd S r    r   )r   vr   r   r   	<genexpr>  r   z&GLPNEncoder.forward.<locals>.<genexpr>Zlast_hidden_stater'   
attentions)r   	enumeratezipr   r   r;   r_   r^   rc   tupler   )r#   rA   re   output_hidden_statesreturn_dictZall_hidden_statesZall_self_attentionsrf   r'   idxr   Zembedding_layerZblock_layerZ
norm_layerrD   rE   r   ZblkZlayer_outputsr   r   r   r(   s  s.    

 zGLPNEncoder.forward)FFTrl   r   r   r$   r   r   B  s
   3   r   c                   @   s*   e Zd ZU eed< dZdZg Zdd ZdS )GLPNPreTrainedModelrW   glpnrA   c                 C   s   t |tjtjfr@|jjjd| jjd |j	dur|j	j
  nlt |tjr|jjjd| jjd |jdur|jj|j 
  n,t |tjtjfr|j	j
  |jjd dS )zInitialize the weightsr   )meanZstdNg      ?)r}   r   rN   r8   weightdataZnormal_rW   Zinitializer_rangerx   Zzero_Z	EmbeddingZpadding_idxr:   BatchNorm2dZfill_)r#   moduler   r   r   _init_weights  s    

z!GLPNPreTrainedModel._init_weightsN)	r*   r+   r,   r   __annotations__Zbase_model_prefixZmain_input_nameZ_no_split_modulesr   r   r   r   r   r     s
   
r   c                	       sV   e Zd Z fddZdd Zed	ejee	 ee	 ee	 e
eef dddZ  ZS )
	GLPNModelc                    s(   t  | || _t|| _|   d S r    )r!   r"   rW   r   encoder	post_initr#   rW   r$   r   r   r"     s    
zGLPNModel.__init__c                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   layerr   rt   )r#   Zheads_to_pruner   rr   r   r   r   _prune_heads  s    zGLPNModel._prune_headsN)rA   re   r   r   r   c                 C   s~   |d ur|n| j j}|d ur |n| j j}|d ur4|n| j j}| j||||d}|d }|sl|f|dd   S t||j|jdS )Nre   r   r   r   r   r   )rW   re   r   use_return_dictr   r   r'   r   )r#   rA   re   r   r   Zencoder_outputsZsequence_outputr   r   r   r(     s$    	zGLPNModel.forward)NNN)r*   r+   r,   r"   r   r   r   FloatTensorr   boolr   r   r   r(   r1   r   r   r$   r   r     s   
   
r   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )GLPNSelectiveFeatureFusionz
    Selective Feature Fusion module, as explained in the [paper](https://huggingface.co/papers/2201.07436) (section 3.4). This
    module adaptively selects and integrates local and global features by attaining an attention map for each feature.
    @   c              	      s   t    ttjt|d |ddddt|t | _ttj|t|d ddddtt|d t | _	tjt|d ddddd| _
t | _d S )Nr3   r   r   )in_channelsout_channelsr5   r6   r7   )r!   r"   r   
Sequentialr8   rK   r   ReLUconvolutional_layer1convolutional_layer2convolutional_layer3ZSigmoidsigmoid)r#   Z
in_channelr$   r   r   r"     s    
z#GLPNSelectiveFeatureFusion.__init__c                 C   s   t j||fdd}| |}| |}| |}| |}||d d dd d d d f d ||d d dd d d d f d  }|S )Nr   r[   r   )r   catr   r   r   r   Z	unsqueeze)r#   Zlocal_featuresZglobal_featuresfeaturesZattnZhybrid_featuresr   r   r   r(     s    



(z"GLPNSelectiveFeatureFusion.forward)r   rF   r   r   r$   r   r     s   r   c                       s&   e Zd Z fddZdddZ  ZS )GLPNDecoderStagec                    sP   t    ||k}|s&tj||ddnt | _t|| _tjdddd| _	d S )Nr   )r5   r3   bilinearFZscale_factormodeZalign_corners)
r!   r"   r   r8   r   convolutionr   fusionUpsampleupsample)r#   r   r   should_skipr$   r   r   r"     s
    

zGLPNDecoderStage.__init__Nc                 C   s,   |  |}|d ur| ||}| |}|S r    )r   r   r   )r#   hidden_stateZresidualr   r   r   r(     s    

 zGLPNDecoderStage.forward)Nrl   r   r   r$   r   r     s   r   c                       s8   e Zd Z fddZeej eej dddZ  ZS )GLPNDecoderc                    s\   t    |jd d d }|j t fdd|D | _d | jd _tjdddd| _	d S )	NrY   c                    s   g | ]}t | qS r   )r   )r   r>   r   r   r   r   1  r   z(GLPNDecoder.__init__.<locals>.<listcomp>r   r3   r   Fr   )
r!   r"   r   decoder_hidden_sizer   r   stagesr   r   final_upsample)r#   rW   Zreserved_hidden_sizesr$   r   r   r"   *  s    
zGLPNDecoder.__init__r&   c                 C   sN   g }d }t |d d d | jD ]\}}|||}|| q| ||d< |S )NrY   )r   r   r   r   )r#   r'   Zstage_hidden_statesZstage_hidden_stater   Zstager   r   r   r(   8  s    
zGLPNDecoder.forward	r*   r+   r,   r"   listr   r/   r(   r1   r   r   r$   r   r   )  s   r   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )	SiLogLossz
    Implements the Scale-invariant log scale loss [Eigen et al., 2014](https://huggingface.co/papers/1406.2283).

    $$L=\frac{1}{n} \sum_{i} d_{i}^{2}-\frac{1}{2 n^{2}}\left(\sum_{i} d_{i}^{2}\right)$$ where $d_{i}=\log y_{i}-\log
    y_{i}^{*}$.

          ?c                    s   t    || _d S r    )r!   r"   lambd)r#   r   r$   r   r   r"   M  s    
zSiLogLoss.__init__c                 C   sX   |dk  }t|| t||  }tt|d | jt| d  }|S )Nr   r3   )detachr   logrb   powr   r   )r#   predtargetZ
valid_maskZdiff_loglossr   r   r   r(   Q  s    ,zSiLogLoss.forward)r   rF   r   r   r$   r   r   D  s   r   c                       s4   e Zd Z fddZeej ejdddZ  ZS )GLPNDepthEstimationHeadc                    sR   t    || _|j}ttj||ddddtjddtj|ddddd| _d S )Nr   r   r4   F)Zinplace)	r!   r"   rW   r   r   r   r8   r   head)r#   rW   Zchannelsr$   r   r   r"   Z  s    

z GLPNDepthEstimationHead.__init__r&   c                 C   s8   || j j }| |}t|| j j }|jdd}|S )Nr   r[   )rW   Zhead_in_indexr   r   r   	max_depthZsqueeze)r#   r'   predicted_depthr   r   r   r(   f  s
    
zGLPNDepthEstimationHead.forwardr   r   r   r$   r   r   Y  s   r   zg
    GLPN Model transformer with a lightweight depth estimation head on top e.g. for KITTI, NYUv2.
    )Zcustom_introc                
       s\   e Zd Z fddZedejeej ee ee ee e	e
ej ef dddZ  ZS )GLPNForDepthEstimationc                    s6   t  | t|| _t|| _t|| _|   d S r    )	r!   r"   r   r   r   decoderr   r   r   r   r$   r   r   r"   x  s
    


zGLPNForDepthEstimation.__init__N)rA   labelsre   r   r   r   c                 C   s   |dur|n| j j}|dur |n| j j}| j||d|d}|rD|jn|d }| |}| |}	d}
|dur|t }||	|}
|s|r|	f|dd  }n|	f|dd  }|
dur|
f| S |S t|
|	|r|jnd|j	dS )a  
        labels (`torch.FloatTensor` of shape `(batch_size, height, width)`, *optional*):
            Ground truth depth estimation maps for computing the loss.

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, GLPNForDepthEstimation
        >>> import torch
        >>> import numpy as np
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("vinvino02/glpn-kitti")
        >>> model = GLPNForDepthEstimation.from_pretrained("vinvino02/glpn-kitti")

        >>> # prepare image for the model
        >>> inputs = image_processor(images=image, return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**inputs)

        >>> # interpolate to original size
        >>> post_processed_output = image_processor.post_process_depth_estimation(
        ...     outputs,
        ...     target_sizes=[(image.height, image.width)],
        ... )

        >>> # visualize the prediction
        >>> predicted_depth = post_processed_output[0]["predicted_depth"]
        >>> depth = predicted_depth * 255 / predicted_depth.max()
        >>> depth = depth.detach().cpu().numpy()
        >>> depth = Image.fromarray(depth.astype("uint8"))
        ```NTr   r   r3   )r   r   r'   r   )
rW   r   r   r   r'   r   r   r   r   r   )r#   rA   r   re   r   r   rh   r'   outr   r   Zloss_fctr   r   r   r   r(     s6    .


zGLPNForDepthEstimation.forward)NNNN)r*   r+   r,   r"   r   r   r   r   r   r   r   r/   r   r(   r1   r   r   r$   r   r   r  s   
    r   )r   r   r   r   )r   F)1r-   ra   typingr   r   r   Ztorch.utils.checkpointr   Zactivationsr   Zmodeling_outputsr   r   Zmodeling_utilsr	   Zpytorch_utilsr
   r   utilsr   r   Zconfiguration_glpnr   Z
get_loggerr*   loggerr/   r.   r   r   Moduler   r2   rG   ri   rm   rv   r{   r   r   r   r   r   r   r   r   r   r   __all__r   r   r   r   <module>   sH   
Z'+X5,^