a
    h0                     @   s  d dl Z d dlmZ d dlmZmZ d dlZd dlZd dlmZm	Z	 ddl
mZmZ ddlmZ ddlmZ dd	lmZmZmZmZ dd
lmZmZmZ ddlmZmZmZ ddlmZ ddl m!Z!m"Z" ddl#m$Z$m%Z%m&Z&m'Z'm(Z(m)Z)m*Z*m+Z+m,Z, ddl-m.Z.m/Z/m0Z0m1Z1m2Z2m3Z3 ddl4m5Z5m6Z6 e7e8Z9G dd de%Z:dd Z;dd Z<G dd de	j=Z>G dd de+Z?G dd de,Z@G dd  d e$ZAG d!d" d"e'ZBG d#d$ d$e)ZCG d%d& d&e(ZDG d'd( d(e&ZEG d)d* d*e*ZFeG d+d, d,eZGG d-d. d.eGZHG d/d0 d0e	j=ZIG d1d2 d2e	j=ZJG d3d4 d4e	j=ZKeeG d5d6 d6eZLG d7d8 d8e	j=ZMG d9d: d:e	j=ZNG d;d< d<e2ZOG d=d> d>e3ZPG d?d@ d@e0ZQG dAdB dBe.ZRG dCdD dDe/ZSG dEdF dFe1ZTG dGdH dHeTZUG dIdJ dJeTeZVg dKZWdS )L    N)	dataclass)OptionalUnion)Tensornn   )CacheDynamicCache)GenerationMixin)create_causal_mask)BaseModelOutputWithPast,BaseModelOutputWithPoolingAndCrossAttentionsCausalLMOutputWithPastModelOutput)ModuleUtilsMixinPreTrainedModelget_parameter_dtype)auto_docstringcan_return_tuplelogging)deprecate_kwarg)OutputRecordercheck_model_inputs   )	EsmAttentionEsmEmbeddings
EsmEncoderEsmIntermediateEsmLayer	EsmOutput	EsmPoolerEsmSelfAttentionEsmSelfOutput)LlamaAttentionLlamaDecoderLayerLlamaMLPLlamaPreTrainedModelLlamaRMSNormLlamaRotaryEmbedding   )EvollaConfigSaProtConfigc                       s   e Zd Z fddZ  ZS )EvollaSaProtEmbeddingsc                    s   t  | d | _d S N)super__init__position_idsselfconfig	__class__ e/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/evolla/modular_evolla.pyr/   C   s    zEvollaSaProtEmbeddings.__init__)__name__
__module____qualname__r/   __classcell__r6   r6   r4   r7   r,   B   s   r,   c                 C   s&   | j ddd\}}tj| |fddS )Nr   dim)chunktorchcat)xx1Zx2r6   r6   r7   rotate_half_esmI   s    rD   c                 C   s`   |d d d d d | j d d d f }|d d d d d | j d d d f }| | t| |  S )N)shaperD   )rB   cossinr6   r6   r7   apply_rotary_pos_emb_esmN   s    &&rI   c                       s^   e Zd ZU dZejed< ed fddZdddZ	ejeje
ejejf d	d
dZ  ZS )EvollaSaProtRotaryEmbeddingz
    Rotary position embeddings based on those in
    [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
    matrices which depend on their relative positions.
    inv_freqr=   c                    sT   t    ddtjd|dtjd |   }|}| d| d | _d | _d | _	d S )N      ?i'  r   r   dtyperK   )
r.   r/   r@   arangeZint64floatZregister_buffer_seq_len_cached_cos_cached_sin_cached)r2   r>   rK   r4   r6   r7   r/   ^   s    
$z$EvollaSaProtRotaryEmbedding.__init__r   c                 C   s   |j | }|| jks"| jj|jkr|| _tj|j | |jd| j}t|| j}tj	||fdd
|j}| d d d d d d f | _| d d d d d d f | _| j| jfS )Ndevicer<   r=   )rF   rQ   rR   rU   r@   rO   Ztype_asrK   outerrA   torG   rH   rS   )r2   rB   seq_dimensionZseq_lentZfreqsZembr6   r6   r7   _update_cos_sin_tablesi   s    
z2EvollaSaProtRotaryEmbedding._update_cos_sin_tables)qkreturnc                 C   sJ   | j |dd\| _| _t|| j| jj|jdt|| j| jj|jdfS )NrE   )rX   rM   )rZ   rR   rS   rI   rW   rN   )r2   r[   r\   r6   r6   r7   forwardy   s    z#EvollaSaProtRotaryEmbedding.forward)r   )r8   r9   r:   __doc__r@   r   __annotations__intr/   rZ   tupler^   r;   r6   r6   r4   r7   rJ   U   s
   


rJ   c                   @   s   e Zd ZdddZdS )EvollaSaProtSelfAttentionNFc                 C   s>  t j|  || _|j|j dkrFt|dsFtd|j d|j d|j| _t|j|j | _	| j| j	 | _
t |j| j
| _t |j| j
| _t |j| j
| _|j| _|pt|dd| _d | _| jdks| jd	kr|j| _t d
|j d | j	| _n| jdkrt| j	d| _|j| _|| _d| _| jo6| | _d S )Nr   Zembedding_sizezThe hidden size (z6) is not a multiple of the number of attention heads ()position_embedding_typeabsoluteZrelative_keyZrelative_key_queryr   r)   Zrotaryr=   rL   )r   Moduler/   r3   hidden_sizenum_attention_headshasattr
ValueErrorra   attention_head_sizeall_head_sizeLinearquerykeyvalueattention_probs_dropout_probdropoutgetattrre   Zrotary_embeddingsZmax_position_embeddings	EmbeddingZdistance_embeddingrJ   
is_decoder	layer_idxZscalingZ	is_causal)r2   r3   re   rw   Zis_cross_attentionr6   r6   r7   r/      s8    
z"EvollaSaProtSelfAttention.__init__)NNF)r8   r9   r:   r/   r6   r6   r6   r7   rc      s   rc   c                   @   s   e Zd ZdS )EvollaSaProtSelfOutputNr8   r9   r:   r6   r6   r6   r7   rx      s   rx   c                   @   s   e Zd ZdS )EvollaSaProtAttentionNry   r6   r6   r6   r7   rz      s   rz   c                   @   s   e Zd ZdS )EvollaSaProtIntermediateNry   r6   r6   r6   r7   r{      s   r{   c                   @   s   e Zd ZdS )EvollaSaProtOutputNry   r6   r6   r6   r7   r|      s   r|   c                   @   s   e Zd ZdS )EvollaSaProtLayerNry   r6   r6   r6   r7   r}      s   r}   c                   @   s   e Zd ZdS )EvollaSaProtEncoderNry   r6   r6   r6   r7   r~      s   r~   c                   @   s   e Zd ZdS )EvollaSaProtPoolerNry   r6   r6   r6   r7   r      s   r   c                   @   sT   e Zd ZU eed< dgZdZdZdZe	e
edddge
edddgdZd	d
 ZdS )EvollaSaProtPreTrainedModelr3   r}   Tr)   	attention)indexZ
layer_nameZcrossattention)hidden_states
attentionscross_attentionsc                 C   s   | j j}t|tjr>|jjjd|d |jdur|jj	  nbt|tj
rz|jjjd|d |jdur|jj|j 	  n&t|tjr|jj	  |jjd dS )zInitialize the weights        meanstdNrL   )r3   initializer_range
isinstancer   rn   weightdatanormal_biaszero_ru   padding_idx	LayerNormfill_r2   moduler   r6   r6   r7   _init_weights   s    

z)EvollaSaProtPreTrainedModel._init_weightsN)r8   r9   r:   r+   r`   _no_split_modules_supports_flash_attnZ_supports_sdpa_supports_attention_backendr}   r   rc   Z_can_record_outputsr   r6   r6   r6   r7   r      s   
r   c                       s   e Zd Zed fddZdd Zdd Zdd	 Zede	e
j e	e
j eee
j ef dddZdeee e
je
jedddZ  ZS )EvollaSaProtProteinEncoderr3   c                    s$   t  | t|| _t|| _d S r-   )r.   r/   r,   
embeddingsr~   encoderr1   r4   r6   r7   r/      s    
z#EvollaSaProtProteinEncoder.__init__c                 C   s   | j jS r-   r   Zword_embeddingsr2   r6   r6   r7   get_input_embeddings   s    z/EvollaSaProtProteinEncoder.get_input_embeddingsc                 C   s   || j _d S r-   r   r2   rq   r6   r6   r7   set_input_embeddings   s    z/EvollaSaProtProteinEncoder.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   layerr   Zprune_heads)r2   Zheads_to_pruner   headsr6   r6   r7   _prune_heads   s    z'EvollaSaProtProteinEncoder._prune_headsN)	input_idsattention_maskr]   c                 C   sv   |  }|\}}|j}|d u r0tj||f|d}| j||d}| ||}| j||d}	|	d }
t|
|	j|	j	|	j
dS )NrT   r   r   )r   r   )last_hidden_stater   r   r   )sizerU   r@   onesr   get_extended_attention_maskr   r   r   r   r   )r2   r   r   input_shapeZ
batch_sizeZ
seq_lengthrU   inputs_embedsextended_attention_maskZencoder_outputsZsequence_outputr6   r6   r7   r^      s    z"EvollaSaProtProteinEncoder.forward)r   r   rU   rN   r]   c                 C   s   |du rt | }| dkr$| jjs8|dur8tdt | dkrb|dddddddf }nV| dkr| jjrt|||}q|ddddddf }nt	d| d|j
 d|j|d}d	| t|j }|S )
