a
    hU                    @   s  d dl Z d dlZd dlmZ d dlmZmZmZ d dlZd dlm	Z	m
Z
 ddlmZ ddlmZmZ ddlmZ dd	lmZ dd
lmZ ddlmZ ddlmZmZmZmZmZ ddlmZm Z  ddl!m"Z"m#Z#m$Z$m%Z% ddl&m'Z' ddl(m)Z)m*Z* ddl+m,Z,m-Z-m.Z. ddl/m0Z0 ddl1m2Z2m3Z3 ddl4m5Z5m6Z6 dd Z7G dd de
j8Z9dd Z:dd Z;G dd de
j8Z<d_e
j8ej	ej	ej	eej	 e=e=eej	 e'e, d!	d"d#Z>G d$d% d%e
j8Z?G d&d' d'e
j8Z@G d(d) d)e
j8ZAd*d+ ZBG d,d- d-e
j8ZCG d.d/ d/e
j8ZDG d0d1 d1eZEG d2d3 d3e
j8ZFG d4d5 d5e
j8ZGe-G d6d7 d7e$ZHG d8d9 d9eHZIG d:d; d;e
j8ZJG d<d= d=e
j8ZKG d>d? d?e
j8ZLee-G d@dA dAeZMG dBdC dCe
j8ZNG dDdE dEe
j8ZOedFG dGdH dHe
j8ZPG dIdJ dJe
j8ZQG dKdL dLe
j8ZRdMdN ZSd`dOdPZTej	eUej	dQdRdSZVG dTdU dUe
j8ZWG dVdW dWeZXe-G dXdY dYe$ZYG dZd[ d[eYZZG d\d] d]eYeZ[g d^Z\dS )a    N)	dataclass)CallableOptionalUnion)Tensornn   )ACT2FN)CacheDynamicCache)GenerationMixin)use_kernel_forward_from_hub)create_causal_mask)GradientCheckpointingLayer)"BaseModelOutputWithCrossAttentionsBaseModelOutputWithPast,BaseModelOutputWithPoolingAndCrossAttentionsCausalLMOutputWithPastModelOutput)ROPE_INIT_FUNCTIONSdynamic_rope_update)ALL_ATTENTION_FUNCTIONSModuleUtilsMixinPreTrainedModelget_parameter_dtype)Unpack) find_pruneable_heads_and_indicesprune_linear_layer)TransformersKwargsauto_docstringcan_return_tuple)deprecate_kwarg)OutputRecordercheck_model_inputs   )EvollaConfigSaProtConfigc                 C   s2   |  | }tj|dd|| }| | S )a  
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's `utils.make_positions`.

    Args:
        x: torch.Tensor x:

    Returns: torch.Tensor
    r$   dim)neinttorchZcumsumtype_aslong)	input_idspadding_idxmaskZincremental_indices r1   f/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/evolla/modeling_evolla.py"create_position_ids_from_input_ids5   s    r3   c                       s2   e Zd ZdZ fddZd	ddZdd Z  ZS )
EvollaSaProtEmbeddingszV
    Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
    c                    s   t    tj|j|j|jd| _|jr>tj	|j|j
d| _nd | _t|j| _t|dd| _| jdt|jddd |j| _| jdkrtj|j|j| jd| _|j| _|j| _d | _d S )	N)r/   epsposition_embedding_typeabsoluteposition_ids)r$   F
persistent)super__init__r   	Embedding
vocab_sizehidden_sizepad_token_idword_embeddingsZemb_layer_norm_before	LayerNormlayer_norm_eps
layer_normDropouthidden_dropout_probdropoutgetattrr7   register_bufferr+   arangemax_position_embeddingsexpandr/   position_embeddingstoken_dropoutmask_token_idr9   selfconfig	__class__r1   r2   r>   J   s$    

zEvollaSaProtEmbeddings.__init__Nc           
      C   s  |d u r(|d urt || j}n
| |}|d u r:| |}|}| jr|d ur||| jkdd}d}|d urz|dn|j	d }|| jkd
 | }|d|  d| d d d d f  |j}| jdkr| |}	||	 }| jd ur| |}|d ur||d |j}|S )Nr:           gQ?r$   r8   )r3   r/   &create_position_ids_from_inputs_embedsrC   rP   masked_fillrQ   	unsqueezesumshapefloattodtyper7   rO   rF   )
rS   r.   attention_maskr9   inputs_embeds
embeddingsZmask_ratio_trainZsrc_lengthsZmask_ratio_observedrO   r1   r1   r2   forwardc   s.    

	"




zEvollaSaProtEmbeddings.forwardc                 C   sN   |  dd }|d }tj| jd || j d tj|jd}|d|S )z
        We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.

        Args:
            inputs_embeds: torch.Tensor

        Returns: torch.Tensor
        Nr:   r$   r_   devicer   )sizer+   rL   r/   r-   re   rZ   rN   )rS   ra   input_shapeZsequence_lengthr9   r1   r1   r2   rX      s    	z=EvollaSaProtEmbeddings.create_position_ids_from_inputs_embeds)NNNN)__name__
__module____qualname____doc__r>   rc   rX   __classcell__r1   r1   rU   r2   r4   E   s       
1r4   c                 C   s&   | j ddd\}}tj| |fddS )N   r:   r'   )chunkr+   catxx1Zx2r1   r1   r2   rotate_half_esm   s    rs   c                 C   s`   |d d d d d | j d d d f }|d d d d d | j d d d f }| | t| |  S )N)r\   rs   )rq   cossinr1   r1   r2   apply_rotary_pos_emb_esm   s    &&rw   c                       s^   e Zd ZU dZejed< ed fddZdddZ	ejeje
ejejf d	d
dZ  ZS )EvollaSaProtRotaryEmbeddingz
    Rotary position embeddings based on those in
    [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
    matrices which depend on their relative positions.
    inv_freqr'   c                    sT   t    ddtjd|dtjd |   }|}| d| d | _d | _d | _	d S )N      ?i'  r   rm   r_   ry   )
