a
    hA                    @   s2  d Z ddlZddlmZmZ ddlZddlZddlmZ ddlm	Z	m
Z
mZ ddlmZ ddlmZmZmZ dd	lmZ dd
lmZmZ ddlmZ ddlmZmZmZmZmZmZm Z  ddl!m"Z" ddl#m$Z$m%Z% ddl&m'Z' ddl(m)Z) e%*e+Z,ej-e.e.dddZ/G dd dej0Z1G dd dej2Z3G dd deZ4G dd deZ5G dd dej2Z6G dd  d ej2Z7e$G d!d" d"e"Z8G d#d$ d$e8Z9G d%d& d&e8Z:e$G d'd( d(e8Z;e$d)d*G d+d, d,e8eZ<e$d-d*G d.d/ d/e8Z=e$G d0d1 d1e8Z>G d2d3 d3e8Z?G d4d5 d5e8eZ@g d6ZAdS )7zPyTorch MVP model.    N)OptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)_prepare_4d_attention_mask!_prepare_4d_causal_attention_mask)GradientCheckpointingLayer)BaseModelOutput)BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentionsSeq2SeqLMOutputSeq2SeqModelOutput#Seq2SeqQuestionAnsweringModelOutputSeq2SeqSequenceClassifierOutput)PreTrainedModel)auto_docstringlogging)deprecate_kwarg   )	MvpConfig)	input_idspad_token_iddecoder_start_token_idc                 C   sh   |  | j}| ddddf  |ddddf< ||dddf< |du rTtd||dk| |S )z1
    Shift input ids one token to the right.
    Nr   r   z1self.model.config.pad_token_id has to be defined.i)Z	new_zerosshapeclone
ValueErrorZmasked_fill_)r   r   r    Zshifted_input_ids r%   `/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/mvp/modeling_mvp.pyshift_tokens_right4   s    (r'   c                       sD   e Zd ZdZeed fddZd
ejeejd fdd	Z  Z	S )MvpLearnedPositionalEmbeddingzN
    This module learns positional embeddings up to a fixed maximum size.
    )num_embeddingsembedding_dimc                    s   d| _ t || j  | d S N   )offsetsuper__init__)selfr)   r*   	__class__r%   r&   r/   J   s    z&MvpLearnedPositionalEmbedding.__init__r   N)r   past_key_values_lengthposition_idsc                    s\   |du r@|j dd \}}tj||| tj| jjd|d}n
|d}t 	|| j
 S )z3`input_ids' shape is expected to be [bsz x seqlen].Nr,   )dtypedevicer!   r   )r"   torcharangelongweightr6   expandZ	unsqueezer.   forwardr-   )r0   r   r3   r4   bszZseq_lenr1   r%   r&   r<   P   s    
z%MvpLearnedPositionalEmbedding.forward)r   N)
__name__
__module____qualname____doc__intr/   r7   Tensorr<   __classcell__r%   r%   r1   r&   r(   E   s   r(   c                       s   e Zd ZdZdeeee ee ee ee d fddZe	d	d
ddde
jee
j ee ee
j ee
j ee
j eee
j ee
jee
j eee
j  f d	ddZ  ZS )MvpAttentionz=Multi-headed attention from 'Attention Is All You Need' paper        FTN)	embed_dim	num_headsdropout
is_decoderbias	layer_idxc                    s   t    || _|| _|| _|| | _| j| | jkrNtd| j d| d| jd | _|| _|| _	t
j|||d| _t
j|||d| _t
j|||d| _t
j|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      ࿩rK   )r.   r/   rG   rH   rI   head_dimr$   scalingrJ   rL   r   Lineark_projv_projq_projout_proj)r0   rG   rH   rI   rJ   rK   rL   r1   r%   r&   r/   a   s$    	


