a
    hЍ                     @   s  d dl Z d dlmZmZmZ d dlZd dlm  mZ	 d dlmZ ddl
mZ ddlmZmZ ddlmZ ddlmZ dd	lmZ dd
lmZmZ ddlmZmZ ddlmZmZ ddlmZm Z  ddl!m"Z"m#Z# ddl$m%Z% ddl&m'Z'm(Z(m)Z)m*Z* ddl+m,Z, ddl-m.Z.m/Z/ ddl0m1Z1 e* r2d dl2m3Z3 edG dd dej4Z5G dd dej4Z6dd Z7d@ddZ8ej9e:ej9dd d!Z;dAej4ej9ej9ej9eej9 e<e<e%e' d#d$d%Z=dBej4ej9ej9ej9eej9d&f ee< ee< eej9 e>ej9ej9f d'	d(d)Z?e" Z@e?e@d*< G d+d, d,ej4ZAG d-d. d.ej4ZBG d/d0 d0ej4ZCG d1d2 d2eZDe(G d3d4 d4e#ZEe(G d5d6 d6eEZFdCeej9e>ej9 df ee: ee: e:eej9 eej9e:f d8d9d:ZGe(G d;d< d<eEeZHG d=d> d>eeEZIg d?ZJdS )D    N)CallableOptionalUnion)nn   )ACT2FN)CacheDynamicCache)GenerationMixin)use_kernel_forward_from_hub)compile_friendly_flex_attention)create_causal_mask!create_sliding_window_causal_mask) GenericForSequenceClassificationGradientCheckpointingLayer)MoeCausalLMOutputWithPastMoeModelOutputWithPast)ROPE_INIT_FUNCTIONSdynamic_rope_update)AttentionInterfacePreTrainedModel)Unpack)TransformersKwargsauto_docstringcan_return_tupleis_torch_flex_attn_available)deprecate_kwarg)OutputRecordercheck_model_inputs   )
DogeConfig)	BlockMaskZRMSNormc                       s.   e Zd Zd fdd	Zdd Zdd Z  ZS )	DogeRMSNormư>c                    s&   t    tt|| _|| _dS )z:
        DogeRMSNorm is equivalent to T5LayerNorm
        N)super__init__r   	Parametertorchonesweightvariance_epsilon)selfhidden_sizeeps	__class__ b/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/doge/modeling_doge.pyr%   6   s    
zDogeRMSNorm.__init__c                 C   sJ   |j }|tj}|djddd}|t|| j  }| j|| S )N   T)keepdim)	dtypetor'   float32powmeanZrsqrtr*   r)   )r+   hidden_statesZinput_dtypeZvariancer0   r0   r1   forward>   s
    zDogeRMSNorm.forwardc                 C   s   t | jj d| j S )Nz, eps=)tupler)   shaper*   )r+   r0   r0   r1   
extra_reprE   s    zDogeRMSNorm.extra_repr)r#   )__name__
__module____qualname__r%   r;   r>   __classcell__r0   r0   r.   r1   r"   4   s   r"   c                       sD   e Zd ZU ejed< ded fddZe e	dd Z
  ZS )	DogeRotaryEmbeddinginv_freqNconfigc                    s   t    t|dr:t|jtr:|jd|jd| _nd| _|j| _	|j| _
|| _t| j | _| | j|\}| _| jd|dd | j| _d S )Nrope_scaling	rope_typetypedefaultrD   F)
persistent)r$   r%   hasattr
isinstancerG   dictgetrH   Zmax_position_embeddingsZmax_seq_len_cachedZoriginal_max_seq_lenrF   r   Zrope_init_fnattention_scalingZregister_bufferrD   Zoriginal_inv_freq)r+   rF   devicerD   r.   r0   r1   r%   L   s    
zDogeRotaryEmbedding.__init__c           
      C   s   | j d d d d f  |jd dd|j}|d d d d d f  }t|jjtrl|jjdkrl|jjnd}t	j
|ddV | |  dd}t	j||fdd	}| | j }| | j }	W d    n1 s0    Y  |j|jd
|	j|jd
fS )Nr   r3   r   ZmpscpuF)device_typeZenabledr2   dim)r5   )rD   floatexpandr=   r6   rQ   rM   rI   strr'   Zautocast	transposecatcosrP   sinr5   )
r+   xposition_idsZinv_freq_expandedZposition_ids_expandedrS   ZfreqsZembr[   r\   r0   r0   r1   r;   ]   s    0&,zDogeRotaryEmbedding.forward)N)r?   r@   rA   r'   Tensor__annotations__r    r%   Zno_gradr   r;   rB   r0   r0   r.   r1   rC   I   s
   

