a
    h                  	   @   s  d dl mZ d dlmZmZmZ d dlZd dlmZ ddlm	Z	 ddl
mZ ddlmZmZ dd	lmZmZ dd
lmZmZmZmZ ddlmZ eeZeeddG dd deZeeddG dd deZG dd dejZG dd dejZ dGejej!ej!ej!eej! e"e"dddZ#dd Z$G dd dejZ%dHej!e"e&ej!d!d"d#Z'G d$d% d%ejZ(G d&d' d'ejZ)G d(d) d)eZ*G d*d+ d+ejZ+ej!e,ej! ej!d,d-d.Z-G d/d0 d0ejZ.G d1d2 d2ejZ/G d3d4 d4ejZ0G d5d6 d6ejZ1G d7d8 d8eZ2G d9d: d:eZ3G d;d< d<ejZ4eG d=d> d>eZ5d?d@ Z6eG dAdB dBe5Z7edCdG dDdE dEe5Z8g dFZ9dS )I    )	dataclass)CallableOptionalUnionN)nn   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputImageClassifierOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)ModelOutputauto_docstringcan_return_tuplelogging   )VJEPA2ConfigzO
    VJEPA Predictor outputs that also contains the masked encoder outputs
    )Zcustom_introc                   @   st   e Zd ZU dZejed< dZeej ed< dZ	ee
ejdf  ed< dZee
ejdf  ed< dZeej ed< dS )	$VJEPA2WithMaskedInputPredictorOutputa  
    masked_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, returned when `context_mask` is provided which is applied on VJEPA2Encoder outputs):
        The masked hidden state of the model.
    target_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, returned when `target_mask` is provided which is applied on VJEPA2Encoder outputs):
        The target hidden state of the model.
    last_hidden_stateNmasked_hidden_state.hidden_states
attentionstarget_hidden_state)__name__
__module____qualname____doc__torchFloatTensor__annotations__r   r   r   tupler   r    r"   r"   f/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/vjepa2/modeling_vjepa2.pyr       s   

r   zs
    VJEPA outputs that also contains the masked encoder outputs
    Optionally contains the predictor outputs
    c                       s   e Zd ZU dZejed< dZeej ed< dZ	ee
ejdf  ed< dZee
ejdf  ed< dZee ed<  fd	d
Z  ZS ) VJEPA2WithMaskedInputModelOutputaq  
    masked_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, returned when `context_mask` is provided which is applied on VJEPA2Encoder outputs):
        The masked hidden state of the model.
    predictor_output (`VJEPA2WithMaskedInputPredictorOutput`, *optional*):
        The output from the Predictor module.
    r   Nr   .r   r   predictor_outputc                    s4   t t  }t|d tr,|d  |d< t|S )N)listsuperto_tuple
isinstancer   r!   )selfoutput	__class__r"   r#   r)   J   s    z)VJEPA2WithMaskedInputModelOutput.to_tuple)r   r   r   r   r   r   r    r   r   r   r!   r   r%   r   r)   __classcell__r"   r"   r-   r#   r$   5   s   

r$   c                       sJ   e Zd ZdZdeed fddZedd Ze	j
e	j
dd	d
Z  ZS )VJEPA2PatchEmbeddings3Dz"
    Image to Patch Embedding
       confighidden_sizec                    sR   t    |j| _|j| _|| _tj|j||j|j|jf|j|j|jfd| _d S )N)Zin_channelsZout_channelsZkernel_sizeZstride)	r(   __init__
patch_sizetubelet_sizer4   r   Conv3dZin_chansprojr+   r3   r4   r-   r"   r#   r5   V   s    
z VJEPA2PatchEmbeddings3D.__init__c                 C   s$   | j | j | j| j  | j| j  S Nframes_per_clipr7   	crop_sizer6   r3   r"   r"   r#   num_patchesg   s    


z#VJEPA2PatchEmbeddings3D.num_patchespixel_values_videosreturnc                 C   s   |  |ddd}|S )N   r   )r9   flatten	transpose)r+   rB   xr"   r"   r#   forwardo   s    zVJEPA2PatchEmbeddings3D.forward)r1   )r   r   r   r   r   intr5   staticmethodr@   r   TensorrH   r/   r"   r"   r-   r#   r0   Q   s    
r0   c                       s>   e Zd ZdZd	eed fddZejejdddZ	  Z
S )
VJEPA2Embeddings>
    Construct mask token, position and patch embeddings.
    r1   r2   c                    s:   t    || _|| _t||d| _| jj| _|j| _d S )Nr4   )r(   r5   r3   r4   r0   patch_embeddingsr@   r6   r:   r-   r"   r#   r5   y   s    

zVJEPA2Embeddings.__init__rA   c                 C   sd   |j d }|ddddd}|| jjk r>|dd| jjdd}| jjjj}|j	|d}| |}|S )Nr   r   rD   r      )dtype)
shapepermuter3   r7   repeatrO   r9   weightrQ   to)r+   rB   Z
num_framesZtarget_dtype
embeddingsr"   r"   r#   rH      s    

