a
    h                  	   @   s  d Z ddlZddlmZ ddlmZ ddlmZm	Z	 ddl
ZddlZddlZddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZ ddlmZmZ ddlmZ ddlmZmZ ddl m!Z!m"Z"m#Z#m$Z$ ddl%m&Z&m'Z' ddl(m)Z)m*Z* ddl+m,Z, e$-e.Z/ee#ddG dd de!Z0ee#ddG dd de!Z1dd Z2G dd dej3Z4G dd  d ej3Z5d@ej3ej6ej6ej6e	ej6 e7e7d"d#d$Z8G d%d& d&ej3Z9G d'd( d(ej3Z:G d)d* d*ej3Z;G d+d, d,ej3Z<G d-d. d.ej3Z=G d/d0 d0eZ>G d1d2 d2ej3Z?e#G d3d4 d4eZ@e#G d5d6 d6e@ZAG d7d8 d8ej3ZBe#d9dG d:d; d;e@ZCe#d<dG d=d> d>e@ZDg d?ZEdS )Az,PyTorch VideoMAE (masked autoencoder) model.    N)deepcopy)	dataclass)CallableOptional)nn)MSELoss   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputImageClassifierOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack) find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputTransformersKwargsauto_docstringlogging)IMAGENET_DEFAULT_MEANIMAGENET_DEFAULT_STD)can_return_tuplecheck_model_inputs   )VideoMAEConfigz[
    Class for VideoMAEDecoder's outputs, with potential hidden states and attentions.
    )Zcustom_introc                   @   sP   e Zd ZU dZdZeej ed< dZ	ee
ej  ed< dZee
ej  ed< dS )VideoMAEDecoderOutputz
    logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
        Pixel reconstruction logits.
    Nlogitshidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   tupler    r(   r(   j/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/videomae/modeling_videomae.pyr   +   s   
r   zb
    Class for VideoMAEForPreTraining's outputs, with potential hidden states and attentions.
    c                   @   sb   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dS )VideoMAEForPreTrainingOutputz
    loss (`torch.FloatTensor` of shape `(1,)`):
        Pixel reconstruction loss.
    logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
        Pixel reconstruction logits.
    Nlossr   r   r   )r    r!   r"   r#   r+   r   r$   r%   r&   r   r   r'   r   r(   r(   r(   r)   r*   <   s
   
r*   c                    s    fddt fddt| D }t |dddddf |dddddf< t |dddddf |dddddf< t|dS )	z Sinusoid position encoding tablec                    s    fddt D S )Nc              	      s(   g | ] }t d d|d     qS )i'     )nppower).0Zhid_j)d_hidpositionr(   r)   
<listcomp>W       zOget_sinusoid_encoding_table.<locals>.get_position_angle_vec.<locals>.<listcomp>)ranger1   )r0   r5   r)   get_position_angle_vecV   s    z;get_sinusoid_encoding_table.<locals>.get_position_angle_vecc                    s   g | ]} |qS r(   r(   )r/   Zpos_i)r6   r(   r)   r2   Y   r3   z/get_sinusoid_encoding_table.<locals>.<listcomp>Nr   r,   r   )r-   arrayr4   sincosr$   r%   Z	unsqueeze)Z
n_positionr0   Zsinusoid_tabler(   )r0   r6   r)   get_sinusoid_encoding_tableR   s
    ..r:   c                       s(   e Zd ZdZ fddZdd Z  ZS )VideoMAEEmbeddingsz7
    Construct the patch and position embeddings.

    c                    s8   t    t|| _| jj| _t| j|j| _|| _d S N)	super__init__VideoMAEPatchEmbeddingspatch_embeddingsnum_patchesr:   hidden_sizeposition_embeddingsconfigselfrD   	__class__r(   r)   r>   f   s
    


zVideoMAEEmbeddings.__init__c                 C   sZ   |  |}|| j |j|jdd }|d urV|j\}}}||  }||d|}|S )NTdevicecopy)r@   rC   detachtype_astorJ   shapereshape)rF   pixel_valuesbool_masked_pos
embeddings
batch_size_num_channelsr(   r(   r)   forwardo   s    

zVideoMAEEmbeddings.forwardr    r!   r"   r#   r>   rX   __classcell__r(   r(   rG   r)   r;   `   s   	r;   c                       s(   e Zd ZdZ fddZdd Z  ZS )r?   aw  
    Video to Patch Embedding. This module turns a batch of videos of shape (batch_size, num_frames, num_channels,
    height, width) into a tensor of shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.

    The seq_len (the number of patches) equals (number of frames // tubelet_size) * (height // patch_size) * (width //
    patch_size).

    c           	         s   t    |j}|j}|j}|j}|j}|j}t|t	j
