a
    hk                     @   s  d Z ddlZddlmZmZ ddlZddlmZ ddlZddlm	Z	m
Z
mZ ddlmZ ddlmZmZ ddlmZ dd	lmZ dd
lmZmZmZmZ ddlmZ ddlmZ ddlmZm Z m!Z!m"Z" ddl#m$Z$ ddl%m&Z&m'Z'm(Z( ddl)m*Z* ddl+m,Z, e! rddl-m.Z.m/Z/ G dd de*Z0G dd de(Z1G dd de&Z2G dd de'Z3e G dd deZ4e G dd  d e4Z5e d!d"G d#d$ d$e4eZ6e G d%d& d&e4Z7e d'd"G d(d) d)e4Z8g d*Z9dS )+zPyTorch BioGPT model.    N)OptionalUnion)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)CacheDynamicCache)GenerationMixin)AttentionMaskConverter))BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentions SequenceClassifierOutputWithPastTokenClassifierOutput)PreTrainedModel)Unpack)TransformersKwargsauto_docstringis_torch_flex_attn_availablelogger)deprecate_kwarg   )BartAttentionBartDecoderLayerBartScaledWordEmbedding)OPTLearnedPositionalEmbedding   )BioGptConfig)	BlockMaskmake_flex_block_causal_maskc                       s0   e Zd Zdejeeej d fddZ  ZS ) BioGptLearnedPositionalEmbeddingr   N)attention_maskpast_key_values_lengthposition_idsc                    s   t  ||| dS )z3`input_ids_shape` is expected to be [bsz x seqlen].N)superforward)selfr"   r#   r$   	__class__ e/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/biogpt/modular_biogpt.pyr&   <   s    z(BioGptLearnedPositionalEmbedding.forward)r   N)	__name__
__module____qualname__torch
LongTensorintr   r&   __classcell__r*   r*   r(   r+   r!   ;   s     r!   c                   @   s   e Zd ZdS )BioGptScaledWordEmbeddingNr,   r-   r.   r*   r*   r*   r+   r3   F   s   r3   c                   @   s   e Zd ZdS )BioGptAttentionNr4   r*   r*   r*   r+   r5   J   s   r5   c                       s   e Zd Zdeee d fddZedddddej	eej	 eej	 ee
 ee ee eej eej	 ee eejeeejejf  f d
ddZ  ZS )BioGptDecoderLayerN)config	layer_idxc              	      sv   t  | |j| _t| j|j|jdd||d| _|j| _	t
|j | _t| j|j| _t|j| j| _| `| `d S )NT)	embed_dimZ	num_headsdropoutZ
is_decoderZ	is_causalr7   r8   )r%   __init__hidden_sizer9   r5   Znum_attention_headsZattention_probs_dropout_prob	self_attnhidden_dropout_probr:   r   Z
hidden_actactivation_fnnnLinearZintermediate_sizefc1fc2Zencoder_attnZencoder_attn_layer_norm)r'   r7   r8   r(   r*   r+   r;   O   s"    	zBioGptDecoderLayer.__init__Zpast_key_valuepast_key_valuesz4.58)new_nameversionFT)
hidden_statesr"   layer_head_maskrD   output_attentions	use_cacher$   cache_positionkwargsreturnc	              
   K   s   |}
|  |}| jf |||||||d|	\}}tjj|| j| jd}|
| }|}
| |}| |}| |}tjj|| j	| jd}| 
|}tjj|| j| jd}|
| }|f}|r||f7 }|S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
                Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
                cache in the correct position and to infer the complete sequence length.
        )rG   rD   r"   rH   rI   r$   rK   ptraining)Zself_attn_layer_normr=   r@   
functionalr:   rP   Zfinal_layer_normrB   r?   Zactivation_dropoutrC   )r'   rG   r"   rH   rD   rI   rJ   r$   rK   rL   ZresidualZself_attn_weightsoutputsr*   r*   r+   r&   e   s6    







zBioGptDecoderLayer.forward)N)NNNFTNN)r,   r-   r.   r   r   r1   r;   r   r/   Tensorr	   boolr0   r   r   tupleFloatTensorr&   r2   r*   r*   r(   r+   r6   N   s*          r6   c                   @   sv   e Zd ZU eed< dZdZdZdZdZ	dZ
eeejdf  ejejedddZeejeeejejedd	d
ZdS )BioGptPreTrainedModelr7   biogptTr   )r"   input_tensorrK   rD   c                 C   sf  | j jdkrRt|tjr"t|}n,|d u rNttj|jd |jd f|jd}|S | j jdkrz|d urv|dk	 rv|S d S |d ur|
 nd}|d ur|jnd}| j jdkr|stj|||| jd	rd S |j}|jd }|r| }	n"t|tjr|jd
 n
