a
    hM                     @   s   d dl mZmZ ddlmZ ddlmZmZ ddlm	Z	m
Z
mZmZ ddlmZmZ ddlmZ e rnd dlZG d	d
 d
e
ddZG dd deZdgZdS )    )OptionalUnion   )BatchFeature)
ImageInputis_valid_image)MultiModalDataProcessingKwargsProcessorMixinUnpack)PreTokenizedInput	TextInput)is_torch_availableNc                   @   s&   e Zd ZddidddddidZd	S )
ColQwen2ProcessorKwargspaddingZlongestZchannels_firstT)Zdata_formatZdo_convert_rgbZreturn_tensorspt)text_kwargsimages_kwargsZcommon_kwargsN)__name__
__module____qualname__	_defaults r   r   l/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/colqwen2/processing_colqwen2.pyr   #   s   r   F)totalc                       s  e Zd ZdZddgZdZdZd"ee ee d fdd	Z	d#e
eeeee ee f ee ed
ddZd$ddZeedddZd%e
ee edddZeeee f ee edddZd&eded f eded f eed edef ddddZed d! Z  ZS )'ColQwen2Processora  
    Constructs a ColQwen2 processor which wraps a Qwen2VLProcessor and special methods to process images and queries, as
    well as to compute the late-interaction retrieval score.

    [`ColQwen2Processor`] offers all the functionalities of [`Qwen2VLProcessor`]. See the [`~Qwen2VLProcessor.__call__`]
    for more information.

    Args:
        image_processor ([`Qwen2VLImageProcessor`], *optional*):
            The image processor is a required input.
        tokenizer ([`Qwen2TokenizerFast`], *optional*):
            The tokenizer is a required input.
        chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
            in a chat into a tokenizable string.
        visual_prompt_prefix (`str`, *optional*): A string that gets tokenized and prepended to the image tokens.
        query_prefix (`str`, *optional*): A prefix to be used for the query.
    image_processor	tokenizerZAutoImageProcessor)ZQwen2TokenizerZQwen2TokenizerFastN)visual_prompt_prefixquery_prefixc                    sf   t  j|||d t|ds dn|j| _t|ds6dn|j| _|d u rJd}|| _|d u r\d}|| _d S )N)chat_templateimage_tokenz<|image_pad|>video_tokenz<|video_pad|>zf<|im_start|>user
<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>zQuery: )super__init__hasattrr!   r"   r   r   )selfr   r   r    r   r   kwargs	__class__r   r   r$   H   s    	zColQwen2Processor.__init__)imagestextr'   returnc                 K   s  | j tfd| jji|}|d dd}|du}|du rJ|du rJtd|durb|durbtd|dur0t|r||g}nHt|trt|d rn0t|trt|d trt|d d std| j	gt
| }	| jf d	|i|d
 }