zVJEPA2Embeddings.forward)r1   )r   r   r   r   r   rI   r5   r   rK   rH   r/   r"   r"   r-   r#   rL   t   s   
rL           )modulequerykeyvalueattention_maskscalingdropoutc           
      K   s|   t ||dd| }tjj|dt jd|j}tjj	||| j
d}|d urX|| }t ||}	|	dd }	|	|fS )Nr&   )dimrQ   )ptrainingr   rD   )r   matmulrF   r   Z
functionalZsoftmaxfloat32rV   rQ   r_   rc   
contiguous)
rY   rZ   r[   r\   r]   r^   r_   kwargsattn_weightsattn_outputr"   r"   r#   eager_attention_forward   s    rj   c                 C   s   |   \}}}}tj|d | j| jd}||d  }dd|  }|d| }| }| }	|d	dddd}|	d	dddd}	| 
dd}
|
jdd	\}}tj| |fdd	}
|
d
}
| |	 |
|  S )NrD   rQ   deviceg       @      ?i'  r&   r   )r&   rD   ra   r`   )sizer   arangerQ   rl   	unsqueezesincossqueezerT   Z	unflattenZunbindstackrE   )rG   posB	num_headsNDomegafreqZemb_sinZemb_cosyy1y2r"   r"   r#   rotate_queries_or_keys   s    
r   c                       s   e Zd Zdeeed fddZdd Zdd	 ZdddZdd Z	de
ej ee
ej eeejejf eej f dddZ  ZS )VJEPA2RopeAttentionr1      )r3   r4   num_attention_headsc                    sD  t    || _|| _|| _|| dkr@td|f d| dt|| | _| j| j | _t	j
|| j|jd| _t	j
|| j|jd| _t	j
|| j|jd| _t	
||| _|j| _t	| j| _| jj| jj | _| jj| jj | _td| jd d  | _td| jd d  | _td| jd d  | _| jd | _d	| _d S )
Nr   zThe hidden size z4 is not a multiple of the number of attention heads .biasrD   r         F)r(   r5   r3   r4   r   
ValueErrorrI   attention_head_sizeall_head_sizer   LinearZqkv_biasrZ   r[   r\   r9   Zattention_probs_dropout_probdropout_probZDropoutr_   r>   r6   	grid_sizer=   r7   Z
grid_depthd_dimh_dimw_dimr^   	is_causal)r+   r3   r4   r   r-   r"   r#   r5      s2    

zVJEPA2RopeAttention.__init__c                 C   s   t | j| j }|| S r;   )rI   r   )r+   idstokens_per_framer"   r"   r#   _get_frame_pos   s    z"VJEPA2RopeAttention._get_frame_posc                 C   s4   t | j| j }| |}|||  }| j}|| S r;   )rI   r   r   )r+   r   r   	frame_idstokens_per_rowr"   r"   r#   _get_height_pos   s
    
z#VJEPA2RopeAttention._get_height_posNc                 C   s   |j }|d}|d ur0|dd| jd}ntj||d}t| j| j }| 	|}| j}| 
|}	|||  ||	  }
||	|
fS )Nr   rl   )rl   ro   rq   rT   r   r   rp   rI   r   r   r   )r+   rG   masksrl   Z
token_sizer   r   r   r   Z
height_idsZ	width_idsr"   r"   r#   get_position_ids  s    


