a
    h2                     @   sx  d dl Z d dlZd dlmZ d dlmZmZmZ d dlZd dl	Zd dlm
Z
 ddlmZ ddlmZ ddlmZ dd	lmZmZ dd
lmZ ddlmZ ddlmZ ddlmZmZ ddlmZmZ ddl m!Z!m"Z"m#Z# ddl$m%Z%m&Z&m'Z'm(Z(m)Z)m*Z*m+Z+m,Z,m-Z-m.Z. ddl/m0Z0 e r:d dl1m2Z2 d dl3m4Z4m5Z5 n
d\Z2Z4Z5e r^d dl6m7Z7m8Z8 nd\Z8Z7e9e2e7e8fZ:dZ;e<e=Z>G dd dej
j?Z@G dd de-ZAG dd de)ZBG d d! d!eZCG d"d# d#e%ZDG d$d% d%e
j?ZEG d&d' d'e
j?ZFG d(d) d)e&ZGG d*d+ d+e+ZHG d,d- d-e*ZIG d.d/ d/eZJG d0d1 d1e,eJZKG d2d3 d3e'ZLG d4d5 d5e(ZMg d6ZNdS )7    N)cycle)CallableOptionalUnion)nn   )ACT2FN)FlashAttentionKwargs)BaseModelOutputWithPast)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)logging)deprecate_kwarg)is_causal_conv1d_availableis_mamba_ssm_available   )LlamaRotaryEmbeddingapply_rotary_pos_emb)pad_tensor_by_sizereshape_into_chunkssegment_sum)
ZambaAttentionZambaAttentionDecoderLayerZambaForCausalLMZambaForSequenceClassificationZambaHybridDynamicCacheZambaHybridLayerZambaMambaDecoderLayer
ZambaModelZambaRMSNormeager_attention_forward   )Zamba2Config)selective_state_update)mamba_chunk_scan_combined mamba_split_conv1d_scan_combined)NNN)causal_conv1d_fncausal_conv1d_update)NNzZyphra/Zamba2-2.7Bc                       s(   e Zd Zd fdd	ZdddZ  ZS )	Zamba2RMSNormGatedư>c                    s,   t    tt|| _|| _|| _d S N)	super__init__r   	Parametertorchonesweightvariance_epsilon
group_size)selfhidden_sizer3   eps	__class__ e/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/zamba2/modular_zamba2.pyr-   K   s    
zZamba2RMSNormGated.__init__Nc           	      C   s   |j }|tj}|d ur2|tj|tj }|j^ }}|| j }|j	g ||| jR  }|
djddd}|t|| j  }|j	g ||| j R  }| j|| S )Nr   T)Zkeepdim)dtypetor/   float32r   
functionalsilushaper3   viewpowmeanZrsqrtr2   r1   )	r4   hidden_statesgateZinput_dtypeZprefix_dimsZlast_dimZgroup_countZhidden_states_groupZvariancer9   r9   r:   forwardQ   s    
zZamba2RMSNormGated.forward)r*   )N)__name__
__module____qualname__r-   rG   __classcell__r9   r9   r7   r:   r)   J   s   r)   c                   @   s   e Zd ZdS )Zamba2RMSNormNrH   rI   rJ   r9   r9   r9   r:   rL   _   s   rL   c                   @   sj   e Zd ZdZejdfeeeje	e
 dddZeejejejdddZd	d
 Zde	e edddZdS )Zamba2HybridDynamicCachea  
    A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
    (which has a constant shape regardless of seq_len).

    This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
    and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
    For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
    while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
    For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
    while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
    and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
    N)config
