a
    h=                     @   s   d dl mZmZ d dlmZmZmZ ddlmZ ddl	m
Z
mZ ddlmZmZ ddlmZmZ ddlmZmZ e r~d d	lZeeZG d
d deddZG dd deZdgZd	S )    )OptionalUnion)IMAGE_TOKENPaliGemmaProcessorbuild_string_from_input   )BatchFeature)
ImageInputmake_flat_list_of_images)ProcessingKwargsUnpack)PreTokenizedInput	TextInput)is_torch_availableloggingNc                   @   s&   e Zd ZddidddddidZd	S )
ColPaliProcessorKwargspaddingZlongestZchannels_firstT)Zdata_formatZdo_convert_rgbZreturn_tensorspt)text_kwargsimages_kwargsZcommon_kwargsN)__name__
__module____qualname__	_defaults r   r   g/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/colpali/modular_colpali.pyr   #   s   r   F)totalc                       s   e Zd ZdZdeed fddZeedd	d
Zdee	e
eee
 ee f ee edddZdeee edddZe	e
ee
 f ee edddZde	ded f e	ded f eed e	def ddddZ  ZS ) ColPaliProcessora  
    Constructs a ColPali processor which wraps a PaliGemmaProcessor and special methods to process images and queries, as
    well as to compute the late-interaction retrieval score.

    [`ColPaliProcessor`] offers all the functionalities of [`PaliGemmaProcessor`]. See the [`~PaliGemmaProcessor.__call__`]
    for more information.

    Args:
        image_processor ([`SiglipImageProcessor`], *optional*):
            The image processor is a required input.
        tokenizer ([`LlamaTokenizerFast`], *optional*):
            The tokenizer is a required input.
        chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
            in a chat into a tokenizable string.
        visual_prompt_prefix (`str`, *optional*, defaults to `"Describe the image."`):
            A string that gets tokenized and prepended to the image tokens.
        query_prefix (`str`, *optional*, defaults to `"Question: "`):
            A prefix to be used for the query.
    NDescribe the image.
Question: )visual_prompt_prefixquery_prefixc                    s"   t  j|||d || _|| _d S )N)image_processor	tokenizerchat_template)super__init__r    r!   )selfr"   r#   r$   r    r!   	__class__r   r   r&   E   s    zColPaliProcessor.__init__)returnc                 C   s   | j jS )z
        Return the query augmentation token.

        Query augmentation buffers are used as reasoning buffers during inference.
        )r#   Z	pad_tokenr'   r   r   r   query_augmentation_tokenQ   s    z)ColPaliProcessor.query_augmentation_token)imagestextkwargsr*   c                    s   j tfd jji|}|d dd}|du}|du rJ|du rJtd|durb|durbtd|durZ j|}t|} j	gt
| }	dd |D } fd	dt|	|D }
 j|fi |d
 d }|d dddur|d d   j7  <  j|
fddi|d }i |d|i}|rP|d |d dkd}|d|i t|dS |durt|trx|g}n$t|trt|d tstd|du r jd }g }|D ]*} jj j | | d }|| q|d dd|d d<  j|fddi|d }|S dS )a	  
        Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is a custom
        wrapper around the PaliGemmaProcessor's [`~PaliGemmaProcessor.__call__`] method adapted for the ColPali model. It cannot process
        both text and images at the same time.

        When preparing the text(s), this method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's
        [`~LlamaTokenizerFast.__call__`].
        When preparing the image(s), this method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's
        [`~SiglipImageProcessor.__call__`].
        Please refer to the docstring of the above two methods for more information.

        Args:
            images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
                The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
                tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
                number of channels, H and W are image height and width.
            text (`str`, `list[str]`, `list[list[str]]`):
                The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
                (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
                `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
            return_tensors (`str` or [`~utils.TensorType`], *optional*):
                If set, will return tensors of a particular framework. Acceptable values are:

                - `'tf'`: Return TensorFlow `tf.constant` objects.
                - `'pt'`: Return PyTorch `torch.Tensor` objects.
                - `'np'`: Return NumPy `np.ndarray` objects.
                - `'jax'`: Return JAX `jnp.ndarray` objects.

        Returns:
            [`BatchFeature`]: A [`BatchFeature`] with the following fields:

            - **input_ids** -- List of token ids to be fed to a model.
            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
              `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
              `None`).
            - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
        Ztokenizer_init_kwargsr   suffixNz&Either text or images must be providedz5Only one of text or images can be processed at a timec                 S   s   g | ]}| d qS )RGB)convert).0imager   r   r   
<listcomp>       z-ColPaliProcessor.__call__.<locals>.<listcomp>c              
      s:   g | ]2\}}t | jj jtt|tr.t|nd dqS )   )prompt	bos_tokenZimage_seq_lenZimage_tokenZ
