a
    hL                 	   @   s  d dl Z d dlZd dlmZ d dlmZmZmZmZ d dl	Z	d dl	m
Z
 d dlmZmZmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZ ddlmZmZmZ ddlmZmZ ddl m!Z!m"Z" ddl#m$Z$ ddl%m&Z&m'Z' ddl(m)Z) ddl*m+Z+m,Z, ddl-m.Z. e, r6d dl/m0Z0 d dl1m2Z2m3Z3 n
d\Z0Z2Z3e+ rZd dl4m5Z5m6Z6 nd\Z6Z5e'7e8Z9G dd de	j
j:Z;G dd de
j:Z<G dd dZ=G d d! d!e
j:Z>e	j?e@e	j?d"d#d$ZAdKe
j:e	j?e	j?e	j?ee	j? eBeBd&d'd(ZCd)d* ZDdLd+d,ZEG d-d. d.e
j:ZFe	j?e@d/d0d1ZGd2d3 ZHd4d5 ZIeJe0e5e6fZKG d6d7 d7e
j:ZLG d8d9 d9e
j:ZMG d:d; d;e
j:ZNG d<d= d=e
j:ZOG d>d? d?e
j:ZPG d@dA dAe"ZQe&G dBdC dCeQZRG dDdE dEeQeZSe&dFdGG dHdI dIeQZTg dJZUdS )M    N)cycle)AnyCallableOptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)Cache)GenerationMixin)AttentionMaskConverter)FlashAttentionKwargs)BaseModelOutputWithPastCausalLMOutputWithPast SequenceClassifierOutputWithPast)ROPE_INIT_FUNCTIONSdynamic_rope_update)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)auto_docstringlogging)deprecate_kwarg)is_causal_conv1d_availableis_mamba_ssm_available   )Zamba2Config)selective_state_update)mamba_chunk_scan_combined mamba_split_conv1d_scan_combined)NNN)causal_conv1d_fncausal_conv1d_update)NNc                       s(   e Zd Zd fdd	ZdddZ  ZS )	Zamba2RMSNormGatedư>c                    s,   t    tt|| _|| _|| _d S N)	super__init__r   	Parametertorchonesweightvariance_epsilon
group_size)selfhidden_sizer/   eps	__class__ f/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/zamba2/modeling_zamba2.pyr)   =   s    
zZamba2RMSNormGated.__init__Nc           	      C   s   |j }|tj}|d ur2|tj|tj }|j^ }}|| j }|j	g ||| jR  }|
djddd}|t|| j  }|j	g ||| j R  }| j|| S N   T)Zkeepdim)dtypetor+   float32r   
functionalsilushaper/   viewpowmeanrsqrtr.   r-   )	r0   hidden_statesgateinput_dtypeZprefix_dimsZlast_dimZgroup_countZhidden_states_groupvariancer5   r5   r6   forwardC   s    
zZamba2RMSNormGated.forward)r&   )N)__name__
__module____qualname__r)   rH   __classcell__r5   r5   r3   r6   r%   <   s   r%   c                       s.   e Zd Zd fdd	Zdd Zdd Z  ZS )	Zamba2RMSNormr&   c                    s&   t    tt|| _|| _dS )z<
        Zamba2RMSNorm is equivalent to T5LayerNorm
        N)r(   r)   r   r*   r+   r,   r-   r.   )r0   r1   r2   r3   r5   r6   r)   R   s    
zZamba2RMSNorm.__init__c                 C   sJ   |j }|tj}|djddd}|t|| j  }| j|| S r7   )	r:   r;   r+   r<   rA   rB   rC   r.   r-   )r0   rD   rF   rG   r5   r5   r6   rH   Z   s
    zZamba2RMSNorm.forwardc                 C   s   t | jj d| j S )Nz, eps=)tupler-   r?   r.   r0   r5   r5   r6   
extra_repra   s    zZamba2RMSNorm.extra_repr)r&   )rI   rJ   rK   r)   rH   rP   rL   r5   r5   r3   r6   rM   Q   s   rM   c                   @   s   e Zd ZdZdZejdfeeej	e
e dddZdd Zeeejejf d	d
dZdejejee
eeef  eejejf dddZejdddZde
e ed	ddZeejejejdddZdd ZdS )Zamba2HybridDynamicCachea  
    A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
    (which has a constant shape regardless of seq_len).

    This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
    and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
    For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
    while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
    For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
    while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
    and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
    FN)config
batch_sizer:   devicec              	      s  || _ |j| _d| _t|j|j | _|j| _|j	| _
|j| _g | _i | _i | _i | _i | _i | _t|jD ]n}tj | jd|j |j  | j
|d| j|< tj | j|j| j|d| j|< | j| dkrl| j| ql fddt|jD | _ fddt|jD | _d S )NFr8   rT   r:   hybridc                    s    g | ]}t jg g  d qS rT   r+   Ztensor.0_rS   rT   r5   r6   
<listcomp>       z5Zamba2HybridDynamicCache.__init__.<locals>.<listcomp>c                    s    g | ]}t jg g  d qS rW   rY   rZ   r]   r5   r6   r^      r_   )r:   layers_block_typehas_previous_stateintmamba_expandr1   intermediate_sizemamba_d_statessm_state_sizemamba_d_convconv_kernel_sizen_mamba_headstransformer_layersZ_modules_parameters_buffersconv_states
ssm_statesrangenum_hidden_layersr+   zerosmamba_ngroupsmamba_headdimappend	key_cachevalue_cache)r0   rR   rS   r:   rT   ir5   r]   r6   r)   u   s8    z!Zamba2HybridDynamicCache.__init__c                 C   s
   t | jS r'   )lenru   rO   r5   r5   r6   __len__   s    z Zamba2HybridDynamicCache.__len__)	layer_idxreturnc                 C   s   | j | | j| fS r'   )ru   rv   r0   rz   r5   r5   r6   __getitem__   s    z$Zamba2HybridDynamicCache.__getitem__)