jr@|n||f}t|t	j
jrZ|n||f}|| _|| _t|| _|d |d  |d |d   || j  }|| _|| _tj||| j|d |d f| j|d |d fd| _d S )Nr   r   )Zin_channelsZout_channelsZkernel_sizeZstride)r=   r>   
image_size
patch_sizerW   rB   
num_framestubelet_size
isinstancecollectionsabcIterableintrA   r   Conv3d
projection)	rF   rD   r[   r\   rW   rB   r]   r^   rA   rG   r(   r)   r>      s,    

(z VideoMAEPatchEmbeddings.__init__c              
   C   s   |j \}}}}}|| jkr"td|| jd ks>|| jd krltd| d| d| jd  d| jd  d	|dddd	d
}| |ddd}|S )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   r   zInput image size (*z) doesn't match model (z).r,   r      )rP   rW   
ValueErrorr[   permutere   flatten	transpose)rF   rR   rU   r]   rW   heightwidthrT   r(   r(   r)   rX      s    
(zVideoMAEPatchEmbeddings.forwardrY   r(   r(   rG   r)   r?      s   	r?           )modulequerykeyvalueattention_maskscalingdropoutc           
      K   s|   t ||dd| }tjj|dt jd|j}tjj	||| j
d}|d urX|| }t ||}	|	dd }	|	|fS )NrL   )dimdtype)ptrainingr   r,   )r$   matmulrk   r   
functionalZsoftmaxZfloat32rO   rx   ru   rz   
contiguous)
ro   rp   rq   rr   rs   rt   ru   kwargsZattn_weightsZattn_outputr(   r(   r)   eager_attention_forward   s    r   c                       sH   e Zd Zedd fddZdeej eejejf dddZ	  Z
S )	VideoMAESelfAttentionN)rD   returnc                    s  t    |j|j dkr>t|ds>td|j d|j d|| _|j| _t|j|j | _| j| j | _	|j
| _| jd | _d| _tj|j| j	dd| _tj|j| j	dd| _tj|j| j	dd| _|jrtt| j	| _tt| j	| _nd | _d | _d S )	Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .g      Fbias)r=   r>   rB   num_attention_headshasattrrh   rD   rc   attention_head_sizeall_head_sizeZattention_probs_dropout_probdropout_probrt   	is_causalr   Linearrp   rq   rr   Zqkv_bias	Parameterr$   zerosq_biasv_biasrE   rG   r(   r)   r>      s,    

zVideoMAESelfAttention.__init__)	head_maskr   c              
   C   s6  |j \}}}| jd ur&tj| jddnd }tjj|| jj	|d}tjj|| j
j	| jd}tjj|| jj	| jd}	||d| j| jdd}
||d| j| jdd}|	|d| j| jdd}t}| jjdkrt| jj }|| ||
||| j| j| jsdn| jd	\}}| d d
 | jf }||}||fS )NF)Zrequires_grad)inputweightr   rL   r   r,   eagerrn   )r   rt   ru   rv   )rP   r   r$   Z
zeros_liker   r   r|   Zlinearrq   r   rr   rp   viewr   r   rk   r   rD   Z_attn_implementationr   r   rt   rz   r   sizer   rQ   )rF   r   r   rU   Z
seq_lengthrV   Zk_biaskeysvaluesZqueriesZ	key_layerZvalue_layerZquery_layerZattention_interfaceZcontext_layerZattention_probsZnew_context_layer_shaper(   r(   r)   rX      s0    

zVideoMAESelfAttention.forward)N)r    r!   r"   r   r>   r   r$   Tensorr'   rX   rZ   r(   r(   rG   r)   r      s   r   c                       s>   e Zd ZdZed fddZejejejdddZ  Z	S )VideoMAESelfOutputz
    The residual connection is defined in VideoMAELayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    rD   c                    s.   t    t|j|j| _t|j| _d S r<   )	r=   r>   r   r   rB   denseDropouthidden_dropout_probru   rE   rG   r(   r)   r>     s    
