a
    hW4                    @   s  d Z ddlmZ ddlmZ ddlmZmZmZ ddl	Z	ddl
Z	ddl	mZmZ ddlmZ dd	lmZmZ dd
lmZ ddlmZmZ ddlmZ ddlmZmZmZmZmZ ddlm Z m!Z!m"Z" e rddl#m$Z$ e%e&Z'e	je	jdddZ(e	je	jdddZ)eeG dd deZ*eedddZ+eedddZ,dd  Z-d!d" Z.eed#d$G d%d& d&eZ/eed'd$G d(d) d)eZ0G d*d+ d+ej1Z2G d,d- d-ej1Z3G d.d/ d/ej1Z4G d0d1 d1ej1Z5G d2d3 d3eZ6eG d4d5 d5eZ7G d6d7 d7ej1Z8G d8d9 d9ej1Z9G d:d; d;e7Z:G d<d= d=ej1Z;G d>d? d?e7Z<eG d@dA dAe7Z=G dBdC dCej1Z>G dDdE dEej1Z?G dFdG dGe7Z@g dHZAdS )IzPyTorch OWLv2 model.    )	dataclass)	lru_cache)AnyOptionalUnionN)Tensornn   )ACT2FN) _create_4d_causal_attention_mask_prepare_4d_attention_mask)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPooling)PreTrainedModel)ModelOutputauto_docstringis_vision_availablelogging	torch_int   )Owlv2ConfigOwlv2TextConfigOwlv2VisionConfig)center_to_corners_format)logitsreturnc                 C   s   t j| tjt| | jdS )Ndevice)r   
functionalZcross_entropytorcharangelenr   )r    r#   d/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/owlv2/modeling_owlv2.pycontrastive_loss-   s    r%   )
similarityr   c                 C   s    t | }t |  }|| d S )Ng       @)r%   t)r&   Zcaption_lossZ
image_lossr#   r#   r$   
owlv2_loss2   s    r(   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeed< dZeed	< ee d
ddZdS )Owlv2Outputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
        Contrastive loss for image-text similarity.
    logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
        The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
        similarity scores.
    logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
        The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
        similarity scores.
    text_embeds (`torch.FloatTensor` of shape `(batch_size * num_max_text_queries, output_dim`):
        The text embeddings obtained by applying the projection layer to the pooled output of [`Owlv2TextModel`].
    image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
        The image embeddings obtained by applying the projection layer to the pooled output of
        [`Owlv2VisionModel`].
    text_model_output (tuple[`BaseModelOutputWithPooling`]):
        The output of the [`Owlv2TextModel`].
    vision_model_output (`BaseModelOutputWithPooling`):
        The output of the [`Owlv2VisionModel`].
    Nlosslogits_per_imagelogits_per_texttext_embedsimage_embedstext_model_outputvision_model_outputr   c                    s   t  fdd  D S )Nc                 3   s,   | ]$}|d vr | nt  | V  qdS )r/   r0   Ngetattrto_tuple.0kselfr#   r$   	<genexpr>X   s   z'Owlv2Output.to_tuple.<locals>.<genexpr>tuplekeysr9   r#   r9   r$   r5   W   s    zOwlv2Output.to_tuple)__name__
__module____qualname____doc__r*   r   r    FloatTensor__annotations__r+   r,   r-   r.   r/   r   r0   r=   r   r5   r#   r#   r#   r$   r)   8   s   
r)   )r'   r   c                 C   sH   |   r&| jtjtjfv r| S |  S | jtjtjfv r<| S |  S d S N)	Zis_floating_pointdtyper    float32Zfloat64floatZint32Zint64int)r'   r#   r#   r$   _upcast_   s    rJ   )boxesr   c                 C   sH   t | } | dddf | dddf  | dddf | dddf   S )a  
    Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.

    Args:
        boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
            Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
            < x2` and `0 <= y1 < y2`.

    Returns:
        `torch.FloatTensor`: a tensor containing the area for each box.
    N   r   r	   r   )rJ   )rK   r#   r#   r$   box_areah   s    rM   c           
      C   s   t | }t |}t| d d d d df |d d d df }t| d d d dd f |d d dd f }|| jdd}|d d d d df |d d d d df  }|d d d f | | }|| }	|	|fS )NrL   r   minr   )rM   r    maxrO   clamp)
boxes1boxes2Zarea1Zarea2Zleft_topZright_bottomwidth_heightinterunioniour#   r#   r$   box_iouy   s    ..,rX   c                 C   s*  | ddddf | ddddf k  s:td|  |ddddf |ddddf k  sttd| t| |\}}t| dddddf |ddddf }t| dddddf |ddddf }|| jdd}|dddddf |dddddf  }||| |  S )z
    Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.

    Returns:
        `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
    NrL   z<boxes1 must be in [x0, y0, x1, y1] (corner) format, but got z<boxes2 must be in [x0, y0, x1, y1] (corner) format, but got r   rN   r   )all
ValueErrorrX   r    rO   rP   rQ   )rR   rS   rW   rV   top_leftbottom_rightrT   Zarear#   r#   r$   generalized_box_iou   s    	,,..,r]   z5
    Output type of [`Owlv2ForObjectDetection`].
    )Zcustom_introc                   @   s   e Zd ZU dZdZeej ed< dZ	ee
 ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed	< dZeej ed
