a
    h                     @   s8  d Z ddlZddlmZ ddlmZmZ ddlZddl	Zddlm
Z
 ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZmZ ddlmZ eeZeeddG dd deZeeddG dd deZG dd de
j Z!G dd de
j Z"G dd de
j Z#G dd de
j Z$d>ej%e&e'ej%dd d!Z(G d"d# d#e
j Z)G d$d% d%eZ*G d&d' d'e
j Z+G d(d) d)e
j,Z-G d*d+ d+e
j Z.G d,d- d-e
j Z/eG d.d/ d/eZ0eG d0d1 d1e0Z1ej%e2ej%d2d3d4Z3ej%e2e2ej%d5d6d7Z4G d8d9 d9e
j Z5ed:dG d;d< d<e0Z6g d=Z7dS )?zPyTorch SegGpt model.    N)	dataclass)OptionalUnion)nn)
functional   )ACT2FN)GradientCheckpointingLayer)PreTrainedModel)ModelOutputauto_docstringlogging	torch_int   )SegGptConfigz1
    Output type of [`SegGptEncoderOutput`].
    )Zcustom_introc                   @   s^   e Zd ZU dZejed< dZee	ej  ed< dZ
ee	ej  ed< dZee	ej  ed< dS )SegGptEncoderOutputay  
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, patch_height, patch_width, hidden_size)`):
        Sequence of hidden-states at the output of the last layer of the model.
    hidden_states (`tuple[torch.FloatTensor]`, `optional`, returned when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
        of shape `(batch_size, patch_height, patch_width, hidden_size)`.
    attentions (`tuple[torch.FloatTensor]`, `optional`, returned when `config.output_attentions=True`):
        Tuple of *torch.FloatTensor* (one for each layer) of shape
        `(batch_size, num_heads, seq_len, seq_len)`.
    intermediate_hidden_states (`tuple[torch.FloatTensor]`, *optional*, returned when `config.intermediate_hidden_state_indices` is set):
        Tuple of `torch.FloatTensor` of shape `(batch_size, patch_height, patch_width, hidden_size)`.
        Each element in the Tuple corresponds to the output of the layer specified in `config.intermediate_hidden_state_indices`.
        Additionally, each feature passes through a LayerNorm.
    last_hidden_stateNhidden_states
attentionsintermediate_hidden_states)__name__
__module____qualname____doc__torchFloatTensor__annotations__r   r   tupler   r    r   r   f/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/seggpt/modeling_seggpt.pyr   $   s
   

r   z;
    Output type of [`SegGptImageSegmentationOutput`].
    c                   @   sb   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dS )SegGptImageSegmentationOutputa  
    loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
        The loss value.
    pred_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
        The predicted masks.
    hidden_states (`tuple[torch.FloatTensor]`, `optional`, returned when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
        of shape `(batch_size, patch_height, patch_width, hidden_size)`.
    attentions (`tuple[torch.FloatTensor]`, `optional`, returned when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape
        `(batch_size, num_heads, seq_len, seq_len)`.
    Nloss
pred_masksr   r   )r   r   r   r   r!   r   r   r   r   r"   r   r   r   r   r   r   r   r    @   s
   
r    c                       s(   e Zd ZdZ fddZdd Z  ZS )SegGptPatchEmbeddingsz
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j }}|j|j }}t|tjj	r8|n||f}t|tjj	rR|n||f}|d |d  |d |d   }|| _|| _|| _|| _
tj||||d| _d S )Nr   r   )kernel_sizeZstride)super__init__
image_size
patch_sizenum_channelshidden_size
isinstancecollectionsabcIterablenum_patchesr   Conv2d
projection)selfconfigr'   r(   r)   r*   r/   	__class__r   r   r&   b   s    
 zSegGptPatchEmbeddings.__init__c              
   C   s   |j \}}}}|| jkr td|| jd ks<|| jd krjtd| d| d| jd  d| jd  d	| |ddd	d}|S )
NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   r   zInput image size (*z) doesn't match model ().   r   )shaper)   
ValueErrorr'   r1   permute)r2   pixel_values
batch_sizer)   heightwidth
embeddingsr   r   r   forwardp   s    
(zSegGptPatchEmbeddings.forward)r   r   r   r   r&   rA   __classcell__r   r   r4   r   r#   [   s   r#   c                       sd   e Zd ZdZedd fddZeeejdddZ	dejeje
ej e
e ejd	d
dZ  ZS )SegGptEmbeddingszX
    Construct the embeddings from patch, position embeddings for input and prompt.
    Nr3   returnc                    s   t    ttddd|j| _ttddd|j| _ttddd|j| _	ttddd|j| _
ttddd|j| _t|| _|j|j d d }ttd||j| _t|j| _d S )Nr   r8   )r%   r&   r   	Parameterr   zerosr*   
mask_tokensegment_token_inputsegment_token_prompttype_token_semantictype_token_instancer#   patch_embeddingsZpretrain_image_sizer(   Zrandnposition_embeddingsZDropoutZhidden_dropout_probZdropout)r2   r3   Znum_positionsr4   r   r   r&      s    

zSegGptEmbeddings.__init__)r>   r?   rE   c                 C   s   | j d d dd f }|jd }t|d }tj sF||ksF||krtj|d||d	dddd||fddd	}|	ddddS |d||dS d S )
Nr         ?r   r   r8   ZbicubicF)sizemodeZalign_corners)
rN   r9   r   r   Zjit
is_tracingFinterpolatereshaper;   )r2   r>   r?   Zpatch_pos_embedr/   Zpretrain_patch_sizer   r   r   interpolate_pos_encoding   s    
z)SegGptEmbeddings.interpolate_pos_encoding)r<   prompt_pixel_valuesbool_masked_posembedding_typerE   c                 C   s   |  |}|  |}|j\}}}	}
| j|||	d}|d|d||	d}|d|  ||  }|d urp|nd}| ||	}|| j }|| j	 }|| }|| }|dkr| j
}n|dkr| j}ntd| || }|| }tj||fdd}|S )NrP   r   instanceZsemanticzBEmbedding type should be either 'semantic' or 'instance', but got r   dim)rM   r9   rH   expand	unsqueezeZtype_asrV   rW   rI   rJ   rK   rL   r:   r   cat)r2   r<   rX   rY   rZ   Zinput_embeddingsZprompt_embeddingsr=   patch_heightpatch_width_rH   wZ	pos_embedZtype_embeddingr@   r   r   r   rA      s*    



zSegGptEmbeddings.forward)NN)r   r   r   r   r   r&   intr   TensorrW   r   
BoolTensorstrrA   rB   r   r   r4   r   rC   ~   s     rC   c                       s   e Zd ZdZ fddZeeejejdddZejejejeje	eef e	eef ejddd	Z
dejejdddZ  ZS )SegGptAttentionz=Multi-head Attention block with relative position embeddings.c                    s  t    |j|j }}t|tjjr*|n||f}t|tjjrD|n||f}|d |j |d |j f}|j|j	 }|j	| _	|d | _
tj|j|jd |jd| _t|j|j| _|j| _| jr|d u rtdttd|d  d || _ttd|d  d || _d S )Nr   r   g      r   biaszBInput size must be provided if using relative positional encoding.r8   )r%   r&   r'   r(   r+   r,   r-   r.   r*   num_attention_headsscaler   LinearZqkv_biasqkvproj use_relative_position_embeddingsr:   rF   r   rG   	rel_pos_h	rel_pos_w)r2   r3   r'   r(   Z
input_sizeZhead_dimr4   r   r   r&      s     

 zSegGptAttention.__init__)q_sizek_sizerel_posrE   c           	      C   s   t dt|| d }tj|d|jd dddd|dd}|d|dd}t|dddf t|| d }t|dddf t|| d }|| |d t|| d  }||	  S )	a  
        Get relative positional embeddings according to the relative positions of
            query and key sizes.

        Args:
            q_size (int):
                size of the query.
            k_size (int):
                size of key k.
            rel_pos (`torch.Tensor`):
                relative position embeddings (L, channel).

        Returns:
            Extracted positional embeddings according to relative positions.
        r8   r   r   rP   Zlinear)rQ   rR   N      ?)
re   maxrT   rU   rV   r9   r;   r   Zarangelong)	r2   rt   ru   rv   Zmax_rel_distZrel_pos_resizedZq_coordsZk_coordsZrelative_coordsr   r   r   get_rel_pos   s    $$zSegGptAttention.get_rel_pos)attnqueryrr   rs   rt   ru   rE   c                 C   s   |\}}|\}	}
|  ||	|}|  ||
|}|j\}}}|||||}td||}td||}|||||	|
}||dddddddddf  |dddddddddf  }|||| |	|
 }|S )a  
        Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
        https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py

        Args:
            attn (`torch.Tensor`):
                attention map.
            query (`torch.Tensor`):
                query q in the attention layer with shape (batch_size, query_height * query_width, channel).
            rel_pos_h (`torch.Tensor`):
                relative position embeddings (Lh, channel) for height axis.
            rel_pos_w (`torch.Tensor`):
                relative position embeddings (Lw, channel) for width axis.
            q_size (tuple):
                spatial sequence size of query q with (query_height, query_width).
            k_size (tuple):
                spatial sequence size of key k with (key_height, key_width).

        Returns:
            attn (`torch.Tensor`):
                attention map with added relative positional embeddings.
        zbhwc,hkc->bhwkzbhwc,wkc->bhwkN)rz   r9   rV   r   Zeinsum)r2   r{   r|   rr   rs   rt   ru   Zquery_heightZquery_widthZ
key_heightZ	key_widthZrelative_position_heightZrelative_position_widthr=   rc   r]   Zreshaped_queryZrel_hZrel_wr   r   r   add_decomposed_rel_pos  s    Hz&SegGptAttention.add_decomposed_rel_posFr   rE   c              	   C   s:  |j \}}}}| |||| d| jdddddd}|d|| j || dd\}}	}
|| j |	dd }| jr| 	||| j
| j||f||f}tjjj|tjdd|j}|r||| j|| d}||| j || d}nd }||
 || j||d}|ddddd|||d}| |}||fS )	Nr   rP   r8   r   r      )dtyper]   )r9   ro   rV   rl   r;   Zunbindrm   Z	transposerq   r}   rr   rs   r   r   r   Zsoftmaxfloat32tor   viewrp   )r2   r   output_attentionsr=   r>   r?   rc   ro   r|   keyvalueZattn_weightsZattn_weights_reshapedZattn_outputr   r   r   rA   ;  s,    

&
zSegGptAttention.forward)F)r   r   r   r   r&   re   r   rf   rz   r   r}   rA   rB   r   r   r4   r   ri      s   "

-ri   c                       s0   e Zd Z fddZejejdddZ  ZS )	SegGptMlpc                    s>   t    t|j|j| _t|j|j| _t|j	 | _
d S N)r%   r&   r   rn   r*   Zmlp_dimlin1lin2r   
hidden_actactr2   r3   r4   r   r   r&   c  s    
zSegGptMlp.__init__r~   c                 C   s"   |  |}| |}| |}|S r   )r   r   r   r2   r   r   r   r   rA   i  s    


zSegGptMlp.forward)r   r   r   r&   r   rf   rA   rB   r   r   r4   r   r   b  s   r           F)input	drop_probtrainingrE   c                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    r   r   r   )r   r   device)r9   ndimr   Zrandr   r   Zfloor_div)r   r   r   Z	keep_probr9   Zrandom_tensoroutputr   r   r   	drop_pathq  s    
r   c                       sP   e Zd ZdZdee dd fddZejejdddZ	e
d	d
dZ  ZS )SegGptDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).N)r   rE   c                    s   t    || _d S r   )r%   r&   r   )r2   r   r4   r   r   r&     s    
zSegGptDropPath.__init__r~   c                 C   s   t || j| jS r   )r   r   r   r   r   r   r   rA     s    zSegGptDropPath.forwardrE   c                 C   s   d| j  S )Nzp=)r   r2   r   r   r   
extra_repr  s    zSegGptDropPath.extra_repr)N)r   r   r   r   r   floatr&   r   rf   rA   rh   r   rB   r   r   r4   r   r     s   r   c                	       sZ   e Zd Zeedd fddZd	ejee	e	e
eejejf eej f dddZ  ZS )
SegGptLayerN)r3   drop_path_raterE   c                    sd   t    t|| _t|| _|dkr.t|nt | _	tj
|j|jd| _tj
|j|jd| _d S )Nr   eps)r%   r&   ri   	attentionr   mlpr   r   ZIdentityr   	LayerNormr*   layer_norm_epslayernorm_beforelayernorm_after)r2   r3   r   r4   r   r   r&     s    


zSegGptLayer.__init__F)r   ensemble_condfeature_ensembler   rE   c                 C   s  | j | ||d}|d }|dd  }|r|jd d |kr|j|jd d dd\}}	|dkr|jd d }
|	d|
d}	|	jddd|	}	|	j|j }	n|	jddd|	}	tj||	gdd}| 	|| }|}| 
|}| |}|| 	| }|f| }|S )	N)r   r   r   r8   r\   rP   T)r]   Zkeepdim)r   r   r9   splitrV   meanZ	expand_asr   r`   r   r   r   )r2   r   r   r   r   Zself_attention_outputsZattention_outputoutputspromptinputsZnum_promptsZresidualr   r   r   rA     s,    


