a
    hk                     @   s  d Z ddlZddlmZ ddlmZmZ ddlZddlZddlm	Z	 ddl
mZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZ ddlmZmZ ddlmZ eeZe rddlmZ ddl m!Z!m"Z" n
d\Z!Z"Ze rddl#m$Z$m%Z% nd\Z%Z$e&ee!e"e$e%fZ'ej(e)dddZ*dd Z+dd Z,dd Z-G dd dZ.G dd  d ej	j/Z0G d!d" d"e	j/Z1G d#d$ d$e	j/Z2G d%d& d&eZ3eG d'd( d(eZ4eed)d*G d+d, d,eZ5eed-d*G d.d/ d/eZ6eG d0d1 d1e4Z7ed2d*G d3d4 d4e4eZ8g d5Z9dS )6zPyTorch MAMBA2 model.    N)	dataclass)OptionalUnion)nn   )ACT2FN)GenerationMixin)GradientCheckpointingLayer)PreTrainedModel)ModelOutputauto_docstringlogging)is_causal_conv1d_availableis_mamba_2_ssm_available   )Mamba2Config)selective_state_update)mamba_chunk_scan_combined mamba_split_conv1d_scan_combined)NNN)causal_conv1d_fncausal_conv1d_update)NN)input_tensorpad_sizec                 C   sH   t | jdkr"ddddd|ddfnddd|ddf}tjjj| |dddS )z
    Padding x tensor with `pad_size` on the seq_len dim (dim=1)

    Assumes that we only have tensors of either size 4 or 3
       r   Zconstant)modevalue)lenshapetorchr   
functionalpad)r   r   Z	pad_shape r!   f/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/mamba2/modeling_mamba2.pypad_tensor_by_sizeB   s    2r#   c                 C   s\   t | |} t| jdkr4| | jd d|| jd S | | jd d|| jd | jd S dS )z
    Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
    simultaneously splitting it into chunk sequences.

    Assumes that we only have tensors of either size 4 or 3
    r   r      N)r#   r   r   reshape)r   r   
chunk_sizer!   r!   r"   reshape_into_chunksM   s    
r(   c                 C   s   |  d}| d jg |   |R  } tjtj||| jtjddd}| | d} tj| dd}tjtj||| jtjddd}|| tj	 }|S )zo
    More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
    r$   .Ndevicedtype)Zdiagonalr   dim)
sizeexpandr   Ztrilonesr+   boolZmasked_fillcumsuminf)r   r'   maskZtensor_segsumr!   r!   r"   segment_suma   s    
  r7   c                 C   sN   |durJ|j d dkrJ|j d dkrJ| j}| |dddddf  |} | S )zm
    Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
    Nr   r   )r   r,   to)hidden_statesattention_maskr,   r!   r!   r"   apply_mask_to_padding_statesu   s    $ r;   c                   @   sf   e Zd ZdZejdfeeeje	e
 dddZdeejeejddd	Zeejd
ddZdd ZdS )Mamba2Cachea  
    Arguments:
        config: Mamba2Config
        batch_size: int
        dtype: torch.dtype
        device: torch.device

    Attributes:
        dtype: (`torch.dtype`):
            The default `dtype` used to initializing the cache.
        conv_kernel_size: (`int`):
            Model's convolution kernel size taken from config.
        n_groups: (`int`):
            Model's number of groups taken from the config - similar to tensor parallel in Transformer.
        state_size: (`int`):
            Model's SSM state size taken from config.
        num_heads: (`int`):
            The number of heads used in the linear attention / SSM.
        head_dim: (`int`):
            The respective dimension of the heads used in the linear attention / SSM.
        intermediate_size: (`int`):
            Model's intermediate_size based on (expand * hidden_dim) from config.
        conv_states: (`torch.Tensor`):
            A tensor of shape `[num_layers, batch_size, conv_kernel_size, intermediate_size + 2 * n_groups * state_size]` that holds convolutional states.
        ssm_states: (`torch.Tensor`):
            A tensor of shape `[num_layers, batch_size, num_heads, head_dim, state_size]` that holds ssm states.
    N)config
