a
    h=                     @   s  d Z ddlZddlmZ ddlmZmZmZ ddl	Z
ddlZddlZddlmZ ddlmZ ddlmZmZ dd	lmZ dd
lmZmZ ddlmZ ddlmZmZmZmZ ddlm Z m!Z!m"Z" e#e$Z%ej&ej&dddZ'ej&ej&dddZ(ej&e)dddZ*dJej&e+e,e)ej&dddZ-dKddZ.dd  Z/G d!d" d"ej0Z1G d#d$ d$ej0Z2G d%d& d&ej0Z3eeG d'd( d(eZ4G d)d* d*ej0Z5G d+d, d,ej0Z6G d-d. d.ej0Z7G d/d0 d0ej0Z8G d1d2 d2ej0Z9G d3d4 d4e9Z:G d5d6 d6ej0Z;G d7d8 d8eZ<eG d9d: d:eZ=G d;d< d<ej0Z>G d=d> d>ej0Z?G d?d@ d@ej0Z@G dAdB dBe=ZAG dCdD dDej0ZBG dEdF dFe=ZCeG dGdH dHe=ZDg dIZEdS )LzPyTorch GroupViT model.    N)	dataclass)AnyOptionalUnion)nn   )ACT2FN) _create_4d_causal_attention_mask_prepare_4d_attention_mask)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPooling)PreTrainedModel)ModelOutputauto_docstringlogging	torch_int   )GroupViTConfigGroupViTTextConfigGroupViTVisionConfig)logitsreturnc                 C   s   t j| tjt| | jdS )Ndevice)r   
functionalZcross_entropytorcharangelenr   )r    r   j/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/groupvit/modeling_groupvit.pycontrastive_loss(   s    r!   )
similarityr   c                 C   s    t | }t |  }|| d S )Ng       @)r!   t)r"   Zcaption_lossZ
image_lossr   r   r    groupvit_loss-   s    r$   )r   dimc                 C   sJ   |  |}|j|ddd }tj| tjd||d}||  | }|S )NTkeepdimr   Zmemory_format      ?)softmaxmaxr   
zeros_likelegacy_contiguous_formatscatter_detach)r   r%   y_softindexy_hardretr   r   r    hard_softmax3   s
    
r4   F)r   tauhardr%   r   c           
      C   s   t jjt jd| j| jdt jd| j| jd}|| j}| | | }|	|}|r|j
|ddd }t j| t jd||d}||  | }	n|}	|	S )N        )r   dtyper)   Tr&   r   r(   )r   distributionsgumbelZGumbeltensorr   r9   sampleshaper*   r+   r,   r-   r.   r/   )
r   r6   r7   r%   Zgumbel_distZgumbelsr0   r1   r2   r3   r   r   r    gumbel_softmax=   s    
r?   c           	      C   s   || | j d  d }||kr@tt|| }| j d | }n tt|| }| j d | }| j d }| j d }| ||||} tjj| ||fd|d} | S )a  
    Args:
        attentions (`torch.Tensor`): attention map of shape [batch_size, groups, feat_height*feat_width]
        height (`int`): height of the output attention map
        width (`int`): width of the output attention map
        align_corners (`bool`, *optional*): the `align_corner` argument for `nn.functional.interpolate`.

    Returns:
        `torch.Tensor`: resized attention map of shape [batch_size, groups, height, width]
             ?r   r   Zbilinearsizemodealign_corners)r>   intnproundreshaper   r   interpolate)	
attentionsheightwidthrE   scaleZ
feat_widthZfeat_height
batch_sizegroupsr   r   r    resize_attention_mapS   s    

rQ   c                 C   s   g }t  n d}| D ]T}|ddd }|du r:|}n|| }t|ddd g|R  }|| qW d   n1 s0    Y  |d }|S )a1  
    Args:
        attentions (`tuple(torch.FloatTensor)`: tuple of attention maps returned by `GroupViTVisionTransformer`
        hw_shape (`tuple(int)`): height and width of the output attention map
    Returns:
        `torch.Tensor`: the attention map of shape [batch_size, groups, height, width]
    Nr   r@   r   r5   )r   Zno_gradpermute