< dZeed< dZeed< ee dddZdS )Owlv2ObjectDetectionOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
        Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
        bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
        scale-invariant IoU loss.
    loss_dict (`Dict`, *optional*):
        A dictionary containing the individual losses. Useful for logging.
    logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`):
        Classification logits (including no-object) for all queries.
    objectness_logits (`torch.FloatTensor` of shape `(batch_size, num_patches, 1)`):
        The objectness logits of all image patches. OWL-ViT represents images as a set of image patches where the
        total number of patches is (image_size / patch_size)**2.
    pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
        Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
        values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
        possible padding). You can use [`~Owlv2ImageProcessor.post_process_object_detection`] to retrieve the
        unnormalized bounding boxes.
    text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`):
        The text embeddings obtained by applying the projection layer to the pooled output of [`Owlv2TextModel`].
    image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
        Pooled output of [`Owlv2VisionModel`]. OWLv2 represents images as a set of image patches and computes image
        embeddings for each patch.
    class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
        Class embeddings of all image patches. OWLv2 represents images as a set of image patches where the total
        number of patches is (image_size / patch_size)**2.
    text_model_output (tuple[`BaseModelOutputWithPooling`]):
        The output of the [`Owlv2TextModel`].
    vision_model_output (`BaseModelOutputWithPooling`):
        The output of the [`Owlv2VisionModel`].
    Nr*   	loss_dictr   objectness_logits