z$VJEPA2RopeAttention.get_position_idsc                 C   s   |\}}}d}t |d||| j f |d}|| j7 }t |d||| j f |d}|| j7 }t |d||| j f |d}	|| j7 }|| jk r|d|d f }
tj|||	|
gdd}ntj|||	gdd}|S )Nr   .)rv   r&   rn   )r   r   r   r   r   r   cat)r+   Zqkpos_idsZd_maskZh_maskZw_masksZqkdZqkhZqkwZqkrr"   r"   r#   apply_rotary_embeddings  s    




z+VJEPA2RopeAttention.apply_rotary_embeddingsF)position_maskoutput_attentions	head_maskrC   c              
   C   s<  |j \}}}| ||d| j| jdd}| ||d| j| jdd}	| ||d| j| jdd}
| j||d}| 	|	|}	| 	||}t
}| jjdkr| jjdkr|rtd nt| jj }|| ||	|
|| j| j| jsdn| jd	\}}| d d
 | jf }| ||}|r2||fn|f}|S )Nr&   r   rD   )r   eagersdpa`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.rX   r   r^   r_   r`   )rR   rZ   viewr   r   rF   r[   r\   r   r   rj   r3   _attn_implementationloggerwarning_oncer   r   r^   rc   r   ro   r   r9   reshape)r+   r   r   r   r   
batch_size
seq_length_Zquery_layerZ	key_layerZvalue_layerr   attention_interfaceZcontext_layerZattention_probsZnew_context_layer_shapeoutputsr"   r"   r#   rH   )  sV    



zVJEPA2RopeAttention.forward)r1   r   )N)NFN)r   r   r   r   rI   r5   r   r   r   r   r   r   rK   boolr   r!   rH   r/   r"   r"   r-   r#   r      s(     %	
   r   F)input	drop_probrc   rC   c                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    rX   r   r   )r   rk   )rR   ndimr   ZrandrQ   rl   Zfloor_div)r   r   rc   Z	keep_probrR   Zrandom_tensorr,   r"   r"   r#   	drop_pathc  s    
r   c                       sN   e Zd ZdZdee d fddZejejdddZ	e
d	d
dZ  ZS )VJEPA2DropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).Nr   c                    s   t    || _d S r;   )r(   r5   r   )r+   r   r-   r"   r#   r5   {  s    
zVJEPA2DropPath.__init__)r   rC   c                 C   s   t || j| jS r;   )r   r   rc   )r+   r   r"   r"   r#   rH     s    zVJEPA2DropPath.forwardrC   c                 C   s   d| j  S )Nzp=r   r+   r"   r"   r#   
extra_repr  s    zVJEPA2DropPath.extra_repr)N)r   r   r   r   r   floatr5   r   rK   rH   strr   r/   r"   r"   r-   r#   r   x  s   r   c                       s<   e Zd Zd	eeed fddZejejdddZ	  Z
S )
	VJEPA2MLPr1         @)r3   r4   	mlp_ratioc                    sR   t    | }}t|| }tj||dd| _t|j | _tj||dd| _	d S NTr   )
r(   r5   rI   r   r   fc1r   Z
hidden_act
activationfc2)r+   r3   r4   r   Zin_featuresZout_featuresZhidden_featuresr-   r"   r#   r5     s    
zVJEPA2MLP.__init__hidden_staterC   c                 C   s"   |  |}| |}| |}|S r;   )r   r   r   )r+   r   r"   r"   r#   rH     s    


zVJEPA2MLP.forward)r1   r   )r   r   r   r   rI   r   r5   r   rK   rH   r/   r"   r"   r-   r#   r     s   r   c                       s`   e Zd ZdZdeeeeed fddZdej	e
ej	 e
ej	 eeej	df dddZ  ZS )VJEPA2LayerzCThis corresponds to the Block class in the original implementation.rX   r1   r   r   )r3   drop_path_rater4   r   r   c                    s   t    || _|| _|| _|| _tj||jd| _	t
|||| _|jdkrTt|nt | _tj||jd| _t|||d| _d S )NepsrX   )r4   r   )r(   r5   r3   r4   r   r   r   	LayerNormlayer_norm_epsnorm1r   	attentionr   r   ZIdentityr   norm2r   mlp)r+   r3   r   r4   r   r   r-   r"   r#   r5     s    
zVJEPA2Layer.__init__NF.)r   r   r   r   rC   c           	      C   sv   |}|  |}| j||||d}|d }| || }|}| |}| |}| || }|dd  }|f| }|S )N)r   r   r   r   r   )r   r   r   r   r   )	r+   r   r   r   r   residualZself_attention_outputsZattention_outputr   r"   r"   r#   rH     s"    