num_images)r   r#   r9   image_seq_lengthr   
isinstancelistlen)r3   r8   Z
image_listr+   r   r   r5      s   r   pixel_values
max_lengthreturn_token_type_idsFZ	input_idsZtoken_type_idsr   ilabels)dataz*Text must be a string or a list of strings
   
2   )Z_merge_kwargsr   r#   Zinit_kwargspop
ValueErrorr"   Zfetch_imagesr
   r    r=   zipgetr:   Zmasked_fillupdater   r;   strr<   r,   r9   r!   append)r'   r-   r.   ZaudioZvideosr/   Zoutput_kwargsr0   r@   Z	texts_docZinput_stringsr>   inputsZreturn_datarA   Ztexts_queryqueryZbatch_queryr   r+   r   __call__Z   sp    -






zColPaliProcessor.__call__)r-   r/   r*   c                 K   s   | j f d|i|S )a  
        Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColPaliProcessor's
        [`ColPaliProcessor.__call__`].

        This method forwards the `images` and `kwargs` arguments to the image processor.

        Args:
            images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
                The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
                tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
                number of channels, H and W are image height and width.
            return_tensors (`str` or [`~utils.TensorType`], *optional*):
                If set, will return tensors of a particular framework. Acceptable values are:

                - `'tf'`: Return TensorFlow `tf.constant` objects.
                - `'pt'`: Return PyTorch `torch.Tensor` objects.
                - `'np'`: Return NumPy `np.ndarray` objects.
                - `'jax'`: Return JAX `jnp.ndarray` objects.

        Returns:
            [`BatchFeature`]: A [`BatchFeature`] with the following fields:

            - **input_ids** -- List of token ids to be fed to a model.
            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
              `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
              `None`).
            - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
        r-   rO   )r'   r-   r/   r   r   r   process_images   s    !zColPaliProcessor.process_images)r.   r/   r*   c                 K   s   | j f d|i|S )a  
        Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColPaliProcessor's
        [`ColPaliProcessor.__call__`].

        This method forwards the `text` and `kwargs` arguments to the tokenizer.

        Args:
            text (`str`, `list[str]`, `list[list[str]]`):
                The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
                (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
                `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
            return_tensors (`str` or [`~utils.TensorType`], *optional*):
                If set, will return tensors of a particular framework. Acceptable values are:

                - `'tf'`: Return TensorFlow `tf.constant` objects.
                - `'pt'`: Return PyTorch `torch.Tensor` objects.
                - `'np'`: Return NumPy `np.ndarray` objects.
                - `'jax'`: Return JAX `jnp.ndarray` objects.

        Returns:
            [`BatchFeature`]: A [`BatchFeature`] with the following fields:

            - **input_ids** -- List of token ids to be fed to a model.
            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
              `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
              `None`).
        r.   rP   )r'   r.   r/   r   r   r   process_queries   s     z ColPaliProcessor.process_queries   cpuztorch.Tensorztorch.dtypeztorch.device)query_embeddingspassage_embeddings
batch_sizeoutput_dtypeoutput_devicer*   c              	   C   s@  t |dkrtdt |dkr(td|d j|d jkrDtd|d j|d jkr`td|du rr|d j}g }tdt ||D ]}g }tjjjj	||||  ddd}	tdt ||D ]N}
tjjjj	||
|
|  ddd}|
td	|	|jd
dd jdd q|
tj|dd|| qtj|ddS )aZ  
        Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
        query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the
        image of a document page.

        Because the embedding tensors are multi-vector and can thus have different shapes, they
        should be fed as:
        (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
        (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
            obtained by padding the list of tensors.

        Args:
            query_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Query embeddings.
            passage_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Passage embeddings.
            batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
            output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor.
                If `None`, the dtype of the input embeddings is used.
            output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor.

        Returns:
            `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
            tensor is saved on the "cpu" device.
        r   zNo queries providedzNo passages providedz/Queries and passages must be on the same devicez-Queries and passages must have the same dtypeNT)Zbatch_firstZpadding_valuezbnd,csd->bcnsr   )dim   r7   )r=   rG   ZdeviceZdtyperangetorchnnutilsZrnnZpad_sequencerL   Zeinsummaxsumcatto)r'   rU   rV   rW   rX   rY   ZscoresiZbatch_scoresZbatch_queriesjZbatch_passagesr   r   r   score_retrieval  s2     


 "z ColPaliProcessor.score_retrieval)NNNr   r   )NNNN)N)rS   NrT   )r   r   r   __doc__rK   r&   propertyr,   r	   r   r   r   r<   r   r   r   rO   rQ   rR   intr   rf   __classcell__r   r   r(   r   r   0   sV        
    y %&   
r   )typingr   r   Z2transformers.models.paligemma.processing_paligemmar   r   r   Zfeature_extraction_utilsr   Zimage_utilsr	   r
   Zprocessing_utilsr   r   Ztokenization_utils_baser   r   r_   r   r   r]   Z
get_loggerr   loggerr   r   __all__r   r   r   r   <module>   s   
  *