pred_boxesr-   r.   class_embedsr/   r0   r1   c                    s   t  fdd  D S )Nc                 3   s,   | ]$}|d vr | nt  | V  qdS r2   r3   r6   r9   r#   r$   r;      s   z6Owlv2ObjectDetectionOutput.to_tuple.<locals>.<genexpr>r<   r9   r#   r9   r$   r5      s    z#Owlv2ObjectDetectionOutput.to_tuple)r?   r@   rA   rB   r*   r   r    rC   rD   r_   dictr   r`   ra   r-   r.   rb   r/   r   r0   r=   r   r5   r#   r#   r#   r$   r^      s   
r^   zL
    Output type of [`Owlv2ForObjectDetection.image_guided_detection`].
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeej ed< dZeed	< dZeed
< ee dddZdS )%Owlv2ImageGuidedObjectDetectionOutputa  
    logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`):
        Classification logits (including no-object) for all queries.
    image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
        Pooled output of [`Owlv2VisionModel`]. OWLv2 represents images as a set of image patches and computes
        image embeddings for each patch.
    query_image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
        Pooled output of [`Owlv2VisionModel`]. OWLv2 represents images as a set of image patches and computes
        image embeddings for each patch.
    target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
        Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
        values are normalized in [0, 1], relative to the size of each individual target image in the batch
        (disregarding possible padding). You can use [`~Owlv2ImageProcessor.post_process_object_detection`] to
        retrieve the unnormalized bounding boxes.
    query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
        Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
        values are normalized in [0, 1], relative to the size of each individual query image in the batch
        (disregarding possible padding). You can use [`~Owlv2ImageProcessor.post_process_object_detection`] to
        retrieve the unnormalized bounding boxes.
    class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
        Class embeddings of all image patches. OWLv2 represents images as a set of image patches where the total
        number of patches is (image_size / patch_size)**2.
    text_model_output (tuple[`BaseModelOutputWithPooling`]):
        The output of the [`Owlv2TextModel`].
    vision_model_output (`BaseModelOutputWithPooling`):
        The output of the [`Owlv2VisionModel`].
    Nr   r.   query_image_embedstarget_pred_boxesquery_pred_boxesrb   r/   r0   r1   c                    s   t  fdd  D S )Nc                 3   s,   | ]$}|d vr | nt  | V  qdS r2   r3   r6   r9   r#   r$   r;     s   zAOwlv2ImageGuidedObjectDetectionOutput.to_tuple.<locals>.<genexpr>r<   r9   r#   r9   r$   r5     s    z.Owlv2ImageGuidedObjectDetectionOutput.to_tuple)r?   r@   rA   rB   r   r   r    rC   rD   r.   re   rf   rg   rb   r/   r   r0   r=   r   r5   r#   r#   r#   r$   rd      s   
rd   c                       sR   e Zd Zed fddZejeeejdddZdej	e
ejdd	d
Z  ZS )Owlv2VisionEmbeddingsconfigc                    s   t    |j| _|| _|j| _tt	|j| _
tj|j| j|j|jdd| _|j|j d | _| jd | _t| j| j| _| jdt| jddd d S )NF)Zin_channelsZout_channelsZkernel_sizeZstridebiasrL   r   position_idsr   
persistent)super__init__
patch_sizerj   hidden_size	embed_dimr   	Parameterr    Zrandnclass_embeddingZConv2dZnum_channelspatch_embedding
image_sizenum_patchesnum_positions	Embeddingposition_embeddingregister_bufferr!   expandr:   rj   	__class__r#   r$   rr     s     
zOwlv2VisionEmbeddings.__init__)
embeddingsheightwidthr   c                 C   s  |j d d }| jjd}|j d d }tj sP||krP||krP| | jS |ddddf }|ddddf }|j d }	|| j }
|| j }t	|d }|
d|||	}|dddd}tjj||
|fdd	d
}|dddddd|	}tj||fddS )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   r   Nrn   g      ?r	   rL   ZbicubicF)sizemodeZalign_cornersdim)shaper}   weight	unsqueezer    Zjit
is_tracingrl   rs   r   reshapeZpermuter   r   Zinterpolateviewcat)r:   r   r   r   rz   r}   r{   Zclass_pos_embedZpatch_pos_embedr   Z
new_heightZ	new_widthZsqrt_num_positionsr#   r#   r$   interpolate_pos_encoding%  s*    



z.Owlv2VisionEmbeddings.interpolate_pos_encodingF)pixel_valuesr   r   c           
      C   sx   |j \}}}}| |}|ddd}| j|dd}tj||gdd}	|rd|	| |	|| }	n|	| 	| j
 }	|	S )NrL   r   rn   r   )r   rx   flatten	transposerw   r   r    r   r   r}   rl   )
r:   r   r   
batch_size_r   r   Zpatch_embedsrb   r   r#   r#   r$   forwardN  s    
zOwlv2VisionEmbeddings.forward)F)r?   r@   rA   r   rr   r    r   rI   r   rC   boolr   __classcell__r#   r#   r   r$   rh     s   )rh   c                       sL   e Zd Zed fddZdeej eej eej ej	dddZ
  ZS )	Owlv2TextEmbeddingsri   c                    sP   t    t|j|j| _t|j|j| _| j	dt
|jddd d S )Nrl   rm   Fro   )rq   rr   r   r|   Z
vocab_sizert   token_embeddingZmax_position_embeddingsr}   r~   r    r!   r   r   r   r#   r$   rr   ^  s    
zOwlv2TextEmbeddings.__init__N)	input_idsrl   inputs_embedsr   c                 C   sb   |d ur|j d n|j d }|d u r:| jd d d |f }|d u rL| |}| |}|| }|S )Nrn   )r   rl   r   r}   )r:   r   rl   r   Z
seq_lengthZposition_embeddingsr   r#   r#   r$   r   h  s    

zOwlv2TextEmbeddings.forward)NNN)r?   r@   rA   r   rr   r   r    
LongTensorrC   r   r   r   r#   r#   r   r$   r   ]  s      r   c                       sz   e Zd ZdZ fddZejeedddZdeje	ej e	ej e	e
 eeje	ej e	eej  f d	d
dZ  ZS )Owlv2Attentionz=Multi-headed attention from 'Attention Is All You Need' paperc                    s   t    || _|j| _|j| _| j| j | _| j| j | jkrZtd| j d| j d| jd | _	|j
| _t| j| j| _t| j| j| _t| j| j| _t| j| j| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).      )rq   rr   rj   rt   ru   Znum_attention_heads	num_headshead_dimrZ   scaleZattention_dropoutdropoutr   Lineark_projv_projq_projout_projr   r   r#   r$   rr     s"    

zOwlv2Attention.__init__)tensorseq_lenbszc                 C   s    | ||| j| jdd S )Nr   rL   )r   r   r   r   
contiguous)r:   r   r   r   r#   r#   r$   _shape  s    zOwlv2Attention._shapeNFhidden_statesattention_maskcausal_attention_maskoutput_attentionsr   c                 C   s  |  \}}}| || j }| | |d|}	| | |d|}
|| j d| jf}| |||j| }|	j| }	|
j| }
|	 d}t	
||	dd}|  || j ||fkrtd|| j ||f d|   |durD|  |d||fkrtd|d||f d|   ||| j||| }||| j ||}|dur|  |d||fkrtd|d||f d|   ||| j||| }||| j ||}tjj|dd}|r||| j||}||| j ||}nd}tjj|| j| jd	}||
j}t	
||
}|  || j || jfkr^td
|| j|| jf d|   ||| j|| j}|dd}||||}| |}||fS )z#Input shape: Batch x Time x Channelrn   r   rL   z$Attention weights should be of size z	, but is Nz!Attention mask should be of size r   )ptrainingz `attn_output` should be of size )r   r   r   r   r   r   r   r   r   r    Zbmmr   rZ   r   r   Zsoftmaxr   r   torF   r   r   )r:   r   r   r   r   r   Ztgt_lenru   Zquery_statesZ
key_statesZvalue_statesZ
proj_shapeZsrc_lenattn_weightsZattn_weights_reshapedZ
attn_probsZattn_outputr#   r#   r$   r     sf    	





zOwlv2Attention.forward)NNF)r?   r@   rA   rB   rr   r    r   rI   r   r   r   r=   r   r   r#   r#   r   r$   r   }  s      r   c                       s0   e Zd Z fddZejejdddZ  ZS )Owlv2MLPc                    sD   t    || _t|j | _t|j|j	| _
t|j	|j| _d S rE   )rq   rr   rj   r
   Z
hidden_actactivation_fnr   r   rt   Zintermediate_sizefc1fc2r   r   r#   r$   rr     s
    
zOwlv2MLP.__init__)r   r   c                 C   s"   |  |}| |}| |}|S rE   )r   r   r   )r:   r   r#   r#   r$   r     s    


zOwlv2MLP.forward)r?   r@   rA   rr   r    r   r   r   r#   r#   r   r$   r     s   r   c                       sJ   e Zd Zed fddZdejejejee e	ej
 dddZ  ZS )	Owlv2EncoderLayerri   c                    sR   t    |j| _t|| _tj| j|jd| _	t
|| _tj| j|jd| _d S Neps)rq   rr   rt   ru   r   	self_attnr   	LayerNormlayer_norm_epslayer_norm1r   mlplayer_norm2r   r   r#   r$   rr     s    


zOwlv2EncoderLayer.__init__Fr   c                 C   sd   |}|  |}| j||||d\}}|| }|}| |}| |}|| }|f}|r`||f7 }|S )aI  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                `(config.encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r   r   r   r   )r   r   r   r   )r:   r   r   r   r   Zresidualr   outputsr#   r#   r$   r     s"    




zOwlv2EncoderLayer.forward)F)r?   r@   rA   r   rr   r    r   r   r   r=   rC   r   r   r#   r#   r   r$   r     s    r   c                   @   s4   e Zd ZU eed< dZdZdgZej	dddZ
dS )	Owlv2PreTrainedModelrj   owlv2Tr   )modulec                 C   s\  | j j}t|trF|jjjjd|d d |jjjjd|d d nt|t	rt
jj|jd|jd | d t
jj|jj|j j| d t
jj|jj|j j| d nNt|tr4|jd d|j j d  | }|jd | }t
jj|jj|d t
jj|jj|d t
jj|jj|d t
jj|jj|d nt|tr|j jd d|j j d  | }d|j j d | }t
jj|jj|d t
jj|jj|d nZt|trt
jj|jj|jd | d t
jj|jj|jd | d |jj | j j! t|t
j"r |j#j$  |jj d t|t
j%rX|jjjd|d |j#durX|j#j$  dS )	zInitialize the weights        g{Gz?)meanstdr   )r   rL         ?N)&rj   Zinitializer_factor
isinstancer   r   r   dataZnormal_r}   rh   r   initrw   ru   rx   Zinitializer_ranger   num_hidden_layersr   r   r   r   r   rt   r   r   
Owlv2Modeltext_projectiontext_embed_dimvisual_projectionvision_embed_dimlogit_scaleZfill_logit_scale_init_valuer   rk   Zzero_r   )r:   r   factorZin_proj_stdZout_proj_stdZfc_stdr#   r#   r$   _init_weights3  sJ    

  z"Owlv2PreTrainedModel._init_weightsN)r?   r@   rA   r   rD   Zbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesr   Moduler   r#   r#   r#   r$   r   +  s
   
r   c                	       s`   e Zd ZdZed fddZd	eej eej ee	 ee	 ee	 e
eef dddZ  ZS )
Owlv2Encoderz
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`Owlv2EncoderLayer`].

    Args:
        config: Owlv2Config
    ri   c                    s4   t    t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r#   )r   )r7   r   ri   r#   r$   
