a
    hj                     @   sF  d dl mZmZ d dlZd dlZd dlmZmZ d dlm	Z	m
Z
mZmZmZmZmZ ddlmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZ eeZG dd deZ G dd deZ!G dd deZ"G dd deZ#G dd deZ$G dd deZ%G dd de
Z&G dd deZ'G dd de	Z(g dZ)dS )     )OptionalUnionN)InstructBlipQFormerConfigInstructBlipVisionConfig)$InstructBlipForConditionalGeneration/InstructBlipForConditionalGenerationModelOutputInstructBlipModelInstructBlipPreTrainedModelInstructBlipQFormerModelInstructBlipVisionModelTransformersKwargs   )PretrainedConfig)FlashAttentionKwargs)!MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)Unpack)logging   )CONFIG_MAPPING
AutoConfigc                   @   s   e Zd ZdS )InstructBlipVideoVisionConfigN__name__
__module____qualname__ r   r   {/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/instructblipvideo/modular_instructblipvideo.pyr   .   s   r   c                   @   s   e Zd ZdS )InstructBlipVideoQFormerConfigNr   r   r   r   r   r   2   s   r   c                       sP   e Zd ZdZdZddiZeeedZ	d fdd		Z
eeeed
ddZ  ZS )InstructBlipVideoConfiga
  
    [`InstructBlipVideoConfig`] is the configuration class to store the configuration of a
    [`InstructBlipVideoForConditionalGeneration`]. It is used to instantiate a Instructblipvideo model according to the specified
    arguments, defining the vision model, Q-Former model and language model configs. Instantiating a configuration with
    the defaults will yield a similar configuration to that of the Instructblipvideo
    [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        vision_config (`dict`, *optional*):
            Dictionary of configuration options used to initialize [`InstructBlipVideoVisionConfig`].
        qformer_config (`dict`, *optional*):
            Dictionary of configuration options used to initialize [`InstructBlipVideoQFormerConfig`].
        text_config (`dict`, *optional*):
            Dictionary of configuration options used to initialize any [`PretrainedConfig`].
        num_query_tokens (`int`, *optional*, defaults to 32):
            The number of query tokens passed through the Transformer.

        video_token_index (`int`, *optional*):
            Token index of special video token.
        kwargs (*optional*):
            Dictionary of keyword arguments.

    Example:

    ```python
    >>> from transformers import (
    ...     InstructBlipVideoVisionConfig,
    ...     InstructBlipVideoQFormerConfig,
    ...     OPTConfig,
    ...     InstructBlipVideoConfig,
    ...     InstructBlipVideoForConditionalGeneration,
    ... )

    >>> # Initializing a InstructBlipVideoConfig with Salesforce/instruct-blip-flan-t5 style configuration
    >>> configuration = InstructBlipVideoConfig()

    >>> # Initializing a InstructBlipVideoForConditionalGeneration (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration
    >>> model = InstructBlipVideoForConditionalGeneration(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config

    >>> # We can also initialize a InstructBlipVideoConfig from a InstructBlipVideoVisionConfig, InstructBlipVideoQFormerConfig and any PretrainedConfig

    >>> # Initializing Instructblipvideo vision, Instructblipvideo Q-Former and language model configurations
    >>> vision_config = InstructBlipVideoVisionConfig()
    >>> qformer_config = InstructBlipVideoQFormerConfig()
    >>> text_config = OPTConfig()

    >>> config = InstructBlipVideoConfig.from_text_vision_configs(vision_config, qformer_config, text_config)
    ```Zinstructblipvideovideo_token_idvideo_token_index)text_configqformer_configvision_configN    c                    s   t  jf i | |d u r(i }td |d u r>i }td |d u rTi }td tf i || _tf i || _|dd}t	| f i || _
