a
    h                     @   sd  d Z ddlZddlmZmZ ddlZddlZddlmZ ddlm	Z	 ddl
mZmZmZ ddlmZ dd	lmZmZ dd
lmZ ddlmZmZ ddlmZ ddlmZmZ ddlmZ ddlm Z  e!e"Z#G dd dej$Z%G dd dej&Z'G dd dej&Z(G dd deZ)eG dd deZ*eG dd de*Z+eddG dd  d e*eZ,g d!Z-dS )"zPyTorch XGLM model.    N)OptionalUnion)nn   )ACT2FN)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)_prepare_4d_attention_mask!_prepare_4d_causal_attention_mask)GradientCheckpointingLayer))BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentions)PreTrainedModel)auto_docstringlogging)deprecate_kwarg   )
XGLMConfigc                       sF   e Zd ZdZd	eeeee d fddZej	d fddZ
  ZS )
XGLMScaledWordEmbeddingz\
    This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
          ?)num_embeddingsembedding_dimpadding_idxembed_scalec                    s   t  ||| || _d S N)super__init__r   )selfr   r   r   r   	__class__ b/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/xglm/modeling_xglm.pyr   -   s    z XGLMScaledWordEmbedding.__init__)	input_idsc                    s   t  || j S r   )r   forwardr   )r   r$   r    r"   r#   r%   1   s    zXGLMScaledWordEmbedding.forward)r   )__name__
__module____qualname____doc__intr   floatr   torchTensorr%   __classcell__r"   r"   r    r#   r   (   s   r   c                       s   e Zd ZdZdeeee d fddZdeeee dddZedeeee dd	d
Z	e
 dee
j edddZ  ZS )!XGLMSinusoidalPositionalEmbeddingzDThis module produces sinusoidal positional embeddings of any length.N)num_positionsr   r   c                    s4   t    d| _|| _|| _| || j || d S )N   )r   r   offsetr   r   make_weights)r   r0   r   r   r    r"   r#   r   8   s
    
z*XGLMSinusoidalPositionalEmbedding.__init__)r   r   r   c                 C   sB   |  |||}t| dr.|j| jj| jjd}| jd|dd d S )NweightsdtypedeviceF)
persistent)get_embeddinghasattrtor4   r6   r7   Zregister_buffer)r   r   r   r   Zemb_weightsr"   r"   r#   r3   ?   s    
z.XGLMSinusoidalPositionalEmbedding.make_weightsc                 C   s   |d }t d|d  }ttj|tjd |  }tj| tjd d|d }tjt	|t
|gdd| d}|d dkrtj|t| dgdd}|durd||ddf< |t S )	z
        Build sinusoidal embeddings.

        This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
        "Attention Is All You Need".
        r1   i'  r   )r6   r   dimN)mathlogr,   exparangeZint64r+   	unsqueezecatsincosviewZzerosr;   Zget_default_dtype)r   r   r   Zhalf_dimZembr"   r"   r#   r9   G   s     $&z/XGLMSinusoidalPositionalEmbedding.get_embeddingr   )position_idspast_key_values_lengthc                 C   sn   |  \}}|| j7 }d| | }|| j dkrD| || j| j | jd|d||| jjd 	 S )Nr1   r   r>   )
sizer2   r4   r3   r   r   Zindex_selectrG   shapedetach)r   rH   rI   bszZseq_lenZmax_posr"   r"   r#   r%   \   s    
z)XGLMSinusoidalPositionalEmbedding.forward)N)N)N)Nr   )r&   r'   r(   r)   r*   r   r   r3   staticmethodr9   r,   Zno_gradr-   r%   r.   r"   r"   r    r#   r/   5   s   r/   c                       s   e Zd ZdZdeeee ee ee ee d fddZe	d	d
ddde
jee
j ee ee
j ee
j eee
j ee
jee
j eee
j  f dddZ  ZS )XGLMAttentionz=Multi-headed attention from 'Attention Is All You Need' paper        FTN)	embed_dim	num_headsdropout
is_decoderbias	layer_idxc                    s   t    || _|| _|| _|| | _| j| | jkrNtd| j d| d| jd | _|| _|| _	t
j|||d| _t
j|||d| _t
j|||d| _t
j|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      ࿩rU   )r   r   rQ   rR   rS   head_dim
ValueErrorscalingrT   rV   r   Lineark_projv_projq_projout_proj)r   rQ   rR   rS   rT   rU   rV   r    r"   r#   r   l   s$    	