a  
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (`Tuple[int]`):
                The shape of the input to the model.

        Returns:
            `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
        Nr   zNThe `device` argument is deprecated and will be removed in v5 of Transformers.r   z!Wrong shape for input_ids (shape z) or attention_mask (shape rd   rM   rL   )r   r>   r3   rv   warningswarnFutureWarningr   Z*create_extended_attention_mask_for_decoderrk   rF   rW   r@   finfomin)r2   r   r   rU   rN   r   r6   r6   r7   r     s*    	z6EvollaSaProtProteinEncoder.get_extended_attention_mask)N)NN)r8   r9   r:   r+   r/   r   r   r   r   r   r@   r   r   rb   r   r^   ra   rU   rP   r   r;   r6   r6   r4   r7   r      s      r   c                       s&   e Zd Zd fdd	Zdd Z  ZS )!EvollaSequenceCompressorAttention@      c                    sx   t    |d | _|| _|| }t|| _t|| _tj||dd| _	tj||d dd| _
tj||dd| _d S )N      Fr   r   )r.   r/   scaler   r   r   
norm_medianorm_latentsrn   to_qto_kvto_out)r2   r>   dim_headr   	inner_dimr4   r6   r7   r/   E  s    

z*EvollaSequenceCompressorAttention.__init__c                 C   s  |  |}| |}| j}| |}tj||fdd}| |jddd\}}||	d|	d|d
dddd}||	d|	d|d
dddd}||	d|	d|d
dddd}|| j }t||dd}	|	|	jddd	  }	|	j\}
}}}t|||j}|d
d
d
d
d
d
f }|d
d
d
d
d
d
f }|| }|	d|  d}	|	jdd}t||}|
dddd}||	d|	dd}| |S )z
        Args:
            x (torch.Tensor): image features
                shape (b, n1, D)
            latent (torch.Tensor): latent features
                shape (b, n2, D);  n2: num of latent tokens
        rE   r=   r   r<   r   r)   r   Tr>   ZkeepdimNg     )r   r   r   r   r@   rA   r   r?   viewr   permuter   matmul	transposeamaxdetachrF   r   rW   rU   masked_fillboolZsoftmaxZreshaper   )r2   rB   latentsmaskhr[   Zkv_inputr\   vsimbsZnhZskdZokdr   Zmask_expZones_expattnoutr6   r6   r7   r^   R  s2    




(((
z)EvollaSequenceCompressorAttention.forward)r   r   r8   r9   r:   r/   r^   r;   r6   r6   r4   r7   r   D  s   r   c                       s&   e Zd Zd fdd	Zdd Z  ZS )EvollaFeedForward   c                    sT   t    t|| }t|| _tj||dd| _t | _	tj||dd| _
d S NFr   )r.   r/   ra   r   r   normrn   fc1ZGELU
activationfc2)r2   r>   multr   r4   r6   r7   r/     s    

zEvollaFeedForward.__init__c              	   C   s   |  | | | |S r-   )r   r   r   r   )r2   rB   r6   r6   r7   r^     s    zEvollaFeedForward.forward)r   r   r6   r6   r4   r7   r   ~  s   	r   c                       s*   e Zd Zed fddZdd Z  ZS )!EvollaSequenceCompressorResamplerr   c              
      s   t    |jj}|j| _tjt	| j|dd| _
tg | _t|jD ]2}| jtt||j|jdt||jdg qJt|j| _t||j| _d S )NT)Zrequires_grad)r>   r   r   )r>   r   )r.   r/   protein_encoder_configrh   Zresampler_num_latentsnum_latentsr   	Parameterr@   Zrandnr   
ModuleListlayersrangeZresampler_depthappendr   Zresampler_dim_headZresampler_headsr   Zresampler_ff_multr   r   rn   protein_projector)r2   r3   Zprotein_repr_dim_r4   r6   r7   r/     s"    

z*EvollaSequenceCompressorResampler.__init__c                 C   s   |j d }|j \}}t|| j|j}tj||fdd}t|| jj}| jd  |ddd }||j	}| j
D ]$\}	}
|	|||| }|
|| }qz| |}| |S )Nr   r)   r=   r<   )rF   r@   r   r   rW   rU   rA   r   r   rN   r   r   r   )r2   Zembedsr   br   r   Zlatent_maskr   r   r   ffZtransformed_featurer6   r6   r7   r^     s    


z)EvollaSequenceCompressorResampler.forward)r8   r9   r:   r*   r/   r^   r;   r6   r6   r4   r7   r     s   r   c                   @   sb   e Zd ZU dZejed< dZeej ed< dZ	ee
ejdf  ed< dZee
ejdf  ed< dS )EvollaProteinEncoderModelOutputNsequence_compressor_outputr   .r   r   )r8   r9   r:   r   r@   FloatTensorr`   r   r   r   rb   r   r6   r6   r6   r7   r     s   
r   c                       s:   e Zd Zed fddZeejejdddZ	  Z
S )EvollaProteinEncoderr   c                    s(   t    t|jd| _t|d| _d S )Nr   )r.   r/   r   r   modelr   sequence_compressor_resamplerr1   r4   r6   r7   r/     s    
zEvollaProteinEncoder.__init__r   c                 K   s.   | j ||d}|j}| ||}t||jdS )Nr   )r   r   )r   r   r   r   )r2   r   r   kwargsZprotein_outputZprotein_embedsZsequence_reprr6   r6   r7   r^     s    zEvollaProteinEncoder.forward)r8   r9   r:   r*   r/   r   r@   
LongTensorr   r^   r;   r6   r6   r4   r7   r     s   r   c                       sT   e Zd Zdee ee ee d fddZdd Zeddd	d
dddZ  Z	S )#EvollaSequenceAlignerCrossAttentionN)protein_encoder_dimstructure_encoder_dimmsa_encoder_dimc                    sv  t    |j| _|j| _| jd | _t| j| j | _| j| j | _|j}|j	}|j
}t| j| j| _|d urt|| j| _t|| j| _nd | _d | _|d urt|| j| _t|| j| _nd | _d | _|d urt|| j| _t|| j| _nd | _d | _t| j| _t|| _tj| j| j|d| _t| j|| _ttdg| _ttdg| _d S )Nr   r   r   ) r.   r/   rh   ri   r   ra   rl   rm   Z$aligner_attention_probs_dropout_probZaligner_enable_biasZaligner_ffn_multr   rn   ro   key_proteinvalue_proteinkey_structurevalue_structurekey_msa	value_msaEvollaRMSNormattention_normZDropoutrs   out_projr   r   r   r@   tensorgate_attentiongate_ffw)r2   r3   r   r   r   rr   Zenable_biasZffn_multr4   r6   r7   r/     s>    

