a
    h                  	   @   s  d Z ddlZddlZddlmZ ddlmZmZmZm	Z	 ddl
Z
ddlZ
ddl
mZ ddlmZ ddlmZmZ dd	lmZ dd
lmZmZ ddlmZmZ ddlmZmZmZmZmZ ddl m!Z!m"Z"m#Z# e$e%Z&e
j'e
j'dddZ(e
j'e
j'dddZ)eeG dd deZ*eeG dd deZ+eeG dd deZ,G dd dej-Z.G dd dej-Z/d@ej-e
j'e
j'e
j'ee
j' e0e0d d!d"Z1G d#d$ d$ej-Z2G d%d& d&ej-Z3G d'd( d(eZ4eG d)d* d*eZ5G d+d, d,ej-Z6G d-d. d.ej-Z7G d/d0 d0e5Z8G d1d2 d2ej-Z9G d3d4 d4e5Z:eG d5d6 d6e5Z;G d7d8 d8ej-Z<G d9d: d:e5Z=ed;d<G d=d> d>e5Z>g d?Z?dS )AzPyTorch CLIPSeg model.    N)	dataclass)AnyCallableOptionalUnion)nn   )ACT2FN) _create_4d_causal_attention_mask_prepare_4d_attention_mask)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPooling)ALL_ATTENTION_FUNCTIONSPreTrainedModel)ModelOutputauto_docstringcan_return_tuplelogging	torch_int   )CLIPSegConfigCLIPSegTextConfigCLIPSegVisionConfig)logitsreturnc                 C   s   t j| tjt| | jdS )Ndevice)r   
functionalZcross_entropytorcharangelenr   )r    r"   h/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/clipseg/modeling_clipseg.pycontrastive_loss(   s    r$   )
similarityr   c                 C   s    t | }t |  }|| d S )Ng       @)r$   t)r%   Zcaption_lossZ
image_lossr"   r"   r#   clipseg_loss-   s    r'   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeed< dZeed	< ee d
ddZdS )CLIPSegOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
        Contrastive loss for image-text similarity.
    logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
        The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
        similarity scores.
    logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
        The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
        similarity scores.
    text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
        The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPSegTextModel`].
    image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
        The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPSegVisionModel`].
    text_model_output (`BaseModelOutputWithPooling`):
        The output of the [`CLIPSegTextModel`].
    vision_model_output (`BaseModelOutputWithPooling`):
        The output of the [`CLIPSegVisionModel`].
    Nlosslogits_per_imagelogits_per_texttext_embedsimage_embedstext_model_outputvision_model_outputr   c                    s   t  fdd  D S )Nc                 3   s,   | ]$}|d vr | nt  | V  qdS ))r.   r/   Ngetattrto_tuple.0kselfr"   r#   	<genexpr>S   s   z)CLIPSegOutput.to_tuple.<locals>.<genexpr>tuplekeysr7   r"   r7   r#   r3   R   s    zCLIPSegOutput.to_tuple)__name__
__module____qualname____doc__r)   r   r   FloatTensor__annotations__r*   r+   r,   r-   r.   r   r/   r;   r   r3   r"   r"   r"   r#   r(   3   s   
r(   c                   @   sP   e Zd ZU dZdZeej ed< dZ	ee
ej  ed< dZee
ej  ed< dS )CLIPSegDecoderOutputz|
    logits (`torch.FloatTensor` of shape `(batch_size, height, width)`):
        Classification scores for each pixel.
    Nr   hidden_states