contiguousrQ   append)rK   Zhw_shapeZ	attn_mapsZprev_attn_masksZ
attn_masksZcur_attn_mapZfinal_groupingr   r   r    get_grouping_from_attentionsq   s    	
*rU   c                       s*   e Zd Zed fddZdd Z  ZS )GroupViTCrossAttentionLayerconfigc                    sJ   t    t|| _tj|j|jd| _t	|| _
tj|j|jd| _d S Neps)super__init__GroupViTAttentionattnr   	LayerNormhidden_sizelayer_norm_epsnorm2GroupViTMLPmlp	norm_postselfrX   	__class__r   r    r]      s
    


z$GroupViTCrossAttentionLayer.__init__c                 C   s<   |}|| j ||dd  }|| | | }| |}|S )N)encoder_hidden_statesr   )r_   re   rc   rf   )rh   querykeyxr   r   r    forward   s
    
z#GroupViTCrossAttentionLayer.forward)__name__
__module____qualname__r   r]   ro   __classcell__r   r   ri   r    rV      s   rV   c                       s4   e Zd Zed fddZd	ddZdd Z  ZS )
GroupViTAssignAttentionrW   c                    sj   t    |jd | _t|j|j| _t|j|j| _t|j|j| _t|j|j| _	|j
| _
d S )N      )r\   r]   ra   rN   r   Linearq_projk_projv_projproj
assign_epsrg   ri   r   r    r]      s    
z GroupViTAssignAttention.__init__Tc                 C   s@   |r| j rt|d|d}n"|r,t|dd}ntjj|dd}|S )N)r%   r7   r%   )trainingr?   r4   r   r   r*   )rh   r_   r;   r7   r   r   r    get_attn   s    
z GroupViTAssignAttention.get_attnc                 C   s   |}|  |}| |}| |}||dd | j }| |}| j|ddd}||jddd| j  }|| }| |}||fS )Nr|   r5   F)r;   r7   Tr%   r'   )	rw   rx   ry   	transposerN   r   sumr{   rz   )rh   rl   rm   valueZraw_attnr_   Z	soft_attnoutr   r   r    ro      s    




zGroupViTAssignAttention.forward)TT)rp   rq   rr   r   r]   r   ro   rs   r   r   ri   r    rt      s   

rt   c                       s2   e Zd Zed fddZdd Zdd Z  ZS )GroupViTTokenAssignrW   c                    s   t    || _tj j jd| _t j	t
jjr: j	n
 j	 j	f} fdd|D \}}t |||| _tj j jd| _tj j jd| _t | _t | _tj j jd| _t  j| j| _d S )NrZ   c                    s   g | ]}t | j qS r   )rF   ra   ).0rn   rW   r   r    
