a
    h                  
   @   s  d dl Z d dlmZmZmZ d dlZd dlmZ d dlmZm	Z	m
Z
 ddlmZ ddlmZmZmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZmZ ddlmZm Z  ddl!m"Z" ddl#m$Z$m%Z%m&Z&m'Z' ddl(m)Z) ddl*m+Z+ e& rddl,m-Z-m.Z. e'/e0Z1G dd dej2Z3G dd dej2Z4d-ej5ej6ej6ej6eej6 ee7 e7eej6 dddZ8G dd dej5Z9G dd deZ:e%G dd  d e Z;e%G d!d" d"e;Z<e%d#d$G d%d& d&e;eZ=e%G d'd( d(e;Z>e%d)d$G d*d+ d+e;Z?g d,Z@dS ).    N)CallableOptionalUnion)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)AttentionMaskConverter)FlashAttentionKwargs)GradientCheckpointingLayer))BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentions SequenceClassifierOutputWithPastTokenClassifierOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)TransformersKwargsauto_docstringis_torch_flex_attn_availablelogging)deprecate_kwarg   )BioGptConfig)	BlockMaskmake_flex_block_causal_maskc                       sH   e Zd ZdZeed fddZd
ejeeej d fdd	Z	  Z
S ) BioGptLearnedPositionalEmbeddingzN
    This module learns positional embeddings up to a fixed maximum size.
    )num_embeddingsembedding_dimc                    s   d| _ t || j  | d S )N   )offsetsuper__init__)selfr"   r#   	__class__ f/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/biogpt/modeling_biogpt.pyr'   <   s    z)BioGptLearnedPositionalEmbedding.__init__r   N)attention_maskpast_key_values_lengthposition_idsc                    sL   |du r:t j|dd}|| d  }|dd|df }t || j S )z3`input_ids_shape` is expected to be [bsz x seqlen].Nr   dim)torchcumsumlongr&   forwardr%   )r(   r-   r.   r/   r)   r+   r,   r5   B   s
    z(BioGptLearnedPositionalEmbedding.forward)r   N)__name__
__module____qualname____doc__intr'   r2   
LongTensorr   r5   __classcell__r+   r+   r)   r,   r!   7   s   	  r!   c                       sF   e Zd ZdZd	eeeee d fddZej	d fddZ
  ZS )
BioGptScaledWordEmbeddingz\
    This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
          ?)r"   r#   padding_idxembed_scalec                    s   t  ||| || _d S N)r&   r'   r@   )r(   r"   r#   r?   r@   r)   r+   r,   r'   X   s    z"BioGptScaledWordEmbedding.__init__)	input_idsc                    s   t  || j S rA   )r&   r5   r@   )r(   rB   r)   r+   r,   r5   \   s    z!BioGptScaledWordEmbedding.forward)r>   )r6   r7   r8   r9   r:   r   floatr'   r2   Tensorr5   r<   r+   r+   r)   r,   r=   S   s   r=           )modulequerykeyvaluer-   scalingdropout	head_maskc                 K   s   |d u r| dd }t||dd| }	|d ur>|	| }	tjj|	dd}	|d urj|	|dddd }	tjj|	|| j	d}	t|	|}
|
dd
 }
|
|	fS )N      r$   r   r0   r   ptraining)sizer2   matmul	transposenn
functionalZsoftmaxviewrK   rQ   
contiguous)rF   rG   rH   rI   r-   rJ   rK   rL   kwargsattn_weightsattn_outputr+   r+   r,   eager_attention_forward`   s    r\   c                       s   e Zd ZdZdeeeeeeee ee d fddZ	e
d	d
dddejeej ee eej eej eeej ee eejeej eeej  f d	ddZ  ZS )BioGptAttentionz=Multi-headed attention from 'Attention Is All You Need' paperrE   FTN)	embed_dim	num_headsrK   
is_decoderbias	is_causalconfig	layer_idxc	           	         s   t    || _|| _|| _|| | _|| _| j| | jkrTtd| j d| d| jd | _|| _	|| _
|| _|d u r| j	rtd| jj d tj|||d| _tj|||d| _tj|||d| _tj|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).rN   zInstantiating a decoder z without passing `layer_idx` is not recommended and will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` when creating this class.ra   )r&   r'   r^   r_   rK   head_dimrc   
ValueErrorrJ   r`   rb   rd   loggerwarning_oncer*   r6   rU   Lineark_projv_projq_projout_proj)	r(   r^   r_   rK   r`   ra   rb   rc   rd   r)   r+   r,   r'      s0    