|| _|| _| jj| j_| j
jtv | _d| _d| _d S )NzZvision_config is None. initializing the InstructBlipVideoVisionConfig with default values.z\qformer_config is None. Initializing the InstructBlipVideoQFormerConfig with default values.zTtext_config is None. Initializing the text config with default values (`OPTConfig`).
model_typeoptg      ?g{Gz?)super__init__loggerinfor   r#   r   r"   getr   r!   num_query_tokensr    Zhidden_sizeZencoder_hidden_sizer%   r   use_decoder_only_language_modelZinitializer_factorZinitializer_range)selfr#   r"   r!   r,   r    kwargsZtext_model_type	__class__r   r   r(   x   s(    	


z InstructBlipVideoConfig.__init__r#   r"   r!   c                 K   s"   | f |  |  |  d|S )a  
        Instantiate a [`InstructBlipVideoConfig`] (or a derived class) from a InstructBlipVideo vision model, Q-Former and
        language model configurations.

        Returns:
            [`InstructBlipVideoConfig`]: An instance of a configuration object
        r2   )to_dict)clsr#   r"   r!   r/   r   r   r    from_vision_qformer_text_configs   s    z8InstructBlipVideoConfig.from_vision_qformer_text_configs)NNNr$   N)r   r   r   __doc__r%   Zattribute_mapr   r   r   Zsub_configsr(   classmethodr   r5   __classcell__r   r   r0   r   r   6   s&   7     #r   c                   @   s   e Zd ZdS ) InstructBlipVideoPreTrainedModelNr   r   r   r   r   r9      s   r9   c                   @   s   e Zd ZdS )InstructBlipVideoVisionModelNr   r   r   r   r   r:      s   r:   c                   @   s   e Zd ZdS )InstructBlipVideoQFormerModelNr   r   r   r   r   r;      s   r;   c                   @   s   e Zd ZdS )4InstructBlipVideoForConditionalGenerationModelOutputNr   r   r   r   r   r<      s   r<   c                   @   s|   e Zd Zdejejeej eej eej eej eej eej ee ee ee eee e	e
 eeef dddZdS )InstructBlipVideoModelNF)pixel_valuesqformer_input_idsqformer_attention_mask	input_idsattention_maskdecoder_input_idsdecoder_attention_maskinputs_embedsoutput_attentionsoutput_hidden_statesreturn_dictinterpolate_pos_encoding	use_cacher/   returnc                 K   s>  |d ur|n| j j}|j\}}}}}||| |||}| j||	|
||d}|d }tj| d d tj|j	d}| j
|jd dd}tj| d d tj|j	d}|d u rt|}|j|dd}|j|dd}tj||gdd}| j||||||	|
|d}|d d d d |dd d f }| |}||| j j| d}|d u r| j |}|| j jk}|d u rt|}n.||  tj| j jtj|j	dk}|d}|d||j	}||j	|j}|||}| j jr| jf |||	|
||d|}n"| jf |||||	|
||d	|}t|||d
S )N)r>   rF   rG   rH   rI   r   dtypedevicedim   )rA   rB   query_embedsencoder_hidden_statesencoder_attention_maskrF   rG   rH   rE   rB   rF   rG   rH   rJ   )rE   rB   rC   rD   rF   rG   rH   rJ   )vision_outputsqformer_outputslanguage_model_outputs)configuse_return_dictshapereshapevision_modeltorchonessizelongrO   query_tokensexpand	ones_likerepeat_interleavecatqformerlanguage_projectionr,   language_modelget_input_embeddingsr   tensorall	unsqueeze	expand_astorN   masked_scatterr-   r<   )r.   r>   r?   r@   rA   rB   rC   rD   rE   rF   rG   rH   rI   rJ   r/   
batch_sizeframeschannelheightwidthrW   image_embedsimage_attention_maskrc   query_attention_maskquery_outputsquery_outputlanguage_model_inputsspecial_image_maskoutputsr   r   r   forward   s      

$





	zInstructBlipVideoModel.forward)NNNNNNNNNFN)r   r   r   r_   FloatTensorr   