batch_sizer<   devicec              	      s  || _ |j| _d| _t|j|j | _|j| _|j	| _
|j| _g | _i | _i | _i | _i | _i | _t|jD ]n}tj | jd|j |j  | j
|d| j|< tj | j|j| j|d| j|< | j| dkrl| j| ql fddt|jD | _ fddt|jD | _d S )NFr   rQ   r<   hybridc                    s    g | ]}t jg g  d qS rQ   r/   Ztensor.0_rP   rQ   r9   r:   
<listcomp>       z5Zamba2HybridDynamicCache.__init__.<locals>.<listcomp>c                    s    g | ]}t jg g  d qS rT   rV   rW   rZ   r9   r:   r[      r\   )r<   layers_block_typehas_previous_stateintmamba_expandr5   intermediate_sizemamba_d_statessm_state_sizemamba_d_convconv_kernel_sizen_mamba_headstransformer_layersZ_modules_parameters_buffersconv_states
ssm_statesrangenum_hidden_layersr/   zerosmamba_ngroupsmamba_headdimappend	key_cacheZvalue_cache)r4   rO   rP   r<   rQ   ir9   rZ   r:   r-   q   s8    z!Zamba2HybridDynamicCache.__init__)	layer_idxnew_conv_statecache_positionreturnc                 C   sr   | j | }|d| jd }|jddd}||j|d d d d |f< | j |   | j |  |7  < | j | S )Nr   r"   r;   Zshiftsdims)rj   clampre   rollr=   rQ   zero_)r4   rt   ru   rv   
conv_stater9   r9   r:   update_conv_state   s    
z*Zamba2HybridDynamicCache.update_conv_statec                 C   s   | j   | j  d S r+   )rj   r|   rk   )r4   r9   r9   r:   reset   s    
zZamba2HybridDynamicCache.resetr   )rt   rw   c                 C   sL   || j vr| j d n|}t| j|ks8| j|  dkr<dS | j| jd S )zYReturns the sequence length of the cached states. A layer index can be optionally passed.r   )rg   lenrr   ZnumelrA   )r4   rt   r9   r9   r:   get_seq_length   s     z'Zamba2HybridDynamicCache.get_seq_length)r   )rH   rI   rJ   __doc__r/   Zfloat16r#   r_   r<   r   strr-   Tensor
LongTensorr~   r   r   r9   r9   r9   r:   rN   c   s   !
rN   c                   @   s   e Zd ZdS )Zamba2RotaryEmbeddingNrM   r9   r9   r9   r:   r      s   r   c                       s   e Zd ZdZdeee ee ee d fddZedddd	de	j
eee	j
 ee eee	j
e	j
f  ee ee	j
ee	j
 eee	j
  f d
ddZ  ZS )Zamba2AttentionaZ  
    Multi-headed attention from 'Attention Is All You Need' paper.

    Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
    The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads.
    The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer
    (see fig. 2 in https://huggingface.co/papers/2405.16712).
    Additionally, replaced
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2)
    Finally, this attention layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this
    layer is tied, un-tied adapters (formally the same as LoRA but used in the base model) modules are added to the q, k, v projectors to increase
    expressivity with a small memory overhead (see Fig. 2 of https://huggingface.co/papers/2411.15242).
    N)rO   rt   num_fwd_mem_blocksblock_idc           	   	      sT  t  || || _|j| _|| _|jr:tg | _	tg | _
tg | _t| jD ]}||j |krttj| j| jjddtj| jj| jdd}ttj| j| jjddtj| jj| jdd}ttj| j| jjddtj| jj| jdd}nt }t }t }| j	| | j
| | j| qXdd t| jD | _d S )NFbiasc                 S   s   i | ]\}}||qS r9   r9   rX   indexvaluer9   r9   r:   
<dictcomp>   r\   z,Zamba2Attention.__init__.<locals>.<dictcomp>)r,   r-   r   hybrid_layer_idslayer_block_mapr   use_shared_attention_adapterr   
ModuleListlinear_q_adapter_listlinear_k_adapter_listlinear_v_adapter_listrl   num_mem_blocks
SequentialLinearZattention_hidden_sizerO   adapter_rankIdentityrq   	enumerate	layer_dic)	r4   rO   rt   r   r   rs   Zlinear_q_adapterZlinear_k_adapterZlinear_v_adapterr7   r9   r:   r-      s:    zZamba2Attention.__init__past_key_valuepast_key_values4.58new_nameversion)rE   rt   attention_maskr   position_embeddingskwargsrw   c                 K   st  |j d d }g |d| jR }| |}	| |}
| |}| jjr| j| }|	| j| | }	|
| j	| | }
|| j
| | }|	|dd}	|
|dd}
||dd}| jjr|\}}t|	|
||\}	}
|d ur||
||\}
}t}| jjdkrt| jj }|| |	|
||f| js2dn| j| jd|\}}|jg |dR   }| |}||fS )Nr;   r"   r   eagerg        )Zdropoutscaling)rA   head_dimZq_projZk_projZv_projrO   r   r   r   r   r   rB   	transposeuse_mem_roper   updater!   _attn_implementationr   trainingZattention_dropoutr   reshape
contiguousZo_proj)r4   rE   rt   r   r   r   r   Zinput_shapeZhidden_shapeZquery_statesZ
key_statesZvalue_statesZadapter_layer_idxcossinZattention_interfaceZattn_outputZattn_weightsr9   r9   r:   rG      sH    






