a
    h{                    @   s  d Z ddlZddlZddlZddlmZ ddlmZmZ ddl	Z	ddl
Z	ddl	mZmZ ddlmZ ddlmZ dd	lmZmZmZ dd
lmZ ddlmZ ddlmZ ddlmZ ddlmZmZm Z  ddl!m"Z" ddl#m$Z$ e %e&Z'dFddZ(dd Z)dGddZ*dd Z+eeddG dd deZ,eeddG d d! d!eZ-eed"dG d#d$ d$eZ.eed"dG d%d& d&eZ/eG d'd( d(eZ0G d)d* d*ej1Z2G d+d, d,ej3Z4G d-d. d.ej3Z5G d/d0 d0ej3Z6G d1d2 d2eZ7G d3d4 d4eZ8ed5dG d6d7 d7e0Z9ed8dG d9d: d:e0Z:eG d;d< d<e0Z;ed=dG d>d? d?e0eZ<ed@dG dAdB dBe0eZ=G dCdD dDe0Z>g dEZ?dS )HzRPyTorch ProphetNet model, ported from ProphetNet repo(fairsequery_states version).    N)	dataclass)OptionalUnion)Tensornn)	LayerNorm   )ACT2FN)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)GradientCheckpointingLayer)BaseModelOutput)PreTrainedModel)ModelOutputauto_docstringlogging)deprecate_kwarg   )ProphetNetConfigFc                 C   s0   |rt jj|  |dS t jj| |tjdS d S )Ndimr   dtype)r   
functionalsoftmaxfloattorchfloat32)Zhidden_stater   
onnx_trace r!   n/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/prophetnet/modeling_prophetnet.pyr   *   s    r   c                 C   s   t j|| | f||dt |j }|  }t|D ]*}|| jddd || | d  q6d|dddddf< t j	||gddS )	z@
    This function computes the bias for the predict stream
    )devicer   r   F)wrapr   N   r   )
r   onesfinfomindetachclonerangeZfill_diagonal_Ztriu_cat)sequence_lengthngramr#   r   Z
left_blockZright_blockZ
stream_idxr!   r!   r"   ngram_attention_bias1   s     r/   c           	      C   s   | }d}|r@| d } |t |t | |   }t |}nt |t |}| d }t ||}|t | | t||  | |   }t 	|t 
|| d   }|t || | }|S )zo
    This function computes individual parts of the relative position buckets. For more detail, see paper.
    r   r%   r   )r   lt
zeros_likeintabsmaxlogr   mathr(   Z	ones_likewhere)	num_bucketsmax_distancerelative_positionsis_bidirectionalZinv_relative_positionsZrel_positions_bucketZ	max_exactZis_smallZval_if_larger!   r!   r"   compute_relative_bucketsB   s(    r<   c                 C   s   | dd|dd}|| d }tj|d |fdd d}|d|dd}|| d }t| ||dd}t| ||dd}||fS )zm
    This function computes both main and predict relative position buckets. For more detail, see paper.
    r   r   F)r;   )	unsqueezerepeatsizer   r,   r<   )r8   r9   position_idsZmain_stream_relative_positionsZ$predicting_stream_relative_positionsmain_relative_position_buckets!predict_relative_position_bucketsr!   r!   r"   #compute_all_stream_relative_buckets]   s    rD   zF
    Base class for sequence-to-sequence language models outputs.
    )Zcustom_introc                   @   s  e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeej  ed< dZeeej  ed< dZeeej  ed< dZeeej  ed	< dZeeej  ed
< dZeeej  ed< dZeej ed< dZeeej  ed< dZeeej  ed< edd ZdS )ProphetNetSeq2SeqLMOutputa
	  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss.
    logits (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, config.vocab_size)`):
        Prediction scores of the main stream language modeling head (scores for each vocabulary token before
        SoftMax).
    logits_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`):
        Prediction scores of the predict stream language modeling head (scores for each vocabulary token before
        SoftMax).
    past_key_values (`list[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
        num_attn_heads, decoder_sequence_length, embed_size_per_head)`).

        Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
        used (see `past_key_values` input) to speed up sequential decoding.
    decoder_ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
        shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.

        Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
        outputs.
    decoder_ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
        decoder_sequence_length, decoder_sequence_length)`.

        Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
        weighted average in the self-attention heads.
    encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
        Sequence of hidden-states at the output of the last layer of the encoder of the model.
    Nlosslogitslogits_ngrampast_key_valuesdecoder_hidden_statesdecoder_ngram_hidden_statesdecoder_attentionsdecoder_ngram_attentionscross_attentionsencoder_last_hidden_stateencoder_hidden_statesencoder_attentionsc                 C   s   t dt | jS Nzi`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions` instead.warningswarnFutureWarningrN   selfr!   r!   r"   decoder_cross_attentions   s
    z2ProphetNetSeq2SeqLMOutput.decoder_cross_attentions)__name__
__module____qualname____doc__rF   r   r   FloatTensor__annotations__rG   rH   rI   tuplerJ   rK   rL   rM   rN   rO   rP   rQ   propertyrY   r!   r!   r!   r"   rE   t   s   
rE   z
    Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
    decoding.
    c                   @   s   e Zd ZU dZejed< dZeej ed< dZ	ee
ej  ed< dZee
ej  ed< dZee
ej  ed< dZee
ej  ed< dZee
ej  ed	< dZee
ej  ed
< dZeej ed< dZee
ej  ed< dZee
ej  ed< edd ZdS )ProphetNetSeq2SeqModelOutputa  
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, hidden_size)`):
        Sequence of main stream hidden-states at the output of the last layer of the decoder of the model.

        If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
        hidden_size)` is output.
    last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`, *optional*):
        Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model.
    past_key_values (`list[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
        num_attn_heads, decoder_sequence_length, embed_size_per_head)`).

        Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
        used (see `past_key_values` input) to speed up sequential decoding.
    decoder_ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
        shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.

        Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
        outputs.
    decoder_ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
        decoder_sequence_length, decoder_sequence_length)`.

        Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
        weighted average in the
    encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
        Sequence of hidden-states at the output of the last layer of the encoder of the model.
    last_hidden_stateNlast_hidden_state_ngramrI   rJ   rK   rL   rM   rN   rO   rP   rQ   c                 C   s   t dt | jS rR   rS   rW   r!   r!   r"   rY      s
    z5ProphetNetSeq2SeqModelOutput.decoder_cross_attentions)rZ   r[   r\   r]   r   r^   r_   rd   r   rI   r`   rJ   rK   rL   rM   rN   rO   rP   rQ   ra   rY   r!   r!   r!   r"   rb      s   

