a
    h                     @   s  d Z ddlmZmZ ddlZddlZddlmZ ddlmZm	Z	m
Z
 ddlmZ ddlmZ dd	lmZmZmZmZmZ dd
lmZ ddlmZmZ ddlmZ eeZG dd dejZ G dd dejZ!ej"j#dd Z$ej"j#dd Z%ej"j#dd Z&ej"j#dd Z'ej"j#ej(e)dddZ*ej"j#ej(ej(dddZ+ej"j#ej(ej(e)d d!d"Z,ej"j#ej(ej(dd#d$Z-G d%d& d&ejZ.G d'd( d(ejZ/G d)d* d*ejZ0G d+d, d,ejZ1G d-d. d.ejZ2G d/d0 d0eZ3G d1d2 d2ejZ4eG d3d4 d4eZ5eG d5d6 d6e5Z6G d7d8 d8ejZ7G d9d: d:ejZ8G d;d< d<ejZ9G d=d> d>ejZ:G d?d@ d@ejZ;eG dAdB dBe5Z<G dCdD dDejZ=edEdFG dGdH dHe5Z>eG dIdJ dJe5Z?eG dKdL dLe5Z@g dMZAdS )NzPyTorch DeBERTa model.    )OptionalUnionN)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputMaskedLMOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel)auto_docstringlogging   )DebertaConfigc                       s*   e Zd ZdZd fdd	Zdd Z  ZS )DebertaLayerNormzBLayerNorm module in the TF style (epsilon inside the square root).-q=c                    s8   t    tt|| _tt|| _|| _	d S N)
super__init__r   	Parametertorchonesweightzerosbiasvariance_epsilon)selfsizeeps	__class__ h/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/deberta/modeling_deberta.pyr   ,   s    
zDebertaLayerNorm.__init__c                 C   sj   |j }| }|jddd}|| djddd}|| t|| j  }||}| j| | j	 }|S )NT)Zkeepdim   )
dtypefloatmeanpowr   sqrtr    tor   r   )r!   hidden_statesZ
input_typer,   Zvarianceyr&   r&   r'   forward2   s    
zDebertaLayerNorm.forward)r   __name__
__module____qualname____doc__r   r2   __classcell__r&   r&   r$   r'   r   )   s   r   c                       s$   e Zd Z fddZdd Z  ZS )DebertaSelfOutputc                    s>   t    t|j|j| _t|j|j| _t	|j
| _d S r   )r   r   r   Linearhidden_sizedenser   layer_norm_eps	LayerNormDropouthidden_dropout_probdropoutr!   configr$   r&   r'   r   >   s    
zDebertaSelfOutput.__init__c                 C   s&   |  |}| |}| || }|S r   r<   rA   r>   r!   r0   Zinput_tensorr&   r&   r'   r2   D   s    

zDebertaSelfOutput.forwardr4   r5   r6   r   r2   r8   r&   r&   r$   r'   r9   =   s   r9   c                 C   s   |  d}| d}tj|tj| jd}tj|tj|jd}|dddf |dd|d }|d|ddf }|d}|S )a  
    Build relative position according to the query and key

    We assume the absolute position of query \(P_q\) is range from (0, query_size) and the absolute position of key
    \(P_k\) is range from (0, key_size), The relative positions from query to key is \(R_{q \rightarrow k} = P_q -
    P_k\)

    Args:
        query_size (int): the length of query
        key_size (int): the length of key

    Return:
        `torch.LongTensor`: A tensor with shape [1, query_size, key_size]

    r*   deviceNr   r(   r   )r"   r   arangelongrI   viewrepeat	unsqueeze)query_layer	key_layerZ
query_sizeZkey_sizeZq_idsZk_idsZrel_pos_idsr&   r&   r'   build_relative_positionK   s    