zZamba2Attention.forward)NNN)NNN)rH   rI   rJ   r   r#   r   r_   r-   r   r/   r   rN   tupler   r	   rG   rK   r9   r9   r7   r:   r      s.      )   r   c                       s   e Zd ZdZdeee d fddZdej	ee
 eej	 dddZdee
 eej	 d	d
dZdee
 eej	 d	ddZ  ZS )Zamba2MambaMixeru  
    Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
    A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
    ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
    and is why Mamba is called **selective** state spaces)
    NrO   rt   c                    s  t    || _|j| _|j| _|j| _t|j	| j | _
|| _|j| _d| _t | _|j| _|j| _|j| _| jj| _|j| _|j| _|j| _|j| _| j
d| j | j  | _tj| j| jd|j| j|jd d| _| j
| j | j }tj| j||j d| _!t"t#$| j| _%t#&d| jd }t"t#'|| _(d| j(_)t*| j
| j
| j dd| _+t"t#$| j| _,d| j,_)tj| j
| j|j d| _-t.st/0d	 d S )
Nr@   r   Tr"   )Zin_channelsZout_channelsr   Zkernel_sizegroupspaddingr   gh㈵>)r3   r6   a  The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d)1r,   r-   rO   r5   rb   rc   rd   re   r_   r`   ra   rt   use_conv_bias
activationr   ZSiLUactuse_mem_eff_pathro   n_groupsrp   r   rf   	num_heads
chunk_sizetime_step_limittime_step_mintime_step_maxconv_dimZConv1dconv1dr   add_bias_linearin_projr.   r/   r0   dt_biasarangelogA_logZ_no_weight_decayr)   normDout_projis_fast_path_availableloggerwarning_once)r4   rO   rt   Zprojection_sizeAr7   r9   r:   r-   $  s^    



	zZamba2MambaMixer.__init__)rE   cache_paramsr   c                 C   sF  |j \}}}| j| j }d| j d| j | j  | j }|d ur6|jr6| |d}	|	j d | d }
|
|
| j| j| jg}t	j
|	|dd\}}}}}t||j| j | jjd| jj| j}t	j
|| j||gdd\}}}t	| j  }|d d d df d d d d d f d| j| jjt	jd}|d d d d d f dd| j}| jd d d df d| j}| jd d d df d| j}||| j|j d | j }||| j|j d | j }||| j| j}t|j| j ||||||d |dd
}||| j| j }| ||}| |d d d df }n|d urvt	 |dksv|j!}||d d d d d f  |}| |}t	| j  }| j"d u ri nd	| j"i}|d urt	 |dk}nd}| j#rN| j$rN|d u rN|rNt%|| jjd| jj| j|f| j| j&d | j| jj| jj'| jj| jj| j| jd
dd|\}}nt	j
|| j| j| jgdd\}}}|d ur|(dd}t)j*+|| j,|j d  df}|j| j -| t.d u s| jdvr| /| |(dd(ddd d d |f }n@t.|(dd| jjd| jj| jd(ddd d d |f }t	j
|| j||gdd\}}}|d urt	 |dks|j!}||d d d d d f  |}t0|||d| j|||||| jd|||| jdf| j&| jd d d| jdd|\}}|d ur|d ur|j| j -| |||d}| ||}| |}|S )Nr   r"   r;   dim.r<   T)zr   dt_softplusZdt_limitF)r   r   seq_idxr   Zrmsnorm_weightZrmsnorm_epsZoutproj_weightZoutproj_biasZheaddimZngroupsZnorm_before_gatereturn_final_statesr   )r@   Zswish)xr1   r   r   )r   r   r   r   r   r   r   )1rA   r   rc   ra   r   r^   r   squeezer   r/   splitr(   rj   rt   r   r1   r   r   expr   floatexpandr   r=   r>   r   r   rB   r$   rk   r   r   allr<   r   r   r   r&   r   r2   r   r   r?   padre   copy_r'   r   r%   )r4   rE   r   r   rP   seq_lenrY   Zgroups_time_state_sizeZd_to_removeZin_projected_statesd_mlpZsplit_projection_dimrF   Zhidden_states_B_CdtBCr   r   r   Zhidden_states_reshapedoutr<   projected_statesZdt_limit_kwargsZinput_not_masked	ssm_stateZ	time_stepZhidden_states_B_C_tr}   scan_outputr9   r9   r:   cuda_kernels_forwarde  s    