zMvpAttention.__init__past_key_valuepast_key_values4.58new_nameversion)	hidden_stateskey_value_statesrV   attention_masklayer_head_maskattn_promptoutput_attentionscache_positionreturnc	                 C   s  |du}	|  \}
}}| || j }|dur^t|trZ|j| j}|	rR|j}q^|j	}n|}|	rf|n|}|	r|dur|r|j
| j j}|j
| j j}n| |}| |}||
d| j| jdd}||
d| j| jdd}|dur&|	s|nd}|||| jd|i\}}|	r&d|j| j< |durtj|d |
ddd|gdd}tj|d |
ddd|gdd}|durt|
d||d  d|j}tj||gdd}|
| j d| jf}||
|| j| jdd}|j| }|j| }|j| }| d}t||dd}|  |
| j ||fkrZtd	|
| j ||f d
|   |dur|  |
d||fkrtd|
d||f d
|   ||
| j||| }||
| j ||}tjj|dd}|durB|  | jfkrtd| jf d
|   |dddd||
| j|| }||
| j ||}|rp||
| j||}||
| j ||}nd}tjj || j | j!d}t||}|  |
| j || jfkrtd|
| j|| jf d
|   ||
| j|| j}|dd}||
|| j"}| #|}||fS )z#Input shape: Batch x Time x ChannelNr!   r   r,   ra   Tr   dimz$Attention weights should be of size z	, but is z!Attention mask should be of size z/Head mask for a single layer should be of size ptrainingz `attn_output` should be of size )$sizerS   rO   
isinstancer   
is_updatedgetrL   Zcross_attention_cacheZself_attention_cachelayerskeysvaluesrQ   rR   viewrH   rN   Z	transposeupdater7   catr;   zerostor6   ZreshapeZbmmr$   r   
functionalZsoftmaxrI   rg   rG   rT   )r0   r[   r\   rV   r]   r^   r_   r`   ra   Zis_cross_attentionr=   tgt_len_Zquery_statesrj   Zcurr_past_key_valueZcurrent_statesZ
key_statesZvalue_statesZprompt_maskZ
proj_shapeZsrc_lenattn_weightsZattn_weights_reshapedZ
attn_probsZattn_outputr%   r%   r&   r<   ~   s    




""
"





"
zMvpAttention.forward)rF   FTN)NNNNNFN)r>   r?   r@   rA   rB   r   floatboolr/   r   r7   rC   r
   tupler<   rD   r%   r%   r1   r&   rE   ^   s@              rE   c                
       sX   e Zd Zed fddZdejejejejee e	ejeej f dddZ
  ZS )	MvpEncoderLayerconfigc                    s   t    |j| _t| j|j|jd| _t	| j| _
|j| _t|j | _|j| _t| j|j| _t|j| j| _t	| j| _d S )N)rG   rH   rI   )r.   r/   d_modelrG   rE   encoder_attention_headsattention_dropout	self_attnr   	LayerNormself_attn_layer_normrI   r	   activation_functionactivation_fnactivation_dropoutrP   Zencoder_ffn_dimfc1fc2final_layer_normr0   r}   r1   r%   r&   r/      s    
zMvpEncoderLayer.__init__F)r[   r]   r^   self_attn_promptr`   rb   c           	      C   s   |}| j |||||d\}}tjj|| j| jd}|| }| |}|}| | |}tjj|| j| jd}| 	|}tjj|| j| jd}|| }| 
|}|jtjkrt| st| rt|jjd }tj|| |d}||fS )a@  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            self_attn_prompt (`torch.FloatTensor`): prompt of self attention of shape
                `(2, encoder_attention_heads, pro_len, head_dim)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r[   r]   r^   r_   r`   re   i  )minmax)r   r   rt   rI   rg   r   r   r   r   r   r   r5   r7   Zfloat16isinfanyisnanZfinfor   clamp)	r0   r[   r]   r^   r   r`   residualrw   Zclamp_valuer%   r%   r&   r<     s4    



zMvpEncoderLayer.forward)F)r>   r?   r@   r   r/   r7   FloatTensorr   ry   rz   r<   rD   r%   r%   r1   r&   r{      s    r{   c                       s   e Zd Zded fddZedddddejeej eej eej eej eej eej eej ee	 ee
 ee
 eej eejeeejejf  f dddZ  ZS )MvpDecoderLayerNr|   c                    s   t    |j| _t| j|j|jd|d| _|j| _t	|j
 | _|j| _t| j| _t| j|j|jd|d| _t| j| _t| j|j| _t|j| j| _t| j| _d S )NT)rG   rH   rI   rJ   rL   )rI   rJ   rL   )r.   r/   r~   rG   rE   decoder_attention_headsr   r   rI   r	   r   r   r   r   r   r   encoder_attnencoder_attn_layer_normrP   Zdecoder_ffn_dimr   r   r   )r0   r}   rL   r1   r%   r&   r/   B  s0    
zMvpDecoderLayer.__init__rU   rV   rW   rX   FT)r[   r]   encoder_hidden_statesencoder_attention_maskr^   cross_attn_layer_head_maskr   cross_attn_promptrV   r`   	use_cachera   rb   c              	   C   s  |}| j ||	||||
|d\}}tjj|| j| jd}|| }| |}d}|dur|}| j||||||	|
d\}}tjj|| j| jd}|| }| |}|}| | 	|}tjj|| j
| jd}| |}tjj|| j| jd}|| }| |}|f}|
r|||f7 }|S )aD  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            encoder_hidden_states (`torch.FloatTensor`):
                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
                size `(decoder_attention_heads,)`.
            self_attn_prompt (`torch.FloatTensor`): prompt of self attention of shape
                `(2, decoder_attention_heads, pro_len, head_dim)`.
            cross_attn_prompt (`torch.FloatTensor`): prompt of cross attention of shape
                `(2, decoder_attention_heads, pro_len, head_dim)`.
            past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r[   rV   r]   r^   r_   r`   ra   re   N)r[   r\   r]   r^   r_   rV   r`   )r   r   rt   rI   rg   r   r   r   r   r   r   r   r   )r0   r[   r]   r   r   r^   r   r   r   rV   r`   r   ra   r   Zself_attn_weightsZcross_attn_weightsoutputsr%   r%   r&   r<   ^  sN    &
	

	


zMvpDecoderLayer.forward)N)NNNNNNNNFTN)r>   r?   r@   r   r/   r   r7   rC   r   r
   ry   rz   r   r<   rD   r%   r%   r1   r&   r   A  s8              r   c                       s@   e Zd ZdZeeeed fddZejejdddZ	  Z
S )MvpClassificationHeadz-Head for sentence-level classification tasks.)	input_dim	inner_dimnum_classespooler_dropoutc                    s8   t    t||| _tj|d| _t||| _d S )Nrf   )r.   r/   r   rP   denseDropoutrI   rT   )r0   r   r   r   r   r1   r%   r&   r/     s    
zMvpClassificationHead.__init__)r[   rb   c                 C   s6   |  |}| |}t|}|  |}| |}|S N)rI   r   r7   tanhrT   )r0   r[   r%   r%   r&   r<     s    




zMvpClassificationHead.forward)r>   r?   r@   rA   rB   rx   r/   r7   rC   r<   rD   r%   r%   r1   r&   r     s   r   c                       s8   e Zd ZdZ fddZejeej dddZ  Z	S )	MvpPromptz)Layer-wise prompt for encoder or decoder.c              	      s   t    |j| _|| _|| _|j| | _tj|j	d| _	t
|j|j| _tt|j|jt t|j|d |j | _d S )Nr   r,   )r.   r/   prompt_length
num_layersrH   r~   rN   r   r   rI   	Embeddingprompt_embeddingZ
SequentialrP   Zprompt_mid_dimZGELUprompt_trans)r0   r}   r   rH   r1   r%   r&   r/     s    
zMvpPrompt.__init__)
prompt_idsrb   c                 C   sN   |  | |}|| j| jd | j| j}| |}|g d	d}|S )Nr,   )r   r,   r   r   )
r   r   ro   r   r   rH   rN   rI   Zpermutesplit)r0   r   promptr%   r%   r&   r<     s
    
zMvpPrompt.forward)
r>   r?   r@   rA   r/   r7   rC   rz   r<   rD   r%   r%   r1   r&   r     s   r   c                   @   s2   e Zd ZU eed< dZdZdd Zedd Z	dS )	MvpPreTrainedModelr}   modelTc                 C   s|   | j j}t|tjr>|jjjd|d |jd urx|jj	  n:t|tj
rx|jjjd|d |jd urx|jj|j 	  d S )NrF   )meanstd)r}   Zinit_stdri   r   rP   r:   dataZnormal_rK   Zzero_r   padding_idx)r0   moduler   r%   r%   r&   _init_weights  s    

z MvpPreTrainedModel._init_weightsc                 C   s>   | j j}tjg ddddd|gg| jd}|||d}|S )N)r      
      r,   r         r,   r6   )r]   r   )r}   r   r7   Ztensorr6   ne)r0   Z	pad_tokenr   dummy_inputsr%   r%   r&   r     s    "zMvpPreTrainedModel.dummy_inputsN)
r>   r?   r@   r   __annotations__Zbase_model_prefixZsupports_gradient_checkpointingr   propertyr   r%   r%   r%   r&   r     s   
r   c                       s   e Zd ZdZd
eeej ee d fddZ	dee
j ee
j ee
j ee
j ee ee ee eeef ddd	Z  ZS )
MvpEncodera  
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`MvpEncoderLayer`].

    Args:
        config: MvpConfig
        embed_tokens (nn.Embedding): output embedding
        use_prompt (bool): whether to use prompt
    NFr}   embed_tokens
use_promptc                    s   t     j| _ j| _ j} j| _ j| _	 j
rBt|nd| _|d urX|| _nt j|| j| _t j|| _t fddt jD | _t|| _|| _|rʈ j| _t  j j| _d| _|    d S )N      ?c                    s   g | ]}t  qS r%   )r{   ).0rv   r|   r%   r&   
<listcomp>&      z'MvpEncoder.__init__.<locals>.<listcomp>F)!r.   r/   rI   Zencoder_layerdrop	layerdropr~   r   r   max_position_embeddingsZmax_source_positionsscale_embeddingmathsqrtembed_scaler   r   r   
vocab_sizer(   embed_positions
ModuleListrangeZencoder_layersrl   r   layernorm_embeddingr   r   r   r   r   gradient_checkpointing	post_init)r0   r}   r   r   rG   r1   r|   r&   r/     s4     zMvpEncoder.__init__)r   r]   	head_maskinputs_embedsr`   output_hidden_statesreturn_dictrb   c                 C   s~  |dur|n| j j}|dur |n| j j}|dur4|n| j j}|durV|durVtdn\|durz|}|j}	|d|	d }n8|dur| dd }	|dddddf }ntd|du r| || j	 }| 
|}
||
 }| |}tjj|| j| jd}| jr"t| j| j}| |}|dur8t||j}|rBdnd}|rPdnd}|dur| d t| jkrtdt| j d	| d  d
