a
    h                     @   sj  d Z ddlZddlZddlmZ ddlmZmZ ddl	Z	ddl
Z	ddl	mZ ddlmZmZmZ ddlmZ dd	lmZ dd
lmZ ddlmZ ddlmZmZmZ ddlmZ ddlmZ e e!Z"eeddG dd deZ#eeddG dd deZ$eeddG dd deZ%eeddG dd deZ&G dd dej'Z(G dd  d ej'Z)d@e	j*e+e,e	j*d#d$d%Z-G d&d' d'ej'Z.G d(d) d)ej'Z/G d*d+ d+ej'Z0G d,d- d-ej'Z1G d.d/ d/eZ2G d0d1 d1ej'Z3eG d2d3 d3eZ4eG d4d5 d5e4Z5ed6dG d7d8 d8e4Z6ed9dG d:d; d;e4Z7ed<dG d=d> d>e4eZ8g d?Z9dS )AzPyTorch FocalNet model.    N)	dataclass)OptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)GradientCheckpointingLayer)BackboneOutput)PreTrainedModel)ModelOutputauto_docstringlogging)BackboneMixin   )FocalNetConfigzC
    FocalNet encoder's outputs, with potential hidden states.
    )Zcustom_introc                   @   sP   e Zd ZU dZdZeej ed< dZ	ee
ej  ed< dZee
ej  ed< dS )FocalNetEncoderOutputa  
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nlast_hidden_statehidden_statesreshaped_hidden_states)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   tupler    r    r    j/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/focalnet/modeling_focalnet.pyr   '   s   
	r   zZ
    FocalNet model's outputs that also contains a pooling of the last hidden states.
    c                   @   sb   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dS )FocalNetModelOutputa  
    pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
        Average pooling of the last layer hidden-state.
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nr   pooler_outputr   r   )r   r   r   r   r   r   r   r   r   r#   r   r   r   r    r    r    r!   r"   <   s
   
r"   z.
    FocalNet masked image model outputs.
    c                   @   sb   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dS )!FocalNetMaskedImageModelingOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
        Masked image modeling (MLM) loss.
    reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
        Reconstructed pixel values.
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nlossreconstructionr   r   )r   r   r   r   r%   r   r   r   r   r&   r   r   r   r    r    r    r!   r$   T   s
   
r$   z4
    FocalNet outputs for image classification.
    c                   @   sb   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dS )FocalNetImageClassifierOutputa7  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Classification (or regression if config.num_labels==1) loss.
    logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
        Classification (or regression if config.num_labels==1) scores (before SoftMax).
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nr%   logitsr   r   )r   r   r   r   r%   r   r   r   r   r(   r   r   r   r    r    r    r!   r'   n   s
   
r'   c                       sH   e Zd ZdZd	 fdd	Zd
eej eej e	ej
 dddZ  ZS )FocalNetEmbeddingszX
    Construct the patch embeddings and layernorm. Optionally, also the mask token.
    Fc              	      s|   t    t||j|j|j|j|jdd| _| jj	| _
|rPttdd|jnd | _tj|j|jd| _t|j| _d S )NT)config
image_size
patch_sizenum_channels	embed_dimuse_conv_embedis_stemr   eps)super__init__FocalNetPatchEmbeddingsr+   r,   r-   r.   r/   patch_embeddings	grid_size
patch_gridr   	Parameterr   Zzeros
mask_token	LayerNormlayer_norm_epsnormDropouthidden_dropout_probdropout)selfr*   use_mask_token	__class__r    r!   r4      s    
	
 zFocalNetEmbeddings.__init__N)pixel_valuesbool_masked_posreturnc           
      C   st   |  |\}}| |}| \}}}|d urb| j||d}|d|}	|d|	  ||	  }| |}||fS )N      ?)r6   r=   sizer:   expand	unsqueezeZtype_asr@   )