<listcomp>       z0GroupViTTokenAssign.__init__.<locals>.<listcomp>)r\   r]   num_output_groupr   r`   ra   rb   norm_tokens
isinstanceassign_mlp_ratiocollectionsabcIterableGroupViTMixerMLP	mlp_internorm_post_tokensnorm_xrV   pre_assign_attnrt   assign
norm_new_xrd   mlp_channels)rh   rX   num_group_tokenr   r   Z
tokens_dimZchannels_dimri   rW   r    r]      s    



zGroupViTTokenAssign.__init__c                 C   s   |  |}| |}|S )z
        Args:
            group_tokens (torch.Tensor): group tokens, [batch_size, num_group_tokens, channels]

        Returns:
            projected_group_tokens (torch.Tensor): [batch_size, num_output_groups, channels]
        )r   r   )rh   group_tokensprojected_group_tokensr   r   r    project_group_token   s    	

z'GroupViTTokenAssign.project_group_tokenc                 C   s^   |  |}| |}| |}| ||}| ||\}}||7 }|| | | }||fS )z
        Args:
            image_tokens (`torch.Tensor`): image tokens, of shape [batch_size, input_length, channels]
            group_tokens (`torch.Tensor`): group tokens, [batch_size, num_group_tokens, channels]
        )r   r   r   r   r   r   r   )rh   Zimage_tokensr   r   Znew_image_tokens	attentionr   r   r    ro      s    


zGroupViTTokenAssign.forward)rp   rq   rr   r   r]   r   ro   rs   r   r   ri   r    r      s   r   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeej ed< dZeed	< dZeed
< ee dddZdS )GroupViTModelOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
        Contrastive loss for image-text similarity.
    logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
        The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
        similarity scores.
    logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
        The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
        similarity scores.
    segmentation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
        Classification scores for each pixel.

        <Tip warning={true}>

        The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
        to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
        original image size as post-processing. You should always check your logits shape and resize as needed.

        </Tip>
    text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
        The text embeddings obtained by applying the projection layer to the pooled output of
        [`GroupViTTextModel`].
    image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
        The image embeddings obtained by applying the projection layer to the pooled output of
        [`GroupViTVisionModel`].
    text_model_output (`BaseModelOutputWithPooling`):
        The output of the [`GroupViTTextModel`].
    vision_model_output (`BaseModelOutputWithPooling`):
        The output of the [`GroupViTVisionModel`].
    Nlosslogits_per_imagelogits_per_textsegmentation_logitstext_embedsimage_embedstext_model_outputvision_model_outputr   c                    s   t  fdd  D S )Nc                 3   s,   | ]$}|d vr | nt  | V  qdS ))r   r   N)getattrto_tuple)r   krh   r   r    	<genexpr>1  s   z/GroupViTModelOutput.to_tuple.<locals>.<genexpr>)tuplekeysr   r   r   r    r   0  s    zGroupViTModelOutput.to_tuple)rp   rq   rr   __doc__r   r   r   FloatTensor__annotations__r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    r     s   
r   c                       sV   e Zd ZdZdeeeeeef f eed fddZdej	e
ej	d
ddZ  ZS )GroupViTPatchEmbeddingsz#
    Image to Patch Embedding.
          r      
image_size
patch_sizenum_channels	embed_dimc                    s   t    t|tjjr|n||f}t|tjjr6|n||f}|d |d  |d |d   }|| _|| _|| _t	j
||||d| _d S )Nr   r   )Zkernel_sizeZstride)r\   r]   r   r   r   r   r   r   num_patchesr   Conv2d
projection)rh   r   r   r   r   r   ri   r   r    r]   <  s    
 z GroupViTPatchEmbeddings.__init__Fpixel_valuesinterpolate_pos_encodingr   c              
   C   sx   |j \}}}}|s\|| jd ks.|| jd kr\td| d| d| jd  d| jd  d	| |ddd}|S )Nr   r   zInput image size (*z) doesn't match model ().r@   )r>   r   
ValueErrorr   flattenr   )rh   r   r   rO   r   rL   rM   rn   r   r   r    ro   M  s    zGroupViTPatchEmbeddings.forward)r   r   r   r   )F)rp   rq   rr   r   rF   r   r   r]   r   Tensorboolro   rs   r   r   ri   r    r   7  s       r   c                       sR   e Zd Zed fddZejeeejdddZdeje	ejdd	d
Z
  ZS )GroupViTVisionEmbeddingsrW   c                    sx   t    t|j|j|j|jd| _| jj}t	
td||j| _t	|j| _t	j|j|jd| _|j| _|| _d S )Nr   r   rZ   )r\   r]   r   r   r   r   ra   patch_embeddingsr   r   	Parameterr   zerosposition_embeddingsZDropoutdropoutr`   rb   	layernormrX   )rh   rX   r   ri   r   r    r]   Z  s    
z!GroupViTVisionEmbeddings.__init__)
embeddingsrL   rM   r   c                 C   s   |j d }| jj d }tj s6||kr6||kr6| jS | j}|j d }|| j }|| j }	t|d }
|d|
|
|}|dddd}t	j
j|||	fddd	}|dddddd|}|S )
a  
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing and no class embeddings.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   r5   rA   r   r   r@   ZbicubicFrB   )r>   r   r   Zjit
is_tracingr   r   rI   rR   r   r   rJ   view)rh   r   rL   rM   r   Znum_positionsZpatch_pos_embedr%   Z
new_heightZ	new_widthZsqrt_num_positionsr   r   r    r   j  s&    




z1GroupViTVisionEmbeddings.interpolate_pos_encodingFr   c           
      C   sd   |j \}}}}| j||d}| |}| \}}}	|rL|| ||| }n
|| j }| |}|S )N)r   )r>   r   r   rC   r   r   r   )
rh   r   r   rO   r   rL   rM   r   seq_len_r   r   r    ro     s    


z GroupViTVisionEmbeddings.forward)F)rp   rq   rr   r   r]   r   r   rF   r   r   ro   rs   r   r   ri   r    r   Y  s   &r   c                       sL   e Zd Zed fddZdeej eej eej ej	dddZ
  ZS )	GroupViTTextEmbeddingsrW   c                    sR   t    |j}t|j|| _t|j|| _| j	dt