|| d }	| j|||	|||jd d}
| j jdkrb|d urb|jjdv rbt|j}t|
|}
|
S )NZflex_attentionr   r   )sizedeviceZflash_attention_2g        FZsdpa)inputs_embedsr#   Zis_training)sequence_lengthtarget_lengthdtyperK   
batch_size)cudaZxpuZnpu)r7   Z_attn_implementation
isinstancer/   rS   r    onesshaper[   anyget_seq_lengthZis_compileabler   Z_ignore_causal_mask_sdparP   r`   Zget_max_cache_shape5_prepare_4d_causal_attention_mask_with_cache_positiontypefinfominZ_unmask_unattended)r'   r"   rY   rK   rD   Zpast_seen_tokensZusing_compilable_cacher`   r^   r_   causal_mask	min_dtyper*   r*   r+   _update_causal_mask   sd    







z)BioGptPreTrainedModel._update_causal_mask)r"   r^   r_   r`   rK   ra   c                 K   sF  | dur|   dkr| }n&t|j}tj||f|||jd}|dkrVtj|dd}|tj||jd|ddk9 }|ddddddf 	|ddd}| durB|
 }| jd }	|ddddddd|	f | ddddddf |j }
|
dk}
|ddddddd|	f |
||ddddddd|	f< |S )	aM  
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (`torch.Tensor`):
                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
                `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (`int`):
                The sequence length being processed.
            target_length (`int`):
                The target length: when generating with static cache, the mask should be as long as the static cache,
                to account for the 0 padding, the part of the cache that is not filled yet.
            dtype (`torch.dtype`):
                The dtype to use for the 4D attention mask.
            cache_position (`torch.Tensor`):
                Indices depicting the position of the input sequence tokens in the sequence.
            batch_size (`torch.Tensor`):
                Batch size.
        N   )Z
fill_valuer`   r[   r   )Zdiagonalr[   r]   r   )dimr/   rj   rk   fullr[   ZtriuarangeZreshapeexpandclonere   toZmasked_fill)r"   r^   r_   r`   rK   ra   rL   rl   rm   Zmask_lengthZpadding_maskr*   r*   r+   rh      s*     $

6  zKBioGptPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_positionN)r,   r-   r.   r   __annotations__Zbase_model_prefixZsupports_gradient_checkpointingZ_supports_flash_attnZ_supports_sdpaZ_supports_flex_attnZ_can_compile_fullgraphr   r   r/   rS   r	   rn   staticmethodr1   r`   rh   r*   r*   r*   r+   rW      s(   
LrW   c                       s   e Zd Zed fddZedeej eej	 eej	 eej	 ee
e
ej   ee eej ee ee ee eej ee ee
ef dddZ  ZS )	BioGptModelr7   c                    s   t     | _ j| _ j| _ j| _ j| _	 j
rDt jnd}t j| j| j	|d| _t j| j| _t fddt jD | _t| j| _d| _|   d S )Ng      ?)embed_scalec                    s   g | ]}t  |d qS ))r8   )r6   ).0irz   r*   r+   
<listcomp>I      z(BioGptModel.__init__.<locals>.<listcomp>F)r%   r;   r7   	layerdropr>   r:   r<   r9   pad_token_idZpadding_idxZscale_embeddingmathsqrtr3   
vocab_sizeembed_tokensr!   Zmax_position_embeddingsembed_positionsr@   Z
ModuleListrangeZnum_hidden_layerslayersZ	LayerNorm
layer_normgradient_checkpointing	post_init)r'   r7   r{   r(   rz   r+   r;   ;  s     zBioGptModel.__init__N)	input_idsr"   	head_maskr\   rD   rJ   r$   rI   output_hidden_statesreturn_dictrK   rL   rM   c                 K   sL  |d ur|n| j j}|	d ur |	n| j j}	|d ur4|n| j j}|
d urH|
n| j j}
|d u |d uA rjtdn\|d ur|}|j}|d|d }n8|d ur| d d }|d d d d df }ntd|d u r| 	|}| j
r| jr|rtd d}|r|d u rt| j d}|r8t|tr8td t|}| d d \}}|d ur^| nd}|d u rtj||| |jd	}|d u r|| }tj|||jd	}|}| ||||}|d u rtj|d
d}|| d
  }|d d |d f }| j|||d}|| }tjj|| j| jd}| j
rH| jrH|rHtd d}|	rRdnd }|r`dnd }d }t| jD ]\}}|	r||f7 }| jrt g }|| j!k rqr||f||d ur|| nd |||||d|}|d }|rr||d
 f7 }qr|	r||f7 }| "|}|
