a
    h                     @   s  d Z ddlZddlmZ ddlmZ ddlmZmZ ddl	Z	ddl
Z	ddl	mZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZmZmZmZ ddlmZ eeZdadd ZG dd de	j j!Z"d,ddZ#d-ddZ$G dd dej%Z&G dd dej%Z'G dd deZ(eG dd deZ)eedd G d!d" d"eZ*eed#d G d$d% d%eZ+eG d&d' d'e)Z,ed(d G d)d* d*e)eZ-g d+Z.dS ).zPyTorch RWKV model.    N)	dataclass)Path)OptionalUnion)nn   )GenerationMixin)GradientCheckpointingLayer)PreTrainedModel)ModelOutputauto_docstringis_bitsandbytes_availableis_ninja_availableis_torch_cuda_availablelogging   )
RwkvConfigc                    s   ddl m} tt jjjd d   fdddD }td urNtj| krNd S t	d|  d	 d
dddddd|  g}|d|  |t
 t
jk|da| t_d S )Nr   )loadZkernelsrwkvc                    s   g | ]} | qS  r   ).0fZkernel_folderr   b/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/rwkv/modeling_rwkv.py
<listcomp>5       z(load_wkv_cuda_kernel.<locals>.<listcomp>)z
wkv_op.cppzwkv_cuda.cuzwkv_cuda_bf16.cuz2Loading CUDA kernel for RWKV at context length of .z
-res-usagez--maxrregcount 60z--use_fast_mathz-O3z-Xptxas -O3z--extra-device-vectorizationz-DTmax=Zwkv_)namesourcesverboseZextra_cuda_cflags)Ztorch.utils.cpp_extensionr   r   __file__resolveparentrwkv_cuda_kernelmax_seq_lengthloggerinfor   Zget_verbosityDEBUG)context_lengthZload_kernelZcuda_kernel_filesflagsr   r   r   load_wkv_cuda_kernel/   s*    	r*   c                   @   s(   e Zd ZedddZedddZdS )	RwkvLinearAttentionNFc              	   C   s  |  \}}}	|tjkr0td| dtj d||	 t|	d dkrhtd| d|	 dt|	d d	|j| _|jjd
ks|jjd
ks|jjd
ks|jjd
krtdt	
|   }|jt	jkr| }| }| }| }| }| }t	j|t	jd}
|s|d ur|d u r^t	j||	dt	j|jt	jd}|d d d d df  d8  < nt	jdd |D dd }|jt	jkrtj}ntj}||||||
| n*|jt	jkrtjntj}||||||
 | |||||
 |d ur
dd t	j|dddD }|
| j|fS )NzCannot process a batch with z+ tokens at the same time, use a maximum of z with this model.    r   zThe product of batch size (z) and hidden size (z") needs to be a round multiple of r   cudazUCalling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.memory_formatr   )dtypedevicer/      籡*Gc                 S   s   g | ]}| d qS r2   Z	unsqueezer   sr   r   r   r   ~   r   z/RwkvLinearAttention.forward.<locals>.<listcomp>)dimc                 S   s   g | ]}| d qS r4   )Zsqueezer6   r   r   r   r      r   )sizer#   r$   
ValueErrorminr0   input_dtyper1   typetorchexpfloat
contiguousfloat16
empty_likecontiguous_formatzerosfloat32catbfloat16Zforward_with_state_bf16Zforward_with_stateZforward_bf16forwardZsave_for_backwardchunkto)ctx
time_decay
time_firstkeyvaluestatereturn_stateZ
batch_sizeZseq_lenhidden_sizeoutputZforward_funcr   r   r   rI   P   sl    





 
zRwkvLinearAttention.forwardc                 C   s   | j }| j\}}}}}tj|tj|tjkr0tjntjd}	tj|tjd}
tj|tjd}tj|tjd}|tjkr|| }|tjkrt	j
nt	j}||||||| |	|
||
 |	||