attentions)r=   r>   r?   r@   r   r   r   rA   rB   rD   r;   rE   r"   r"   r"   r#   rC   Y   s   
rC   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeed< dZeed< ee d	d
dZdS )CLIPSegImageSegmentationOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Binary cross entropy loss for segmentation.
    logits (`torch.FloatTensor` of shape `(batch_size, height, width)`):
        Classification scores for each pixel.
    conditional_embeddings (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
        Conditional embeddings used for segmentation.
    pooled_output (`torch.FloatTensor` of shape `(batch_size, embed_dim)`):
        Pooled output of the [`CLIPSegVisionModel`].
    vision_model_output (`BaseModelOutputWithPooling`):
        The output of the [`CLIPSegVisionModel`].
    decoder_output (`CLIPSegDecoderOutput`):
        The output of the [`CLIPSegDecoder`].
    Nr)   r   conditional_embeddingspooled_outputr/   decoder_outputr0   c                    s   t  fdd  D S )Nc                 3   s,   | ]$}|d vr | nt  | V  qdS ))r/   rI   Nr1   r4   r7   r"   r#   r9      s   z:CLIPSegImageSegmentationOutput.to_tuple.<locals>.<genexpr>r:   r7   r"   r7   r#   r3      s    z'CLIPSegImageSegmentationOutput.to_tuple)r=   r>   r?   r@   r)   r   r   rA   rB   r   rG   rH   r/   r   rI   rC   r;   r   r3   r"   r"   r"   r#   rF   f   s   
rF   c                       sP   e Zd Zed fddZejeeejdddZdej	ejdd	d
Z
  ZS )CLIPSegVisionEmbeddingsconfigc                    s   t    || _|j| _|j| _|j| _tt	
| j| _tj|j| j| j| jdd| _| j| j d | _| jd | _t| j| j| _| jdt	| jddd d S )NF)Zin_channelsZout_channelskernel_sizestridebias   r   position_idsr   
persistent)super__init__rL   hidden_size	embed_dim
image_size
patch_sizer   	Parameterr   Zrandnclass_embeddingConv2dZnum_channelspatch_embeddingnum_patchesnum_positions	Embeddingposition_embeddingregister_bufferr    expandr8   rL   	__class__r"   r#   rW      s"    
z CLIPSegVisionEmbeddings.__init__)
embeddingsheightwidthr   c                 C   s  |j d d }| jjd}|j d d }tj sP||krP||krP| | jS |ddddf }|ddddf }|j d }	|| j }
|| j }t	|d }|
d|||	}|dddd}tjj||
|fdd	d
}|dddddd|	}tj||fddS )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   r   NrS   g      ?r   rP   ZbicubicF)sizemodeZalign_cornersdim)shaperc   weightZ	unsqueezer   Zjit
is_tracingrQ   r[   r   reshapepermuter   r   Zinterpolateviewcat)r8   ri   rj   rk   r`   rc   ra   Zclass_pos_embedZpatch_pos_embedro   Z
new_heightZ	new_widthZsqrt_num_positionsr"   r"   r#   interpolate_pos_encoding   s*    



z0CLIPSegVisionEmbeddings.interpolate_pos_encodingT)pixel_valuesr   c           
   
   C   s   |j \}}}}|sL|| jks&|| jkrLtd| d| d| j d| j d	| |}|ddd}| j|dd}tj	||gdd}	|r|	| 
|	|| }	n|	| | j }	|	S )	NzInput image size (*z) doesn't match model ().rP   r   rS   rn   )rp   rZ   
ValueErrorr_   flatten	transposer]   re   r   rv   rw   rc   rQ   )
r8   rx   rw   
batch_size_rj   rk   Zpatch_embedsZclass_embedsri   r"   r"   r#   forward   s     
zCLIPSegVisionEmbeddings.forward)T)r=   r>   r?   r   rW   r   Tensorintrw   rA   r   __classcell__r"   r"   rg   r#   rJ      s   )rJ   c                       sL   e Zd Zed fddZdeej eej eej ej	dddZ
  ZS )	CLIPSegTextEmbeddingsrK   c                    sR   t    |j}t|j|| _t|j|| _| j	dt
|jddd d S )NrQ   rR   FrT   )rV   rW   rX   r   rb   Z
vocab_sizetoken_embeddingZmax_position_embeddingsrc   rd   r   r    re   r8   rL   rY   rg   r"   r#   rW      s    
zCLIPSegTextEmbeddings.__init__N)	input_idsrQ   inputs_embedsr   c                 C   s   |d ur|j d n|j d }| jjj d }||krFtd| d| |d u rd| jd d d |f }|d u rv| |}| |}|| }|S )NrS   r   zRSequence length must be less than max_position_embeddings (got `sequence length`: z and max_position_embeddings: )rp   rc   rq   r{   rQ   r   )r8   r   rQ   r   
seq_lengthZmax_position_embeddingZposition_embeddingsri   r"   r"   r#   r      s"    

zCLIPSegTextEmbeddings.forward)NNN)r=   r>   r?   r   rW   r   r   
LongTensorrA   r   r   r   r"   r"   rg   r#   r      s      r           )modulequerykeyvalueattention_maskscalingdropoutc           
      K   s|   t ||dd| }|d ur(|| }tjj|dt jd|j}tjj	||| j
d}t ||}	|	dd }	|	|fS )NrS   r   )ro   dtype)ptrainingr   rP   )r   matmulr}   r   r   ZsoftmaxZfloat32tor   r   r   
contiguous)
r   r   r   r   r   r   r   kwargsattn_weightsattn_outputr"   r"   r#   eager_attention_forward  s    
r   c                	       sh   e Zd ZdZeeef d fddZd
ej	e
ej	 e
ej	 e
e eej	e
ej	 f ddd	Z  ZS )CLIPSegAttentionz=Multi-headed attention from 'Attention Is All You Need' paperrK   c                    s   t    || _|j| _|j| _| j| j | _| j| j | jkrZtd| j d| j d| jd | _	|j
| _d| _t| j| j| _t| j| j| _t| j| j| _t| j| j| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: rz         F)rV   rW   rL   rX   rY   num_attention_heads	num_headshead_dimr{   scaleZattention_dropoutr   	is_causalr   Lineark_projv_projq_projout_projrf   rg   r"   r#   rW     s$    

zCLIPSegAttention.__init__NFrD   r   causal_attention_maskoutput_attentionsr   c              
   C   sP  |j \}}}| |}| |}	| |}
|||| j| jdd}|	||| j| jdd}	|
||| j| jdd}
| jj	dkr|dur|dur|| }q|dur|}n
|du| _
t}| jj	dkr| jj	dkr|rtd nt| jj	 }|| ||	|
|| j
| j| jsdn| jd	\}}|||| }| |}|sHd}||fS )
z#Input shape: Batch x Time x Channelr   rP   Zflash_attention_2NeagerZsdpaz`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.r   )r   r   r   )rp   r   r   r   ru   r   r   r}   rL   _attn_implementationr   r   loggerZwarning_oncer   r   r   r   rs   r   r   )r8   rD   r   r   r   r~   r   rY   Zqueriesr<   valuesZattention_interfacer   r   r"   r"   r#   r   1  sF    	






