a
    hw                  	   @   sX  d Z ddlZddlmZ ddlmZmZmZ ddl	Z	ddl
Z	ddl	mZ ddlmZ ddlmZ dd	lmZmZ dd
lmZmZ ddlmZ ddlmZmZ ddlmZmZmZmZ ddl m!Z!m"Z" ddl#m$Z$ e%e&Z'eeddG dd deZ(G dd dej)Z*G dd dej)Z+G dd dej)Z,G dd dej)Z-d;ej)e	j.e	j.e	j.ee	j. e/e/ddd Z0G d!d" d"ej)Z1G d#d$ d$ej)Z2G d%d& d&ej)Z3G d'd( d(ej)Z4G d)d* d*ej)Z5G d+d, d,eZ6G d-d. d.ej)Z7eG d/d0 d0eZ8eG d1d2 d2e8Z9G d3d4 d4ej)Z:G d5d6 d6ej)Z;ed7dG d8d9 d9e8Z<g d:Z=dS )<zPyTorch YOLOS model.    N)	dataclass)CallableOptionalUnion)nn   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPooling)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack) find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputTransformersKwargsauto_docstringlogging)can_return_tuplecheck_model_inputs   )YolosConfigz5
    Output type of [`YolosForObjectDetection`].
    )Zcustom_introc                   @   s   e Zd ZU dZdZeej ed< dZ	ee
 ed< dZeej ed< dZeej ed< dZeee
  ed< dZeej ed< dZeeej  ed	< dZeeej  ed
< dS )YolosObjectDetectionOutputa0  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
        Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
        bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
        scale-invariant IoU loss.
    loss_dict (`Dict`, *optional*):
        A dictionary containing the individual losses. Useful for logging.
    logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
        Classification logits (including no-object) for all queries.
    pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
        Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
        values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
        possible padding). You can use [`~YolosImageProcessor.post_process`] to retrieve the unnormalized bounding
        boxes.
    auxiliary_outputs (`list[Dict]`, *optional*):
        Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
        and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
        `pred_boxes`) for each decoder layer.
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
        Sequence of hidden-states at the output of the last layer of the decoder of the model.
    Nloss	loss_dictlogits
pred_boxesauxiliary_outputslast_hidden_statehidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   dictr   r   r   listr   r    tupler!    r,   r,   d/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/yolos/modeling_yolos.pyr   '   s   
r   c                       s<   e Zd ZdZedd fddZejejdddZ  Z	S )	YolosEmbeddingszT
    Construct the CLS token, detection tokens, position and patch embeddings.

    Nconfigreturnc                    s   t    ttdd|j| _ttd|j|j| _	t
|| _| jj}ttd||j d |j| _t|j| _t|| _|| _d S Nr   )super__init__r   	Parameterr&   zeroshidden_size	cls_tokennum_detection_tokensdetection_tokensYolosPatchEmbeddingspatch_embeddingsnum_patchesposition_embeddingsDropouthidden_dropout_probdropout$InterpolateInitialPositionEmbeddingsinterpolationr0   )selfr0   r=   	__class__r,   r-   r4   T   s    


zYolosEmbeddings.__init__pixel_valuesr1   c                 C   s   |j \}}}}| |}| \}}}| j|dd}	| j|dd}
tj|	||
fdd}| | j	||f}|| }| 
|}|S )Nr   dim)shaper<   sizer8   expandr:   r&   catrC   r>   rA   )rD   rH   
batch_sizenum_channelsheightwidth
embeddingsseq_len_Z
cls_tokensr:   r>   r,   r,   r-   forwardc   s    