|
d }|dur|| jjd }d}tt
|	D ]`}| j|	| v rb|	| | jd||  |  d|	|< |d7 }q|	| d| j|	|< q| j|	fddi|d }ti ||
d}|d dddf |d dddf  }tt|d | }tjjjj|dd|d< |r,|d |d dkd}|d|i |S |durt|trN|g}n$t|trjt|d tsrtd|du r| jd }g }|D ]}| j| | }|| q| j|fddi|d }|S dS )a	  
        Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is a custom
        wrapper around the Qwen2VLProcessor's [`~Qwen2VLProcessor.__call__`] method adapted for the ColQwen2 model. It cannot process
        both text and images at the same time.

        When preparing the the text(s), this method forwards the `text` and `kwargs` arguments to Qwen2TokenizerFast's
        [`~Qwen2TokenizerFast.__call__`].
        When preparing the the image(s), this method forwards the `images` and `kwargs` arguments to Qwen2VLImageProcessor's
        [`~Qwen2VLImageProcessor.__call__`].
        Please refer to the doctsring of the above two methods for more information.

        Args:
            images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
                The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
                tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
                number of channels, H and W are image height and width.
            text (`str`, `list[str]`, `list[list[str]]`):
                The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
                (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
                `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
            return_tensors (`str` or [`~utils.TensorType`], *optional*):
                If set, will return tensors of a particular framework. Acceptable values are:

                - `'tf'`: Return TensorFlow `tf.constant` objects.
                - `'pt'`: Return PyTorch `torch.Tensor` objects.
                - `'np'`: Return NumPy `np.ndarray` objects.
                - `'jax'`: Return JAX `jnp.ndarray` objects.

        Returns:
            [`BatchFeature`]: A [`BatchFeature`] with the following fields:

            - **input_ids** -- List of token ids to be fed to a model.
            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
              `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
              `None`).
            - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
        Ztokenizer_init_kwargsr   suffixNz&Either text or images must be providedz5Only one of text or images can be processed at a timer   zAimages must be an image, list of images or list of list of imagesr*   r   image_grid_thw   z<|placeholder|>   return_token_type_idsF)datapixel_valuesT)batch_firstZ	input_idsZtoken_type_idsilabelsz*Text must be a string or a list of strings
   )Z_merge_kwargsr   r   Zinit_kwargspop
ValueErrorr   
isinstancelistr   lenr   
merge_sizeranger!   replaceprodr   torchsplittolistnnutilsrnnpad_sequenceZmasked_fillupdatestrquery_augmentation_tokenr   append)r&   r*   r+   ZaudioZvideosr'   Zoutput_kwargsr-   r1   Z	texts_docZimage_inputsr.   Zmerge_lengthindexiZtext_inputsZreturn_dataoffsetsr3   r5   Ztexts_queryqueryZaugmented_queryZbatch_queryr   r   r   __call__]   s    -
(
(




zColQwen2Processor.__call__c                    s|   i }|durnt jdi   |  ddp6jj fdd|D }fdd|D }|||d tf i |S )a  
        Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
        Args:
            image_sizes (`list[list[int]]`, *optional*):
                The input sizes formatted as (height, width) per each image.
        Returns:
            `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
            input modalities, along with other useful data.
        Nr   r<   c                    s"   g | ]}j jg | R  qS r   )r   Zget_number_of_image_patches).0Z