rC   c                 C   sH   | dd| j d d f }| d| j d d df }tj| |fddS )z*Rotates half the hidden dims of the input..Nr3   r2   rT   )r=   r'   rZ   )r]   x1Zx2r0   r0   r1   rotate_halfm   s    rb   c                 C   sD   | |}| |}| | t| |  }|| t||  }||fS )a  Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    )	unsqueezerb   )qkr[   r\   r^   Zunsqueeze_dimZq_embedZk_embedr0   r0   r1   apply_rotary_pos_embt   s
    

rf   )r:   n_repreturnc                 C   s^   | j \}}}}|dkr| S | dddddddddf |||||} | ||| ||S )z
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    r   N)r=   rW   reshape)r:   rg   batchnum_key_value_headsslenhead_dimr0   r0   r1   	repeat_kv   s
    0rn           )modulequerykeyvalueattention_maskscalingdropoutkwargsc                 K   s   t || j}t || j}	t||dd| }
|d urf|d d d d d d d |jd f }|
| }
tjj|
dtj	d
|j}
tjj|
|| jd}
t|
|	}|dd }||
fS )Nr2   r   r3   )rU   r5   ptrainingr   )rn   num_key_value_groupsr'   matmulrY   r=   r   
functionalsoftmaxr7   r6   r5   rv   r{   
contiguous)rp   rq   rr   rs   rt   ru   rv   rw   
key_statesvalue_statesattn_weightscausal_maskattn_outputr0   r0   r1   eager_attention_forward   s    
&r   r!   )	rp   rq   rr   rs   rt   ru   softcap	head_maskrh   c              
      s   d }	d  t |tr|}	n|  d urJ d d d d d d d |jd f   fdd}
t||||
|	d|dd\}}||j}|dd }||fS )Nrx   c                    s^   d urt |   }  d ur:|  | | | |  } d urZ| | | d d  } | S )Nr   )r'   tanh)ZscoreZ	batch_idxZhead_idxZq_idxZkv_idxr   r   r   r0   r1   	score_mod   s    z)flex_attention_forward.<locals>.score_modT)r   
block_maskZ
enable_gqascaleZ
return_lser   r2   )rM   r!   r=   r   r6   r5   rY   r   )rp   rq   rr   rs   rt   ru   r   r   rw   r   r   r   Zattention_weightsr0   r   r1   flex_attention_forward   s*    
&	
r   Zdoge_flex_attentionc                       s   e Zd Zdeee d fddZedddddej	e
ej	ej	f eej	 ee eej e
ej	eej	 ee
ej	  f d	d
dZdej	ej	eeej	 dddZ  ZS )DogeAttentionNrF   	layer_idxc                    s(  t    || _|| _t|d|j|j | _|j|j | _	| jd | _
|j| _|j| _tj|j|j| j |jd| _tj|j|j| j |jd| _tj|j|j| j |jd| _tt|j| _tj|j| j |j|jd| _tj|j| j |j|jd| _t| j|jd| _t| j|jd| _d S )Nrm   g      ࿩Zbiasr-   )r$   r%   rF   r   getattrr,   Znum_attention_headsrm   rk   r|   ru   attention_dropoutkeep_window_sizer   LinearZattention_biasq_projk_projv_projr&   r'   zerosAdt_projo_projr"   rms_norm_epsq_normk_normr+   rF   r   r.   r0   r1   r%      s4    
zDogeAttention.__init__past_key_valuepast_key_values4.58new_nameversion)r:   position_embeddingsrt   r   cache_positionrh   c                 K   s  |j d d }g |d| jR }| | ||dd}	| | ||dd}
| ||dd}|\}}t	|	|
||\}	}
|d ur|||d}|
|
|| j|\}
}| |dd|j d |j d d}t| jt| dd}| j||| j|d}t|| j}t}| jjdkr>t| jj }|| |	|
|f|| jsXd	n| j| jd
|\}}|jg |dR   }| |}||fS )Nr3   r   r2   )r\   r[   r   r   rx   r:   	dt_statesr   rt   eagerro   )rt   rv   ru   ) r=   rm   r   r   viewrY   r   r   r   rf   updater   r   ri   r'   expr   FZsoftplusprepare_dynamic_maskr   rn   r|   r   rF   Z_attn_implementationALL_ATTENTION_FUNCTIONSr{   r   ru   r   r   )r+   r:   r   rt   r   r   rw   Zinput_shapeZhidden_shapeZquery_statesr   r   r[   r\   Zcache_kwargsr   	attn_maskZattention_interfacer   r   r0   r0   r1   r;   	  sN    
 