t| jD ]\}}|r||f }d}| jrtg }|| jk rd}|rd}n<||||dur|| nd| jr|| nd|d}|d }|r||d f }q|rR||f }|sptdd |||fD S t|||dS )a~  
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        NzDYou cannot specify both input_ids and inputs_embeds at the same timer!   z5You have to specify either input_ids or inputs_embedsre   r%   r   z&The head_mask should be specified for  layers, but it is for .FT)NN)r^   r   r`   r   c                 s   s   | ]}|d ur|V  qd S r   r%   r   vr%   r%   r&   	<genexpr>  r   z%MvpEncoder.forward.<locals>.<genexpr>last_hidden_stater[   
attentions) r}   r`   r   use_return_dictr$   r"   ro   rh   r   r   r   r   r   rt   rI   rg   r   r7   r8   r   rs   r6   r   r   r5   lenrl   	enumeraterandr   rz   r   )r0   r   r]   r   r   r`   r   r   inputinput_shapeZ	embed_posr[   r   r   Zencoder_statesZall_attentionsidxZencoder_layerZto_dropdropout_probabilitylayer_outputsr%   r%   r&   r<   6  sz    .









zMvpEncoder.forward)NF)NNNNNNN)r>   r?   r@   rA   r   r   r   r   ry   r/   r7   
LongTensorrC   r   r   rz   r   r<   rD   r%   r%   r1   r&   r     s,    (       
r   c                       s   e Zd ZdZd
eeej ee d fddZ	dee
j ee
j ee
j ee
j ee
j ee
j eee
j  ee
j ee ee ee ee ee
j eeef ddd	Z  ZS )
MvpDecoderz
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MvpDecoderLayer`]

    Args:
        config: MvpConfig
        embed_tokens (nn.Embedding): output embedding
        use_prompt (bool): whether to use prompt
    NFr   c                    s   t     j| _ j| _ j| _ j| _ j	r>t
 jnd| _|d urT|| _nt j j| j| _t j j| _t fddt jD | _t j| _|| _|rވ j| _t  j j| _t  j j| _d| _ | !  d S )Nr   c                    s   g | ]}t  |d qS ))rL   )r   )r   ir|   r%   r&   r     r   z'MvpDecoder.__init__.<locals>.<listcomp>F)"r.   r/   rI   Zdecoder_layerdropr   r   r   r   Zmax_target_positionsr   r   r   r~   r   r   r   r   r   r(   r   r   r   Zdecoder_layersrl   r   r   r   r   r   r   r   r   r   r   )r0   r}   r   r   r1   r|   r&   r/     s<     zMvpDecoder.__init__)r   r]   r   r   r   cross_attn_head_maskrV   r   r   r`   r   r   ra   rb   c                 C   s  |