key_statesvalue_statesrz   cache_kwargsr{   c                 C   sz   | j | jd dkr*|| j |< || j|< n<tj| j | |gdd| j |< tj| j| |gdd| j|< | j | | j| fS )Nr9   r   r8   dim)ru   r?   rv   r+   cat)r0   r~   r   rz   r   r5   r5   r6   update   s    
zZamba2HybridDynamicCache.update)beam_idxc                 C   s   t t| jD ]}| j| j}| j| d||| j|< | j| j}| j| d||| j|< | j| j}| j| d||| j|< | j| j}| j| d||| j|< qdS )zDReorders the cache for beam search, given the selected beam indices.r   N)	ro   rx   ru   rT   Zindex_selectr;   rv   rm   rn   )r0   r   rz   rT   r5   r5   r6   reorder_cache   s    z&Zamba2HybridDynamicCache.reorder_cacher   c                 C   sL   || j vr| j d n|}t| j|ks8| j|  dkr<dS | j| jd S )zYReturns the sequence length of the cached states. A layer index can be optionally passed.r   )rj   rx   ru   Znumelr?   r|   r5   r5   r6   get_seq_length   s     z'Zamba2HybridDynamicCache.get_seq_length)rz   new_conv_statecache_positionr{   c                 C   sr   | j | }|d| jd }|jddd}||j|d d d d |f< | j |   | j |  |7  < | j | S )Nr   r   r9   Zshiftsdims)rm   clamprh   rollr;   rT   zero_)r0   rz   r   r   
conv_stater5   r5   r6   update_conv_state   s    
z*Zamba2HybridDynamicCache.update_conv_statec                 C   s   | j   | j  d S r'   )rm   r   rn   rO   r5   r5   r6   reset   s    
zZamba2HybridDynamicCache.reset)N)r   )rI   rJ   rK   __doc__Zis_compileabler+   Zfloat16r   rb   r:   r   strr)   ry   rN   Tensorr}   dictr   r   
LongTensorr   r   r   r   r5   r5   r5   r6   rQ   e   s,     	
rQ   c                       sD   e Zd ZU ejed< ded fddZe e	dd Z
  ZS )	Zamba2RotaryEmbeddinginv_freqNrR   c                    s   t    t|dr:t|jtr:|jd|jd| _nd| _|j| _	|j| _
|| _t| j | _| | j|\}| _| jd|dd | j| _d S )Nrope_scaling	rope_typetypedefaultr   F)
persistent)r(   r)   hasattr
isinstancer   r   getr   max_position_embeddingsZmax_seq_len_cachedZoriginal_max_seq_lenrR   r   Zrope_init_fnattention_scalingZregister_bufferr   Zoriginal_inv_freq)r0   rR   rT   r   r3   r5   r6   r)      s    
zZamba2RotaryEmbedding.__init__c           
      C   s   | j d d d d f  |jd dd|j}|d d d d d f  }t|jjtrl|jjdkrl|jjnd}t	j
|ddV | |  dd}t	j||fdd	}| | j }| | j }	W d    n1 s0    Y  |j|jd
|	j|jd
fS )Nr   r9   r   ZmpscpuF)device_typeZenabledr8   r   r:   )r   floatexpandr?   r;   rT   r   r   r   r+   Zautocast	transposer   cosr   sinr:   )
r0   xposition_idsZinv_freq_expandedZposition_ids_expandedr   ZfreqsZembr   r   r5   r5   r6   rH      s    0&,zZamba2RotaryEmbedding.forward)N)rI   rJ   rK   r+   r   __annotations__r   r)   Zno_gradr   rH   rL   r5   r5   r3   r6   r      s
   

r   )rD   n_repr{   c                 C   s^   | j \}}}}|dkr| S | dddddddddf |||||} | ||| ||S )z
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    r   N)r?   r   reshape)rD   r   batchnum_key_value_headsslenhead_dimr5   r5   r6   	repeat_kv   s
    0r           )modulequerykeyvalueattention_maskscalingdropoutc                 K   s   t || j}t || j}	t||dd| }
|d urf|d d d d d d d |jd f }|
| }
tjj|
dtj	d
|j}
tjj|
|| jd}
t|
|	}|dd }||
fS )Nr8   r   r   r9   )r   r:   )ptrainingr   )r   num_key_value_groupsr+   matmulr   r?   r   r=   Zsoftmaxr<   r;   r:   r   r   
contiguous)r   r   r   r   r   r   r   kwargsr~   r   attn_weightscausal_maskattn_outputr5   r5   r6   eager_attention_forward  s    
&r   c                 C   sH   | dd| j d d f }| d| j d d df }tj| |fddS )z*Rotates half the hidden dims of the input..Nr9   r8   r   )r?   r+   r   )r   x1Zx2r5   r5   r6   rotate_half  s    r   c                 C   sD   | |}| |}| | t| |  }|| t||  }||fS )a  Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    )	unsqueezer   )qkr   r   r   Zunsqueeze_dimZq_embedZk_embedr5   r5   r6   apply_rotary_pos_emb#  s
    