r=   r>   r+   rL   Zint64r]   rK   _seq_len_cached_cos_cached_sin_cached)rS   r(   ry   rU   r1   r2   r>      s    
$z$EvollaSaProtRotaryEmbedding.__init__rm   c                 C   s   |j | }|| jks"| jj|jkr|| _tj|j | |jd| j}t|| j}tj	||fdd
|j}| d d d d d d f | _| d d d d d d f | _| j| jfS )Nre   r:   r'   )r\   r|   r}   re   r+   rL   r,   ry   outerro   r^   ru   rv   r~   )rS   rq   seq_dimensionZseq_lentfreqsembr1   r1   r2   _update_cos_sin_tables   s    
z2EvollaSaProtRotaryEmbedding._update_cos_sin_tables)qkreturnc                 C   sJ   | j |dd\| _| _t|| j| jj|jdt|| j| jj|jdfS )Nrt   )r   r{   )r   r}   r~   rw   r^   r_   )rS   r   r   r1   r1   r2   rc      s    z#EvollaSaProtRotaryEmbedding.forward)rm   )rh   ri   rj   rk   r+   r   __annotations__r*   r>   r   tuplerc   rl   r1   r1   rU   r2   rx      s
   


rx   rW   )	modulequerykeyvaluer`   scalingrI   	head_maskkwargsc                 K   s  t ||dd| }	t| dr| jdv r|jd }
t j|
t j|	jd	dd}t j|
t j|	jd	dd}|| }| 
|| j d }|j|jd}| jd	krt d
||}n.| jdkrt d
||}t d||}|| }|	| }	|d ur |d d d d d d d |jd f }|	| }	tjj|	dt jd|j}	tjj|	|| jd}	|d urb|	| }	t |	|}|dd }||	fS )Nrm   r   r7   relative_keyrelative_key_queryrd   r:   r$   r{   r   zbhld,lrd->bhlrr   zbhrd,lrd->bhlrrt   )r(   r_   )ptraining)r+   matmul	transposehasattrr7   r\   rL   r-   re   viewdistance_embeddingrM   r^   r_   Zeinsumr   Z
functionalsoftmaxfloat32rI   r   
contiguous)r   r   r   r   r`   r   rI   r   r   attn_weights
seq_lengthZposition_ids_lZposition_ids_rZdistanceZpositional_embeddingZrelative_position_scoresZrelative_position_scores_queryZrelative_position_scores_keycausal_maskattn_outputr1   r1   r2   eager_attention_forward   s2    



&
r   c                	       s^   e Zd Zd fdd	Zd	ejeej eej eej eej ee	 e
ej dddZ  ZS )
EvollaSaProtSelfAttentionNFc                    s<  t    || _|j|j dkrDt|dsDtd|j d|j d|j| _t|j|j | _| j| j | _	t
|j| j	| _t
|j| j	| _t
|j| j	| _|j| _|pt|dd| _d | _| jdks| jd	kr|j| _t
d
|j d | j| _n| jdkrt| jd| _|j| _|| _d| _| jo4| | _d S )Nr   Zembedding_sizezThe hidden size (z6) is not a multiple of the number of attention heads ()r7   r8   r   r   rm   r$   rotaryr'   rz   )r=   r>   rT   rA   num_attention_headsr   
ValueErrorr*   attention_head_sizeall_head_sizer   Linearr   r   r   attention_probs_dropout_probrI   rJ   r7   rotary_embeddingsrM   r?   r   rx   
is_decoder	layer_idxr   	is_causal)rS   rT   r7   r   is_cross_attentionrU   r1   r2   r>     s8    

z"EvollaSaProtSelfAttention.__init__)hidden_statesr`   r   encoder_hidden_statesencoder_attention_maskr   r   c                 K   s@  |j d d \}}||d| jf}	| ||	dd}
|d u}|rH|n|}|rT|n|}| ||	dd}| ||	dd}|
| jd  }
| jdkr| |
|\}
}t	}| j
jdkr| jdv rtd| j
j d	| j d
t| j
j }|| |
|||f| jsdn| j| j|d|\}}|||d }||fS )Nr:   r$   rm         r   eagerr   zESM z attention does not support z^ embeddings. Set attention explicitly to 'eager' with `model.set_attn_implementation('eager')`rW   )rI   r   r   )r\   r   r   r   r   r   r   r7   r   r   rT   _attn_implementationr   r   r   rI   r   reshaper   )rS   r   r`   r   r   r   r   
batch_sizer   hidden_shapequery_layerr   Zcurrent_states	key_layervalue_layerattention_interfacer   r   r1   r1   r2   rc   4  sB    	

	
z!EvollaSaProtSelfAttention.forward)NNF)NNNN)rh   ri   rj   r>   r+   r   r   FloatTensorr   r   r   rc   rl   r1   r1   rU   r2   r     s   %    r   c                       s$   e Zd Z fddZdd Z  ZS )EvollaSaProtSelfOutputc                    s.   t    t|j|j| _t|j| _d S N)	r=   r>   r   r   rA   denserG   rH   rI   rR   rU   r1   r2   r>   k  s    
zEvollaSaProtSelfOutput.__init__c                 C   s    |  |}| |}|| }|S r   r   rI   rS   r   Zinput_tensorr1   r1   r2   rc   p  s    

zEvollaSaProtSelfOutput.forwardrh   ri   rj   r>   rc   rl   r1   r1   rU   r2   r   j  s   r   c                       s:   e Zd Zd
 fdd	Zdd Zdee ddd	Z  ZS )EvollaSaProtAttentionNFc                    sD   t    t|||d| _t|| _t | _tj	|j
|jd| _	d S )N)r   r   r5   )r=   r>   r   rS   r   outputsetpruned_headsr   rD   rA   rE   )rS   rT   r   r   rU   r1   r2   r>   x  s
    

zEvollaSaProtAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r$   r'   )lenr   rS   r   r   r   r   r   r   r   r   r   r   union)rS   headsindexr1   r1   r2   prune_heads  s    z!EvollaSaProtAttention.prune_headsr   c           
      K   s:   |  |}| j|f||||d|\}}	| ||}|S )Nr`   r   r   r   )rD   rS   r   )