rb   zs
    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
    c                   @   s   e Zd ZU dZejed< dZeej ed< dZ	ee
ej  ed< dZee
ej  ed< dZee
ej  ed< dZee
ej  ed< dZee
ej  ed	< dZee
ej  ed
< dS )ProphetNetDecoderModelOutputa  
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, hidden_size)`):
        Sequence of main stream hidden-states at the output of the last layer of the decoder of the model.

        If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
        hidden_size)` is output.
    last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`):
        Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model.
    past_key_values (`list[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
        num_attn_heads, decoder_sequence_length, embed_size_per_head)`).

        Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
        used (see `past_key_values` input) to speed up sequential decoding.
    hidden_states_ngram (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
        shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.

        Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
        outputs.
    ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
        decoder_sequence_length, decoder_sequence_length)`.

        Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
        weighted average in the
    rc   Nrd   rI   hidden_stateshidden_states_ngram
attentionsngram_attentionsrN   )rZ   r[   r\   r]   r   r^   r_   rd   r   rI   r`   rf   rg   rh   ri   rN   r!   r!   r!   r"   re      s   

re   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeej  ed< dZeeej  ed< dZeeej  ed< dZeeej  ed	< dZeeej  ed
< dZeeej  ed< dS )ProphetNetDecoderLMOutputa	  
    ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
        shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.

        Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
        outputs.
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss.
    logits (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, config.vocab_size)`):
        Prediction scores of the main stream language modeling head (scores for each vocabulary token before
        SoftMax).
    logits_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`):
        Prediction scores of the predict stream language modeling head (scores for each vocabulary token before
        SoftMax).
    past_key_values (`list[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
        num_attn_heads, decoder_sequence_length, embed_size_per_head)`).

        Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
        used (see `past_key_values` input) to speed up sequential decoding.
    hidden_states_ngram (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
        shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.

        Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
        outputs.
    ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
        decoder_sequence_length, decoder_sequence_length)`.

        Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
        weighted average in the
    NrF   rG   rH   rI   rf   rg   rh   ri   rN   )rZ   r[   r\   r]   rF   r   r   r^   r_   rG   rH   rI   r`   rf   rg   rh   ri   rN   r!   r!   r!   r"   rj     s   
#rj   c                   @   s.   e Zd ZU eed< dZdZdd Zdd ZdS )	ProphetNetPreTrainedModelconfig
prophetnetTc                 C   s|   t |tjr:|jjjd| jjd |jd urx|jj	  n>t |tj
rx|jjjd| jjd |jd urx|jj|j 	  d S )N        )meanZstd)
isinstancer   LinearweightdataZnormal_rl   Zinit_stdbiasZzero_	Embeddingpadding_idx)rX   moduler!   r!   r"   _init_weightsU  s    

z'ProphetNetPreTrainedModel._init_weightsc                 C   s   | j j}| j j}|d us J d||j}|dd df  |ddd f< ||d< |d usdJ d||dk| t|dk	 sJ d	|S )
Nzself.model.config.decoder_start_token_id has to be defined. In ProphetNet it is usually set to the pad_token_id. See ProphetNet docs for more information.r=   r   ).r   z1self.model.config.pad_token_id has to be defined.r   z8Verify that `shifted_input_ids` has only positive values)
rl   decoder_start_token_idpad_token_id	new_zerosshaper*   Zmasked_fill_r   allitem)rX   	input_idsrz   r{   Zshifted_input_idsr!   r!   r"   _shift_right_  s    
 z&ProphetNetPreTrainedModel._shift_rightN)	rZ   r[   r\   r   r_   Zbase_model_prefixZsupports_gradient_checkpointingrx   r   r!   r!   r!   r"   rk   O  s
   

rk   c                       sB   e Zd ZdZedd fddZd
 fdd	Z fdd	Z  ZS )ProphetNetPositionalEmbeddingsa  
    This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
    based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
    the forward function.
    N)rl   returnc                    s"   |j | _t |j |j|j d S N)max_position_embeddings
max_lengthsuper__init__hidden_sizer{   rX   rl   	__class__r!   r"   r   }  s    z'ProphetNetPositionalEmbeddings.__init__c                    s   |d u s| j d u sJ d|d u r|d url| dkrl| }|d | }tjdtj|dt| j |  }nN|d u rtj|tj|d}tj|dd||  | j  }|d| j	d }t
 ||fS )NzCIf position_ids is pre-computed then padding_idx should not be set.r   r   )r   r   r   r#   r   )rv   get_seq_lengthr   r&   longr2   Zcumsumtype_asclampr   r   forward)rX   Zinputs_shaper#   attention_maskrI   rA   Zprev_num_input_idsZnum_input_idsr   r!   r"   r     s$    z&ProphetNetPositionalEmbeddings.forwardc                    s   t  |S r   )r   r   )rX   rA   r   r!   r"   _forward  s    z'ProphetNetPositionalEmbeddings._forward)NNN)	rZ   r[   r\   r]   r   r   r   r   __classcell__r!   r!   r   r"   r   v  s   r   c                       s   e Zd ZdZdeeee d fddZedddd	dee	 ee	 ee	 ee
 ee eej	 ee	ee	 f dddZ  ZS )ProphetNetAttentionz=Multi-headed attention from 'Attention Is All You Need' paperN)rl   num_attn_heads	layer_idxc                    s   t    |j}|j| _|j| _|| _|| | _|| _| j| |ksLJ dt	||| _
t	||| _t	||| _t	||| _d S )Nzw`config.hidden_size` must be divisible by `config.num_encoder_attention_heads` and `config.num_decoder_attention_heads`)r   r   r   attention_dropoutdropoutr   head_dimr   r   rq   key_proj
value_proj
query_projout_proj)rX   rl   r   r   r   r   r!   r"   r     s    

zProphetNetAttention.__init__past_key_valuerI   4.58new_nameversionF)key_value_statesr   layer_head_maskrI   output_attentionscache_positionr   c                 C   sb  |  \}}	}
|d u}t|  ||	|
gksJJ d||	|
f d|   | || jd  }|d urt|tr|j| j}|r|j	}q|j
}n|}|r|n|}|r|d ur|r|j| j j}|j| j j}n| |}| |}||d| j| jdd}||d| j| jdd}|d ur`|s0|nd }|||| jd|i\}}|r`d|j| j< |||	| j| jdd}| d}td	||dd
}|| j|	|f}|  |krtd| d|   |d ur| dkrd }|| jd|f}|d ur*|  |kr*td| d|   |d ur<|| }|rH|}nd }tjj|dd}|d ur|  | jfksJ d| jf d|   |dddd||| j|	| }|dddd| }tjj|| j| jd}td	||}|| j|	| jf}|  |kr$td| d|   |dd||	|
}| |}tjj|| j| jd}||fS )Nz Size of hidden states should be 	, but is       ?r=   r   r%   r   Tzbsij,bsjk->bsikr   z#Attention weights should have size r   z Attention mask should have size r   /Head mask for a single layer should be of size ptrainingz `attn_output` should have shape , but is of shape ) r@   listr   r   rp   r   
is_updatedgetr   Zcross_attention_cacheself_attention_cachelayerskeysvaluesr   r   viewr   	transposeupdater   einsum
ValueErrorr   r   r   r   r   r   r   reshaper   )rX   rf   r   r   r   rI   r   r   
batch_sizetgt_lenr   Zis_cross_attentionquery_statesr   curr_past_key_valueZcurrent_states
key_statesvalue_statessrc_lenattn_weightsZexpected_shapeZattn_weights_reshapedZ
attn_probsattn_outputr!   r!   r"   r     s    









zProphetNetAttention.forward)N)NNNNFN)rZ   r[   r\   r]   r   r2   r   r   r   r   r
   boolr   r`   r   r   r!   r!   r   r"   r     s$         r   c                       s0   e Zd ZdZeed fddZdd Z  ZS )ProphetNetFeedForwardzm
    This is the residual two feed-forward layer block based on the original Transformer implementation.
    )rl   ffn_dimc                    sJ   t    t|j | _t|j|| _t||j| _	|j