zYolosEmbeddings.forward
r"   r#   r$   r%   r   r4   r&   TensorrW   __classcell__r,   r,   rE   r-   r.   N   s   r.   c                       s4   e Zd Zdd fddZdejdddZ  ZS )	rB   Nr1   c                    s   t    || _d S Nr3   r4   r0   rD   r0   rE   r,   r-   r4   y   s    
z-InterpolateInitialPositionEmbeddings.__init__i   i@  c                 C   s  |d d dd d f }|d d d f }|d d | j j d d d f }|d d d| j j d d f }|dd}|j\}}}| j jd | j j | j jd | j j  }	}
||||	|
}|\}}|| j j || j j  }}tjj	|||fddd}|
ddd}tj|||fdd}|S )Nr   r      bicubicFrM   modeZalign_cornersrJ   )r0   r9   	transposerL   
image_size
patch_sizeviewr   
functionalinterpolateflattenr&   rO   )rD   	pos_embedimg_sizecls_pos_embeddet_pos_embedpatch_pos_embedrP   r7   rU   patch_heightpatch_widthrR   rS   new_patch_heightnew_patch_widthscale_pos_embedr,   r,   r-   rW   }   s$      z,InterpolateInitialPositionEmbeddings.forward)r_   r"   r#   r$   r4   r&   rY   rW   rZ   r,   r,   rE   r-   rB   x   s   rB   c                       s4   e Zd Zdd fddZdejdddZ  ZS )	 InterpolateMidPositionEmbeddingsNr[   c                    s   t    || _d S r\   r]   r^   rE   r,   r-   r4      s    
z)InterpolateMidPositionEmbeddings.__init__r_   c                 C   sH  |d d d d dd d f }|d d d f }|d d d d | j j d d d f }|d d d d d| j j d d f }|dd}|j\}}}}	| j jd | j j | j jd | j j  }
}||| ||
|}|\}}|| j j || j j  }}tjj	|||fddd}|
ddd |||| |}tj|||fdd}|S )	Nr   r   r`   r   ra   Frb   rJ   )r0   r9   rd   rL   re   rf   rg   r   rh   ri   rj   
contiguousr&   rO   )rD   rk   rl   rm   rn   ro   depthrP   r7   rU   rp   rq   rR   rS   rr   rs   rt   r,   r,   r-   rW      s.    &&
z(InterpolateMidPositionEmbeddings.forward)r_   ru   r,   r,   rE   r-   rv      s   rv   c                       s4   e Zd ZdZ fddZejejdddZ  ZS )r;   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j }}|j|j }}t|tjj	r8|n||f}t|tjj	rR|n||f}|d |d  |d |d   }|| _|| _|| _|| _
tj||||d| _d S )Nr   r   )Zkernel_sizeZstride)r3   r4   re   rf   rQ   r7   
isinstancecollectionsabcIterabler=   r   Conv2d
projection)rD   r0   re   rf   rQ   r7   r=   rE   r,   r-   r4      s    
 zYolosPatchEmbeddings.__init__rG   c                 C   s<   |j \}}}}|| jkr td| |ddd}|S )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r`   r   )rL   rQ   
ValueErrorr~   rj   rd   )rD   rH   rP   rQ   rR   rS   rT   r,   r,   r-   rW      s    
zYolosPatchEmbeddings.forward)	r"   r#   r$   r%   r4   r&   rY   rW   rZ   r,   r,   rE   r-   r;      s   r;           )modulequerykeyvalueattention_maskscalingrA   c           
      K   s|   t ||dd| }tjj|dt jd|j}tjj	||| j
d}|d urX|| }t ||}	|	dd }	|	|fS )NrI   )rK   dtype)ptrainingr   r`   )r&   matmulrd   r   rh   ZsoftmaxZfloat32tor   rA   r   rw   )
r   r   r   r   r   r   rA   kwargsZattn_weightsZattn_outputr,   r,   r-   eager_attention_forward   s    r   c                       sJ   e Zd Zed fddZdejeej eejejf dddZ	  Z
S )	YolosSelfAttentionr0   c                    s   t    |j|j dkr>t|ds>td|j d|j d|| _|j| _t|j|j | _| j| j | _	|j
| _| jd | _d| _tj|j| j	|jd| _tj|j| j	|jd| _tj|j| j	|jd| _d S )	Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .g      F)bias)r3   r4   r7   num_attention_headshasattrr   r0   intattention_head_sizeall_head_sizeZattention_probs_dropout_probdropout_probr   	is_causalr   LinearZqkv_biasr   r   r   r^   rE   r,   r-   r4      s"    

zYolosSelfAttention.__init__Nr    	head_maskr1   c              
   C   s   |j d }|d| j| jf}| |j| dd}| |j| dd}| |j| dd}t}| j	j
dkr~t| j	j
 }|| ||||| j| j| jsdn| jd\}	}
|	 d d | jf }|	|}	|	|
fS )	Nr   rI   r   r`   eagerr   )r   r   rA   r   )rL   r   r   r   rg   rd   r   r   r   r0   Z_attn_implementationr   r   r   r   r   rM   r   Zreshape)rD   r    r   rP   Z	new_shapeZ	key_layerZvalue_layerZquery_layerZattention_interfaceZcontext_layerZattention_probsZnew_context_layer_shaper,   r,   r-   rW     s*    


