a
    h~                     @   s  d Z ddlZddlmZ ddlmZ ddlZddlZddlmZ ddl	m
Z
 ddlmZ dd	lmZmZmZ dd
lmZ ddlmZ ddlmZmZ ddlmZ ddlmZ eeZeeG dd deZG dd dej Z!G dd dej Z"G dd dej Z#G dd dej Z$G dd dej Z%G dd dej Z&G dd dej Z'G d d! d!eZ(G d"d# d#ej Z)G d$d% d%ej Z*eG d&d' d'eZ+G d(d) d)ej Z,G d*d+ d+ej Z-e,e-d,Z.ed-d.G d/d0 d0e+Z/G d1d2 d2ej Z0ed3d.G d4d5 d5e+Z1g d6Z2dS )7zPyTorch TVP Model    N)	dataclass)Optional)nn   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPoolingModelOutput)PreTrainedModel)prune_linear_layer)auto_docstringlogging)load_backbone   )	TvpConfigc                   @   sj   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dS )TvpVideoGroundingOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
        Temporal-Distance IoU loss for video grounding.
    logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
        Contains start_time/duration and end_time/duration. It is the time slot of the videos corresponding to the
        input texts.
    attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.
    Nlosslogits.hidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   tupler    r   r   `/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/tvp/modeling_tvp.pyr   &   s
   
r   c                       s@   e Zd ZdZ fddZdd Zdd Zdd	 Zd
d Z  Z	S )TvpLossa~  
    This class computes the losses for `TvpForVideoGrounding`. The process happens in two steps: 1) we compute
    hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched
    ground-truth / prediction (supervise class and box).

    Args:
        losses (`list[str]`):
            List of all the losses to be applied.
    c                    sL   t    | j| j| jd| _|D ]}|| jvr"td| dq"|| _d S )NioudistancedurationzLoss z not supported)super__init__loss_iouloss_distanceloss_durationloss_map
ValueErrorlosses)selfr-   r   	__class__r   r    r'   E   s    

zTvpLoss.__init__c           	      C   sH   t ||t || }t ||t || }d|jdd|  }|S )z6
        Measure the intersection over union.
        r   r   min)r   r2   maxclamp)	r.   
start_timeend_timecandidates_start_timecandidates_end_timer%   interunionr#   r   r   r    r(   R   s    zTvpLoss.loss_iouc           	      C   sT   t t ||d}t t ||d}t t ||t || |jdd}|S )z5
        Measure the distance of mid points.
        g       @g?r1   )r   divaddr3   r2   r4   )	r.   r5   r6   r7   r8   r%   Zmid_candidatesZmid_groundtruthZdistance_diffr   r   r    r)   \   s    zTvpLoss.loss_distancec           	      C   sB   t ||}t ||}t t t |||}|jdd}|S )z5
        Measure the difference of duration.
        g?r1   )r   subZsquarer;   r4   )	r.   r5   r6   r7   r8   r%   Zduration_candidatesZduration_groundtruthZduration_diffr   r   r    r*   h   s
    zTvpLoss.loss_durationc              
   C   st   |\}}}t ||}|dddf  |dddf   }}i }	| jD ]$}
|	|
| j|
 |||||i qJ|	S )am  
        This performs the loss computation.

        Args:
            logits (`torch.FloatTensor`):
                The output logits of head module.
            labels (`list[torch.FloatTensor]`):
                List of tensors ([start, end, duration]), which contains start time, end time of the video corresponding to the text, and also the duration.
        Nr   r   )r   mulfloatr-   updater+   )r.   r   labelsr%   r5   r6   
candidatesr7   r8   Zlosses_dictr   r   r   r    forwards   s    

*
zTvpLoss.forward)
r   r   r   r   r'   r(   r)   r*   rC   __classcell__r   r   r/   r    r!   :   s   

r!   c                       s$   e Zd Z fddZdd Z  ZS )TvpVisionModelc              	      s   t    t|| _|jd ur,|jjd }nXt| jdrVt| jjdrV| jjjd }n.t| jdr|t| jjdr|| jjj}nt	dt
j||jdddddd	| _d S )
Nconfighidden_sizeshidden_sizezBackbone config not foundr   r   F)kernel_sizestridepaddinggroupsbias)r&   r'   r   backboneZbackbone_configrH   hasattrrG   rI   r,   r   Conv2dgrid_encoder_conv)r.   rG   Zin_channelsr/   r   r    r'      s$    


zTvpVisionModel.__init__c                 C   s   |j \}}}}}||| |||}| |d d }| |}tjj|ddd}tjj|dd}|j dd  \}	}
}||||	|
|}|ddd	d
d}|S )NZfeature_mapsr      )rJ   rK   T)Zinplacer   r      )	shapeviewrO   rR   r   
functionalZ
max_pool2drelupermute)r.   pixel_values
batch_size
num_framesnum_channelsheightwidthZgrid_feat_outputsgridZnew_channelZ
new_heightZ	new_widthr   r   r    rC      s    
zTvpVisionModel.forwardr   r   r   r'   rC   rD   r   r   r/   r    rE      s   rE   c                       sX   e Zd ZdZ fddZejeeejdddZde	dd	d
Z
de	dddZ  ZS )TvpVisualInputEmbeddingz;
    Takes input of both image and video (multi-frame)
    c                    s   t    t|j|j| _t|j|j| _t|j	|j| _
td|j| _tj|j|jd| _t|j| _|j| _|j	| _	d S )Nr   eps)r&   r'   r   	Embeddingmax_position_embeddingsrI   position_embeddings max_grid_row_position_embeddingsrow_position_embeddings max_grid_col_position_embeddingscol_position_embeddingstoken_type_embeddings	LayerNormlayer_norm_eps
layer_normDropouthidden_dropout_probdropoutr.   rG   r/   r   r    r'      s    
z TvpVisualInputEmbedding.__init__)	embeddingr_   r`   returnc                 C   sl   d }}|| j kr|| j  }|| jkr0|| j }|dddd}tjj|||fddd}|dddd}|S )z
        This method allows to interpolate the pre-trained pad weights , to be able to use the model on collection of high
        resolution images (high resolution videos).

        r   r   r   rS   bicubicFZscale_factormodeZalign_corners)ri   rk   rZ   r   rX   interpolate)r.   ru   r_   r`   h0w0r   r   r    interpolate_pos_encoding   s    



z0TvpVisualInputEmbedding.interpolate_pos_encodingFr}   c                 C   s   |j \}}}}t| j|}tj|tj|jd}| |}	dt|j d  |d|f }
|	j	|
 }	t| j
|}tj|tj|jd}| |}|d||f}|j	| }|	| }|r|| jks|| j
kr|| ||| }n|| }|S )af  
        Args:
            grid: (batch_size, height, width, hidden_dim)
            interpolate_pos_encoding: (`bool`, *optional*, defaults to `False`):
                Whether to interpolate the pre-trained position encodings.
        Returns:
            grid + col_position_embeddings.view(*col_shape): (batch_size, *, height, width, hidden_dim)
        dtypedevice)r   r   r   )rV   r2   ri   r   arangelongr   rj   lenrW   rk   rl   r}   )r.   ra   r}   r\   r_   r`   Z