r   c                       s   e Zd ZdZdeee ee ee d fddZedddd	de	j
eee	j
 ee eee	j
e	j
f  ee ee	j
ee	j
 eee	j
  f d
ddZ  ZS )Zamba2AttentionaZ  
    Multi-headed attention from 'Attention Is All You Need' paper.

    Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
    The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads.
    The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer
    (see fig. 2 in https://huggingface.co/papers/2405.16712).
    Additionally, replaced
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2)
    Finally, this attention layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this
    layer is tied, un-tied adapters (formally the same as LoRA but used in the base model) modules are added to the q, k, v projectors to increase
    expressivity with a small memory overhead (see Fig. 2 of https://huggingface.co/papers/2411.15242).
    N)rR   rz   num_fwd_mem_blocksblock_idc           	   	      s  t    || _|| _|j| _|j| _|j|j | _	|j
| _
| jd d | _d| _|j| _tj|j|j| j dd| _tj|j|j| j dd| _tj|j|j| j dd| _tj|j| j |jdd| _|| _|j| _|| _|jrtg | _tg | _tg | _t| jD ]}||j |krt tj| j| jj!ddtj| jj!| jdd}t tj| j| jj!ddtj| jj!| jdd}t tj| j| jj!ddtj| jj!| jdd}nt" }t" }t" }| j#| | j#| | j#| qdd t$| jD | _%d S )Nr8   g      TFbiasc                 S   s   i | ]\}}||qS r5   r5   r[   indexr   r5   r5   r6   
<dictcomp>  r_   z,Zamba2Attention.__init__.<locals>.<dictcomp>)&r(   r)   rR   rz   attention_hidden_sizeZattention_head_dimr   Znum_attention_headsr   r   r   r   Z	is_causalattention_dropoutr   Linearq_projk_projv_projr1   o_projr   hybrid_layer_idslayer_block_mapr   use_shared_attention_adapter
ModuleListlinear_q_adapter_listlinear_k_adapter_listlinear_v_adapter_listro   num_mem_blocks
Sequentialadapter_rankIdentityrt   	enumerate	layer_dic)	r0   rR   rz   r   r   rw   Zlinear_q_adapterZlinear_k_adapterZlinear_v_adapterr3   r5   r6   r)   N  sT    
zZamba2Attention.__init__past_key_valuepast_key_values4.58new_nameversion)rD   rz   r   r   position_embeddingsr   r{   c                 K   st  |j d d }g |d| jR }| |}	| |}
| |}| jjr| j| }|	| j| | }	|
| j	| | }
|| j
| | }|	|dd}	|
|dd}
||dd}| jjr|\}}t|	|
||\}	}
|d ur||
||\}
}t}| jjdkrt| jj }|| |	|
||f| js2dn| j| jd|\}}|jg |dR   }| |}||fS )Nr9   r   r8   eagerr   )r   r   )r?   r   r   r   r   rR   r   r   r   r   r   r@   r   use_mem_roper   r   r   _attn_implementationr   r   r   r   r   r   r   )r0   rD   rz   r   r   r   r   Zinput_shapeZhidden_shapeZquery_statesr~   r   Zadapter_layer_idxr   r   Zattention_interfacer   r   r5   r5   r6   rH     sH    






zZamba2Attention.forward)NNN)NNN)rI   rJ   rK   r   r   r   rb   r)   r   r+   r   rQ   rN   r   r   rH   rL   r5   r5   r3   r6   r   >  s.      8   r   )input_tensorpad_sizec                 C   sH   t | jdkr"ddddd|ddfnddd|ddf}tjjj| |dddS )z
    Padding x tensor with `pad_size` on the seq_len dim (dim=1)

    Assumes that we only have tensors of either size 4 or 3
       r   Zconstant)moder   )rx   r?   r+   r   r=   pad)r   r   Z	pad_shaper5   r5   r6   pad_tensor_by_size  s    2r   c                 C   s\   t | |} t| jdkr4| | jd d|| jd S | | jd d|| jd | jd S dS )z
    Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
    simultaneously splitting it into chunk sequences.

    Assumes that we only have tensors of either size 4 or 3
    r   r   r9   r8   N)r   rx   r?   r   )r   r   
chunk_sizer5   r5   r6   reshape_into_chunks  s    
r   c                 C   s   |  d}| d jg |   |R  } tjtj||| jtjddd}| | d} tj| dd}tjtj||| jtjddd}|| tj	 }|S )zo
    More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
    r9   .NrU   Zdiagonalr   r   r   )
sizer   r+   Ztrilr,   rT   boolmasked_fillcumsuminf)r   r   maskZtensor_segsumr5   r5   r6   segment_sum  s    
  r  c                       s   e Zd ZdZdeee d fddZdej	ee
 eej	 dddZdee
 eej	 d	d
dZdee
 eej	 d	ddZ  ZS )Zamba2MambaMixeru  
    Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
    A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
    ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
    and is why Mamba is called **selective** state spaces)
    NrR   rz   c                    s  t    || _|j| _|j| _|j| _t|j	| j | _