|jddd d S )Nposition_ids)r   r5   F)
persistent)r\   r]   ra   r   Z	EmbeddingZ
vocab_sizetoken_embeddingZmax_position_embeddingsposition_embeddingZregister_bufferr   r   expandrh   rX   r   ri   r   r    r]     s    
zGroupViTTextEmbeddings.__init__N)	input_idsr   inputs_embedsr   c                 C   s   |d ur|j d n|j d }| jjj d }||krFtd| d| |d u rd| jd d d |f }|d u rv| |}| |}|| }|S )Nr5   r|   r   zRSequence length must be less than max_position_embeddings (got `sequence length`: z and max_position_embeddings: )r>   r   weightr   r   r   )rh   r   r   r   Z
seq_lengthZmax_position_embeddingr   r   r   r   r    ro     s"    

zGroupViTTextEmbeddings.forward)NNN)rp   rq   rr   r   r]   r   r   
LongTensorr   r   ro   rs   r   r   ri   r    r     s      r   c                       s   e Zd ZdZeeeeed fddZedd Zdd Z	de
jee
j e
jd
ddZde
jee
j ee ee
j dddZ  ZS )GroupViTStagezMThis corresponds to the `GroupingLayer` class in the GroupViT implementation.)rX   depthnum_prev_group_tokenr   r   c                    s   t    || _|| _|dkr8ttd| j| _	nd | _	t
 fddt|D | _|dkrvt ||d| _nd | _|dkr|dkrttj j jdt | jd || _nd | _d S )Nr   r   c                    s   g | ]}t  qS r   GroupViTEncoderLayerr   r   rW   r   r    r     r   z*GroupViTStage.__init__.<locals>.<listcomp>)rX   r   r   rZ   r@   )r\   r]   r   r   r   r   r   r   ra   group_token
ModuleListrangelayersr   
downsample
Sequentialr`   rb   r   group_projector)rh   rX   r   r   r   r   ri   rW   r    r]     s(    

zGroupViTStage.__init__c                 C   s
   | j d uS N)r   r   r   r   r    with_group_token  s    zGroupViTStage.with_group_tokenc                 C   sB   | j r6|d d d | j f |d d | j d f fS |d fS d S r   )r   r   rh   rn   r   r   r    split_x  s    0zGroupViTStage.split_xN)rn   r   r   c                 C   s   |d u r|S t j||gddS )Nr   r}   )r   cat)rh   rn   r   r   r   r    concat_x  s    zGroupViTStage.concat_xF)hidden_statesprev_group_tokenoutput_attentionsr   c                 C   s   | j r6| j|ddd}| jdur:|| | }nd}|}| ||}| jD ]}||ddd}|d }qP| |\}}d}	| jdur| ||\}}	||f}
|r|
|	f }
|
S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                `(config.encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the grouping tensors of Grouping block.
        r   r5   N)attention_maskcausal_attention_mask)	r   r   r   rC   r   r   r   r   r   )rh   r   r   r   r   rn   Zcat_xlayerZ	layer_outr   outputsr   r   r    ro      s$    




zGroupViTStage.forward)N)NF)rp   rq   rr   r   r   rF   r]   propertyr   r   r   r   r   r   r   r   r   ro   rs   r   r   ri   r    r     s&   "
  r   c                       sJ   e Zd Zdeee ee ee d fddZejejdddZ	  Z
S )	rd   N)rX   ra   intermediate_sizeoutput_sizec                    sp   t    || _t|j | _|d ur(|n|j}|d ur:|n|j}|d urL|n|}t	||| _
t	||| _d S r   )r\   r]   rX   r   Z
hidden_actactivation_fnra   r   r   rv   fc1fc2)rh   rX   ra   r   r   ri   r   r    r]   +  s    
zGroupViTMLP.__init__)r   r   c                 C   s"   |  |}| |}| |}|S r   )r   r   r   )rh   r   r   r   r    ro   ;  s    


zGroupViTMLP.forward)NNN)rp   rq   rr   r   r   rF   r]   r   r   ro   rs   r   r   ri   r    rd   *  s      rd   c                       s   e Zd Z fddZ  ZS )r   c                    s    t  |dd}|ddS Nr   r@   )r\   ro   r   r   ri   r   r    ro   C  s    zGroupViTMixerMLP.forward)rp   rq   rr   ro   rs   r   r   ri   r    r   B  s   r   c                       s   e Zd ZdZ fddZejeedddZdeje	ej e	ej e	ej
 e	e eeje	ej e	eej  f d	d