hidden_dim
row_heightZrow_position_idsrj   Z	row_shapeZ	row_widthZcol_position_idsrl   Z	col_shapeZpositional_embeddingsr   r   r    add_2d_positional_embeddings   s(    	



z4TvpVisualInputEmbedding.add_2d_positional_embeddingsc                 C   s   |j \}}}}}|d}| j||d}||d|}|j dd }	|j}
tj|	tj|
d}| |}|| }| 	|}| 
|}|S )a  
        Args:
            grid: Array of shape (batch_size, num_frames, height, width, num_channels).
                It contains processed frames extracted from videos, and is generated by Tvp image preprocessor. Note,
                num_frames can be 1
            interpolate_pos_encoding: (bool, *optional*, defaults to `False`):
                Whether to interpolate the pre-trained position encodings.

        Returns:
            embeddings: The embedding of grid with size (batch_size, height*width, num_channels)

        r   r~   rF   Nr   )rV   meanr   rW   r   r   zerosr   rm   rp   rs   )r.   ra   r}   r\   r]   r_   r`   r^   Zvisual_tokensZvisual_tokens_shaper   token_type_idsrm   
embeddingsr   r   r    rC     s    



zTvpVisualInputEmbedding.forward)F)F)r   r   r   r   r'   r   Tensorintr}   boolr   rC   rD   r   r   r/   r    rc      s
   )rc   c                       s*   e Zd ZdZ fddZdddZ  ZS )TvpTextInputEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    sl   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _d S )N)Zpadding_idxrd   )r&   r'   r   rf   Z
vocab_sizerI   Zpad_token_idword_embeddingsrg   rh   Ztype_vocab_sizerm   rn   ro   rp   rq   rr   rs   rt   r/   r   r    r'   &  s    
zTvpTextInputEmbeddings.__init__Nc                 C   s   |d ur|  }n|  d d }|d }|d ur8|jn|j}|d u rhtj|tj|d}|d|}|d u rtj|tj|d}|d u r| |}| 	|}| 
|}	|| |	 }
| |
}
| |
}
|
S )NrF   r   r   r   )sizer   r   r   r   Z	unsqueezeexpandr   r   rh   rm   rp   rs   )r.   	input_idsr   Zposition_idsZinputs_embedsZinput_shapeZ
seq_lengthr   rh   rm   r   r   r   r    rC   .  s$    





zTvpTextInputEmbeddings.forward)NNNNr   r   r   r   r'   rC   rD   r   r   r/   r    r   #  s   r   c                       sL   e Zd Z fddZdd ZejeedddZde	e
 d	d
dZ  ZS )TvpAttentionc                    s   t    |j|j dkr<t|ds<td|j d|j |j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _t	
|j|j| _t	j|j|jd| _t	|j| _t | _d S )Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads rd   )r&   r'   rI   num_attention_headsrP   r,   r   attention_head_sizeall_head_sizer   Linearquerykeyvaluerq   Zattention_probs_dropout_probattn_dropoutdensern   ro   rp   rr   rs   setpruned_headsrt   r/   r   r    r'   H  s     
zTvpAttention.__init__c                    s   t |dkrd S t| j| j}t|| j }|D ](  t fdd| jD   d| < q2|d	 
d}tt ||  }t| j|| _t| j|| _t| j|| _t| j|dd| _| jt | | _| j| j | _| j|| _d S )Nr   c                 3   s   | ]}| k rd ndV  qdS )r   r   Nr   ).0hheadr   r    	<genexpr>d      z+TvpAttention.prune_heads.<locals>.<genexpr>rF   r   dim)r   r   onesr   r   r   r   sumrW   
contiguouseqr   r   r   r   r   r   r   r   r:   )r.   headsmaskindexr   r   r    prune_heads]  s     
zTvpAttention.prune_heads)tensorsequence_lengthr\   c                 C   s    | ||| j| jdd S )Nr   rS   )rW   r   r   	transposer   )r.   r   r   r\   r   r   r    _reshapet  s    zTvpAttention._reshapeNoutput_attentionsc                 C   s  |j d d \}}| |}| |}| |}	| |||}
| |||}| |	||}t|
|dd}|t	| j
 }|d ur|| }tjj|dd}| |}|d ur|| }t||}|dd }|||| j}| |}| |}| || }|r||fn|f}|S )NrS   rF   r   r   )rV   r   r   r   r   r   matmulr   mathsqrtr   r   rX   Zsoftmaxr   r   reshaper   r   rs   rp   )r.   r   attention_mask	head_maskr   r\   r   Zmixed_query_layerZmixed_key_layerZmixed_value_layerZquery_layerZ	key_layerZvalue_layerZattention_scoresZattention_probsZattn_outputoutputsr   r   r    rC   {  s.    





zTvpAttention.forward)NNN)r   r   r   r'   r   r   r   r   r   r   r   rC   rD   r   r   r/   r    r   G  s   
   r   c                       s0   e Zd Z fddZejejdddZ  ZS )TvpIntermediatec                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S N)r&   r'   r   r   rI   intermediate_sizer   
isinstanceZ
hidden_actstrr   intermediate_act_fnrt   r/   r   r    r'     s
    
zTvpIntermediate.__init__r   rv   c                 C   s   |  |}| |}|S r   )r   r   )r.   r   r   r   r    rC     s    

zTvpIntermediate.forwardr   r   r   r'   r   r   rC   rD   r   r   r/   r    r     s   r   c                       s4   e Zd Z fddZejejejdddZ  ZS )TvpOutputLayerc                    sB   t    t|j|j| _tj|j|jd| _	t
|j| _d S )Nrd   )r&   r'   r   r   r   rI   r   rn   ro   rp   rq   rr   rs   rt   r/   r   r    r'     s    
zTvpOutputLayer.__init__)r   input_tensorrv   c                 C   s&   |  |}| |}| || }|S r   )r   rs   rp   )r.   r   r   r   r   r    rC     s    

zTvpOutputLayer.forwardr   r   r   r/   r    r     s   r   c                       s0   e Zd Z fddZdee dddZ  ZS )TvpEncodeLayerc                    s,   t    t|| _t|| _t|| _d S r   )r&   r'   r   	attentionr   intermediater   outputrt   r/   r   r    r'     s    


zTvpEncodeLayer.__init__Nr   c           
      C   sJ   | j ||||d}|d }|dd  }| |}| ||}	|	f| }|S )Nr   r   r   )r   r   r   )
r.   r   r   r   r   Zself_attention_outputsZattention_outputr   Zintermediate_outputZlayer_outputr   r   r    rC     s    

zTvpEncodeLayer.forward)NNN)r   r   r   r'   r   r   rC   rD   r   r   r/   r    r     s   	   r   c                       sD   e Zd Z fddZdeej ee ee ee dddZ  Z	S )
TvpEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r   )r   )r   _rG   r   r    
<listcomp>  r   z'TvpEncoder.__init__.<locals>.<listcomp>F)	r&   r'   rG   r   Z
ModuleListrangenum_hidden_layerslayerZgradient_checkpointingrt   r/   r   r    r'     s    
 zTvpEncoder.__init__N)r   r   output_hidden_statesreturn_dictc                 C   s   |d ur|n| j j}|d ur |n| j j}|d ur4|n| j j}d}d}t| jD ]B\}	}
|rd||f }|
||||	 |}|d }|rN||d f }qN|r||f }|s|f}|r||f }|r||f }|S t||r|nd |r|nd dS )Nr   r   r   )last_hidden_stater   r   )rG   r   r   r   	enumerater   r   )r.   r   r   r   r   r   r   Zall_hidden_statesZall_attentionsiZlayer_moduleZlayer_outputsr   r   r   r    rC     s6    	





zTvpEncoder.forward)NNNNN)
r   r   r   r'   r   r   r   r   rC   rD   r   r   r/   r    r     s   	     r   c                       s0   e Zd Z fddZejejdddZ  ZS )	TvpPoolerc                    s*   t    t|j|j| _t | _d S r   )r&   r'   r   r   rI   r   ZTanh
activationrt   r/   r   r    r'     s    
zTvpPooler.__init__r   c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )r.   r   Zfirst_token_tensorpooled_outputr   r   r    rC     s    

zTvpPooler.forwardr   r   r   r/   r    r     s   r   c                   @   s.   e Zd ZU eed< dZdZejdddZ	dS )TvpPreTrainedModelrG   modelT)modulec                 C   s0  t |tjtjfr*|jjjd| jjd n|t |tj	rR|j
j  |jjd nTt |tjrtjj|jddd |j
durtj|j
d nt |trtj|j t |tjr|j
dur|j
j  t|d	rtj|j t|d
rtj|j t|drtj|j t|dr,tj|j dS )zInitialize the weights        )r   Zstdg      ?Zfan_outrY   )ry   ZnonlinearityNr   pad_uppad_downpad_left	pad_right)r   r   r   rf   weightdataZnormal_rG   Zinitializer_rangern   rN   Zzero_Zfill_rQ   initZkaiming_normal_Z	constant_TvpModeltext_promptrP   r   r   r   r   )r.   r   r   r   r    _init_weights.  s*    



z TvpPreTrainedModel._init_weightsN)
r   r   r   r   r   Zbase_model_prefixZsupports_gradient_checkpointingr   Moduler   r   r   r   r    r   (  s   
r   c                       s(   e Zd ZdZ fddZdd Z  ZS )TvpFrameDownPadPrompterz>
    Pad frames extracted from videos only at the bottom.
    c              	      sb   |j dvrtdt   |j| _|j| _|j| _|j | _ tt	
d|jd|j|jg| _d S )Nr<   replaceremove9`visual_prompter_apply` must be in (add, replace, remove)r   r   )visual_prompter_applyr,   r&   r'   visual_prompt_sizeZ	frame_nummax_img_sizer   	Parameterr   randnr   rt   r/   r   r    r'   O  s    

z TvpFrameDownPadPrompter.__init__c                 C   s   | j dkrLtj| j| jg|j|jd}d|| j| j | jd d f< ||9 }| j dkrtj|jd |jd d| j| jg|jd}| j| j }| j	|d d d d d d || jd d f< ||
|j7 }|S )	Nr<   r   r   r   r   r   r   r   )r   r   r   r   r   r   r   r   rV   r   to)r.   r[   visual_prompt_maskpromptZstart_pointr   r   r    rC   ]  s    

*zTvpFrameDownPadPrompter.forwardr   r   r   r/   r    r   J  s   r   c                       sH   e Zd ZdZ fddZejeeejdddZde	dd	d
Z
  ZS )TvpFramePadPrompterz?
    Pad frames extracted from videos in the surroundings.
    c              
      s   |j dvrtdt   |j| _|j| _|j | _ |j|jd  | _t	t
d|jd|j|jg| _t	t
d|jd|j|jg| _t	t
d|jd|j|jd  |jg| _t	t
d|jd|j|jd  |jg| _d S )Nr   r   rS   r   r   )r   r,   r&   r'   r]   r   r   	base_sizer   r   r   r   r   r   r   r   rt   r/   r   r    r'   t  sB    

zTvpFramePadPrompter.__init__)r   r_   r`   rv   c                 C   sh   || j  || j   }}|j\}}}}	}
||| ||	|
}tjj|||fddd}||||||}|S )z
        This method allows to interpolate the pre-trained pad weights, to be able to use the model on collection of high
        resolution images (high resolution videos).

        rw   Frx   )r   rV   r   r   rX   rz   )r.   r   r_   r`   r{   r|   batchr]   ZchannelsZprompt_heightZprompt_widthr   r   r    interpolate_pad_encoding  s    z,TvpFramePadPrompter.interpolate_pad_encodingFr   c                 C   s   |r|j d |j d fn
| j| jf\}}| jdvrBtd| j | jdv rltj||g|j|jd}||9 }| jdv rtjd| j	d	| j
| j
|jd
}tj| j|| jgdd}tj| j|| jgd	d}t|d|g }|r| |||}|||j }|S )Nr   rF   )r<   r   r   z$Invalid visual_prompter_apply value )r   r   r   )r   r<   r   r   r   rU   r   r   )rV   r   r   r,   r   r   r   r   r   r]   r   catr   r   r   r   r   r   r   )r.   r[   r   r_   r`   r   baser   r   r   r    rC     s$    