z,EvollaSequenceAlignerCrossAttention.__init__c	                 C   s  |||g}	dd |	D }	|	s$t dtj|	dd}	| |}
| |
}
| jdurz| jdurz||}| |}| |}nd}d}| jdur| j	dur||}| |}| 	|}nd}d}| j
dur| jdur||}| 
|}| |}nd}d}|||g}dd |D }tj|dd}|||g}dd |D }tj|dd}|
 dd	 | j| jf }|
j| d
ddd}
| dd	 | j| jf }|j| d
ddd}| dd	 | j| jf }|j| d
ddd}|
| j }
|du rt|d
|d|j}|ddddddf |	ddddddf  }t|
|d	d}||jd	dd  }|d|  t|jj}tjd	d|}t||}|d
ddd }| dd | j f }|j| }| !|}|S )z
        query_states: text
        key_value_states: protein
        query_states: [bs, query_seq_len, dim]
        key_value_states: [bs, kv_seq_len, dim]
        query_attn_mask: [bs, query_seq_len]
        kv_attn_mask: [bs, kv_seq_len]
        c                 S   s   g | ]}|d ur|qS r-   r6   .0r   r6   r6   r7   
<listcomp>      zGEvollaSequenceAlignerCrossAttention.cross_attention.<locals>.<listcomp>z=At least one modality should be provided for cross attention.r)   r=   Nc                 S   s   g | ]}|d ur|qS r-   r6   r  r6   r6   r7   r  A  r  c                 S   s   g | ]}|d ur|qS r-   r6   r  r6   r6   r7   r  E  r  r<   r   r   r   rE   Tr   )"rk   r@   rA   r   ro   r   r   rW   r   r   r   r   r   ri   rl   r   r   r   r   rU   r   r   r   r   r   r   r   rN   r   r   ZSoftmax
contiguousrm   r   )r2   query_statesprotein_key_value_statesstructure_key_value_statesmsa_key_value_statesquery_attn_maskprotein_kv_attn_maskstructure_kv_attn_maskmsa_kv_attn_maskZkv_attn_maskZquery_layerZkey_layer_proteinZvalue_layer_proteinZkey_layer_structureZvalue_layer_structureZkey_layer_msaZvalue_layer_msaZ	key_layerZvalue_layerZnew_query_layer_shapeZnew_key_layer_shapeZnew_value_layer_shaper   Zattn_weightsZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shaper6   r6   r7   cross_attention  s|    












 0