dZ  ZS )r^   z=Multi-headed attention from 'Attention Is All You Need' paperc                    s   t    || _|j| _|j| _| j| j | _| j| j | jkrZtd| j d| j d| jd | _	|j
| _t| j| j| _t| j| j| _t| j| j| _t| j| j| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: r   ru   )r\   r]   rX   ra   r   Znum_attention_heads	num_headshead_dimr   rN   Zattention_dropoutr   r   rv   rx   ry   rw   out_projrg   ri   r   r    r]   K  s"    

zGroupViTAttention.__init__)r<   r   bszc                 C   s    | ||| j| jdd S r   )r   r   r   r   rS   )rh   r<   r   r  r   r   r    _shape^  s    zGroupViTAttention._shapeNF)r   r   r   rk   r   r   c                 C   s  |  \}}}|du}	| || j }
|	rT| | |d|}| | |d|}n(| | |d|}| | |d|}|| j d| jf}| |
||j| }
|j| }|j| }| d}t	
|
|dd}|  || j ||fkrtd|| j ||f d|   |dur||  |d||fkrRtd|d||f d|   ||| j||| }||| j ||}|dur|  |d||fkrtd|d||f d|   ||| j||| }||| j ||}tjj|dd}|r$||| j||}||| j ||}nd}tjj|| j| jd	}t	
||}|  || j || jfkrtd
|| j|| jf d|   ||| j|| j}|dd}||||}| |}||fS )z#Input shape: Batch x Time x ChannelNr5   r   r@   z$Attention weights should be of size z	, but is z!Attention mask should be of size r}   )pr~   z `attn_output` should be of size )rC   rw   rN   r  rx   ry   r   r   r   r   Zbmmr   r   r   r   r*   r   r~   rI   r   )rh   r   r   r   rk   r   r  Ztgt_lenr   Zis_cross_attentionZquery_statesZ
key_statesZvalue_statesZ
proj_shapeZsrc_lenattn_weightsZattn_weights_reshapedZ
attn_probsZattn_outputr   r   r    ro   a  sl    






zGroupViTAttention.forward)NNNF)rp   rq   rr   r   r]   r   r   rF   r  r   r   r   r   ro   rs   r   r   ri   r    r^   H  s       r^   c                       sJ   e Zd Zed fddZdejejejee e	ej
 dddZ  ZS )	r   rW   c                    sR   t    |j| _t|| _tj| j|jd| _	t
|| _tj| j|jd| _d S rY   )r\   r]   ra   r   r^   	self_attnr   r`   rb   layer_norm1rd   re   layer_norm2rg   ri   r   r    r]     s    


zGroupViTEncoderLayer.__init__F)r   r   r   r   r   c                 C   sd   |}|  |}| j||||d\}}|| }|}| |}| |}|| }|f}|r`||f7 }|S )aI  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                `(config.encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r   r   r   r   )r  r  r  re   )rh   r   r   r   r   Zresidualr  r   r   r   r    ro     s"    




zGroupViTEncoderLayer.forward)F)rp   rq   rr   r   r]   r   r   r   r   r   r   ro   rs   r   r   ri   r    r     s    r   c                   @   s&   e Zd ZU eed< dZdZdd ZdS )GroupViTPreTrainedModelrX   ZgroupvitTc                 C   s  | j j}t|tjtjfrD|jjjd|d |j	durj|j	j
  n&t|tjrj|j	j
  |jjd | j j}t|tr|jjjjd|d d |jjjjd|d d nt|trB| j j}|jd d|j j d  | }|jd | }tjj|jj|d tjj|jj|d tjj|jj|d tjj|jj|d npt|tr| j j}|j jd d|j j d  | }d|j j d | }tjj|jj|d tjj|jj|d dS )	zInitialize the weightsr8   )meanstdNr)   g{Gz?ru   r@   )r
  )rX   Zinitializer_ranger   r   rv   r   r   dataZnormal_biasZzero_r`   Zfill_Zinitializer_factorr   r   r   r^   r   num_hidden_layersinitrw   rx   ry   r   rd   ra   r   r   )rh   moduleZ
init_rangefactorZin_proj_stdZout_proj_stdZfc_stdr   r   r    _init_weights  s4    

 z%GroupViTPreTrainedModel._init_weightsN)rp   rq   rr   r   r   Zbase_model_prefixZsupports_gradient_checkpointingr  r   r   r   r    r    s   