<listcomp>h      z)Owlv2Encoder.__init__.<locals>.<listcomp>F)rq   rr   r   Z
ModuleListranger   layersZgradient_checkpointingr   r   ri   r$   rr   f  s    
 zOwlv2Encoder.__init__N)r   r   r   output_hidden_statesreturn_dictr   c                 C   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}|rDdnd}|rPdnd}|}	| jD ]<}
|rp||	f }|
|	|||d}|d }	|r^||d f }q^|r||	f }|stdd |	||fD S t|	||dS )	a  
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`).
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.
                [What are attention masks?](../glossary#attention-mask)
            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Causal mask for the text model. Mask values selected in `[0, 1]`:
                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.
                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr#   )r   r   r   c                 s   s   | ]}|d ur|V  qd S rE   r#   )r7   vr#   r#   r$   r;     r   z'Owlv2Encoder.forward.<locals>.<genexpr>)last_hidden_stater   
attentions)rj   r   r   use_return_dictr   r=   r   )r:   r   r   r   r   r   r   Zencoder_statesZall_attentionsr   Zencoder_layerZlayer_outputsr#   r#   r$   r   k  s4    


zOwlv2Encoder.forward)NNNNN)r?   r@   rA   rB   r   rr   r   r    r   r   r   r=   r   r   r   r#   r#   r   r$   r   ]  s        
r   c                       sd   e Zd Zed fddZedejeej eej ee	 ee	 ee	 e
eef dddZ  ZS )	Owlv2TextTransformerri   c                    s@   t    || _|j}t|| _t|| _tj	||j
d| _d S r   )rq   rr   rj   rt   r   r   r   encoderr   r   r   final_layer_norm)r:   rj   ru   r   r#   r$   rr     s    


zOwlv2TextTransformer.__init__N)r   r   rl   r   r   r   r   c                 C   s  |dur|n| j j}|dur |n| j j}|dur4|n| j j}| }|d|d }| j||d}t||j|j	d}	|durt
||j}| j|||	|||d}
|
d }| |}|tj|jd |j	d|tjjdd|j	f }|s||f|
dd  S t|||
j|
jd	S )
a|  
        input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)
        Nrn   )r   rl   r   )r   r   r   r   r   r   r   r   r   r   Zpooler_outputr   r   )rj   r   r   r   r   r   r   r   rF   r   r   r   r   r    r!   r   r   rI   Zargmaxr   r   r   )r:   r   r   rl   r   r   r   Zinput_shaper   r   encoder_outputsr   pooled_outputr#   r#   r$   r     sF    
	
zOwlv2TextTransformer.forward)NNNNN)r?   r@   rA   r   rr   r   r    r   r   r   r   r=   r   r   r   r#   r#   r   r$   r     s         
r   c                
       s~   e Zd ZU eed< ed fddZejdddZdd	 Z	e
dejeej ee ee ee eeef dddZ  ZS )Owlv2TextModelrj   ri   c                    s"   t  | t|| _|   d S rE   )rq   rr   r   
text_model	post_initr   r   r#   r$   rr     s    
zOwlv2TextModel.__init__r1   c                 C   s
   | j jjS rE   r   r   r   r9   r#   r#   r$   get_input_embeddings  s    z#Owlv2TextModel.get_input_embeddingsc                 C   s   || j j_d S rE   r   )r:   valuer#   r#   r$   set_input_embeddings  s    z#Owlv2TextModel.set_input_embeddingsNr   r   r   r   r   r   c                 C   s   | j |||||dS )a  
        input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)

        Examples:
        ```python
        >>> from transformers import AutoProcessor, Owlv2TextModel

        >>> model = Owlv2TextModel.from_pretrained("google/owlv2-base-patch16")
        >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16")
        >>> inputs = processor(
        ...     text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt"
        ... )
        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled (EOS token) states
        ```r   r   r   r   r   )r   )r:   r   r   r   r   r   r#   r#   r$   r   
  s    zOwlv2TextModel.forward)NNNN)r?   r@   rA   r   rD   rr   r   r   r   r   r   r    r   r   r   r   r=   r   r   r   r#   r#   r   r$   r     s"   
    
r   c                
       sZ   e Zd Zed fddZed	ejee	 ee	 ee	 ee	 e
eef dddZ  ZS )
Owlv2VisionTransformerri   c                    sP   t    || _t|| _tj|j|jd| _	t
|| _tj|j|jd| _d S r   )rq   rr   rj   rh   r   r   r   rt   r   pre_layernormr   r   post_layernormr   r   r#   r$   rr   3  s    


zOwlv2VisionTransformer.__init__NFr   r   r   r   r   r   c                 C   s   |d ur|n| j j}|d ur |n| j j}|d ur4|n| j j}| jjjj}||}| j||d}| 	|}| j
||||d}|d }	|	d d dd d f }
| |
}
|s|	|
f|dd   S t|	|
|j|jdS )N)r   )r   r   r   r   r   r   r   )rj   r   r   r   r   rx   r   rF   r   r   r   r   r   r   r   )r:   r   r   r   r   r   Zexpected_input_dtyper   r   r   r   r#   r#   r$   r   <  s2    	


zOwlv2VisionTransformer.forward)NNFN)r?   r@   rA   r   rr   r   r    rC   r   r   r   r=   r   r   r   r#   r#   r   r$   r   2  s   	    
r   c                
       sx   e Zd ZU eed< dZed fddZejdddZ	e
deej ee ee eee eeef dddZ  ZS )Owlv2VisionModelrj   r   ri   c                    s"   t  | t|| _|   d S rE   )rq   rr   r   vision_modelr   r   r   r#   r$   rr   n  s    
zOwlv2VisionModel.__init__r1   c                 C   s
   | j jjS rE   )r   r   rx   r9   r#   r#   r$   r   t  s    z%Owlv2VisionModel.get_input_embeddingsNFr   c                 C   s   | j |||||dS )a  
        Examples:
        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Owlv2VisionModel

        >>> model = Owlv2VisionModel.from_pretrained("google/owlv2-base-patch16")
        >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16")
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled CLS states
        ```r   r   r   r   r   )r   )r:   r   r   r   r   r   r#   r#   r$   r   w  s    zOwlv2VisionModel.forward)NNNFN)r?   r@   rA   r   rD   Zmain_input_namerr   r   r   r   r   r   r    rC   r   r   r=   r   r   r   r#   r#   r   r$   r   j  s$   
     