zSegGptLayer.forward)FF)r   r   r   r   r   r&   r   rf   re   boolr   r   rA   rB   r   r   r4   r   r     s     r   c                	       sH   e Zd Zedd fddZd
ejeeeeee	e
f ddd	Z  ZS )SegGptEncoderNrD   c                    sp   t     | _dd tjd j jddD t fddt	 jD | _
tj j jd| _d| _d S )	Nc                 S   s   g | ]}|  qS r   )item).0xr   r   r   
<listcomp>      z*SegGptEncoder.__init__.<locals>.<listcomp>r   cpu)r   c                    s   g | ]}t  | qS r   )r   )r   ir3   Zdprr   r   r     r   r   F)r%   r&   r3   r   Zlinspacer   Znum_hidden_layersr   Z
ModuleListrangelayersr   r*   r   	layernormZgradient_checkpointingr   r4   r   r   r&     s    
 "zSegGptEncoder.__init__FT)r   r   r   output_hidden_statesreturn_dictrE   c                 C   s  |rdnd }|rdnd }g }t | jD ]\}	}
|r<||f }| jj|	krLdnd}|
||||}|d }|	| jjkr|d |jd d  ||jd d d   d }|	| jjv r|| | |r&||d f }q&|r||f }|stdd ||||fD S t	||||dS )	Nr   r8   r   r   rO   c                 s   s   | ]}|d ur|V  qd S r   r   )r   vr   r   r   	<genexpr>  s   z(SegGptEncoder.forward.<locals>.<genexpr>)r   r   r   r   )
	enumerater   r3   Zmerge_indexr9   !intermediate_hidden_state_indicesappendr   r   r   )r2   r   r   r   r   r   Zall_hidden_statesZall_self_attentionsr   r   Zlayer_moduler   Zlayer_outputsr   r   r   rA     s:    