|| _|j| _d| _t | _|j| _|j| _|j| _| jj| _|j| _|j| _|j| _|j| _| j
d| j | j  | _tj| j| jd|j| j|jd d| _| j
| j | j }tj| j||j d| _!t"t#$| j| _%t#&d| jd }t"t#'|| _(d| j(_)t*| j
| j
| j dd| _+t"t#$| j| _,d| j,_)tj| j
| j|j d| _-t.st/0d	 d S )
Nr>   r8   Tr   )Zin_channelsZout_channelsr   Zkernel_sizegroupspaddingr   gh㈵>)r/   r2   a  The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d)1r(   r)   rR   r1   re   rf   rg   rh   rb   rc   rd   rz   use_conv_bias
activationr   ZSiLUactuse_mem_eff_pathrr   n_groupsrs   r   ri   	num_headsr   time_step_limittime_step_mintime_step_maxconv_dimZConv1dconv1dr   add_bias_linearin_projr*   r+   r,   dt_biasarangelogA_logZ_no_weight_decayr%   normDout_projis_fast_path_availableloggerwarning_once)r0   rR   rz   Zprojection_sizeAr3   r5   r6   r)     s^    



	zZamba2MambaMixer.__init__rD   cache_paramsr   c                 C   sF  |j \}}}| j| j }d| j d| j | j  | j }|d ur6|jr6| |d}	|	j d | d }
|
|
| j| j| jg}t	j
|	|dd\}}}}}t||j| j | jjd| jj| j}t	j
|| j||gdd\}}}t	| j  }|d d d df d d d d d f d| j| jjt	jd}|d d d d d f dd| j}| jd d d df d| j}| jd d d df d| j}||| j|j d | j }||| j|j d | j }||| j| j}t|j| j ||||||d |dd
}||| j| j }| ||}| |d d d df }n|d urvt	 |dksv|j!}||d d d d d f  |}| |}t	| j  }| j"d u ri nd	| j"i}|d urt	 |dk}nd}| j#rN| j$rN|d u rN|rNt%|| jjd| jj| j|f| j| j&d | j| jj| jj'| jj| jj| j| jd
dd|\}}nt	j
|| j| j| jgdd\}}}|d ur|(dd}t)j*+|| j,|j d  df}|j| j -| t.d u s| jdvr| /| |(dd(ddd d d |f }n@t.|(dd| jjd| jj| jd(ddd d d |f }t	j
|| j||gdd\}}}|d urt	 |dks|j!}||d d d d d f  |}t0|||d| j|||||| jd|||| jdf| j&| jd d d| jdd|\}}|d ur|d ur|j| j -| |||d}| ||}| |}|S )Nr8   r   r9   r   .r   T)zr  dt_softplusZdt_limitF)r  r   seq_idxr  Zrmsnorm_weightZrmsnorm_epsZoutproj_weightZoutproj_biasZheaddimZngroupsZnorm_before_gatereturn_final_statesr   )r>   Zswish)r   r-   r   r  )r   r  r$  r&  r'  r  r%  )1r?   r  rf   rd   r  ra   r  squeezer  r+   splitr$   rm   rz   r  r-   r   r  expr  r   r   r   r;   r<   r  r  r@   r    rn   r  r  allr:   r  r  r   r"   r   r.   r   r   r=   r   rh   copy_r#   r  r!   )r0   rD   r#  r   rS   seq_lenr\   Zgroups_time_state_sizeZd_to_removeZin_projected_statesd_mlpZsplit_projection_dimrE   Zhidden_states_B_CdtBCr!  r  r  Zhidden_states_reshapedoutr:   projected_statesZdt_limit_kwargsZinput_not_masked	ssm_stateZ	time_stepZhidden_states_B_C_tr   scan_outputr5   r5   r6   cuda_kernels_forward=  s    

<"
 

 
(

 

z%Zamba2MambaMixer.cuda_kernels_forward)r#  r   c           1   
      s	  |j \}}}|j}|d ur2|jr2|d}n@|d urht|dksh||d d d d d f  |}|}|j d dj  dj	 j
  j d }	|j|	|	jjjgdd\}}}
}}|d urv|jj  }||j}|jr|
d}
|jj }tj|ddd}|jdkr>|d d dd d f n||d d d d df< |jj | tj||jjjd d dd d f  dd}jr|jj7 }||d d d df }n|dd}tj |j!|j d  df}|jj | |ddd d d |d d f }|d urt|dks|j}||d d d d d f  |}nLtj"|jj#j
f|j|d	}|dddd |f dd}tj|jj	j
 j	j
 gdd\}}}t$j%&  }|d ur|jr|jdkr2|d d d df n"|d d dd d f d d d df }|dd'||j d j#}j(d
 'j(j d j#}tjj)|||j }t*|j+}|d 'jj#j
jtj,d}t$|d
 | }|-|j	ddd d d f }|'|j	jj	 |j d . }|-|d|j d }|d
 |dd d d f  }|-|dj#}||d
  }|jj |jj | |  |-|j	ddd d d f }|'|j	jj	 |j d . }|-|d|j d }|jj |j}|/|j j#j
}|/|j j
d}t0||}|/|jj#}j1d
 'j1j d j#}|||  |j}|-|dd d d df }n8tj)|j( }t*|j+}|-||dj#& }|-||dj
