a
    h\                    @   sh  d Z ddlZddlmZ ddlmZmZmZ ddlZddlm	Z	 ddl
mZmZ ddlmZ dd	lmZmZmZmZ dd
lmZ ddlmZ ddlmZmZ ddlmZ ddlmZ ee Z!eeddG dd deZ"eeG dd deZ#eddeG dd deZ$eG dd de$Z%eddG dd de$Z&eddG dd  d e$eZ'g d!Z(dS )"zRAG model implementation.    N)	dataclass)CallableOptionalUnion)nn   )CacheEncoderDecoderCache)PretrainedConfig)GenerationConfigGenerationMixinLogitsProcessorListStoppingCriteriaList)ModelOutput)PreTrainedModel)auto_docstringlogging   )	RagConfig)RagRetrieverzI
    Base class for retriever augmented marginalized models outputs.
    )Zcustom_introc                   @   sz  e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZee ed< dZeej ed< dZeej ed< dZeej ed	< dZeej ed
< dZeej ed< dZeeejdf  ed< dZeeejdf  ed< dZeej ed< dZeeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< dS )RetrievAugLMMarginOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss.
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head. The score is possibly marginalized over all documents for
        each vocabulary token.
    doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
        Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
        `question_encoder_last_hidden_state`.
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
        num_heads, sequence_length, embed_size_per_head)`).

        Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
        (see `past_key_values` input) to speed up sequential decoding.
    retrieved_doc_embeds (`torch.FloatTensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*):
        Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute
        the `doc_scores`.
    retrieved_doc_ids (`torch.LongTensor` of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*):
        The indexes of the embedded documents retrieved by the retriever.
    context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
        Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever.
    context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
        Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
        retriever.
    question_encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
        Sequence of hidden states at the output of the last layer of the question encoder pooled output of the
        model.
    question_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden states of the question encoder at the output of each layer plus the initial embedding outputs.
    question_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights of the question encoder, after the attention softmax, used to compute the weighted
        average in the self-attention heads.
    generator_enc_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
        Sequence of hidden-states at the output of the last layer of the generator encoder of the model.
    generator_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs.
    generator_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted
        average in the self-attention heads.
    generator_dec_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs.
    generator_dec_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
        average in the self-attention heads.
    generator_cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
        weighted average in the cross-attention heads.
    Nlosslogits
doc_scorespast_key_valuesretrieved_doc_embedsretrieved_doc_idscontext_input_idscontext_attention_mask"question_encoder_last_hidden_state.question_enc_hidden_statesquestion_enc_attentionsgenerator_enc_last_hidden_stategenerator_enc_hidden_statesgenerator_enc_attentionsgenerator_dec_hidden_statesgenerator_dec_attentionsgenerator_cross_attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r   r   r   r   
LongTensorr   r   r   r    tupler!   r"   r#   r$   r%   r&   r'    r1   r1   `/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/rag/modeling_rag.pyr   %   s$   
Gr   c                   @   sh  e Zd ZU dZdZeej ed< dZ	eej ed< dZ
ee ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed	< dZeej ed
< dZeeejdf  ed< dZeeejdf  ed< dZeej ed< dZeeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< dS )RetrievAugLMOutputa7  
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head. The score is possibly marginalized over all documents for
        each vocabulary token.
    doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
        Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
        `question_encoder_last_hidden_state`.
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
        num_heads, sequence_length, embed_size_per_head)`).

        Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
        (see `past_key_values` input) to speed up sequential decoding.
    retrieved_doc_embeds (`torch.FloatTensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*):
        Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute
        the `doc_scores`.
    retrieved_doc_ids (`torch.LongTensor` of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*):
        The indexes of the embedded documents retrieved by the retriever.
    context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
        Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever.
    context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
        Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
        retriever.
    question_encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
        Sequence of hidden states at the output of the last layer of the question encoder pooled output of the
        model.
    question_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden states of the question encoder at the output of each layer plus the initial embedding outputs.
    question_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights of the question encoder, after the attention softmax, used to compute the weighted
        average in the self-attention heads.
    generator_enc_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
        Sequence of hidden-states at the output of the last layer of the generator encoder of the model.
    generator_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs.
    generator_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted
        average in the self-attention heads.
    generator_dec_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs.
    generator_dec_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
        average in the self-attention heads.
    generator_cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
        weighted average in the cross-attention heads.
    Nr   r   r   r   r   r   r   r   .r    r!   r"   r#   r$   r%   r&   r'   )r(   r)   r*   r+   r   r   r,   r-   r.   r   r   r   r   r   r/   r   r   r   r    r0   r!   r"   r#   r$   r%   r&   r'   r1   r1   r1   r2   r3      s"   