LongTensorZTensorboolr   r   r   tupler<   r   r   r   r   r   r=      s8              
r=   c                   @   s4  e Zd Zdejejeej ee ee dddZdejejeej ee ee dddZ	ejejdd	d
Z
dejejeej eej eej eej eej eej ee ee eej ee eee ee eeef dddZe dejeej eej eej eej eej eejdddZdS ))InstructBlipVideoForConditionalGenerationNF)r>   r?   r@   rI   rH   c                 C   s@  |j \}}}}	}
||| ||	|
}| j||dd}|d }tj| dd tj|jd}| j	|j d dd}tj| dd tj|jd}|du rt
|}|j|dd}|j|dd}tj||gdd}| j|||||dd	}|d ddd|dddf }| |}||| jj| d}|r<|||fS |S )
a$  
        Encodes images into continuous embeddings that can be forwarded to the language model.

        Args:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
                The tensors corresponding to the input images.
        T)r>   rI   rH   r   NrL   rM   rP   rR   )rA   rB   rS   rT   rU   rH   )r\   r]   r^   r_   r`   ra   rb   rO   rc   rd   re   rf   rg   rh   ri   rZ   r,   )r.   r>   r?   r@   rI   rH   rr   rs   rt   ru   rv   rW   rw   rx   rc   ry   rz   r{   r|   r   r   r   get_video_features1  s<      
$

z<InstructBlipVideoForConditionalGeneration.get_video_featuresc                 C   s   d S )Nr   )r.   r>   r?   r@   rI   rH   r   r   r   get_image_featuresm  s    z<InstructBlipVideoForConditionalGeneration.get_image_features)rA   rE   c                 C   s`   |du r8||   tj| jjtj|jdk}|d}n|| jjk}|d	|
|j}|S )zZ
        Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
        NrM   rL   )rk   r_   rl   rZ   r   rb   rO   rm   rn   ro   rp   )r.   rA   rE   r}   r   r   r   get_placeholder_maskw  s    z>InstructBlipVideoForConditionalGeneration.get_placeholder_mask)r>   r?   r@   rA   rB   rC   rD   rE   rF   rG   labelsrH   rI   rJ   r/   rK   c                 K   sd  |dur|n| j j}| j||||dd\}}}|s:| n|}|sJ| n|}|du rb|  |}|du rtt|}||j|j	}| j
||d}|||}| j jr| jf |||	|
||d|}|r|jn|d }d}|durR| jf ||| j jjd|}nL| jf |||||	|
|||d	|}|r6|jn|d }|rJ|jn|d	 }t|||||d
S )a  
        qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length)):
            The sequence used as a prompt to be fed to the Q-Former module.
        qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
            Mask to avoid performing attention on padding token indices.

        Examples:

        ```python
        >>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration
        >>> import torch
        >>> from huggingface_hub import hf_hub_download
        >>> import av
        >>> import numpy as np

        >>> def read_video_pyav(container, indices):
        ...     '''
        ...     Decode the video with PyAV decoder.
        ...     Args:
        ...         container (`av.container.input.InputContainer`): PyAV container.
        ...         indices (`list[int]`): List of frame indices to decode.
        ...     Returns:
        ...         result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
        ...     '''
        ...     frames = []
        ...     container.seek(0)
        ...     start_index = indices[0]
        ...     end_index = indices[-1]
        ...     for i, frame in enumerate(container.decode(video=0)):
        ...         if i > end_index:
        ...             break
        ...         if i >= start_index and i in indices:
        ...             frames.append(frame)
        ...     return np.stack([x.to_ndarray(format="rgb24") for x in frames])

        >>> model = InstructBlipVideoForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", device_map="auto")
        >>> processor = InstructBlipVideoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")

        >>> file_path = hf_hub_download(
        ...       repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
        ... )
        >>> container = av.open(file_path)

        >>> # sample uniformly 4 frames from the videWhy is this video funny?o
        >>> total_frames = container.streams.video[0].frames
        >>> indices = np.arange(0, total_frames, total_frames / 4).astype(int)
        >>> clip = read_video_pyav(container, indices)

        >>> prompt = "What is happening in the video?"
        >>> inputs = processor(text=prompt, images=clip, return_tensors="pt").to(model.device)

        >>> outputs = model.generate(
        ...     **inputs,
        ...     do_sample=False,
        ...     num_beams=5,
        ...     max_length=256,
        ...     repetition_penalty=1.5,
        ...     length_penalty=1.0,
        ... )
        >>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
        >>> print(generated_text)
        "A person is eating a bowl of pasta, and they are using a fork to eat it. The person is sitting at a table, and the plate of pasta is on the table in front"
        ```NTr?   r@   rI   rH   rE   rV   r   )logitsr   
