a
    h.                     @   s  d Z ddlZddlmZmZ ddlZddlZddlmZ ddlm	Z	m
Z
mZmZ ddlmZ ddlmZmZ dd	lmZ dd
lmZ ddlmZ ddlmZmZmZmZmZ ddlmZ ddl m!Z!m"Z" ddl#m$Z$ ddl%m&Z& e"'e(Z)d+ddZ*G dd dej+Z,G dd dej+Z-G dd deZ.e!G dd deZ/e!G dd de/Z0e!dd G d!d" d"e/eZ1e!d#d G d$d% d%e/Z2e!G d&d' d'e/Z3e!G d(d) d)e/Z4g d*Z5dS ),zPyTorch MPT model.    N)OptionalUnion)nn)BCEWithLogitsLossCrossEntropyLoss	LayerNormMSELoss)
functional   )CacheDynamicCache)GenerationMixin)!_prepare_4d_causal_attention_mask)GradientCheckpointingLayer))BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentionsQuestionAnsweringModelOutput SequenceClassifierOutputWithPastTokenClassifierOutput)PreTrainedModel)auto_docstringlogging)deprecate_kwarg   )	MptConfig   c                 C   s   t jd| dt j|dddd|}dtt|  }t jd|d t j|d }|||  }dt 	d| }|d|dd}|| krt j
|ddddddf |ddddddf gddddd| df }|| }|dS )	a  
    Link to paper: https://huggingface.co/papers/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
    relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
    the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
    https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292
    r   )dtypedevice         ?N.dimr   )torcharangeint32viewmathceillog2Zint64floatpowconcatsqueeze)	num_headssequence_lengthalibi_bias_maxr   alibiZnum_heads_power_of_2baseZslopes r2   `/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/mpt/modeling_mpt.pybuild_mpt_alibi_tensor.   s    $Lr4   c                       sh   e Zd ZdZdeee d fddZedddd	de	j
e	j
ee ee	j
 ee	j
 d
ddZ  ZS )MptAttentionzzMulti-head self attention.
    Using torch or triton attention implementation enables user to also use additive bias.
    Nconfig	layer_idxc                    s   t    |j| _|j| _|j| _| j| j | _|jj| _| jd u r\dt	
| j| j  | _|jj| _|jj| _tj| jd| j dd| _tj| j| jdd| _|| _d S )Nr   r
   Fbias)super__init__hidden_sizen_headsmax_seq_lenZmax_seq_lengthhead_dimattn_configsoftmax_scaler&   sqrt
attn_pdropattn_dropout_pclip_qkvr   LinearWqkvout_projr8   )selfr7   r8   	__class__r2   r3   r<   J   s    




zMptAttention.__init__past_key_valuepast_key_values4.58new_nameversion)hidden_statesposition_biasrN   attention_maskcache_positionc                 C   s  |j d d \}}| |}| jr6|j| j | jd}|jddd\}	}
}|	||| j| jdd}	|
||| j| jdd}
|||| j| jdd}|d urd|i}|	|
|| j
|\}
}t|	|
dd| j }|d u r|n
||  }|d urxt|j dkr$td	t|j  |
j d }td
|d| }td
|d| }|d d |d |d f }|| }|d ur||t|	jj}tjj| dd|j}tjj|| j| jd}t||}|d
ddd  !||d}| "|}||fS )Nr   )minmaxr
   r    r   rV   z6Expecting position_bias shape to be 3 dimensions, got r   ptraining)#shaperH   rF   clampchunkreshaper>   r@   Z	transposeupdater8   r"   matmulrB   get_seq_lengthlen
ValueErrorrX   sizeZmasked_fillZfinfor   rW   r   r	   Zsoftmaxr)   todropoutrE   r]   Zpermute
contiguousr%   rI   )rJ   rS   rT   rN   rU   rV   
batch_size
seq_lengthZ	mixed_qkvZquery_statesZ
key_statesZvalue_statesZcache_kwargsZattention_scoresZquery_lengthZ
key_lengthZposition_bias_query_indexZposition_bias_key_indexattn_weightsZcontext_statesZattn_outputr2   r2   r3   forwardZ   s:    	




zMptAttention.forward)N)NNN)__name__
__module____qualname____doc__r   r   intr<   r   r"   Tensorr   rn   __classcell__r2   r2   rK   r3   r5   E   s      r5   c                       s:   e Zd Zed fddZejejejdddZ  ZS )MptMLPr7   c                    sX   t    |j}tj|d| dd| _tjdd| _tjd| |dd| _|j	j
| _d S )N   Fr9   none)Zapproximate)r;   r<   r=   r   rG   up_projZGELUact	down_projrA   rD   hidden_dropout)rJ   r7   r=   rK   r2   r3   r<      s    
zMptMLP.__init__)rS   residualreturnc                 C   s:   |  | |}| |}tj|| j| jd}|| }|S )Nr[   )r{   rz   r|   Fri   r}   r]   )rJ   rS   r~   Zintermediate_outputoutputr2   r2   r3   rn      s
    
zMptMLP.forward)	ro   rp   rq   r   r<   r"   rt   rn   ru   r2   r2   rK   r3   rv      s   	rv   c                	       sV   e Zd Zd	eee d fddZd
ejejejee	 e
e
eej dddZ  ZS )MptBlockNr6   c                    sz   t    |j}t||jd| _d | j_|j| _t	||| _
t||jd| _d | j_t|| _|jj| _t| j| _d S )Neps)r;   r<   r=   r   layer_norm_epsilonnorm_1r:   r>   r-   r5   attnnorm_2rv   ffnrA   rD   Zdropout_rater   Dropoutresid_attn_dropout)rJ   r7   r8   r=   rK   r2   r3   r<      s    


zMptBlock.__init__F)rS   rT   rU   
layer_past	use_cacheoutput_attentionsrV   c                 C   sV   |  |}|}	| j|||||d\}
}| |
|	 }| |}|}	| ||	}||fS )N)rT   rU   rN   rV   )r   r   r   r   r   )rJ   rS   rT   rU   r   r   r   rV   Zlayernorm_outputr~   Zattn_outputsrm   r   r2   r2   r3   rn      s    


zMptBlock.forward)N)NFFN)ro   rp   rq   r   r   rs   r<   r"   rt   r   boolrn   ru   r2   r2   rK   r3   r      s       r   c                       s   e Zd ZU eed< dZdZdgZdgZ fddZ	e
jdd	d
Zeeddddeeejejf  eeejejf  dddZ  ZS )MptPreTrainedModelr7   transformerTr   z
lm_head.*.c                    s   t  j|i | d S N)r;   r<   )rJ   inputskwargsrK   r2   r3   r<      s    zMptPreTrainedModel.__init__)modulec                 C   s   t |tjr:|jjjd| jjd |jdur|jj	  nnt |tj
rz|jjjd| jjd |jdur|jj|j 	  n.t |tr|jdur|jj	  |jjd dS )zInitialize the weights.g        )meanZstdNr   )
isinstancer   rG   weightdataZnormal_r7   Zinitializer_ranger:   Zzero_	EmbeddingZpadding_idxr   Zfill_)rJ   r   r2   r2   r3   _init_weights   s    



z MptPreTrainedModel._init_weightsrM   rN   rO   rP   )rN   r   c                    s8   | d d j \}}||  t fdd| D S )zw
        Converts the cache to the format expected by Mpt, i.e. to tuple(tuple([batch_size * num_heads, ...]))
        r   c                 3   s2   | ]*}|d    |d   fV  qdS )r   r   N)ra   ).0r   Zbatch_size_times_num_headsr@   rl   r2   r3   	<genexpr>  s   z;MptPreTrainedModel._convert_to_mpt_cache.<locals>.<genexpr>)r^   tuple)rN   rk   r-   r2   r   r3   _convert_to_mpt_cache   s
    z(MptPreTrainedModel._convert_to_mpt_cache)ro   rp   rq   r   __annotations__Zbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesZ_keys_to_ignore_on_load_missingr<   r   Moduler   staticmethodr   r   r"   rt   r   ru   r2   r2   rK   r3   r      s   
r   c                       s   e Zd Zed fddZdd Zddd	Zejd
ddZ	e
deej eeeeejejf df ef  eej eej ee ee ee ee eej eeejdf ef d
ddZ  ZS )MptModelrw   c                    sz   t     j| _ j| _t j| j| _t	 fddt
 jD | _t| j jd| _d | j_d| _|   d S )Nc                    s   g | ]}t  |d qS ))r8   )r   )r   irw   r2   r3   
<listcomp>      z%MptModel.__init__.<locals>.<listcomp>r   F)r;   r<   r=   r>   r-   r   r   
vocab_sizewteZ
ModuleListrangeZn_layersblocksr   r   norm_fr:   gradient_checkpointing	post_initrJ   r7   rK   rw   r3   r<     s     zMptModel.__init__c                 C   s   | j S r   r   )rJ   r2   r2   r3   get_input_embeddings'  s    zMptModel.get_input_embeddingsr   Nc                 C   s   t ||||S r   )r4   )rJ   r-   r.   r/   r   r2   r2   r3   r4   *  s    zMptModel.build_mpt_alibi_tensornew_embeddingsc                 C   s
   || _ d S r   r   rJ   r   r2   r2   r3   set_input_embeddings-  s    zMptModel.set_input_embeddings.)
	input_idsrN   rU   inputs_embedsr   r   output_hidden_statesreturn_dictrV   r   c
              
   K   s@  |dur|n| j j}|dur |n| j j}|dur4|n| j j}|durH|n| j j}|durj|durjtdn2|dur~|j\}}n|dur|j\}}}ntd| jr| jr|rt	
d d}|du r| |}|r|du rt| j d}|r
t|tr
t	
d t|}|}|rdnd}|r&dnd}|dur<| nd	}|| }|du rhtj||f|jd
}n||j}| j| j| j j|jd
}t|||f||}| }| jD ]H}|r||f }||||||||	d}|d	 }|r||d f }q| |}|r||f }|s0tdd ||||fD S t||||dS )  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
            (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        NzDYou cannot specify both input_ids and inputs_embeds at the same timez5You have to specify either input_ids or inputs_embedszZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Frw   zPassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `DynamicCache` instead, e.g. `past_key_values=DynamicCache.from_legacy_cache(past_key_values)`.r2   r   r   )r   rU   r   r   rT   rV   r   c                 s   s   | ]}|d ur|V  qd S r   r2   )r   vr2   r2   r3   r     s   z#MptModel.forward.<locals>.<genexpr>)Zlast_hidden_staterN   rS   
attentions)r7   r   r   r   use_return_dictrf   r^   r   r]   loggerwarning_oncer   r   r   r   Zfrom_legacy_cacherd   r"   Zonesr   rh   r4   r-   r?   r   r   r   r   r   )rJ   r   rN   rU   r   r   r   r   r   rV   r   rk   rl   _rS   Zall_self_attentionsZall_hidden_statesZpast_key_values_lengthZseq_length_with_pastr0   Zcausal_maskblockoutputsr2   r2   r3   rn   0  s    









zMptModel.forward)r   N)	NNNNNNNNN)ro   rp   rq   r   r<   r   r4   r"   rt   r   r   r   
LongTensorr   r   r   r   r   rn   ru   r2   r2   rK   r3   r     s4   
         "r   z
    The MPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
    embeddings).
    )Zcustom_introc                       s   e Zd ZdgZed fddZejdddZe	de
ej e
eeejejf d	f  e
ej e
ej e
ej e
e e
e e
e e
e e
ej eeej ef d
ddZ  ZS )MptForCausalLMzlm_head.weightrw   c                    s8   t  | t|| _tj|j|jdd| _| 	  d S NFr9   )
r;   r<   r   r   r   rG   r=   r   lm_headr   r   rK   r2   r3   r<     s    
zMptForCausalLM.__init__r   c                 C   s
   || _ d S r   )r   r   r2   r2   r3   set_output_embeddings  s    z$MptForCausalLM.set_output_embeddingsN.)r   rN   rU   r   labelsr   r   r   r   rV   r   c                 K   s   |	dur|	n| j j}	| j||||||||	|
d	}|d }| |}d}|durv||j}| j||fd| j ji|}|	s|f|dd  }|dur|f| S |S t|||j	|j
|jdS )a\  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
            (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        N)rN   rU   r   r   r   r   r   rV   r   r   r   losslogitsrN   rS   r   )r7   r   r   r   rh   r   Zloss_functionr   r   rN   rS   r   )rJ   r   rN   rU   r   r   r   r   r   r   rV   r   transformer_outputsrS   Z	lm_logitsr   r   r2   r2   r3   rn     sF     
zMptForCausalLM.forward)
NNNNNNNNNN)ro   rp   rq   Z_tied_weights_keysr   r<   r"   rt   r   r   r   r   r   r   r   r   rn   ru   r2   r2   rK   r3   r     s6             r   a  
    The MPT Model transformer with a sequence classification head on top (linear layer).

    [`MptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
    (e.g. GPT-1) do.

    Since it does classification on the last token, it requires to know the position of the last token. If a
    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
    each row of the batch).
    c                       s   e Zd Zed fddZed	eej ee	e	ej