batch_sizer,   r+   c              	   C   s   || _ |j| _|j| _|j| _|j| _|j| _t|j|j	 | _
tj|j|| j
d| j | j  | j||d| _tj|j|| j| j| j||d| _d S )Nr%   r*   )r,   conv_kernelconv_kernel_sizen_groups
state_size	num_headshead_dimintr1   hidden_sizeintermediate_sizer   Zzerosnum_hidden_layersconv_states
ssm_states)selfr=   r>   r,   r+   r!   r!   r"   __init__   s0    zMamba2Cache.__init__F)	layer_idxnew_conv_state
cache_initreturnc                 C   sv   |r| | jj| j|< nR| j| jddd| j|< |d d dd d f  | jj| j| d d d d df< | j| S )Nr$   )Zshiftsdimsr   )r8   rI   r+   Zroll)rK   rM   rN   rO   r!   r!   r"   update_conv_state   s
    8zMamba2Cache.update_conv_staterM   new_ssm_statec                 C   s   | | jj| j|< | j| S N)r8   rJ   r+   )rK   rM   rT   r!   r!   r"   update_ssm_state   s    zMamba2Cache.update_ssm_statec                 C   s   | j   | j  d S rU   )rI   Zzero_rJ   rK   r!   r!   r"   reset   s    
zMamba2Cache.reset)F)__name__
__module____qualname____doc__r   Zfloat16r   rE   r,   r   strrL   Tensorr3   rR   rV   rX   r!   r!   r!   r"   r<      s    
r<   c                       s(   e Zd Zd fdd	ZdddZ  ZS )	MambaRMSNormGatedư>c                    s&   t    tt|| _|| _d S rU   superrL   r   	Parameterr   r2   weightvariance_epsilonrK   rF   eps	__class__r!   r"   rL      s    
zMambaRMSNormGated.__init__Nc                 C   sj   |j }|tj}|d ur2|tj|tj }|djddd}|t	|| j
  }| j|| S Nr%   r$   T)Zkeepdim)r,   r8   r   float32r   r   silupowmeanrsqrtre   rd   )rK   r9   gateinput_dtypevariancer!   r!   r"   forward   s    zMambaRMSNormGated.forward)r`   )NrY   rZ   r[   rL   rs   __classcell__r!   r!   rh   r"   r_      s   r_   c                       s   e Zd ZdZeed fddZdeje	e
 e	ej e	ej dddZdeje	e
 e	ej e	ej dd	d
Zde	e
 e	ej e	ej dddZ  ZS )Mamba2Mixeru  
    Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
    A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
    ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
    and is why Mamba is called **selective** state spaces)
    )r=   rM   c                    s  t    |j| _|j| _|j| _|j| _t|j	| j | _
t|j| _|| _|j| _|j| _t|j | _|j| _|j| _|j| _|j| _|j| _|j| _|j| _|j| _| j
d| j | j  | _tj| j| j|j|j| j|jd d| _| j
| j | j }tj| j||jd| _ t!t"#| j| _$t"%d| jd }t!t"&|| _'d| j'_(t)| j
| jd| _*t!t"#| j| _+d| j+_(tj| j
| j|jd| _,|j| _t-st./d d S )Nr%   r   )Zin_channelsZout_channelsbiasZkernel_sizegroupspaddingrw   Trg   a  The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d)0rb   rL   rC   rF   rB   ssm_state_sizer?   r@   rE   r1   rG   Ztime_step_rankrM   use_conv_biasZ
hidden_act
activationr   actlayer_norm_epsilonZrms_normrA   rD   r'   time_step_limittime_step_mintime_step_maxconv_dimr   ZConv1dconv1dLinearZuse_biasin_projrc   r   r2   dt_biasarangelogA_log_no_weight_decayr_   normDout_projis_fast_path_availableloggerZwarning_once)rK   r=   rM   Zprojection_sizeArh   r!   r"   rL      s^    

	zMamba2Mixer.__init__N)r9   cache_paramscache_positionr:   c                 C   s  t ||}| |}|j\}}}| j| j }	|jd d| j  d| j | j  | j d }