Er3   a  
    RAG models were released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP
    Tasks](https://huggingface.co/papers/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al.

    RAG is a retriever augmented model and encapsulate three components: a question encoder, a dataset retriever and a
    generator, the encoder and generator are trainable while the retriever is just an indexed dataset.
    c                   @   sD   e Zd ZU eed< dZdZdZede	e
 e	e
 eedddZdS )	RagPreTrainedModelconfigragTN).question_encoder_pretrained_model_name_or_path'generator_pretrained_model_name_or_path	retrieverreturnc                 K   s  dd |  D }dd |  D }|D ]}|d| = q(|D ]}|d| = q<|dd}|du r|duspJ dd	d
lm}	 d|vrd	dlm}
 |
j|fi |ddi\}}||d< |	j|fi |}|dd}|du rH|dusJ dd	dlm} d|vr6d	dlm}
 |
j|fi |ddi\}}||d< |j|fi |}|d}|du rtt	j
|j|jfi |}| ||||dS )a  
        Instantiates an question encoder and a generator from one or two base classes of the library from pretrained
        model checkpoints.

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
        the model, you need to first set it back in training mode with `model.train()`.

        Params:
            question_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
                Information necessary to initiate the question encoder. Can be either:

                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
                    - A path to a *directory* containing model weights saved using
                      [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
                    - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
                      this case, `from_tf` should be set to `True` and a configuration object should be provided as
                      `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
                      PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            generator_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
                Information necessary to initiate the generator. Can be either:

                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
                    - A path to a *directory* containing model weights saved using
                      [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
                    - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
                      this case, `from_tf` should be set to `True` and a configuration object should be provided as
                      `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
                      PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args (remaining positional arguments, *optional*):
                All remaining positional arguments will be passed to the underlying model's `__init__` method.
            retriever ([`RagRetriever`], *optional*):
                The retriever to use.
            kwwargs (remaining dictionary of keyword arguments, *optional*):
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
                `output_attentions=True`).

                - To update the question_encoder configuration, use the prefix *question_encoder_* for each
                  configuration parameter.
                - To update the generator configuration, use the prefix *generator_* for each configuration parameter.
                - To update the parent model configuration, do not use a prefix for each configuration parameter.

                Behaves differently depending on whether a `config` is provided or automatically loaded.

        Example:

        ```python
        >>> from transformers import RagModel

        >>> # initialize a RAG from two pretrained models.
        >>> model = RagModel.from_pretrained_question_encoder_generator(
        ...     "facebook/dpr-question_encoder-single-nq-base", "google-t5/t5-small"
        ... )
        >>> # saving model after fine-tuning
        >>> model.save_pretrained("./rag")
        >>> # load fine-tuned model
        >>> model = RagModel.from_pretrained("./rag")
        ```c                 S   s,   i | ]$\}}| d r|td d |qS )question_encoder_N
startswithlen.0argumentvaluer1   r1   r2   
<dictcomp>4  s   
zQRagPreTrainedModel.from_pretrained_question_encoder_generator.<locals>.<dictcomp>c                 S   s,   i | ]$\}}| d r|td d |qS )
generator_Nr<   r?   r1   r1   r2   rC   :  s   
r;   rD   modelNznIf `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to be defined   	AutoModelr5   )
AutoConfigZreturn_unused_kwargsTzqIf `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has to be definedAutoModelForSeq2SeqLM)question_encoder	generatorr5   r9   )itemspopauto.modeling_autorH   Zauto.configuration_autorI   Zfrom_pretrainedrK   getr   'from_question_encoder_generator_configsr5   )clsr7   r8   r9   kwargsZkwargs_question_encoderZkwargs_generatorkeyrL   rH   rI   Zquestion_encoder_configrM   rK   Zgenerator_configr5   r1   r1   r2   *from_pretrained_question_encoder_generator   sx    D





z=RagPreTrainedModel.from_pretrained_question_encoder_generator)NNN)r(   r)   r*   r   r.   Zbase_model_prefixZ_supports_flash_attnZ_supports_sdpaclassmethodr   strr   r   rV   r1   r1   r1   r2   r4      s   
   r4   c                       s   e Zd Zdee ee ee ee d fddZed	ee	j
 ee	j eeee	j   ee	j
 ee	j ee ee	j ee	j
 ee	j
 ee ee ee ee ee eee	j ef dddZ  ZS )
RagModelNr5   rL   rM   r9   c                    s  |dus |dur|dus J d|du rBt j|j|jfi |}n"t|| jsdJ d| d| j t | |du rddlm} |	|j
}|du rddlm} |	|j}|| _| jdurt|tsJ dt| j d	|| _|| _
|| _d| _d
| _dS )  
        question_encoder (`PreTrainedModel`, *optional*):
            The model responsible for encoding the question into hidden states for retrieval.
        generator (`PreTrainedModel`, *optional*):
            The model responsible for generating text based on retrieved documents.
        retriever (`RagRetriever`, *optional*):
            The component responsible for retrieving documents from a knowledge base given the encoded question.
        NzQEither a configuration or an question_encoder and a generator has to be provided.zconfig: z has to be of type rF   rG   rJ   z`self.retriever` is of type z&, but should be of type `RagRetriever`F)r   rR   r5   