r   c                       s   e Zd ZU eed< ed fddZedeej	 eej	 ee
 ee
 ee
 ejdddZedeej ee
 ee
 e
ee
 ejd
ddZedeej eej eej	 ee
 ee
 ee
 e
ee
 ee
 eeef d
ddZ  ZS )r   rj   ri   c                    s   t  | t|jts.tdt|j dt|jtsPtdt|j d|j}|j}|j	| _	|j
| _|j
| _t|| _t|| _tj| j| j	dd| _tj| j| j	dd| _tt|j| _|   d S )NzLconfig.text_config is expected to be of type Owlv2TextConfig but is of type .zPconfig.vision_config is expected to be of type Owlv2VisionConfig but is of type F)rk   )rq   rr   r   text_configr   	TypeErrortypevision_configr   Zprojection_dimrt   r   r   r   r   r   r   r   r   r   r   rv   r    r   r   r   r   )r:   rj   r   r  r   r#   r$   rr     s0    

zOwlv2Model.__init__Nr   c           	      C   s:   |dur|n| j j}| j|||d}|d }| |}|S )aL  
        input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)

        Returns:
            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of [`Owlv2TextModel`].

        Examples:
        ```python
        >>> from transformers import AutoProcessor, Owlv2Model

        >>> model = Owlv2Model.from_pretrained("google/owlv2-base-patch16-ensemble")
        >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
        >>> inputs = processor(
        ...     text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt"
        ... )
        >>> text_features = model.get_text_features(**inputs)
        ```N)r   r   r   r   )rj   r   r   r   )	r:   r   r   r   r   r   Ztext_outputr   Ztext_featuresr#   r#   r$   get_text_features  s
    