zYolosSelfAttention.forward)N)r"   r#   r$   r   r4   r&   rY   r   r+   rW   rZ   r,   r,   rE   r-   r      s    r   c                       s>   e Zd ZdZed fddZejejejdddZ  Z	S )YolosSelfOutputz
    The residual connection is defined in YolosLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    r   c                    s.   t    t|j|j| _t|j| _d S r\   )	r3   r4   r   r   r7   denser?   r@   rA   r^   rE   r,   r-   r4   2  s    
zYolosSelfOutput.__init__r    input_tensorr1   c                 C   s   |  |}| |}|S r\   r   rA   rD   r    r   r,   r,   r-   rW   7  s    

zYolosSelfOutput.forwardrX   r,   r,   rE   r-   r   ,  s   r   c                       sR   e Zd Zed fddZee dddZdej	e
ej	 ej	dd	d
Z  ZS )YolosAttentionr   c                    s*   t    t|| _t|| _t | _d S r\   )r3   r4   r   	attentionr   outputsetpruned_headsr^   rE   r,   r-   r4   ?  s    


zYolosAttention.__init__)headsc                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rJ   )lenr   r   r   r   r   r   r   r   r   r   r   r   union)rD   r   indexr,   r,   r-   prune_headsE  s    zYolosAttention.prune_headsNr   c                 C   s    |  ||\}}| ||}|S r\   )r   r   )rD   r    r   Zself_attn_outputrV   r   r,   r,   r-   rW   W  s    zYolosAttention.forward)N)r"   r#   r$   r   r4   r   r   r   r&   rY   r   rW   rZ   r,   r,   rE   r-   r   >  s   r   c                       s6   e Zd Zed fddZejejdddZ  ZS )YolosIntermediater   c                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S r\   )r3   r4   r   r   r7   intermediate_sizer   ry   Z
hidden_actstrr   intermediate_act_fnr^   rE   r,   r-   r4   _  s
    
zYolosIntermediate.__init__r    r1   c                 C   s   |  |}| |}|S r\   )r   r   )rD   r    r,   r,   r-   rW   g  s    

zYolosIntermediate.forward	r"   r#   r$   r   r4   r&   rY   rW   rZ   r,   r,   rE   r-   r   ^  s   r   c                       s:   e Zd Zed fddZejejejdddZ  ZS )YolosOutputr   c                    s.   t    t|j|j| _t|j| _	d S r\   )
r3   r4   r   r   r   r7   r   r?   r@   rA   r^   rE   r,   r-   r4   o  s    
zYolosOutput.__init__r   c                 C   s    |  |}| |}|| }|S r\   r   r   r,   r,   r-   rW   t  s    

zYolosOutput.forwardr   r,   r,   rE   r-   r   n  s   r   c                       sD   e Zd ZdZed fddZd	ejeej ejdddZ	  Z
S )

YolosLayerz?This corresponds to the Block class in the timm implementation.r   c                    sb   t    |j| _d| _t|| _t|| _t|| _	t
j|j|jd| _t
j|j|jd| _d S )Nr   eps)r3   r4   Zchunk_size_feed_forwardZseq_len_dimr   r   r   intermediater   r   r   	LayerNormr7   layer_norm_epslayernorm_beforelayernorm_afterr^   rE   r,   r-   r4     s    



zYolosLayer.__init__Nr   c                 C   sB   |  |}| ||}|| }| |}| |}| ||}|S r\   )r   r   r   r   r   )rD   r    r   Zhidden_states_normZattention_outputZlayer_outputr,   r,   r-   rW     s    


zYolosLayer.forward)N)r"   r#   r$   r%   r   r4   r&   rY   r   rW   rZ   r,   r,   rE   r-   r   |  s   
r   c                       sD   e Zd Zedd fddZdejeeeej e	dddZ
  ZS )	YolosEncoderNr/   c                    s   t     | _t fddt jD | _d| _d j	d  j	d   j