| _
|j| _d S r   )r   r   r	   Zactivation_functionactivation_fnr   rq   r   intermediateoutputactivation_dropoutr   )rX   rl   r   r   r!   r"   r   ,  s    
zProphetNetFeedForward.__init__c                 C   sN   |  |}| |}tjj|| j| jd}| |}tjj|| j| jd}|S )Nr   )r   r   r   r   r   r   r   r   )rX   rf   r!   r!   r"   r   4  s    


zProphetNetFeedForward.forward)	rZ   r[   r\   r]   r   r2   r   r   r   r!   r!   r   r"   r   '  s   r   c                	       sj   e Zd Zded fddZdd Zdd Zed	d
dddee	e
  dddZdd Zdd Z  ZS )ProphetNetNgramSelfAttentionNrl   c                    s   t    |j| _|j| _|j| _|j| _|j| _|j| _|j| j | _	|j
| _
|| _| j	| j |jkspJ dt|j|j| _t|j|j| _t|j|j| _t|j|j| _t|j| j| j | _d| _d S )Nz6config.hidden_size must be divisible by num_attn_headsF)r   r   r   r8   relative_max_distancenum_decoder_attention_headsr   r   r   r   r.   r   r   rq   r   r   r   r   relative_pos_embeddingsr    rX   rl   r   r   r!   r"   r   ?  s&    
z%ProphetNetNgramSelfAttention.__init__c                 C   s    | ||| j| jdd S Nr   r%   )r   r   r   r   
contiguous)rX   ZtensorZseq_lenr   r!   r!   r"   _shape]  s    z#ProphetNetNgramSelfAttention._shapec                 C   s
   d| _ d S )NT)r    rW   r!   r!   r"   prepare_for_onnx_export_`  s    z5ProphetNetNgramSelfAttention.prepare_for_onnx_export_r   rI   r   r   )rI   c
           *         sl  |  \}
}}t|  |
||gks@J d|
||f d|j | |}| |}| |}|| jd  }| |||
}| |d|
}| |d|
}|
| jd| jf}|j	| }|j	| }|j	| }|j
d| j dd}|j
d| j dd}|j
d| j dd}|j
d| j dd}|d |dd   }}|d |dd   }}|d |dd    }|d |dd   }|d urt|tr|j}n|}| | jd	|	i\ |d| j  }td
| dd}| ||||}|| }|d ur|| }t|d| jd|}|d urf|  | jfksDJ d| jf d|   |dddd||
| jd| }tjj|| j| jd}td
|} | dd	|
d||} | | } t|d|
| j| j|| j}!t fdd|D d}"tj|dd}#t fdd|D d}$td|!|"f}%| !|#|%||}&|%|& }%|d urf|"ddddd}|#|%j$}|%| }%t|%d| jd|%}'|d ur|  | jfksJ d| jf d|   |ddddd|' }'tjj|'| j| jd}'td|'|$ddf}(|(dd}(|(	|
| j||}(| |(}(t | |(gd|
d|})||
| j|d}tjj|)| j| jd})|)||'fS )Nz#`hidden_states` should be of shape r   r   r=   r   r   r%   r   r   zbntc,bncs->bntsr   )r   r    r   r   r   c                    s   g | ]}t  |gd qS r%   )r   r,   ).0key)main_key_statesr!   r"   
<listcomp>      z8ProphetNetNgramSelfAttention.forward.<locals>.<listcomp>c                    s"   g | ]}t  |gd d qS r   )r   r,   r>   )r   Zv_p)main_value_statesr!   r"   r     r   zbnhtc,bnhsc->bnhts   zbnhts,bnhsc->bnhtc)%r@   r   r}   r   r   r   r   r   r   r   chunkr.   rp   r   r   r   r   r   r   r    get_main_relative_pos_embeddingsr   r    r   r   r   r   r   r   r   r   stackr,   #get_predict_relative_pos_embeddingspermutetor   )*rX   rf   rI   r   r   extended_predict_attention_maskrB   rC   rA   r   r   Zngram_sequence_lengthr   r   r   r   Z
proj_shapeZhidden_states_listZquery_states_listZkey_states_listZvalue_states_listZmain_hidden_statesZhidden_states_predict_listZmain_query_statesZpredict_query_states_listZpredict_key_states_listZpredict_value_states_listr   r-   Zmain_attn_weightsmain_relative_pos_embeddingsZmain_attn_probsZmain_attn_outputZpredict_query_statesZpredict_key_statesZpredict_hidden_statesZpredict_value_statesZpredict_attn_weightspredict_relative_pos_embeddingsZpredict_attn_probsZpredict_attn_outputr   r!   )r   r   r"   r   c  s    