zTvpFramePadPrompter.forward)F)r   r   r   r   r'   r   r   r   r   r   rC   rD   r   r   r/   r    r   o  s   &r   )ZframedownpadZframepadzw
    The bare Tvp Model transformer outputting BaseModelOutputWithPooling object without any specific head on top.
    )Zcustom_introc                       sz   e Zd Z fddZdd Zdd Zdd Zedee	j
 ee	j ee	j
 ee	j ee ee ee edddZ  ZS )r   c                    s   t  | || _t|| _t|| _t|| _t	|| _
t|| _ttdd|jg| _t|j| _|jtvr~tdt|j || _|   d S )Nr   
   z:`visual_prompter_type` must be in (framedownpad, framepad))r&   r'   rG   rE   vision_modelr   r   rc   visual_embeddingsr   encoderr   poolerr   r   r   r   rI   r   rq   rr   rs   Zvisual_prompter_typeTVP_PROMPTER_CLASSES_MAPPINGr,   visual_prompter	post_initrt   r/   r   r    r'     s    





zTvpModel.__init__c                 C   s   | j jS r   r   r   )r.   r   r   r    get_input_embeddings  s    zTvpModel.get_input_embeddingsc                 C   s   || j _d S r   r  )r.   r   r   r   r    set_input_embeddings  s    zTvpModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )zPrunes heads of the model.
        heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel
        N)itemsr  r   r   r   )r.   Zheads_to_pruner   r   r   r   r    _prune_heads  s    zTvpModel._prune_headsNF)r   r[   r   r   r   r   r   r}   c	                 C   sV  |dur|n| j j}| | j||d}| j|d}	| j||d}