dur|
n| j j}
|dur |n| j j}|	dur4|	n| j j}	|durH|n| j j}|durj|durjtdn\|dur|}|j}|d|d }n8|dur| dd }|dddddf }ntd|du r| 	|| j
 }| jr| jr|	rtd d}	|	r<|du r<|dur0tt| j dt| j dn
t| j d}|	rbt|trbtd t|}|durt| nd	}t||||}|dur|durt||j|d d
}| ||}|| }| |}tjj|| j| jd}| jrt| j !| j"}| #|}| $|}|rdnd}|
r*dnd}|
rB|durBdnd}t%||gddgD ]V\}}|durX| d	 t&| j'krXtd| dt&| j' d| d	  dqXt(| j'D ]\}}|r||f7 }| jrt)g }|| j*k rq||||||dur|| nd|dur$|| nd| jr6|| nd| jrH|| nd||
|	|d}|d	 }|
r||d f7 }|dur||d f7 }q|r||f7 }|stdd |||||fD S t+|||||dS )a  
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
                of the decoder.
            encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
                Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
                selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
                cross-attention on hidden heads. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.

                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.

                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        NzTYou cannot specify both decoder_input_ids and decoder_inputs_embeds at the same timer!   zEYou have to specify either decoder_input_ids or decoder_inputs_embedszZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fr|   zPassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.r   )ru   re   r%   r   r   zThe `z` should be specified for r   r   )	r   r^   r   r   r   rV   r`   r   ra   r   r,   c                 s   s   | ]}|d ur|V  qd S r   r%   r   r%   r%   r&   r     s   z%MvpDecoder.forward.<locals>.<genexpr>)r   rV   r[   r   cross_attentions),r}   r`   r   r   r   r$   r"   ro   rh   r   r   r   rg   loggerZwarning_oncer   r   ri   rz   Zfrom_legacy_cacheZget_seq_lengthr   r   r5   r   r   r   rt   rI   r   r7   r8   r   rs   r6   r   r   zipr   rl   r   r   r   r   )r0   r   r]   r   r   r   r   rV   r   r   r`   r   r   ra   r   r   r3   Z	positionsr[   r   r   r   Zall_hidden_statesZall_self_attnsZall_cross_attentionsZ	attn_maskZ	mask_namer   Zdecoder_layerr   r   r%   r%   r&   r<     s    Q