d    j } jrtt jd d| jnd | _ jrt nd | _d S )Nc                    s   g | ]}t  qS r,   )r   ).0rV   r   r,   r-   
<listcomp>      z)YolosEncoder.__init__.<locals>.<listcomp>Fr   r   r`   )r3   r4   r0   r   
ModuleListrangenum_hidden_layerslayerZgradient_checkpointingre   rf   r9   use_mid_position_embeddingsr5   r&   r6   r7   mid_position_embeddingsrv   rC   )rD   r0   Z
seq_lengthrE   r   r-   r4     s$    
 &	zYolosEncoder.__init__)r    rR   rS   r   r1   c           	      C   sz   | j jr| | j||f}t| jD ]J\}}|d ur<|| nd }|||}| j jr$|| j jd k r$|||  }q$t|dS )Nr   )r   )r0   r   rC   r   	enumerater   r   r
   )	rD   r    rR   rS   r   Z$interpolated_mid_position_embeddingsiZlayer_moduleZlayer_head_maskr,   r,   r-   rW     s    
zYolosEncoder.forward)N)r"   r#   r$   r   r4   r&   rY   r   r   r
   rW   rZ   r,   r,   rE   r-   r     s    r   c                   @   s`   e Zd ZU eed< dZdZdZg ZdZ	dZ
dZdZeedZeejejejf dddd	ZdS )
YolosPreTrainedModelr0   vitrH   T)r    r!   N)r   r1   c                 C   sj   t |tjtjfr@|jjjd| jjd |j	durf|j	j
  n&t |tjrf|j	j
  |jjd dS )zInitialize the weightsr   )meanZstdNg      ?)ry   r   r   r}   weightdataZnormal_r0   Zinitializer_ranger   Zzero_r   Zfill_)rD   r   r,   r,   r-   _init_weights  s    
z"YolosPreTrainedModel._init_weights)r"   r#   r$   r   r(   Zbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_no_split_modulesZ_supports_sdpaZ_supports_flash_attnZ_supports_flex_attnZ_supports_attention_backendr   r   Z_can_record_outputsr   r   r   r}   r   r   r,   r,   r,   r-   r     s   
r   c                       s~   e Zd Zdeed fddZedddZee	e
e	 f dd	d
dZeedeej eej ee edddZ  ZS )
YolosModelT)r0   add_pooling_layerc                    sX   t  | || _t|| _t|| _tj|j	|j
d| _|rFt|nd| _|   dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        r   N)r3   r4   r0   r.   rT   r   encoderr   r   r7   r   	layernormYolosPoolerpooler	post_init)rD   r0   r   rE   r,   r-   r4     s    

zYolosModel.__init__r[   c                 C   s   | j jS r\   )rT   r<   )rD   r,   r,   r-   get_input_embeddings  s    zYolosModel.get_input_embeddingsN)heads_to_pruner1   c                 C   s*   |  D ]\}}| jj| j| qdS )a	  
        Prunes heads of the model.

        Args:
            heads_to_prune (`dict`):
                See base class `PreTrainedModel`. The input dictionary must have the following format: {layer_num:
                list of heads to prune in this layer}
        N)itemsr   r   r   r   )rD   r   r   r   r,   r,   r-   _prune_heads  s    	zYolosModel._prune_heads)rH   r   r   r1   c           
      K   s   |d u rt d| || jj}| |}|jdd  \}}| j||||d}|j}| |}| j	d urr| 	|nd }	t
||	dS )Nz You have to specify pixel_valuesr   )rR   rS   r   )r   Zpooler_output)r   Zget_head_maskr0   r   rT   rL   r   r   r   r   r   )
rD   rH   r   r   Zembedding_outputrR   rS   Zencoder_outputssequence_outputpooled_outputr,   r,   r-   rW     s    

zYolosModel.forward)T)NN)r"   r#   r$   r   boolr4   r;   r   r)   r   r*   r   r   r   r   r&   rY   r   r   r   rW   rZ   r,   r,   rE   r-   r     s     r   c                       s6   e Zd Zed fddZejejdddZ  ZS )r   r   c                    s*   t    t|j|j| _t | _d S r\   )r3   r4   r   r   r7   r   ZTanh
activationr^   rE   r,   r-   r4   '  s    
zYolosPooler.__init__r   c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )rD   r    Zfirst_token_tensorr   r,   r,   r-   rW   ,  s    

zYolosPooler.forwardr   r,   r,   rE   r-   r   &  s   r   c                       s(   e Zd ZdZ fddZdd Z  ZS )YolosMLPPredictionHeada  
    Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
    height and width of a bounding box w.r.t. an image.

    Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py

    c                    sJ   t    || _|g|d  }tdd t|g| ||g D | _d S )Nr   c                 s   s   | ]\}}t ||V  qd S r\   )r   r   )r   nkr,   r,   r-   	<genexpr>C  r   z2YolosMLPPredictionHead.__init__.<locals>.<genexpr>)r3   r4   