|d urF|d urF|d dkrF|dj|
|
| j| j	| jgdd\}}}}}t
||j| j | jjd| jj| j}tj|| j|	|	gdd\}}}t| j  }|d d d df d d d d d f d| j| jjtjd}|d d d d d f dd| j}| jd d d df d| j}| jd d d df d| j}||| j|jd | j }||| j|jd | j }||| j| j}t|j| j ||||||d |dd	
}||| j| j }| ||}| |d d d df }nJt| j  }| j d
tdfkrpi nd| j i}| j!r|d u rt"|| jjd| jj| j|f| j| j#d | j| jj| jj$| jj| jj| j| jddd|}n|j|
|
| j| j	| jgdd\}}}}}|d urZ|%dd}t&j'(||j)|jd  df}|j*| j|dd | jdvr| +| |%dddd |f %dd}n0t,|%dd| jjd| jj| jd%dd}t ||}tj|| j|	|	gdd\}}}t-|||d| j|||||| jd|||| jdf| j#| jd d d| jdd|\}}|d url|d url|j.| j|d |||d}| ||}| |}|S )Nr$   r%   r   r   r.   .r,   T)zr   dt_softplusg        r5   Zdt_limitF)r   r'   seq_idxr~   Zrmsnorm_weightZrmsnorm_epsZoutproj_weightZoutproj_biasZheaddimZngroupsZnorm_before_gatereturn_final_statesrM   rN   rO   )rl   Zswish)xrd   rw   r~   )r'   r   r   r   r   r   r   rS   )/r;   r   r   rA   r|   rG   rC   squeezesplitr   r   rI   rM   r   rd   rw   r~   r   expr   floatr1   rD   r8   rk   r   r   viewr   rJ   r   r   r   trainingr   r'   re   	transposer   r   r    r@   rR   r   r   r   rV   )rK   r9   r   r   r:   projected_statesr>   seq_len_Zgroups_time_state_sized_mlprp   hidden_states_B_CdtBCr   r   r   Zhidden_states_reshapedoutZdt_limit_kwargshidden_states_B_C_transposedrI   scan_output	ssm_stater!   r!   r"   cuda_kernels_forward)  s   

"


<"
"

$




z Mamba2Mixer.cuda_kernels_forwardc           2   
      s  |j \}}}|j}t||}|}	|	j d dj  dj j  j d }
|	j|
|
jj	jgdd\}}}}}|d ur|d ur|d dkr|j
j|dd |jj jjjjd}tj|jjd dd}jr|jj }|}nr|d urD|dd}tj||j|j d  df}|j
j|d	d |ddd
d |f dd}t||}tj|jjj jj gdd\}}}tj  }|d urp|d urp|d dkrp|jj}|d d dd d f d d d d
f }|dd ||j d j!}j"d  j"j d j!}tjj#|||j }t$|j%d j%d }|d  jj!jjtj&d}t|d | j|d}|'|jdd
d d d f }| |jjj |j d ( }|'|d|j d }|d |d
d d d f  }|'|dj!}||d  j|d}|j)j|jj | | d |'|jdd
d d d f }| |jjj |j d ( }|'|d|j d }|jj j|j|jd}|*|j j!j}|*|j jd}t+||}|*|jj!}j,d  j,j d j!}|||  |j}|'|dd d d d
f }ntj#|j" }t$|j%d j%d }|'||dj! }|'||dj }|'||dj }|j-jj djd}|j-jj djd}j.|j.  j.  j,d t/|  }||d  }||j| } fdd||||fD \}}}}|0dddd}tj1|dd}tt2|}|d d d d d d d d d d d f |d d d d d d d d d d d f  } | jdd}!|!d |0dddddd  }"|"jdd}#|#d |d d d d d f  jdd}$t|d d d d d d dd f | }%||%0ddddd  }&|&d
d d d f |d  jdd}'|d ur |d ur |d dkr |jj d d d d
f j|'jd}(nt3|'d d d df }(tj4|(|'gdd}'tt2tj|d d d d d d df d})|)dd})|)d |'d d d d d d
f  jdd}*|*d d d df |*d d df  }'}+t|},|d
d d d f |'d d d d d d
f  }-|,0dddd}.|-d|.d  }/|$|/ }|'|djj!}|| } dkrZ|d d d |d d d d f }|'||d}|+d ur|d ur|j)j|+d 5||}06|0|}1|1S )Nr$   r%   r.   r   Fr   r+   r   T.r)   ).NNr   rS   r*   )r/   Zoutput_sizec                    s   g | ]}t | jqS r!   )r(   r'   ).0tr   rK   r!   r"   
<listcomp>K      z-Mamba2Mixer.torch_forward.<locals>.<listcomp>r   r   r-   )r   r   )7r   r,   r;   r   rG   rA   r|   rC   r   r   rR   rM   rI   r8   r   rd   r+   r   sumr   r}   rw   r   r   r   r   r    r@   r   r   r   rJ   r1   rD   r   Zsoftplusclampr   rk   r&   
contiguousrV   r   Zbmmr   Zrepeat_interleaver'   r#   Zpermuter4   r7   Z
zeros_likecatr   r   )2rK   r9   r   r   r:   r>   r   r   r,   r   r   rp   r   r   rI   r   r   r   r   Zcache_devicer   ZdAZdBZdBxrJ   Zssm_states_reshapedZ
C_reshapedyr   Z
D_residualZA_cumsumLZG_intermediateGZM_intermediateMZY_diagZdecay_statesZB_decayZstatesZprevious_statesZdecay_chunkZ
new_statesr   Zstate_decay_outZC_times_statesZstate_decay_out_permutedZY_offr   Zcontextualized_statesr!   r   r"   torch_forward  s    