zMvpDecoder.forward)NF)NNNNNNNNNNNNN)r>   r?   r@   rA   r   r   r   r   ry   r/   r7   r   rC   r   listr   rz   r   r<   rD   r%   r%   r1   r&   r     sD   
 *             
r   c                       s   e Zd ZdgZddgZed fddZdd Zd	d
 Zdd Z	dd Z
edeej eej eej eej eej eej eej eeej  eeej  eej eej ee ee ee ee eej eeef dddZ  ZS )MvpModelfinal_logits_biasencoder.embed_tokens.weightdecoder.embed_tokens.weightr|   c                    sd   t  | |j|j }}|j| _t||j|| _t	|| j|j| _
t|| j|j| _|   d S r   )r.   r/   r   r   r   r   r   r~   sharedr   encoderr   decoderr   )r0   r}   r   r   r1   r%   r&   r/     s    zMvpModel.__init__c                 C   s   | j S r   )r   r0   r%   r%   r&   get_input_embeddings  s    zMvpModel.get_input_embeddingsc                 C   s   || _ | j | j_| j | j_d S r   )r   r   r   r   r0   valuer%   r%   r&   set_input_embeddings  s    
zMvpModel.set_input_embeddingsc                 C   s   | j S r   )r   r   r%   r%   r&   get_encoder  s    zMvpModel.get_encoderc                 C   sF   | j sJ d| d | jjd | jjd | jjd d S )NzHIf you want to use lightweight tuning, make sure that `use_prompt=True`.FT)r   requires_grad_r   r   r   r   r   r%   r%   r&   set_lightweight_tuning  s
    