r  c                       sR   e Zd Zedd fddZdejee ee ee e	e
ef dddZ  ZS )	GroupViTVisionEncoderN)rX   r   c                    s>   t     | _t fddtt jD | _d| _	d S )Nc              
      sF   g | ]>}t   j|  j|  j| |d kr: j|d  nd dqS )r   r   )rX   r   r   r   r   )r   depthsZnum_group_tokensZnum_output_groups)r   irW   r   r    r     s   z2GroupViTVisionEncoder.__init__.<locals>.<listcomp>F)
r\   r]   rX   r   r   r   r   r  stagesgradient_checkpointingrg   ri   rW   r    r]     s    

zGroupViTVisionEncoder.__init__)r   output_hidden_statesr   return_dictr   c                 C   s   |d ur|n| j j}|d ur |n| j j}|d ur4|n| j j}|rDdnd }|rPdnd }d }t| jD ]P\}}	|rx||f }|	|||}
|
d }|
d }|rb|
d d urb||
d f }qb|r||f }|stdd |||fD S t|||dS )Nr   r   r   r@   c                 s   s   | ]}|d ur|V  qd S r   r   r   vr   r   r    r   E  r   z0GroupViTVisionEncoder.forward.<locals>.<genexpr>last_hidden_stater   rK   )rX   r   r  use_return_dict	enumerater  r   r   )rh   r   r  r   r  Zall_hidden_statesZall_groupingsr   r  Zstagelayer_outputsr   r   r    ro   #  s,    

zGroupViTVisionEncoder.forward)NNN)rp   rq   rr   r   r]   r   r   r   r   r   r   r   ro   rs   r   r   ri   r    r    s      
r  c                	       s`   e Zd ZdZed fddZd	eej eej ee	 ee	 ee	 e
eef dddZ  ZS )
GroupViTTextEncoderz
    Transformer encoder consisting of `config.num_hidden_layers` self-attention layers. Each layer is a
    [`GroupViTEncoderLayer`].

    Args:
        config: GroupViTTextConfig
    rW   c                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r   r   r   rW   r   r    r   W  r   z0GroupViTTextEncoder.__init__.<locals>.<listcomp>F)	r\   r]   rX   r   r   r   r  r   r  rg   ri   rW   r    r]   T  s    
 zGroupViTTextEncoder.__init__N)r   r   r   r  r  r   c                 C   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}|rDdnd}|rPdnd}|}	t| jD ]@\}
}|rx||	f }||	|||d}|d }	|rb||d f }qb|r||	f }|stdd |	||fD S t|	||dS )	a  
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Causal mask for the text model. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr   )r   r   r   c                 s   s   | ]}|d ur|V  qd S r   r   r  r   r   r    r     r   z.GroupViTTextEncoder.forward.<locals>.<genexpr>r  )rX   r   r  r  r  r   r   r   )rh   r   r   r   r   r  r  Zencoder_statesZall_attentionsr   idxZencoder_layerr  r   r   r    ro   Z  s4    &

zGroupViTTextEncoder.forward)NNNNN)rp   rq   rr   r   r   r]   r   r   r   r   r   r   r   ro   rs   r   r   ri   r    r   K  s   	     
r   c                       sh   e Zd Zed fddZedeej eej eej ee	 ee	 ee	 e
eef dddZ  ZS )	GroupViTTextTransformerrW   c                    sH   t    || _|j}t|| _t|| _tj	||j
d| _|j| _d S rY   )r\   r]   rX   ra   r   r   r   encoderr   r`   rb   final_layer_normeos_token_idr   ri   r   r    r]     s    


z GroupViTTextTransformer.__init__Nr   r   r   r   r  r  r   c                 C   sn  |d ur|n| j j}|d ur |n| j j}|d ur4|n| j j}|d u rLtd| }|d|d }| j||d}t||j	|j
d}	|d urt||j	}| j|||	|||d}
|
d }| |}| jdkr|tj|jd |j
d|jtj|j
djdd	f }n>|tj|jd |j
d|jtj|j
d| jk jdd	f }|sZ||f|
d
d   S t|||
j|
jdS )NzYou have to specify input_idsr5   )r   r   r   )r   r   r   r   r  r  r   r@   )r9   r   r}   r   r  Zpooler_outputr   rK   )rX   r   r  r  r   rC   r   r   r	   r9   r   r
   r#  r$  r%  r   r   r>   torF   Zargmaxr   r   rK   )rh   r   r   r   r   r  r  Zinput_shaper   r   encoder_outputsr  pooled_outputr   r   r    ro     sZ    

	
	zGroupViTTextTransformer.forward)NNNNNN)rp   rq   rr   r   r]   r   r   r   r   r   r   r   r   ro   rs   r   r   ri   r    r"    s"         