.
,
"$"$$$P&*""&0(&
*
 zMamba2Mixer.torch_forwardr   r   r:   c                 C   s4   t r$d| jjjjv r$| ||||S | ||||S )Ncuda)r   r   rd   r+   typer   r   )rK   r9   r   r   r:   r!   r!   r"   rs     s    zMamba2Mixer.forward)NNN)NNN)NNN)rY   rZ   r[   r\   r   rE   rL   r   r^   r   r<   
LongTensorr   r   rs   ru   r!   r!   rh   r"   rv      s<   E    '    I   rv   c                       s&   e Zd Zd fdd	Zdd Z  ZS )Mamba2RMSNormr`   c                    s&   t    tt|| _|| _dS )zM
        Mamba2RMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
        Nra   rf   rh   r!   r"   rL     s    
zMamba2RMSNorm.__init__c                 C   sJ   |j }|tj}|djddd}|t|| j  }| j|| S rj   )	r,   r8   r   rk   rm   rn   ro   re   rd   )rK   r9   rq   rr   r!   r!   r"   rs     s
    zMamba2RMSNorm.forward)r`   rt   r!   r!   rh   r"   r     s   r   c                       s@   e Zd Z fddZdee eej eej dddZ	  Z
S )Mamba2Blockc                    sB   t    || _|| _|j| _t|j|jd| _t	||d| _
d S )Nr{   rM   )rb   rL   r=   rM   residual_in_fp32r   rF   r   r   rv   mixer)rK   r=   rM   rh   r!   r"   rL     s    
zMamba2Block.__init__Nr   c                 C   sL   |}|  |j| j jjd}| jr.|tj}| j||||d}|| }|S )Nr   r   )r   r8   rd   r,   r   r   rk   r   )rK   r9   r   r   r:   Zresidualr!   r!   r"   rs     s    zMamba2Block.forward)NNN)rY   rZ   r[   rL   r   r<   r   r   r^   rs   ru   r!   r!   rh   r"   r     s      r   c                   @   s0   e Zd ZU eed< dZdgZdZdZdd Z	dS )Mamba2PreTrainedModelr=   backboner   Tc                 C   s  | j j}t|trVtd| j jd }|jt	| d|j_
d|j_
|jjd tt| j jt	| j jt	| j j  t	| j j j| j jd}|t	t|   }|j| d|j_tjj|jjtdd |jjdurt|jjdd	stj |jj tjj|j!jtdd | j j"rV|j!j}|t| j j# }t|tj$rt|jdd	stjj%|j|d
 |jdurt|jdd	stj |j n@t|t&t'fr|jjd n t|tj(rtjj%|j|d
 dS )zInitialize the weights.r   Tg      ?)min   )aN
_no_reinitF)std))r=   Zinitializer_range
isinstancerv   r   r   rC   r   Zcopy_r   r   r   dataZfill_r   Zrandmathr   r   r   Ztime_step_floorexpm1r   r   r   initZkaiming_uniform_r   rd   sqrtrw   getattrZzeros_r   Zrescale_prenorm_residualrH   r   Znormal_r   r_   	Embedding)rK   moduler   r   r   Zinv_dtpr!   r!   r"   _init_weights  sJ    
z#Mamba2PreTrainedModel._init_weightsN)
rY   rZ   r[   r   __annotations__Zbase_model_prefixZ_no_split_modulesZsupports_gradient_checkpointingZ_is_statefulr   r!   r!   r!   r"   r     s   
r   z-
    Class for the MAMBA2 model outputs.
    )Zcustom_introc                   @   sJ   e Zd ZU dZdZeej ed< dZ	ee
 ed< dZeeej  ed< dS )Mamba2Outputa:  
    cache_params (`Mamba2Cache`):
        The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
        avoid providing the old `input_ids`.

        Includes both the State space model state matrices after the selective scan, and the Convolutional states
    Nlast_hidden_stater   r9   )rY   rZ   r[   r\   r   r   r   FloatTensorr   r   r<   r9   tupler!   r!   r!   r"   r   	  s   
r   zK
    Base class for causal language model (or autoregressive) outputs.
    c                   @   s\   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
ee ed< dZeeej  ed< dS )Mamba2CausalLMOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss (for next-token prediction).
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
    cache_params (`Mamba2Cache`):
        The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
        avoid providing the old `input_ids`.

        Includes both the State space model state matrices after the selective scan, and the Convolutional states
    Nlosslogitsr   r9   )rY   rZ   r[   r\   r   r   r   r   r   r   r   r<   r9   r   r!   r!   r!   r"   r     s
   
r   c                       s   e Zd Z fddZdd Zdd Zdd Zedee	j
 ee	j
 ee ee ee ee ee	j
 ee	j eeef d
	ddZ  ZS )Mamba2Modelc                    sn   t    t j j| _t fddt j	D | _
d| _t j jd| _| | j |   d S )Nc                    s   g | ]}t  |d qS )r   )r   )r   idxr=   r!   r"   r   >  r   z(Mamba2Model.__init__.<locals>.<listcomp>Fr{   )rb   rL   r   r   
vocab_sizerF   
embeddingsZ
ModuleListrangerH   layersgradient_checkpointingr   r   norm_fZ"_register_load_state_dict_pre_hook	load_hook	post_initrK   r=   rh   r   r"   rL   :  s     zMamba2Model.__init__c                 G   s0   |D ]&}d|v r| |||dd<  q,qd S )Nz
embedding.zembeddings.)popreplace)rK   Z
state_dictprefixargskr!   r!   r"   r   F  s    zMamba2Model.load_hookc                 C   s   | j S rU   r   rW   r!   r!   r"   get_input_embeddingsL  s    z Mamba2Model.get_input_embeddingsc                 C   s
   || _ d S rU   r   rK   Znew_embeddingsr!   r!   r"   set_input_embeddingsO  s    z Mamba2Model.set_input_embeddingsN)		input_idsinputs_embedsr   	use_cacheoutput_hidden_statesreturn_dictr   r:   rP   c	                 K   sd  |dur|n| j j}|dur |n| js.| j jnd}|dur>|n| j j}|du |duA r^td|du rp| |}| jr| jr|rd}|r|du rt| j |	d|j
|jd}tjd| j j|j
d}q|du rtdnd}|}
|rdnd}| jD ]"}||
|||d	}
|r||
f }q| |
}
|r.||
f }|sLtd
d |
||fD S t|
|rZ|nd|dS )a  
        cache_params (`Mamba2Cache`, *optional*):
            If passed along, the model uses the previous state in all the blocks (which will give the output for the
            `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
        use_cache (`bool`, *optional*):
            If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
        cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            The position of the current input in the cache. This is used to ensure that the cache is correctly updated.
            If `cache_params` is passed, `cache_position` should also be passed.
        NFz:You must specify exactly one of input_ids or inputs_embedsr   r*   r   zYou have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will be initialized for you automaticallyr!   r   c                 s   s   | ]}|d ur|V  qd S rU   r!   )r   vr!   r!   r"   	<genexpr>  r   z&Mamba2Model.forward.<locals>.<genexpr>)r   r   r9   )r=   r   r   r   use_return_dict