zCLIPSegAttention.forward)NNF)r=   r>   r?   r@   r   r   r   rW   r   r   r   boolr;   r   r   r"   r"   rg   r#   r     s      r   c                       s0   e Zd Z fddZejejdddZ  ZS )
CLIPSegMLPc                    sD   t    || _t|j | _t|j|j	| _
t|j	|j| _d S N)rV   rW   rL   r	   
hidden_actactivation_fnr   r   rX   intermediate_sizefc1fc2rf   rg   r"   r#   rW   l  s
    
zCLIPSegMLP.__init__)rD   r   c                 C   s"   |  |}| |}| |}|S r   )r   r   r   )r8   rD   r"   r"   r#   r   s  s    


zCLIPSegMLP.forward)r=   r>   r?   rW   r   r   r   r   r"   r"   rg   r#   r   k  s   r   c                       sJ   e Zd Zed fddZdejejejee e	ej
 dddZ  ZS )	CLIPSegEncoderLayerrK   c                    sR   t    |j| _t|| _tj| j|jd| _	t
|| _tj| j|jd| _d S N)epsrV   rW   rX   rY   r   	self_attnr   	LayerNormlayer_norm_epslayer_norm1r   mlplayer_norm2rf   rg   r"   r#   rW   |  s    


zCLIPSegEncoderLayer.__init__Fr   c                 C   sd   |}|  |}| j||||d\}}|| }|}| |}| |}|| }|f}|r`||f7 }|S aI  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                `(config.encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )rD   r   r   r   )r   r   r   r   r8   rD   r   r   r   Zresidualr   outputsr"   r"   r#   r     s"    