zVJEPA2Layer.forward)rX   r1   r   r   )NNF)r   r   r   r   r   r   rI   r5   r   rK   r   r   r!   rH   r/   r"   r"   r-   r#   r     s,          r   c                       sJ   e Zd Zed fddZed	eej eej e	e	e
dddZ  ZS )
VJEPA2Encoderr?   c                    sx   t     | _t  jd| _ fddt jD t	 fddt jD | _
tj j jd| _d| _d S )NrN   c                    s.   g | ]&} j d kr& j|  j d   ndqS r   rX   )num_hidden_layersr   .0ir?   r"   r#   
<listcomp>  s   z*VJEPA2Encoder.__init__.<locals>.<listcomp>c              	      s(   g | ] }t  |  j j jd qS )r   r4   r   r   )r   r4   r   r   r   r3   Zdrop_path_ratesr"   r#   r     s   r   F)r(   r5   r3   rL   r4   rW   ranger   r   
ModuleListlayerr   r   	layernormgradient_checkpointingr+   r3   r-   r   r#   r5     s    

zVJEPA2Encoder.__init__NF)rB   r   r   output_hidden_statesrC   c                 K   s   |rdnd }|rdnd }|  |}t| jD ]R\}	}
|rB||f }|d urR||	 nd }|
|d ||}|d }|r,||d f }q,| |}|r||f }t|||dS )Nr"   r   r   r   r   r   )rW   	enumerater   r   r
   )r+   rB   r   r   r   rg   all_hidden_statesall_self_attentionsr   r   layer_modulelayer_head_masklayer_outputsr"   r"   r#   rH     s&    	



zVJEPA2Encoder.forward)NNFF)r   r   r   r   r5   r   r   r   rK   r   r
   rH   r/   r"   r"   r-   r#   r     s       r   )tensorr   rC   c                 C   sX   g }|D ]@}| | j}|ddd| d}|tj| d|dg7 }qtj|ddS )z
    Args:
        tensor (`torch.Tensor`):
            Tensor of shape [batch_size, num_patches, feature_dim]
        masks (`List[torch.Tensor]`):
            List of tensors of shape [batch_size, num_patches] containing indices of patches to keep
    r&   r   ra   indexr   rn   )rV   rl   rq   rT   ro   r   gatherr   )r   r   Zall_masked_tensorsmaskZ	mask_keepr"   r"   r#   apply_masks  s    r   c                       sd   e Zd ZdZed fddZedd Zdej	e
ej	 e
ej	 eeej	ej	f dd	d
Z  ZS )VJEPA2PredictorEmbeddingsrM   r?   c                    sf   t    || _t|j|j| _d| _|j	| _
|j| _tt| jdd|j| _|j| _|| _d S )Nr   r   )r(   r5   r3   r   r   r4   pred_hidden_sizepredictor_embeddingsnum_mask_tokensZpred_zero_init_mask_tokenszero_init_mask_tokensZpred_num_mask_tokens	Parameterr   zerosmask_tokensr6   r   r-   r"   r#   r5   !  s    
z"VJEPA2PredictorEmbeddings.__init__c                 C   sJ   | j dkr.| j | j | j| j  | j| j  S | j| j | j| j  S d S )Nr   r<   r?   r"   r"   r#   r@   .  s    



z%VJEPA2PredictorEmbeddings.num_patchesr   )r   context_masktarget_mask
mask_indexrC   c                 C   s   | d}| |}|| j }| j| }|d  d }|||d}t||}|t|dd}tj	||gdd}	tj	|dd}
tj	|dd}tj	|
|gdd}|	|fS )z
        hidden_states : encoder outputs (context)
        context_mask: tokens of the context (outputs from the encoder)
        target_mask: tokens to predict
        mask_index: index of the target mask to choose (useful for multiclip?)
        r   r   rn   )