rA   rE   rF   
embeddingsoutput_dimensions
batch_sizeZseq_len_Zmask_tokensmaskr    r    r!   forward   s    

zFocalNetEmbeddings.forward)F)N)r   r   r   r   r4   r   r   r   
BoolTensorr   TensorrR   __classcell__r    r    rC   r!   r)      s    r)   c                       sJ   e Zd Zd	 fdd	Zdd Zeej eej	ee
 f dddZ  ZS )
r5   Fc	                    s  t    t|tjjr|n||f}t|tjjr6|n||f}|d |d  |d |d   }	|| _|| _|| _|	| _	|d |d  |d |d  f| _
|r|rd}
d}d}nd}
d}d}tj|||
||d| _ntj||||d| _|rtj||jd	| _nd | _d S )
Nr   r            r	   )kernel_sizestridepadding)rY   rZ   r1   )r3   r4   
isinstancecollectionsabcIterabler+   r,   r-   num_patchesr7   r   Conv2d
projectionr;   r<   r=   )rA   r*   r+   r,   r-   r.   add_normr/   r0   r`   rY   r[   rZ   rC   r    r!   r4      s0    
 "

z FocalNetPatchEmbeddings.__init__c                 C   s   || j d  dkr<d| j d || j d   f}tj||}|| j d  dkr|ddd| j d || j d   f}tj||}|S )Nr   r   )r,   r   
functionalpad)rA   rE   heightwidthZ
pad_valuesr    r    r!   	maybe_pad   s     z!FocalNetPatchEmbeddings.maybe_pad)rE   rG   c                 C   s|   |j \}}}}|| jkr td| |||}| |}|j \}}}}||f}|ddd}| jd urt| |}||fS )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.rW   r   )shaper-   
ValueErrorrh   rb   flatten	transposer=   )rA   rE   rP   r-   rf   rg   rM   rN   r    r    r!   rR      s    



zFocalNetPatchEmbeddings.forward)FFF)r   r   r   r4   rh   r   r   r   r   rT   intrR   rU   r    r    rC   r!   r5      s      *	r5           F)input	drop_probtrainingrG   c                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    rn   r   r   )r   )dtypedevice)ri   ndimr   Zrandrr   rs   Zfloor_div)ro   rp   rq   Z	keep_probri   Zrandom_tensoroutputr    r    r!   	drop_path   s    
rw   c                       sP   e Zd ZdZdee dd fddZejejdddZ	e
d	d
dZ  ZS )FocalNetDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).N)rp   rG   c                    s   t    || _d S N)r3   r4   rp   )rA   rp   rC   r    r!   r4     s    
zFocalNetDropPath.__init__)r   rG   c                 C   s   t || j| jS ry   )rw   rp   rq   )rA   r   r    r    r!   rR     s    zFocalNetDropPath.forward)rG   c                 C   s   d| j  S )Nzp=)rp   rA   r    r    r!   
extra_repr  s    zFocalNetDropPath.extra_repr)N)r   r   r   r   r   floatr4   r   rT   rR   strr{   rU   r    r    rC   r!   rx     s   rx   c                       s&   e Zd Zd fdd	Zdd Z  ZS )	FocalNetModulationrW   Trn   c           	         s"  t    || _|j| | _|j| | _|| _|j| _|j	| _	t
j|d| | jd  |d| _t
j||dd|d| _t
 | _t
||| _t
|| _t
 | _g | _t| jD ]P}| j| | j }| jt
t
j|||d||d ddt
  | j| q| jrt
j||jd| _d S )NrW   r   )bias)rY   rZ   r   F)rY   rZ   groupsr[   r   r1   )r3   r4   dimZfocal_windowsZfocal_windowZfocal_levelsfocal_levelfocal_factor use_post_layernorm_in_modulationnormalize_modulatorr   Linearprojection_inra   projection_contextZGELU
activationprojection_outr>   projection_dropout
ModuleListfocal_layersZkernel_sizesrangeappend
Sequentialr;   r<   	layernorm)	rA   r*   indexr   r   r   r   krY   rC   r    r!   r4     s6    
 

zFocalNetModulation.__init__c                 C   s&  |j d }| |dddd }t|||| jd fd\}}}d}t| jD ]2}| j| |}|||dd||d f   }qR| 	|j
dddj
ddd}	||	|dd| jdf   }| jr|| jd  }| |}
||
 }|dddd }| jr| |}| |}| |}|S )	z
        Args:
            hidden_state:
                Input features with shape of (batch_size, height, width, num_channels)
        rH   r   r	   r   rW   NT)Zkeepdim)ri   r   permute
contiguousr   splitr   r   r   r   meanr   r   r   r   r   r   )rA   hidden_stater-   xqctxZgatesZctx_alllevelZ
ctx_globalZ	modulatorZx_outr    r    r!   rR   =  s&    
 "



zFocalNetModulation.forward)rW   Trn   r   r   r   r4   rR   rU   r    r    rC   r!   r~     s   !r~   c                       s&   e Zd Zd fdd	Zdd Z  ZS )FocalNetMlpNrn   c                    sR   t    |p|}|p|}t||| _t|j | _t||| _t	|| _
d S ry   )r3   r4   r   r   fc1r
   Z
hidden_actr   fc2r>   drop)rA   r*   in_featureshidden_featuresout_featuresr   rC   r    r!   r4   c  s    
zFocalNetMlp.__init__c                 C   s6   |  |}| |}| |}| |}| |}|S ry   )r   r   r   r   )rA   r   r    r    r!   rR   l  s    




zFocalNetMlp.forward)NNrn   r   r    r    rC   r!   r   b  s   	r   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )FocalNetLayera  Focal Modulation Network layer (block).

    Args:
        config (`FocalNetConfig`):
            Model config.
        index (`int`):
            Layer index.
        dim (`int`):
            Number of input channels.
        input_resolution (`tuple[int]`):
            Input resolution.
        drop_path (`float`, *optional*, defaults to 0.0):
            Stochastic depth rate.
    rn   c                    s   t    || _|| _|| _|j| _|j| _tj	||j
d| _t|||| jd| _|dkrbt|nt | _tj	||j
d| _t||j }t|||| jd| _d| _d| _|jrtj|jt| dd| _tj|jt| dd| _d S )Nr1   )r*   r   r   r   rn   )r*   r   r   r   rI   T)Zrequires_grad)r3   r4   r*   r   input_resolutionr?   r   use_post_layernormr   r;   r<   norm1r~   
modulationrx   Identityrw   norm2rm   Z	mlp_ratior   mlpgamma_1gamma_2use_layerscaler9   layerscale_valuer   Zones)rA   r*   r   r   r   rw   Zmlp_hidden_dimrC   r    r!   r4     s,    
zFocalNetLayer.__init__c           	   	   C   s   |\}}|j \}}}|}| jr"|n| |}|||||}| |||| |}| js^|n| |}|| | j|  }|| | j| jr| | 	|n| 	| |  }|S ry   )
ri   r   r   viewr   rw   r   r   r   r   )	rA   r   input_dimensionsrf   rg   rO   rP   r-   Zshortcutr    r    r!   rR     s    $zFocalNetLayer.forward)rn   )r   r   r   r   r4   rR   rU   r    r    rC   r!   r   u  s    r   c                       s>   e Zd Z fddZejeeef eej dddZ  Z	S )FocalNetStagec              
      s$  t     | _t j| _ fddt| jD }| | jd k rV|d  nd }| jd k rltnd }dd tj	d j
t jddD }|t jd  t jd d   t fddt j D | _|d ur| d	|d
 jdd| _nd | _d| _d S )Nc                    s   g | ]} j d |  qS )rW   )r.   .0ir*   r    r!   
<listcomp>      z*FocalNetStage.__init__.<locals>.<listcomp>r   c                 S   s   g | ]}|  qS r    )item)r   r   r    r    r!   r     r   r   cpu)rs   c              
      s0   g | ](}t  ttr$| nd qS ))r*   r   r   r   rw   )r   r\   listr   r*   r   rw   r   r   r    r!   r     s   rW   TF)r*   r+   r,   r-   r.   rc   r/   r0   )r3   r4   r*   lendepths
num_stagesr   r5   r   ZlinspaceZdrop_path_ratesumr   r   layersr/   
downsampleZpointing)rA   r*   r   r   r.   Zout_dimr   ZdprrC   r   r!   r4     s6    
$,

zFocalNetStage.__init__)r   r   rG   c           	      C   s|   |\}}| j D ]}|||}q|}| jd urb|\}}|dd|jd d||}| |\}}n||||f}|||f}|S )Nr   rW   r   rH   )r   r   rl   reshaperi   )	rA   r   r   rf   rg   Zlayer_module!hidden_states_before_downsamplingrN   stage_outputsr    r    r!   rR     s    


zFocalNetStage.forward)
r   r   r   r4   r   rT   r   rm   rR   rU   r    r    rC   r!   r     s   ,r   c                	       sT   e Zd Z fddZdejeeef ee	 ee	 ee	 e
eef dddZ  ZS )	FocalNetEncoderc                    sH   t    t j| _ | _t fddt| jD | _	d| _
d S )Nc              	      s6   g | ].}t  |d  d|  d d|  fdqS )r   rW   r   )r*   r   r   )r   )r   Zi_layerr*   r7   r    r!   r     s   z,FocalNetEncoder.__init__.<locals>.<listcomp>F)r3   r4   r   r   r   r*   r   r   r   stagesZgradient_checkpointing)rA   r*   r7   rC   r   r!   r4     s    
zFocalNetEncoder.__init__FT)r   r   output_hidden_states(output_hidden_states_before_downsamplingreturn_dictrG   c                 C   s  |rdnd }|rdnd }|rb|j \}}	}
|j|g||
R  }|dddd}||f7 }||f7 }t| jD ]\}}|||}|d }|d }|d }|d |d f}|r|r|j \}}	}
|j|g|d |d f|
R  }|dddd}||f7 }||f7 }ql|rl|sl|j \}}	}
|j|g||
R  }|dddd}||f7 }||f7 }ql|srtdd	 ||fD S t|||d
S )Nr    r   r	   r   rW   rH   c                 s   s   | ]}|d ur|V  qd S ry   r    )r   vr    r    r!   	<genexpr>>  r   z*FocalNetEncoder.forward.<locals>.<genexpr>)r   r   r   )ri   r   r   	enumerater   r   r   )rA   r   r   r   r   r   Zall_hidden_statesZall_reshaped_hidden_statesrO   rP   Zhidden_sizeZreshaped_hidden_stater   Zstage_moduler   r   rN   r    r    r!   rR     sN    




zFocalNetEncoder.forward)FFT)r   r   r   r4   r   rT   r   rm   r   boolr   r   rR   rU   r    r    rC   r!   r     s      

r   c                   @   s0   e Zd ZU eed< dZdZdZdgZdd Z	dS )	FocalNetPreTrainedModelr*   focalnetrE   Tr   c                 C   s   t |tjtjfr@|jjjd| jjd |j	dur|j	j
  nt |tjrh|j	j
  |jjd nXt |tr|jdur|jj
  n6t |tr| jjr|jj| jj |jj| jj dS )zInitialize the weightsrn   )r   ZstdNrI   )r\   r   r   ra   weightdataZnormal_r*   Zinitializer_ranger   Zzero_r;   Zfill_r)   r:   r   r   r   r   r   )rA   moduler    r    r!   _init_weightsO  s    



z%FocalNetPreTrainedModel._init_weightsN)
r   r   r   r   r   Zbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_no_split_modulesr   r    r    r    r!   r   G  s   
r   c                	       s^   e Zd Zd fdd	Zdd Zedeej eej	 ee
 ee
 eeef dd	d
Z  ZS )FocalNetModelTFc                    s   t  | || _t|j| _t|jd| jd   | _t	||d| _
t|| j
j| _tj| j|jd| _|rxtdnd| _|   dS )z
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        use_mask_token (`bool`, *optional*, defaults to `False`):
            Whether to use a mask token for masked image modeling.
        rW   r   )rB   r1   N)r3   r4   r*   r   r   r   rm   r.   num_featuresr)   rM   r   r8   encoderr   r;   r<   r   ZAdaptiveAvgPool1dpooler	post_init)rA   r*   add_pooling_layerrB   rC   r    r!   r4   e  s    zFocalNetModel.__init__c                 C   s   | j jS ry   )rM   r6   rz   r    r    r!   get_input_embeddingsz  s    z"FocalNetModel.get_input_embeddingsNrE   rF   r   r   rG   c                 C   s   |dur|n| j j}|dur |n| j j}|du r8td| j||d\}}| j||||d}|d }| |}d}	| jdur| |dd}	t	
|	d}	|s||	f|dd  }
|
S t||	|j|jdS )	z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        Nz You have to specify pixel_values)rF   r   r   r   r   rW   )r   r#   r   r   )r*   r   use_return_dictrj   rM   r   r   r   rl   r   rk   r"   r   r   )rA   rE   rF   r   r   Zembedding_outputr   Zencoder_outputssequence_outputpooled_outputrv   r    r    r!   rR   }  s6    

zFocalNetModel.forward)TF)NNNN)r   r   r   r4   r   r   r   r   r   rS   r   r   r   r"   rR   rU   r    r    rC   r!   r   c  s       
r   a  
    FocalNet Model with a decoder on top for masked image modeling.

    This follows the same implementation as in [SimMIM](https://huggingface.co/papers/2111.09886).

    <Tip>

    Note that we provide a script to pre-train this model on custom data in our [examples
    directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).

    </Tip>
    c                	       sT   e Zd Z fddZedeej eej ee	 ee	 e
eef dddZ  ZS )FocalNetForMaskedImageModelingc                    sz   t  | t|ddd| _t|j| _t|jd| jd   }t	
t	j||jd |j ddt	|j| _|   d S )NFT)r   rB   rW   r   )Zin_channelsZout_channelsrY   )r3   r4   r   r   r   r   r   rm   r.   r   r   ra   Zencoder_strider-   ZPixelShuffledecoderr   )rA   r*   r   rC   r    r!   r4     s    
z'FocalNetForMaskedImageModeling.__init__Nr   c                 C   s8  |dur|n| j j}| j||||d}|d }|dd}|j\}}}	t|	d  }
}||||
|}| |}d}|dur| j j	| j j
 }|d||}|| j j
d| j j
dd }tjj||dd	}||  | d
  | j j }|s$|f|dd  }|dur |f| S |S t|||j|jdS )a?  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, FocalNetConfig, FocalNetForMaskedImageModeling
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-base-simmim-window6-192")
        >>> config = FocalNetConfig()
        >>> model = FocalNetForMaskedImageModeling(config)

        >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
        >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
        >>> # create random boolean mask of shape (batch_size, num_patches)
        >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()

        >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
        >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
        >>> list(reconstructed_pixel_values.shape)
        [1, 3, 192, 192]
        ```N)rF   r   r   r   r   rW   g      ?rH   none)Z	reductiongh㈵>)r%   r&   r   r   )r*   r   r   rl   ri   mathfloorr   r   r+   r,   Zrepeat_interleaverL   r   r   rd   Zl1_lossr   r-   r$   r   r   )rA   rE   rF   r   r   outputsr   rO   r-   Zsequence_lengthrf   rg   Zreconstructed_pixel_valuesZmasked_im_lossrJ   rQ   Zreconstruction_lossrv   r    r    r!   rR     sD    $
 z&FocalNetForMaskedImageModeling.forward)NNNN)r   r   r   r4   r   r   r   r   rS   r   r   r   r$   rR   rU   r    r    rC   r!   r     s       