zBioGptAttention.__init__past_key_valuepast_key_values4.58new_nameversion)	hidden_stateskey_value_statesrp   r-   layer_head_maskoutput_attentionscache_positionrY   returnc                 K   s  |du}	|j dd \}
}|	r(|j d n|}|
|d| jf}|
|d| jf}| |j| dd}|durt|tr|j| j	}|	r|j
}q|j}n|}|	r|n|}|	r|dur|r|j| j	 j}|j| j	 j}n|| |}| |}|j| dd}|j| dd}|durN|	s|nd}|||| j	d|i\}}|	rNd|j| j	< t}| jjdkrlt| jj }|| ||||f| jsdn| j| j||d	|\}}||
|d }| |}||fS )
z#Input shape: Batch x Time x ChannelNrM   r   r$   ry   TeagerrE   )rK   rJ   rx   rL   )shaperf   rm   rW   rT   
isinstancer   
is_updatedgetrd   Zcross_attention_cacheZself_attention_cachelayerskeysvaluesrk   rl   updater\   rc   _attn_implementationr   rQ   rK   rJ   reshaperX   rn   )r(   ru   rv   rp   r-   rw   rx   ry   rY   Zis_cross_attentionZbszZtgt_lenZsrc_lenZq_input_shapeZkv_input_shapeZquery_statesr~   Zcurr_past_key_valueZcurrent_statesZ
key_statesZvalue_statesZattention_interfacer[   rZ   r+   r+   r,   r5      s`    






zBioGptAttention.forward)rE   FTFNN)NNNNFN)r6   r7   r8   r9   r:   rC   boolr   r   r'   r   r2   rD   r
   r   r   tupler5   r<   r+   r+   r)   r,   r]   ~   sF         '      r]   c                       s   e Zd Zdeee d fddZedddddej	eej	 eej	 ee
 ee ee eej eej	 ee eejeeejejf  f d
ddZ  ZS )BioGptDecoderLayerN)rc   rd   c              	      s   t    |j| _t| j|j|jdd||d| _|j| _	t
|j | _|j| _t| j| _t| j|j| _t|j| j| _t| j| _d S )NT)r^   r_   rK   r`   rb   rc   rd   )r&   r'   hidden_sizer^   r]   Znum_attention_headsZattention_probs_dropout_prob	self_attnhidden_dropout_probrK   r	   Z
hidden_actactivation_fnactivation_dropoutrU   	LayerNormself_attn_layer_normrj   Zintermediate_sizefc1fc2final_layer_norm)r(   rc   rd   r)   r+   r,   r'      s$    
	zBioGptDecoderLayer.__init__ro   rp   rq   rr   FT)
ru   r-   rw   rp   rx   	use_cacher/   ry   rY   rz   c	              
   K   s   |}
|  |}| jf |||||||d|	\}}tjj|| j| jd}|
| }|}
| |}| |}| |}tjj|| j	| jd}| 
|}tjj|| j| jd}|
| }|f}|r||f7 }|S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
                Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
                cache in the correct position and to infer the complete sequence length.
        )ru   rp   r-   rw   rx   r/   ry   rO   )r   r   rU   rV   rK   rQ   r   r   r   r   r   )r(   ru   r-   rw   rp   rx   r   r/   ry   rY   ZresidualZself_attn_weightsoutputsr+   r+   r,   r5     s6    







zBioGptDecoderLayer.forward)N)NNNFTNN)r6   r7   r8   r   r   r:   r'   r   r2   rD   r
   r   r;   r   r   r   FloatTensorr5   r<   r+   r+   r)   r,   r      s*          r   c                   @   sv   e Zd ZU eed< dZdZdZdZdZ	dZ
eeejdf  ejejedddZeejeeejejedd	d
ZdS )BioGptPreTrainedModelrc   biogptTr   )r-   input_tensorry   rp   c                 C   sf  | j jdkrRt|tjr"t|}n,|d u rNttj|jd |jd f|jd}|S | j jdkrz|d urv|dk	 rv|S d S |d ur|
 nd}|d ur|jnd}| j jdkr|stj|||| jd	rd S |j}|jd }|r| }	n"t|tjr|jd
 n