& }|-||dj
& }|j2jj	 djd}|j2jj	 djd}j3|j3  j3  j1d
 t4|  }||d
  }||j| } fdd||||fD \}}}}|5dddd}tj6|dd}t$t7|}|d d d d d d d d d d d f |d d d d d d d d d d d f  }|jdd}|d
 |5dddddd
  } | jdd}!|!d
 |d d d d d f  d}"t$|d d d d d d dd f | }#||#5ddddd
  }$|$5dddddd
 |5ddddddd d d f  jdd5ddddd}%|d ur0|jr0|jj d d d df }&nt8|%d d d df }&tj9|&|%gdd}%t$t7tj |d d d d d d df d}'|%5ddddd}(|'d |(d d d d d df  jdd})|)5ddddd}*|*d d d df |*d d df  }%}t$|}+|dd d d f |%d d d d d df  },|+5dddd}-|,d|-d
  }.|"|. }|-|djj#}|| } dk	r|d d d |d d d d f }|-||d}|d u	r|d u	r|jj | :||
}/;|/|}0|0S )Nr   r9   r8   r   r   r   r   .rU   r   ).NNr   )r   Zoutput_sizec                    s   g | ]}t | jqS r5   )r   r   )r[   tr   r0   r5   r6   r^   O  r_   z2Zamba2MambaMixer.torch_forward.<locals>.<listcomp>r   )r   r   )<r?   r:   ra   r  r(  r+   r+  r;   rd   r  rf   r  r)  r  rn   rz   clonerT   r   rm   r   ndimr,  sumr  r-   r
  r   r  r   r   r=   r   rh   rq   r   r*  r  r   r   r  Zsoftplusr   r  r<   r   r   r@   Zbmmr  Zrepeat_interleaver   r   Zpermuter  r  Z
zeros_liker   r  r  )1r0   Zinput_statesr#  r   rS   r-  r\   r:   r3  r.  rE   rD   r/  r4  r   r0  r1  r!  r  ZdAZdBZdBxrn   Zssm_states_reshapedZ
C_reshapedyr  Z
D_residualZA_cumsumLZG_intermediateGZM_intermediateMZY_diagZdecay_statesZB_decay_contractionZstatesZprevious_statesZdecay_chunkZstates_permutedresultZ
new_statesZstate_decay_outZC_times_statesZstate_decay_out_permutedZY_offr5  Zcontextualized_statesr5   r8  r6   torch_forward  s     
.

80 .",.B"$$$P$*L0(&
*
 zZamba2MambaMixer.torch_forwardc                 C   s0   t r"d| jjjjv r"| |||S | |||S )Ncuda)r  r  r-   rT   r   r6  rA  )r0   rD   r#  r   r5   r5   r6   rH     s    zZamba2MambaMixer.forward)N)NN)NN)NN)rI   rJ   rK   r   r   r   rb   r)   r+   r   rQ   r6  rA  rH   rL   r5   r5   r3   r6   r    s$   D    F  r  c                       s4   e Zd Zdeee d fddZdddZ  ZS )		Zamba2MLPN)rR   r   c              	      s   t    || _|j| _|j| _|| _|| _tj| jd| j |j	d| _
tj| j| j|j	d| _t|j | _tg | _t| jD ]^}||j |krttj| jj| jjddtj| jjd| j dd}nt }| j| q|j}dd t|D | _dS )aQ  
        This MLP layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this layer
        is tied, un-tied adapter modules (formally same as LoRA, but used in the base model) are added to the up and gate projectors to increase expressivity with a small memory overhead.
        r8   r   Fc                 S   s   i | ]\}}||qS r5   r5   r   r5   r5   r6   r     r_   z&Zamba2MLP.__init__.<locals>.<dictcomp>N)r(   r)   rR   r1   rd   r   r   r   r   r  gate_up_proj	down_projr   Z
hidden_actact_fnr   gate_up_proj_adapter_listro   r   r   r   r   rt   r   r   r   )r0   rR   r   r   rw   Zgate_up_proj_adapterr   r3   r5   r6   r)     s(    
zZamba2MLP.__init__c                 C   sZ   |  |}| j| }|| j| | }tj|ddd}| |d |d  }| |}|S )Nr8   r9   r   r   r   )rD  r   rG  r+   chunkrF  rE  )r0   Zhidden_staterz   Zgate_up_stateoutputr5   r5   r6   rH     s    


zZamba2MLP.forward)NN)N)	rI   rJ   rK   r   r   rb   r)   rH   rL   r5   r5   r3   r6   rC    s   rC  c                       s   e Zd Zdeee ee d fddZedddddej	ej	eeej	 ee
 ee eej ee eejeeejejf  f d
	ddZ  ZS )Zamba2AttentionDecoderLayerN)rR   r   rz   c                    sd   t    || _t|j}t|d||d| _t|||d| _t	|j
|jd| _t	|j|jd| _d S )Nr9   )rz   r   r   )r   r   r2   )r(   r)   r   rx   r   r   	self_attnrC  feed_forwardrM   r   rms_norm_epsinput_layernormr1   pre_ff_layernorm)r0   rR   r   rz   Znum_gsr3   r5   r6   r)     s    