zOwlv2Model.get_text_featuresFr   c           	      C   sf   |dur|n| j j}|dur |n| j j}|dur4|n| j j}| j|||||d}|d }| |}|S )aO  
        Returns:
            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of [`Owlv2VisionModel`].

        Examples:
        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Owlv2Model

        >>> model = Owlv2Model.from_pretrained("google/owlv2-base-patch16-ensemble")
        >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> inputs = processor(images=image, return_tensors="pt")
        >>> image_features = model.get_image_features(**inputs)
        ```Nr   r   )rj   r   r   r   r   r   )	r:   r   r   r   r   r   vision_outputsr   image_featuresr#   r#   r$   get_image_features  s    
zOwlv2Model.get_image_features)
r   r   r   return_lossr   r   r   return_base_image_embedsr   r   c
              	   C   s>  |dur|n| j j}|dur |n| j j}|	dur4|	n| j j}	| j|||||	d}
| j|||||	d}|d }| |}|
d }| |}|tj	j
|dddd }|tj	j
|dddd }| j |j}t|| | }| }d}|rt|}|}|	s(||||||
f}|dur$|f| S |S t|||||||
d	S )
a4  
        return_loss (`bool`, *optional*):
            Whether or not to return the contrastive loss.
        return_base_image_embeds (`bool`, *optional*):
            Whether or not to return the base image embeddings.

        Examples:
        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Owlv2Model

        >>> model = Owlv2Model.from_pretrained("google/owlv2-base-patch16-ensemble")
        >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
        >>> probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities
        ```Nr   r   r   rL   rn   T)ordr   keepdim)r*   r+   r,   r-   r.   r/   r0   )rj   r   r   r   r   r   r   r   r    linalgnormr   expr   r   matmulr'   r(   r)   )r:   r   r   r   r  r   r   r   r  r   r  text_outputsr-   r.   Ztext_embeds_normr   r,   r+   r*   outputr#   r#   r$   r     sV    #	

zOwlv2Model.forward)NNNNN)NNNFN)	NNNNNNFNN)r?   r@   rA   r   rD   rr   r   r   r    r   r   rC   r  r  r   r   r=   r)   r   r   r#   r#   r   r$   r     sh   
      '     .         
r   c                       s:   e Zd Zdeed fddZejejdddZ	  Z
S )	Owlv2BoxPredictionHead   )rj   out_dimc                    sJ   t    |jj}t||| _t||| _t | _	t||| _
d S rE   )rq   rr   r  rt   r   r   dense0dense1ZGELUgeludense2)r:   rj   r  r   r   r#   r$   rr   w  s    

zOwlv2BoxPredictionHead.__init__r  r   c                 C   s6   |  |}| |}| |}| |}| |}|S rE   )r  r  r  r  )r:   r  r  r#   r#   r$   r     s    




zOwlv2BoxPredictionHead.forward)r  )r?   r@   rA   r   rI   rr   r    r   rC   r   r   r#   r#   r   r$   r  v  s   	r  c                       sJ   e Zd Zed fddZejeej eej e	ej dddZ
  ZS )Owlv2ClassPredictionHeadri   c                    sZ   t    |jj}|jj| _t| j|| _t| jd| _	t| jd| _
t | _d S )Nr   )rq   rr   r   rt   r  	query_dimr   r   r  logit_shiftr   ZELUelu)r:   rj   r  r   r#   r$   rr     s    

z!Owlv2ClassPredictionHead.__init__)r.   query_embeds
query_maskr   c                 C   s  |  |}|d u rJ|j}|jd d \}}t||| jf|}||fS |tjj|dddd  }|tjj|dddd  }t	d||}| 
|}	| |}
| |
d }
||	 |
 }|d ur|jdkrtj|dd	}t|d
kt|jj|}|tj}||fS )NrL   rn   T)r   r
  gư>z...pd,...qd->...pqr   r   r   r   )r  r   r   r    Zzerosr  r   r  r  einsumr  r   r  ndimr   whereZfinforF   rO   rG   )r:   r.   r  r  image_class_embedsr   r   rz   pred_logitsr  r   r#   r#   r$   r     s&    




z Owlv2ClassPredictionHead.forward)r?   r@   rA   r   rr   r    rC   r   r   r=   r   r   r#   r#   r   r$   r    s   r  c                       s  e Zd ZU eed< ed fddZeeeej	dddZ
ejejdd	d
Zeddd'eeeej ej	dddZd(ejejeejdddZd)ejeej eej	 eej dddZd*ej	ejej	ee ee eeej dddZd+ejee ee eeej dddZd,ejejeejddd Zed-ejeej ee ee eee ed!d"d#Zed.ej	ejeej	 ee ee eee ed$d%d&Z  ZS )/Owlv2ForObjectDetectionrj   ri   c                    s   t  | t|| _t|| _t|| _t|dd| _t	j
|jj|jjd| _t	 | _|| _| jjj| jjj | _| jjj| jjj | _| | j| j| _|   d S )Nr   )r  r   )rq   rr   r   r   r  
class_headr  box_headobjectness_headr   r   r  rt   r   
layer_normZSigmoidsigmoidrj   ry   rs   num_patches_heightnum_patches_widthcompute_box_biasbox_biasr   r   r   r#   r$   rr     s    



z Owlv2ForObjectDetection.__init__)r*  r+  r   c                 C   s   t jd|d t jd}t jd| d t jd}t j||dd\}}t j||fdd}|d  |  < |d  |   < |dd	}|S )
Nr   )rF   Zxy)Zindexingrn   r   .r   .r   rL   )r    r!   rG   Zmeshgridstackr   )r*  r+  Zx_coordinatesZy_coordinatesxxyybox_coordinatesr#   r#   r$   !normalize_grid_corner_coordinates  s    z9Owlv2ForObjectDetection.normalize_grid_corner_coordinatesr  c                 C   s   |  }| |}|d }|S )a#  Predicts the probability that each image feature token is an object.

        Args:
            image_features (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_dim)`)):
                Features extracted from the image.
        Returns:
            Objectness scores.
        r.  )detachr'  )r:   r  r`   r#   r#   r$   objectness_predictor  s    	