|||||d d fS )N)r/   r0   r.   )r<   Zsaved_tensorsr>   rC   rD   rH   rF   rB   r@   r#   Zbackward_bf16backwardrA   rK   )rL   Zg_outputZg_stater<   rM   rN   rO   rP   rT   Zg_time_decayZg_time_firstZg_keyZg_valueZbackward_funcr   r   r   rU      s@    
zRwkvLinearAttention.backward)NF)N)__name__
__module____qualname__staticmethodrI   rU   r   r   r   r   r+   O   s   >r+   Fc                 C   s  |  \}}}t|}|d u rztj|d d df tjd}	tj|d d df tjd}
tj|d d df tjdd }n
|\}	}
}t|  } t|D ]}|d d |f  }|d d |f }t||| }t|| }t|| | }||	 ||  }||
 | }|| |j	|d d |f< t||  |}t||  | }t|| }||	 ||  }	||
 | }
|}q|s|d ur|	|
|g}||fS )Nr   )r0   r3   )
r9   r>   Z
zeros_likerF   r?   ranger@   maximumrK   r0   )rM   rN   rO   rP   rQ   rR   _Z
seq_lengthrT   Z	num_stateZ	den_stateZ	max_stateZcurrent_indexcurrent_keycurrent_valueZmax_for_outpute1e2	numeratordenominatorZmax_for_stater   r   r   rwkv_linear_attention_cpu   s4    
"

rc   c                 C   sd   t dd | |||fD }|ddk}td u s8|s8|rLt| |||||dS t| |||||S d S )Nc                 s   s   | ]}|j jd kV  qdS )r-   N)r1   r=   )r   tr   r   r   	<genexpr>   r   z(rwkv_linear_attention.<locals>.<genexpr>r   rQ   rR   )anyr9   r#   rc   r+   apply)rM   rN   rO   rP   rQ   rR   Zno_cudaZ	one_tokenr   r   r   rwkv_linear_attention   s
    ri   c                       s2   e Zd Zd
 fdd	ZdddZddd	Z  ZS )RwkvSelfAttentionr   c                    sD  t    || _td uo"tj|jk}t r`t r`|s`zt|j W n t	y^   t
d Y n0 || _|j}|jd ur||jn|}|| _tt|| _tt|| _ttdd|| _ttdd|| _ttdd|| _td| _tj||dd| _tj||dd| _tj||dd| _tj||dd| _d S )Nz9Could not load the custom CUDA kernel for RWKV attention.r   r   r   r   Fbias)super__init__configr#   r$   r(   r   r   r*   	Exceptionr%   r&   layer_idrS   attention_hidden_sizer   	Parameterr>   emptyrM   rN   time_mix_keytime_mix_valuetime_mix_receptance	ZeroPad2d
time_shiftLinearrO   rP   
receptancerT   )selfrq   rs   Zkernel_loadedrS   rt   	__class__r   r   rp      s.    
zRwkvSelfAttention.__init__Nc                 C   s  | ddkr4|d ur4|d d d d d | jf }n:| |}|d urn|d d d d d | jf |d d df< || j |d| j   }|| j |d| j   }|| j |d| j   }| |}| |}t	| 
|}|d ur|d d df |d d d d d | jf< ||||fS Nr   r   rl   )r9   rs   r{   rw   rx   ry   rO   rP   r>   sigmoidr}   )r~   hiddenrQ   shiftedrO   rP   r}   r   r   r   extract_key_value  s    
(


(z#RwkvSelfAttention.extract_key_valueFc           	         s    j ||d\}}}}|d ur<t fdd|dd  D nd }t j j||||d\}}|d ur|d |d d d d d  jf< |d |d d d d d  jf< |d |d	 d d d d  jf<  || |fS )
NrQ   c                 3   s&   | ]}|d d d d  j f V  qd S Nrs   r6   r~   r   r   re   $  r   z,RwkvSelfAttention.forward.<locals>.<genexpr>r2   rf   r   r   r      )r   tupleri   rM   rN   rs   rT   )	r~   r   rQ   	use_cacher}   rO   rP   Zlayer_stater   r   r   r   rI   "  s    *
	   zRwkvSelfAttention.forward)r   )N)NF)rV   rW   rX   rp   r   rI   __classcell__r   r   r   r   rj      s   