s:tdd |||||fD S t#|||||dS )NzTYou cannot specify both decoder_input_ids and decoder_inputs_embeds at the same timer]   zEYou have to specify either decoder_input_ids or decoder_inputs_embedsz[`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`...Frz   zPassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `DynamicCache` instead, e.g. `past_key_values=DynamicCache.from_legacy_cache(past_key_values)`.r   rp   r   )rq   )r$   rN   zZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...r*   )r"   rH   rD   rI   rJ   r$   rK   c                 s   s   | ]}|d ur|V  qd S Nr*   )r|   vr*   r*   r+   	<genexpr>  s   z&BioGptModel.forward.<locals>.<genexpr>)Zlast_hidden_staterD   rG   
attentionscross_attentions)$r7   rI   r   rJ   use_return_dict
ValueErrorre   viewrZ   r   r   rP   r   warning_oncer
   rc   rU   Zfrom_legacy_cacherg   r/   rs   r[   rd   rn   Zcumsumlongr   r@   rQ   r:   	enumerater   Zrandr   r   r   )r'   r   r"   r   r\   rD   rJ   r$   rI   r   r   rK   rL   inputZinput_shapera   Z
seq_lengthr#   Zmask_seq_lengthZself_attn_cacherl   Z	positionsrG   Zall_hidden_statesZall_self_attnsZall_cross_attentionsidxZdecoder_layerZdropout_probabilityZlayer_outputsr*   r*   r+   r&   P  s    







	

zBioGptModel.forward)NNNNNNNNNNN)r,   r-   r.   r   r;   r   r   r/   r0   rV   rU   rS   rT   r   r   r   r   r&   r2   r*   r*   r(   r+   ry   9  s8              
ry   zR
    BioGPT Model with a `language modeling` head on top for CLM fine-tuning.
    )Zcustom_introc                       s   e Zd ZdgZ fddZdd Zdd Zedee	j
 ee	j ee	j ee	j eeee	j   ee	j
 ee ee	j
 ee ee ee ee	j ee eeef d	d
dZ  ZS )BioGptForCausalLMzoutput_projection.weightc                    s8   t  | t|| _tj|j|jdd| _| 	  d S NF)Zbias)
r%   r;   ry   rX   r@   rA   r<   r   output_projectionr   r'   r7   r(   r*   r+   r;     s    
zBioGptForCausalLM.__init__c                 C   s   | j S r   r   r'   r*   r*   r+   get_output_embeddings  s    z'BioGptForCausalLM.get_output_embeddingsc                 C   s
   || _ d S r   r   )r'   Znew_embeddingsr*   r*   r+   set_output_embeddings  s    z'BioGptForCausalLM.set_output_embeddingsN)r   r"   r   r\   rD   labelsrJ   r$   rI   r   r   rK   rL   rM   c                 K   s   |dur|n| j j}| j|f|||||||	|
||d
|}|d }| |}d}|durv| j||fd| j ji|}|s|f|dd  }|dur|f| S |S t|||j|j|j	|j
dS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        N)
r"   r   r\   rD   rJ   r$   rI   r   r   rK   r   r   r   )losslogitsrD   rG   r   r   )r7   r   rX   r   Zloss_functionr   r   rD   rG   r   r   )r'   r   r"   r   r\   rD   r   rJ   r$   rI   r   r   rK   rL   rR   Zsequence_outputZprediction_scoresZlm_lossoutputr*   r*   r+   r&     sP    
zBioGptForCausalLM.forward)NNNNNNNNNNNN)r,   r-   r.   Z_tied_weights_keysr;   r   r   r   r   r/   r0   rV   rU   rS   rT   r   r   r   r   r&   r2   r*   r*   r(   r+   r     sB   	            
r   c                       s   e Zd Z fddZedeej eej eej eej ee	e	ej
   eej eej ee eej ee ee ee eej
 ee	ef dddZ  ZS )BioGptForTokenClassificationc                    sj   t  | |j| _t|| _t|dr:|jd ur:|j}n|j}t	|| _