ValueErrorr   r   r<   r0   r+   r,   r   r   r?   r   r   r   r   )rK   r   r   r   r   r   r   r   r:   kwargsr9   Zall_hidden_statesZmixer_blockr!   r!   r"   rs   R  sT    



zMamba2Model.forward)NNNNNNNN)rY   rZ   r[   rL   r   r   r   r   r   r   r   r<   r3   r^   r   r   r   rs   ru   r!   r!   rh   r"   r   8  s0           
r   z
    The MAMBA2 Model transformer with a language modeling head on top (linear layer with weights not tied to the input
    embeddings).
    c                       s   e Zd Zg Z fddZdd Zdd Zdee ee	j
 ee	j dd	d
Zedee	j
 ee	j ee ee	j
 ee ee ee ee	j ee	j eeef d
ddZ  ZS )Mamba2ForCausalLMc                    s8   t  | t|| _tj|j|jdd| _| 	  d S )NFrz   )
rb   rL   r   r   r   r   rF   r   lm_headr   r   rh   r!   r"   rL     s    
zMamba2ForCausalLM.__init__c                 C   s
   | j  S rU   )r   r   rW   r!   r!   r"   r     s    z&Mamba2ForCausalLM.get_input_embeddingsc                 C   s   | j |S rU   )r   r   r   r!   r!   r"   r     s    z&Mamba2ForCausalLM.set_input_embeddingsNr   c           
      K   s   d|  i}|rn|d u rntjd| jjj|jd}|d urLd|i}|d}	n