*

zSegGptEncoder.forward)FFFT)r   r   r   r   r&   r   rf   r   r   r   r   rA   rB   r   r   r4   r   r     s       
r   c                       s@   e Zd ZdZddd fdd
Zejejd fdd	Z  ZS )
SegGptLayerNormaA  LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
    width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
    gư>channels_last)r   data_formatc                   s8   t  j|fd|i| |dvr.td| || _d S )Nr   )r   channels_firstzUnsupported data format: )r%   r&   NotImplementedErrorr   )r2   normalized_shaper   r   kwargsr4   r   r   r&     s    zSegGptLayerNorm.__init__)featuresrE   c                    sH   | j dkr8|dddd}t |}|dddd}nt |}|S )z
        Args:
            features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
        r   r   r8   r   r   )r   r;   r%   rA   )r2   r   r4   r   r   rA     s    
zSegGptLayerNorm.forward)	r   r   r   r   r&   r   rf   rA   rB   r   r   r4   r   r      s   r   c                       s,   e Zd Z fddZejdddZ  ZS )SegGptDecoderHeadc                    s\   t    tj|j|jddd| _t|j|jdd| _t	|j
 | _tj|jdddd| _d S )Nr   r   )r$   paddingr   )r   r   r   T)r$   rk   )r%   r&   r   r0   decoder_hidden_sizeconvr   r   r   r   r   act_fctheadr   r4   r   r   r&     s    