zDogeAttention.forward   r   c           
   	   C   s  t |jj}|j}|dddddddf dd|jd d}|durt|ts|jt jkr|j}t 	|t j
d|j|d|}||ddddddd|jd f dk|}|jd |kr
t j|||jd}t j||ddd	d
j}	|d|	d}||dk|}|S )a8  
        The core idea of DMA is to calculate the dynamic attention mask to mask the tokens that should be masked, so as to form sparse attention.

        Combine `dt_states` with `attention_mask` to generate the final `attn_mask`.

        Args:
            hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision.
            dt_states (`torch.Tensor`): dt_states of shape `(batch_size, num_heads, key_sequence_length)`.
            keep_window_size (`int`): The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value.
            attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`.
        Nr3   r   ro   )rQ   r5   r   r5   rQ   TF)rU   Zlargestsorted      ?)r'   Zfinfor5   minrW   r=   rM   r!   boolwhereZtensorrQ   Zmasked_fillZ
zeros_liketopkindicesZscatter)
r+   r:   r   r   rt   Z	min_dtyper5   r   Zactive_maskZtopk_indicesr0   r0   r1   r   B  s$    2z"DogeAttention.prepare_dynamic_mask)N)NNN)r   N)r?   r@   rA   r    r   intr%   r   r'   r_   r<   r   
LongTensorr;   r   rB   r0   r0   r.   r1   r      s*      <  r   c                       s$   e Zd Z fddZdd Z  ZS )DogeMLPc                    sx   t    || _|j| _|j| _tj| j| j|jd| _tj| j| j|jd| _	tj| j| j|jd| _
t|j | _d S )Nr   )r$   r%   rF   r,   intermediate_sizer   r   mlp_bias	gate_projup_proj	down_projr   
hidden_actact_fnr+   rF   r.   r0   r1   r%   i  s    
zDogeMLP.__init__c                 C   s$   |  | | || | }|S )N)r   r   r   r   )r+   r]   r   r0   r0   r1   r;   s  s     zDogeMLP.forward)r?   r@   rA   r%   r;   rB   r0   r0   r.   r1   r   h  s   
r   c                       s6   e Zd Zed fddZejejdddZ  ZS )	DogeCDMoErE   c                    s   t    |j| _|j| _t|j | _|j| _t	t
| j| _|j| _|j| _tj| j| j|jd| _tj| j| j|jd| _tj| j| j|jd| _tj| j| jd dd| _t| j| j| _t| j| j| _d S )Nr   r2   F)r$   r%   r,   r   r   r   r   num_expertsmathfloorsqrtnum_keysnum_experts_per_toktop_knorm_topk_probr   r   r   r   r   r   router_gate	Embedding
down_embedup_embedr   r.   r0   r1   r%   y  s    
zDogeCDMoE.__init__)r:   rh   c                 K   s  |j \}}}| |d|| d}|j| jdd\\}}\}	}
|d|d }|	d| j |
d }|jg |j d d dR  }|jg |j d d dR  }|j| jdd\}}|d|}tj	|dd}| j
r||jddd }| |}| |}t|||| dd|| d}| || }t||| dd|||d}| | | || | }|| }||fS )Nr2   r3   rT   rx   T)rU   r4   r   )r=   r   r   r   r   rc   r   gatherr   r   r   sumr   r   r'   r}   r   r   r   r   )r+   r:   rw   ZbszZseq_len_router_logitsscores_xscores_y	indices_x	indices_y
all_scoresall_indicesZscoresposition_indicesr   routing_weightsr   r   Zexperts_weightsZexperts_statesr0   r0   r1   r;     s(    

&$ zDogeCDMoE.forward)	r?   r@   rA   r    r%   r'   r_   r;   rB   r0   r0   r.   r1   r   x  s   r   c                       s   e Zd Zdeee d fddZedddddej	e
ej	ej	f eej	 eej ee
ej	  ee eej ee e
ejee
ejejf  f d
	ddZ  ZS )DogeDecoderLayerNr   c                    s   t    |j| _t|j|jd| _t||d| _t	
t|j| _t|j|jd| _|jsft|nt|| _t	
t|j| _d S )Nr   r   )r$   r%   hidden_dropoutr"   r,   r   input_layernormr   	self_attnr   r&   r'   r(   input_residualpost_attention_layernormZis_moer   r   mlppost_attention_residualr   r.   r0   r1   r%     s    
zDogeDecoderLayer.__init__r   r   r   r   F)	r:   r   rt   r^   r   	use_cacher   rw   rh   c              
   K   s   |}	|  |}| jf |||||||d|\}}
tj|| j| jd}| j|	 | }|}	| |}| |}tj|| j| jd}| j	|	 | }|S )N)r:   r   rt   r^   r   r   r   ry   )
r   r   r   rv   r   r{   r   r   r   r   )r+   r:   r   rt   r^   r   r   r   rw   ZresidualZself_attn_weightsr0   r0   r1   r;     s*    




zDogeDecoderLayer.forward)N)NNNFN)r?   r@   rA   r    r   r   r%   r   r'   r_   r<   r   r   r   r   FloatTensorr;   rB   r0   r0   r.   r1   r     s$        r   c                       sb   e Zd ZU eed< dZdZdgZdgZdZ	dZ
dZdZdZeeddeed	Z fd
dZ  ZS )DogePreTrainedModelrF   modelTr   r   Fr   )index)r   r:   
attentionsc                    sl   t  | t|tr.t|drh|jj  n:t|trht|drP|j	j
d t|drh|jj
d dS )zInitialize the weightsr   r   r   r   N)r$   _init_weightsrM   r   rL   r   dataZzero_r   r   Zfill_r   )r+   rp   r.   r0   r1   r     s    




z!DogePreTrainedModel._init_weights)r?   r@   rA   r    r`   Zbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesZ_skip_keys_device_placementZ_supports_flash_attnZ_supports_sdpaZ_supports_flex_attnZ_can_compile_fullgraphZ_supports_attention_backendr   r   r   r   Z_can_record_outputsr   rB   r0   r0   r.   r1   r     s   