zXGLMAttention.__init__past_key_valuepast_key_values4.58new_nameversion)hidden_stateskey_value_statesra   attention_masklayer_head_maskoutput_attentionscache_positionreturnc                 C   s  |du}|  \}	}
}|r$|jd n|
}| || j }|durpt|trl|j| j}|rd|j	}qp|j
}n|}|rx|n|}|r|dur|r|j| j j}|j| j j}n| |}| |}||	|d| jdd}||	|d| jdd}|dur6|s|nd}|||| jd|i\}}|r6d|j| j< |	| j d| jf}||	|
| j| jdd}|j| }|j| }|j| }| d}t||dd}|  |	| j |
|fkrtd|	| j |
|f d|   |durj|  |	d|
|fkrtd	|	d|
|f d|   ||	| j|
|| }t|tjt|jj|jd
}||	| j |
|}|jtjkrt j!j"|dtj#d$tj}nt j!j"|dd}|dur|  | jfkrtd| jf d|   |dddd||	| j|
| }||	| j |
|}|rB||	| j|
|}||	| j |
|}nd}t j!j%|| j%| j&d}t||}|  |	| j |
| jfkrtd|	| j|
| jf d|   ||	| j|
| j}|dd}||	|
| j'}| (|}||fS )z#Input shape: Batch x Time x ChannelNr   r>   r1   rk   Tz$Attention weights should be of size z	, but is z!Attention mask should be of size )r7   )r=   r6   r<   z/Head mask for a single layer should be of size ptrainingz `attn_output` should be of size ))rJ   rK   r^   rZ   
isinstancer	   
is_updatedgetrV   Zcross_attention_cacheZself_attention_cachelayerskeysvaluesr\   r]   rG   rX   Z	transposeupdaterR   Zreshaper,   ZbmmrY   maxZtensorZfinfor6   minr7   Zfloat16r   
functionalZsoftmaxZfloat32r;   rS   ro   rQ   r_   )r   rf   rg   ra   rh   ri   rj   rk   Zis_cross_attentionrM   tgt_len_Zsrc_lenZquery_statesrq   Zcurr_past_key_valueZcurrent_statesZ
key_statesZvalue_statesZ
proj_shapeZattn_weightsZattn_weights_reshapedZ
attn_probsZattn_outputr"   r"   r#   r%      s    









"
zXGLMAttention.forward)rP   FTN)NNNNFN)r&   r'   r(   r)   r*   r   r+   boolr   r   r,   r-   r   tupler%   r.   r"   r"   r    r#   rO   i   s<             rO   c                       s   e Zd Zded fddZedddddejeej eej eej eej eej ee	 ee
 ee
 eej ejdddZ  ZS )XGLMDecoderLayerNconfigc                    s   t    |j| _t| j|j|jd|d| _|j| _t	|j
 | _|j| _|jrvt| j|j|jd|d| _t| j| _t| j| _t| j|j| _t|j| j| _t| j| _d S )NT)rQ   rR   rS   rT   rV   )r   r   d_modelrQ   rO   Zattention_headsZattention_dropout	self_attnrS   r   Zactivation_functionactivation_fnactivation_dropoutZadd_cross_attentionencoder_attnr   	LayerNormencoder_attn_layer_normself_attn_layer_normr[   Zffn_dimfc1fc2final_layer_norm)r   r   rV   r    r"   r#   r   
  s2    