z$Zamba2AttentionDecoderLayer.__init__r   r   r   r   F)	rD   original_hidden_statesrz   r   r   output_attentionsr   r   r{   c              	   K   sl   t j||gdd}| |}| jf ||||||d|\}}	| |}| ||}|f}
|rh|
|	f7 }
|
S )a  
        Args:
            hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)`
            original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`.
                This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The
                concatenated tensor is then used as input of the pre-attention RMSNorm
                (see fig. 2 in https://huggingface.co/papers/2405.16712).
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, sequence_length)` where padding elements are indicated by 0.
            past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
                with `head_dim` being the embedding dimension of each attention head.
        r9   r   )rD   rz   r   r   rR  r   )r+   ZconcatenaterO  rL  rP  rM  )r0   rD   rQ  rz   r   r   rR  r   r   self_attn_weightsoutputsr5   r5   r6   rH     s$     




z#Zamba2AttentionDecoderLayer.forward)NN)NNFN)rI   rJ   rK   r   r   rb   r)   r   r+   r   rQ   r   r   r   r   rN   FloatTensorrH   rL   r5   r5   r3   r6   rJ    s"    	    rJ  c                       s   e Zd Zeed fddZedddddeje	ej e	e e	ej e	ej e	e
 e	e e	e e	ej e	ej eeje	eejejf  f d
ddZ  ZS )Zamba2MambaDecoderLayerr  c                    s4   t    t||d| _t|j|jd| _|| _d S )Nr  rK  )	r(   r)   r  mambarM   r1   rN  rO  rz   )r0   rR   rz   r3   r5   r6   r)     s    
z Zamba2MambaDecoderLayer.__init__r   r   r   r   NF)rD   rQ  rz   r   r   r   rR  	use_cacher   transformer_hidden_statesr{   c                 K   sd   |}|
dur||
 n|}|  |}| j|||d}d}|| }|f}|rR||f7 }|r`||f7 }|S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, sequence_length)` where padding elements are indicated by 0.
            past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
                Indices depicting the position of the input sequence tokens in the sequence.
        Nr"  )rO  rW  )r0   rD   rQ  rz   r   r   r   rR  rX  r   rY  r   ZresidualrS  rT  r5   r5   r6   rH     s"    


zZamba2MambaDecoderLayer.forward)	NNNNNFFNN)rI   rJ   rK   r   rb   r)   r   r+   r   r   rQ   r   r   rN   rU  rH   rL   r5   r5   r3   r6   rV    s0            rV  c                       s   e Zd Zeejed fddZeddddde	j
ee	j
 ee ee	j
 ee	j
 ee ee ee ee	j ee	jeee	je	jf  f d

ddZ  ZS )Zamba2HybridLayer)shared_transformerlinearrW  c                    s    t    || _|| _|| _d S r'   )r(   r)   r\  mamba_decoderr[  )r0   r[  r\  rW  r3   r5   r6   r)   S  s    
zZamba2HybridLayer.__init__r   r   r   r   NF)
rD   rQ  rz   r   r   r   rR  rX  r   r{   c
              	   C   sn   | j |||||||	d}
|
d }|r,|
d }| |}| j|||||||	d}
|rj|
d |f|
dd  }
|
S )aY  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            original_hidden_states (`torch.FloatTensor`): word embedding output that will be concatenated with
            hidden activations to form the input of the shared transformer layer.
            layer_idx (`int`): layer number.
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, sequence_length)` where padding elements are indicated by 0.
            past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
                with `head_dim` being the embedding dimension of each attention head.
        )rQ  rz   r   r   rR  r   r   r   )rY  r   r   rR  rX  r   r8   N)r[  r\  r]  )r0   rD   rQ  rz   r   r   r   rR  rX  r   layer_outputsrY  rS  r5   r5   r6   rH   [  s2    !


zZamba2HybridLayer.forward)NNNNNFFN)rI   rJ   rK   rJ  r   r   rV  r)   r   r+   r   r   rb   rQ   r   r   rN   rU  rH   rL   r5   r5   r3   r6   rZ  R  s.           rZ  c                       sJ   e Zd ZU eed< dZdZddgZdZdZ	dZ
dZdZ fddZ  ZS )	Zamba2PreTrainedModelrR   modelTrJ  rV  r   c                    s   t  | t|trtt| jjt	
| jjt	
| jj  t	
| jj j| jjd}|t
t|   }|jj| td|jd }|jjt
| |jjd d S )N)minr   g      ?)r(   _init_weightsr   r  r+   r*  ZrandrR   ri   mathr  r  r  r   Ztime_step_floorexpm1r  datar,  r  r  r  r  Zfill_)r0   r   r/  Zinv_dtr!  r3   r5   r6   rb    s     
z#Zamba2PreTrainedModel._init_weights)rI   rJ   rK   r   r   Zbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesZ_skip_keys_device_placementZ_supports_flash_attnZ_supports_flex_attnZ_supports_sdpaZ_is_statefulrb  rL   r5   r5   r3   r6   r_    s   
r_  c                       s   e Zd ZdZed fddZedeej	 eej
 eej	 ee eej ee ee ee ee eej	 eeef dddZd	d
 Zdd Z  ZS )Zamba2Modelzh
    Model consisting of *config.num_hidden_layers* layers.

    Args:
        config: Zamba2Config
    r   c                    sR  t     | _ j| _ j| _t j j| j| _	 fddt
 jD }g }g } j| _t
 jD ]d} j| dkr|t |d qj j| dkrj|tj| jj| jjdd |t |d qjt|}t|}t|}| |||}t|| _ j| _t j jd| _ jr@ jr6td	 t | _d| _ | !  d S )