rS   r   r`   r   r   r   r   Zhidden_states_lnr   _r1   r1   r2   rc     s    	

zEvollaSaProtAttention.forward)NF)NNNN)	rh   ri   rj   r>   r   r   r   rc   rl   r1   r1   rU   r2   r   w  s       r   c                 C   s    | d dt | td   S )zz
    This is the gelu implementation from the original EVOLLA_SA_PROT repo. Using F.gelu yields subtly wrong results.
    g      ?rz   g       @)r+   erfmathsqrt)rq   r1   r1   r2   gelu  s    r   c                       s0   e Zd Z fddZejejdddZ  ZS )EvollaSaProtIntermediatec                    s    t    t|j|j| _d S r   )r=   r>   r   r   rA   intermediate_sizer   rR   rU   r1   r2   r>     s    
z!EvollaSaProtIntermediate.__init__r   r   c                 C   s   |  |}t|}|S r   )r   r   )rS   r   r1   r1   r2   rc     s    
z EvollaSaProtIntermediate.forwardrh   ri   rj   r>   r+   r   rc   rl   r1   r1   rU   r2   r     s   r   c                       s$   e Zd Z fddZdd Z  ZS )EvollaSaProtOutputc                    s.   t    t|j|j| _t|j| _	d S r   )
r=   r>   r   r   r   rA   r   rG   rH   rI   rR   rU   r1   r2   r>     s    
zEvollaSaProtOutput.__init__c                 C   s    |  |}| |}|| }|S r   r   r   r1   r1   r2   rc     s    

zEvollaSaProtOutput.forwardr   r1   r1   rU   r2   r     s   r   c                       s8   e Zd Z fddZd	ee dddZdd Z  ZS )
EvollaSaProtLayerc                    s   t    |j| _d| _t|| _|j| _|j| _| jrZ| jsLt|  dt|dd| _	t
|| _t|| _tj|j|jd| _d S )Nr$   z> should be used as a decoder model if cross attention is addedT)r   r5   )r=   r>   Zchunk_size_feed_forwardZseq_len_dimr   	attentionr   Zadd_cross_attentionRuntimeErrorcrossattentionr   intermediater   r   r   rD   rA   rE   rR   rU   r1   r2   r>     s    



zEvollaSaProtLayer.__init__Nr   c           	      K   sj   | j |f||d|}| jr\|d ur\t| ds@td|  d| j|f||||d|}| |}|S )N)r`   r   r   z'If `encoder_hidden_states` are passed, z` has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`r   )r   r   r   AttributeErrorr   feed_forward_chunk)	rS   r   r`   r   r   r   r   attention_outputlayer_outputr1   r1   r2   rc     s2    	

	
zEvollaSaProtLayer.forwardc                 C   s$   |  |}| |}| ||}|S r   )rD   r   r   )rS   r   Zattention_output_lnZintermediate_outputr   r1   r1   r2   r     s    

z$EvollaSaProtLayer.feed_forward_chunk)NNNN)	rh   ri   rj   r>   r   r   rc   r   rl   r1   r1   rU   r2   r     s       #r   c                       s4   e Zd Z fddZedee dddZ  ZS )EvollaSaProtEncoderc                    sN   t     | _t fddt jD | _tj j	 j
d| _d| _d S )Nc                    s   g | ]}t  qS r1   )r   .0r   rT   r1   r2   
<listcomp>      z0EvollaSaProtEncoder.__init__.<locals>.<listcomp>r5   F)r=   r>   rT   r   
ModuleListrangenum_hidden_layerslayerrD   rA   rE   emb_layer_norm_aftergradient_checkpointingrR   rU   r   r2   r>     s
    
 zEvollaSaProtEncoder.__init__Nr   c           
      K   s\   t | jD ]6\}}|d ur"|| nd }	||f||	||d|}q