z,Owlv2ForObjectDetection.objectness_predictorrL   )maxsizeN)r*  r+  feature_mapr   c           	      C   s   |d urt d| ||}t|dd}t|d t| d  }t|d}|d  |  < |d  |  < t|d t| d  }tj||gdd}|S )	NzOfeature_map has been deprecated as an input. Please pass in num_patches insteadr   r   g-C6?r.  r/  rn   r   )rZ   r4  r    Zcliploglog1pZ	full_liker   )	r:   r*  r+  r8  r3  Zbox_coord_biasZbox_sizeZbox_size_biasr-  r#   r#   r$   r,    s    z(Owlv2ForObjectDetection.compute_box_biasF)image_featsr8  r   r   c           	      C   sR   |  |}|r*|j\}}}}| ||}n| j}||j}||7 }| |}|S )a  
        Args:
            image_feats:
                Features extracted from the image, returned by the `image_text_embedder` method.
            feature_map:
                A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method.
            interpolate_pos_encoding:
                Whether to interpolate the pre-trained position encodings.
        Returns:
            pred_boxes:
                List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary.
        )r&  r   r,  r-  r   r   r)  )	r:   r;  r8  r   ra   r   r*  r+  r-  r#   r#   r$   box_predictor  s    

z%Owlv2ForObjectDetection.box_predictor)r;  r  r  r   c                 C   s   |  |||\}}||fS )a8  
        Args:
            image_feats:
                Features extracted from the `image_text_embedder`.
            query_embeds:
                Text query embeddings.
            query_mask:
                Must be provided with query_embeddings. A mask indicating which query embeddings are valid.
        )r%  )r:   r;  r  r  r#  r"  r#   r#   r$   class_predictor*  s    z'Owlv2ForObjectDetection.class_predictor)r   r   r   r   r   r   r   c              	   C   s   | j ||||||dd}|rH|j\}}}	}
|	| jjj }|
| jjj }n| j}| j}|jd }| j j	|}t
|d d d dd d f |d d d df j}|d d dd d d f | }| |}|jd |||jd f}||}|d }|||fS )NT)r   r   r   r   r   r   r   r   r   rn   )r   r   rj   r  rs   r*  r+  r0   r   r   r    broadcast_tor(  r   )r:   r   r   r   r   r   r   r   r   r   r   r*  r+  r   r.   class_token_outnew_sizer-   r#   r#   r$   image_text_embedder>  s8    


4

z+Owlv2ForObjectDetection.image_text_embedder)r   r   r   r   r   c                 C   s   | j j||dd}|rB|j\}}}}|| jjj }	|| jjj }
n| j}	| j}
|d }| j j|}t	
|d d d dd d f |d d d df j}|d d dd d d f | }| |}|jd |	|
|jd f}||}||fS )NT)r   r   r   r   r   rn   )r   r   r   rj   r  rs   r*  r+  r   r    r?  r(  r   )r:   r   r   r   r   r  r   r   r   r*  r+  r   r.   r@  rA  r#   r#   r$   image_embedderr  s*    4

z&Owlv2ForObjectDetection.image_embedder)query_image_featuresquery_feature_mapr   r   c                 C   s<  |  |\}}| |||}t|}g }g }	|j}
t|jd D ]}tjg dg|
d}|| }t||\}}t	|d dkrt
||}t|d }|d |k }| r@|| |d }tj|| dd}td||}|t| }||| |  |	| q@|r*t|}t|	}nd	\}}|||fS )
Nr   )r   r   r   r   r   r   g?r   )Zaxiszd,id->i)NN)r=  r<  r   r   r   r   r    r   rX   rY   r]   rP   ZnonzeroZnumelZsqueezer   r  Zargminappendr0  )r:   rD  rE  r   r   rb   ra   Zpred_boxes_as_cornersZbest_class_embedsbest_box_indicesZpred_boxes_deviceiZeach_query_boxZeach_query_pred_boxesZiousZiou_thresholdZselected_indsZselected_embeddingsZmean_embedsZmean_simZbest_box_indr  Zbox_indicesr#   r#   r$   embed_image_query  s4    