zSegGptDecoderHead.__init__r   c                 C   s,   |  |}| |}| |}| |}|S r   )r   r   r   r   r   r   r   r   rA   )  s
    



zSegGptDecoderHead.forward)r   r   r   r&   r   r   rA   rB   r   r   r4   r   r     s   r   c                       s@   e Zd Z fddZejejdddZejdddZ  ZS )	SegGptDecoderc                    sX   t    tj|jt|j |jd |j dd| _	t
|| _|j| _|j| _|| _d S )Nr8   Trj   )r%   r&   r   rn   r*   lenr   r(   r   decoder_embedr   decoder_predr3   r   r4   r   r   r&   3  s    

zSegGptDecoder.__init__r~   c                 C   s`   |j \}}}}||||| j| j| j}|dddddd}|j|d|| j || j fd}|S )	Nr      r   r   r8   r   rP   r9   )r9   rV   r(   r   r;   )r2   r   r=   ra   rb   rc   r   r   r   _reshape_hidden_states?  s    z$SegGptDecoder._reshape_hidden_statesr   c                 C   s"   |  |}| |}| |}|S r   )r   r   r   r   r   r   r   rA   K  s    


zSegGptDecoder.forward)	r   r   r   r&   r   r   r   rA   rB   r   r   r4   r   r   2  s   r   c                   @   s<   e Zd ZU eed< dZdZdZddgZe	j
ddd	d
ZdS )SegGptPreTrainedModelr3   modelr<   TrC   r   N)modulerE   c                 C   s  | j j}t|tjtjfr`tjj|jj	
tjd|d
|jj|j_	|jdur\|jj	  n.t|tjtfr|jj	  |jj	d n t|trtjj|jj	
tjd|d
|jj|j_	tjj|jj	
tjd|d
|jj|j_	nt|trtjj|jj	
tjd|d
|jj|j_	tjjj|j|d tjjj|j|d tjjj|j|d tjjj|j|d tjjj|j|d dS )zInitialize the weightsr   )r   stdNrw   )r   )r3   Zinitializer_ranger+   r   rn   r0   initZtrunc_normal_weightdatar   r   r   r   rk   Zzero_r   r   Zfill_ri   rr   rs   rC   rN   Znormal_rH   rI   rJ   rK   rL   )r2   r   r   r   r   r   _init_weights[  sL    


z#SegGptPreTrainedModel._init_weights)r   r   r   r   r   Zbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_no_split_modulesr   Moduler   r   r   r   r   r   S  s   
r   c                       s   e Zd Zed fddZedddZeee	e f ddd	d
Z
edejejejeej ee ee eej ee ee ee eeef dddZ  ZS )SegGptModelr3   c                    s2   t  | || _t|| _t|| _|   d S r   )r%   r&   r3   rC   r@   r   encoder	post_initr   r4   r   r   r&     s
    