<"
 

 
(

 

z%Zamba2MambaMixer.cuda_kernels_forward)r   r   c           1   
      s	  |j \}}}|j}|d ur2|jr2|d}n@|d urht|dksh||d d d d d f  |}|}|j d dj  dj	 j
  j d }	|j|	|	jjjgdd\}}}
}}|d urv|jj  }||j}|jr|
d}
|jj }tj|ddd}|jdkr>|d d dd d f n||d d d d df< |jj | tj||jjjd d dd d f  dd}jr|jj7 }||d d d df }n|dd}tj |j!|j d  df}|jj | |ddd d d |d d f }|d urt|dks|j}||d d d d d f  |}nLtj"|jj#j
f|j|d	}|dddd |f dd}tj|jj	j
 j	j
 gdd\}}}t$j%&  }|d ur|jr|jdkr2|d d d df n"|d d dd d f d d d df }|dd'||j d j#}j(d
 'j(j d j#}tjj)|||j }t*|j+}|d 'jj#j
jtj,d}t$|d
 | }|-|j	ddd d d f }|'|j	jj	 |j d . }|-|d|j d }|d
 |dd d d f  }|-|dj#}||d
  }|jj |jj | |  |-|j	ddd d d f }|'|j	jj	 |j d . }|-|d|j d }|jj |j}|/|j j#j
}|/|j j
d}t0||}|/|jj#}j1d
 'j1j d j#}|||  |j}|-|dd d d df }n8tj)|j( }t*|j+}|-||dj#& }|-||dj
& }|-||dj
& }|j2jj	 djd}|j2jj	 djd}j3|j3  j3  j1d
 t4|  }||d
  }||j| } fdd||||fD \}}}}|5dddd}tj6|dd}t$t7|}|d d d d d d d d d d d f |d d d d d d d d d d d f  }|jdd}|d
 |5dddddd
  } | jdd}!|!d
 |d d d d d f  d}"t$|d d d d d d dd f | }#||#5ddddd
  }$|$5dddddd
 |5ddddddd d d f  jdd5ddddd}%|d ur0|jr0|jj d d d df }&nt8|%d d d df }&tj9|&|%gdd}%t$t7tj |d d d d d d df d}'|%5ddddd}(|'d |(d d d d d df  jdd})|)5ddddd}*|*d d d df |*d d df  }%}t$|}+|dd d d f |%d d d d d df  },|+5dddd}-|,d|-d
  }.|"|. }|-|djj#}|| } dk	r|d d d |d d d d f }|-||d}|d u	r|d u	r|jj | :||
}/;|/|}0|0S )Nr"   r;   r   r   rx   r   r   .rR   ).N).NNr   )r   Zoutput_sizec                    s   g | ]}t | jqS r9   )r   r   )rX   tZpad_sizer4   r9   r:   r[   w  r\   z2Zamba2MambaMixer.torch_forward.<locals>.<listcomp>   )r"   r   )<rA   r<   r^   r   r   r/   r   r=   ra   r   rc   r   r   r   rk   rt   clonerQ   	unsqueezerj   r{   ndimr   sumr   r1   r   r   r   r   r   r?   r   re   rn   r   r   r   r   r   r   Zsoftplusrz   r   r>   r   r   rB   Zbmmr   Zrepeat_interleaver   r   ZpermuteZcumsumr   Z
zeros_likecatr   r   )1r4   Zinput_statesr   r   rP   r   rY   r<   r   r   rF   rE   r   r   r}   r   r   r   r   ZdAZdBZdBxrk   Zssm_states_reshapedZ
C_reshapedyr   Z
D_residualZA_cumsumLZG_intermediateGZM_intermediateMZY_diagZdecay_statesZB_decay_contractionZstatesZprevious_statesZdecay_chunkZstates_permutedresultZ
new_statesZstate_decay_outZC_times_statesZstate_decay_out_permutedZY_offr   Zcontextualized_statesr9   r   r:   torch_forward  s     
.