z$ProphetNetNgramSelfAttention.forwardc                 C   sH  |j \}}}}|||||}|d u r|j d d \}}	td|j d d dd||	d|j}
|
|d||	d }
t| j	| j
|
d}| |}||j d d | j	| jf }|dddd}||j d d d }|d| jd}|d|j d }| }|d|d}tj|d|d}||||d}|S )	Nr%   r   r=   r   Fr   )r=   r   index)r}   r   r   aranger>   r?   r   r#   r<   r8   r   r   r   r   r   r   r@   gather)rX   rf   r   rA   rB   r   r   r   r   r-   r:   rel_pos_embeddingsr   r!   r!   r"   r     sB    

z=ProphetNetNgramSelfAttention.get_main_relative_pos_embeddingsc                 C   sH  |j dd \}}|d u r|j d }|d d |d ks@J dtd|dd||d|j}||d||d }t| j| j	|d}|
dd}| |}	|	|j d d | j| jf }	|	ddddd}	|	d| j}	|d}|| jd| jd}|d|d }tj|	d|d	}
|
|| j| j|d}
|
S )
Nr   r%   r=   r   zb`position_ids` are incorrect. They should be of the format 1 2 3 4 5 ... (key_sequence_length - 1)Fr   r   r   )r}   r   r   r>   r?   r   r#   r<   r8   r   r   r   r   r   r   r   r.   r@   r   r   )rX   rf   r   rA   rC   r   r-   Zkey_sequence_lengthr:   r   r   r!   r!   r"   r   E  sR    



z@ProphetNetNgramSelfAttention.get_predict_relative_pos_embeddings)N)NNNNNNNN)rZ   r[   r\   r   r   r   r   r   r   r`   r   r   r   r   r   r!   r!   r   r"   r   >  s$           
 5-r   c                       s6   e Zd ZdZed fddZd	edddZ  ZS )
ProphetNetEncoderLayerz&
    Encoder block for Prophetnet
    r   c                    sB   t    t||j| _t|j| _t||j	| _
t|j| _d S r   )r   r   r   num_encoder_attention_heads	self_attnr   r   self_attn_layer_normr   Zencoder_ffn_dimfeed_forwardfeed_forward_layer_normr   r   r!   r"   r     s
    
zProphetNetEncoderLayer.__init__F)r   c           	      C   sT   | j ||||d\}}| || }| |}| || }|f}|rP||f7 }|S )N)rf   r   r   r   )r   r   r   r   )	rX   rf   r   r   r   attention_outputr   feed_forward_outputoutputsr!   r!   r"   r     s    


zProphetNetEncoderLayer.forward)F)	rZ   r[   r\   r]   r   r   r   r   r   r!   r!   r   r"   r     s    r   c                       sX   e Zd ZdZded fddZedddd	dee ee ee	j
 dddZ  ZS )ProphetNetDecoderLayerz&
    Decoder block for Prophetnet
    Nr   c                    sf   t    t||d| _t|j| _|jrHt||j	|d| _
t|j| _t||j| _t|j| _d S )Nr   )r   r   r   r   r   r   r   add_cross_attentionr   r   
cross_attncross_attn_layer_normr   Zdecoder_ffn_dimr   r   r   r   r!   r"   r     s    
zProphetNetDecoderLayer.__init__r   rI   r   r   TF)	use_cacher   r   c              
   C   s   | j |||||||	|
d\}}}| || }d }|d urb| j||||||d\}}| || }| |}| || }|f}|r||||f7 }|S )N)rf   rI   r   r   r   rB   rC   rA   )rf   r   r   r   rI   r   )r   r   r   r   r   r   )rX   rf   r   rP   encoder_attn_maskr   cross_attn_layer_head_maskr   rB   rC   rA   rI   r   r   r   Zngram_attention_outputZself_attn_weightsZself_attn_weights_ngramZcross_attn_weightsr   r   r   r!   r!   r"   r     s8    


zProphetNetDecoderLayer.forward)N)NNNNNNNNNNTFN)rZ   r[   r\   r]   r   r   r   r   r   r   r   r   r   r!   r!   r   r"   r     s*                r   z=
    The standalone encoder part of the ProphetNetModel.
    c                       s   e Zd Zdeejd fddZdd Zdd Ze	de
ej e
ej e
ej e
ej e
e e
e e
e eeef d	d
dZ  ZS )ProphetNetEncoderNrl   word_embeddingsc                    sx   t    |dur|ntj j j jd| _t | _	t
 j| _t fddt jD | _d| _|   dS )7  
        word_embeddings (`torch.nn.Embeddings` of shape `(config.vocab_size, config.hidden_size)`, *optional*):
            The word embedding parameters. This can be used to initialize [`ProphetNetEncoder`] with pre-defined word
            embeddings instead of randomly initialized word embeddings.
        Nrv   c                    s   g | ]}t  qS r!   )r   )r   _r   r!   r"   r     r   z.ProphetNetEncoder.__init__.<locals>.<listcomp>F)r   r   r   ru   
vocab_sizer   r{   r   r   position_embeddingsr   embeddings_layer_norm
ModuleListr+   Znum_encoder_layersr   gradient_checkpointing	post_initrX   rl   r   r   r   r"   r     s    
 zProphetNetEncoder.__init__c                 C   s   | j S r   r   rW   r!   r!   r"   get_input_embeddings  s    z&ProphetNetEncoder.get_input_embeddingsc                 C   s
   || _ d S r   r  rX   valuer!   r!   r"   set_input_embeddings  s    z&ProphetNetEncoder.set_input_embeddings)r   r   	head_maskinputs_embedsr   output_hidden_statesreturn_dictr   c                 C   s&  |dur|n| j j}|dur |n| j j}|dur4|n| j j}|du rV|du rVtdn4|durp|durptdn|dur|du r| |}|durd|ddddddf d| j jdd t	| j
j }||j
}nd}| |jdd |j\}	}
||	 }| |}tjj|| j j| jd}|r.dnd}|r<dnd}|dur| d	 t| jksJ d
t| j d| d	  dt| jD ]X\}}|r||f }||||dur|| nd|d}|d	 }|r||d f }q|r||f }|stdd |||fD S t|||dS )a	  
        Example:

        ```python
        >>> from transformers import AutoTokenizer, ProphetNetEncoder
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
        >>> model = ProphetNetEncoder.from_pretrained("patrickvonplaten/prophetnet-large-uncased-standalone")
        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> last_hidden_states = outputs.last_hidden_state
        ```Nz3Either input_ids or inputs_embeds has to be passed.z2Make sure to only pass input_ids or inputs_embeds.      ?r   r%   r   r!   r   z&The head_mask should be specified for  layers, but it is for .)r   r   r   c                 s   s   | ]}|d ur|V  qd S r   r!   r   vr!   r!   r"   	<genexpr>l  r   z,ProphetNetEncoder.forward.<locals>.<genexpr>)rc   rf   rh   )rl   r   r  use_return_dictr   r   r?   r   r   r'   r   r(   r   r  r}   r#   r  r   r   r   r   r@   lenr   	enumerater`   r   )rX   r   r   r  r  r   r  r  extended_attention_maskr  rA   rf   rP   Zall_attentionsidxZencoder_layerlayer_outputsr!   r!   r"   r     sZ    