| jrR| |}t|dS )Nr   )last_hidden_state)	enumerater   r   r   )
rS   r   r`   r   r   r   r   iZlayer_moduleZlayer_head_maskr1   r1   r2   rc     s    
	
zEvollaSaProtEncoder.forward)NNNN)	rh   ri   rj   r>   r    r   r   rc   rl   r1   r1   rU   r2   r      s       r   c                       s0   e Zd Z fddZejejdddZ  ZS )EvollaSaProtPoolerc                    s*   t    t|j|j| _t | _d S r   )r=   r>   r   r   rA   r   ZTanh
activationrR   rU   r1   r2   r>   $  s    
zEvollaSaProtPooler.__init__r   c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )rS   r   Zfirst_token_tensorZpooled_outputr1   r1   r2   rc   )  s    

zEvollaSaProtPooler.forwardr   r1   r1   rU   r2   r   #  s   r   c                   @   sT   e Zd ZU eed< dgZdZdZdZe	e
edddge
edddgdZd	d
 ZdS )EvollaSaProtPreTrainedModelrT   r   Tr$   r   )r   Z
layer_namer   )r   
attentionscross_attentionsc                 C   s   | j j}t|tjr>|jjjd|d |jdur|jj	  nbt|tj
rz|jjjd|d |jdur|jj|j 	  n&t|tjr|jj	  |jjd dS )zInitialize the weightsrW   meanstdNrz   )rT   initializer_range
isinstancer   r   weightdatanormal_biaszero_r?   r/   rD   fill_rS   r   r   r1   r1   r2   _init_weightsB  s    

z)EvollaSaProtPreTrainedModel._init_weightsN)rh   ri   rj   r&   r   _no_split_modules_supports_flash_attn_supports_sdpa_supports_attention_backendr   r"   r   _can_record_outputsr  r1   r1   r1   r2   r   2  s   
r   c                       s   e Zd Zed fddZdd Zdd Zdd	 Zede	e
j e	e
j eee
j ef dddZdeee e
je
jedddZ  ZS )EvollaSaProtProteinEncoderr   c                    s$   t  | t|| _t|| _d S r   )r=   r>   r4   rb   r   encoderrR   rU   r1   r2   r>   S  s    
z#EvollaSaProtProteinEncoder.__init__c                 C   s   | j jS r   rb   rC   rS   r1   r1   r2   get_input_embeddingsX  s    z/EvollaSaProtProteinEncoder.get_input_embeddingsc                 C   s   || j _d S r   r  rS   r   r1   r1   r2   set_input_embeddings[  s    z/EvollaSaProtProteinEncoder.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr  r   r   r   )rS   Zheads_to_pruner   r   r1   r1   r2   _prune_heads^  s    z'EvollaSaProtProteinEncoder._prune_headsN)r.   r`   r   c                 C   sv   |  }|\}}|j}|d u r0tj||f|d}| j||d}| ||}| j||d}	|	d }
t|
|	j|	j	|	j
dS )Nr   r.   r`   )r`   r   )r   r   r   r   )rf   re   r+   onesrb   get_extended_attention_maskr  r   r   r   r   )rS   r.   r`   rg   r   r   re   ra   extended_attention_maskZencoder_outputsZsequence_outputr1   r1   r2   rc   f  s    z"EvollaSaProtProteinEncoder.forward)r`   rg   re   r_   r   c                 C   s   |du rt | }| dkr$| jjs8|dur8tdt | dkrb|dddddddf }nV| dkr| jjrt|||}q|ddddddf }nt	d| d|j
 d|j|d}d	| t|j }|S )
a  
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (`Tuple[int]`):
                The shape of the input to the model.

        Returns:
            `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
        Nrm   zNThe `device` argument is deprecated and will be removed in v5 of Transformers.r   z!Wrong shape for input_ids (shape z) or attention_mask (shape r   r{   rz   )r   r(   rT   r   warningswarnFutureWarningr   Z*create_extended_attention_mask_for_decoderr   r\   r^   r+   finfomin)rS   r`   rg   re   r_   r  r1   r1   r2   r    s*    	z6EvollaSaProtProteinEncoder.get_extended_attention_mask)N)NN)rh   ri   rj   r&   r>   r  r  r  r#   r   r+   r   r   r   r   rc   r*   re   r]   r  rl   r1   r1   rU   r2   r
  R  s      r
  c                       s&   e Zd Zd fdd	Zdd Z  ZS )!EvollaSequenceCompressorAttention@      c                    sx   t    |d | _|| _|| }t|| _t|| _tj||dd| _	tj||d dd| _
tj||dd| _d S )Nr   Fr   rm   )r=   r>   scaler   r   rD   
norm_medianorm_latentsr   to_qto_kvto_out)rS   r(   dim_headr   	inner_dimrU   r1   r2   r>     s    

z*EvollaSequenceCompressorAttention.__init__c                 C   s  |  |}| |}| j}| |}tj||fdd}| |jddd\}}||	d|	d|d
dddd}||	d|	d|d
dddd}||	d|	d|d
dddd}|| j }t||dd}	|	|	jddd	  }	|	j\}
}}}t|||j}|d
d
d
d
d
d
f }|d
d
d
d
d
d
f }|| }|	d|  d}	|	jdd}t||}|
dddd}||	d|	dd}| |S )z
        Args:
            x (torch.Tensor): image features
                shape (b, n1, D)
            latent (torch.Tensor): latent features
                shape (b, n2, D);  n2: num of latent tokens
        rt   r'   rm   r:   r   r$   r   Tr(   keepdimNg     )r!  r"  r   r#  r+   ro   r$  rn   r   rf   permuter   r   r   amaxdetachr\   r  r^   re   rY   boolr   r   r%  )rS   rq   latentsr0   hr   Zkv_inputr   vsimbsZnhZskdZokdr  Zmask_expZones_expattnoutr1   r1   r2   rc     s2    




(((
z)EvollaSequenceCompressorAttention.forward)r  r  r   r1   r1   rU   r2   r    s   r  c                       s&   e Zd Zd fdd	Zdd Z  ZS )EvollaFeedForward   c                    sT   t    t|| }t|| _tj||dd| _t | _	tj||dd| _
d S NFr  )r=   r>   r*   r   rD   normr   fc1ZGELUr   fc2)rS   r(   multr'  rU   r1   r2   r>     s    

zEvollaFeedForward.__init__c              	   C   s   |  | | | |S r   )r:  r   r9  r8  )rS   rq   r1   r1   r2   rc     s    zEvollaFeedForward.forward)r6  r   r1   r1   rU   r2   r5    s   	r5  c                       s*   e Zd Zed fddZdd Z  ZS )!EvollaSequenceCompressorResamplerr   c              
      s   t    |jj}|j| _tjt	| j|dd| _