z3EvollaSequenceAlignerCrossAttention.cross_attentionpast_key_valuepast_key_values4.58new_nameversionc              
   C   s  |d urL|j \}}}|d u rPt|||	j|	j||fdj |j}nd }|d ur|j \}}}|d u rt|||	j|
j||fdj |j}nd }|d ur|j \}}}|d u rt|||	j|j||fdj |j}nd }|}|d ur| s0|d ur| s0|d ur| r|}| j||||||||d}t	| j
| }|| }|}| |t	| j }|| }|S )N)r   )r  r  r  r	  r
  r  r  r  )rF   r@   r   rW   rU   expandTanyr  tanhr   r   r   )r2   r  protein_kv_statesstructure_kv_statesmsa_kv_statesr
  r  r  r  protein_batch_maskstructure_batch_maskmsa_batch_maskr  r   Zprotein_kv_seq_lenr>   Zstructure_kv_seq_lenZmsa_kv_seq_lenr   residualr6   r6   r7   r^   u  sx    z+EvollaSequenceAlignerCrossAttention.forward)NNN)NNNNNNN)
r8   r9   r:   r   ra   r/   r  r   r^   r;   r6   r6   r4   r7   r     s$      3p       r   c                   @   s   e Zd ZdS )r   Nry   r6   r6   r6   r7   r     s   r   c                   @   s   e Zd ZdS )EvollaRotaryEmbeddingNry   r6   r6   r6   r7   r     s   r   c                   @   s   e Zd ZdS )	EvollaMLPNry   r6   r6   r6   r7   r!    s   r!  c                   @   s   e Zd ZdS )EvollaAttentionNry   r6   r6   r6   r7   r"    s   r"  c                       s   e Zd Zeed fddZedddddeje	ejejf e
ej e
ej e
e e
e e
ej e
ej e
ej e
ej e
ej e
ej e
ej e
ej d
ddZ  ZS )EvollaDecoderLayerr3   rw   c                    s@   t  || |d t|j|j d dkr<t||jd| _d S )Nr)   r   )r   )r.   r/   maxnum_hidden_layersZaligner_num_add_layersr   rh   adapter)r2   r3   rw   r4   r6   r7   r/     s    zEvollaDecoderLayer.__init__r  r  r  r  NF)r   position_embeddingsr   r0   r  	use_cachecache_positionr  r  r  r  r  r  r
  c              
   K   s   |}|  |}| jf |||||||d|\}}|| }|}| |}| |}|| }t| dr~| j|||	|