*



zProphetNetEncoder.forward)N)NNNNNNN)rZ   r[   r\   r   r   ru   r   r  r  r   r   r   r   r   r   r`   r   r   r   r!   r!   r   r"   r     s*          
r   z=
    The standalone decoder part of the ProphetNetModel.
    c                       s   e Zd Zdeeej d fddZdd Zdd Z	e
deej eej eej eej eej eej eeeej   eej ee ee ee ee eej eeef d	d
dZdd Zdd Zdd Z  ZS )ProphetNetDecoderNr   c                    s   t     j| _ j| _ j| _ j| _ j| _|dur@|ntj	 j
 j jd| _t | _t	| j jd| _t fddt jD | _t j| _d| _|   dS )r  Nr  c                    s   g | ]}t  |d qS )r   )r   )r   ir   r!   r"   r     r   z.ProphetNetDecoder.__init__.<locals>.<listcomp>F)r   r   r.   r8   r   r   r   max_target_positionsr   ru   r  r   r{   r   r   r  ngram_embeddingsr  r+   Znum_decoder_layersr   r   r  r  r	  r
  r   r   r"   r   x  s$    
zProphetNetDecoder.__init__c                 C   s   | j S r   r  rW   r!   r!   r"   r    s    z&ProphetNetDecoder.get_input_embeddingsc                 C   s
   || _ d S r   r  r  r!   r!   r"   r    s    z&ProphetNetDecoder.set_input_embeddings)r   r   rP   encoder_attention_maskr  cross_attn_head_maskrI   r  r   r   r  r  r   r   c           %         s"  |	dur|	n| j j}	|
dur |
n| j j}
|dur4|n| j j}|durH|n| j j}|du rj|du rjtdn4|dur|durtdn|dur|du r| |}|jdd \ }| jr| j	r|	rt
d d}	|	r|du r|durtt| j dt| j dn
t| j d}|	r4t|tr4t
d t|}|durF| nd	}| j |f|j|d
\}}|d	krxd\}}n| |\}}| j|d || }| jj|d	kr|ddksJ d fddt| jD }d}d}n2fddt| jD }| ||}| ||}|durrd|ddddddf d| j jdd t | j!j" }|#|j!}nd}t$|g| d}| j%r| %|}t&j'j(|| j(| j	d}|rdnd}|r| j jd	krdnd}|
rdnd}|
rdnd}|
r
| j j)r
dnd}t*||gddgD ]V\}}|dur | d	 t+| j,ks J d| dt+| j, d| d	  dq t-| j,D ]\} }!|r||ddd|f f7 }| j jd	kr||dd|df f7 }|!|||||dur||  nd|dur||  nd||||||	|
|d}"|"d	 }|
r||"d f7 }||"d f7 }| j j)r||"d f7 }q|r||ddd|f f7 }| j jd	kr||dd|df f7 }|ddd|f }#| j jd	kr|dd|df nd}$|s
tdd |#|$||||||fD S t.|#|$||||||dS )aY  
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, ProphetNetDecoder
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
        >>> model = ProphetNetDecoder.from_pretrained("microsoft/prophetnet-large-uncased", add_cross_attention=False)
        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> last_hidden_states = outputs.last_hidden_state
        ```NzGEither `decoder_input_ids` or `decoder_inputs_embeds` has to be passed.zFMake sure to only pass `decoder_input_ids` or `decoder_inputs_embeds`.r%   zZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fr   zPassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.r   )r#   rI   )NNr   zOAt the moment `use_cache` is only supported for `decoder_input_ids` of length 1c                    s&   g | ]}|d      d d qS r   )r?   r   r.   r   r#  predicting_stream_pos_embedr!   r"   r     s   z-ProphetNetDecoder.forward.<locals>.<listcomp>c                    s   g | ]} |d    qS r&  r!   r'  )r#  r)  r!   r"   r   
  s   r  r   r!   r  r%  zThe `z` should be specified for r  r  )r   r   r   r   rB   rC   rA   rI   r   r   r   r   c                 s   s   | ]}|d ur|V  qd S r   r!   r  r!   r!   r"   r  Y  s   