rj   c                       s(   e Zd Zd fdd	ZdddZ  ZS )	RwkvFeedForwardr   c                    s   t    || _|| _|j}|jd ur,|jnd|j }td| _t	t
dd|| _t	t
dd|| _tj||dd| _tj||dd| _tj||dd| _d S )Nr   rk   r   Frm   )ro   rp   rq   rs   rS   intermediate_sizer   rz   r{   ru   r>   rv   rw   ry   r|   rO   r}   rP   )r~   rq   rs   rS   r   r   r   r   rp   7  s    
zRwkvFeedForward.__init__Nc                 C   s
  | ddkr4|d ur4|d d d d d | jf }n:| |}|d urn|d d d d d | jf |d d df< || j |d| j   }|| j |d| j   }tt| |}| 	|}t
| |}|d ur|d d df |d d d d d | jf< || |fS r   )r9   rs   r{   rw   ry   r>   ZsquareZrelurO   rP   r   r}   )r~   r   rQ   r   rO   r}   rP   r   r   r   rI   H  s    
(
(zRwkvFeedForward.forward)r   )NrV   rW   rX   rp   rI   r   r   r   r   r   r   6  s   r   c                       s&   e Zd Z fddZdddZ  ZS )	RwkvBlockc                    sv   t    || _|| _|dkr2tj|j|jd| _tj|j|jd| _	tj|j|jd| _
t||| _t||| _d S )Nr   )eps)ro   rp   rq   rs   r   	LayerNormrS   Zlayer_norm_epsilonpre_lnln1ln2rj   	attentionr   feed_forward)r~   rq   rs   r   r   r   rp   ]  s    
zRwkvBlock.__init__NFc                 C   sz   | j dkr| |}| j| |||d\}}|| }| j| ||d\}}|| }||f}|rn||f7 }n|d7 }|S )Nr   )rQ   r   r   r   )rs   r   r   r   r   r   )r~   r   rQ   r   output_attentionsr   r   outputsr   r   r   rI   k  s    

zRwkvBlock.forward)NFFr   r   r   r   r   r   \  s   r   c                   @   s@   e Zd ZU eed< dZdgZddgZdZdZ	e
jddd	Zd
S )RwkvPreTrainedModelrq   r   r   rM   rN   T)modulec                    s  t |tr<|j}|jj}|jj|j ||d  d||  }tjfddt	D |j
j|j
jd}|ddddf } fddt	 D }tj||jj|jjd}tjdd t	 D |jj|jjdd	 }||j_t|jtd
 | |j_t|||j
_t||d
  |j_t|d	| |j_nt |tr|j}|jj}|jjd||  }tjfddt	D |j
j|j
jd}|ddddf }t|||j
_t|||j_nt |tjrn|jjj}d}	d}
|jdur|jj  |d |d kr*t|d |d  }	|d | jjkrR|d | jjkrRd	}
|	|
9 }	tjj |j|	d npt |tj!r|jjj}dtt"|d |d  }	tjj |j|	d n(t |tj#r|jj$d |jj  dS )zInitialize the weights.r   g      ?c                    s   g | ]}|  qS r   r   r   irS   r   r   r     r   z5RwkvPreTrainedModel._init_weights.<locals>.<listcomp>r0   r1   Nc                    s,   g | ]$}d d| d  dd     qS )   r   gffffff?g?r   )r   h)rt   ratio_0_to_1r   r   r     s   c                 S   s   g | ]}|d  d d  qS )r   r   r   r   r   r   r   r     r   g      ?g333333?c                    s   g | ]}|  qS r   r   r   r   r   r   r     r   r   )gaing-C6?)%
isinstancerj   rs   rq   num_hidden_layersrS   rt   r>   ZtensorrZ   rw   r0   r1   rM   rN   dataZ	ones_likemathlogpowrx   ry   r   r   r|   weightshapern   Zzero_sqrt
vocab_sizeinitZorthogonal_	Embeddingmaxr   Zfill_)r~   r   rs   r   Zratio_1_to_almost0Ztime_weightZdecay_speedZzigzagr   r   scaler   )rt   rS   r   r   _init_weights  s|    	
$
z!RwkvPreTrainedModel._init_weightsN)rV   rW   rX   r   __annotations__Zbase_model_prefixZ_no_split_modulesZ_keep_in_fp32_modulesZsupports_gradient_checkpointingZ_is_statefulr   Moduler   r   r   r   r   r   ~  s   
r   z+
    Class for the RWKV model outputs.
    )Zcustom_introc                   @   sn   e Zd ZU dZdZeej ed< dZ	ee