zCLIPSegEncoderLayer.forward)F)r=   r>   r?   r   rW   r   r   r   r   r;   rA   r   r   r"   r"   rg   r#   r   {  s    r   c                   @   s&   e Zd ZU eed< dZdZdd ZdS )CLIPSegPreTrainedModelrL   clipTc                 C   sX  | j j}t|trF|jjjjd|d d |jjjjd|d d nt|t	r| j j}t
jj|jd|jd | d t
jj|jj|j j| d t
jj|jj|j j| d nTt|trD| j j}|jd d|j j d  | }|jd | }t
jj|jj|d t
jj|jj|d t
jj|jj|d t
jj|jj|d nt|tr| j j}|j jd d|j j d  | }d|j j d | }t
jj|jj|d t
jj|jj|d nPt|trt
jj|jj|jd | j j d t
jj|jj|jd | j j d t|t
jr.|j j!  |jj"d t|t
j#rT|j durT|j j!  dS )	zInitialize the weightsr   g{Gz?)meanstdr   )r   rP   g      ?N)$rL   Zinitializer_factor
isinstancer   r   rq   dataZnormal_rc   rJ   r   initr]   rY   r_   Zinitializer_ranger   num_hidden_layersr   r   r   r   r   rX   r   r   CLIPSegModeltext_projectiontext_embed_dimvisual_projectionvision_embed_dimr   rO   Zzero_Zfill_r   )r8   r   factorZin_proj_stdZout_proj_stdZfc_stdr"   r"   r#   _init_weights  sJ    

  z$CLIPSegPreTrainedModel._init_weightsN)r=   r>   r?   r   rB   Zbase_model_prefixZsupports_gradient_checkpointingr   r"   r"   r"   r#   r     s   
r   c                
       sd   e Zd ZdZed fddZed	eej	 eej	 ee
 ee
 ee
 eeef dddZ  ZS )
CLIPSegEncoderz
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`CLIPSegEncoderLayer`].

    Args:
        config: CLIPSegConfig
    rK   c                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r"   )r   r5   r   rK   r"   r#   
<listcomp>      z+CLIPSegEncoder.__init__.<locals>.<listcomp>F)	rV   rW   rL   r   
ModuleListranger   layersZgradient_checkpointingrf   rg   rK   r#   rW     s    
 zCLIPSegEncoder.__init__N)r   r   r   output_hidden_statesreturn_dictr   c                 C   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}|rDdnd}|rPdnd}|}	t| jD ]@\}
}|rx||	f }||	|||d}|d }	|rb||d f }qb|r||	f }t|	||dS )a  
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Causal mask for the text model. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr"   )r   r   r   )last_hidden_staterD   rE   )rL   r   r   use_return_dict	enumerater   r   )r8   r   r   r   r   r   r   Zencoder_statesall_attentionsrD   idxZencoder_layerlayer_outputsr"   r"   r#   r     s0    '

zCLIPSegEncoder.forward)NNNNN)r=   r>   r?   r@   r   rW   r   r   r   r   r   r   r;   r   r   r   r"   r"   rg   r#   r     s         
r   c                       sh   e Zd Zed fddZedeej eej eej ee	 ee	 ee	 e
eef dddZ  ZS )	CLIPSegTextTransformerrK   c                    sH   t    || _|j}t|| _t|| _tj	||j
d| _|j| _d S r   )rV   rW   rL   rX   r   ri   r   encoderr   r   r   final_layer_normeos_token_idr   rg   r"   r#   rW   5  s    