r   z
    FocalNet Model with an image classification head on top (a linear layer on top of the pooled output) e.g. for
    ImageNet.
    c                	       sT   e Zd Z fddZedeej eej ee	 ee	 e
eef dddZ  ZS )FocalNetForImageClassificationc                    sP   t  | |j| _t|| _|jdkr:t| jj|jnt | _	| 
  d S )Nr   )r3   r4   
num_labelsr   r   r   r   r   r   
classifierr   rA   r*   rC   r    r!   r4   '  s    
"z'FocalNetForImageClassification.__init__N)rE   labelsr   r   rG   c                 C   sf  |dur|n| j j}| j|||d}|d }| |}d}|dur"| j jdu r| jdkrdd| j _n4| jdkr|jtjks|jtj	krd| j _nd| j _| j jdkrt
 }	| jdkr|	| | }n
|	||}nN| j jdkrt }	|	|d| j|d}n| j jdkr"t }	|	||}|sR|f|dd  }
|durN|f|
 S |
S t|||j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   Z
regressionZsingle_label_classificationZmulti_label_classificationrH   rW   )r%   r(   r   r   )r*   r   r   r   Zproblem_typer   rr   r   longrm   r   Zsqueezer   r   r   r'   r   r   )rA   rE   r   r   r   r   r   r(   r%   Zloss_fctrv   r    r    r!   rR   5  sH    