tg | _t|jD ]2}| jtt||j|jdt||jdg qJt|j| _t||j| _d S )NT)Zrequires_grad)r(   r&  r   )r(   r;  )r=   r>   protein_encoder_configrA   Zresampler_num_latentsnum_latentsr   	Parameterr+   Zrandnr.  r   layersr   Zresampler_depthappendr  Zresampler_dim_headZresampler_headsr5  Zresampler_ff_multrD   r8  r   protein_projector)rS   rT   Zprotein_repr_dimr   rU   r1   r2   r>     s"    

z*EvollaSequenceCompressorResampler.__init__c                 C   s   |j d }|j \}}t|| j|j}tj||fdd}t|| jj}| jd  |ddd }||j	}| j
D ]$\}	}
|	|||| }|
|| }qz| |}| |S )Nr   r$   r'   r:   )r\   r+   r  r>  r^   re   ro   r.  r   r_   r@  rB  r8  )rS   Zembedsr0   br2  r   Zlatent_maskr  r.  r3  ffZtransformed_featurer1   r1   r2   rc     s    


z)EvollaSequenceCompressorResampler.forward)rh   ri   rj   r%   r>   rc   rl   r1   r1   rU   r2   r<    s   r<  c                   @   sb   e Zd ZU dZejed< dZeej ed< dZ	ee
ejdf  ed< dZee
ejdf  ed< dS )EvollaProteinEncoderModelOutputNsequence_compressor_outputr   .r   r   )rh   ri   rj   rF  r+   r   r   r   r   r   r   r   r1   r1   r1   r2   rE  &  s   
rE  c                       s:   e Zd Zed fddZeejejdddZ	  Z
S )EvollaProteinEncoderr   c                    s(   t    t|jd| _t|d| _d S )Nr   )r=   r>   r
  r=  modelr<  sequence_compressor_resamplerrR   rU   r1   r2   r>   0  s    
zEvollaProteinEncoder.__init__r  c                 K   s.   | j ||d}|j}| ||}t||jdS )Nr  )rF  r   )rH  r   rI  rE  )rS   r.   r`   r   Zprotein_outputZprotein_embedsZsequence_reprr1   r1   r2   rc   5  s    zEvollaProteinEncoder.forward)rh   ri   rj   r%   r>   r    r+   
LongTensorr   rc   rl   r1   r1   rU   r2   rG  /  s   rG  c                       sT   e Zd Zdee ee ee d fddZdd Zeddd	d
dddZ  Z	S )#EvollaSequenceAlignerCrossAttentionN)protein_encoder_dimstructure_encoder_dimmsa_encoder_dimc                    sv  t    |j| _|j| _| jd | _t| j| j | _| j| j | _|j}|j	}|j
}t| j| j| _|d urt|| j| _t|| j| _nd | _d | _|d urt|| j| _t|| j| _nd | _d | _|d urt|| j| _t|| j| _nd | _d | _t| j| _t|| _tj| j| j|d| _t| j|| _ttdg| _ttdg| _d S )Nr   r  rW   ) r=   r>   rA   r   r   r*   r   r   Z$aligner_attention_probs_dropout_probZaligner_enable_biasZaligner_ffn_multr   r   r   key_proteinvalue_proteinkey_structurevalue_structurekey_msa	value_msaEvollaRMSNormattention_normrG   rI   out_projr5  rD  r?  r+   tensorgate_attentiongate_ffw)rS   rT   rL  rM  rN  r   Zenable_biasZffn_multrU   r1   r2   r>   B  s>    

z,EvollaSequenceAlignerCrossAttention.__init__c	                 C   s  |||g}	dd |	D }	|	s$t dtj|	dd}	| |}
| |
}
| jdurz| jdurz||}| |}| |}nd}d}| jdur| j	dur||}| |}| 	|}nd}d}| j
dur| jdur||}| 
|}| |}nd}d}|||g}dd |D }tj|dd}|||g}dd |D }tj|dd}|
 dd	 | j| jf }|
j| d
ddd}
| dd	 | j| jf }|j| d
ddd}| dd	 | j| jf }|j| d
ddd}|
| j }
|du rt|d
|d|j}|ddddddf |	ddddddf  }t|
|d	d}||jd	dd  }|d|  t|jj}tjd	d|}t||}|d
ddd }| dd | j f }|j| }| !|}|S )z
        query_states: text
        key_value_states: protein
        query_states: [bs, query_seq_len, dim]
        key_value_states: [bs, kv_seq_len, dim]
        query_attn_mask: [bs, query_seq_len]
        kv_attn_mask: [bs, kv_seq_len]
        c                 S   s   g | ]}|d ur|qS r   r1   r   r1   r1   r2   r     r   zGEvollaSequenceAlignerCrossAttention.cross_attention.<locals>.<listcomp>z=At least one modality should be provided for cross attention.r$   r'   Nc                 S   s   g | ]}|d ur|qS r   r1   r   r1   r1   r2   r     r   c                 S   s   g | ]}|d ur|qS r   r1   r   r1   r1   r2   r     r   r:   r   rm   r   rt   Tr(  )"r   r+   ro   rV  r   rO  rP  r^   rQ  rR  rS  rT  rf   r   r   r   r*  r   r  re   r   r   r+  r,  rY   r-  r  r_   r  r   ZSoftmaxr   r   rW  )rS   query_statesprotein_key_value_statesstructure_key_value_statesmsa_key_value_statesquery_attn_maskprotein_kv_attn_maskstructure_kv_attn_maskmsa_kv_attn_maskZkv_attn_maskr   Zkey_layer_proteinZvalue_layer_proteinZkey_layer_structureZvalue_layer_structureZkey_layer_msaZvalue_layer_msar   r   Znew_query_layer_shapeZnew_key_layer_shapeZnew_value_layer_shaper`   r   Zattention_scoresZattention_probsZcontext_layerZnew_context_layer_shaper1   r1   r2   cross_attentionu  s|    












 0