|dur||
jdd }t	|jd dj
|j|jd}tj|||gd	d
}| || 
|j}| j|	jd d	d	}tj||	|
gdd
}| j||| || j j|||d}|r|jn|d }| |}| |}| |}|sB||f|dd  S t|||j|jdS )a  
        Examples:
        ```python
        >>> import torch
        >>> from transformers import AutoConfig, AutoTokenizer, TvpModel

        >>> model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp")

        >>> tokenizer = AutoTokenizer.from_pretrained("Jiqing/tiny-random-tvp")

        >>> pixel_values = torch.rand(1, 1, 3, 448, 448)
        >>> text_inputs = tokenizer("This is an example input", return_tensors="pt")
        >>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask)
        ```Nr   )r   r~   rS   r   r   )r   r   rF   r   r   )r   r   r   r   r   )r   pooler_outputr   r   )rG   r   r   r  r   r  Znew_onesrV   r   r   r   r   r   r   Zget_extended_attention_maskr   r   r   r  Zget_head_maskr   r   r  rs   r	   r   r   )r.   r   r[   r   r   r   r   r   r}   Ztext_embedding_outputZvisual_embedding_outputZvisual_attention_maskZpt_maskr   Zembedding_outputZencoder_outputsr   r   r   r   r    rC     sJ    