z,ProphetNetDecoder.forward.<locals>.<genexpr>)rc   rd   rI   rf   rg   rh   ri   rN   )/rl   r   r   r  r  r   r   r}   r  r   loggerZwarning_oncer   r   rp   r`   Zfrom_legacy_cacher   r  r#   !compute_buffered_relative_bucketsr   r#  rr   r@   r+   r.   prepare_attention_maskprepare_predict_attention_maskr?   r   r   r'   r   r(   r   r,   r  r   r   r   r   zipr  r   r  re   )%rX   r   r   rP   r$  r  r%  rI   r  r   r   r  r  r   r-   Zpast_key_values_lengthZmain_stream_pos_embedrA   rB   rC   rf   Zngram_hidden_statesr  r   Zextended_encoder_attention_maskZall_main_stream_hidden_statesZall_ngram_stream_hidden_statesZall_main_stream_attnsZall_ngram_stream_attnsZall_cross_attnsZ	attn_maskZ	mask_namer  Zdecoder_layerr  rc   rd   r!   r(  r"   r     s   %









*



&zProphetNetDecoder.forwardc              	   C   s   |j \}}td| j|jdd}t| j| j	|\}}|d d d |d |f |dd}t
|d d d |d |f |d d d || j| j| f gd|dd}||fS r   )r}   r   r   r"  r   r#   r?   rD   r8   r   r,   )rX   rA   r   r-   Zmain_relative_bucketsZpredict_relative_bucketsr!   r!   r"   r+  r  s"    

$
z3ProphetNetDecoder.compute_buffered_relative_bucketsc                 C   s   |j d d \}}tj||ft|jj|j|jd}t|d}|d |d |f d d d d d d f || j	j
f|j  }|d urd|d d d d d d f  t| jj }|| }n|}||jS )Nr%   r   r   r  )r}   r   fullr'   r   r(   r#   Ztriuexpandrl   r   r   )rX   rf   r   r   
seq_lengthZcausal_maskZextended_causal_maskr  r!   r!   r"   r,    s     (*
z(ProphetNetDecoder.prepare_attention_maskc           	      C   s"  |j d d \}}t| j| j|j|j}tj|d d d |d |f |d d d || j| j| f gdd}|d d d d d d d d f || j	j
f|j  }|d urd|d d d d d d d f  t| jj }||| j	j
| j||f}tj|t|gdd}|| }n|}||jS )Nr%   r=   r   r  )r}   r/   r"  r.   r#   r   r   r,   r0  rl   r   r'   r(   r1   r   )	rX   rf   r   r   r1  Zpredict_causal_maskZextended_predict_causal_maskr  r   r!   r!   r"   r-    s4    	
,
z0ProphetNetDecoder.prepare_predict_attention_mask)N)NNNNNNNNNNNNN)rZ   r[   r\   r   r   r   ru   r   r  r  r   r   r   r`   r   r   re   r   r+  r,  r-  r   r!   r!   r   r"   r   r  sJ                
 Ur   c                       s   e Zd ZddgZed fddZdd Zdd	 Zd
d Zdd Z	e
deej eej eej eej eej eej eej ee eeeej   eej eej ee ee ee ee eej eeef dddZ  ZS )ProphetNetModelencoder.word_embeddings.weightdecoder.word_embeddings.weightr   c                    sx   t  | tj|j|j|jd| _t	|}d|_
d|_t|| j| _t	|}d|_d|_t|| j| _|   d S )Nr  FT)r   r   r   ru   r  r   r{   r   copydeepcopyr   Ztie_encoder_decoderr   encoder
is_decoderr   decoderr	  )rX   rl   Zencoder_configZdecoder_configr   r!   r"   r     s    

zProphetNetModel.__init__c                 C   s   | j S r   r  rW   r!   r!   r"   r    s    z$ProphetNetModel.get_input_embeddingsc                 C   s   || _ | j | j_ | j | j_ d S r   )r   r7  r9  r  r!   r!   r"   r    s    
z$ProphetNetModel.set_input_embeddingsc                 C   s0   | j jr,| | jj| j | | jj| j d S r   )rl   tie_word_embeddings_tie_or_clone_weightsr7  r   r9  rW   r!   r!   r"   _tie_weights  s    zProphetNetModel._tie_weightsc                 C   s   | j S r   )r7  rW   r!   r!   r"   get_encoder  s    zProphetNetModel.get_encoderN)r   r   decoder_input_idsdecoder_attention_maskr  decoder_head_maskr%  encoder_outputsrI   r  decoder_inputs_embedsr   r   r  r  r   r   c                 C   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}|durH|n| j j}|du rp| j||||
|||d}| j|||d ||||	||||||d}|s|| S t|j|j	|j
|j|j|j|j|j|j|j|jdS )a7  
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            ProphetNet uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
            `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
            `past_key_values`).
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, ProphetNetModel

        >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
        >>> model = ProphetNetModel.from_pretrained("microsoft/prophetnet-large-uncased")

        >>> input_ids = tokenizer(
        ...     "Studies have been shown that owning a dog is good for you", return_tensors="pt"
        ... ).input_ids  # Batch size 1
        >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1
        >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)

        >>> last_hidden_states = outputs.last_hidden_state  # main stream hidden states
        >>> last_hidden_states_ngram = outputs.last_hidden_state_ngram  # predict hidden states
        ```N)r   r   r  r  r   r  r  r   )r   r   rP   r$  r  r%  rI   r  r   r  r   r  r   )rc   rd   rI   rJ   rK   rL   rM   rN   rO   rP   rQ   )rl   r   r   r  r  r7  r9  rb   rc   rd   rI   rf   rg   rh   ri   rN   )rX   r   r   r>  r?  r  r@  r%  rA  rI   r  rB  r   r   r  r  r   Zdecoder_outputsr!   r!   r"   r     sZ    :zProphetNetModel.forward)NNNNNNNNNNNNNNNN)rZ   r[   r\   _tied_weights_keysr   r   r  r  r<  r=  r   r   r   r   