r"  c                       s   e Zd ZU eed< ed fddZejdddZdd	 Z	e
deej eej eej ee ee ee eeef dddZ  ZS )GroupViTTextModelrX   rW   c                    s"   t  | t|| _|   d S r   )r\   r]   r"  
text_model	post_initrg   ri   r   r    r]     s    
zGroupViTTextModel.__init__r   c                 C   s
   | j jjS r   r,  r   r   r   r   r   r    get_input_embeddings  s    z&GroupViTTextModel.get_input_embeddingsc                 C   s   || j j_d S r   r.  )rh   r   r   r   r    set_input_embeddings  s    z&GroupViTTextModel.set_input_embeddingsNr&  c                 C   s   | j ||||||dS )a9  
        Examples:

        ```python
        >>> from transformers import CLIPTokenizer, GroupViTTextModel

        >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc")
        >>> model = GroupViTTextModel.from_pretrained("nvidia/groupvit-gcc-yfcc")

        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled (EOS token) states
        ```r   r   r   r   r  r  )r,  )rh   r   r   r   r   r  r  r   r   r    ro     s    zGroupViTTextModel.forward)NNNNNN)rp   rq   rr   r   r   r]   r   Moduler/  r0  r   r   r   r   r   r   r   r   ro   rs   r   r   ri   r    r+    s(   
      
r+  c                	       sX   e Zd Zed fddZedeej ee	 ee	 ee	 e
eef dddZ  ZS )	GroupViTVisionTransformerrW   c                    s@   t    || _|j}t|| _t|| _tj	||j
d| _d S rY   )r\   r]   rX   ra   r   r   r  r#  r   r`   rb   r   r   ri   r   r    r]   2  s    


z"GroupViTVisionTransformer.__init__N)r   r  r   r  r   c           	      C   s   |d ur|n| j j}|d ur |n| j j}|d ur4|n| j j}|d u rLtd| |}| j||||d}|d }| |}|jdd}|s||f|dd   S t	|||j
|jdS )Nz You have to specify pixel_values)r   r  r   r  r   r   r}   r'  )rX   r   r  r  r   r   r#  r   r	  r   r   rK   )	rh   r   r  r   r  r   r)  r  r*  r   r   r    ro   ;  s0    

z!GroupViTVisionTransformer.forward)NNNN)rp   rq   rr   r   r]   r   r   r   r   r   r   r   r   ro   rs   r   r   ri   r    r3  1  s   	    
r3  c                	       st   e Zd ZU eed< dZed fddZedddZe	de
ej e
e e
e e
e eeef d
ddZ  ZS )GroupViTVisionModelrX   r   rW   c                    s"   t  | t|| _|   d S r   )r\   r]   r3  vision_modelr-  rg   ri   r   r    r]   j  s    
zGroupViTVisionModel.__init__r   c                 C   s
   | j jjS r   )r5  r   r   r   r   r   r    r/  p  s    z(GroupViTVisionModel.get_input_embeddingsNr   r   r  r  r   c                 C   s   | j ||||dS )a  
        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, GroupViTVisionModel

        >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
        >>> model = GroupViTVisionModel.from_pretrained("nvidia/groupvit-gcc-yfcc")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled CLS states
        ```r   r   r  r  )r5  )rh   r   r   r  r  r   r   r    ro   s  s    zGroupViTVisionModel.forward)NNNN)rp   rq   rr   r   r   Zmain_input_namer]   r   r/  r   r   r   r   r   r   r   r   ro   rs   r   r   ri   r    r4  f  s    
    
r4  c                       s   e Zd ZU eed< ed fddZedeej	 eej	 eej	 ee
 ee
 ee
 ejdddZedeej ee
 ee
 ee
 ejd	d
dZedeej eej eej	 eej ee
 ee
 ee
 ee
 ee
 eeef d
ddZ  ZS )GroupViTModelrX   rW   c              
      s6  t  | t|jts.tdt|j dt|jtsPtdt|j d|j}|j}|j	| _	|j
| _
|j| _|j| _t|| _t|| _ttj| j| j
ddt| j
tjddtj| j
| j	dd| _ttj| j| j
ddt| j
tjddtj| j
| j	dd| _tt| jj| _|   d S )NzOconfig.text_config is expected to be of type GroupViTTextConfig but is of type .zSconfig.vision_config is expected to be of type GroupViTVisionConfig but is of type T)r  )Zinplace) r\   r]   r   text_configr   	TypeErrortypevision_configr   Zprojection_dimZprojection_intermediate_dimra   Ztext_embed_dimZvision_embed_dimr"  r,  r3  r5  r   r   rv   ZBatchNorm1dZReLUvisual_projectiontext_projectionr   r   r<   rX   Zlogit_scale_init_valuelogit_scaler-  )rh   rX   r:  r=  ri   r   r    r]     sF    





zGroupViTModel.__init__Nr&  c           
      C   sh   |dur|n| j j}|dur |n| j j}|dur4|n| j j}| j||||||d}|d }| |}	|	S )a  
        Returns:
            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of [`GroupViTTextModel`].

        Examples:

        ```python
        >>> from transformers import CLIPTokenizer, GroupViTModel

        >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
        >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc")

        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
        >>> text_features = model.get_text_features(**inputs)
        ```Nr1  r   )rX   r   r  r  r,  r?  )
rh   r   r   r   r   r  r  text_outputsr*  Ztext_featuresr   r   r    get_text_features  s    	
zGroupViTModel.get_text_featuresr6  c                 C   sd   |dur|n| j j}|dur |n| j j}|dur4|n| j j}| j||||d}|d }| |}|S )aH  
        Returns:
            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of [`GroupViTVisionModel`].

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, GroupViTModel

        >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
        >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> image_features = model.get_image_features(**inputs)
        ```Nr7  r   )rX   r   r  r  r5  r>  )rh   r   r   r  r  vision_outputsr*  Zimage_featuresr   r   r    get_image_features  s    