z3EvollaSequenceAlignerCrossAttention.cross_attentionpast_key_valuepast_key_values4.58new_nameversionc              
   C   s  |d urL|j \}}}|d u rPt|||	j|	j||fdj |j}nd }|d ur|j \}}}|d u rt|||	j|
j||fdj |j}nd }|d ur|j \}}}|d u rt|||	j|j||fdj |j}nd }|}|d ur| s0|d ur| s0|d ur| r|}| j||||||||d}t	| j
| }|| }|}| |t	| j }|| }|S )N)rf   )r[  r\  r]  r^  r_  r`  ra  rb  )r\   r+   r  r^   re   rN   Tanyrc  tanhrY  rD  rZ  )rS   r[  protein_kv_statesstructure_kv_statesmsa_kv_statesr_  r`  ra  rb  protein_batch_maskstructure_batch_maskmsa_batch_maskre  r2  Zprotein_kv_seq_lenr(   Zstructure_kv_seq_lenZmsa_kv_seq_lenr   residualr1   r1   r2   rc     sx    z+EvollaSequenceAlignerCrossAttention.forward)NNN)NNNNNNN)
rh   ri   rj   r   r*   r>   rc  r!   rc   rl   r1   r1   rU   r2   rK  A  s$      3p       rK  ZRMSNormc                       s.   e Zd Zd fdd	Zdd Zdd Z  ZS )	rU  ư>c                    s&   t    tt|| _|| _dS )z<
        EvollaRMSNorm is equivalent to T5LayerNorm
        N)r=   r>   r   r?  r+   r  r   variance_epsilon)rS   rA   r6   rU   r1   r2   r>   2  s    
zEvollaRMSNorm.__init__c                 C   sJ   |j }|tj}|djddd}|t|| j  }| j|| S )Nrm   r:   T)r)  )	r_   r^   r+   r   powr   Zrsqrtru  r   )rS   r   Zinput_dtypeZvariancer1   r1   r2   rc   :  s
    zEvollaRMSNorm.forwardc                 C   s   t | jj d| j S )Nz, eps=)r   r   r\   ru  r  r1   r1   r2   
extra_reprA  s    zEvollaRMSNorm.extra_repr)rt  )rh   ri   rj   r>   rc   rw  rl   r1   r1   rU   r2   rU  0  s   rU  c                       sD   e Zd ZU ejed< ded fddZe e	dd Z
  ZS )	EvollaRotaryEmbeddingry   Nr   c                    s   t    t|dr:t|jtr:|jd|jd| _nd| _|j| _	|j| _
|| _t| j | _| | j|\}| _| jd|dd | j| _d S )Nrope_scaling	rope_typetypedefaultry   Fr;   )r=   r>   r   r   ry  dictgetrz  rM   Zmax_seq_len_cachedZoriginal_max_seq_lenrT   r   Zrope_init_fnattention_scalingrK   ry   Zoriginal_inv_freq)rS   rT   re   ry   rU   r1   r2   r>   H  s    
zEvollaRotaryEmbedding.__init__c           
      C   s   | j d d d d f  |jd dd|j}|d d d d d f  }t|jjtrl|jjdkrl|jjnd}t	j
|ddV | |  dd}t	j||fdd	}| | j }| | j }	W d    n1 s0    Y  |j|jd
|	j|jd
fS )Nr   r:   r$   ZmpscpuF)device_typeZenabledrm   r'   r{   )ry   r]   rN   r\   r^   re   r   r{  strr+   Zautocastr   ro   ru   r  rv   r_   )
rS   rq   r9   Zinv_freq_expandedZposition_ids_expandedr  r   r   ru   rv   r1   r1   r2   rc   Y  s    0&,zEvollaRotaryEmbedding.forward)N)rh   ri   rj   r+   r   r   r%   r>   Zno_gradr   rc   rl   r1   r1   rU   r2   rx  E  s
   

rx  c                       s$   e Zd Z fddZdd Z  ZS )	EvollaMLPc                    sx   t    || _|j| _|j| _tj| j| j|jd| _tj| j| j|jd| _	tj| j| j|jd| _
t|j | _d S )Nr  )r=   r>   rT   rA   r   r   r   Zmlp_bias	gate_projup_proj	down_projr	   Z
hidden_actact_fnrR   rU   r1   r2   r>   j  s    
zEvollaMLP.__init__c                 C   s$   |  | | || | }|S r   )r  r  r  r  )rS   rq   r  r1   r1   r2   rc   t  s     zEvollaMLP.forwardr   r1   r1   rU   r2   r  i  s   
r  c                 C   sH   | dd| j d d f }| d| j d d df }tj| |fddS )z*Rotates half the hidden dims of the input..Nr:   rm   r'   )r\   r+   ro   rp   r1   r1   r2   rotate_halfy  s    r  c                 C   sD   | |}| |}| | t| |  }|| t||  }||fS )a  Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    )rZ   r  )r   r   ru   rv   r9   Zunsqueeze_dimZq_embedZk_embedr1   r1   r2   apply_rotary_pos_emb  s
    