ej  ed< dZeeejdf  ed< dZeeejdf  ed< dS )
RwkvOutputa  
    state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
        The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
        avoid providing the old `input_ids`.
    Nlast_hidden_staterQ   .hidden_states
attentions)rV   rW   rX   __doc__r   r   r>   FloatTensorr   rQ   listr   r   r   r   r   r   r   r     s
   
r   zK
    Base class for causal language model (or autoregressive) outputs.
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeejdf  ed< dZeeejdf  ed< dS )	RwkvCausalLMOutputap  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss (for next-token prediction).
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
    state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
        The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
        avoid providing the old `input_ids`.
    NlosslogitsrQ   .r   r   )rV   rW   rX   r   r   r   r>   r   r   r   rQ   r   r   r   r   r   r   r   r   r     s   

r   c                       s   e Zd Z fddZdd Zdd Zedeej	 eej	 eej
 eeej
  ee ee ee ee eeef d	d	d
Zdd Zdd Z  ZS )	RwkvModelc                    sd   t    t j j| _t fddt j	D | _
t j| _d| _d| _|   d S )Nc                    s   g | ]}t  |d qS )r   )r   )r   idxrq   r   r   r     r   z&RwkvModel.__init__.<locals>.<listcomp>F)ro   rp   r   r   r   rS   
embeddingsZ
ModuleListrZ   r   blocksr   ln_outlayers_are_rescaledgradient_checkpointing	post_initr~   rq   r   r   r   rp      s     zRwkvModel.__init__c                 C   s   | j S r   r   r   r   r   r   get_input_embeddings  s    zRwkvModel.get_input_embeddingsc                 C   s
   || _ d S r   r   r~   Znew_embeddingsr   r   r   set_input_embeddings  s    zRwkvModel.set_input_embeddingsN)		input_idsattention_maskinputs_embedsrQ   r   r   output_hidden_statesreturn_dictreturnc	                    s"  |dur|n| j j}|dur |n| j j}|dur4|n| jsB| j jnd}|durR|n| j j}|durltd | j| jkr| 	  |dur durt
dn|du r du rt
d du r| | |r|du r d| j j| j jf fddtd	D }|d
  d8  < | jr8| jr8|r8td d} }	|rFdnd}
|rTdnd}t| jD ]t\}}||	|||d\}	}}| jr| j jdkr|d | j j dkr|	d }	|r||	f }|rb|
|f }
qb| |	}	|r||	f }|stdd |	|||
fD S t|	|||
dS )a  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
            `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
            sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
            If passed along, the model uses the previous state in all the blocks (which will give the output for the
            `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
        use_cache (`bool`, *optional*):
            If set to `True`, the last state is returned and can be used to quickly generate the next logits.
        NFz<`attention_mask` was passed, but it is unused in this model.zDYou cannot specify both input_ids and inputs_embeds at the same timez5You have to specify either input_ids or inputs_embedsr   c                    s0   g | ](}t j|d kr jnt j jdqS )r   r   )r>   rE   r0   rF   r1   r   r   r   r   r   r   J  s   z%RwkvModel.forward.<locals>.<listcomp>   r   gꌠ9Y>)FzZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...r   )rQ   r   r   r   r2   c                 s   s   | ]}|d ur|V  qd S r   r   )r   xr   r   r   re   u  r   z$RwkvModel.forward.<locals>.<genexpr>)r   rQ   r   r   )rq   r   r   trainingr   use_return_dictr%   Zwarning_oncer   _rescale_layersr:   r   r9   rS   r   rZ   r   	enumerater   rescale_everyr   r   r   )r~   r   r   r   rQ   r   r   r   r   r   Zall_self_attentionsZall_hidden_statesr   blockr   r   r   r   rI     sp    