num_layersr   r   ziplayers)rD   	input_dim
hidden_dim
output_dimr   hrE   r,   r-   r4   ?  s    
zYolosMLPPredictionHead.__init__c                 C   s>   t | jD ].\}}|| jd k r0tj||n||}q
|S r2   )r   r   r   r   rh   Zrelu)rD   xr   r   r,   r,   r-   rW   E  s    (zYolosMLPPredictionHead.forward)r"   r#   r$   r%   r4   rW   rZ   r,   r,   rE   r-   r   6  s   r   zy
    YOLOS Model (consisting of a ViT encoder) with object detection heads on top, for tasks such as COCO detection.
    c                       s^   e Zd Zed fddZejjdd Ze	e
d
ejeee  ee eddd	Z  ZS )YolosForObjectDetectionr   c                    sX   t  | t|dd| _t|j|j|jd dd| _t|j|jddd| _| 	  d S )NF)r   r   r   )r   r   r   r      )
r3   r4   r   r   r   r7   Z
num_labelsclass_labels_classifierbbox_predictorr   r^   rE   r,   r-   r4   Q  s    z YolosForObjectDetection.__init__c                 C   s$   dd t |d d |d d D S )Nc                 S   s   g | ]\}}||d qS ))r   r   r,   )r   abr,   r,   r-   r   i  r   z9YolosForObjectDetection._set_aux_loss.<locals>.<listcomp>rI   )r   )rD   outputs_classoutputs_coordr,   r,   r-   _set_aux_lossd  s    z%YolosForObjectDetection._set_aux_lossN)rH   labelsr   r1   c              
   K   s   | j |fi |}|j}|dd| jj dddf }| |}| | }d\}}	}
|durd\}}| jjr|j}| |}| | }| 	||| j
|| j||\}}	}
t||	|||
|j|j|jdS )a	  
        labels (`list[Dict]` of len `(batch_size,)`, *optional*):
            Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
            following 2 keys: `'class_labels'` and `'boxes'` (the class labels and bounding boxes of an image in the
            batch respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding
            boxes in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image,
            4)`.

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoModelForObjectDetection
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("hustvl/yolos-tiny")
        >>> model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-tiny")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
        >>> target_sizes = torch.tensor([image.size[::-1]])
        >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
        ...     0
        ... ]

        >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        ...     box = [round(i, 2) for i in box.tolist()]
        ...     print(
        ...         f"Detected {model.config.id2label[label.item()]} with confidence "
        ...         f"{round(score.item(), 3)} at location {box}"
        ...     )
        Detected remote with confidence 0.991 at location [46.48, 72.78, 178.98, 119.3]
        Detected remote with confidence 0.908 at location [336.48, 79.27, 368.23, 192.36]
        Detected cat with confidence 0.934 at location [337.18, 18.06, 638.14, 373.09]
        Detected cat with confidence 0.979 at location [10.93, 53.74, 313.41, 470.67]
        Detected remote with confidence 0.974 at location [41.63, 72.23, 178.09, 119.99]
        ```N)NNN)NN)r   r   r   r   r   r   r    r!   )r   r   r0   r9   r   r   ZsigmoidZauxiliary_lossr    Zloss_functionZdevicer   r!   )rD   rH   r   r   outputsr   r   r   r   r   r   r   r   r   r,   r,   r-   rW   k  s2    5 



zYolosForObjectDetection.forward)N)r"   r#   r$   r   r4   r&   ZjitZunusedr   r   r   r'   r   r*   r)   r   r   r   rW   rZ   r,   r,   rE   r-   r   K  s   
 
r   )r   r   r   )r   )>r%   collections.abcrz   dataclassesr   typingr   r   r   r&   Ztorch.utils.checkpointr   Zactivationsr   Zmodeling_layersr	   Zmodeling_outputsr
   r   Zmodeling_utilsr   r   Zprocessing_utilsr   Zpytorch_utilsr   r   utilsr   r   r   r   Zutils.genericr   r   Zconfiguration_yolosr   Z
get_loggerr"   loggerr   Moduler.   rB   rv   r;   rY   floatr   r   r   r   r   r   r   r   r   r   r   r   r   __all__r,   r,   r,   r-   <module>   sj   
!*!) 5 .@q