80 .",.B"$$$P$*L0(&
*
 zZamba2MambaMixer.torch_forwardc                 C   s0   t r"d| jjjjv r"| |||S | |||S )Ncuda)r   r   r1   rQ   typer   r   )r4   rE   r   r   r9   r9   r:   rG     s    zZamba2MambaMixer.forward)N)NN)NN)NN)rH   rI   rJ   r   r#   r   r_   r-   r/   r   rN   r   r   rG   rK   r9   r9   r7   r:   r     s$   D    F  r   c                       s4   e Zd Zdeee d fddZdddZ  ZS )		Zamba2MLPN)rO   r   c              	      s   t    || _|j| _|j| _|| _|| _tj| jd| j |j	d| _
tj| j| j|j	d| _t|j | _tg | _t| jD ]^}||j |krttj| jj| jjddtj| jjd| j dd}nt }| j| q|j}dd t|D | _dS )aQ  
        This MLP layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this layer
        is tied, un-tied adapter modules (formally same as LoRA, but used in the base model) are added to the up and gate projectors to increase expressivity with a small memory overhead.
        r   r   Fc                 S   s   i | ]\}}||qS r9   r9   r   r9   r9   r:   r     r\   z&Zamba2MLP.__init__.<locals>.<dictcomp>N)r,   r-   rO   r5   ra   r   r   r   r   r   gate_up_proj	down_projr   Z
hidden_actact_fnr   gate_up_proj_adapter_listrl   r   r   r   r   rq   r   r   r   )r4   rO   r   r   rs   Zgate_up_proj_adapterr   r7   r9   r:   r-     s(    
zZamba2MLP.__init__c                 C   sZ   |  |}| j| }|| j| | }tj|ddd}| |d |d  }| |}|S )Nr   r;   r   r   r"   )r   r   r  r/   chunkr   r   )r4   Zhidden_statert   Zgate_up_stateoutputr9   r9   r:   rG     s    


zZamba2MLP.forward)NN)N)	rH   rI   rJ   r#   r   r_   r-   rG   rK   r9   r9   r7   r:   r     s   r   c                       s   e Zd Zdeee ee d fddZedddddej	ej	eeej	 ee
 ee eej ee eejeeejejf  f d
	ddZ  ZS )Zamba2AttentionDecoderLayerN)rO   r   rt   c                    sD   || _ t|j}t || t|d||d| _t|||d| _d S )Nr;   )rt   r   r   )r   r   )	r   r   r   r,   r-   r   	self_attnr   feed_forward)r4   rO   r   rt   Znum_gsr7   r9   r:   r-     s
    
z$Zamba2AttentionDecoderLayer.__init__r   r   r   r   F)	rE   original_hidden_statesrt   r   r   output_attentionsr   r   rw   c              	   K   sl   t j||gdd}| |}| jf ||||||d|\}}	| |}| ||}|f}
|rh|
|	f7 }
|
S )a  
        Args:
            hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)`
            original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`.
                This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The
                concatenated tensor is then used as input of the pre-attention RMSNorm
                (see fig. 2 in https://huggingface.co/papers/2405.16712).
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, sequence_length)` where padding elements are indicated by 0.
            past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
                with `head_dim` being the embedding dimension of each attention head.
        r;   r   )rE   rt   r   r   r  r   )r/   Zconcatenateinput_layernormr  Zpre_ff_layernormr  )r4   rE   r  rt   r   r   r  r   r   self_attn_weightsoutputsr9   r9   r:   rG     s$     




z#Zamba2AttentionDecoderLayer.forward)NN)NNFN)rH   rI   rJ   r#   r   r_   r-   r   r/   r   rN   boolr   r   r	   r   FloatTensorrG   rK   r9   r9   r7   r:   r    s"        r  c                       s$   e Zd Zeed fddZ  ZS )Zamba2MambaDecoderLayerr   c                    s2   t  || t||d| _t|j|jd| _d S )Nr   r6   )r,   r-   r   mambarL   r5   rms_norm_epsr	  )r4   rO   rt   r7   r9   r:   r-   4  s    z Zamba2MambaDecoderLayer.__init__)rH   rI   rJ   r#   r_   r-   rK   r9   r9   r7   r:   r  3  s   r  c                       s   e Zd Zeejed fddZeddddde	j
ee	j
 ee ee	j
 ee	j
 ee ee ee ee	j ee	jeee	je	jf  f d

ddZ  ZS )Zamba2HybridLayer)shared_transformerlinearr  c                    s   t  ||| | `|| _d S r+   )r,   r-   Zshared_transfr  )r4   r  r  r  r7   r9   r:   r-   ;  s    zZamba2HybridLayer.__init__r   r   r   r   NF)
rE   r  rt   r   causal_maskr   r  	use_cacher   rw   c
              	   C   sn   | j |||||||	d}