zTvpModel.forward)NNNNNNNF)r   r   r   r'   r  r	  r  r   r   r   
LongTensorr   r   rC   rD   r   r   r/   r    r     s.           r   c                       s$   e Zd Z fddZdd Z  ZS )TvpVideoGroundingHeadc                    sL   t    t|j|jd | _t|jd d| _t | _t	 | _
d S )NrS   )r&   r'   r   r   rI   layer_0layer_1ZReLUactivation_0ZSigmoidactivation_1rt   r/   r   r    r'   =  s
    

zTvpVideoGroundingHead.__init__c                 C   s$   |  | |}| | |}|S r   )r  r  r  r  )r.   r  r   r   r   r    rC   D  s    zTvpVideoGroundingHead.forwardrb   r   r   r/   r    r  <  s   r  zb
    Tvp Model with a video grounding head on top computing IoU, distance, and duration loss.
    c                       sn   e Zd Z fddZedeej eej eej ee	ej
  eej ee ee ee ed	ddZ  ZS )	TvpForVideoGroundingc                    s2   t  | || _t|| _t|| _|   d S r   )r&   r'   rG   r   r   r  video_grounding_headr  rt   r/   r   r    r'   P  s
    

zTvpForVideoGrounding.__init__NF)	r   r[   r   rA   r   r   r   r   r}   c
              
   C   s   |dur|n| j j}| j||||||||	d}