zVideoMAESelfOutput.__init__r   input_tensorr   c                 C   s   |  |}| |}|S r<   r   ru   rF   r   r   r(   r(   r)   rX     s    

zVideoMAESelfOutput.forward)
r    r!   r"   r#   r   r>   r$   r   rX   rZ   r(   r(   rG   r)   r     s   r   c                       sR   e Zd Zed fddZee dddZdej	e
ej	 ej	dd	d
Z  ZS )VideoMAEAttentionr   c                    s*   t    t|| _t|| _t | _d S r<   )r=   r>   r   	attentionr   outputsetpruned_headsrE   rG   r(   r)   r>   %  s    


zVideoMAEAttention.__init__)headsc                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rw   )lenr   r   r   r   r   r   rp   rq   rr   r   r   r   union)rF   r   indexr(   r(   r)   prune_heads+  s    zVideoMAEAttention.prune_headsNr   r   r   c                 C   s    |  ||\}}| ||}|S r<   )r   r   )rF   r   r   Zself_attn_outputrV   r   r(   r(   r)   rX   =  s    zVideoMAEAttention.forward)N)r    r!   r"   r   r>   r   rc   r   r$   r   r   rX   rZ   r(   r(   rG   r)   r   $  s   r   c                       s6   e Zd Zed fddZejejdddZ  ZS )VideoMAEIntermediater   c                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S r<   )r=   r>   r   r   rB   intermediate_sizer   r_   Z
hidden_actstrr	   intermediate_act_fnrE   rG   r(   r)   r>   E  s
    
zVideoMAEIntermediate.__init__)r   r   c                 C   s   |  |}| |}|S r<   )r   r   )rF   r   r(   r(   r)   rX   M  s    

zVideoMAEIntermediate.forward	r    r!   r"   r   r>   r$   r   rX   rZ   r(   r(   rG   r)   r   D  s   r   c                       s:   e Zd Zed fddZejejejdddZ  ZS )VideoMAEOutputr   c                    s.   t    t|j|j| _t|j| _	d S r<   )
r=   r>   r   r   r   rB   r   r   r   ru   rE   rG   r(   r)   r>   U  s    
zVideoMAEOutput.__init__r   c                 C   s    |  |}| |}|| }|S r<   r   r   r(   r(   r)   rX   Z  s    

zVideoMAEOutput.forwardr   r(   r(   rG   r)   r   T  s   r   c                       sD   e Zd ZdZed fddZd	ejeej ejdddZ	  Z
S )
VideoMAELayerz?This corresponds to the Block class in the timm implementation.r   c                    sb   t    |j| _d| _t|| _t|| _t|| _	t
j|j|jd| _t
j|j|jd| _d S )Nr   eps)r=   r>   Zchunk_size_feed_forwardZseq_len_dimr   r   r   intermediater   r   r   	LayerNormrB   layer_norm_epslayernorm_beforelayernorm_afterrE   rG   r(   r)   r>   e  s    



zVideoMAELayer.__init__Nr   c                 C   sB   |  |}| ||}|| }| |}| |}| ||}|S r<   )r   r   r   r   r   )rF   r   r   Zhidden_states_normZattention_outputZlayer_outputr(   r(   r)   rX   o  s    


zVideoMAELayer.forward)N)r    r!   r"   r#   r   r>   r$   r   r   rX   rZ   r(   r(   rG   r)   r   b  s   
r   c                       s>   e Zd Zed fddZdejeej edddZ	  Z
S )	VideoMAEEncoderr   c                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r(   r   r/   rV   r   r(   r)   r2     r3   z,VideoMAEEncoder.__init__.<locals>.<listcomp>F)	r=   r>   rD   r   
ModuleListr4   num_hidden_layerslayergradient_checkpointingrE   rG   r   r)   r>     s    
 zVideoMAEEncoder.__init__Nr   c                 C   s<   t | jD ]&\}}|d ur"|| nd }|||}q