|| d }	| j|||	|||jd d}
| j jdkrb|d urb|jjdv rbt|j}t|
|}
|
S )NZflex_attentionr   r   )rR   deviceZflash_attention_2rE   FZsdpa)inputs_embedsr.   Zis_trainingrM   )sequence_lengthtarget_lengthdtypery   
batch_size)cudaZxpuZnpu)rc   r   r}   r2   rD   r    onesr|   r   anyget_seq_lengthZis_compileabler   Z_ignore_causal_mask_sdparQ   r   Zget_max_cache_shape5_prepare_4d_causal_attention_mask_with_cache_positiontypefinfominZ_unmask_unattended)r(   r-   r   ry   rp   Zpast_seen_tokensZusing_compilable_cacher   r   r   causal_mask	min_dtyper+   r+   r,   _update_causal_maskd  sd    







z)BioGptPreTrainedModel._update_causal_mask)r-   r   r   r   ry   r   c                 K   sF  | dur|   dkr| }n&t|j}tj||f|||jd}|dkrVtj|dd}|tj||jd|ddk9 }|ddddddf 	|ddd}| durB|
 }| jd }	|ddddddd|	f | ddddddf |j }
|
dk}
|ddddddd|	f |
||ddddddd|	f< |S )	aM  
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (`torch.Tensor`):
                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
                `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (`int`):
                The sequence length being processed.
            target_length (`int`):
                The target length: when generating with static cache, the mask should be as long as the static cache,
                to account for the 0 padding, the part of the cache that is not filled yet.
            dtype (`torch.dtype`):
                The dtype to use for the 4D attention mask.
            cache_position (`torch.Tensor`):
                Indices depicting the position of the input sequence tokens in the sequence.
            batch_size (`torch.Tensor`):
                Batch size.
        N   )Z
fill_valuer   r   r   )Zdiagonalr   rM   r   )r1   r2   r   r   fullr   Ztriuaranger   expandcloner|   toZmasked_fill)r-   r   r   r   ry   r   rY   r   r   Zmask_lengthZpadding_maskr+   r+   r,   r     s*     $

6  zKBioGptPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_positionN)r6   r7   r8   r   __annotations__Zbase_model_prefixZsupports_gradient_checkpointingZ_supports_flash_attnZ_supports_sdpaZ_supports_flex_attnZ_can_compile_fullgraphr   r   r2   rD   r
   r   staticmethodr:   r   r   r+   r+   r+   r,   r   X  s(   
Lr   c                       s   e Zd Zed fddZedeej eej	 eej	 eej	 ee
e
ej   ee eej ee ee ee eej ee ee
ef dddZ  ZS )	BioGptModelrc   c                    s   t     | _ j| _ j| _ j| _ j| _	 j
rDt jnd}t j| j| j	|d| _t j| j| _t fddt jD | _t| j| _d| _|   d S )Nr>   )r@   c                    s   g | ]}t  |d qS ))rd   )r   ).0ir   r+   r,   
<listcomp>      z(BioGptModel.__init__.<locals>.<listcomp>F)r&   r'   rc   	layerdropr   rK   r   r^   pad_token_idr?   Zscale_embeddingmathsqrtr=   
vocab_sizeembed_tokensr!   Zmax_position_embeddingsembed_positionsrU   Z
ModuleListrangeZnum_hidden_layersr   r   
layer_normgradient_checkpointing	post_init)r(   rc   r@   r)   r   r,   r'     s     zBioGptModel.__init__N)rB   r-   rL   r   rp   r   r/   rx   output_hidden_statesreturn_dictry   rY   rz   c                 K   sL  |d ur|n| j j}|	d ur |	n| j j}	|d ur4|n| j j}|
d urH|
n| j j}
|d u |d uA rjtdn\|d ur|}|j}|d|d }n8|d ur| d d }|d d d d df }ntd|d u r| 	|}| j
r| jr|rtd d}|r|d u rt| j d}|r8t|tr8td t|}| d d \}}|d ur^| nd}|d u rtj||| |jd	}|d u r|| }tj|||jd	}|}| ||||}|d u rtj|d
d}|| d
  }|d d |d f }| j|||d}|| }tjj|| j| jd}| j
rH| jrH|rHtd d}|	rRdnd }|r`dnd }d }t| jD ]\}}|	r||f7 }| jrt g }|| j!k rqr||f||d ur|| nd |||||d|}|d }|rr||d
 f7 }qr|	r||f7 }| "|}|