zSegGptModel.__init__r   c                 C   s   | j jS r   )r@   rM   r   r   r   r   get_input_embeddings  s    z SegGptModel.get_input_embeddingsN)heads_to_prunerE   c                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   layerr   Zprune_heads)r2   r   r   Zheadsr   r   r   _prune_heads  s    zSegGptModel._prune_headsr<   rX   prompt_masksrY   r   rZ   labelsr   r   r   rE   c                 C   sP  |dur|n| j j}|	dur |	n| j j}	|
dur4|
n| j j}
|durH|nd}| jjjjj}|	|}|	|}t
j||fdd}|du rt
j||fddnt
j||fdd}|du r|durtd |du r&| jjj}t
j|d t
j|jd}t
j||d  t
j|jd}t
||g}|d}| j||||d}| j||||	|
d	}|S )
a
  
        prompt_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Prompt pixel values. Prompt pixel values can be obtained using [`AutoImageProcessor`]. See
            [`SegGptImageProcessor.__call__`] for details.
        prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Prompt mask. Prompt mask can be obtained using [`AutoImageProcessor`]. See [`SegGptImageProcessor.__call__`] for
            details.
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        feature_ensemble (`bool`, *optional*):
            Boolean indicating whether to use feature ensemble or not. If `True`, the model will use feature ensemble
            if we have at least two prompts. If `False`, the model will not use feature ensemble. This argument should
            be considered when doing few-shot inference on an input image i.e. more than one prompt for the same image.
        embedding_type (`str`, *optional*):
            Embedding type. Indicates whether the prompt is a semantic or instance embedding. Can be either
            instance or semantic.
        labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, `optional`):
            Ground truth mask for input images.

        Examples:

        ```python
        >>> from transformers import SegGptImageProcessor, SegGptModel
        >>> from PIL import Image
        >>> import requests

        >>> image_input_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg"
        >>> image_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg"
        >>> mask_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png"

        >>> image_input = Image.open(requests.get(image_input_url, stream=True).raw)
        >>> image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw)
        >>> mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw).convert("L")

        >>> checkpoint = "BAAI/seggpt-vit-large"
        >>> model = SegGptModel.from_pretrained(checkpoint)
        >>> image_processor = SegGptImageProcessor.from_pretrained(checkpoint)

        >>> inputs = image_processor(images=image_input, prompt_images=image_prompt, prompt_masks=mask_prompt, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> list(outputs.last_hidden_state.shape)
        [1, 56, 28, 1024]
        ```
        NFr8   r\   zLabels were provided, but bool_masked_pos were not. It will be set to default value. If you're training the model, make sure to provide a bool_masked_pos.r   r   )rZ   rY   )r   r   r   r   )r3   r   r   use_return_dictr@   rM   r1   r   r   r   r   r`   loggerZwarning_oncer/   rG   r   r   onesr_   r   )r2   r<   rX   r   rY   r   rZ   r   r   r   r   Zexpected_dtyper/   bool_masked_pos_zerosbool_masked_pos_onesZembedding_outputZencoder_outputsr   r   r   rA     sH    ;