$
rQ   c                 C   s*   |  |d|d|d|dgS )Nr   r   r)   r(   expandr"   )c2p_posrO   relative_posr&   r&   r'   c2p_dynamic_expandh   s    rV   c                 C   s*   |  |d|d|d|dgS )Nr   r   rG   rR   )rT   rO   rP   r&   r&   r'   p2c_dynamic_expandm   s    rW   c                 C   s*   |  | d d | d|df S )Nr)   rG   rR   )	pos_indexp2c_attrP   r&   r&   r'   pos_dynamic_expandr   s    rZ   rO   scale_factorc                 C   s    t t j| dt jd| S )Nr(   r*   )r   r.   tensorr"   r+   r[   r&   r&   r'   scaled_size_sqrtz   s    r_   )rO   rP   c                 C   s&   |  d| dkrt| |S |S d S NrG   )r"   rQ   )rO   rP   rU   r&   r&   r'   
build_rpos   s    
ra   rO   rP   max_relative_positionsc                 C   s"   t tt| d|d|S r`   )r   r^   minmaxr"   rb   r&   r&   r'   compute_attention_span   s    rf   c                 C   sV   | d| dkrN|d d d d d d df d}tj| dt|| |dS | S d S )NrG   r   r(   r)   dimindex)r"   rN   r   gatherrZ   )rY   rO   rP   rU   rX   r&   r&   r'   uneven_size_corrected   s    "rk   c                       s   e Zd ZdZ fddZdd Zdejejee	ej e	ej e	ej e
eje	ej f dd	d
ZejejejejedddZ  ZS )DisentangledSelfAttentiona  
    Disentangled self-attention module

    Parameters:
        config (`str`):
            A model config class instance with the configuration to build a new model. The schema is similar to
            *BertConfig*, for more details, please refer [`DebertaConfig`]

    c                    s  t    |j|j dkr4td|j d|j d|j| _t|j|j | _| j| j | _tj	|j| jd dd| _
ttj| jtjd| _ttj| jtjd| _|jd ur|jng | _t|d	d| _t|d
d| _| jrtj	|j|jdd| _tj	|j|jdd| _nd | _d | _| jrt|dd| _| jdk rH|j| _t|j| _d| jv rxtj	|j| jdd| _d| jv rt	|j| j| _t|j| _d S )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ()r   Fr   r]   relative_attentiontalking_headrc   r(   r   c2pp2c) r   r   r;   num_attention_heads
ValueErrorintZattention_head_sizeZall_head_sizer   r:   in_projr   r   r   r+   q_biasv_biaspos_att_typegetattrro   rp   head_logits_projhead_weights_projrc   max_position_embeddingsr?   r@   pos_dropoutpos_proj
pos_q_projZattention_probs_dropout_probrA   rB   r$   r&   r'   r      s>    

z"DisentangledSelfAttention.__init__c                 C   s4   |  d d | jdf }||}|ddddS )Nr(   r   r)   r   r   )r"   rs   rL   permute)r!   xZnew_x_shaper&   r&   r'   transpose_for_scores   s    
z.DisentangledSelfAttention.transpose_for_scoresFN)r0   attention_maskoutput_attentionsquery_statesrU   rel_embeddingsreturnc                    s  |du r.  |} |jddd\}}	}
n j jj jd dd fddtdD }t|d | j	|d j
d}t|d	 | j	|d	 j
d}t|d
 | j	|d
 j
d} fdd|||fD \}}	}
|  jddddf  }|
  jddddf  }
d}d	t j }t||}||j	|j
d }t||	dd} jr|dur|dur |} ||	|||}|dur|| } jdur؈ |dd
dd	ddd	d
}| }|| t|j
j}tjj|dd} |} jdur@ |dd
dd	ddd	d
}t||
}|dd
d	d }|  dd d }|!|}|s|dfS ||fS )a  
        Call the module

        Args:
            hidden_states (`torch.FloatTensor`):
                Input states to the module usually the output from previous layer, it will be the Q,K and V in
                *Attention(Q,K,V)*

            attention_mask (`torch.BoolTensor`):
                An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
                sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
                th token.

            output_attentions (`bool`, *optional*):
                Whether return the attention matrix.

            query_states (`torch.FloatTensor`, *optional*):
                The *Q* state in *Attention(Q,K,V)*.

            relative_pos (`torch.LongTensor`):
                The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
                values ranging in [*-max_relative_positions*, *max_relative_positions*].

            rel_embeddings (`torch.FloatTensor`):
                The embedding of relative distances. It's a tensor of shape [\(2 \times
                \text{max_relative_positions}\), *hidden_size*].


        Nr   r(   rh   r   c                    s0   g | ]( t j fd dtjD ddqS )c                    s   g | ]}|d     qS )r   r&   ).0i)kwsr&   r'   