s:tdd |||||fD S t#|||||dS )NzTYou cannot specify both decoder_input_ids and decoder_inputs_embeds at the same timerM   zEYou have to specify either decoder_input_ids or decoder_inputs_embedsz[`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`...Fr   zPassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `DynamicCache` instead, e.g. `past_key_values=DynamicCache.from_legacy_cache(past_key_values)`.r   r   r   r0   )r/   rO   zZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...r+   )r-   rw   rp   rx   r   r/   ry   c                 s   s   | ]}|d ur|V  qd S rA   r+   )r   vr+   r+   r,   	<genexpr>  s   z&BioGptModel.forward.<locals>.<genexpr>)Zlast_hidden_staterp   ru   
attentionscross_attentions)$rc   rx   r   r   use_return_dictrg   r|   rW   rR   r   r   rQ   rh   ri   r   r}   r   Zfrom_legacy_cacher   r2   r   r   r   r   r3   r4   r   rU   rV   rK   	enumerater   Zrandr   r   r   )r(   rB   r-   rL   r   rp   r   r/   rx   r   r   ry   rY   inputZinput_shaper   Z
seq_lengthr.   Zmask_seq_lengthZself_attn_cacher   Z	positionsru   Zall_hidden_statesZall_self_attnsZall_cross_attentionsidxZdecoder_layerZdropout_probabilityZlayer_outputsr+   r+   r,   r5      s    







	

zBioGptModel.forward)NNNNNNNNNNN)r6   r7   r8   r   r'   r   r   r2   r;   r   r   rD   r   r   r   r   r   r5   r<   r+   r+   r)   r,   r     s8              
r   zR
    BioGPT Model with a `language modeling` head on top for CLM fine-tuning.
    )Zcustom_introc                       s   e Zd ZdgZ fddZdd Zdd Zedee	j
 ee	j ee	j ee	j eeee	j   ee	j
 ee ee	j
 ee ee ee ee	j ee eeef d	d
dZ  ZS )BioGptForCausalLMzoutput_projection.weightc                    s8   t  | t|| _tj|j|jdd| _| 	  d S NFre   )
r&   r'   r   r   rU   rj   r   r   output_projectionr   r(   rc   r)   r+   r,   r'     s    
zBioGptForCausalLM.__init__c                 C   s   | j S rA   r   r(   r+   r+   r,   get_output_embeddings  s    z'BioGptForCausalLM.get_output_embeddingsc                 C   s
   || _ d S rA   r   )r(   Znew_embeddingsr+   r+   r,   set_output_embeddings  s    z'BioGptForCausalLM.set_output_embeddingsN)rB   r-   rL   r   rp   labelsr   r/   rx   r   r   ry   rY   rz   c                 K   s   |dur|n| j j}| j|f|||||||	|
||d
|}|d }| |}d}|durv| j||fd| j ji|}|s|f|dd  }|dur|f| S |S t|||j|j|j	|j
dS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        N)
r-   rL   r   rp   r   r/   rx   r   r   ry   r   r   r   )losslogitsrp   ru   r   r   )rc   r   r   r   Zloss_functionr   r   rp   ru   r   r   )r(   rB   r-   rL   r   rp   r   r   r/   rx   r   r   ry   rY   r   Zsequence_outputZprediction_scoresZlm_lossoutputr+   r+   r,   r5     sP    
zBioGptForCausalLM.forward)NNNNNNNNNNNN)r6   r7   r8   Z_tied_weights_keysr'   r   r   r   r   r2   r;   r   r   rD   r   r   r   r   r   r5   r<   r+   r+   r)   r,   r     sB   	            
r   c                       s   e Zd Z fddZedeej eej eej eej ee	e	ej
   eej eej ee eej ee ee ee eej
 ee	ef dddZ  ZS )BioGptForTokenClassificationc                    sj   t  | |j| _t|| _t|dr:|jd ur:|j}n|j}t	|| _