zRwkvModel.forwardc                 C   s  | j | j krd S | jjdkrtt 8 t| jD ]\}}| jr|jj	j
dt|| jj   |jjj
dt|| jj   q6t|jj	j
dr|jj	j
jdt|| jj   |jjj
jdt|| jj   q6t|jj	j
dr| |jj	| | |jj| q6|jj	j
dt|| jj   |jjj
dt|| jj   q6W d    n1 sj0    Y  | j | _ d S )Nr   r2   SCBquant_state)r   r   rq   r   r>   Zno_gradr   r   r   rT   r   Zmul_intr   rP   hasattrr   div_ _bnb_4bit_dequantize_and_rescale)r~   block_idr   r   r   r   r   ~  s"     ""$ BzRwkvModel._rescale_layersc                 C   st   t  stdddl}|j|jj|jj}|dt	|| j
j   |jj|ddd|j}t|d| dS )	z
        Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
        be quantized again.
        z/Please install bitsandbytes to use this method.r   Nr2   cpuF)Zrequires_gradr   )r   ImportErrorZbitsandbytesZ
functionalZdequantize_4bitr   r   r   r   r   rq   r   r   Z
Params4bitrK   r1   setattr)r~   Ztarget_layerr   ZbnbZdequant_weightsZquant_weightr   r   r   r     s    z*RwkvModel._bnb_4bit_dequantize_and_rescale)NNNNNNNN)rV   rW   rX   rp   r   r   r   r   r>   
LongTensorr   r   boolr   r   r   rI   r   r   r   r   r   r   r   r     s2           
ir   z
    The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
    embeddings).
    c                       s   e Zd ZdgZ fddZdd Zdd Zdd	d
Zede	e
j e	e
j e	e
j e	ee
j  e	e
j e	e e	e e	e e	e eeef d
ddZ  ZS )RwkvForCausalLMzhead.weightc                    s8   t  | t|| _tj|j|jdd| _| 	  d S )NFrm   )
ro   rp   r   r   r   r|   rS   r   headr   r   r   r   r   rp     s    
zRwkvForCausalLM.__init__c                 C   s   | j S r   r   r   r   r   r   get_output_embeddings  s    z%RwkvForCausalLM.get_output_embeddingsc                 C   s
   || _ d S r   r   r   r   r   r   set_output_embeddings  s    z%RwkvForCausalLM.set_output_embeddingsNc                 K   sT   |d ur|d d df  d}|d ur8|d u r8d|i}nd|i}||d< ||d< |S )Nrl   r   r   rQ   r   r5   )r~   r   rQ   r   r   kwargsZmodel_inputsr   r   r   prepare_inputs_for_generation  s    
z-RwkvForCausalLM.prepare_inputs_for_generation)
r   r   r   rQ   labelsr   r   r   r   r   c
              	   K   s   |	dur|	n| j j}	| j|||||||	d}|d }| |}d}|durf| j||fd| j ji|
}|	s|f|dd  }|dur|f| S |S t|||j|j|j	dS )aJ  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
            `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
            sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
            If passed along, the model uses the previous state in all the blocks (which will give the output for the
            `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        use_cache (`bool`, *optional*):
            If set to `True`, the last state is returned and can be used to quickly generate the next logits.
        N)r   rQ   r   r   r   r   r   r   r   )r   r   rQ   r   r   )
rq   r   r   r   Zloss_functionr   r   rQ   r   r   )r~   r   r   r   rQ   r   r   r   r   r   r   Zrwkv_outputsr   r   r   rT   r   r   r   rI     s@    %	
zRwkvForCausalLM.forward)NNN)	NNNNNNNNN)rV   rW   rX   Z_tied_weights_keysrp   r   r   r   r   r   r>   r   r   r   r   r   r   r   rI   r   r   r   r   r   r     s6   
         
r   )r   r   r   )NF)NF)/r   r   dataclassesr   pathlibr   typingr   r   r>   Ztorch.utils.checkpointr   Z
generationr   Zmodeling_layersr	   Zmodeling_utilsr
   utilsr   r   r   r   r   r   Zconfiguration_rwkvr   Z
get_loggerrV   r%   r#   r*   ZautogradFunctionr+   rc   ri   r   rj   r   r   r   r   r   r   r   __all__r   r   r   r   <module>   sR    
 j
,
F&"T .l