BoolTensorr`   r   r   rb   r   r   r!   r!   r   r"   r2    sT                   
r2  zh
    The ProphetNet Model with a language modeling head. Can be used for sequence generation tasks.
    c                       s  e Zd Zg dZed fddZdd Zdd Zede	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	eee
j   e	e
j e	e
j e	e
j e	e e	e e	e e	e e	e
j eeef d
ddZdddZe
jdddZdd Zdd Z  ZS )"ProphetNetForConditionalGeneration)r3  r4  lm_head.weightr   c                    sH   t  | t|| _|j| _|j| _tj|j	|j
dd| _|   d S )NFrt   )r   r   r2  rm   r{   rv   disable_ngram_lossr   rq   r   r  lm_headr	  r   r   r!   r"   r   _  s    
z+ProphetNetForConditionalGeneration.__init__c                 C   s   | j jr| | jj| j d S r   )rl   r:  r;  rm   r   rI  rW   r!   r!   r"   r<  j  s    z/ProphetNetForConditionalGeneration._tie_weightsc                 C   s   | j jS r   )rm   r   rW   r!   r!   r"   r  n  s    z7ProphetNetForConditionalGeneration.get_input_embeddingsN)r   r   r>  r?  r  r@  r%  rA  rI   r  rB  labelsr   r   r  r  r   r   c                 C   s  |dur|n| j j}|dur6|du r6|du r6| |}| j|||||||||	|
||||||d}|durn|jn|jdd \}}|d || j j|d}| |}|dddf }| j jdkr|ddddf nd}| s|	 }d}|dur| 
||}|sJtdd ||fD }|dur:|f| |dd  S ||dd  S t||||j|j|j|j|j|j|j|j|jd	S dS )
a	  
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            ProphetNet uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
            `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
            `past_key_values`).
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
            labels in `[0, ..., config.vocab_size]`

        Example:

        ```python
        >>> from transformers import AutoTokenizer, ProphetNetForConditionalGeneration

        >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
        >>> model = ProphetNetForConditionalGeneration.from_pretrained("microsoft/prophetnet-large-uncased")

        >>> input_ids = tokenizer(
        ...     "Studies have been shown that owning a dog is good for you", return_tensors="pt"
        ... ).input_ids  # Batch size 1
        >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1
        >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)

        >>> logits_next_token = outputs.logits  # logits to predict next token as usual
        >>> logits_ngram_next_tokens = outputs.logits_ngram  # logits to predict 2nd, 3rd, ... next tokens
        ```N)r   r   r>  r?  r  r@  r%  rA  rI   r  rB  r   r   r  r  r   r%   r   r=   r   c                 s   s   | ]}|d ur|V  qd S r   r!   r  r!   r!   r"   r    r   z=ProphetNetForConditionalGeneration.forward.<locals>.<genexpr>)rF   rG   rH   rI   rJ   rK   rL   rM   rN   rO   rP   rQ   )rl   r  r   rm   r}   r   r.   rI  Zis_contiguousr   _compute_lossr`   rE   rI   rJ   rK   rL   rM   rN   rO   rP   rQ   )rX   r   r   r>  r?  r  r@  r%  rA  rI   r  rB  rJ  r   r   r  r  r   r   r   r-   predicting_streamspredict_logitsrG   rH   rF   
all_logitsr!   r!   r"   r   q  sb    ?

$0z*ProphetNetForConditionalGeneration.forwardry   c                 C   s  | | jj|d|d|}t| jjD ],}|dkrF| jrF q^|||d d d d f< q0|dd }t	j
j|d|ddtjd}t	j
j||ddd}| jjdkr|jddd	 }||d}	||	 }| }| jj|d }
d
| jj | |
|  }|S Nr   r   r=   r   ro   )Z	reductionrn   T)r   Zkeepdimr  r|   rl   r.   r@   Zfill_r+   rH  r   r   r   r   Zlog_softmaxr   r   r   Znll_lossepssumnero   rX   rG   rJ  Zignore_indexZexpend_targetsr!  ZlprobsrF   Zsmooth_lossZnon_masked_tokensZeps_ir!   r!   r"   rK    s(    $z0ProphetNetForConditionalGeneration._compute_loss)rJ  c                 C   s
   |  |S r   )r   )rX   rJ  r!   r!   r"   %prepare_decoder_input_ids_from_labels	  s    zHProphetNetForConditionalGeneration.prepare_decoder_input_ids_from_labelsc                 C   s   | j jS r   )rm   r7  rW   r!   r!   r"   r=    s    z.ProphetNetForConditionalGeneration.get_encoderc                 C   s   | j jS r   rm   r9  rW   r!   r!   r"   get_decoder  s    z.ProphetNetForConditionalGeneration.get_decoder)NNNNNNNNNNNNNNNNN)ry   )rZ   r[   r\   rC  r   r   r<  r  r   r   r   r   rD  r`   r   r   rE   r   rK  rU  r=  rW  r   r!   r!   r   r"   rE  W  s\                    
{
rE  zt
    The standalone decoder part of the ProphetNetModel with a lm head on top. The model can be used for causal
    c                       s   e Zd Zg dZed fddZdd Zdd Zd	d
 Zdd Z	dd Z
edeej eej eej eej eej eej eeeej   eej eej ee ee ee ee eeef dddZdddZdddZ  ZS )ProphetNetForCausalLM)z!prophetnet.word_embeddings.weightz)prophetnet.decoder.word_embeddings.weightrF  r   c                    s^   t |}d|_d|_t | t|| _|j| _	|j
| _
tj|j|jdd| _|   d S )NTFrG  )r5  r6  r8  Zis_encoder_decoderr   r   ProphetNetDecoderWrapperrm   r{   rv   rH  r   rq   r   r  rI  r	  r   r   r!   r"   r     s    

zProphetNetForCausalLM.__init__c                 C   s
   | j jjS r   rm   r9  r   rW   r!   r!   r"   r  /  s    z*ProphetNetForCausalLM.get_input_embeddingsc                 C   s   || j j_d S r   rZ  r  r!   r!   r"   r  2  s    z*ProphetNetForCausalLM.set_input_embeddingsc                 C   s    | j jr| | jjj| j d S r   )rl   r:  r;  rm   r9  r   rI  rW   r!   r!   r"   r<  5  s    z"ProphetNetForCausalLM._tie_weightsc                 C   s   || j _d S r   rV  )rX   r9  r!   r!   r"   set_decoder9  s    z!ProphetNetForCausalLM.set_decoderc                 C   s   | j jS r   rV  rW   r!   r!   r"   rW  <  s    z!ProphetNetForCausalLM.get_decoderN)r   r   rP   r$  r  r%  rI   r  rJ  r   r   r  r  r   c                 C   s<  |dur|n| j j}| jj|||||||||