zCLIPSegTextTransformer.__init__Nr   r   rQ   r   r   r   r   c                 C   sn  |d ur|n| j j}|d ur |n| j j}|d ur4|n| j j}|d u rLtd| }|d|d }| j||d}t||j	|j
d}	|d urt||j	}| j|||	|||d}
|
d }| |}| jdkr|tj|jd |j
d|jtj|j
djdd	f }n>|tj|jd |j
d|jtj|j
d| jk jdd	f }|sZ||f|
d
d   S t|||
j|
jdS )NzYou have to specify input_idsrS   )r   rQ   r   )r   r   r   r   r   r   r   rP   )r   r   rn   r   r   pooler_outputrD   rE   )rL   r   r   r   r{   rl   ru   ri   r
   r   r   r   r   r   r   r   r    rp   r   r   Zargmaxr   rD   rE   )r8   r   r   rQ   r   r   r   Zinput_shaperD   r   encoder_outputsr   rH   r"   r"   r#   r   @  sZ    

	
	zCLIPSegTextTransformer.forward)NNNNNN)r=   r>   r?   r   rW   r   r   r   r   r   r   r;   r   r   r   r"   r"   rg   r#   r   4  s"         
r   c                       s   e Zd ZU eed< ddgZed fddZejddd	Z	d
d Z
edeej eej eej ee ee ee eeef dddZ  ZS )CLIPSegTextModelrL   r   r   rK   c                    s"   t  | t|| _|   d S r   )rV   rW   r   
text_model	post_initrf   rg   r"   r#   rW     s    
zCLIPSegTextModel.__init__r0   c                 C   s
   | j jjS r   r   ri   r   r7   r"   r"   r#   get_input_embeddings  s    z%CLIPSegTextModel.get_input_embeddingsc                 C   s   || j j_d S r   r   )r8   r   r"   r"   r#   set_input_embeddings  s    z%CLIPSegTextModel.set_input_embeddingsNr   c                 C   s   | j ||||||dS )a;  
        Examples:

        ```python
        >>> from transformers import AutoTokenizer, CLIPSegTextModel

        >>> tokenizer = AutoTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined")
        >>> model = CLIPSegTextModel.from_pretrained("CIDAS/clipseg-rd64-refined")

        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled (EOS token) states
        ```r   r   rQ   r   r   r   )r   )r8   r   r   rQ   r   r   r   r"   r"   r#   r     s    zCLIPSegTextModel.forward)NNNNNN)r=   r>   r?   r   rB   Z_no_split_modulesrW   r   Moduler   r   r   r   r   r   r   r   r;   r   r   r   r"   r"   rg   r#   r     s*   
      
r   c                
       s^   e Zd Zed fddZed	eej ee	 ee	 ee	 ee	 e
eef dddZ  ZS )
CLIPSegVisionTransformerrK   c                    sR   t    || _|j}t|| _tj||jd| _	t
|| _tj||jd| _d S r   )rV   rW   rL   rX   rJ   ri   r   r   r   pre_layrnormr   r   post_layernormr   rg   r"   r#   rW     s    


z!CLIPSegVisionTransformer.__init__NT)rx   r   r   r   rw   r   c           
      C   s   |d ur|n| j j}|d ur |n| j j}|d ur4|n| j j}| j||d}| |}| j||||d}|d }|d d dd d f }	| |	}	|s||	f|dd   S t||	|j	|j
dS )N)rw   )r   r   r   r   r   r   r   )rL   r   r   r   ri   r   r   r   r   rD   rE   )
r8   rx   r   r   r   rw   rD   r   r   rH   r"   r"   r#   r     s.    	

z CLIPSegVisionTransformer.forward)NNNT)r=   r>   r?   r   rW   r   r   r   rA   r   r   r;   r   r   r   r"   r"   rg   r#   r     s   
    
r   c                
       s|   e Zd ZU eed< dZed fddZejdddZ	e