r   c                       st   e Zd Zed fddZeedeej	 eej
 eej	 ee eej ee eej	 ee ed	ddZ  ZS )		DogeModelrE   c                    s   t     j| _ j| _t j j| j| _t	 fddt
 jD | _t j jd| _t d| _d| _|   d S )Nc                    s   g | ]}t  |qS r0   )r   ).0r   rE   r0   r1   
<listcomp>      z&DogeModel.__init__.<locals>.<listcomp>r   rE   F)r$   r%   Zpad_token_idZpadding_idx
vocab_sizer   r   r,   embed_tokensZ
ModuleListrangenum_hidden_layerslayersr"   r   normrC   
rotary_embZgradient_checkpointing	post_initr   r.   rE   r1   r%     s    zDogeModel.__init__N)		input_idsrt   r^   r   inputs_embedsr   r   rw   rh   c              
   K   s  |d u |d uA rt d|r0|d u r0t| jd}|d u rB| |}|d u rz|d urZ| nd}	tj|	|	|jd  |jd}|d u r|	d}| jj
d u rtnt}
|
| j|||||d}|}| ||}| jd | jj D ]"}||f||||||d|}q| |}t||dS )	Nz:You must specify exactly one of input_ids or inputs_embedsrE   r   r   )rQ   )rF   Zinput_embedsrt   r   r   r^   )r   rt   r^   r   r   r   )last_hidden_stater   )
ValueErrorr	   rF   r   Zget_seq_lengthr'   Zaranger=   rQ   rc   Zsliding_windowr   r   r   r   r   r   r   )r+   r   rt   r^   r   r   r   r   rw   Zpast_seen_tokensZmask_functionr   r:   r   Zdecoder_layerr0   r0   r1   r;     sT    

	
zDogeModel.forward)NNNNNNN)r?   r@   rA   r    r%   r   r   r   r'   r   r_   r   r   r   r   r   r   r;   rB   r0   r0   r.   r1   r     s*          r   r2   )gate_logitsr   r   r   rt   rh   c                 C   sx  | du st | tsdS | d j}| d j}g }g }| D ]}	|	|}	|	j|dd\\}
}\}}|
d|d }|d| |d }|jg |jdd dR  }|jg |jdd dR  }|j|dd\}}|	d|}t
j|dd}|| || q6tj|dd}tj|dd}|du r|d}tj|||d}tj|||d}|d|||jd  }tj|dd}n|j\}}t| }|ddddddf ||||fd|}|d|  }tj|||d}tj|||d}|d||t| }|ddddddf ||||fd||}tj|| ddtj|dd }t|| }|| S )a  
    Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.

    See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
    experts is too unbalanced.

    Args:
        gate_logits:
            Logits from the `router_gate`, should be a tuple of model.config.num_hidden_layers tensors of
            shape [2, batch_size * sequence_length, num_keys].
        num_experts:
            Number of experts
        num_keys:
            Number of keys
        top_k:
            The number of experts to route per-token, can be also interpreted as the `top-k` routing
            parameter.
        attention_mask (`torch.Tensor`, *optional*):
            The attention_mask used in forward function
            shape [batch_size X sequence_length] if not None.

    Returns:
        The auxiliary loss.
    Nr   r3   rT   rx   r   )rM   r<   r5   rQ   r6   r   rc   r   r=   r   r   r   appendr'   rZ   r   Z	ones_likeZscatter_add_r9   lenrW   ri   r   r   )r  r   r   r   rt   Zcompute_dtypeZcompute_deviceZall_expert_indicesZall_routing_weightsZlayer_gate_logitsr   r   r   r   r   r   r   r   Zexpert_indicesr   Ztokens_per_expertpadZrouter_prob_per_expertZ