zMvpModel.set_lightweight_tuningN)r   r]   decoder_input_idsdecoder_attention_maskr   decoder_head_maskr   encoder_outputsrV   r   decoder_inputs_embedsr   r`   r   r   ra   rb   c                 C   sL  |du r4|du r4|du r t dt|| jj| jj}|dur@|n| jj}|durT|n| jj}|durh|n| jj}|dur||n| jj}|du r| j	||||
|||d}nH|rt
|tst|d t|dkr|d ndt|dkr|d ndd}| j|||d ||||	||||||d}|s$|| S t|j|j|j|j|j|j|j|jd	S )
a*  
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            Mvp uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
            is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).

            For translation and summarization training, `decoder_input_ids` should be provided. If no
            `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
            for denoising pre-training following the paper.
        decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.

            If you want to change padding behavior, you should read [`modeling_mvp._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
            information on the default strategy.
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
            1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        NzIf no `decoder_input_ids` or `decoder_inputs_embeds` are passed, `input_ids` cannot be `None`. Please pass either `input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`.)r   r]   r   r   r`   r   r   r   r   r,   r   )r   r]   r   r   r   r   rV   r   r   r`   r   r   ra   )r   rV   decoder_hidden_statesdecoder_attentionsr   encoder_last_hidden_stater   encoder_attentions)r$   r'   r}   r   r    r`   r   r   r   r   ri   r   r   r   r   r   rV   r[   r   r   )r0   r   r]   r  r  r   r  r   r  rV   r   r  r   r`   r   r   ra   Zdecoder_outputsr%   r%   r&   r<     sp    3
zMvpModel.forward)NNNNNNNNNNNNNNNN)r>   r?   r@   Z"_keys_to_ignore_on_load_unexpected_tied_weights_keysr   r/   r   r   r  r  r   r   r7   r   rC   r   r   ry   r   rz   r   r<   rD   r%   r%   r1   r&   r     sV                   
r   ze
    The MVP Model with a language modeling head. Can be used for various text generation tasks.
    )Zcustom_introc                       s&  e Zd Zg dZed fddZdd Zdd Zdee	e e
ejd fddZed	dddZdd Zede	ej e	ej e	ej e	ej e	ej e	ej e	ej e	eej  e	eej  e	ej e	ej e	ej e	e
 e	e
 e	e
 e	e
 e	ej eeef dddZejdddZ  ZS )MvpForConditionalGeneration)r   r   lm_head.weightr|   c                    sX   t  | t|| _| dtd| jjjf t	j
|j| jjjdd| _|   d S )Nr   r   FrM   )r.   r/   r   r   register_bufferr7   rr   r   r)   r   rP   r~   lm_headr   r   r1   r%   r&   r/   f  s
    
z$MvpForConditionalGeneration.__init__c                 C   s
   | j  S r   )r   r  r   r%   r%   r&   r  o  s    z'MvpForConditionalGeneration.get_encoderc                 C   s
   | j  S r   )r   get_decoderr   r%   r%   r&   r  r  s    z'MvpForConditionalGeneration.get_decoderNT)new_num_tokenspad_to_multiple_ofmean_resizingrb   c                    s   t  |||}| | |S r   )r.   resize_token_embeddings_resize_final_logits_bias)r0   r  r  r  Znew_embeddingsr1   r%   r&   r  u  s    
z3MvpForConditionalGeneration.resize_token_embeddings)r  rb   c                 C   sj   | j jd }||kr,| j d d d |f }n.tjd|| f| j jd}tj| j |gdd}| d| d S )Nr!   r   r   rc   r   )r   r"   r7   rr   r6   rq   r  )r0   r  Zold_num_tokensZnew_biasZ
extra_biasr%   r%   r&   r  |  s    z5MvpForConditionalGeneration._resize_final_logits_biasc                 C   s   | j   | jd d S NFr   r  r  r  r   r%   r%   r&   r    s    
z2MvpForConditionalGeneration.set_lightweight_tuning)r   r]   r  r  r   r  r   r  rV   r   r  labelsr   r`   r   r   ra   rb   c                 C   s  |dur|n| j j}|durR|r*td d}|du rR|du rRt|| j j| j j}| j|||||||||	|
||||||d}| |d | j	 }d}|durt
 }||d| j j|d}|s|f|dd  }|dur|f| S |S t|||j|j|j|j|j|j|jd	S )	a  
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            Mvp uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
            is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).

            For translation and summarization training, `decoder_input_ids` should be provided. If no
            `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
            for denoising pre-training following the paper.
        decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.

            If you want to change padding behavior, you should read [`modeling_mvp._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
            information on the default strategy.
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
            1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example of summarization:

        Fine-tuning a model
        ```python
        >>> import torch
        >>> from transformers import AutoTokenizer, MvpForConditionalGeneration

        >>> tokenizer = AutoTokenizer.from_pretrained("RUCAIBox/mvp")
        >>> model = MvpForConditionalGeneration.from_pretrained("RUCAIBox/mvp")

        >>> inputs = tokenizer(
        ...     "Summarize: You may want to stick it to your boss and leave your job, but don't do it if these are your reasons.",
        ...     return_tensors="pt",
        ... )
        >>> labels = tokenizer("Bad Reasons To Quit Your Job", return_tensors="pt")["input_ids"]

        >>> loss = model(**inputs, labels=labels).loss
        >>> loss.backward()
        ```

        Inference after the model fine-tuned
        ```python
        >>> with torch.no_grad():
        ...     generated_ids = model.generate(**inputs)

        >>> generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        ```
        NzJThe `use_cache` argument is changed to `False` since `labels` is provided.F)r]   r  r  r  r   r  r   rV   r   r  r   r`   r   r   ra   r   r!   r   	losslogitsrV   r	  r
  r   r  r   r  )r}   r   r   warningr'   r   r    r   r  r   r   ro   r   r   rV   r	  r
  r   r  r   r  )r0   r   r]   r  r  r   r  r   r  rV   r   r  r  r   r`   r   r   ra   r   Z	lm_logitsZmasked_lm_lossloss_fctoutputr%   r%   r&   r<     s\    R
z#MvpForConditionalGeneration.forward)r  c                 C   s   t || jj| jjS r   )r'   r}   r   r    )r0   r  r%   r%   r&   %prepare_decoder_input_ids_from_labels  s    zAMvpForConditionalGeneration.prepare_decoder_input_ids_from_labels)NT)NNNNNNNNNNNNNNNNN)r>   r?   r@   r  r   r/   r  r  rB   r   ry   r   r   r  r  r  r   r7   r   rC   r   r   r   rz   r   r<   r!  rD   r%   r%   r1   r&   r  ^  sf   	 
	                 
 r  z
    Mvp model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
    tasks.
    c                       s   e Zd ZddgZed fddZdd Zedee	j
 ee	j ee	j
 ee	j
 ee	j ee	j ee	j eee	j  ee	j ee	j ee	j
 ee ee ee ee eeef d	d