deej ee ee ee ee eeef dddZ  ZS )CLIPSegVisionModelrL   rx   rK   c                    s"   t  | t|| _|   d S r   )rV   rW   r   vision_modelr   rf   rg   r"   r#   rW     s    
zCLIPSegVisionModel.__init__r0   c                 C   s
   | j jjS r   )r   ri   r_   r7   r"   r"   r#   r     s    z'CLIPSegVisionModel.get_input_embeddingsNTrx   r   r   rw   r   r   c                 C   s   | j |||||dS )a  
        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, CLIPSegVisionModel

        >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
        >>> model = CLIPSegVisionModel.from_pretrained("CIDAS/clipseg-rd64-refined")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled CLS states
        ```rx   r   r   rw   r   )r   )r8   rx   r   r   rw   r   r"   r"   r#   r     s    zCLIPSegVisionModel.forward)NNNTN)r=   r>   r?   r   rB   Zmain_input_namerW   r   r   r   r   r   r   rA   r   r   r;   r   r   r   r"   r"   rg   r#   r     s$   
     
r   c                       s   e Zd ZU eed< ed fddZedeej	 eej	 eej	 ee
 ee
 ee
 ejdddZedeej ee
 ee
 e
ee
 ejd
ddZedeej eej eej	 eej ee
 ee
 ee
 e
ee
 eeef d
ddZ  ZS )r   rL   rK   c                    s   t  | t|jts.tdt|j dt|jtsPtdt|j d|j}|j}|j	|_	|j	|_	|j
| _
|j| _|j| _t|| _t|| _tj| j| j
dd| _tj| j| j
dd| _tt| jj| _|   d S )NzNconfig.text_config is expected to be of type CLIPSegTextConfig but is of type .zRconfig.vision_config is expected to be of type CLIPSegVisionConfig but is of type F)rO   )rV   rW   r   text_configr   	TypeErrortypevision_configr   r   projection_dimrX   r   r   r   r   r   r   r   r   r   r   r\   r   ZtensorrL   Zlogit_scale_init_valuelogit_scaler   )r8   rL   r   r   rg   r"   r#   rW   .  s4    

zCLIPSegModel.__init__Nr   c           
      C   sh   |dur|n| j j}|dur |n| j j}|dur4|n| j j}| j||||||d}|d }| |}	|	S )a  
        Returns:
            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of [`CLIPSegTextModel`].

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, CLIPSegModel

        >>> tokenizer = AutoTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined")
        >>> model = CLIPSegModel.from_pretrained("CIDAS/clipseg-rd64-refined")

        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
        >>> text_features = model.get_text_features(**inputs)
        ```Nr   r   )rL   r   r   r   r   r   )
r8   r   r   rQ   r   r   r   text_outputsrH   Ztext_featuresr"   r"   r#   get_text_featuresR  s    	
zCLIPSegModel.get_text_featuresTr   c           	      C   sf   |dur|n| j j}|dur |n| j j}|dur4|n| j j}| j|||||d}|d }| |}|S )aI  
        Returns:
            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of [`CLIPSegVisionModel`].

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, CLIPSegModel

        >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
        >>> model = CLIPSegModel.from_pretrained("CIDAS/clipseg-rd64-refined")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> image_features = model.get_image_features(**inputs)
        ```Nr   r   )rL   r   r   r   r   r   )	r8   rx   r   r   rw   r   vision_outputsrH   Zimage_featuresr"   r"   r#   get_image_features  s     
zCLIPSegModel.get_image_features)
r   rx   r   rQ   return_lossr   r   rw   r   r   c
              	   C   s,  |dur|n| j j}|dur |n| j j}|	dur4|	n| j j}	| j|||||	d}
| j||||||	d}|
d }| |}|d }| |}||jdddd }||jdddd }| j	
 }t|| | }| }d}|rt|}|	s||||||
f}|dur|f| S |S t|||||||
d	S )
a  
        return_loss (`bool`, *optional*):
            Whether or not to return the contrastive loss.

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, CLIPSegModel

        >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
        >>> model = CLIPSegModel.from_pretrained("CIDAS/clipseg-rd64-refined")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(
        ...     text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
        ... )

        >>> outputs = model(**inputs)
        >>> logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
        >>> probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities
        ```Nr   r   r   rP   rS   T)r   ro   Zkeepdim)r)   r*   r+   r,   r-   r.   r/   )rL   r   r   r   r   r   r   r   Znormr  expr   r   r&   r'   r(   )r8   r   rx   r   rQ   r  r   r   rw   r   r  r  r-   r,   r  r+   r*   r)   outputr"   r"   r#   r     sV    '	