ej
f df  eej
 eej
 eej
 ee ee ee ee ee	ej
 ef d
ddZ  ZS )
MptForSequenceClassificationrw   c                    s@   t  | |j| _t|| _tj|j|jdd| _| 	  d S r   )
r;   r<   
num_labelsr   r   r   rG   r=   scorer   r   rK   r2   r3   r<     s
    
z%MptForSequenceClassification.__init__N.
r   rN   rU   r   r   r   r   r   r   r   c
              
   C   s8  |	dur|	n| j j}	| j||||||||	d}
|
d }| |}|durT|jd }n
|jd }| j jdu rz|dkrztd| j jdu rd}nb|dur|| j jk|jt	j
}t	j|jd |jt	j
d}|| d}nd}t| jj d |t	j||jd	|f }d}|dur| j jdu rr| jdkr8d
| j _n:| jdkrj|jt	jks`|jt	jkrjd| j _nd| j _| j jd
krt }| jdkr|| | }n
|||}n>| j jdkrt }|||}n| j jdkrt }|||}|	s |f|
dd  }|dur|f| S |S t|||
j|
j|
jdS )6  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
            (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        NrN   rU   r   r   r   r   r   r   r   z=Cannot handle batch sizes > 1 if no padding token is defined.rY   )r   r   z will not detect padding tokens in `inputs_embeds`. Results may be unexpected if using padding tokens in conjunction with `inputs_embeds.`r   Z
regressionZsingle_label_classificationZmulti_label_classificationr   )r7   r   r   r   r^   Zpad_token_idrf   rh   r   r"   r$   r#   Zargmaxr   r   rL   ro   Zproblem_typer   r   longrs   r   r,   r   r   r   rN   rS   r   )rJ   r   rN   rU   r   r   r   r   r   r   r   rS   r   rk   Zlast_non_pad_tokenZnon_pad_maskZtoken_indicesZpooled_logitsr   loss_fctr   r2   r2   r3   rn     st    



(

z$MptForSequenceClassification.forward)	NNNNNNNNN)ro   rp   rq   r   r<   r   r   r"   r   r   rt   r   r   r   rn   ru   r2   r2   rK   r3   r     s.   	         r   c                       s   e Zd Zed fddZed	eej ee	e	ej
ej
f df  eej
 eej
 eej
 ee ee ee ee ee	ej
 ef d
ddZ  ZS )
MptForTokenClassificationrw   c                    s   t  | |j| _t|| _t|dr:|jd ur:|j}n t|drV|jd urV|j}nd}t	|| _
t|j|j| _|   d S )Nclassifier_dropoutr}   g?)r;   r<   r   r   r   hasattrr   r}   r   r   ri   rG   r=   
classifierr   )rJ   r7   r   rK   r2   r3   r<     s    
z"MptForTokenClassification.__init__N.r   c
              
   K   s   |	dur|	n| j j}	| j||||||||	d}|d }| |}| |}d}|dur||j}|j\}}t }||	|| | j
|	|| }|	s|f|dd  }|dur|f| S |S t|||j|jdS )r   Nr   r   r   )r   r   rS   r   )r7   r   r   ri   r   rh   r   r^   r   r%   r   r   rS   r   )rJ   r   rN   rU   r   r   r   r   r   r   Zdeprecated_argumentsr   rS   r   r   rk   rl   r   r   r2   r2   r3   rn     s>    


z!MptForTokenClassification.forward)	NNNNNNNNN)ro   rp   rq   r   r<   r   r   r"   r   r   rt   r   r   r   rn   ru   r2   r2   rK   r3   r     s.            r   c                       sr   e Zd Z fddZedeej eej eej eej eej ee	 ee	 ee	 e
eef d	ddZ  ZS )MptForQuestionAnsweringc                    s2   t  | t|| _t|jd| _|   d S )Nr   )	r;   r<   r   r   r   rG   r=   
qa_outputsr   r   rK   r2   r3   r<     s    
z MptForQuestionAnswering.__init__N)	r   rU   r   start_positionsend_positionsr   r   r   r   c	                 C   sF  |dur|n| j j}| j||||||d}	|	d }
| |
}|jddd\}}|d }|d }d}|dur|durt| dkr|d}t| dkr|d}|d}|	d|}|	d|}t
|d}|||}|||}|| d }|s0||f|	dd  }|dur,|f| S |S t||||	j|	jd	S )
r   N)rU   r   r   r   r   r   r   rY   r    )Zignore_indexr   )r   start_logits
end_logitsrS   r   )r7   r   r   r   splitr,   rj   re   rg   r_   r   r   rS   r   )rJ   r   rU   r   r   r   r   r   r   r   Zsequence_outputr   r   r   Z
total_lossZignored_indexr   Z
start_lossZend_lossr   r2   r2   r3   rn     sJ    	






zMptForQuestionAnswering.forward)NNNNNNNN)ro   rp   rq   r<   r   r   r"   r   ZFloatTensorr   r   r   r   rn   ru   r2   r2   rK   r3   r     s*           
r   )r   r   r   r   r   r   )r   N)6rr   r&   typingr   r   r"   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   r   r	   r   Zcache_utilsr   r   Z
generationr   Zmodeling_attn_mask_utilsr   Zmodeling_layersr   Zmodeling_outputsr   r   r   r   r   Zmodeling_utilsr   utilsr   r   Zutils.deprecationr   Zconfiguration_mptr   Z
get_loggerro   r   r4   r   r5   rv   r   r   r   r   r   r   r   __all__r2   r2   r2   r3   <module>   sN   

J:0 XrXR