r  )r   n_repr   c                 C   s^   | j \}}}}|dkr| S | dddddddddf |||||} | ||| ||S )z
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    r$   N)r\   rN   r   )r   r  batchnum_key_value_headsslenhead_dimr1   r1   r2   	repeat_kv  s
    0r  c                       s   e Zd ZdZeed fddZedddddej	e
ej	ej	f eej	 ee eej ee e
ej	ej	f d
ddZ  ZS )EvollaAttentionz=Multi-headed attention from 'Attention Is All You Need' paperrT   r   c                    s   t    || _|| _t|d|j|j | _|j|j | _	| jd | _
|j| _d| _tj|j|j| j |jd| _tj|j|j| j |jd| _tj|j|j| j |jd| _tj|j| j |j|jd| _d S )Nr  r   Tr  )r=   r>   rT   r   rJ   rA   r   r  r  Znum_key_value_groupsr   attention_dropoutr   r   r   Zattention_biasq_projk_projv_projo_projrS   rT   r   rU   r1   r2   r>     s(    
zEvollaAttention.__init__rd  re  rf  rg  N)r   rO   r`   re  cache_positionr   r   c                 K   s$  |j d d }g |d| jR }| ||dd}	| ||dd}
| ||dd}|\}}t|	|
||\}	}
|d ur|||d}||
|| j	|\}
}t
}| jjdkrt| jj }|| |	|
||f| jsdn| j| jd|\}}|jg |dR   }| |}||fS )Nr:   r$   rm   )rv   ru   r  r   rW   )rI   r   )r\   r  r  r   r   r  r  r  updater   r   rT   r   r   r   r  r   r   r   r  )rS   r   rO   r`   re  r  r   rg   r   r[  Z
key_statesZvalue_statesru   rv   Zcache_kwargsr   r   r   r1   r1   r2   rc     s8    


zEvollaAttention.forward)NN)rh   ri   rj   rk   r%   r*   r>   r!   r+   r   r   r   r
   rJ  r   r   rc   rl   r1   r1   rU   r2   r    s     r  c                       s   e Zd Zeed fddZedddddeje	ejejf e
ej e
ej e
e e
e e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej ejd
ddZ  ZS )EvollaDecoderLayerr  c                    s   t    |j| _t||d| _t|| _t|j|jd| _	t|j|jd| _
|d t|j|j d dkr|t||jd| _d S )Nr  r5   r$   r   )rL  )r=   r>   rA   r  	self_attnr  mlprU  rms_norm_epsinput_layernormpost_attention_layernormmaxr   Zaligner_num_add_layersrK  adapterr  rU   r1   r2   r>     s    

zEvollaDecoderLayer.__init__rd  re  rf  rg  NF)r   rO   r`   r9   re  	use_cacher  rm  rn  ro  rp  rq  rr  r_  r   c              
   K   s   |}|  |}| jf |||||||d|\}}|| }|}| |}| |}|| }t| dr~| j|||	|
||||d}|S )N)r   r`   r9   re  r  r  rO   r  )r[  rm  rn  ro  r_  rp  rq  rr  )r  r  r  r  r   r  )rS   r   rO   r`   r9   re  r  r  rm  rn  ro  rp  rq  rr  r_  r   rs  r   r1   r1   r2   rc     s<    





zEvollaDecoderLayer.forward)NNNFNNNNNNNN)rh   ri   rj   r%   r*   r>   r!   r+   r   r   r   rJ  r
   r-  rc   rl   r1   r1   rU   r2   r    s>               r  c                       sZ   e Zd ZU eed< dZdZg dZdgZdZ	dZ
dZdZdZeedZ fdd	Z  ZS )
EvollaPreTrainedModelrT   rH  T)r  r<  rK  re  F)r   r   c                    sd   | j j}t | t|trD|j  |j  |j	j
jd nt|tr`|jjjd|d d S )Nrz   rW   r   )rT   r   r=   r  r   rK  rY  r  rZ  rV  r   r   r  r<  r.  r   r  rU   r1   r2   r  M  s    



z#EvollaPreTrainedModel._init_weights)rh   ri   rj   r%   r   Zbase_model_prefixZsupports_gradient_checkpointingr  Z_skip_keys_device_placementr  r  Z_supports_flex_attnZ_can_compile_fullgraphr  r  r  r	  r  rl   r1   r1   rU   r2   r  7  s   
r  c                       s   e Zd Zed fddZdd Zdd Zeede	j
ee	j ee	j
 ee ee	j ee ee	j
 ee	j
 ee	j ee	j ee	j ee	j ee	j eeef d	d
dZ  ZS )EvollaModelr   c                    s   t     j| _ j| _t| j j| j| _t	 d| _
t fddt jD | _t j jd| _t d| _t dd| _|   d S )Nr   c                    s   g | ]}t  |d qS )r  )r  )r   r   r   r1   r2   r   `  s
   z(EvollaModel.__init__.<locals>.<listcomp>r5   r   F)r=   r>   rB   r/   r@   r   r?   rA   embed_tokensrG  protein_encoderr   r   r   r@  rU  r  r8  rx  
rotary_embrJ   r   	post_initrR   rU   r   r2   r>   Y  s    

zEvollaModel.__init__c                 C   s   | j S r   r  r  r1   r1   r2   r  n  s    z EvollaModel.get_input_embeddingsc                 C   s
   || _ d S r   r  r  r1   r1   r2   r  q  s    z EvollaModel.set_input_embeddingsN)r.   r`   r9   re  ra   r  r  protein_input_idsprotein_attention_maskstructure_feats	msa_featsrq  rr  r   c                 K   sJ  |du |duA rt d|du r*| |}|rB|du rBt| jd}|du rz|durZ| nd}tj|||jd  |jd}|du r|	d}d}d}|dur|	dur| j