isinstanceZconfig_classsuper__init__rP   rH   from_configrL   rK   rM   r9   r   typectx_encodercontext_encoder_training)selfr5   rL   rM   r9   rT   rH   rK   	__class__r1   r2   r^     s6    "
zRagModel.__init__)	input_idsattention_maskencoder_outputsdecoder_input_idsdecoder_attention_maskr   r   r   r   	use_cacheoutput_attentionsoutput_hidden_statesoutput_retrievedn_docsr:   c                 C   s4  |dur|n| j j}|
dur |
n| j j}
|dur4|n| j j}|durH|n| j j}|dur\|n| j j}| jduo|du s|	du s|du o|du }|du r*|r| j||dd}|d }| j|| j	dt
jd | jj j|dd}| jr|d	 |d
 |d |d |d |d f\}}	}}}}|	|}|		|}	|	|}|	|}| j||ddj}|d||jd }t
|d|ddd}nb|d	 |d
 |d |d f\}}	}}|	|}|	|}|		|}	t
|d|ddd}n6|dusJ d|	dusJ d|dus*J d|dus<J d|jd | dkslJ d| d|jd  d|dur|j|dd}|dur|j|dd}| j||	|||||
|dd	}|sd}d}d}d}d}n|j}|j}|r|sd}d}	d}d}t|j||j||	||||||j|j|j |j!|j"|j#dS )ay  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
            which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
            obtain the indices.

            [What are input IDs?](../glossary#input-ids)
        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*)
            Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
            *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
            sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
            generator's encoder.

            Used by the ([`RagModel`]) model during decoding.
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Provide for generation tasks. `None` by default, construct as per instructions for the generator model
            you're using with your RAG instance.
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size,  target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
            Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
            `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
            has to be provided to the forward pass. `doc_scores` can be computed via
            `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
        context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
            Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
            the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
        context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
            Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
            provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
        output_retrieved (`bool`, *optional*):
            Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
            `context_attention_mask`. See returned tensors for more detail.
        n_docs (`int`, *optional*):
            The number of documents to retrieve.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, RagRetriever, RagModel
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-base")
        >>> retriever = RagRetriever.from_pretrained(
        ...     "facebook/rag-token-base", index_name="exact", use_dummy_dataset=True
        ... )
        >>> # initialize with RagRetriever to do everything in one forward call
        >>> model = RagModel.from_pretrained("facebook/rag-token-base", retriever=retriever)

        >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
        >>> outputs = model(input_ids=inputs["input_ids"])
        ```NT)rg   return_dictr   cpudevicedtypeptprefixro   Zreturn_tensorsr   r   r   Ztokenized_doc_idsZtokenized_doc_attention_maskZdoc_idsr   rF   zMake sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.zMake sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.zMake sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.z^Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function.M The first dimension of `context_input_ids` should be a multiple of `n_docs`=	, but is .dim)	rf   rg   rh   ri   rj   r   rk   rl   rp   N)r   r   r   r   r   r   r   r   r    r!   r"   r#   r$   r%   r&   r'   )$r5   ro   rk   rl   rm   rn   r9   rL   detachtor,   float32numpyrM   rw   rb   ra   Zpooler_outputviewshapebmm	unsqueeze	transposesqueezerepeat_interleavehidden_statesZ
attentionsr3   r   r   Zencoder_last_hidden_stateZencoder_hidden_statesZencoder_attentionsZdecoder_hidden_statesZdecoder_attentionsZcross_attentions)rc   rf   rg   rh   ri   rj   r   r   r   r   rk   rl   rm   rn   ro   Zhas_to_retrieveZquestion_enc_outputsr   Zretriever_outputsr   Zretrieved_doc_input_idsZretrieved_doc_attention_maskr   Zgen_outputsr    r!   r1   r1   r2   forward  s   I

	









zRagModel.forward)NNNN)NNNNNNNNNNNNNN)r(   r)   r*   r   r
   r   r   r^   r   r,   r/   Tensorr0   r-   
BoolTensorr   boolintr   r3   r   __classcell__r1   r1   rd   r2   rY   ~  sT       2              rY   zu
    A RAG-sequence model implementation. It performs RAG-sequence specific marginalization in the forward pass.
    c                       s~  e Zd Zdee ee ee ee d fddZedddZedd	d
Z	e
deej eej eeeej   eej eej ee eej eej eej ee ee ee ee ee ee eej ee edddZedd Zedd Zedd Ze deej eej eej eej eej ee ee ee ee ejd
ddZd ddZedd Z  ZS )!RagSequenceForGenerationNrZ   c                    sb   |dus |dur|dus J d|du r@t j|j|jfi |}t | t||||d| _dS r[   NzHEither a configuration or an encoder and a generator has to be provided.rZ   r   rR   r5   r]   r^   rY   r6   rc   r5   rL   rM   r9   rT   rd   r1   r2   r^     s    z!RagSequenceForGeneration.__init__r9   c                 C   s   || j _d S r~   r6   r9   rc   r9   r1   r1   r2   set_retriever  s    z&RagSequenceForGeneration.set_retrieverra   c                 C   s   d| j _|| j _d S NTr6   rb   ra   rc   ra   r1   r1   r2    set_context_encoder_for_training  s    z9RagSequenceForGeneration.set_context_encoder_for_training)rf   rg   rh   ri   rj   r   r   r   r   rk   rl   rm   rn   exclude_bos_scorereduce_losslabelsro   r:   c                 K   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}|durT|du rP|}d}
| j||||||||	||
||||d}d}|dur| j|j|j||| j j||d}t	||j|j|j
|j|j|j|j|j|j|j|j|j|j|j|j|jdS )a3  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
            which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
            obtain the indices.

            [What are input IDs?](../glossary#input-ids)
        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*)
            Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
            *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
            sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
            generator's encoder.

            Used by the ([`RagModel`]) model during decoding.
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Provide for generation tasks. `None` by default, construct as per instructions for the generator model
            you're using with your RAG instance.
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size,  target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
            Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
            the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
        context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
            Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
            provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
        doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
            Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
            `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
            has to be provided to the forward pass. `doc_scores` can be computed via
            `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
        output_retrieved (`bool`, *optional*):
            Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
            `context_attention_mask`. See returned tensors for more detail.
        exclude_bos_score (`bool`, *optional*):
            Only relevant if `labels` is passed. If `True`, the score of the BOS token is disregarded when computing
            the loss.
        reduce_loss (`bool`, *optional*):
            Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `torch.Tensor.sum`
            operation.
        n_docs (`int`, *optional*):
            The number of documents to retrieve.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, RagRetriever, RagSequenceForGeneration
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq")
        >>> retriever = RagRetriever.from_pretrained(
        ...     "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
        ... )
        >>> # initialize with RagRetriever to do everything in one forward call
        >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

        >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
        >>> targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
        >>> input_ids = inputs["input_ids"]
        >>> labels = targets["input_ids"]
        >>> outputs = model(input_ids=input_ids, labels=labels)

        >>> # or use retriever separately
        >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
        >>> # 1. Encode
        >>> question_hidden_states = model.question_encoder(input_ids)[0]
        >>> # 2. Retrieve
        >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
        >>> doc_scores = torch.bmm(
        ...     question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)
        ... ).squeeze(1)
        >>> # 3. Forward to generator
        >>> outputs = model(
        ...     context_input_ids=docs_dict["context_input_ids"],
        ...     context_attention_mask=docs_dict["context_attention_mask"],
        ...     doc_scores=doc_scores,
        ...     decoder_input_ids=labels,
        ... )
        ```NFrf   rg   rh   ri   rj   r   r   r   r   rk   rl   rm   rn   ro   )r   epsilonr   ro   r   r   r   r   r   r   r   r   r   r    r!   r"   r#   r$   r%   r&   r'   )r5   ro   r   r   r6   get_nllr   r   label_smoothingr   r   r   r   r   r   r   r    r!   r"   r#   r$   r%   r&   r'   )rc   rf   rg   rh   ri   rj   r   r   r   r   rk   rl   rm   rn   r   r   r   ro   rT   outputsr   r1   r1   r2   r     sj    g