Nc                    s   g | ]}t  |d qS ))r   )rJ  )r[   r   r   r5   r6   r^     r_   z(Zamba2Model.__init__.<locals>.<listcomp>rW  rz   rV   Fr   rK  ze`use_long_context` set to `True`: using rescaled `rope_theta` and extended `max_position_embeddings`.)"r(   r)   rR   pad_token_idZpadding_idx
vocab_sizer   Z	Embeddingr1   embed_tokensro   r   r`   rp   rt   rV  r   iterr   
get_layersr   layersr   rM   rN  final_layernormr   Zuse_long_contextr  r   r   
rotary_embgradient_checkpointing	post_init)r0   rR   blocksmamba_layerslinear_layersrw   rm  r3   r   r6   r)     s<    
zZamba2Model.__init__N)	input_idsr   r   r   inputs_embedsrX  rR  output_hidden_statesreturn_dictr   r{   c                 C   sx  |d ur|n| j j}|d ur |n| j j}|d ur4|n| j j}|	d urH|	n| j j}	|d u |d uA rhtd| jr| jr|rt	d d}|d u r| 
|}|}t|}|r|d u r|d ur|jd n|jd }t| j || j| jd}|
d u r&|d ur|j| jdnd}tj|||jd  |jd}
|d u r:|
d}| |||
}| j jr`| ||}nd }|rnd	nd }|r|d	nd }t| jD ]\}}|r||f7 }| jr| jr| |j|||||||||
}n||||||||||d
	}|d }|r|d d ur||d f7 }q| |}|r4||f7 }|d urL|jsLd|_t||rZ|nd ||d}|	rp|S | S )NzaYou cannot specify both input_ids and inputs_embeds at the same time, and must specify either onezX`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.Fr   r:   rT   rg  r   rX   r5   )rQ  rz   r   r   r   rR  rX  r   T)Zlast_hidden_stater   rD   
attentions) rR   rR  rw  rX  use_return_dict
ValueErrorrp  r   r  r   rj  r+   r9  r?   rQ   r:   rT   r   first_transformer_layer_idr  r   _update_causal_maskr   ro  r   rm  Z_gradient_checkpointing_func__call__rn  ra   r   Zto_tuple)r0   ru  r   r   r   rv  rX  rR  rw  rx  r   rD   rQ  rS   Zpast_seen_tokensr   r   Zall_hidden_statesZall_self_attnsrz   layerr^  rI  r5   r5   r6   rH     s    








zZamba2Model.forwardc                 C   s  | j jdkr$|d ur d|v r |S d S |j|j }}t|j}|jd }|d d }tj||f|||d}	|dkrtj	|	dd}	|	tj
||d|ddk9 }	|	d d d d d d f |jd ddd}	|d urJ|	 }	| d	krJ|jd }
|	d
d |
f d|d d d d d d f d }|	d
d |
f |||	d
d |
f< | j jdkr||d ur||jjdv r|t|	|}	|	S )NZflash_attention_2r   r   r9   )Z
fill_valuer:   rT   r   rX   r   r8   .Zsdpa)rB  ZxpuZnpu)rR   r   r:   rT   r+   Zfinfora  r?   fullZtriur  r   r   r9  r   eqr  r   r   Z_unmask_unattended)r0   r   r   r   r:   rT   Z	min_dtypeZsequence_lengthZtarget_lengthr   Zmask_lengthZpadding_maskr5   r5   r6   r~  `  s6    
*

4$

zZamba2Model._update_causal_maskc                 C   s|  g }g | _ d| _t| jD ]Z\}}|dkrh| jdkr>|| _t|}| jjt| jj dkrLd| d}t	
|d d d d	 d
 }	| j |	 d}
| jD ]H}|dkr|
| jj |jkrt	
dt|
 d }| j | |
d7 }
q| jjrLd}
| jD ]L}|dkrB|
| jj |jkrBt	
dt|
 d }| j | |
d7 }
q|t|t|t| q|t| q|S )Nr   rV   r   z	^layers\.z\.shared_transformer\.z(?:z3self_attn\.(?:q_proj|k_proj|v_proj|o_proj)\.weight|z1feed_forward\.(?:gate_up_proj|down_proj)\.weight|z,(?:input_layernorm|pre_ff_layernorm)\.weightz)$z>^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\.z\.(?:0|1)\.weight$zg^shared_transformer\.self_attn\.(?:linear_q_adapter_list|linear_k_adapter_list|linear_v_adapter_list)\.)_tied_weights_keysr}  r   r`   nextrR   r   rx   r   recompilert   r   r   r   rZ  )r0   rr  rt  rs  rm  Zlayer_idZ
layer_typeblockZprefix_patternZmain_keys_patternZ
adapter_idZ_layer_typeZadapter_patternZattn_adapter_patternr5   r5   r6   rl    sh    






zZamba2Model.get_layers)
NNNNNNNNNN)rI   rJ   rK   r   r   r)   r   r   r+   r   r   rQ   rU  r   r   rN   r   rH   r~  rl  rL   r5   r5   r3   r6   rf    s8   $          
x#rf  c                       s   e Zd Zed fddZedeej eej	 eej ee
 eej eej ee ee ee ee eej eeej	f eeef dddZdd
dZ  ZS )Zamba2ForCausalLMr   c                    sP   t  | t|| _dg| jj| _|j| _tj|j|jdd| _	| 
  d S )Nzlm_head.weightFr   )r(   r)   rf  r`  r  ri  r   r   r1   lm_headrq  r0   rR   r3   r5   r6   r)     s    