||	d}|j}tjdg|jd  |jd}t| j||||d	}|}| ||}| jD ]0}||f||||||||
|||||d
|}q| |}t||d}|S )a;  
        protein_input_ids (torch.LongTensor):
            The input IDs for the protein sequence in structure-aware tokens. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`.
        protein_attention_mask (torch.Tensor):
            The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`.
        structure_feats (torch.FloatTensor):
            The input IDs for purely structure-based features. Should be of shape `(batch_size, structure_seq_length, structure_feat_dim)` and type `torch.FloatTensor`. Dummy input for now.
        msa_feats (torch.FloatTensor):
            The input IDs for purely MSA-based features. Should be of shape `(batch_size, msa_seq_length, msa_feat_dim)` and type `torch.FloatTensor`. Dummy input for now.
        structure_batch_mask (torch.Tensor):
            The batch mask to decide which protein sequences are purely structure-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `structure_feats`. Dummpy input for now.
        msa_batch_mask (torch.Tensor):
            The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now.
        Nz:You must specify exactly one of input_ids or inputs_embedsr   r   r$   r   r  T)rT   Zinput_embedsr`   r  re  )r`   r9   re  r  r  rO   rm  rn  ro  rp  rq  rr  r_  )r   re  )r   r  r   rT   Zget_seq_lengthr+   rL   r\   re   rZ   r  rF  rX  r   r  r@  r8  r   )rS   r.   r`   r9   re  ra   r  r  r  r  r  r  rq  rr  r   Zpast_seen_tokensZprotein_featsrp  Zprotein_outputsr   r   rO   Zdecoder_layerr   r1   r1   r2   rc   t  sr    !



zEvollaModel.forward)NNNNNNNNNNNNN)rh   ri   rj   r%   r>   r  r  r   r#   r+   rJ  r   r   r
   r   r-  r   r   r   rc   rl   r1   r1   rU   r2   r  X  sD                
r  c                       sp   e Zd Z fddZdd Zdd Zeedej	e
ej e
ej e
ej	 ej	e
ej e
e dd	d
Z  ZS )EvollaForProteinText2Textc                    s@   t  | t|| _|j| _tj|j| jdd| _| 	  d S r7  )
r=   r>   r  rH  r@   r   r   rA   lm_headr  rR   rU   r1   r2   r>     s
    
z"EvollaForProteinText2Text.__init__c                 C   s
   | j  S r   )rH  r  r  r1   r1   r2   r    s    z.EvollaForProteinText2Text.get_input_embeddingsc                 C   s   | j |S r   )rH  r  r  r1   r1   r2   r    s    z.EvollaForProteinText2Text.set_input_embeddingsN)r.   r`   ra   labelsr  r  r  c              	   K   sr   | j f ||||||d|}	|	d }
| |
}d}|durV| jf ||| jd|}t|||	j|	j|	jd}|S )a,  
        protein_input_ids (torch.LongTensor):
            The input IDs for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`.
        protein_attention_mask (torch.Tensor):
            The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`.

        Example:

        ```python
        >>> from transformers import EvollaProcessor, EvollaForProteinText2Text
        >>> model = EvollaForProteinText2Text.from_pretrained("westlake/Evolla-10B-hf")
        >>> processor = EvollaProcessor.from_pretrained("westlake/Evolla-10B-hf")

        >>> protein_information = {
            "aa_seq": "your amino acid sequence",
            "foldseek": "your foldseek sequence",
        }
        >>> question = "What is the function of this protein?"
        >>> message = [
            {"role": "system", "content": "You are an AI expert that can answer any questions about protein."},
            {"role": "user", "content": question},
        ]

        >>> inputs = processor(proteins=[protein_information], messages_list=[message], return_tensors="pt", padding="longest")
        >>> outputs = model.generate(**inputs)

        >>> print(processor.batch_decode(outputs, skip_special_tokens=True))
        ```)r.   r`   ra   r  r  r  r   N)logitsr  r@   )lossr  re  r   r   )rH  r  Zloss_functionr@   r   re  r   r   )rS   r.   r`   ra   r  r  r  r  r   outputsr   r  r  Z
lm_outputsr1   r1   r2   rc     s.    *	
z!EvollaForProteinText2Text.forward)NNNNNNN)rh   ri   rj   r>   r  r  r    r   r+   rJ  r   r   r   r-  rc   rl   r1   r1   rU   r2   r    s*          r  )r  r  r  )rW   N)Nr$   )]r   r  dataclassesr   typingr   r   r   r+   r   r   Zactivationsr	   Zcache_utilsr
   r   Z
generationr   Zintegrationsr   Zmasking_utilsr   Zmodeling_layersr   Zmodeling_outputsr   r   r   r   r   Zmodeling_rope_utilsr   r   Zmodeling_utilsr   r   r   r   Zprocessing_utilsr   Zpytorch_utilsr   r   utilsr   r   r    Zutils.deprecationr!   Zutils.genericr"   r#   Zconfiguration_evollar%   r&   r3   Moduler4   rs   rw   rx   r]   r   r   r   r   r   r   r   r   r   r   r   r
  r  r5  r<  rE  rG  rK  rU  rx  r  r  r  r*   r  r  r  r  r  r  __all__r1   r1   r1   r2   <module>   s   a4  2Y0:#b:* p$
GI  S