|d}	t| jj|	| j| j	d}|r|d dkr|d d df 
d  |d< d }|s|d urd|i}|||||d |S )Nr   r   r   r   r*   r$   )r   r   r   r:   )r   r   r   r   r=   r?   r+   r0   r<   r,   Z	unsqueezeupdate)
rK   r   r   r   r   r   r:   r   Zmodel_inputsZmax_batch_sizer!   r!   r"   prepare_inputs_for_generation  s*    
z/Mamba2ForCausalLM.prepare_inputs_for_generation)
r   r   r   labelsr   r   r   r   r:   rP   c
              
   K   s   |dur|n| j j}| j||||||||	d}|d }| || jjj }d}|durx| jf ||| j j	d|
}|s|f|dd  }|dur|f| S |S t
|||j|jdS )ao  
        cache_params (`Mamba2Cache`, *optional*):
            If passed along, the model uses the previous state in all the blocks (which will give the output for the
            `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        use_cache (`bool`, *optional*):
            If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
        cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            The position of the current input in the cache. This is used to ensure that the cache is correctly updated.
            If `cache_params` is passed, `cache_position` should also be passed.
        N)r   r   r   r   r   r   r:   r   )r   r  r   r   )r   r   r   r9   )r=   r   r   r  r8   rd   r,   r   Zloss_functionr   r   r   r9   )rK   r   r   r   r  r   r   r   r   r:   r   Zmamba2_outputsr9   r   r   outputr!   r!   r"   rs     s2    
zMamba2ForCausalLM.forward)NNNNN)	NNNNNNNNN)rY   rZ   r[   Z_tied_weights_keysrL   r   r   r   r<   r   r   r^   r  r   r   r3   r   r   r   rs   ru   r!   r!   rh   r"   r    sH        *         
r  )r  r   r   ):r\   r   dataclassesr   typingr   r   r   Ztorch.utils.checkpointr   Zactivationsr   Z
generationr   Zmodeling_layersr	   Zmodeling_utilsr
   utilsr   r   r   Zutils.import_utilsr   r   Zconfiguration_mamba2r   Z
get_loggerrY   r   Z+mamba_ssm.ops.triton.selective_state_updater   Z!mamba_ssm.ops.triton.ssd_combinedr   r   Zcausal_conv1dr   r   allr   r^   rE   r#   r(   r7   r;   r<   Moduler_   rv   r   r   r   r   r   r   r  __all__r!   r!   r!   r"   <module>   sx   

M   A>mv