z RagSequenceForGeneration.forwardc                 C   s   | j jS r~   r   rc   r1   r1   r2   r9   e  s    z"RagSequenceForGeneration.retrieverc                 C   s   | j jS r~   r6   rM   r   r1   r1   r2   rM   i  s    z"RagSequenceForGeneration.generatorc                 C   s   | j jS r~   r6   rL   r   r1   r1   r2   rL   m  s    z)RagSequenceForGeneration.question_encoder)
rf   rg   r   r   r   do_deduplicationnum_return_sequences	num_beamsro   r:   c
                 K   s`  |	dur|	n| j j}	|dur |n| j j}|dur4|n| j j}|durH|n| j j}|dush|dushJ d| jdur|du r| j||dd }| j|| jdt	j
d | jj j|	ddd	 }||}g }||
d
< ||
d< d|
d< |dur|jd n|jd |	 }t|D ]<}|||	 |d |	  }| jj|fi |
}|r`t	tdd |D  }|jd }|dur|||d  |d}| ||dd}n|dusJ d|dusJ d||d}|||	 |d |	  }||d}|||d ddf }||d}| ||||dd}|d  |d }|||  q| j|| j jjdS )a  
        Implements RAG sequence "thorough" decoding. Read the [`~generation.GenerationMixin.generate`]` documentation
        for more information on how to set other generate input parameters.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                The sequence used as a prompt for the generation. If `input_ids` is not passed, then
                `context_input_ids` has to be provided.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
                Input IDs post-processed from the retrieved documents and the question encoder input_ids by the
                retriever.
            context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
                Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
                retriever.

                If the model is not initialized with a `retriever` or `input_ids` is not given, `context_input_ids` and
                `context_attention_mask` have to be provided to the forward pass. They are returned by
                [`~RagRetriever.__call__`].
            doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
                Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
                `question_encoder_last_hidden_state`.

                If the model is not initialized with a `retriever` or `input_ids` is not given, `doc_scores` has to be
                provided to the forward pass. `doc_scores` are returned by [`~RagRetriever.__call__`].
            do_deduplication (`bool`, *optional*):
                Whether or not to deduplicate the generations from different context documents for a given input. Has
                to be set to `False` if used while training with distributed backend.
            num_return_sequences(`int`, *optional*, defaults to 1):
                The number of independently computed returned sequences for each element in the batch. Note that this
                is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`]` function,
                where we set `num_return_sequences` to `num_beams`.
            num_beams (`int`, *optional*, defaults to 1):
                Number of beams for beam search. 1 means no beam search.
            n_docs (`int`, *optional*, defaults to `config.n_docs`)
                Number of documents to retrieve and/or number of documents for which to generate an answer.
            kwargs (`dict[str, Any]`, *optional*):
                Additional kwargs will be passed to [`~generation.GenerationMixin.generate`].

        Return:
            `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
            sequences. The second dimension (sequence length) is either equal to `max_length` or shorter if all batches
            finished early due to the `eos_token_id`.
        Nz= At least one of input_ids or context_input_ids must be givenrg   r   rq   rr   ru   rv   r   r   r   rg   r   c                 S   s   i | ]}t | |qS r1   )rX   tolist)r@   kr1   r1   r2   rC         z5RagSequenceForGeneration.generate.<locals>.<dictcomp>T)r   r   zMake sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.zMake sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.)r   r   r   r   r   r   )pad_token_id)r5   ro   r   r   r   r9   rL   r   r   r,   r   r   rM   rw   r   rangegeneratestacklistvaluesrepeatZtopkappend_cat_and_padr   )rc   rf   rg   r   r   r   r   r   r   ro   model_kwargsZnum_doc_return_sequencesquestion_hidden_statesZhypos