||||d}|S )N)r   r   r0   r  r)  r*  r(  r'  )r  r  r  r  r
  r  r  r  )Zinput_layernormZ	self_attnZpost_attention_layernormZmlprj   r'  )r2   r   r(  r   r0   r  r)  r*  r  r  r  r  r  r  r
  r   r  r   r6   r6   r7   r^     s<    





zEvollaDecoderLayer.forward)NNNFNNNNNNNN)r8   r9   r:   r*   ra   r/   r   r@   r   rb   r   r   r   r   r^   r;   r6   r6   r4   r7   r#    s<               r#  c                   @   s(   e Zd ZdZdZdZg dZdd ZdS )EvollaPreTrainedModelF)r#  r   r   c                 C   sd   | j j}t| | t|trD|j  |j  |j	j
jd nt|tr`|jjjd|d d S )NrL   r   r   )r3   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r6   r6   r7   r     s    



z#EvollaPreTrainedModel._init_weightsN)r8   r9   r:   r   Z_supports_flex_attnr   r   r   r6   r6   r6   r7   r+    s
   r+  c                       s   e Zd Zed fddZdd Zdd Zeede	j
ee	j ee	j
 ee ee	j ee ee	j
 ee	j
 ee	j ee	j ee	j ee	j ee	j eeef d	d
dZ  ZS )EvollaModelr   c                    s   t     j| _ j| _t| j j| j| _t	 d| _
t fddt jD | _t j jd| _t d| _t dd| _|   d S )Nr   c                    s   g | ]}t  |d qS )r$  )r#  )r  rw   r   r6   r7   r  /  s
   z(EvollaModel.__init__.<locals>.<listcomp>)epsgradient_checkpointingF)r.   r/   Zpad_token_idr   
vocab_sizer   ru   rh   embed_tokensr   protein_encoderr   r   r&  r   r   Zrms_norm_epsr   r   
rotary_embrt   r.  	post_initr1   r4   r   r7   r/   (  s    

zEvollaModel.__init__c                 C   s   | j S r-   r0  r   r6   r6   r7   r   =  s    z EvollaModel.get_input_embeddingsc                 C   s
   || _ d S r-   r4  r   r6   r6   r7   r   @  s    z EvollaModel.set_input_embeddingsN)r   r   r0   r  r   r)  r*  protein_input_idsprotein_attention_maskstructure_feats	msa_featsr  r  r]   c                 K   sJ  |du |duA rt d|du r*| |}|rB|du rBt| jd}|du rz|durZ| nd}tj|||jd  |jd}|du r|	d}d}d}|dur|	dur| j
||	d}|j}tjdg|jd  |jd}t| j||||d	}|}| ||}| jD ]0}||f||||||||
|||||d
|}q| |}t||d}|S )a;  
        protein_input_ids (torch.LongTensor):
            The input IDs for the protein sequence in structure-aware tokens. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`.
        protein_attention_mask (torch.Tensor):
            The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`.
        structure_feats (torch.FloatTensor):
            The input IDs for purely structure-based features. Should be of shape `(batch_size, structure_seq_length, structure_feat_dim)` and type `torch.FloatTensor`. Dummy input for now.
        msa_feats (torch.FloatTensor):
            The input IDs for purely MSA-based features. Should be of shape `(batch_size, msa_seq_length, msa_feat_dim)` and type `torch.FloatTensor`. Dummy input for now.
        structure_batch_mask (torch.Tensor):
            The batch mask to decide which protein sequences are purely structure-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `structure_feats`. Dummpy input for now.
        msa_batch_mask (torch.Tensor):
            The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now.
        Nz:You must specify exactly one of input_ids or inputs_embedsr   r   r)   rT   r   T)r3   Zinput_embedsr   r*  r  )r   r0   r  r)  r*  r(  r  r  r  r  r  r  r
  )r   r  )rk   r0  r	   r3   Zget_seq_lengthr@   rO   rF   rU   Z	unsqueezer1  r   r   r   r2  r   r   r   )r2   r   r   r0   r  r   r)  r*  r5  r6  r7  r8  r  r  r   Zpast_seen_tokensZprotein_featsr  Zprotein_outputsZcausal_maskr   r(  Zdecoder_layeroutputr6   r6   r7   r^   C  sr    !



zEvollaModel.forward)NNNNNNNNNNNNN)r8   r9   r:   r*   r/   r   r   r   r   r@   r   r   r   r   r   r   r   rb   r   r^   r;   r6   r6   r4   r7   r,  '  sD                
r,  c                       sp   e Zd Z fddZdd Zdd Zeedej	e
ej e
ej e
ej	 ej	e
ej e
e dd	d
Z  ZS )EvollaForProteinText2Textc                    s@   t  | t|| _|j| _tj|j| jdd| _| 	  d S r   )
r.   r/   r,  r   r/  r   rn   rh   lm_headr3  r1   r4   r6   r7   r/     s
    
z"EvollaForProteinText2Text.__init__c                 C   s
   | j  S r-   )r   r   r   r6   r6   r7   r     s    z.EvollaForProteinText2Text.get_input_embeddingsc                 C   s   | j |S r-   )r   r   r   r6   r6   r7   r     s    z.EvollaForProteinText2Text.set_input_embeddingsN)r   r   r   labelsr5  r6  r)  c              	   K   sr   | j f ||||||d|}	|	d }
| |
}d}|durV| jf ||| jd|}t|||	j|	j|	jd}|S )a,  
        protein_input_ids (torch.LongTensor):
            The input IDs for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`.
        protein_attention_mask (torch.Tensor):
            The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`.

        Example:

        ```python
        >>> from transformers import EvollaProcessor, EvollaForProteinText2Text
        >>> model = EvollaForProteinText2Text.from_pretrained("westlake/Evolla-10B-hf")
        >>> processor = EvollaProcessor.from_pretrained("westlake/Evolla-10B-hf")

        >>> protein_information = {
            "aa_seq": "your amino acid sequence",
            "foldseek": "your foldseek sequence",
        }
        >>> question = "What is the function of this protein?"
        >>> message = [
            {"role": "system", "content": "You are an AI expert that can answer any questions about protein."},
            {"role": "user", "content": question},
        ]

        >>> inputs = processor(proteins=[protein_information], messages_list=[message], return_tensors="pt", padding="longest")
        >>> outputs = model.generate(**inputs)

        >>> print(processor.batch_decode(outputs, skip_special_tokens=True))
        ```)r   r   r   r5  r6  r)  r   N)logitsr<  r/  )lossr=  r  r   r   )r   r;  Zloss_functionr/  r   r  r   r   )r2   r   r   r   r<  r5  r6  r)  r   outputsr   r=  r>  Z
lm_outputsr6   r6   r7   r^     s.    *	
z!EvollaForProteinText2Text.forward)NNNNNNN)r8   r9   r:   r/   r   r   r   r   r@   r   r   r   r   r   r^   r;   r6   r6   r4   r7   r:    s*          r:  )r:  r,  r+  )Xr   dataclassesr   typingr   r   r@   Ztorch.utils.checkpointr   r   Zcache_utilsr   r	   Z
generationr
   Zmasking_utilsr   Zmodeling_outputsr   r   r   r   Zmodeling_utilsr   r   r   utilsr   r   r   Zutils.deprecationr   Zutils.genericr   r   Zesm.modeling_esmr   r   r   r   r   r   r    r!   r"   Zllama.modeling_llamar#   r$   r%   r&   r'   r(   Zconfiguration_evollar*   r+   Z
get_loggerr8   loggerr,   rD   rI   rg   rJ   rc   rx   rz   r{   r|   r}   r~   r   r   r   r   r   r   r   r   r   r   r   r!  r"  r#  r+  r,  r:  __all__r6   r6   r6   r7   <module>   sf   , 
-$b:* pB S