batch_sizeZsequence_lengthr   Zexpert_attention_maskZ router_per_expert_attention_maskZoverall_lossr0   r0   r1   load_balancing_loss_funcV  sn     








r  c                       s   e Zd ZdgZddiZddgdgfiZ fddZeede	e
j e	e
j e	e
j e	ee
j  e	e
j e	e
j e	e e	e
j eee
jf e	e ee ed
ddZ  ZS )DogeForCausalLMzlm_head.weightlm_headZcolwise_repr:   logitsc                    sX   t  | t|| _|j| _tj|j|jdd| _|j	| _	|j
| _
|j| _|   d S )NFr   )r$   r%   r   r   r   r   r   r,   r  router_aux_loss_coefr   r   r   r   r.   r0   r1   r%     s    
zDogeForCausalLM.__init__Nr   )r   rt   r^   r   r   labelsr   r   logits_to_keepoutput_router_logitsrw   rh   c              
   K   s   |
dur|
n| j j}
| jf |||||||d|}|j}t|	trPt|	 dn|	}| |dd|ddf }d}|dur| j||| j	fi |}d}|
rt
|j| jtt| j| j|}|dur|| j||j 7 }t||||j|j|j|jdS )ah  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, DogeForCausalLM

        >>> model = DogeForCausalLM.from_pretrained("SmallDoge/Doge-320M")
        >>> tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```N)r   rt   r^   r   r   r   r   )lossaux_lossr	  r   r:   r   r   )rF   r  r   r   rM   r   slicer  Zloss_functionr   r  r   r   r   r   r   r   r
  r6   rQ   r   r   r:   r   )r+   r   rt   r^   r   r   r  r   r   r  r  rw   outputsr:   Zslice_indicesr	  r  r  r0   r0   r1   r;     sN    'zDogeForCausalLM.forward)
NNNNNNNNr   N)r?   r@   rA   Z_tied_weights_keysZ_tp_planZ_pp_planr%   r   r   r   r'   r   r_   listr   r   r   r   r   r   r   r;   rB   r0   r0   r.   r1   r    s<             r  c                   @   s   e Zd ZdS )DogeForSequenceClassificationN)r?   r@   rA   r0   r0   r0   r1   r  (  s   r  )r  r   r   r  )Nr   )ro   )NNN)NNr2   N)Kr   typingr   r   r   r'   Ztorch.nn.functionalr   r~   r   Zactivationsr   Zcache_utilsr   r	   Z
generationr
   Zintegrationsr   Zintegrations.flex_attentionr   Zmasking_utilsr   r   Zmodeling_layersr   r   Zmodeling_outputsr   r   Zmodeling_rope_utilsr   r   Zmodeling_utilsr   r   Zprocessing_utilsr   utilsr   r   r   r   Zutils.deprecationr   Zutils.genericr   r   Zconfiguration_doger    Z!torch.nn.attention.flex_attentionr!   Moduler"   rC   rb   rf   r_   r   rn   rV   r   r<   r   r   r   r   r   r   r   r   r  r  r  __all__r0   r0   r0   r1   <module>   s   $
     1~93T    jg