zZamba2ForCausalLM.__init__Nr   )ru  r   r   r   rv  labelsrX  rR  rw  rx  r   logits_to_keepr{   c                 K   s   |dur|n| j j}|	dur |	n| j j}	|
dur4|
n| j j}
| j||||||||	||
d
}|d }t|trxt| dn|}| |dd|ddf }d}|dur| j	||| j
fi |}|
s|f|dd  }|dur|f| S |S t|||j|j|jdS )al  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, Zamba2ForCausalLM

        >>> model = Zamba2ForCausalLM.from_pretrained("Zyphra/Zamba2-7B-v1")
        >>> tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B-v1")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```N)
ru  r   r   r   rv  rX  rR  rw  r   rx  r   r   losslogitsr   rD   rz  )rR   rR  rw  r{  r`  r   rb   slicer  Zloss_functionri  r   r   rD   rz  )r0   ru  r   r   r   rv  r  rX  rR  rw  rx  r   r  r   rT  rD   Zslice_indicesr  r  rI  r5   r5   r6   rH     s@    (zZamba2ForCausalLM.forwardTc              	   K   s  |d u }	|	sj|d us&|d |j d krD|d d |j d  d f }q|j d |j d kr|d d |f }nt| j|j d | j| jd}|d ur|d u r| dd }||dkd |	s|d d |j d  d f }|d ur|	rd|i}
nd| i}
|
	||||| jj
|d |
S )Nr9   r   r   ry  rv  ru  )r   r   rX  r   r  r   )r?   rQ   rR   r:   rT   longr  Zmasked_fill_r   r   Znum_logits_to_keep)r0   ru  r   r   rv  r   r   rX  r   Zempty_past_kvZmodel_inputsr5   r5   r6   prepare_inputs_for_generation  s<    

z/Zamba2ForCausalLM.prepare_inputs_for_generation)NNNNNNNNNNNr   )NNNNNT)rI   rJ   rK   r   r)   r   r   r+   r   r   rQ   rU  r   r   rb   rN   r   rH   r  rL   r5   r5   r3   r6   r    sH   
            
T      r  a  
    The Zamba2 Model with a sequence classification head on top (linear layer).

    [`Zamba2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
    (e.g. GPT-2) do.

    Since it does classification on the last token, it requires to know the position of the last token. If a
    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
    each row of the batch).
    )Zcustom_introc                       s   e Zd Z fddZedeej eej eej ee	e
eej f  eej eej ee ee ee ee e	eef dddZ  ZS )Zamba2ForSequenceClassificationc                    sJ   t  | |j| _t|| _| jj| _tj|j| jdd| _	| 
  d S )NFr   )r(   r)   
num_labelsrf  r`  r  r   r   r1   scorerq  r  r3   r5   r6   r)   ]  s    

z(Zamba2ForSequenceClassification.__init__N)ru  r   r   r   rv  r  rX  rR  rw  rx  r{   c                 C   sV  |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}|durV|jd }n
|jd }| j jdu r||dkr|td| j jdu rd}nb|dur|| j jk|jt	j
}t	j|jd |jt	j
d}|| d}nd}t| jj d |t	j||jd	|f }d}|dur||j}| j jdu r| jdkrFd
| j _n:| jdkrx|jt	jksn|jt	jkrxd| j _nd| j _| j jd
krt }| jdkr|| | }n
|||}nN| j jdkrt }||d| j|d}n| j jdkrt }|||}|
s>|f|dd  }|dur:|f| S |S t|||j|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        N)r   r   r   rv  rX  rR  rw  rx  r   r   z=Cannot handle batch sizes > 1 if no padding token is defined.r9   rU   z will not detect padding tokens in `inputs_embeds`. Results may be unexpected if using padding tokens in conjunction with `inputs_embeds.`rX   Z
regressionZsingle_label_classificationZmulti_label_classificationr  )rR   r{  r`  r  r?   rh  r|  r;   rT   r+   Zint32r  Zargmaxr  r   r4   rI   Zproblem_typer  r:   r  rb   r
   r(  r	   r@   r   r   r   rD   rz  )r0   ru  r   r   r   rv  r  rX  rR  rw  rx  Ztransformer_outputsrD   r  rS   Zlast_non_pad_tokenZnon_pad_maskZtoken_indicesZpooled_logitsr  Zloss_fctrI  r5   r5   r6   rH   g  sx    



(

z'Zamba2ForSequenceClassification.forward)
NNNNNNNNNN)rI   rJ   rK   r)   r   r   r+   r   r   r   r   listrU  r   rN   r   rH   rL   r5   r5   r3   r6   r  N  s2   
          
r  )r  r  rf  r_  )r   )Nr   )Vrc  r  	itertoolsr   typingr   r   r   r   r+   r   Ztorch.nnr   r	   r
   Zactivationsr   Zcache_utilsr   Z
generationr   Zmodeling_attn_mask_utilsr   Zmodeling_flash_attention_utilsr   Zmodeling_outputsr   r   r   Zmodeling_rope_utilsr   r   Zmodeling_utilsr   r   Zprocessing_utilsr   utilsr   r   Zutils.deprecationr   Zutils.import_utilsr   r   Zconfiguration_zamba2r   Z+mamba_ssm.ops.triton.selective_state_updater    Z!mamba_ssm.ops.triton.ssd_combinedr!   r"   Zcausal_conv1dr#   r$   Z
get_loggerrI   r  Moduler%   rM   rQ   r   r   rb   r   r   r   r   r   r   r   r   r  r+  r  r  rC  rJ  rV  rZ  r_  rf  r  r  __all__r5   r5   r5   r6   <module>   s   

m$ 
    1*AEK { j