|||d}|durF|jn|jdd \}}|d || j j|d}| |}|dddf }| j jdkr|ddddf nd}d}|	dur| ||	}|st	dd ||fD }|dur|f| |dd  S ||dd  S t
||||j|j|j|j|j|jd		S dS )
a	  
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
            `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
            ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`

        Example:

        ```python
        >>> from transformers import AutoTokenizer, ProphetNetForCausalLM
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
        >>> model = ProphetNetForCausalLM.from_pretrained("microsoft/prophetnet-large-uncased")
        >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> logits = outputs.logits

        >>> # Model can also be used with EncoderDecoder framework
        >>> from transformers import BertTokenizer, EncoderDecoderModel, AutoTokenizer
        >>> import torch

        >>> tokenizer_enc = BertTokenizer.from_pretrained("google-bert/bert-large-uncased")
        >>> tokenizer_dec = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
        >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained(
        ...     "google-bert/bert-large-uncased", "microsoft/prophetnet-large-uncased"
        ... )

        >>> ARTICLE = (
        ...     "the us state department said wednesday it had received no "
        ...     "formal word from bolivia that it was expelling the us ambassador there "
        ...     "but said the charges made against him are `` baseless ."
        ... )
        >>> input_ids = tokenizer_enc(ARTICLE, return_tensors="pt").input_ids
        >>> labels = tokenizer_dec(
        ...     "us rejects charges against its ambassador in bolivia", return_tensors="pt"
        ... ).input_ids
        >>> outputs = model(input_ids=input_ids, decoder_input_ids=labels[:, :-1], labels=labels[:, 1:])

        >>> loss = outputs.loss
        ```N)r   r   rP   r$  r  r%  rI   r  r   r   r  r  r%   r   r=   r   c                 s   s   | ]}|d ur|V  qd S r   r!   r  r!   r!   r"   r    r   z0ProphetNetForCausalLM.forward.<locals>.<genexpr>)	rF   rG   rH   rI   rf   rg   rh   ri   rN   )rl   r  rm   r9  r}   r   r.   rI  rK  r`   rj   rI   rf   rg   rh   ri   rN   )rX   r   r   rP   r$  r  r%  rI   r  rJ  r   r   r  r  r   r   r-   rL  rM  rG   rH   rF   rN  r!   r!   r"   r   ?  sJ    A 
$0zProphetNetForCausalLM.forwardry   c                 C   s  | | jj|d|d|}t| jjD ],}|dkrF| jrF q^|||d d d d f< q0|dd }t	j
j|d|ddtjd}t	j
j||ddd}| jjdkr|jddd	 }||d}	||	 }| }| jj|d }
d
| jj | |
|  }|S rO  rP  rT  r!   r!   r"   rK    s(    $z#ProphetNetForCausalLM._compute_lossc                 K   sL   |d u r| |j}|d ur<| dkr<|d d dd f }|||||dS )Nr   r=   )r   r   r  rI   r   )Znew_onesr}   r   )rX   r   rI   r   r  r   kwargsr!   r!   r"   prepare_inputs_for_generation  s    z3ProphetNetForCausalLM.prepare_inputs_for_generation)NNNNNNNNNNNNN)ry   )NNNN)rZ   r[   r\   rC  r   r   r  r  r<  r[  rW  r   r   r   r   r`   r   r   rj   r   rK  r]  r   r!   r!   r   r"   rX    sV                
n
    rX  c                       s6   e Zd ZdZed fddZdd Zdd Z  ZS )	rY  z
    This is a wrapper class, so that [`ProphetNetForCausalLM`] can correctly be loaded from pretrained prophetnet
    classes.
    r   c                    s@   t  | tj|j|j|jd| _t|| jd| _	| 
  d S )Nr  r  )r   r   r   ru   r  r   r{   r   r   r9  r	  r   r   r!   r"   r     s    z!ProphetNetDecoderWrapper.__init__c                 C   s   |  | j| j  d S r   )r;  r   r9  r  rW   r!   r!   r"   r<    s    z%ProphetNetDecoderWrapper._tie_weightsc                 O   s   | j |i |S r   )r9  )rX   argsr\  r!   r!   r"   r     s    z ProphetNetDecoderWrapper.forward)	rZ   r[   r\   r]   r   r   r<  r   r   r!   r!   r   r"   rY    s   	rY  )r   r   rX  rE  r2  rk   )F)F)@r]   r5  r6   rT   dataclassesr   typingr   r   r   Ztorch.utils.checkpointr   r   Ztorch.nnr   Zactivationsr	   Zcache_utilsr
   r   r   Z
generationr   Zmodeling_layersr   Zmodeling_outputsr   Zmodeling_utilsr   utilsr   r   r   Zutils.deprecationr   Zconfiguration_prophetnetr   Z
get_loggerrZ   r*  r   r/   r<   rD   rE   rb   re   rj   rk   ru   r   Moduler   r   r   r   r   r   r   r2  rE  rX  rY  __all__r!   r!   r!   r"   <module>   s   


75'/&+   E+Lu  O  8 N