batch_sizeindexZgenerator_input_idsZoutput_sequencesZnum_candidatesZnew_input_idsr   Zindividual_input_idsZindividual_attention_maskZindividual_doc_scoresZtop_cand_indsr1   r1   r2   r   q  s~    A	
 
z!RagSequenceForGeneration.generateF        c                    sH  t d d dd f jd d jjjgd|d urF|n jj} jj	p^ jjj	}|d uod d df 
| }	 fdd}
tjj|dd|jd | |d|d}tjj|dddd}|d d d d d dd d f }|d d d d ddd d f }|d d d d dd d d f }t j||| |gdd}ddd|dd | ksJ |jdd}|jdd	d
}|
||\}}|r|	r|d d d d dd f dn|d}|d}|d}|d}| }| }|r"| }| }||d }d| | ||  }|S )Nr   r   c                    sD     jjj}| r0| |d ||d | d|dfS Nr   rx   eqr5   rM   r   anyZmasked_fill_r   ll
smooth_objZpad_maskrc   targetr1   r2   
_mask_pads  s
    z4RagSequenceForGeneration.get_nll.<locals>._mask_padsrx   r|   rF   r}   r   Tr}   Zkeepdim      ?)r,   catnewr   fill_r5   rM   r   ro   bos_token_idr   allr   
functionallog_softmaxr   sizer   r   r}   gathersum	logsumexp)rc   
seq_logitsr   r   r   r   r   ro   r   Zuse_bosr   seq_logprobsdoc_logprobsZfirst_token_scoresZsecond_token_scores	remainderrag_logprobsr   r   nll_losssmooth_losseps_ir   r1   r   r2   r     s@    2"   6