"


z&FocalNetForImageClassification.forward)NNNN)r   r   r   r4   r   r   r   r   Z
LongTensorr   r   r   r'   rR   rU   r    r    rC   r!   r     s       
r   zG
    FocalNet backbone, to be used with frameworks like X-Decoder.
    c                       sJ   e Zd ZdZed fddZed	eje	e
 e	e
 edddZ  ZS )
FocalNetBackboneFr   c                    s>   t  | t  | |jg|j | _t|| _|   d S ry   )	r3   r4   Z_init_backboner.   Zhidden_sizesr   r   r   r   r   rC   r    r!   r4   z  s
    
zFocalNetBackbone.__init__N)rE   r   r   rG   c           
      C   s   |dur|n| j j}|dur |n| j j}| j|ddd}|j}d}t| jD ] \}}|| jv rL||| f7 }qL|s|f}	|r|	|jf7 }	|	S t	||r|jndddS )aj  
        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoBackbone
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-tiny-lrf")
        >>> model = AutoBackbone.from_pretrained("microsoft/focalnet-tiny-lrf")

        >>> inputs = processor(image, return_tensors="pt")
        >>> outputs = model(**inputs)
        ```NTr   r    )feature_mapsr   Z
attentions)
r*   r   r   r   r   r   Zstage_namesr   r   r   )
rA   rE   r   r   r   r   r   idxZstagerv   r    r    r!   rR     s&    
zFocalNetBackbone.forward)NN)r   r   r   Zhas_attentionsr   r4   r   r   rT   r   r   r   rR   rU   r    r    rC   r!   r   r  s   
  r   )r   r   r   r   r   )rn   F):r   collections.abcr]   r   dataclassesr   typingr   r   r   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   Zactivationsr
   Zmodeling_layersr   Zmodeling_outputsr   Zmodeling_utilsr   utilsr   r   r   Zutils.backbone_utilsr   Zconfiguration_focalnetr   Z
get_loggerr   loggerr   r"   r$   r'   Moduler)   r5   rT   r|   r   rw   rx   r~   r   r   r   r   r   r   r   r   r   __all__r    r    r    r!   <module>   s|   
(HGEBKKbMA