t|j|j| _|   d S )Nclassifier_dropout)r&   r'   
num_labelsr   r   hasattrr   r   rU   ZDropoutrK   rj   r   
classifierr   )r(   rc   r   r)   r+   r,   r'     s    
z%BioGptForTokenClassification.__init__N)rB   token_type_idsr-   rL   rp   r   r   r   r/   rx   r   r   ry   rz   c                 C   s  |dur|n| j j}| j|||||||	|
|||d}|d }| |}| |}d}|durt }|dur|ddk}|d| j}t	||dt
|j|}|||}n||d| j|d}|s|f|dd  }|dur|f| S |S t|||j|jdS )  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        N
rp   r-   rL   r   r   r/   rx   r   r   ry   r   rM   r   r$   )r   r   ru   r   )rc   r   r   rK   r   r   rW   r   r2   whereZtensorZignore_indexZtype_asr   ru   r   )r(   rB   r   r-   rL   rp   r   r   r   r/   rx   r   r   ry   transformer_outputsru   r   r   loss_fctZactive_lossZactive_logitsZactive_labelsr   r+   r+   r,   r5     sJ    

z$BioGptForTokenClassification.forward)NNNNNNNNNNNNN)r6   r7   r8   r'   r   r   r2   r;   r   r   rD   r   r   r   r5   r<   r+   r+   r)   r,   r     s>                
r   a  
    The BioGpt Model transformer with a sequence classification head on top (linear layer).

    [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
    (e.g. GPT-2) do.

    Since it does classification on the last token, it is required to know the position of the last token. If a
    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
    each row of the batch).
    c                       s   e Zd Zed fddZedeej eej	 eej	 ee
e
ej   eej	 eej ee eej ee ee ee eej ee
ef dddZdd	 Zd
d Z  ZS )BioGptForSequenceClassificationr   c                    s@   t  | |j| _t|| _tj|j| jdd| _| 	  d S r   )
r&   r'   r   r   r   rU   rj   r   scorer   r   r)   r+   r,   r'   Q  s
    
z(BioGptForSequenceClassification.__init__N)rB   r-   rL   rp   r   r   r   r/   rx   r   r   ry   rz   c                 C   s$  |dur|n| j j}| j||||||||	|
||d}|d }| |}|durb|jdd \}}n|jdd \}}| j jdu rd}nD|durt|| j jdd 	|j
}nd}t| jj d |tj||j
d|f }d}|dur| j jdu rN| jdkrd	| j _n:| jdkrF|jtjks<|jtjkrFd
| j _nd| j _| j jd	krt }| jdkr|| | }n
|||}nN| j jd
krt }||d| j|d}n| j jdkrt }|||}|s|f|dd  }|dur|f| S |S t|||j|j|jdS )r   Nr   r   r$   rM   r   z will not detect padding tokens in `inputs_embeds`. Results may be unexpected if using padding tokens in conjunction with `inputs_embeds.`r   Z
regressionZsingle_label_classificationZmulti_label_classification)r   r   rp   ru   r   )rc   r   r   r   r|   r   r2   nesumr   r   rh   ri   r*   r6   r   Zproblem_typer   r   r4   r:   r   Zsqueezer   rW   r   r   rp   ru   r   )r(   rB   r-   rL   rp   r   r   r   r/   rx   r   r   ry   r   ru   r   r   r   Zpooled_logitsr   r   r   r+   r+   r,   r5   Z  sr    
$

(

z'BioGptForSequenceClassification.forwardc                 C   s   | j jS rA   r   r   r   r+   r+   r,   get_input_embeddings  s    z4BioGptForSequenceClassification.get_input_embeddingsc                 C   s   || j _d S rA   r   )r(   rI   r+   r+   r,   set_input_embeddings  s    z4BioGptForSequenceClassification.set_input_embeddings)NNNNNNNNNNNN)r6   r7   r8   r   r'   r   r   r2   r;   r   r   rD   r   r   r   r5   r   r   r<   r+   r+   r)   r,   r   B  s>   	            
\r   )r   r   r   r   r   )NrE   N)Ar   typingr   r   r   r2   Ztorch.nnrU   r   r   r   Zactivationsr	   Zcache_utilsr
   r   r   Z
generationr   Zmodeling_attn_mask_utilsr   Zmodeling_flash_attention_utilsr   Zmodeling_layersr   Zmodeling_outputsr   r   r   r   Zmodeling_utilsr   r   Zprocessing_utilsr   utilsr   r   r   r   Zutils.deprecationr   Zconfiguration_biogptr   Zintegrations.flex_attentionr   r    Z
get_loggerr6   rh   Z	Embeddingr!   r=   ModulerD   rC   r\   r]   r   r   r   r   r   r   __all__r+   r+   r+   r,   <module>   sj   
   [  +TTn