z RagSequenceForGeneration.get_nllc                 C   sv   | d  tdd | D tdd | D |}d}| D ]6}|||||jd  d |jd f< ||jd 7 }q:|S )Nr   c                 S   s   g | ]}|j d  qS )r   r   r@   tr1   r1   r2   
<listcomp>F  r   z9RagSequenceForGeneration._cat_and_pad.<locals>.<listcomp>c                 S   s   g | ]}|j d  qS )r   r   r   r1   r1   r2   r   F  r   r   )r   r   maxr   r   )Ztensorsr   outputindr   r1   r1   r2   r   C  s    0$z%RagSequenceForGeneration._cat_and_pad)NNNN)NNNNNNNNNNNNNNNNN)	NNNNNNNNN)Fr   FN) r(   r)   r*   r   r
   r   r   r^   r   r   r   r,   r/   r   r0   r   r   r-   r   r   r   r   propertyr9   rM   rL   no_gradr   r   staticmethodr   r   r1   r1   rd   r2   r     s                         !


           
;r   zo
    A RAG-token model implementation. It performs RAG-token specific marginalization in the forward pass.
    c                       s  e Zd Zd+ee ee ee ee d fddZedddZedd	d
Z	d,ddZ
edd Zedd Zedd Zedd Zd-ddZed.eej eej eeeej   eej eej ee eej eej eej ee ee ee ee ee ee eej ee edddZe dddddddde e  f
eej eej eej eej eej ee ee! ee"eejge#e f  ee ee  ejdddZ$dd Z%dd  Z&d!d" Z'd#d$ Z(d/d%d&Z)d0d)d*Z*  Z+S )1RagTokenForGenerationNrZ   c                    sb   |dus |dur|dus J d|du r@t j|j|jfi |}t | t||||d| _dS r   r   r   rd   r1   r2   r^   U  s    zRagTokenForGeneration.__init__r   c                 C   s   || j _d S r~   r   r   r1   r1   r2   r   s  s    z#RagTokenForGeneration.set_retrieverr   c                 C   s   d| j _|| j _d S r   r   r   r1   r1   r2   r   v  s    z6RagTokenForGeneration.set_context_encoder_for_trainingc           	   
   K   s4   |d ur|d d dd f }d ||||||d|d	S )Nrx   T)	rf   rh   r   r   ri   r   rk   do_marginalizero   r1   )	rc   ri   r   rg   rk   rh   r   ro   rT   r1   r1   r2   prepare_inputs_for_generationz  s    z3RagTokenForGeneration.prepare_inputs_for_generationc                 C   s   | j jS r~   r   r   r1   r1   r2   r9     s    zRagTokenForGeneration.retrieverc                 C   s   | j jS r~   r   r   r1   r1   r2   rM     s    zRagTokenForGeneration.generatorc                 C   s   | j jS r~   r   r   r1   r1   r2   rL     s    z&RagTokenForGeneration.question_encoderc                    sL   dd  d}| D ]"}|t  fdd|D f7 }qt| trHt|}|S )zeReorders cache for generation. BART-inspired but we need to take care of the extra dimension for docsc                 S   s^   | j d |j d  }| jd|g| j dd  R  } | d|} | jdg| j dd  R  }|S )Nr   rx   r   rF   )r   r   Zindex_select)r   Z	new_orderro   resultr1   r1   r2   _reorder_stacked  s
    z>RagTokenForGeneration._reorder_cache.<locals>._reorder_stackedr1   c                 3   s    | ]} | |jV  qd S r~   )r   rs   )r@   Z
past_stater   beam_idxr1   r2   	<genexpr>  r   z7RagTokenForGeneration._reorder_cache.<locals>.<genexpr>)r0   r\   r	   Zfrom_legacy_cache)r   r   Zreordered_pastZ
layer_pastr1   r   r2   _reorder_cache  s    