z)Owlv2ForObjectDetection.embed_image_query)r   query_pixel_valuesr   r   r   r   r   c              
   C   s*  |dur|n| j j}|dur |n| j j}|dur4|n| j j}| j||dd }| j||||d\}}	|j\}
}}}t||
|| |f}|j\}
}}}t||
|| |f}| |||\}}}| j	||d\}}| 
|||}|s|||||||	 f}tdd |D }|S t||||||d|	dS )	a  
        query_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values of query image(s) to be detected. Pass in one query image per target image.

        Examples:
        ```python
        >>> import requests
        >>> from PIL import Image
        >>> import torch
        >>> from transformers import AutoProcessor, Owlv2ForObjectDetection

        >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
        >>> model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> query_url = "http://images.cocodataset.org/val2017/000000001675.jpg"
        >>> query_image = Image.open(requests.get(query_url, stream=True).raw)
        >>> inputs = processor(images=image, query_images=query_image, return_tensors="pt")

        >>> # forward pass
        >>> with torch.no_grad():
        ...     outputs = model.image_guided_detection(**inputs)

        >>> target_sizes = torch.Tensor([image.size[::-1]])

        >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
        >>> results = processor.post_process_image_guided_detection(
        ...     outputs=outputs, threshold=0.9, nms_threshold=0.3, target_sizes=target_sizes
        ... )
        >>> i = 0  # Retrieve predictions for the first image
        >>> boxes, scores = results[i]["boxes"], results[i]["scores"]
        >>> for box, score in zip(boxes, scores):
        ...     box = [round(i, 2) for i in box.tolist()]
        ...     print(f"Detected similar object with confidence {round(score.item(), 3)} at location {box}")
        Detected similar object with confidence 0.938 at location [327.31, 54.94, 547.39, 268.06]
        Detected similar object with confidence 0.959 at location [5.78, 360.65, 619.12, 366.39]
        Detected similar object with confidence 0.902 at location [2.85, 360.01, 627.63, 380.8]
        Detected similar object with confidence 0.985 at location [176.98, -29.45, 672.69, 182.83]
        Detected similar object with confidence 1.0 at location [6.53, 14.35, 624.87, 470.82]
        Detected similar object with confidence 0.998 at location [579.98, 29.14, 615.49, 489.05]
        Detected similar object with confidence 0.985 at location [206.15, 10.53, 247.74, 466.01]
        Detected similar object with confidence 0.947 at location [18.62, 429.72, 646.5, 457.72]
        Detected similar object with confidence 0.996 at location [523.88, 20.69, 586.84, 483.18]
        Detected similar object with confidence 0.998 at location [3.39, 360.59, 617.29, 499.21]
        Detected similar object with confidence 0.969 at location [4.47, 449.05, 614.5, 474.76]
        Detected similar object with confidence 0.966 at location [31.44, 463.65, 654.66, 471.07]
        Detected similar object with confidence 0.924 at location [30.93, 468.07, 635.35, 475.39]
        ```N)r   r   r   )r   r   r   r   )r;  r  c                 s   s   | ]}|d ur|V  qd S rE   r#   r7   xr#   r#   r$   r;   1  r   zAOwlv2ForObjectDetection.image_guided_detection.<locals>.<genexpr>)r.   re   rf   rg   r   rb   r/   r0   )rj   r   r   r   rC  r   r    r   rI  r=  r<  r5   r=   rd   )r:   r   rJ  r   r   r   r   rE  r8  r  r   r*  r+  
hidden_dimr;  Zquery_image_featsr  rG  rg   r#  rb   rf   r  r#   r#   r$   image_guided_detection  s^    ;

	z.Owlv2ForObjectDetection.image_guided_detection)r   r   r   r   r   r   r   r   c              
   C   sD  |dur|n| j j}|dur |n| j j}|dur4|n| j j}| j||||||d\}}	}
|
j}|
j}|	j\}}}}t	|	||| |f}|jd | }|	|||jd }|	|||jd }|d dk}| 
|||\}}| |}| ||	|}|s,|||||	|| | f}tdd |D }|S t|	|||||||dS )	a	  
        input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`, *optional*):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids).
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the last hidden state. See `text_model_last_hidden_state` and
            `vision_model_last_hidden_state` under returned tensors for more detail.

        Examples:
        ```python
        >>> import requests
        >>> from PIL import Image
        >>> import torch

        >>> from transformers import Owlv2Processor, Owlv2ForObjectDetection

        >>> processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
        >>> model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> text_labels = [["a photo of a cat", "a photo of a dog"]]
        >>> inputs = processor(text=text_labels, images=image, return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
        >>> target_sizes = torch.tensor([(image.height, image.width)])
        >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
        >>> results = processor.post_process_grounded_object_detection(
        ...     outputs=outputs, target_sizes=target_sizes, threshold=0.1, text_labels=text_labels
        ... )
        >>> # Retrieve predictions for the first image for the corresponding text queries
        >>> result = results[0]
        >>> boxes, scores, text_labels = result["boxes"], result["scores"], result["text_labels"]
        >>> for box, score, text_label in zip(boxes, scores, text_labels):
        ...     box = [round(i, 2) for i in box.tolist()]
        ...     print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}")
        Detected a photo of a cat with confidence 0.614 at location [341.67, 23.39, 642.32, 371.35]
        Detected a photo of a cat with confidence 0.665 at location [6.75, 51.96, 326.62, 473.13]
        ```N)r   r   r   r   r   r   r   rn   r.  c                 s   s   | ]}|d ur|V  qd S rE   r#   rK  r#   r#   r$   r;     r   z2Owlv2ForObjectDetection.forward.<locals>.<genexpr>)r.   r-   ra   r   r`   rb   r/   r0   )rj   r   r   r   rB  r/   r0   r   r    r   r=  r6  r<  r5   r=   r^   )r:   r   r   r   r   r   r   r   r  r8  r   r  r  r   r*  r+  rM  r;  Zmax_text_queriesr  r#  rb   r`   ra   r  r#   r#   r$   r   ?  sZ    4


zOwlv2ForObjectDetection.forward)N)F)NN)NNF)NNF)F)NNNFN)NNNFN)r?   r@   rA   r   rD   rr   staticmethodrI   r    r   r4  rC   r6  r   r   r,  r   r<  r=   r=  rB  rC  rI  r   rd   rN  r^   r   r   r#   r#   r   r$   r$    s   
  %     7   / ,     u     r$  )r   r   r   r   r$  )BrB   dataclassesr   	functoolsr   typingr   r   r   r    Ztorch.utils.checkpointr   r   Zactivationsr
   Zmodeling_attn_mask_utilsr   r   Zmodeling_layersr   Zmodeling_outputsr   r   Zmodeling_utilsr   utilsr   r   r   r   r   Zconfiguration_owlv2r   r   r   Ztransformers.image_transformsr   Z
get_loggerr?   loggerr%   r(   r)   rJ   rM   rX   r]   r^   rd   r   rh   r   r   r   r   r   r   r   r   r   r   r   r  r  r$  __all__r#   r#   r#   r$   <module>   sr   
%	2.N l20QM781 Z0   ~