t|j|j| _|   d S )Nclassifier_dropout)r%   r;   
num_labelsry   rX   hasattrr   r>   r@   ZDropoutr:   rA   r<   
classifierr   )r'   r7   r   r(   r*   r+   r;   ?  s    
z%BioGptForTokenClassification.__init__N)r   token_type_idsr"   r   rD   r\   r   rJ   r$   rI   r   r   rK   rM   c                 C   s  |dur|n| j j}| j|||||||	|
|||d}|d }| |}| |}d}|durt }|dur|ddk}|d| j}t	||dt
|j|}|||}n||d| j|d}|s|f|dd  }|dur|f| S |S t|||j|jdS )  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        N
rD   r"   r   r\   rJ   r$   rI   r   r   rK   r   r]   r   r   )r   r   rG   r   )r7   r   rX   r:   r   r   r   r   r/   whereZtensorZignore_indexZtype_asr   rG   r   )r'   r   r   r"   r   rD   r\   r   rJ   r$   rI   r   r   rK   transformer_outputsrG   r   r   loss_fctZactive_lossZactive_logitsZactive_labelsr   r*   r*   r+   r&   M  sJ    

z$BioGptForTokenClassification.forward)NNNNNNNNNNNNN)r,   r-   r.   r;   r   r   r/   r0   rV   rU   rS   rT   r   r   r&   r2   r*   r*   r(   r+   r   =  s>                
r   a  
    The BioGpt Model transformer with a sequence classification head on top (linear layer).

    [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
    (e.g. GPT-2) do.

    Since it does classification on the last token, it is required to know the position of the last token. If a
    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
    each row of the batch).
    c                       s   e Zd Zed fddZedeej eej	 eej	 ee
e
ej   eej	 eej ee eej ee ee ee eej ee
ef dddZdd	 Zd
d Z  ZS )BioGptForSequenceClassificationrz   c                    s@   t  | |j| _t|| _tj|j| jdd| _| 	  d S r   )
r%   r;   r   ry   rX   r@   rA   r<   scorer   r   r(   r*   r+   r;     s
    
z(BioGptForSequenceClassification.__init__N)r   r"   r   rD   r\   r   rJ   r$   rI   r   r   rK   rM   c                 C   s$  |dur|n| j j}| j||||||||	|
||d}|d }| |}|durb|jdd \}}n|jdd \}}| j jdu rd}nD|durt|| j jdd 	|j
}nd}t| jj d |tj||j
d|f }d}|dur| j jdu rN| jdkrd	| j _n:| jdkrF|jtjks<|jtjkrFd
| j _nd| j _| j jd	krt }| jdkr|| | }n
|||}nN| j jd
krt }||d| j|d}n| j jdkrt }|||}|s|f|dd  }|dur|f| S |S t|||j|j|jdS )r   Nr   r   r   r]   r   z will not detect padding tokens in `inputs_embeds`. Results may be unexpected if using padding tokens in conjunction with `inputs_embeds.`rp   Z
regressionZsingle_label_classificationZmulti_label_classification)r   r   rD   rG   r   )r7   r   rX   r   re   r   r/   nesumrv   r[   r   r   r)   r,   rs   Zproblem_typer   r`   r   r1   r   Zsqueezer   r   r   r   rD   rG   r   )r'   r   r"   r   rD   r\   r   rJ   r$   rI   r   r   rK   r   rG   r   ra   r^   Zpooled_logitsr   r   r   r*   r*   r+   r&     sr    
$

(

z'BioGptForSequenceClassification.forwardc                 C   s   | j jS r   rX   r   r   r*   r*   r+   get_input_embeddings  s    z4BioGptForSequenceClassification.get_input_embeddingsc                 C   s   || j _d S r   r   )r'   valuer*   r*   r+   set_input_embeddings
  s    z4BioGptForSequenceClassification.set_input_embeddings)NNNNNNNNNNNN)r,   r-   r.   r   r;   r   r   r/   r0   rV   rU   rS   rT   r   r   r&   r   r   r2   r*   r*   r(   r+   r     s>   	            
\r   )r   r   r   ry   rW   ):__doc__r   typingr   r   r/   Ztorch.nnr@   Ztorch.utils.checkpointr   r   r   Zactivationsr   Zcache_utilsr	   r
   Z
generationr   Zmodeling_attn_mask_utilsr   Zmodeling_outputsr   r   r   r   Zmodeling_utilsr   Zprocessing_utilsr   utilsr   r   r   r   Zutils.deprecationr   Zbart.modeling_bartr   r   r   Zopt.modeling_optr   Zconfiguration_biogptr   Zintegrations.flex_attentionr   r    r!   r3   r5   r6   rW   ry   r   r   r   __all__r*   r*   r*   r+   <module>   sR   Z  +TTn