|
d }|r,|
d }| |}| j|||||||	d}
|rj|
d |f|
dd  }
|
S )aY  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            original_hidden_states (`torch.FloatTensor`): word embedding output that will be concatenated with
            hidden activations to form the input of the shared transformer layer.
            layer_idx (`int`): layer number.
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, sequence_length)` where padding elements are indicated by 0.
            past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
                with `head_dim` being the embedding dimension of each attention head.
        )r  rt   r   r   r  r   r   r"   )transformer_hidden_statesr   r   r  r  r   r   N)r  r  Zmamba_decoder)r4   rE   r  rt   r   r  r   r  r  r   layer_outputsr  r
  r9   r9   r:   rG   B  s2    !


zZamba2HybridLayer.forward)NNNNNFFN)rH   rI   rJ   r  r   r   r  r-   r   r/   r   r   r_   rN   r  r   r   r  rG   rK   r9   r9   r7   r:   r  :  s.           r  c                       sJ   e Zd ZU eed< dZdZddgZdZdZ	dZ
dZdZ fddZ  ZS )	Zamba2PreTrainedModelrO   modelTr  r  r   c                    s   t  | t|trtt| jjt	
| jjt	
| jj  t	
| jj j| jjd}|t
t|   }|jj| td|jd }|jjt
| |jjd d S )N)minr"   g      ?)r,   _init_weights
isinstancer   r/   r   ZrandrO   rf   mathr   r   r   rz   Ztime_step_floorexpm1r   datar   r   r   r   r   Zfill_)r4   moduler   Zinv_dtr   r7   r9   r:   r    s     
z#Zamba2PreTrainedModel._init_weights)rH   rI   rJ   r#   __annotations__Zbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesZ_skip_keys_device_placementZ_supports_flash_attnZ_supports_flex_attnZ_supports_sdpaZ_is_statefulr  rK   r9   r9   r7   r:   r    s   
r  c                   @   s   e Zd ZdZedddZdd Zdeej	 eej
 eej	 ee eej ee ee ee ee eej	 eeef dd	d
ZdS )Zamba2Modelzh
    Model consisting of *config.num_hidden_layers* layers.

    Args:
        config: Zamba2Config
    rO   c                    sR  t |    | _ j| _ j| _t j j| j| _	 fddt
 jD }g }g } j| _t
 jD ]d} j| dkr|t |d qj j| dkrj|tj| jj| jjdd |t |d qjt|}t|}t|}| |||}t|| _ j| _t j jd| _ jr@ jr6td	 t | _d| _ | !  d S )
Nc                    s   g | ]}t  |d qS ))r   )r  )rX   kr$  r9   r:   r[     r\   z(Zamba2Model.__init__.<locals>.<listcomp>r  rt   rS   Fr   r  ze`use_long_context` set to `True`: using rescaled `rope_theta` and extended `max_position_embeddings`.)"r  r-   rO   Zpad_token_idZpadding_idxZ
vocab_sizer   Z	Embeddingr5   embed_tokensrl   r   r]   rm   rq   r  r   iterr   
get_layersr   layersr   rL   r  final_layernormr   Zuse_long_contextr   r   r   
rotary_embgradient_checkpointingZ	post_init)r4   rO   blocksmamba_layerslinear_layersrs   r*  r9   r$  r:   r-     s<    
zZamba2Model.__init__c                 C   s|  g }g | _ d| _t| jD ]Z\}}|dkrh| jdkr>|| _t|}| jjt| jj dkrLd| d}t	
|d d d d	 d
 }	| j |	 d}
| jD ]H}|dkr|
| jj |jkrt	
dt|
 d }| j | |
d7 }
q| jjrLd}
| jD ]L}|dkrB|
| jj |jkrBt	
dt|
 d }| j | |
d7 }
q|t|t|t| q|t| q|S )Nr   rS   r"   z	^layers\.z\.shared_transformer\.z(?:z3self_attn\.(?:q_proj|k_proj|v_proj|o_proj)\.weight|z1feed_forward\.(?:gate_up_proj|down_proj)\.weight|z,(?:input_layernorm|pre_ff_layernorm)\.weightz)$z>^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\.z\.(?:0|1)\.weight$zg^shared_transformer\.self_attn\.(?:linear_q_adapter_list|linear_k_adapter_list|linear_v_adapter_list)\.)Z_tied_weights_keysfirst_transformer_layer_idr   r]   nextrO   r   r   r   recompilerq   r   r   r   r  )r4   r.  r0  r/  r*  Zlayer_idZ
layer_typeblockZprefix_patternZmain_keys_patternZ
adapter_idZ_layer_typeZadapter_patternZattn_adapter_patternr9   r9   r:   r)    sh    






zZamba2Model.get_layersN)	input_idsr   position_idsr   inputs_embedsr  r  output_hidden_statesreturn_dictrv   rw   c                 C   sx  |d ur|n| j j}|d ur |n| j j}|d ur4|n| j j}|	d urH|	n| j j}	|d u |d uA rhtd| jr| jr|rt	d d}|d u r| 
|}|}t|}|r|d u r|d ur|jd n|jd }t| j || j| jd}|
d u r&|d ur|j| jdnd}tj|||jd  |jd}
|d u r:|
d}| |||
}| j jr`| ||}nd }|rnd	nd }|r|d	nd }t| jD ]\}}|r||f7 }| jr| jr| |j|||||||||
}n||||||||||d
	}|d }|r|d d ur||d f7 }q| |}|r4||f7 }|d urL|jsLd|_t||rZ|nd ||d}|	rp|S | S )NzaYou cannot specify both input_ids and inputs_embeds at the same time, and must specify either onezX`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.Fr   )r<   rQ   r&  r"   rU   r9   )r  rt   r   r  r   r  r  r   T)Zlast_hidden_stater   rE   Z
attentions) rO   r  r9  r  Zuse_return_dict
ValueErrorr-  r   r   r   r'  r/   r   rA   rN   r<   rQ   r   r1  r   r   Z_update_causal_maskr   r,  r   r*  Z_gradient_checkpointing_func__call__r+  r^   r
   Zto_tuple)r4   r6  r   r7  r   r8  r  r  r9  r:  rv   rE   r  rP   Zpast_seen_tokensr  r   Zall_hidden_statesZall_self_attnsrt   layerr  r  r9   r9   r:   rG     s    








zZamba2Model.forward)
NNNNNNNNNN)rH   rI   rJ   r   r#   r-   r)  r   r/   r   r   rN   r  r  r   r   r
   rG   r9   r9   r9   r:   r#    s4   $2          
r#  c                   @   s   e Zd ZdS )Zamba2ForCausalLMNrM   r9   r9   r9   r:   r>  v  s   r>  c                   @   s   e Zd ZdS )Zamba2ForSequenceClassificationNrM   r9   r9   r9   r:   r?  z  s   r?  )r>  r?  r#  r  )Or  r3  	itertoolsr   typingr   r   r   r/   Ztorch.utils.checkpointr   Zactivationsr   Zmodeling_flash_attention_utilsr	   Zmodeling_outputsr
   Zmodeling_utilsr   r   Zprocessing_utilsr   utilsr   Zutils.deprecationr   Zutils.import_utilsr   r   Zllama.modeling_llamar   r   Zmamba2.modeling_mamba2r   r   r   Zzamba.modeling_zambar   r   r   r   r   r   r   r   r    r!   Zconfiguration_zamba2r#   Z+mamba_ssm.ops.triton.selective_state_updater$   Z!mamba_ssm.ops.triton.ssd_combinedr%   r&   Zcausal_conv1dr'   r(   r   r   Z_CONFIG_FOR_DOCZ
get_loggerrH   r   Moduler)   rL   rN   r   r   r   r   r  r  r  r  r#  r>  r?  __all__r9   r9   r9   r:   <module>   s^   0

Gn   1*?J V