image_size)r   r&   r   r   
<listcomp>   s   z@ColQwen2Processor._get_num_multimodal_tokens.<locals>.<listcomp>c                    s   g | ]}| d   qS )r/   r   )rP   Znum_patches)r<   r   r   rQ          )num_image_tokensnum_image_patches)r   r   getrG   r   r<   r   )r&   Zimage_sizesr'   Zvision_datarT   rS   r   )r   r<   r&   r   _get_num_multimodal_tokens   s    
z,ColQwen2Processor._get_num_multimodal_tokens)r,   c                 C   s   | j jS )z
        Return the query augmentation token.

        Query augmentation buffers are used as reasoning buffers during inference.
        )r   Z	pad_token)r&   r   r   r   rI      s    z*ColQwen2Processor.query_augmentation_token)r*   r'   r,   c                 K   s   | j f d|i|S )a  
        Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColQwen2Processor's
        [`ColQwen2Processor.__call__`].

        This method forwards the `images` and `kwargs` arguments to the image processor.

        Args:
            images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
                The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
                tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
                number of channels, H and W are image height and width.
            return_tensors (`str` or [`~utils.TensorType`], *optional*):
                If set, will return tensors of a particular framework. Acceptable values are:

                - `'tf'`: Return TensorFlow `tf.constant` objects.
                - `'pt'`: Return PyTorch `torch.Tensor` objects.
                - `'np'`: Return NumPy `np.ndarray` objects.
                - `'jax'`: Return JAX `jnp.ndarray` objects.

        Returns:
            [`BatchFeature`]: A [`BatchFeature`] with the following fields:

            - **input_ids** -- List of token ids to be fed to a model.
            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
              `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
              `None`).
            - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
        r*   rO   )r&   r*   r'   r   r   r   process_images  s    !z ColQwen2Processor.process_images)r+   r'   r,   c                 K   s   | j f d|i|S )a  
        Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColQwen2Processor's
        [`ColQwen2Processor.__call__`].

        This method forwards the `text` and `kwargs` arguments to the tokenizer.

        Args:
            text (`str`, `list[str]`, `list[list[str]]`):
                The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
                (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
                `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
            return_tensors (`str` or [`~utils.TensorType`], *optional*):
                If set, will return tensors of a particular framework. Acceptable values are:

                - `'tf'`: Return TensorFlow `tf.constant` objects.
                - `'pt'`: Return PyTorch `torch.Tensor` objects.
                - `'np'`: Return NumPy `np.ndarray` objects.
                - `'jax'`: Return JAX `jnp.ndarray` objects.

        Returns:
            [`BatchFeature`]: A [`BatchFeature`] with the following fields:

            - **input_ids** -- List of token ids to be fed to a model.
            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
              `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
              `None`).
        r+   rW   )r&   r+   r'   r   r   r   process_queries(  s     z!ColQwen2Processor.process_queries   cpuztorch.Tensorztorch.dtypeztorch.device)query_embeddingspassage_embeddings
batch_sizeoutput_dtypeoutput_devicer,   c              	   C   s@  t |dkrtdt |dkr(td|d j|d jkrDtd|d j|d jkr`td|du rr|d j}g }tdt ||D ]}g }tjjjj	||||  ddd}	tdt ||D ]N}
tjjjj	||
|
|  ddd}|
td	|	|jd
dd jdd q|
tj|dd|| qtj|ddS )a[  
        Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
        query embeddings (`qs`) and passage embeddings (`ps`). For ColQwen2, a passage is the
        image of a document page.

        Because the embedding tensors are multi-vector and can thus have different shapes, they
        should be fed as:
        (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
        (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
            obtained by padding the list of tensors.

        Args:
            query_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Query embeddings.
            passage_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Passage embeddings.
            batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
            output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor.
                If `None`, the dtype of the input embeddings is used.
            output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor.

        Returns:
            `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
            tensor is saved on the "cpu" device.
        r   zNo queries providedzNo passages providedz/Queries and passages must be on the same devicez-Queries and passages must have the same dtypeNT)r4   Zpadding_valuezbnd,csd->bcnsr   )dimr/   r0   )r;   r8   ZdeviceZdtyper=   r@   rC   rD   rE   rF   rJ   Zeinsummaxsumcatto)r&   r\   r]   r^   r_   r`   ZscoresrL   Zbatch_scoresZbatch_queriesjZbatch_passagesr   r   r   score_retrievalJ  s2     


 "z!ColQwen2Processor.score_retrievalc                 C   s&   | j j}| jj}dd |D }|| S )Nc                 S   s   g | ]}|d vr|qS ))Zpixel_values_videosZvideo_grid_thwr   )rP   namer   r   r   rQ     s   z7ColQwen2Processor.model_input_names.<locals>.<listcomp>)r   model_input_namesr   )r&   Ztokenizer_input_namesZimage_processor_input_namesr   r   r   ri     s    z#ColQwen2Processor.model_input_names)NNNNN)NNNN)N)N)rZ   Nr[   )r   r   r   __doc__
attributesZimage_processor_classZtokenizer_classr   rH   r$   r   r   r   r   r:   r   r   r   rO   rV   propertyrI   rX   rY   intrg   ri   __classcell__r   r   r(   r   r   0   sd             

 %&   
@r   )typingr   r   Zfeature_extraction_utilsr   Zimage_utilsr   r   Zprocessing_utilsr   r	   r
   r   Ztokenization_utils_baser   r   rD   r   r@   r   r   __all__r   r   r   r   <module>   s     i