ro   r   r   r   maxrT   r   lenr   r   )r+   r   r   r   r   rw   contexttargetZmax_patch_numrW   cmtmr   r"   r"   r#   rH   9  s    




z!VJEPA2PredictorEmbeddings.forward)r   )r   r   r   r   r   r5   rJ   r@   r   rK   r'   rI   r!   rH   r/   r"   r"   r-   r#   r     s   
 r   c                
       sh   e Zd Zed fddZdddZdd Zedej	e
ej	 e
ej	 eej	 eeed
ddZ  ZS )VJEPA2Predictorr?   c                    s   t     | _d| _t | _ fddt jD t	 fddt jD | _
tj j jd| _tj j jdd| _d S )NFc                    s.   g | ]&} j d kr& j|  j d   ndqS r   )pred_num_hidden_layersr   r   r?   r"   r#   r   h  s   z,VJEPA2Predictor.__init__.<locals>.<listcomp>c              	      s(   g | ] }t  |  j j jd qS r   )r   r   Zpred_num_attention_headsZpred_mlp_ratior   r   r"   r#   r   q  s   r   Tr   )r(   r5   r3   r   r   rW   r   r   r   r   r   r   r   r   r   r   r4   r9   r   r-   r   r#   r5   c  s    


zVJEPA2Predictor.__init__Nc              	   C   s<  | |j}tj|d|d}| |j}|ddd|d}tj|d|d}|d ur2|d d ur2| |j}|ddddd}|ddd|d|ddddddd|d}tj|d|d}|dddd|d|d|dd}tj|d|d}|ddddd}|||fS )Nr   r   r&   r   rD   r   rP   )rV   rl   r   r   rq   expandro   rS   )r+   r   position_masksargsortr   Zhidden_states_argsortZ
argsort_4dZ
argsort_5dr"   r"   r#   sort_tokens  s<    

zVJEPA2Predictor.sort_tokensc                 C   sH   | |j}tj|dd}|ddd|d}tj|d|d}|S )Nr   rn   r&   r   )rV   rl   r   r  rq   r  ro   r   )r+   r   r  Zreverse_argsortr"   r"   r#   unsort_tokens  s
    zVJEPA2Predictor.unsort_tokensF)encoder_hidden_statesr   r   r   r   r   rC   c                 K   s  |rdnd }|rdnd }	t ||}|j\}
}}| |||\}}tj|dd}| ||||\}}}t| jD ]R\}}|r||f }|d ur|| nd }|||||}|d }|rn|	|d f }	qn|r||f }| |}| 	||}|d d |d f }| 
|}t|||	dS )Nr"   r   rn   r   r   )r   rR   rW   r   r  r  r   r   r   r  r9   r
   )r+   r  r   r   r   r   r   rg   r   r   r   ZN_ctxtrz   r   r  r  r   r   r   r   r"   r"   r#   rH     s4    




zVJEPA2Predictor.forward)N)NFF)r   r   r   r   r5   r  r  r   r   rK   r'   r   r   r
   rH   r/   r"   r"   r-   r#   r   b  s    
!   r   c                       sX   e Zd ZdZed fddZd
ejeej ee	 e
ejeej f ddd	Z  ZS )VJEPA2PoolerSelfAttentionz=Multi-headed attention from 'Attention Is All You Need' paperr?   c                    s   t    || _|j| _|j| _| j| j | _| j| j | jkrZtd| j d| j d| jd | _	|j
| _d| _t| j| j| _t| j| j| _t| j| j| _t| j| j| _d S Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).r   F)r(   r5   r3   r4   	embed_dimr   rx   head_dimr   scaleattention_dropoutr_   r   r   r   k_projv_projq_projout_projr   r-   r"   r#   r5     s$    

z"VJEPA2PoolerSelfAttention.__init__NFr   r]   r   rC   c              
   C   s  |j \}}}| |}| |}| |}	|||| j| jdd}|||| j| jdd}|	||| j| jdd}	t}
| j	j
dkr| j	j
dkr|rtd nt| j	j
 }