z$RagTokenForGeneration._reorder_cachec                 C   sp   |d ur|n| j j}tjj|dd|jd | |d|d}tj|dd}||	d	d }tj
|ddS )Nrx   r|   r   r   )r5   ro   r   r   r   r   r   r   r,   r   r   )rc   r   r   ro   r   r   Zlog_prob_sumr1   r1   r2   marginalize  s    z!RagTokenForGeneration.marginalize)rf   rg   rh   ri   rj   r   r   r   r   rk   rl   rm   rn   r   r   r   ro   r:   c                 K   s  |dur|n| j j}|dur |n| j j}|dur4|n| j j}|durT|du rP|}d}
| j||||||||	||
||||d}d}|j}|dur|dusJ | j|j|j||| j j|d}|r| 	||j|}t
|||j|j|j|j|j|j|j|j|j|j|j|j|j|j|jdS )a  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
            which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
            obtain the indices.

            [What are input IDs?](../glossary#input-ids)
        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*)
            Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
            *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
            sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
            generator's encoder.

            Used by the ([`RagModel`]) model during decoding.
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Provide for generation tasks. `None` by default, construct as per instructions for the generator model
            you're using with your RAG instance.
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size,  target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
            Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
            the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
        context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
            Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
            provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
        doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
            Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
            `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
            has to be provided to the forward pass. `doc_scores` can be computed via
            `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
        output_retrieved (`bool`, *optional*):
            Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
            `context_attention_mask`. See returned tensors for more detail.
        do_marginalize (`bool`, *optional*):
            If `True`, the logits are marginalized over all documents by making use of
            `torch.nn.functional.log_softmax`.
        reduce_loss (`bool`, *optional*):
            Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `torch.Tensor.sum`
            operation.
        n_docs (`int`, *optional*):
            The number of documents to retrieve.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, RagRetriever, RagTokenForGeneration
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-nq")
        >>> retriever = RagRetriever.from_pretrained(
        ...     "facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True
        ... )
        >>> # initialize with RagRetriever to do everything in one forward call
        >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

        >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
        >>> targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
        >>> input_ids = inputs["input_ids"]
        >>> labels = targets["input_ids"]
        >>> outputs = model(input_ids=input_ids, labels=labels)

        >>> # or use retriever separately
        >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True)
        >>> # 1. Encode
        >>> question_hidden_states = model.question_encoder(input_ids)[0]
        >>> # 2. Retrieve
        >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
        >>> doc_scores = torch.bmm(
        ...     question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)
        ... ).squeeze(1)
        >>> # 3. Forward to generator
        >>> outputs = model(
        ...     context_input_ids=docs_dict["context_input_ids"],
        ...     context_attention_mask=docs_dict["context_attention_mask"],
        ...     doc_scores=doc_scores,
        ...     decoder_input_ids=labels,
        ... )

        >>> # or directly generate
        >>> generated = model.generate(
        ...     context_input_ids=docs_dict["context_input_ids"],
        ...     context_attention_mask=docs_dict["context_attention_mask"],
        ...     doc_scores=doc_scores,
        ... )
        >>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)
        ```NFr   )r   r   ro   r   )r5   ro   r   r   r6   r   r   r   r   r   r   r   r   r   r   r   r   r    r!   r"   r#   r$   r%   r&   r'   )rc   rf   rg   rh   ri   rj   r   r   r   r   rk   rl   rm   rn   r   r   r   ro   rT   r   r   r   r1   r1   r2   r     sp    o	zRagTokenForGeneration.forward)rf   rg   r   r   r   ro   generation_configprefix_allowed_tokens_fnlogits_processorstopping_criteriar:   c                    s  |du r| j }t|}|jf i |}|dddu}| || durPn| jj| jdur|du r| j	||dd }| j||
 jdtjd | jjjdd}|d	 |d
 |d   }}}||}||}||}t|d|ddd}|jd  dks8J d d|jd  d|jd   | jj }|||dd}tj |j df|jtjt|  jd}|jd }|d }d' fdd	}|||jd}|||jd|d< |j|jdd}||d< ||d< ||d< |d< | j |||||	|jd}| j!||
d}| j"||d|jd |j#d d |jdkr|j$dkrlt%d |j$ d!| j&|f|||d"dd#|S |jdkr|j$|jkrt%d$| j'|f|||d"d%|S t%d&|j dS )(a  
        Implements RAG token decoding.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                The sequence used as a prompt for the generation. If `input_ids` is not passed, then
                `context_input_ids` has to be provided.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
                Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
                retriever.

                If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
                forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
            context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
                Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
                retriever.

                If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
                forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
            doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
                Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
                `question_encoder_last_hidden_state`.

                If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
                forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
            n_docs (`int`, *optional*, defaults to `config.n_docs`)
                Number of documents to retrieve and/or number of documents for which to generate an answer.
            generation_config (`~generation.GenerationConfig`, *optional*):
                The generation configuration to be used as base parametrization for the generation call. `**kwargs`
                passed to generate matching the attributes of `generation_config` will override them. If
                `generation_config` is not provided, the default will be used, which has the following loading
                priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
                configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
                default values, whose documentation should be checked to parameterize generation.
            prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
                If provided, this function constraints the beam search to allowed tokens only at each step. If not
                provided no constraint is applied. This function takes 2 arguments `inputs_ids` and the batch ID
                `batch_id`. It has to return a list with the allowed tokens for the next generation step conditioned on
                the previously generated tokens `inputs_ids` and the batch ID `batch_id`. This argument is useful for
                constrained generation conditioned on the prefix, as described in [Autoregressive Entity
                Retrieval](https://huggingface.co/papers/2010.00904).
            logits_processor (`LogitsProcessorList`, *optional*):
                Custom logits processors that complement the default logits processors built from arguments and a
                model's config. If a logit processor is passed that is already created with the arguments or a model's
                config an error is thrown.
            stopping_criteria (`StoppingCriteriaList`, *optional*):
                Custom stopping criteria that complement the default stopping criteria built from arguments and a
                model's config. If a stopping criteria is passed that is already created with the arguments or a
                model's config an error is thrown.
            kwargs (`dict[str, Any]`, *optional*):
                Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
                forwarded to the `forward` function of the model.

        Return:
            `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
            sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches
            finished early due to the `eos_token_id`.
        Nrg   r   r   rq   rr   ru   rv   r   r   r   r   rF   ry   rz   r{   T)rf   rg   rp   )rt   rs   rx   last_hidden_statec                    sl   | d d d d f   df| jdd   } |  |f| jdd   } |   |  f| jdd   S )Nr   r   )Zreshaper   expand)Ztensorr   r   ro   r1   r2   extend_enc_output  s    ,z9RagTokenForGeneration.generate.<locals>.extend_enc_output)r   r|   r   rh   ro   )r   input_ids_seq_lengthZencoder_input_idsr   r   rs   )r   r   )Zassistant_modelr   Zmax_cache_lengthz)num_return_sequences has to be 1, but is z when doing greedy search.F)r   r   r   synced_gpusstreamerzA`num_return_sequences` has to be smaller or equal to `num_beams`.)r   r   r   r   uH   `num_beams` has to be an integer strictly superior to 0 (≥ 1), but is )N)(r   copydeepcopyupdaterQ   Z_prepare_special_tokensr5   ro   r9   rL   r   r   r,   r   r   rM   rw   r   r   r   r   r   r6   Zget_encoderfullr   decoder_start_token_idlongnext
parametersrs   r   Z_get_logits_processorZ_get_stopping_criteriaZ_prepare_cache_for_generation
max_lengthr   
ValueErrorZ_sampleZ_beam_search)rc   rf   rg   r   r   r   ro   r   r   r   r   rT   r   Zkwargs_has_attention_maskr   outr   encoderrh   r   r   r   Zpre_processorZprepared_stopping_criteriar1   r   r2   r   q  s    Q





	
			