zSegGptModel.forward)NNNNNNN)r   r   r   r   r&   r#   r   dictre   listr   r   r   rf   r   rg   r   rh   r   r   r   r   rA   rB   r   r   r4   r   r     s0   
       
r   )tensorr(   rE   c                 C   sl   | j \}}}}|| }|| }| j||||||fd} | dddddd} | j||| |d d fd} | S )Nr   r   r8   r   r   r   r   )r9   rV   r;   )r   r(   r=   r)   r>   r?   ra   rb   r   r   r   patchify
  s    r   )r   ra   rb   rE   c                 C   s   | j d }t| j d d d }|| | j d krTtd| j d  d| d| d	| j|||||dfd
} | dddddd} | j|d|| || fd
} | S )Nr   rP   r   rO   r   zNumber of patches z does not match patch height (z) and width (r7   r   r   r8   r   )r9   re   r:   rV   r;   )r   ra   rb   r=   r(   r   r   r   
unpatchify  s    
r   c                       s8   e Zd Z fddZejejejejdddZ  ZS )
SegGptLossc                    s   t    |j| _|j| _d S r   )r%   r&   betar(   r   r4   r   r   r&   &  s    
zSegGptLoss.__init__)r   r"   r   rY   c                 C   s   t j||fdd}|dddddf dd| jd d }t||jd | j |jd | j }tj||d| jd}|| 	 |	  }|S )aN  Computes the L1 loss between the predicted masks and the ground truth masks.

        Args:
            prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
                Pixel values from mask prompt.

            pred_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, 2*height, width)`):
                Predicted masks.

            labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
                Ground truth mask for input images.

            bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
                Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

        Returns:
            `torch.FloatTensor`: The mean L1 loss between the predicted masks and the ground truth masks.
        r8   r\   Nr   r   none)Z	reductionr   )
r   r`   repeatr(   r   r9   rT   Zsmooth_l1_lossr   sum)r2   r   r"   r   rY   Zground_truthmaskr!   r   r   r   rA   +  s    *$zSegGptLoss.forward)	r   r   r   r&   r   r   rg   rA   rB   r   r   r4   r   r   %  s   r   zM
    SegGpt model with a decoder on top for one-shot image segmentation.
    c                       sx   e Zd Zed fddZedejejejeej	 ee
 ee eej ee
 ee
 ee
 eeef dddZ  ZS )	SegGptForImageSegmentationr   c                    s2   t  | || _t|| _t|| _|   d S r   )r%   r&   r3   r   r   r   decoderr   r   r4   r   r   r&   U  s
    