zCLIPSegModel.forward)NNNNNN)NNNTN)	NNNNNNNTN)r=   r>   r?   r   rB   rW   r   r   r   r   r   rA   r  r  r   r   r;   r(   r   r   r"   r"   rg   r#   r   *  sl   
$      .     2         
r   c                       sN   e Zd ZdZed fddZd	ejejejee	 e
ej dddZ  ZS )
CLIPSegDecoderLayerz
    CLIPSeg decoder layer, which is identical to `CLIPSegEncoderLayer`, except that normalization is applied after
    self-attention/MLP, rather than before.
    rK   c                    sR   t    |j| _t|| _tj| j|jd| _	t
|| _tj| j|jd| _d S r   r   rf   rg   r"   r#   rW     s    


zCLIPSegDecoderLayer.__init__Fr   c                 C   sd   |}| j ||||d\}}|| }| |}|}| |}|| }| |}|f}|r`||f7 }|S r   )r   r   r   r   r   r"   r"   r#   r   #  s"    




zCLIPSegDecoderLayer.forward)F)r=   r>   r?   r@   r   rW   r   r   r   r   r;   rA   r   r   r"   r"   rg   r#   r	    s    r	  c                       sN   e Zd Zed fddZd	eej ejee	 ee	 ee	 dddZ
  ZS )
CLIPSegDecoderrK   c                    sX  t     j| _t j j| _t j j| _ j	r j
jd  j
jd f}ttj j jdddt tj j jd |d |d dt tj jd d|d |d d| _ntj jd j
j j
jd| _t j}t fd	d
t|D | _t j
 j_ j_ j_d_tfdd
tt jD | _d S )N   r   r   )rM   paddingrP   r   )rM   rN   )rN   c                    s   g | ]}t  jj jqS r"   )r   r   r   rX   
reduce_dimr   rK   r"   r#   r   n  r   z+CLIPSegDecoder.__init__.<locals>.<listcomp>Zreluc                    s   g | ]}t  qS r"   )r	  r   )decoder_configr"   r#   r   v  r   )rV   rW   conditional_layerr   r   r   r  film_mulfilm_addZ"use_complex_transposed_convolutionr   r[   Z
Sequentialr^   ZReLUZConvTranspose2dtransposed_convolutionr!   extract_layersr   r   reducescopydeepcopyrX   Zdecoder_num_attention_headsr   Zdecoder_intermediate_sizer   r   r   )r8   rL   Ztransposed_kernelsdepthrg   )rL   r  r#   rW   N  sB    
zCLIPSegDecoder.__init__NT)rD   rG   r   r   r   c                 C   sr  |rdnd }|rdnd }|d d d }d }	t t|| j| jD ]\}
\}}}|	d urb|||	 }	n||}	|
| jkr| ||	ddd | | }	|	ddd}	||	d d |d}|d }	|r||	f7 }|r>||d f7 }q>|	d d dd d d f ddd}	tt	
|	jd }|jd }|	||	jd ||}	| |	d}|sdtdd |||fD S t|||d	S )
Nr"   rS   r   r   rP   )r   r   r   c                 s   s   | ]}|d ur|V  qd S r   r"   )r5   vr"   r"   r#   r9     r   z)CLIPSegDecoder.forward.<locals>.<genexpr>)r   rD   rE   )r   zipr   r  r  r  rt   r  r   mathsqrtrp   ru   r  Zsqueezer;   rC   )r8   rD   rG   r   r   r   Zall_hidden_statesr   activationsr  iZ
activationlayerreducer   rl   r~   r   r"   r"   r#   r   x  sB    "

$
zCLIPSegDecoder.forward)NNT)r=   r>   r?   r   rW   r;   r   r   r   r   r   r   r"   r"   rg   r#   r
  M  s   .   r
  zn
    CLIPSeg model with a Transformer-based decoder on top for zero-shot and one-shot image segmentation.
    )Zcustom_introc                       s   e Zd ZU eed< ed fddZdee eej	 eej	 eej	 eej	 dddZ
edeej eej eej eej eej	 eej eej ee ee eee eeef d
ddZ  ZS )CLIPSegForImageSegmentationrL   rK   c                    s:   t  | || _t|| _|j| _t|| _|   d S r   )	rV   rW   rL   r   r   r  r
  decoderr   rf   rg   r"   r#   rW     s    

z$CLIPSegForImageSegmentation.__init__Nr~   r   r   rQ   conditional_pixel_valuesc                 C   s   |d urXt ||krtdt " | jj|||d}W d    q1 sL0    Y  nZ|d urt ||krttdt  | j|}W d    q1 s0    Y  ntd|S )Nz@Make sure to pass as many prompt texts as there are query images)r   rQ   zAMake sure to pass as many prompt images as there are query imagesz[Invalid conditional, should be either provided as `input_ids` or `conditional_pixel_values`)r!   r{   r   no_gradr   r  r  )r8   r~   r   r   rQ   r#  rG   r"   r"   r#   get_conditional_embeddings  s     
&
,z6CLIPSegForImageSegmentation.get_conditional_embeddingsT)r   rx   r#  rG   r   rQ   labelsr   r   rw   r   r   c                    s  |dur|n| j j}t  | jj||d|
|d}| j|d }|rN|jn|d   fdd| jD }|rt	|j
|j|	r|jnd|jd}n |	s|dd |d	d  n|}W d   n1 s0    Y  |du r| j|jd
 ||||d}n:|jd
 |jd
 krtd|jd | j jkr,td| j||||	|d}|rL|jn|d
 }d}|dur||j}t }|||}|s|||||f}|dur|f| S |S t||||||dS )aX  
        conditional_pixel_values (`torch.FloatTensor`, *optional*):
            The pixel values of the conditional images.
        conditional_embeddings (`torch.FloatTensor` of shape `(batch_size, config.projection_dim)`, *optional*):
            The conditional embeddings for the query images. If provided, the model will use this instead of computing
            the embeddings from the conditional_pixel_values.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> from transformers import AutoProcessor, CLIPSegForImageSegmentation
        >>> from PIL import Image
        >>> import requests

        >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
        >>> model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> texts = ["a cat", "a remote", "a blanket"]
        >>> inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt")

        >>> outputs = model(**inputs)

        >>> logits = outputs.logits
        >>> print(logits.shape)
        torch.Size([3, 352, 352])
        ```NTr   r   rP   c                    s   g | ]} |d   qS )r   r"   )r5   r  rD   r"   r#   r   !  r   z7CLIPSegForImageSegmentation.forward.<locals>.<listcomp>r   r   r   r"  zWMake sure to pass as many conditional embeddings as there are query images in the batchzcMake sure that the feature dimension of the conditional embeddings matches `config.projection_dim`.)r   r   r   )r)   r   rG   rH   r/   rI   )rL   r   r   r$  r   r   r   rD   r  r   r   r   rE   r%  rp   r{   r   r!  r   r   r   r   ZBCEWithLogitsLossrF   )r8   r   rx   r#  rG   r   rQ   r&  r   r   rw   r   r  rH   r  Zdecoder_outputsr   r)   Zloss_fnr  r"   r'  r#   r     sz    /
 

z#CLIPSegForImageSegmentation.forward)NNNNN)NNNNNNNNNTN)r=   r>   r?   r   rB   rW   r   r   r   r   r%  r   rA   r   r   r   r;   r(   r   r   r"   r"   rg   r#   r     sP   
                
r   )r   r   r   r   r   )r   )@r@   r  r  dataclassesr   typingr   r   r   r   r   Ztorch.utils.checkpointr   r  r	   Zmodeling_attn_mask_utilsr
   r   Zmodeling_layersr   Zmodeling_outputsr   r   Zmodeling_utilsr   r   utilsr   r   r   r   r   Zconfiguration_clipsegr   r   r   Z
get_loggerr=   r   r   r$   r'   r(   rC   rF   r   rJ   r   floatr   r   r   r   r   r   r   r   r   r   r   r	  r
  r   __all__r"   r"   r"   r#   <module>   sv   
#T0 Q20V[443 j9d .