dZ  ZS )MvpForSequenceClassificationr   r   r|   c                    sB   t  j|fi | t|| _t|j|j|j|j| _| 	  d S r   )
r.   r/   r   r   r   r~   
num_labelsZclassifier_dropoutclassification_headr   )r0   r}   kwargsr1   r%   r&   r/     s    
z%MvpForSequenceClassification.__init__c                 C   s   | j   | jd d S r  )r   r  r$  r  r   r%   r%   r&   r  )  s    
z3MvpForSequenceClassification.set_lightweight_tuningN)r   r]   r  r  r   r  r   r  r   r  r  r   r`   r   r   rb   c                 C   sB  |dur|n| j j}|dur d}|du rB|	durBtd| jj | j|||||||||	|
||||d}|d }|| j j|j	}t
t|ddkrtd||ddf |dd|ddddddf }| |}d}|dur| j jdu rX| j jdkrd	| j _n<| j jdkrP|jtjksF|jtjkrPd
| j _nd| j _| j jd	krt }| j jdkr|| | }n
|||}nP| j jd
krt }||d| j j|d}n| j jdkrt }|||}|s|f|dd  }|dur|f| S |S t|||j|j|j|j|j |j!|j"d	S )a  
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            Mvp uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
            is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).

            For translation and summarization training, `decoder_input_ids` should be provided. If no
            `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
            for denoising pre-training following the paper.
        decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.

            If you want to change padding behavior, you should read [`modeling_mvp._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
            information on the default strategy.
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
            1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Example of single-label classification:

        Fine-tuning a model on `num_labels` classes
        ```python
        >>> import torch
        >>> from transformers import AutoTokenizer, MvpForSequenceClassification

        >>> num_labels = 2  # for example, this is a binary classification task
        >>> tokenizer = AutoTokenizer.from_pretrained("RUCAIBox/mvp")
        >>> model = MvpForSequenceClassification.from_pretrained("RUCAIBox/mvp", num_labels=num_labels)

        >>> inputs = tokenizer("Classify: Hello, my dog is cute", return_tensors="pt")
        >>> labels = torch.tensor(1)  # the real label for inputs

        >>> loss = model(**inputs, labels=labels).loss
        >>> loss.backward()
        ```

        Inference after the model fine-tuned
        ```python
        >>> with torch.no_grad():
        ...     logits = model(**inputs).logits

        >>> predicted_class_id = logits.argmax()
        ```
        NFz8Passing input embeddings is currently not supported for r]   r  r  r   r  r   r  r   r  r   r`   r   r   r   r   z7All examples must have the same number of <eos> tokens.r!   Z
regressionZsingle_label_classificationZmulti_label_classificationr  )#r}   r   NotImplementedErrorr2   r>   r   eqZeos_token_idrs   r6   r   r7   Zunique_consecutivesumr$   ro   rh   r$  Zproblem_typer#  r5   r9   rB   r   squeezer   r   r   rV   r	  r
  r   r  r   r  )r0   r   r]   r  r  r   r  r   r  r   r  r  r   r`   r   r   r   r[   Zeos_maskZsentence_representationr  r  r  r   r%   r%   r&   r<   -  s    M$


*

z$MvpForSequenceClassification.forward)NNNNNNNNNNNNNNN)r>   r?   r@   r  r   r/   r  r   r   r7   r   rC   r   r   ry   r   rz   r   r<   rD   r%   r%   r1   r&   r"    sJ                  
r"  c                       s   e Zd ZddgZ fddZdd Zedeej	 eej	 eej
 eej
 eej	 eej	 eej	 eeej  eej
 eej
 eej eej ee ee ee ee eeef dd	d
Z  ZS )MvpForQuestionAnsweringr   r   c                    sB   t  | d|_|j| _t|| _t|j|j| _| 	  d S r+   )
r.   r/   r#  r   r   r   rP   hidden_size
qa_outputsr   r   r1   r%   r&   r/     s    
z MvpForQuestionAnswering.__init__c                 C   s   | j   | jd d S r  )r   r  r-  r  r   r%   r%   r&   r    s    
z.MvpForQuestionAnswering.set_lightweight_tuningN)r   r]   r  r  r   r  r   r  start_positionsend_positionsr   r  r   r`   r   r   rb   c                 C   s  |dur|n| j j}|	dur(|
dur(d}| j||||||||||||||d}|d }| |}|jddd\}}|d }|d }d}|	dur&|
dur&t|	 dkr|	d}	t|
 dkr|
d}
|d}|		d|}	|
	d|}
t
|d}|||	}|||
}|| d	 }|sX||f|dd  }|durT|f| S |S t||||j|j|j|j|j|j|jd