z#SegGptForImageSegmentation.__init__Nr   c                 C   sx  |dur|n| j j}|	dur |	n| j j}	|
dur4|
n| j j}
|du r| jjjj}tj	|d tj
|jd}tj||d  tj
|jd}t||g}|d}| j|||||||||	|
d
}|
r|jn|d }tj|dd}| |}d}|dur
t| j }|||||}|
sd|f}|	r*||d f }|rL|	r:dnd}||| f }|dur`|f| }|S t|||j|jd	S )
aY  
        prompt_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Prompt pixel values. Prompt pixel values can be obtained using [`AutoImageProcessor`]. See
            [`SegGptImageProcessor.__call__`] for details.
        prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Prompt mask. Prompt mask can be obtained using [`AutoImageProcessor`]. See [`SegGptImageProcessor.__call__`] for
            details.
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        feature_ensemble (`bool`, *optional*):
            Boolean indicating whether to use feature ensemble or not. If `True`, the model will use feature ensemble
            if we have at least two prompts. If `False`, the model will not use feature ensemble. This argument should
            be considered when doing few-shot inference on an input image i.e. more than one prompt for the same image.
        embedding_type (`str`, *optional*):
            Embedding type. Indicates whether the prompt is a semantic or instance embedding. Can be either
            instance or semantic.
        labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, `optional`):
            Ground truth mask for input images.

        Examples:

        ```python
        >>> from transformers import SegGptImageProcessor, SegGptForImageSegmentation
        >>> from PIL import Image
        >>> import requests

        >>> image_input_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg"
        >>> image_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg"
        >>> mask_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png"

        >>> image_input = Image.open(requests.get(image_input_url, stream=True).raw)
        >>> image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw)
        >>> mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw).convert("L")

        >>> checkpoint = "BAAI/seggpt-vit-large"
        >>> model = SegGptForImageSegmentation.from_pretrained(checkpoint)
        >>> image_processor = SegGptImageProcessor.from_pretrained(checkpoint)

        >>> inputs = image_processor(images=image_input, prompt_images=image_prompt, prompt_masks=mask_prompt, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[(image_input.height, image_input.width)])[0]
        >>> print(list(result.shape))
        [170, 297]
        ```
        Nr8   r   r   )
r<   rX   r   rY   r   rZ   r   r   r   r   rP   r\   r   )r!   r"   r   r   )r3   r   r   r   r   r@   rM   r/   r   rG   r   r   r   r`   r_   r   r   r   r    r   r   )r2   r<   rX   r   rY   r   rZ   r   r   r   r   r/   r   r   r   r   r"   r!   Zloss_fnr   idxr   r   r   rA   _  s^    ;





z"SegGptForImageSegmentation.forward)NNNNNNN)r   r   r   r   r&   r   r   rf   r   rg   r   rh   r   r   r   r    rA   rB   r   r   r4   r   r   O  s,   
       
r   )r   r   r   )r   F)8r   collections.abcr,   dataclassesr   typingr   r   r   Ztorch.utils.checkpointr   Ztorch.nnr   rT   Zactivationsr   Zmodeling_layersr	   Zmodeling_utilsr
   utilsr   r   r   r   Zconfiguration_seggptr   Z
get_loggerr   r   r   r    r   r#   rC   ri   r   rf   r   r   r   r   r   r   r   r   r   r   r   r   re   r   r   r   r   __all__r   r   r   r   <module>   sb   
#U /=!0 * 