|
d }| |}d}|durtg d}|| j |||}|d | j j|d   | j j|d   }|s|f|
dd  }
|dur|f|
 }
|
S t	|||
j
|
jd	S )
a  
        labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*):
            The labels contains duration, start time, and end time of the video corresponding to the text.

        Examples:
        ```python
        >>> import torch
        >>> from transformers import AutoConfig, AutoTokenizer, TvpForVideoGrounding

        >>> model = TvpForVideoGrounding.from_pretrained("Jiqing/tiny-random-tvp")

        >>> tokenizer = AutoTokenizer.from_pretrained("Jiqing/tiny-random-tvp")

        >>> pixel_values = torch.rand(1, 1, 3, 448, 448)
        >>> text_inputs = tokenizer("This is an example input", return_tensors="pt")
        >>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask)
        ```N)r   r   r   r   r}   r   r"   r#   r$   r%   rS   )r   r   r   r   )rG   r   r   r  r!   r   r   Zdistance_loss_weightZduration_loss_weightr   r   r   )r.   r   r[   r   rA   r   r   r   r   r}   r   r  r   r   	criterionZ	loss_dictr   r   r    rC   X  sF    



zTvpForVideoGrounding.forward)	NNNNNNNNF)r   r   r   r'   r   r   r   r  r   r   r   r   rC   rD   r   r   r/   r    r  J  s,            r  )r   r   r  )3r   r   dataclassesr   typingr   r   Ztorch.utils.checkpointr   Zactivationsr   Zmodeling_layersr   Zmodeling_outputsr   r	   r
   Zmodeling_utilsr   Zpytorch_utilsr   utilsr   r   Zutils.backbone_utilsr   Zconfiguration_tvpr   Z
get_loggerr   loggerr   r   r!   rE   rc   r   r   r   r   r   r   r   r   r   r   r  r   r  r  __all__r   r   r   r    <module>   sZ   
P(q$c6!%[hM