S )a`  
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            Mvp uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
            is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).

            For translation and summarization training, `decoder_input_ids` should be provided. If no
            `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
            for denoising pre-training following the paper.
        decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.

            If you want to change padding behavior, you should read [`modeling_mvp._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
            information on the default strategy.
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
            1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        Example:

        Fine-tuning a model for extrative question answering, and our model also supports generative question answering
        using `BartForConditionalGeneration`
        ```python
        >>> import torch
        >>> from transformers import AutoTokenizer, MvpForQuestionAnswering

        >>> tokenizer = AutoTokenizer.from_pretrained("RUCAIBox/mvp")
        >>> model = MvpForQuestionAnswering.from_pretrained("RUCAIBox/mvp")

        >>> inputs = tokenizer(
        ...     "Answer the following question: Who was Jim Henson? [SEP] Jim Henson was a nice puppet",
        ...     return_tensors="pt",
        ... )
        >>> target_start_index = torch.tensor([18])
        >>> target_end_index = torch.tensor([19])

        >>> loss = model(**inputs, start_positions=target_start_index, end_positions=target_end_index).loss
        >>> loss.backward()
        ```

        Inference after the model fine-tuned
        ```python
        >>> with torch.no_grad():
        ...     outputs = model(**inputs)

        >>> answer_start_index = outputs.start_logits.argmax()
        >>> answer_end_index = outputs.end_logits.argmax()

        >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
        >>> predict_answer = tokenizer.decode(predict_answer_tokens)
        ```
        NFr&  r   r   r!   rc   )Zignore_indexr,   )
r  start_logits
end_logitsrV   r	  r
  r   r  r   r  )r}   r   r   r-  r   r*  
contiguousr   rh   r   r   r   rV   r	  r
  r   r  r   r  )r0   r   r]   r  r  r   r  r   r  r.  r/  r   r  r   r`   r   r   r   Zsequence_outputr  r0  r1  Z
total_lossZignored_indexr  Z
start_lossZend_lossr   r%   r%   r&   r<     sp    S







zMvpForQuestionAnswering.forward)NNNNNNNNNNNNNNNN)r>   r?   r@   r  r/   r  r   r   r7   rC   r   r   r   ry   r   rz   r   r<   rD   r%   r%   r1   r&   r+    sN                   
r+  c                       s(   e Zd ZdZ fddZdd Z  ZS )MvpDecoderWrapperz
    This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
    used in combination with the [`EncoderDecoderModel`] framework.
    c                    s   t  | t|| _d S r   )r.   r/   r   r   r   r1   r%   r&   r/   u  s    zMvpDecoderWrapper.__init__c                 O   s   | j |i |S r   )r   )r0   argsr%  r%   r%   r&   r<   y  s    zMvpDecoderWrapper.forward)r>   r?   r@   rA   r/   r<   rD   r%   r%   r1   r&   r3  o  s   r3  c                       s   e Zd ZdgZ fddZdd Zdd Zdd	 Zd
d Zdd Z	e
deej eej eej eej eej eej eeej  eej eej ee ee ee ee eej eeef dddZ  ZS )MvpForCausalLMr  c                    sD   d|_ d|_t | t|| _tj|j|j	dd| _
|   d S )NTFrM   )rJ   Zis_encoder_decoderr.   r/   r3  r   r   rP   r,  r   r  r   r   r1   r%   r&   r/     s    
zMvpForCausalLM.__init__c                 C   s
   | j jjS r   r   r   r   r   r%   r%   r&   r     s    z#MvpForCausalLM.get_input_embeddingsc                 C   s   || j j_d S r   r6  r   r%   r%   r&   r     s    z#MvpForCausalLM.set_input_embeddingsc                 C   s   || j _d S r   r   r   )r0   r   r%   r%   r&   set_decoder  s    zMvpForCausalLM.set_decoderc                 C   s   | j jS r   r7  r   r%   r%   r&   r    s    zMvpForCausalLM.get_decoderc                 C   s   | j   | jd d S r  r  r   r%   r%   r&   r    s    
z%MvpForCausalLM.set_lightweight_tuningN)r   r]   r   r   r   r   rV   r   r  r   r`   r   r   ra   rb   c                 C   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}| jj|||||||||
|||d}| |d }d}|	durt }||d| j j	|	d}|s|f|dd  }|dur|f| S |S t
|||j|j|j|jdS )a  
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, MvpForCausalLM

        >>> tokenizer = AutoTokenizer.from_pretrained("RUCAIBox/mvp")
        >>> model = MvpForCausalLM.from_pretrained("RUCAIBox/mvp", add_cross_attention=False)

        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> logits = outputs.logits
        >>> list(logits.shape)
        [1, 8, 50267]
        ```N)r   r]   r   r   r   r   rV   r   r   r`   r   r   r   r!   r   )r  r  rV   r[   r   r   )r}   r`   r   r   r   r   r  r   ro   r   r   rV   r[   r   r   )r0   r   r]   r   r   r   r   rV   r   r  r   r`   r   r   ra   r   r  r  r  r   r%   r%   r&   r<     sD    -zMvpForCausalLM.forward)NNNNNNNNNNNNNN)r>   r?   r@   r  r/   r   r   r8  r  r  r   r   r7   r   rC   r   r   ry   r   rz   r   r<   rD   r%   r%   r1   r&   r5  }  sN                 
r5  )r5  r  r+  r"  r   r   )BrA   r   typingr   r   r7   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   Zactivationsr	   Zcache_utilsr
   r   r   Z
generationr   Zmodeling_attn_mask_utilsr   r   Zmodeling_layersr   Zmodeling_outputsr   r   r   r   r   r   r   Zmodeling_utilsr   utilsr   r   Zutils.deprecationr   Zconfiguration_mvpr   Z
get_loggerr>   r   rC   rB   r'   r   r(   ModulerE   r{   r   r   r   r   r   r   r   r  r"  r+  r3  r5  __all__r%   r%   r%   r&   <module>   sf   $	
 !Cw 5  
  1 - *v