zXGLMDecoderLayer.__init__r`   ra   rb   rc   FT)rf   rh   encoder_hidden_statesencoder_attention_maskri   cross_attn_layer_head_maskra   rj   	use_cacherk   rl   c              	   C   s  |}|  |}| j||||||
d\}}tjj|| j| jd}|| }d}|dur|}| |}| j|||||||
d\}}tjj|| j| jd}|| }|}| |}| 	| 
|}tjj|| j| jd}| |}tjj|| j| jd}|| }|f}|r|||f7 }|S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            encoder_hidden_states (`torch.FloatTensor`):
                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
                size `(decoder_attention_heads,)`.
            past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )rf   ra   rh   ri   rj   rk   rm   N)rf   rg   rh   ri   ra   rj   rk   )r   r   r   ry   rS   ro   r   r   r   r   r   r   r   )r   rf   rh   r   r   ri   r   ra   rj   r   rk   ZresidualZself_attn_weightsZcross_attn_weightsoutputsr"   r"   r#   r%   (  sL    !



	

zXGLMDecoderLayer.forward)N)	NNNNNNFTN)r&   r'   r(   r   r   r   r,   r-   r   r   r|   r%   r.   r"   r"   r    r#   r~   	  s0            r~   c                   @   s,   e Zd ZU eed< dZdZdgZdd ZdS )XGLMPreTrainedModelr   modelTr~   c                 C   s|   | j j}t|tjr>|jjjd|d |jd urx|jj	  n:t|tj
rx|jjjd|d |jd urx|jj|j 	  d S )NrP   )meanstd)r   Zinit_stdrp   r   r[   weightdataZnormal_rU   Zzero_	Embeddingr   )r   moduler   r"   r"   r#   _init_weights  s    

z!XGLMPreTrainedModel._init_weightsN)	r&   r'   r(   r   __annotations__base_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesr   r"   r"   r"   r#   r   {  s
   
r   c                       s   e Zd Zdeeej d fddZed	ee	j
 ee	j
 ee	j
 ee	j
 ee	j
 ee	j
 ee	j
 eee	j  ee	j
 ee ee ee ee ee	j
 eee	j
 ef dddZ  ZS )
	XGLMModelN)r   embed_tokensc                    s   t     j| _ j| _ j| _ j| _ jr>t	
 jnd}|durR|| _nt j j| j|d| _t j j j| _t fddt jD | _t j| _d| _|   dS )zZ
        embed_tokens (`nn.Embedding`, *optional*):
            output embeddings
        r   N)r   c                    s   g | ]}t  |d qS ))rV   )r~   ).0ir   r"   r#   
<listcomp>      z&XGLMModel.__init__.<locals>.<listcomp>F)r   r   rS   	layerdroppad_token_idr   Zmax_position_embeddingsZmax_target_positionsZscale_embeddingr?   sqrtr   r   r   
vocab_sizer/   embed_positionsr   Z
ModuleListrangeZ
num_layersrs   r   
layer_normgradient_checkpointing	post_init)r   r   r   r   r    r   r#   r     s(     zXGLMModel.__init__)r$   rh   rH   r   r   	head_maskcross_attn_head_maskra   inputs_embedsr   rj   output_hidden_statesreturn_dictrk   rl   c                 C   s  |dur|n| j j}|dur |n| j j}|
dur4|
n| j j}
|durH|n| j j}|durj|	durjtdnP|dur| || | }|d|d }n"|	dur|	 dd }ntd|	du r| 	|}	| j
r| jr|
rtd d}
|
r*|du r*|durtt| j dt| j dn
t| j d}|
rPt|trPtd t|}|durb| nd	}t|||	|}|du rtj||d | tj|dur|jn|	jd
}|d	}|dur|durt||	j|d d}|	| |||	j }tjj |t!| j | jd}|rdnd}|r(dnd}|r@|dur@dnd}t"||gddgD ]V\}}|durV| d	 t#| j$krVtd| dt#| j$ d| d	  dqVt%| j$D ]\}}|r||f7 }| jrt&g }|| j'k rq||||||dur|| nd|dur"|| nd|||
|d
}|d	 }|r||d f7 }|dur||d f7 }q| (|}|r||f7 }|stdd |||||fD S t)|||||dS )a  
        encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
            the decoder.
        encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
            Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
            selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        cross_attn_head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        NzDYou cannot specify both input_ids and inputs_embeds at the same timer>   z5You have to specify either input_ids or inputs_embedsz_`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`...Fr   zPassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.r   r5   )rz   rm   r"   r   r   zThe `z` should be specified for z layers, but it is for .)r   ri   r   ra   rj   r   rk   r   r1   c                 s   s   | ]}|d ur|V  qd S r   r"   )r   vr"   r"   r#   	<genexpr>J  s   z$XGLMModel.forward.<locals>.<genexpr>)Zlast_hidden_statera   rf   
attentionscross_attentions)*r   rj   r   r   use_return_dictrY   Z%warn_if_padding_and_no_attention_maskrJ   rG   r   r   ro   loggerZwarning_oncer	   r   rp   r}   Zfrom_legacy_cacheZget_seq_lengthr   r,   rB   longr7   rC   r   r6   r   r;   r   ry   rS   r+   ziplenrs   	enumerateZrandr   r   r   )r   r$   rh   rH   r   r   r   r   ra   r   r   rj   r   r   rk   Zinput_shaperI   rf   Zall_hidden_statesZall_self_attnsZall_cross_attentionsZ	attn_maskZ	mask_nameidxZdecoder_layerZdropout_probabilityZlayer_outputsr"   r"   r#   r%     s    $













zXGLMModel.forward)N)NNNNNNNNNNNNNN)r&   r'   r(   r   r   r   r   r   r   r,   r-   listFloatTensorr|   r   r}   r   r%   r.   r"   r"   r    r#   r     sB                 r   z
    The XGLM Model transformer with a language modeling head on top (linear layer with weights tied to the input
    embeddings).
    )Zcustom_introc                       s   e Zd ZdZdgZ fddZed	eej	 eej	 eej	 eej	 eej	 eej	 eej	 ee
ej  eej	 eej	 ee ee ee ee eej	 eeej	 ef dddZ  ZS )
XGLMForCausalLMr   zlm_head.weightc                    s8   t  | t|| _tj|j|jdd| _| 	  d S )NFrW   )
r   r   r   r   r   r[   Zhidden_sizer   lm_headr   )r   r   r    r"   r#   r   b  s    
zXGLMForCausalLM.__init__N)r$   rh   rH   r   r   r   r   ra   r   labelsr   rj   r   r   rk   rl   c                 K   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}| j|||||||||	|||||d}| |d }d}|
dur| j||
f| j j| j jd|}|s|f|dd  }|dur|f| S |S t	|||j
|j|j|jdS )a  
        encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
            the decoder.
        encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
            Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
            selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        cross_attn_head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        N)r$   rh   rH   r   r   r   r   ra   r   r   rj   r   r   rk   r   )r   r   r   )losslogitsra   rf   r   r   )r   rj   r   r   r   r   Zloss_functionr   r   r   ra   rf   r   r   )r   r$   rh   rH   r   r   r   r   ra   r   r   r   rj   r   r   rk   kwargsr   r   r   outputr"   r"   r#   r%   j  sV    +zXGLMForCausalLM.forward)NNNNNNNNNNNNNNN)r&   r'   r(   r   Z_tied_weights_keysr   r   r   r,   r-   r   r   r|   r   r}   r   r%   r.   r"   r"   r    r#   r   X  sJ                  r   )r   r   r   ).r)   r?   typingr   r   r,   Ztorch.utils.checkpointr   Zactivationsr   Zcache_utilsr   r   r	   Z
generationr
   Zmodeling_attn_mask_utilsr   r   Zmodeling_layersr   Zmodeling_outputsr   r   Zmodeling_utilsr   utilsr   r   Zutils.deprecationr   Zconfiguration_xglmr   Z
get_loggerr&   r   r   r   Moduler/   rO   r~   r   r   r   __all__r"   r"   r"   r#   <module>   s>   
4 !r Ji