<listcomp>       z@DisentangledSelfAttention.forward.<locals>.<listcomp>.<listcomp>r   r   )r   catrangers   )r   r!   r   )r   r'   r      r   z5DisentangledSelfAttention.forward.<locals>.<listcomp>r]   r   r)   c                    s   g | ]}  |qS r&   )r   )r   r   r!   r&   r'   r      r   rG   r(   )"rv   r   chunkr   rs   r   r   matmultr/   r*   rw   rx   lenry   r_   	transposero   r~   disentangled_att_biasr{   r   boolZmasked_fillZfinford   r   Z
functionalZsoftmaxrA   r|   
contiguousr"   rL   )r!   r0   r   r   r   rU   r   ZqprO   rP   Zvalue_layerZqkvwqr   vZrel_attr\   scaleZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shaper&   r   r'   r2      sH    &
"""


"
"
z!DisentangledSelfAttention.forward)rO   rP   rU   r   r\   c                 C   s  |d u rt |||j}| dkr4|dd}n6| dkrL|d}n| dkrjtd|  t||| j}| }|| j| | j| d d f d}d}d| jv r| 	|}| 
|}t||dd	}	t|| d|d d }
tj|	dt|
||d
}	||	7 }d| jv r| |}| 
|}|t|| }t|||}t| | d|d d }t||dd	j|jd}tj|dt|||d
dd	}t||||}||7 }|S )Nr)   r   r   r      z2Relative position ids must be of dim 2 or 3 or 4. rq   r(   rG   rg   rr   r]   )rQ   rI   rh   rN   rt   rf   rc   rK   ry   r   r   r   r   r   clamprj   rV   r   r_   ra   r/   r*   rW   rk   )r!   rO   rP   rU   r   r\   Zatt_spanZscoreZpos_key_layerZc2p_attrT   Zpos_query_layerZr_posZp2c_posrY   r&   r&   r'   r   %  sT    



z/DisentangledSelfAttention.disentangled_att_bias)FNNN)r4   r5   r6   r7   r   r   r   Tensorr   r   tupler2   ru   r   r8   r&   r&   r$   r'   rl      s,   
&	    Yrl   c                       s*   e Zd ZdZ fddZdddZ  ZS )DebertaEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    t|dd}t|d|j| _tj|j| j|d| _t|dd| _	| j	sXd | _
nt|j| j| _
|jdkrt|j| j| _nd | _| j|jkrtj| j|jdd| _nd | _t|j|j| _t|j| _|| _| jd	t|jd
dd d S )Npad_token_idr   embedding_size)padding_idxposition_biased_inputTFrn   position_ids)r   r(   )
persistent)r   r   rz   r;   r   r   	Embedding
vocab_sizeword_embeddingsr   position_embeddingsr}   Ztype_vocab_sizetoken_type_embeddingsr:   
embed_projr   r=   r>   r?   r@   rA   rC   Zregister_bufferr   rJ   rS   )r!   rC   r   r$   r&   r'   r   a  s(    

zDebertaEmbeddings.__init__Nc                 C   sN  |d ur|  }n|  d d }|d }|d u rH| jd d d |f }|d u rftj|tj| jjd}|d u rx| |}| jd ur| | }n
t|}|}	| j	r|	| }	| j
d ur| 
|}
|	|
 }	| jd ur| |	}	| |	}	|d ur@| |	 kr,| dkr"|dd}|d}||	j}|	| }	| |	}	|	S )Nr(   r   rH   r   r)   )r"   r   r   r   rK   rI   r   r   Z
zeros_liker   r   r   r>   rh   squeezerN   r/   r*   rA   )r!   	input_idstoken_type_idsr   maskinputs_embedsinput_shapeZ
seq_lengthr   
embeddingsr   r&   r&   r'   r2     s>    











zDebertaEmbeddings.forward)NNNNNr3   r&   r&   r$   r'   r   ^  s   r   c                       s>   e Zd Z fddZdeeejeej f dddZ	  Z
S )	DebertaAttentionc                    s(   t    t|| _t|| _|| _d S r   )r   r   rl   r!   r9   outputrC   rB   r$   r&   r'   r     s    


zDebertaAttention.__init__FNr   r   c           
      C   sJ   | j ||||||d\}}|d u r&|}| ||}	|r>|	|fS |	d fS d S )N)r   rU   r   )r!   r   )
r!   r0   r   r   r   rU   r   Zself_output
att_matrixattention_outputr&   r&   r'   r2     s    	
zDebertaAttention.forward)FNNNr4   r5   r6   r   r   r   r   r   r   r2   r8   r&   r&   r$   r'   r     s   
    r   c                       s0   e Zd Z fddZejejdddZ  ZS )DebertaIntermediatec                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S r   )r   r   r   r:   r;   intermediate_sizer<   
isinstance
hidden_actstrr	   intermediate_act_fnrB   r$   r&   r'   r     s
    
zDebertaIntermediate.__init__)r0   r   c                 C   s   |  |}| |}|S r   )r<   r   r!   r0   r&   r&   r'   r2     s    

zDebertaIntermediate.forwardr4   r5   r6   r   r   r   r2   r8   r&   r&   r$   r'   r     s   r   c                       s$   e Zd Z fddZdd Z  ZS )DebertaOutputc                    sD   t    t|j|j| _t|j|j| _	t
|j| _|| _d S r   )r   r   r   r:   r   r;   r<   r   r=   r>   r?   r@   rA   rC   rB   r$   r&   r'   r     s
    
zDebertaOutput.__init__c                 C   s&   |  |}| |}| || }|S r   rD   rE   r&   r&   r'   r2     s    

zDebertaOutput.forwardrF   r&   r&   r$   r'   r     s   r   c                       s>   e Zd Z fddZdeeejeej f dddZ	  Z
S )	DebertaLayerc                    s,   t    t|| _t|| _t|| _d S r   )r   r   r   	attentionr   intermediater   r   rB   r$   r&   r'   r     s    


zDebertaLayer.__init__NFr   c                 C   sH   | j ||||||d\}}| |}	| |	|}
|r<|
|fS |
d fS d S )Nr   r   rU   r   )r   r   r   )r!   r0   r   r   rU   r   r   r   r   Zintermediate_outputZlayer_outputr&   r&   r'   r2     s    	

zDebertaLayer.forward)NNNFr   r&   r&   r$   r'   r     s   
    r   c                       sV   e Zd ZdZ fddZdd Zdd Zdd	d
Zdej	ej	e
e
e
dddZ  ZS )DebertaEncoderz8Modified BertEncoder with relative position bias supportc                    s~   t    t fddt jD | _t dd| _| jrtt dd| _	| j	dk r^ j
| _	t| j	d  j| _d| _d S )	Nc                    s   g | ]}t  qS r&   )r   r   _rC   r&   r'   r     r   z+DebertaEncoder.__init__.<locals>.<listcomp>ro   Frc   r(   r   r)   )r   r   r   Z
ModuleListr   Znum_hidden_layerslayerrz   ro   rc   r}   r   r;   r   Zgradient_checkpointingrB   r$   r   r'   r     s    
 
zDebertaEncoder.__init__c                 C   s   | j r| jjnd }|S r   )ro   r   r   )r!   r   r&   r&   r'   get_rel_embedding  s    z DebertaEncoder.get_rel_embeddingc                 C   sL   |  dkr2|dd}||dd }n|  dkrH|d}|S )Nr)   r   rG   r(   r   )rh   rN   r   )r!   r   Zextended_attention_maskr&   r&   r'   get_attention_mask#  s    
z!DebertaEncoder.get_attention_maskNc                 C   s0   | j r,|d u r,|d ur"t||}n
t||}|S r   )ro   rQ   )r!   r0   r   rU   r&   r&   r'   get_rel_pos,  s
    
zDebertaEncoder.get_rel_posTF)r0   r   output_hidden_statesr   return_dictc              	   C   s   |  |}| |||}|r"|fnd }|r.dnd }	|}
|  }t| jD ]N\}}||
|||||d\}}|rv||f }|d ur|}n|}
|rH|	|f }	qH|stdd |||	fD S t|||	dS )Nr&   )r   rU   r   r   c                 s   s   | ]}|d ur|V  qd S r   r&   )r   r   r&   r&   r'   	<genexpr>]  r   z)DebertaEncoder.forward.<locals>.<genexpr>Zlast_hidden_stater0   
attentions)r   r   r   	enumerater   r   r   )r!   r0   r   r   r   r   rU   r   Zall_hidden_statesZall_attentionsZnext_kvr   r   Zlayer_moduleZatt_mr&   r&   r'   r2   4  s6    


	
zDebertaEncoder.forward)NN)TFNNT)r4   r5   r6   r7   r   r   r   r   r   r   r   r2   r8   r&   r&   r$   r'   r     s"   	
     r   c                   @   s,   e Zd ZU eed< dZdgZdZdd ZdS )DebertaPreTrainedModelrC   debertar   Tc                 C   s   t |tjr:|jjjd| jjd |jdur|jj	  nt |tj
rz|jjjd| jjd |jdur|jj|j 	  njt |tjtfr|jjd |jj	  n>t |tr|jj	  |jj	  nt |ttfr|jj	  dS )zInitialize the weights.g        )r,   ZstdNg      ?)r   r   r:   r   dataZnormal_rC   Zinitializer_ranger   Zzero_r   r   r>   r   Zfill_rl   rw   rx   LegacyDebertaLMPredictionHeadDebertaLMPredictionHead)r!   moduler&   r&   r'   _init_weightsj  s     


z$DebertaPreTrainedModel._init_weightsN)	r4   r5   r6   r   __annotations__Zbase_model_prefixZ"_keys_to_ignore_on_load_unexpectedZsupports_gradient_checkpointingr   r&   r&   r&   r'   r   c  s
   
r   c                       s   e Zd Z fddZdd Zdd Zdd Zedee	j
 ee	j
 ee	j
 ee	j
 ee	j
 ee ee ee eeef d
	ddZ  ZS )DebertaModelc                    s8   t  | t|| _t|| _d| _|| _|   d S Nr   )	r   r   r   r   r   encoderz_stepsrC   	post_initrB   r$   r&   r'   r     s    

zDebertaModel.__init__c                 C   s   | j jS r   r   r   r   r&   r&   r'   get_input_embeddings  s    z!DebertaModel.get_input_embeddingsc                 C   s   || j _d S r   r   r!   Znew_embeddingsr&   r&   r'   set_input_embeddings  s    z!DebertaModel.set_input_embeddingsc                 C   s   t ddS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        z7The prune function is not implemented in DeBERTa model.N)NotImplementedError)r!   Zheads_to_pruner&   r&   r'   _prune_heads  s    zDebertaModel._prune_headsN)	r   r   r   r   r   r   r   r   r   c	              	      s  |d ur|n j j}|d ur |n j j}|d ur4|n j j}|d urV|d urVtdn@|d urt || | }	n"|d ur| d d }	ntd|d ur|jn|j}
|d u rtj	|	|
d}|d u rtj
|	tj|
d} j|||||d} j||d||d}|d	 } jd	kr|d
 } fddt jD }|d } j } j|} j|}|d	d  D ]$}|||d|||d}|| ql|d }|s|f||rd	ndd   S t||r|jnd |jdS )NzDYou cannot specify both input_ids and inputs_embeds at the same timer(   z5You have to specify either input_ids or inputs_embeds)rI   rH   )r   r   r   r   r   T)r   r   r   r   rG   c                    s   g | ]} j jd  qS r   )r   r   r   r   r&   r'   r     r   z(DebertaModel.forward.<locals>.<listcomp>Fr   r)   r   )rC   r   r   use_return_dictrt   Z%warn_if_padding_and_no_attention_maskr"   rI   r   r   r   rK   r   r   r   r   r   r   r   appendr   r0   r   )r!   r   r   r   r   r   r   r   r   r   rI   Zembedding_outputZencoder_outputsZencoded_layersr0   Zlayersr   r   Zrel_posr   sequence_outputr&   r   r'   r2     sr    


zDebertaModel.forward)NNNNNNNN)r4   r5   r6   r   r   r   r   r   r   r   r   r   r   r   r   r2   r8   r&   r&   r$   r'   r     s0   
        
r   c                       s$   e Zd Z fddZdd Z  ZS )$LegacyDebertaPredictionHeadTransformc                    sf   t    t|d|j| _t|j| j| _t|j	t
rFt|j	 | _n|j	| _tj| j|jd| _d S )Nr   )r#   )r   r   rz   r;   r   r   r:   r<   r   r   r   r	   transform_act_fnr>   r=   rB   r$   r&   r'   r     s    
z-LegacyDebertaPredictionHeadTransform.__init__c                 C   s"   |  |}| |}| |}|S r   )r<   r   r>   r   r&   r&   r'   r2     s    


z,LegacyDebertaPredictionHeadTransform.forwardrF   r&   r&   r$   r'   r     s   r   c                       s,   e Zd Z fddZdd Zdd Z  ZS )r   c                    s\   t    t|| _t|d|j| _tj| j|j	dd| _
tt|j	| _| j| j
_d S )Nr   Frn   )r   r   r   	transformrz   r;   r   r   r:   r   decoderr   r   r   r   rB   r$   r&   r'   r     s    

z&LegacyDebertaLMPredictionHead.__init__c                 C   s   | j | j_ d S r   )r   r   r   r&   r&   r'   _tie_weights  s    z*LegacyDebertaLMPredictionHead._tie_weightsc                 C   s   |  |}| |}|S r   )r   r   r   r&   r&   r'   r2     s    

z%LegacyDebertaLMPredictionHead.forward)r4   r5   r6   r   r   r2   r8   r&   r&   r$   r'   r     s   r   c                       s0   e Zd Z fddZejejdddZ  ZS )LegacyDebertaOnlyMLMHeadc                    s   t    t|| _d S r   )r   r   r   predictionsrB   r$   r&   r'   r     s    
z!LegacyDebertaOnlyMLMHead.__init__)r   r   c                 C   s   |  |}|S r   )r   )r!   r   prediction_scoresr&   r&   r'   r2     s    
z LegacyDebertaOnlyMLMHead.forwardr   r&   r&   r$   r'   r     s   r   c                       s(   e Zd ZdZ fddZdd Z  ZS )r   zMhttps://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/bert.py#L270c                    sl   t    t|j|j| _t|jtr6t	|j | _
n|j| _
tj|j|jdd| _tt|j| _d S )NT)r#   Zelementwise_affine)r   r   r   r:   r;   r<   r   r   r   r	   r   r>   r=   r   r   r   r   r   rB   r$   r&   r'   r   $  s    
z DebertaLMPredictionHead.__init__c                 C   s:   |  |}| |}| |}t||j | j }|S r   )r<   r   r>   r   r   r   r   r   )r!   r0   r   r&   r&   r'   r2   2  s    

zDebertaLMPredictionHead.forwardr3   r&   r&   r$   r'   r   !  s   r   c                       s$   e Zd Z fddZdd Z  ZS )DebertaOnlyMLMHeadc                    s   t    t|| _d S r   )r   r   r   lm_headrB   r$   r&   r'   r   =  s    
zDebertaOnlyMLMHead.__init__c                 C   s   |  ||}|S r   )r   )r!   r   r   r   r&   r&   r'   r2   B  s    zDebertaOnlyMLMHead.forwardrF   r&   r&   r$   r'   r   <  s   r   c                       s   e Zd ZddgZ fddZdd Zdd Zedee	j
 ee	j
 ee	j
 ee	j
 ee	j
 ee	j
 ee ee ee eeef d

ddZ  ZS )DebertaForMaskedLMzcls.predictions.decoder.weightzcls.predictions.decoder.biasc                    sP   t  | |j| _t|| _| jr0t|| _nddg| _t|| _	| 
  d S )Nzlm_predictions.lm_head.weightz)deberta.embeddings.word_embeddings.weight)r   r   legacyr   r   r   cls_tied_weights_keysr   lm_predictionsr   rB   r$   r&   r'   r   K  s    


zDebertaForMaskedLM.__init__c                 C   s   | j r| jjjS | jjjS d S r   )r   r   r   r   r   r   r<   r   r&   r&   r'   get_output_embeddingsX  s    
z(DebertaForMaskedLM.get_output_embeddingsc                 C   s8   | j r|| jj_|j| jj_n|| jj_|j| jj_d S r   )r   r   r   r   r   r   r   r<   r   r&   r&   r'   set_output_embeddings^  s
    

z(DebertaForMaskedLM.set_output_embeddingsN
r   r   r   r   r   labelsr   r   r   r   c
              
   C   s   |	dur|	n| j j}	| j||||||||	d}
|
d }| jrH| |}n| || jjj}d}|durt }||	d| j j
|	d}|	s|f|
dd  }|dur|f| S |S t|||
j|
jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        Nr   r   r   r   r   r   r   r   r(   r   losslogitsr0   r   )rC   r   r   r   r   r   r   r   r   rL   r   r   r0   r   )r!   r   r   r   r   r   r  r   r   r   outputsr   r   Zmasked_lm_lossloss_fctr   r&   r&   r'   r2   f  s8    zDebertaForMaskedLM.forward)	NNNNNNNNN)r4   r5   r6   r   r   r   r   r   r   r   r   r   r   r   r   r2   r8   r&   r&   r$   r'   r   G  s4            
r   c                       s0   e Zd Z fddZdd Zedd Z  ZS )ContextPoolerc                    s4   t    t|j|j| _t|j| _|| _	d S r   )
r   r   r   r:   Zpooler_hidden_sizer<   r?   Zpooler_dropoutrA   rC   rB   r$   r&   r'   r     s    
zContextPooler.__init__c                 C   s8   |d d df }|  |}| |}t| jj |}|S r   )rA   r<   r	   rC   Zpooler_hidden_act)r!   r0   Zcontext_tokenpooled_outputr&   r&   r'   r2     s
    

zContextPooler.forwardc                 C   s   | j jS r   )rC   r;   r   r&   r&   r'   
output_dim  s    zContextPooler.output_dim)r4   r5   r6   r   r2   propertyr
  r8   r&   r&   r$   r'   r    s   
r  z
    DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    )Zcustom_introc                       s   e Zd Z fddZdd Zdd Zedeej	 eej	 eej	 eej	 eej	 eej	 ee
 ee
 ee
 eeef d
d	d
Z  ZS ) DebertaForSequenceClassificationc                    s   t  | t|dd}|| _t|| _t|| _| jj}t	
||| _t|dd }|d u rd| jjn|}t	|| _|   d S )N
num_labelsr)   Zcls_dropout)r   r   rz   r  r   r   r  poolerr
  r   r:   
classifierrC   r@   r?   rA   r   )r!   rC   r  r
  Zdrop_outr$   r&   r'   r     s    

z)DebertaForSequenceClassification.__init__c                 C   s
   | j  S r   )r   r   r   r&   r&   r'   r     s    z5DebertaForSequenceClassification.get_input_embeddingsc                 C   s   | j | d S r   )r   r   r   r&   r&   r'   r     s    z5DebertaForSequenceClassification.set_input_embeddingsNr   c
              
   C   sJ  |	dur|	n| j j}	| j||||||||	d}
|
d }| |}| |}| |}d}|dur| j jdu rx| jdkrt	 }|
d|j}|||
d}n| dks|ddkrT|dk }| }|ddkrBt|d||d|d}t|d|
d}t }||
d| j |
d}ntd|}n"td}||| d  }n| j jdkrt	 }| jdkr|| | }n
|||}nN| j jdkrt }||
d| j|
d}n| j jdkrt }|||}|	s6|f|
dd  }|dur2|f| S |S t|||
j|
jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        N)r   r   r   r   r   r   r   r   r   r(   Z
regressionZsingle_label_classificationZmulti_label_classificationr  )rC   r   r   r  rA   r  Zproblem_typer  r   r   rL   r/   r*   rh   r"   ZnonzerorK   r   rj   rS   r   r+   r^   Z
LogSoftmaxsumr,   r   r   r   r0   r   )r!   r   r   r   r   r   r  r   r   r   r  Zencoder_layerr	  r  r  Zloss_fnZlabel_indexZlabeled_logitsr  Zlog_softmaxr   r&   r&   r'   r2     sh    




 

z(DebertaForSequenceClassification.forward)	NNNNNNNNN)r4   r5   r6   r   r   r   r   r   r   r   r   r   r   r   r2   r8   r&   r&   r$   r'   r    s2            
r  c                       sz   e Zd Z fddZedeej eej eej eej eej eej ee ee ee e	e
ef d
ddZ  ZS )DebertaForTokenClassificationc                    sJ   t  | |j| _t|| _t|j| _t	|j
|j| _|   d S r   )r   r   r  r   r   r   r?   r@   rA   r:   r;   r  r   rB   r$   r&   r'   r   &  s    
z&DebertaForTokenClassification.__init__Nr   c
              
   C   s   |	dur|	n| j j}	| j||||||||	d}
|
d }| |}| |}d}|durvt }||d| j|d}|	s|f|
dd  }|dur|f| S |S t|||
j	|
j
dS )z
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        Nr  r   r(   r   r  )rC   r   r   rA   r  r   rL   r  r   r0   r   )r!   r   r   r   r   r   r  r   r   r   r  r   r  r  r  r   r&   r&   r'   r2   1  s0    

z%DebertaForTokenClassification.forward)	NNNNNNNNN)r4   r5   r6   r   r   r   r   r   r   r   r   r   r2   r8   r&   r&   r$   r'   r  $  s.            
r  c                       s   e Zd Z fddZedeej eej eej eej eej eej eej ee ee ee e	e
ef dddZ  ZS )DebertaForQuestionAnsweringc                    s<   t  | |j| _t|| _t|j|j| _| 	  d S r   )
r   r   r  r   r   r   r:   r;   
qa_outputsr   rB   r$   r&   r'   r   d  s
    
z$DebertaForQuestionAnswering.__init__N)r   r   r   r   r   start_positionsend_positionsr   r   r   r   c              
   C   sN  |
d ur|
n| j j}
| j|||||||	|
d}|d }| |}|jddd\}}|d }|d }d }|d ur|d urt| dkr|d}t| dkr|d}|d}|	d|}|	d|}t
|d}|||}|||}|| d }|
s8||f|dd   }|d ur4|f| S |S t||||j|jdS )	Nr  r   r   r(   r   )Zignore_indexr)   )r  start_logits
end_logitsr0   r   )rC   r   r   r  splitr   r   r   r"   r   r   r   r0   r   )r!   r   r   r   r   r   r  r  r   r   r   r  r   r  r  r  Z
total_lossZignored_indexr  Z
start_lossZend_lossr   r&   r&   r'   r2   n  sN    






z#DebertaForQuestionAnswering.forward)
NNNNNNNNNN)r4   r5   r6   r   r   r   r   r   r   r   r   r   r2   r8   r&   r&   r$   r'   r  b  s2   
          
r  )r   r  r  r  r   r   )Br7   typingr   r   r   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   Zactivationsr	   Zmodeling_layersr
   Zmodeling_outputsr   r   r   r   r   Zmodeling_utilsr   utilsr   r   Zconfiguration_debertar   Z
get_loggerr4   loggerModuler   r9   ZjitscriptrQ   rV   rW   rZ   r   ru   r_   ra   rf   rk   rl   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r  __all__r&   r&   r&   r'   <module>   sv   




 GQ#!Rj
Vj=K