t|dS )Nlast_hidden_state)	enumerater   r   )rF   r   r   ilayer_moduleZlayer_head_maskr(   r(   r)   rX     s    zVideoMAEEncoder.forward)N)r    r!   r"   r   r>   r$   r   r   r   rX   rZ   r(   r(   rG   r)   r     s   r   c                   @   sD   e Zd ZU eed< dZdZdZdZdZ	dZ
dZeedZdd ZdS )	VideoMAEPreTrainedModelrD   videomaerR   T)r   r   c                 C   sj   t |tjtjfr@|jjjd| jjd |j	durf|j	j
  n&t |tjrf|j	j
  |jjd dS )zInitialize the weightsrn   )meanstdNg      ?)r_   r   r   rd   r   dataZnormal_rD   Zinitializer_ranger   Zzero_r   Zfill_)rF   ro   r(   r(   r)   _init_weights  s    
z%VideoMAEPreTrainedModel._init_weightsN)r    r!   r"   r   r&   Zbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_supports_sdpaZ_supports_flash_attnZ_supports_flex_attnZ_supports_attention_backendr   r   Z_can_record_outputsr   r(   r(   r(   r)   r     s   
r   c                	       s^   e Zd Z fddZdd Zdd Zeedej	e
ej e
ej ee edd	d
Z  ZS )VideoMAEModelc                    sT   t  | || _t|| _t|| _|jr4d | _nt	j
|j|jd| _|   d S )Nr   )r=   r>   rD   r;   rT   r   encoderuse_mean_pooling	layernormr   r   rB   r   	post_initrE   rG   r(   r)   r>     s    

zVideoMAEModel.__init__c                 C   s   | j jS r<   )rT   r@   )rF   r(   r(   r)   get_input_embeddings  s    z"VideoMAEModel.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   r   )rF   Zheads_to_pruner   r   r(   r(   r)   _prune_heads  s    zVideoMAEModel._prune_headsNrR   rS   r   r~   r   c                 K   sN   |  || jj}| ||}| j||d}|j}| jdurD| |}t|dS )a  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Each video in the
            batch must have the same number of masked patches. If `None`, then all patches are considered. Sequence
            length is `(num_frames // tubelet_size) * (image_size // patch_size) ** 2`.

        Examples:

        ```python
        >>> import av
        >>> import numpy as np

        >>> from transformers import AutoImageProcessor, VideoMAEModel
        >>> from huggingface_hub import hf_hub_download

        >>> np.random.seed(0)


        >>> def read_video_pyav(container, indices):
        ...     '''
        ...     Decode the video with PyAV decoder.
        ...     Args:
        ...         container (`av.container.input.InputContainer`): PyAV container.
        ...         indices (`list[int]`): List of frame indices to decode.
        ...     Returns:
        ...         result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
        ...     '''
        ...     frames = []
        ...     container.seek(0)
        ...     start_index = indices[0]
        ...     end_index = indices[-1]
        ...     for i, frame in enumerate(container.decode(video=0)):
        ...         if i > end_index:
        ...             break
        ...         if i >= start_index and i in indices:
        ...             frames.append(frame)
        ...     return np.stack([x.to_ndarray(format="rgb24") for x in frames])


        >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
        ...     '''
        ...     Sample a given number of frame indices from the video.
        ...     Args:
        ...         clip_len (`int`): Total number of frames to sample.
        ...         frame_sample_rate (`int`): Sample every n-th frame.
        ...         seg_len (`int`): Maximum allowed index of sample's last frame.
        ...     Returns:
        ...         indices (`list[int]`): List of sampled frame indices
        ...     '''
        ...     converted_len = int(clip_len * frame_sample_rate)
        ...     end_idx = np.random.randint(converted_len, seg_len)
        ...     start_idx = end_idx - converted_len
        ...     indices = np.linspace(start_idx, end_idx, num=clip_len)
        ...     indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
        ...     return indices


        >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
        >>> file_path = hf_hub_download(
        ...     repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
        ... )
        >>> container = av.open(file_path)

        >>> # sample 16 frames
        >>> indices = sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
        >>> video = read_video_pyav(container, indices)

        >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
        >>> model = VideoMAEModel.from_pretrained("MCG-NJU/videomae-base")

        >>> # prepare video for the model
        >>> inputs = image_processor(list(video), return_tensors="pt")

        >>> # forward pass
        >>> outputs = model(**inputs)
        >>> last_hidden_states = outputs.last_hidden_state
        >>> list(last_hidden_states.shape)
        [1, 1568, 768]
        ```r   Nr   )Zget_head_maskrD   r   rT   r   r   r   r   )rF   rR   rS   r   r~   Zembedding_outputZencoder_outputssequence_outputr(   r(   r)   rX     s    ^

zVideoMAEModel.forward)NN)r    r!   r"   r>   r   r   r   r   r$   r%   r   
BoolTensorr   r   r   r   rX   rZ   r(   r(   rG   r)   r     s     r   c                       s4   e Zd Zed fddZejedddZ  Z	S )VideoMAEDecoderr   c                    s   t    |j|j |jd  }t| |j _|j _	|j
 _|j _t fddt|jD | _t|j| _|dkrt|j|nt | _d| _ | _d S )Nr,   c                    s   g | ]}t  qS r(   r   r   Zdecoder_configr(   r)   r2   >  r3   z,VideoMAEDecoder.__init__.<locals>.<listcomp>r   F)r=   r>   rW   r^   r\   r   decoder_hidden_sizerB   Zdecoder_num_hidden_layersr   Zdecoder_num_attention_headsr   Zdecoder_intermediate_sizer   r   r   r4   decoder_layersr   normr   Identityheadr   rD   )rF   rD   Zdecoder_num_labelsrG   r   r)   r>   3  s    