zRagTokenForGeneration.generatec                 C   s   |  ||}|S r~   )r   )rc   r   r   r1   r1   r2   _temporary_reorder_cacheG  s    z.RagTokenForGeneration._temporary_reorder_cachec                 C   s   | j j S r~   )r6   rM   get_input_embeddingsr   r1   r1   r2   r  N  s    z*RagTokenForGeneration.get_input_embeddingsc                 C   s   | j j S r~   )r6   rM   get_output_embeddingsr   r1   r1   r2   r  Q  s    z+RagTokenForGeneration.get_output_embeddingsc                 C   s   | j j|S r~   )r6   rM   set_output_embeddings)rc   Znew_embeddingsr1   r1   r2   r  T  s    z+RagTokenForGeneration.set_output_embeddingsc                 C   sX   |du r| j j}||j}|ddddf  |ddddf< ||dddf< |S )zCShift input ids one token to the right, and pad with start_token_idNrx   r   r   )r5   r  Z	new_zerosr   clone)rc   rf   Zstart_token_idZshifted_input_idsr1   r1   r2   shift_tokens_rightW  s    (z(RagTokenForGeneration.shift_tokens_rightFr   c                    s  |d ur|n j j}td d dd f jd d j jjgd fdd} 	|||}
d | ksJ |jdd}	|jddd}
||	|
\}	}
|	d}	|
d}
|	 }|
 }|r| }| }||d }d	| | ||  }|S )
Nr   r   c                    sD     jjj}| r0| |d ||d | d|dfS r   r   r   r   r1   r2   r   g  s
    z1RagTokenForGeneration.get_nll.<locals>._mask_padsrx   r   Tr   r   )r5   ro   r,   r   r   r   r   rM   r   r   r   r}   r   r   r   )rc   r   r   r   r   r   ro   r   r   r   r   r   r   r   r   r1   r   r2   r   `  s*    2


zRagTokenForGeneration.get_nll)NNNN)NNNNNN)N)NNNNNNNNNNNNNNNNN)N)Fr   N),r(   r)   r*   r   r
   r   r   r^   r   r   r   r   r9   rM   rL   r   r   r   r   r,   r/   r-   r0   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r
  r  r  r  r  r   r   r1   r1   rd   r2   r   O  s             





                  - V
	r   )rY   r4   r   r   ))r+   r   dataclassesr   typingr   r   r   r,   r   Zcache_utilsr   r	   Zconfiguration_utilsr
   Z
generationr   r   r   r   Zmodeling_outputsr   Zmodeling_utilsr   utilsr   r   Zconfiguration_ragr   Zretrieval_ragr   Z
get_loggerr(   loggerr   r3   r4   rY   r   r   __all__r1   r1   r1   r2   <module>   s`   
[X	      3    5