|
| |||	|| j| j| jsdn| jd\}}|||| }| |}|sd}||fS 	z#Input shape: Batch x Time x Channelr   rD   r   r   r   rX   r   N)rR   r  r  r  r   rx   r
  rF   rj   r3   r   r   r   r   r   r  rc   r_   r   rf   r  )r+   r   r]   r   r   r   r	  querieskeysvaluesr   ri   rh   r"   r"   r#   rH     s:    




z!VJEPA2PoolerSelfAttention.forward)NFr   r   r   r   r   r5   r   rK   r   r   r!   rH   r/   r"   r"   r-   r#   r    s     r  c                
       s`   e Zd ZdZed fddZd
ejejejeej ee	 e
ejeej f ddd	Z  ZS )VJEPA2PoolerCrossAttentionz_It's different from other cross-attention layers, doesn't have output projection layer (o_proj)r?   c                    s   t    || _|j| _|j| _| j| j | _| j| j | jkrZtd| j d| j d| jd | _	|j
| _d| _t| j| j| _t| j| j| _t| j| j| _d S r  )r(   r5   r3   r4   r	  r   rx   r
  r   r  r  r_   r   r   r   r  r  r  r   r-   r"   r#   r5   '  s"    

z#VJEPA2PoolerCrossAttention.__init__NF)r  r  r  r]   r   rC   c              
   C   s  |j \}}}|j d }	| |}| |}| |}|||| j| jdd}|||	| j| jdd}|||	| j| jdd}t}
| j	j
dkr| j	j
dkr|rtd nt| j	j
 }
|
| ||||| j| j| jsdn| jd\}}|||| }|sd}||fS r  )rR   r  r  r  r   rx   r
  rF   rj   r3   r   r   r   r   r   r  rc   r_   r   rf   )r+   r  r  r  r]   r   r   Zq_seq_lengthr	  Zkv_seq_lengthr   ri   rh   r"   r"   r#   rH   :  s:    





z"VJEPA2PoolerCrossAttention.forward)NFr  r"   r"   r-   r#   r  "  s     r  c                       sJ   e Zd Zed fddZd	ejejee e	ejdf dddZ
  ZS )
VJEPA2PoolerSelfAttentionLayerr?   c                    sP   t    tj|j|jd| _t|| _tj|j|jd| _	t
||jd| _d S Nr   rN   )r(   r5   r   r   r4   r   layer_norm1r  	self_attnlayer_norm2r   r   r   r-   r"   r#   r5   n  s
    

z'VJEPA2PoolerSelfAttentionLayer.__init__F.r  c                 C   sb   |}|  |}| j|||d\}}|| }|}| |}| |}|| }|f}|r^||f7 }|S )a=  
        Args:
            hidden_states (`torch.FloatTensor`):
                Input to the layer of shape `(batch, seq_len, embed_dim)`.
            attention_mask (`torch.FloatTensor`):
                Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r   r]   r   )r  r  r  r   )r+   r   r]   r   r   rh   r   r"   r"   r#   rH   u  s     




z&VJEPA2PoolerSelfAttentionLayer.forward)Fr   r   r   r   r5   r   rK   r   r   r!   rH   r/   r"   r"   r-   r#   r  m  s    r  c                       sN   e Zd Zed fddZd
ejejeej ee	ejdf ddd	Z
  ZS )VJEPA2PoolerCrossAttentionLayerr?   c                    sP   t    tj|j|jd| _t|| _tj|j|jd| _	t
||jd| _d S r  )r(   r5   r   r   r4   r   r  r  
cross_attnr  r   r   r   r-   r"   r#   r5     s
    

z(VJEPA2PoolerCrossAttentionLayer.__init__NF.)r  r   r]   r   rC   c                 C   sh   |}|  |}| j|||||d^}}|| }|}| |}| |}|| }|f}|rd|t|7 }|S )N)r]   r   )r  r  r  r   r!   )r+   r  r   r]   r   r   rh   r   r"   r"   r#   rH     s$    



z'VJEPA2PoolerCrossAttentionLayer.forward)NFr  r"   r"   r-   r#   r    s     r  c                       s:   e Zd ZdZed fddZejejdddZ  Z	S )VJEPA2AttentivePoolerzAttentive Poolerr?   c                    sP   t    ttdd j| _t | _	t
 fddt jD | _d S )Nr   c                    s   g | ]}t  qS r"   )r  )r   r   r?   r"   r#   r         z2VJEPA2AttentivePooler.__init__.<locals>.<listcomp>)r(   r5   r   r   r   r   r4   query_tokensr  cross_attention_layerr   r   Znum_pooler_layersself_attention_layersr   r-   r?   r#   r5     s    

zVJEPA2AttentivePooler.__init__r   c                 C   sL   | j D ]}||d dd }q| j|jd dd}| ||d }|dS )N)r]   r   r   )r$  r"  rT   rR   r#  rt   )r+   r   r   r  r"   r"   r#   rH     s
    
zVJEPA2AttentivePooler.forward)
r   r   r   r   r   r5   r   rK   rH   r/   r"   r"   r-   r#   r     s   r   c                   @   s:   e Zd ZU eed< dZdZdZg dZdZ	dZ
dd ZdS )	VJEPA2PreTrainedModelr3   vjepa2rB   T)r   r  r  r   c                 C   s4  | j j}dd }t|tr||j|d t|jdD ]8\}}||d  }||jjj	|d ||j
jj	|d q4|t|jd d  }||jj
jj	|d nt|tr|jr|jj  n||j|d nht|tjtjtjfr||j	|d |jdur0|jj  n(t|tjr0|jj  |j	jd dS )zInitialize the weightsc                 S   s2   | j tj}tjj|d|d}|| j| _ d S )NrX   )meanstd)datarV   r   re   r   initZtrunc_normal_rQ   )rU   r(  Zdata_float_32Z	data_initr"   r"   r#   trunc_normal_f32_  s    z>VJEPA2PreTrainedModel._init_weights.<locals>.trunc_normal_f32_)r(  r   g      ?Nrm   )r3   Zinitializer_ranger*   r   r"  r   r$  r  r  rU   r   r   r   r#  r   r   r   r)  Zzero_r   r   ZConv2dr8   r   r   Zfill_)r+   rY   Zinit_stdr+  r   r   r(  r"   r"   r#   _init_weights  s*    

z#VJEPA2PreTrainedModel._init_weightsN)r   r   r   r   r    Zbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_no_split_modulesZ_supports_sdpaZ_supports_flash_attnr,  r"   r"   r"   r#   r%    s   
r%  c                 C   s:   | dur,|  d d} | |dddd} n
dg| } | S )z
    Inputs:
        - head_mask: bsz x seq_length x seq_length | None
    Returns
        - [num_hidden_layers x batch x num_heads x seq_length x seq_length] | [num_hidden_layers]
    Nr   r   r&   )rq   r  )r   r   r"   r"   r#   _convert_head_mask_to_5d  s
    
r-  c                       s   e Zd Zed fddZedddZeede	j
ee	j
 eee	j
  ee	j
 eee	j
  eee ee ed		d
dZe	j
dddZ  ZS )VJEPA2Modelr?   c                    s2   t  | || _t|| _t|| _|   d S r;   )r(   r5   r3   r   encoderr   	predictor	post_initr   r-   r"   r#   r5     s
    

zVJEPA2Model.__init__r   c                 C   s
   | j jjS r;   )r/  rW   rO   r   r"   r"   r#   get_input_embeddings"  s    z VJEPA2Model.get_input_embeddingsNF)	rB   context_head_maskr   target_head_maskr   skip_predictorr   r   rC   c	                 K   s2  |dur|n| j j}|dur |n| j j}|du r8tdt|| j j}t|| j j}| j||||d}
|
j}|du r|du r|	d}|	d}t
j||jdd|dfg}t
j||jdd|dfg}|s| j||||||d}t|jt|||j|jd}nd}t|t|||
j|
j|d	}|S )
aL  
        context_head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
            The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard) for the context.
        context_mask (`torch.Tensor` with shape `[batch_size, patch_size, 1]`, *optional*):
            The mask position ids indicating which encoder output patches are going to be exposed to the predictor.
            By default, this mask is created as torch.arange(N).unsqueeze(0).repeat(B,1), indicating full context
            available to the predictor.
        target_head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
            The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard) for the target.
        target_mask (`torch.Tensor` with shape `[batch_size, patch_size, 1]`, *optional*):
            The mask position ids indicating which encoder output patches are going to be used as a prediction target
            for the predictor. By default, this mask is created as torch.arange(N).unsqueeze(0).repeat(B,1), indicating
            that the predictor should predict all encoder patches.
        skip_predictor (bool):
            flag to skip the predictor forward, useful if you just need the encoder outputs
        Nz'You have to specify pixel_values_videos)rB   r   r   r   r   r   r   )r  r   r   r   r   r   )r   r   r   r   )r   r   r   r   r%   )r3   r   r   r   r-  r   r   r/  r   ro   r   rp   rl   rq   rT   r0  r   r   r   r   r$   )r+   rB   r3  r   r4  r   r5  r   r   rg   Zencoder_outputsZsequence_outputrw   ry   Zpredictor_outputsr%   encoder_outputr"   r"   r#   rH   %  sV    

""zVJEPA2Model.forwardc                 C   s   |  |}|jS r;   )rH   r   )r+   rB   r6  r"   r"   r#   get_vision_featuresy  s    
zVJEPA2Model.get_vision_features)NNNNFNN)r   r   r   r   r5   r0   r2  r   r   r   rK   r   r'   r   r$   rH   r7  r/   r"   r"   r-   r#   r.    s.   
       Rr.  z}
    V-JEPA 2 Model transformer with a video classification head on top (a linear layer on top of the attentive pooler).
    c                
       sZ   e Zd Zed fddZeedeje	ej e	e
 e	e
 eeef dddZ  ZS )	VJEPA2ForVideoClassificationr?   c                    sJ   t  | |j| _t|| _t|| _tj|j	|jdd| _
|   d S r   )r(   r5   Z
num_labelsr.  r&  r   poolerr   r   r4   
classifierr1  r   r-   r"   r#   r5     s    

z%VJEPA2ForVideoClassification.__init__N)rB   labelsr   r   rC   c           
      C   s^   | j |d||d}|j}| |}| |}d}	|durJ| j||| jd}	t|	||j|jdS )ag  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> import torch
        >>> import numpy as np
        >>> from transformers import AutoVideoProcessor, VJEPA2ForVideoClassification

        >>> device = "cuda"

        >>> video_processor = AutoVideoProcessor.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2")
        >>> model = VJEPA2ForVideoClassification.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2").to(device)

        >>> video = np.ones((64, 256, 256, 3))  # 64 frames, 256x256 RGB
        >>> inputs = video_processor(video, return_tensors="pt").to(device)

        >>> # For inference
        >>> with torch.no_grad():
        ...     outputs = model(**inputs)
        >>> logits = outputs.logits

        >>> predicted_label = logits.argmax(-1).item()
        >>> print(model.config.id2label[predicted_label])

        >>> # For training
        >>> labels = torch.ones(1, dtype=torch.long, device=device)
        >>> loss = model(**inputs, labels=labels).loss

        ```T)rB   r5  r   r   N)Zpooled_logitsr;  r3   )losslogitsr   r   )	r&  r   r9  r:  Zloss_functionr3   r   r   r   )
r+   rB   r;  r   r   r   r   Zpooler_outputr=  r<  r"   r"   r#   rH     s$    ,

z$VJEPA2ForVideoClassification.forward)NNN)r   r   r   r   r5   r   r   r   rK   r   r   r   r!   r   rH   r/   r"   r"   r-   r#   r8  ~  s      
r8  )r.  r%  r8  )rX   )rX   F):dataclassesr   typingr   r   r   r   r   Zactivationsr   Zmodeling_layersr	   Zmodeling_outputsr
   r   Zmodeling_utilsr   r   utilsr   r   r   r   Zconfiguration_vjepa2r   Z
get_loggerr   r   r   r$   Moduler0   rL   rK   r   rj   r   r   r   r   r   r   r   r   r'   r   r   r   r  r  r  r  r   r%  r-  r.  r8  __all__r"   r"   r"   r#   <module>   st   
#*  7>FyGK.(0gQ