vocab_size)	rE   rB   rC   rD   rF   rG   rH   r   rJ   rR   )lossr   rW   rX   rY   )rZ   r[   r   Zto_tuplerk   r_   re   rp   rO   rN   r   rq   r-   rj   r   Zloss_functionr!   r   r   r<   )r.   r>   r?   r@   rA   rB   rC   rD   rE   rF   rG   r   rH   rI   rJ   r/   r|   rW   rz   r}   r~   r   r   r   r   r   r     st    Q

	

z1InstructBlipVideoForConditionalGeneration.forward)r>   r?   r@   rA   rB   rE   rI   rK   c                 K   s  t | dr|   |jd }	| j||||dd\}
}}|du r|du r| jjg| jj d }|| jjjg }t	j
|gt	j|jd}||	d}|  |}|du rt	|}|
|j|j}
| j||d	}|||
}||d
}| jjjs||d< | jjf i ||}|S )a  
        Overrides `generate` function to be able to use the model as a conditional generator.

        Args:
            pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width) or
                (batch_size, num_frames, num_channels, height, width)): Input images or videos to be processed.
            qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
                The sequence used as a prompt to be fed to the Q-Former module.
            qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
                Mask to avoid performing attention on padding token indices.
            input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
                The sequence used as a prompt for the generation.
            attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
                Mask to avoid performing attention on padding token indices.
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Embedded representation of the inputs. Should be float, not int tokens.
            interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
                Whether to interpolate the positional encoding of the image embeddings.

        Returns:
            captions (list): A list of strings of length batch_size * num_captions.
        Zhf_device_mapr   Tr   N   rM   rR   r   )rE   rB   rA   )hasattrZ_preprocess_accelerater\   r   rZ   r    r,   r!   Zbos_token_idr_   rl   rb   rO   repeatrk   re   rp   rN   r   rq   rj   Zis_encoder_decodergenerate)r.   r>   r?   r@   rA   rB   rE   rI   Zgenerate_kwargsrr   r|   rW   rz   Zvideo_tokensZstart_tokensr}   inputsr~   r   r   r   r     s6    "




z2InstructBlipVideoForConditionalGeneration.generate)NFF)NFF)NNNNNNNNNNFN)NNNNNF)r   r   r   r_   r   r   r   r   r   r   r   r   r   r   r   r<   r   Zno_gradr   r   r   r   r   r   0  s      @   
            
       r   )r   r   r   r:   r9   r;   r=   r   )*typingr   r   r_   Ztorch.utils.checkpointZ;transformers.models.instructblip.configuration_instructblipr   r   Z6transformers.models.instructblip.modeling_instructblipr   r   r   r	   r
   r   r   Zconfiguration_utilsr   Zmodeling_flash_attention_utilsr   Zmodels.auto.modeling_autor   Zprocessing_utilsr   utilsr   autor   r   Z
get_loggerr   r)   r   r   r   r9   r:   r;   r<   r=   r   __all__r   r   r   r   <module>   s.   $

}m  /