z GroupViTModel.get_image_features)
r   r   r   r   return_lossr   r  output_segmentationr  r   c
              
   C   sV  |dur|n| j j}|dur |n| j j}|r0d}|dur<|n| j j}|	durP|	n| j j}	| j||||	d}
| j||||||	d}|
d }| |}|d }| |}||j	ddd }||j	ddd }| j
 }t|| | }| }d}|r|
d }| |d|jd }|r(|
d	 }n|
d
 }t||jd
d }||j	ddd }t|| | }||jd d|jd dd
d}||jd |jd d}t||| }||jd |jd |jd
 |jd	 }d}|rt|}|	s>|dur|||||||
f}n||||||
f}|dur:|f| S |S t||||||||
dS )aM  
        return_loss (`bool`, *optional*):
            Whether or not to return the contrastive loss.
        output_segmentation (`bool`, *optional*):
            Whether or not to return the segmentation logits.

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, GroupViTModel

        >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
        >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(
        ...     text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
        ... )

        >>> outputs = model(**inputs)
        >>> logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
        >>> probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities
        ```NTr7  r1  r   r5   r   r   r   r@   )r   r   r   r   r   r   r   r   )rX   r   rF  r  r  r5  r,  r>  r?  Znormr@  expr   matmulr#   rI   r>   rU   rR   r$   r   )rh   r   r   r   r   rE  r   r  rF  r  rC  rA  r   r   r@  r   r   Z
seg_logitsZimage_group_embedsrK   groupingZlogits_per_image_groupZflatten_groupingr   outputr   r   r    ro   &  s    )	



 

zGroupViTModel.forward)NNNNNN)NNNN)	NNNNNNNNN)rp   rq   rr   r   r   r]   r   r   r   r   r   r   rB  rD  r   r   r   r   ro   rs   r   r   ri   r    r8    sh   
+      .    0         
r8  )r8  r  r+  r4  )r   Fr5   )F)Fr   collections.abcr   dataclassesr   typingr   r   r   numpyrG   r   Ztorch.utils.checkpointr   Zactivationsr   Zmodeling_attn_mask_utilsr	   r
   Zmodeling_layersr   Zmodeling_outputsr   r   Zmodeling_utilsr   utilsr   r   r   r   Zconfiguration_groupvitr   r   r   Z
get_loggerrp   loggerr   r!   r$   rF   r4   floatr   r?   rQ   rU   r2  rV   rt   r   r   r   r   r   r   rd   r   r^   r   r  r  r   r"  r+  r3  r4  r8  __all__r   r   r   r    <module>   s`   


070"K(^o2':X\251  "