zVideoMAEDecoder.__init__)r   return_token_numc                 C   sT   | j D ]}||d d}q|dkr6|d d | d f }| |}| |}t|dS )Nr   r   )r   )r   r   r   r   )rF   r   r   r   r   r(   r(   r)   rX   I  s    


zVideoMAEDecoder.forward)
r    r!   r"   r   r>   r$   r   rc   rX   rZ   r(   r(   rG   r)   r   2  s   r   zb
    The VideoMAE Model transformer with the decoder on top for self-supervised pre-training.
    c                	       sJ   e Zd Z fddZeedejeje	ej
 ee edddZ  ZS )VideoMAEForPreTrainingc                    st   t  | || _t|| _tj|j|jdd| _	t
tdd|j| _t| jjj|j| _t|| _|   d S )NFr   r   )r=   r>   rD   r   r   r   r   rB   r   encoder_to_decoderr   r$   r   
mask_tokenr:   rT   rA   rC   r   decoderr   rE   rG   r(   r)   r>   ^  s    

zVideoMAEForPreTraining.__init__Nr   c                 K   s&  | j |f||d|}|j}| |}|j\}}}	|du rDtd| j|dd|}
|
 j	|j
dd}
|
|  |d|	}|
| |d|	}tj|| | j| gdd}| ||jd }|j}d}t  | jjd	kr|}nd|j
}|j}ttj	||d
ddddddf }ttj	||d
ddddddf }|| | }|j\}}}	}}| jj| jj }}| jjr@|||| ||	|| ||| |}|dddddddd	 }|||| | | | | || | |	}||jddd |jdddd  d  }|||| | | | | || | |	 }n| jjd	krVtd|||| ||	|| ||| |}|dddddddd	 }|||| | | | | || | |	 }|j\}}}	|| |d|	}W d   n1 s0    Y  t! }|||}t"|||j#|j$dS )a  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Each video in the
            batch must have the same number of masked patches. Sequence length is `(num_frames // tubelet_size) *
            (image_size // patch_size) ** 2`.

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, VideoMAEForPreTraining
        >>> import numpy as np
        >>> import torch

        >>> num_frames = 16
        >>> video = list(np.random.randint(0, 256, (num_frames, 3, 224, 224)))

        >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
        >>> model = VideoMAEForPreTraining.from_pretrained("MCG-NJU/videomae-base")

        >>> pixel_values = image_processor(video, return_tensors="pt").pixel_values

        >>> num_patches_per_frame = (model.config.image_size // model.config.patch_size) ** 2
        >>> seq_length = (num_frames // model.config.tubelet_size) * num_patches_per_frame
        >>> bool_masked_pos = torch.randint(0, 2, (1, seq_length)).bool()

        >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
        >>> loss = outputs.loss
        ```)rS   r   Nz!One must provided a boolean mask rL   TrI   r   r   r   )rJ   rx   r   rg      r,         rv   )rw   keepdim)rw   Zunbiasedr   gư>zQCan't unnormalize non-RGB images. Consider setting config.norm_pix_loss to False.r+   r   r   r   )%r   r   r   rP   rh   rC   expandrN   rM   rO   rJ   rQ   r$   catr   r   r   Zno_gradrD   rW   rx   Z	as_tensorr   r   r^   r\   Znorm_pix_lossr   ri   r}   r   varsqrtr   r*   r   r   )rF   rR   rS   r   r~   outputsr   rU   rV   rW   Zexpanded_position_embeddingsZpos_emb_visibleZpos_emb_maskZx_fullZdecoder_outputsr   r+   framesrJ   rx   r   r   timerl   rm   r^   r\   Zframes_normZvideos_patchlabelsZloss_fctr(   r(   r)   rX   o  s    $
&&

2
zVideoMAEForPreTraining.forward)N)r    r!   r"   r>   r   r   r$   r%   r   r   r   r   r   r*   rX   rZ   r(   r(   rG   r)   r   X  s    r   z
    VideoMAE Model transformer with a video classification head on top (a linear layer on top of the average pooled hidden
    states of all tokens) e.g. for ImageNet.
    c                	       sR   e Zd Z fddZeedeej eej eej e	e
 edddZ  ZS )VideoMAEForVideoClassificationc                    sf   t  | |j| _t|| _|jr0t|jnd | _	|jdkrPt
|j|jnt | _|   d S )Nr   )r=   r>   Z
num_labelsr   r   r   r   r   rB   fc_normr   r   
classifierr   rE   rG   r(   r)   r>   
  s    
$z'VideoMAEForVideoClassification.__init__N)rR   r   r   r~   r   c           
      K   s   | j |fd|i|}|j}| jdur<|d}| |}n|dddf }| |}d}	|durz| j||| jfi |}	t|	||j|j	dS )a!  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> import av
        >>> import torch
        >>> import numpy as np

        >>> from transformers import AutoImageProcessor, VideoMAEForVideoClassification
        >>> from huggingface_hub import hf_hub_download

        >>> np.random.seed(0)


        >>> def read_video_pyav(container, indices):
        ...     '''
        ...     Decode the video with PyAV decoder.
        ...     Args:
        ...         container (`av.container.input.InputContainer`): PyAV container.
        ...         indices (`list[int]`): List of frame indices to decode.
        ...     Returns:
        ...         result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
        ...     '''
        ...     frames = []
        ...     container.seek(0)
        ...     start_index = indices[0]
        ...     end_index = indices[-1]
        ...     for i, frame in enumerate(container.decode(video=0)):
        ...         if i > end_index:
        ...             break
        ...         if i >= start_index and i in indices:
        ...             frames.append(frame)
        ...     return np.stack([x.to_ndarray(format="rgb24") for x in frames])


        >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
        ...     '''
        ...     Sample a given number of frame indices from the video.
        ...     Args:
        ...         clip_len (`int`): Total number of frames to sample.
        ...         frame_sample_rate (`int`): Sample every n-th frame.
        ...         seg_len (`int`): Maximum allowed index of sample's last frame.
        ...     Returns:
        ...         indices (`list[int]`): List of sampled frame indices
        ...     '''
        ...     converted_len = int(clip_len * frame_sample_rate)
        ...     end_idx = np.random.randint(converted_len, seg_len)
        ...     start_idx = end_idx - converted_len
        ...     indices = np.linspace(start_idx, end_idx, num=clip_len)
        ...     indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
        ...     return indices


        >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
        >>> file_path = hf_hub_download(
        ...     repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
        ... )
        >>> container = av.open(file_path)

        >>> # sample 16 frames
        >>> indices = sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
        >>> video = read_video_pyav(container, indices)

        >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
        >>> model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")

        >>> inputs = image_processor(list(video), return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**inputs)
        ...     logits = outputs.logits

        >>> # model predicts one of the 400 Kinetics-400 classes
        >>> predicted_label = logits.argmax(-1).item()
        >>> print(model.config.id2label[predicted_label])
        eating spaghetti
        ```r   Nr   r   r   )
r   r   r   r   r   Zloss_functionrD   r   r   r   )
rF   rR   r   r   r~   r   r   r   r   r+   r(   r(   r)   rX     s     \


z&VideoMAEForVideoClassification.forward)NNN)r    r!   r"   r>   r   r   r   r$   r   r   r   r   rX   rZ   r(   r(   rG   r)   r     s      r   )r   r   r   r   )rn   )Fr#   collections.abcr`   rK   r   dataclassesr   typingr   r   numpyr-   r$   Ztorch.utils.checkpointr   Ztorch.nnr   Zactivationsr	   Zmodeling_layersr
   Zmodeling_outputsr   r   Zmodeling_utilsr   r   Zprocessing_utilsr   Zpytorch_utilsr   r   utilsr   r   r   r   Zutils.constantsr   r   Zutils.genericr   r   Zconfiguration_videomaer   Z
get_loggerr    loggerr   r*   r:   Moduler;   r?   r   floatr   r   r   r   r   r   r   r   r   r   r   r   r   __all__r(   r(   r(   r)   <module>   s   
!= =  & ' 