a
    hF                     @   sL  d Z ddlZddlmZmZ ddlZddlZddlmZ ddlm	Z	m
Z
mZ ddlmZ ddlmZmZmZmZmZmZmZ dd	lmZ dd
lmZmZ ddlmZmZ ddlmZ ddl m!Z!m"Z"m#Z#m$Z$m%Z%m&Z& e'e(Z)G dd dej*Z+G dd dej*Z,G dd dej*Z-G dd dej*Z.G dd dej*Z/G dd dej*Z0G dd dej*Z1G dd dej*Z2G dd  d ej*Z3eG d!d" d"eZ4eG d#d$ d$e4Z5eG d%d& d&e4Z6G d'd( d(ej*Z7ed)d*G d+d, d,e4Z8eG d-d. d.e4Z9eG d/d0 d0e4Z:G d1d2 d2ej*Z;eG d3d4 d4e4Z<d8d5d6Z=g d7Z>dS )9zPyTorch I-BERT model.    N)OptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )gelu))BaseModelOutputWithPastAndCrossAttentions,BaseModelOutputWithPoolingAndCrossAttentionsMaskedLMOutputMultipleChoiceModelOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)auto_docstringlogging   )IBertConfig)IntGELUIntLayerNorm
IntSoftmaxQuantActQuantEmbeddingQuantLinearc                       s2   e Zd ZdZ fddZd
ddZdd	 Z  ZS )IBertEmbeddingszV
    Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
    c                    s   t    |j| _d| _d| _d| _d| _d| _t|j	|j
|j| j| jd| _t|j|j
| j| jd| _| jdt|jdd	d
 t|dd| _|j| _t|j|j
| j| j| jd| _t| j| jd| _t| j| jd| _t|j
|j| j| j|jd| _t| j| jd| _t !|j"| _#d S )N             )padding_idx
weight_bit
quant_mode)r$   r%   position_ids)r   F)
persistentposition_embedding_typeabsoluter%   epsZ
output_bitr%   force_dequant)$super__init__r%   Zembedding_bitZembedding_act_bitact_bitln_input_bitln_output_bitr   
vocab_sizehidden_sizeZpad_token_idword_embeddingsZtype_vocab_sizetoken_type_embeddingsZregister_buffertorcharangeZmax_position_embeddingsexpandgetattrr)   r#   position_embeddingsr   embeddings_act1Zembeddings_act2r   layer_norm_epsr.   	LayerNormoutput_activationr   Dropouthidden_dropout_probdropoutselfconfig	__class__ d/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/ibert/modeling_ibert.pyr0   5   sP    
	zIBertEmbeddings.__init__Nr   c                 C   s  |d u r2|d ur(t || j||j}n
| |}|d urD| }n| d d }|d u rrtj|tj| j	jd}|d u r| 
|\}}nd }| |\}}	| j||||	d\}
}| jdkr| |\}}| j|
|||d\}
}| |
|\}
}| |
}
| |
|\}
}|
|fS )Nr'   dtypedeviceidentityZidentity_scaling_factorr*   )"create_position_ids_from_input_idsr#   torM   &create_position_ids_from_inputs_embedssizer8   zeroslongr&   r6   r7   r=   r)   r<   r?   rC   r@   )rE   	input_idstoken_type_idsr&   inputs_embedspast_key_values_lengthinput_shapeZinputs_embeds_scaling_factorr7   Z$token_type_embeddings_scaling_factor
embeddingsZembeddings_scaling_factorr<   Z"position_embeddings_scaling_factorrI   rI   rJ   forwardi   sF    





zIBertEmbeddings.forwardc                 C   sN   |  dd }|d }tj| jd || j d tj|jd}|d|S )z
        We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.

        Args:
            inputs_embeds: torch.Tensor

        Returns: torch.Tensor
        Nr'   r   rK   r   )rS   r8   r9   r#   rU   rM   Z	unsqueezer:   )rE   rX   rZ   Zsequence_lengthr&   rI   rI   rJ   rR      s    	z6IBertEmbeddings.create_position_ids_from_inputs_embeds)NNNNr   )__name__
__module____qualname____doc__r0   r\   rR   __classcell__rI   rI   rG   rJ   r   0   s
   5 
/r   c                       s&   e Zd Z fddZdddZ  ZS )IBertSelfAttentionc              	      sx  t    |j|j dkr>t|ds>td|j d|j d|j| _d| _d| _d| _	|j| _t
|j|j | _| j| j | _t|j| jd| j| j| jdd	| _t|j| jd| j| j| jdd	| _t|j| jd| j| j| jdd	| _t| j	| jd
| _t| j	| jd
| _t| j	| jd
| _t| j	| jd
| _t|j| _t|dd| _| jdkr^tdt| j	| j|jd| _d S )Nr   Zembedding_sizezThe hidden size (z6) is not a multiple of the number of attention heads ()r   r"   Tbiasr$   bias_bitr%   Zper_channelr+   r)   r*   zDI-BERT only supports 'absolute' for `config.position_embedding_type`r%   r.   )r/   r0   r5   num_attention_headshasattr
ValueErrorr%   r$   rf   r1   intattention_head_sizeall_head_sizer   querykeyvaluer   query_activationkey_activationvalue_activationr@   r   rA   Zattention_probs_dropout_probrC   r;   r)   r   r.   softmaxrD   rG   rI   rJ   r0      sd    

		zIBertSelfAttention.__init__NFc                 C   s  |  ||\}}| ||\}}	| ||\}
}| ||\}}| ||	\}}| |
|\}}|j\}}}||d| j| j	
dd}||d| j| j	
dd}||d| j| j	
dd}t||
dd}t| j	}|| }| jr|| | }nd }|d ur|| }| ||\}}| |}|d ur>|| }t||}|d ur^|| }nd }|dddd }| d d | jf }|j| }| ||\}}|r||fn|f}|r||fn|f}||fS )Nr'   r      r   r   )rn   ro   rp   rq   rr   rs   shapeviewrh   rl   Z	transposer8   matmulmathsqrtr%   rt   rC   Zpermute
contiguousrS   rm   r@   )rE   hidden_stateshidden_states_scaling_factorattention_mask	head_maskoutput_attentionsZmixed_query_layerZ mixed_query_layer_scaling_factorZmixed_key_layerZmixed_key_layer_scaling_factorZmixed_value_layerZ mixed_value_layer_scaling_factorZquery_layerZquery_layer_scaling_factorZ	key_layerZkey_layer_scaling_factorZvalue_layerZvalue_layer_scaling_factor
batch_size
seq_length_Zattention_scoresscaleZattention_scores_scaling_factorZattention_probsZattention_probs_scaling_factorZcontext_layerZcontext_layer_scaling_factorZnew_context_layer_shapeoutputsZoutput_scaling_factorrI   rI   rJ   r\      s`    	





zIBertSelfAttention.forward)NNFr]   r^   r_   r0   r\   ra   rI   rI   rG   rJ   rb      s
   >   rb   c                       s$   e Zd Z fddZdd Z  ZS )IBertSelfOutputc              	      s   t    |j| _d| _d| _d| _d| _d| _t|j	|j	d| j| j| jdd| _
t| j| jd| _t|j	|j| j| j|jd| _t| j| jd| _t|j| _d S Nr   r"   r!   Trd   r+   r,   )r/   r0   r%   r1   r$   rf   r2   r3   r   r5   denser   ln_input_actr   r>   r.   r?   r@   r   rA   rB   rC   rD   rG   rI   rJ   r0   9  s4    
	zIBertSelfOutput.__init__c                 C   sX   |  ||\}}| |}| j||||d\}}| ||\}}| ||\}}||fS NrN   r   rC   r   r?   r@   rE   r}   r~   Zinput_tensorZinput_tensor_scaling_factorrI   rI   rJ   r\   V  s    

zIBertSelfOutput.forwardr   rI   rI   rG   rJ   r   8  s   r   c                       s.   e Zd Z fddZdd Zd	ddZ  ZS )
IBertAttentionc                    s2   t    |j| _t|| _t|| _t | _d S N)	r/   r0   r%   rb   rE   r   outputsetpruned_headsrD   rG   rI   rJ   r0   h  s
    


zIBertAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   dim)lenr   rE   rh   rl   r   r   rn   ro   rp   r   r   rm   union)rE   headsindexrI   rI   rJ   prune_headso  s    zIBertAttention.prune_headsNFc                 C   s^   |  |||||\}}| |d |d ||\}}	|f|dd   }
|	f|dd   }|
|fS )Nr   r   )rE   r   )rE   r}   r~   r   r   r   Zself_outputsZself_outputs_scaling_factorattention_outputattention_output_scaling_factorr   Zoutputs_scaling_factorrI   rI   rJ   r\     s    zIBertAttention.forward)NNF)r]   r^   r_   r0   r   r\   ra   rI   rI   rG   rJ   r   g  s      r   c                       s$   e Zd Z fddZdd Z  ZS )IBertIntermediatec              	      s   t    |j| _d| _d| _d| _t|j|jd| j| j| jdd| _	|j
dkrXtdt| j|jd| _t| j| jd| _d S )	Nr   r"   Trd   r	   z3I-BERT only supports 'gelu' for `config.hidden_act`rg   r+   )r/   r0   r%   r1   r$   rf   r   r5   intermediate_sizer   Z
hidden_actrj   r   r.   intermediate_act_fnr   r@   rD   rG   rI   rJ   r0     s$    
	
zIBertIntermediate.__init__c                 C   s8   |  ||\}}| ||\}}| ||\}}||fS r   )r   r   r@   )rE   r}   r~   rI   rI   rJ   r\     s    zIBertIntermediate.forwardr   rI   rI   rG   rJ   r     s   r   c                       s$   e Zd Z fddZdd Z  ZS )IBertOutputc              	      s   t    |j| _d| _d| _d| _d| _d| _t|j	|j
d| j| j| jdd| _t| j| jd| _t|j
|j| j| j|jd| _t| j| jd| _t|j| _d S r   )r/   r0   r%   r1   r$   rf   r2   r3   r   r   r5   r   r   r   r   r>   r.   r?   r@   r   rA   rB   rC   rD   rG   rI   rJ   r0     s4    
	zIBertOutput.__init__c                 C   sX   |  ||\}}| |}| j||||d\}}| ||\}}| ||\}}||fS r   r   r   rI   rI   rJ   r\     s    

zIBertOutput.forwardr   rI   rI   rG   rJ   r     s   r   c                       s.   e Zd Z fddZd	ddZdd Z  ZS )

IBertLayerc                    sd   t    |j| _d| _d| _t|| _t|| _t	|| _
t| j| jd| _t| j| jd| _d S )Nr   r   r+   )r/   r0   r%   r1   Zseq_len_dimr   	attentionr   intermediater   r   r   pre_intermediate_actpre_output_actrD   rG   rI   rJ   r0     s    



zIBertLayer.__init__NFc                 C   sR   | j |||||d\}}|d }|d }	|dd  }
| ||	\}}|f|
 }
|
S )N)r   r   r   )r   feed_forward_chunk)rE   r}   r~   r   r   r   Zself_attention_outputsZ%self_attention_outputs_scaling_factorr   r   r   layer_outputlayer_output_scaling_factorrI   rI   rJ   r\     s    

zIBertLayer.forwardc                 C   sL   |  ||\}}| ||\}}| ||\}}| ||||\}}||fS r   )r   r   r   r   )rE   r   r   Zintermediate_outputZ"intermediate_output_scaling_factorr   r   rI   rI   rJ   r     s    zIBertLayer.feed_forward_chunk)NNF)r]   r^   r_   r0   r\   r   ra   rI   rI   rG   rJ   r     s      
r   c                       s&   e Zd Z fddZdddZ  ZS )	IBertEncoderc                    s<   t     | _ j| _t fddt jD | _d S )Nc                    s   g | ]}t  qS rI   )r   ).0r   rF   rI   rJ   
<listcomp>(      z)IBertEncoder.__init__.<locals>.<listcomp>)	r/   r0   rF   r%   r   Z
ModuleListrangenum_hidden_layerslayerrD   rG   r   rJ   r0   $  s    
zIBertEncoder.__init__NFTc                 C   s   |rdnd }|rdnd }	d }
t | jD ]T\}}|r<||f }|d urL|| nd }||||||}|d }|r&|	|d f }	q&|r||f }|stdd |||	|
fD S t|||	|
dS )NrI   r   r   c                 s   s   | ]}|d ur|V  qd S r   rI   )r   vrI   rI   rJ   	<genexpr>N  s   z'IBertEncoder.forward.<locals>.<genexpr>)last_hidden_stater}   
attentionscross_attentions)	enumerater   tupler
   )rE   r}   r~   r   r   r   output_hidden_statesreturn_dictZall_hidden_statesZall_self_attentionsZall_cross_attentionsiZlayer_moduleZlayer_head_maskZlayer_outputsrI   rI   rJ   r\   *  sB    



zIBertEncoder.forward)NNFFTr   rI   rI   rG   rJ   r   #  s   
     r   c                       s$   e Zd Z fddZdd Z  ZS )IBertPoolerc                    s2   t    |j| _t|j|j| _t | _d S r   )	r/   r0   r%   r   Linearr5   r   ZTanh
activationrD   rG   rI   rJ   r0   a  s    
zIBertPooler.__init__c                 C   s(   |d d df }|  |}| |}|S Nr   )r   r   )rE   r}   Zfirst_token_tensorpooled_outputrI   rI   rJ   r\   g  s    

zIBertPooler.forwardr   rI   rI   rG   rJ   r   `  s   r   c                   @   s,   e Zd ZU eed< dZdd ZdddZdS )	IBertPreTrainedModelrF   ibertc                 C   s   t |ttjfr>|jjjd| jjd |j	dur|j	j
  nt |ttjfr|jjjd| jjd |jdur|jj|j 
  nBt |ttjfr|j	j
  |jjd nt |tr|j	j
  dS )zInitialize the weightsg        )meanZstdNg      ?)
isinstancer   r   r   weightdataZnormal_rF   Zinitializer_rangere   Zzero_r   Z	Embeddingr#   r   r?   Zfill_IBertLMHead)rE   modulerI   rI   rJ   _init_weightsu  s    


z"IBertPreTrainedModel._init_weightsNc                 C   s   t dd S )Nz6`resize_token_embeddings` is not supported for I-BERT.)NotImplementedError)rE   Znew_num_tokensrI   rI   rJ   resize_token_embeddings  s    z,IBertPreTrainedModel.resize_token_embeddings)N)r]   r^   r_   r   __annotations__Zbase_model_prefixr   r   rI   rI   rI   rJ   r   p  s   
r   c                       s   e Zd ZdZd fdd	Zdd Zdd Zd	d
 Zede	e
j e	e
j e	e
j e	e
j e	e
j e	e
j e	e e	e e	e eeee
j f d
ddZ  ZS )
IBertModela  

    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
    cross-attention is added between the self-attention layers, following the architecture described in [Attention is
    all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.

    Tc                    sL   t  | || _|j| _t|| _t|| _|r:t|nd| _	| 
  dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        N)r/   r0   rF   r%   r   r[   r   encoderr   pooler	post_init)rE   rF   add_pooling_layerrG   rI   rJ   r0     s    

zIBertModel.__init__c                 C   s   | j jS r   r[   r6   rE   rI   rI   rJ   get_input_embeddings  s    zIBertModel.get_input_embeddingsc                 C   s   || j _d S r   r   )rE   rp   rI   rI   rJ   set_input_embeddings  s    zIBertModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   r   )rE   Zheads_to_pruner   r   rI   rI   rJ   _prune_heads  s    zIBertModel._prune_headsN)
rV   r   rW   r&   r   rX   r   r   r   returnc
              	   C   s  |d ur|n| j j}|d ur |n| j j}|	d ur4|	n| j j}	|d urV|d urVtdn@|d urt| || | }
n"|d ur| d d }
ntd|
\}}|d ur|jn|j}|d u rtj	||f|d}|d u rtj
|
tj|d}| ||
}| || j j}| j||||d\}}| j|||||||	d}|d }| jd urN| |nd }|	sl||f|d	d   S t|||j|j|jd
S )NzDYou cannot specify both input_ids and inputs_embeds at the same timer'   z5You have to specify either input_ids or inputs_embeds)rM   rK   )rV   r&   rW   rX   )r   r   r   r   r   r   r   )r   Zpooler_outputr}   r   r   )rF   r   r   use_return_dictrj   Z%warn_if_padding_and_no_attention_maskrS   rM   r8   ZonesrT   rU   Zget_extended_attention_maskZget_head_maskr   r[   r   r   r   r}   r   r   )rE   rV   r   rW   r&   r   rX   r   r   r   rZ   r   r   rM   Zextended_attention_maskZembedding_outputZembedding_output_scaling_factorZencoder_outputssequence_outputr   rI   rI   rJ   r\     s\    


	zIBertModel.forward)T)	NNNNNNNNN)r]   r^   r_   r`   r0   r   r   r   r   r   r8   
LongTensorFloatTensorboolr   r   r   r\   ra   rI   rI   rG   rJ   r     s6   	         r   c                       s   e Zd ZddgZ fddZdd Zdd Zedee	j
 ee	j ee	j
 ee	j
 ee	j ee	j ee	j
 ee ee ee eeee	j f d
ddZ  ZS )IBertForMaskedLMzlm_head.decoder.biaszlm_head.decoder.weightc                    s0   t  | t|dd| _t|| _|   d S NF)r   )r/   r0   r   r   r   lm_headr   rD   rG   rI   rJ   r0     s    
zIBertForMaskedLM.__init__c                 C   s   | j jS r   )r   decoderr   rI   rI   rJ   get_output_embeddings  s    z&IBertForMaskedLM.get_output_embeddingsc                 C   s   || j _|j| j _d S r   )r   r   re   )rE   Znew_embeddingsrI   rI   rJ   set_output_embeddings  s    z&IBertForMaskedLM.set_output_embeddingsNrV   r   rW   r&   r   rX   labelsr   r   r   r   c                 C   s   |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}d}|durpt }||d| j j|d}|
s|f|dd  }|dur|f| S |S t|||j|j	dS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        Nr   rW   r&   r   rX   r   r   r   r   r'   ru   losslogitsr}   r   )
rF   r   r   r   r   rx   r4   r   r}   r   )rE   rV   r   rW   r&   r   rX   r   r   r   r   r   r   Zprediction_scoresZmasked_lm_lossloss_fctr   rI   rI   rJ   r\     s6    
zIBertForMaskedLM.forward)
NNNNNNNNNN)r]   r^   r_   Z_tied_weights_keysr0   r   r   r   r   r8   r   r   r   r   r   r   r\   ra   rI   rI   rG   rJ   r     s8   	          r   c                       s6   e Zd ZdZ fddZdd Zdddd	Z  ZS )
r   z)I-BERT Head for masked language modeling.c                    sd   t    t|j|j| _tj|j|jd| _t|j|j	| _
tt|j	| _| j| j
_d S )N)r-   )r/   r0   r   r   r5   r   r?   r>   
layer_normr4   r   	Parameterr8   rT   re   rD   rG   rI   rJ   r0   O  s    
zIBertLMHead.__init__c                 K   s*   |  |}t|}| |}| |}|S r   )r   r	   r   r   )rE   featureskwargsxrI   rI   rJ   r\   X  s
    


zIBertLMHead.forwardN)r   c                 C   s*   | j jjjdkr| j| j _n
| j j| _d S )Nmeta)r   re   rM   typer   rI   rI   rJ   _tie_weightsb  s    zIBertLMHead._tie_weights)r]   r^   r_   r`   r0   r\   r   ra   rI   rI   rG   rJ   r   L  s   	
r   z
    I-BERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    )Zcustom_introc                       s   e Zd Z fddZedeej eej eej eej eej eej eej ee	 ee	 ee	 e
eeej f dddZ  ZS )IBertForSequenceClassificationc                    s8   t  | |j| _t|dd| _t|| _|   d S r   )r/   r0   
num_labelsr   r   IBertClassificationHead
classifierr   rD   rG   rI   rJ   r0   r  s
    
z'IBertForSequenceClassification.__init__Nr   c                 C   sr  |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}d}|dur.| j jdu r| jdkrpd| j _n4| jdkr|jtjks|jtj	krd| j _nd| j _| j jdkrt
 }| jdkr|| | }n
|||}nN| j jdkrt }||d| j|d}n| j jdkr.t }|||}|
s^|f|d	d  }|durZ|f| S |S t|||j|jd
S )a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr'   ru   r   )rF   r   r   r   Zproblem_typer   rL   r8   rU   rk   r   squeezer   rx   r   r   r}   r   rE   rV   r   rW   r&   r   rX   r   r   r   r   r   r   r   r   r   r   rI   rI   rJ   r\   |  sT    



"


z&IBertForSequenceClassification.forward)
NNNNNNNNNN)r]   r^   r_   r0   r   r   r8   r   r   r   r   r   r   r\   ra   rI   rI   rG   rJ   r   k  s2   
          r   c                       s   e Zd Z fddZedeej eej eej eej eej eej eej ee	 ee	 ee	 e
eeej f dddZ  ZS )IBertForMultipleChoicec                    s@   t  | t|| _t|j| _t|j	d| _
|   d S )Nr   )r/   r0   r   r   r   rA   rB   rC   r   r5   r   r   rD   rG   rI   rJ   r0     s
    
zIBertForMultipleChoice.__init__N)rV   rW   r   r   r&   r   rX   r   r   r   r   c                 C   st  |
dur|
n| j j}
|dur&|jd n|jd }|durJ|d|dnd}|durh|d|dnd}|dur|d|dnd}|dur|d|dnd}|dur|d|d|dnd}| j||||||||	|
d	}|d }| |}| |}|d|}d}|dur0t }|||}|
s`|f|dd  }|dur\|f| S |S t	|||j
|jdS )a[  
        input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        Nr   r'   rv   )r&   rW   r   r   rX   r   r   r   ru   r   )rF   r   rw   rx   rS   r   rC   r   r   r   r}   r   )rE   rV   rW   r   r   r&   r   rX   r   r   r   Znum_choicesZflat_input_idsZflat_position_idsZflat_token_type_idsZflat_attention_maskZflat_inputs_embedsr   r   r   Zreshaped_logitsr   r   r   rI   rI   rJ   r\     sL    ,



zIBertForMultipleChoice.forward)
NNNNNNNNNN)r]   r^   r_   r0   r   r   r8   r   r   r   r   r   r   r\   ra   rI   rI   rG   rJ   r     s2   
          r   c                       s   e Zd Z fddZedeej eej eej eej eej eej eej ee	 ee	 ee	 e
eeej f dddZ  ZS )IBertForTokenClassificationc                    sN   t  | |j| _t|dd| _t|j| _t	|j
|j| _|   d S r   )r/   r0   r   r   r   r   rA   rB   rC   r   r5   r   r   rD   rG   rI   rJ   r0   +  s    z$IBertForTokenClassification.__init__Nr   c                 C   s   |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}| |}d}|durxt }||d| j|d}|
s|f|dd  }|dur|f| S |S t|||j	|j
dS )z
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        Nr   r   r'   ru   r   )rF   r   r   rC   r   r   rx   r   r   r}   r   r   rI   rI   rJ   r\   6  s8    

z#IBertForTokenClassification.forward)
NNNNNNNNNN)r]   r^   r_   r0   r   r   r8   r   r   r   r   r   r   r\   ra   rI   rI   rG   rJ   r   )  s2             r   c                       s(   e Zd ZdZ fddZdd Z  ZS )r   z-Head for sentence-level classification tasks.c                    s@   t    t|j|j| _t|j| _t|j|j	| _
d S r   )r/   r0   r   r   r5   r   rA   rB   rC   r   out_projrD   rG   rI   rJ   r0   o  s    
z IBertClassificationHead.__init__c                 K   sL   |d d dd d f }|  |}| |}t|}|  |}| |}|S r   )rC   r   r8   tanhr   )rE   r   r   r}   rI   rI   rJ   r\   u  s    




zIBertClassificationHead.forward)r]   r^   r_   r`   r0   r\   ra   rI   rI   rG   rJ   r   l  s   r   c                       s   e Zd Z fddZedeej eej eej eej eej eej eej eej ee	 ee	 ee	 e
eeej f dddZ  ZS )IBertForQuestionAnsweringc                    s@   t  | |j| _t|dd| _t|j|j| _| 	  d S r   )
r/   r0   r   r   r   r   r   r5   
qa_outputsr   rD   rG   rI   rJ   r0     s
    z"IBertForQuestionAnswering.__init__N)rV   r   rW   r&   r   rX   start_positionsend_positionsr   r   r   r   c                 C   sP  |d ur|n| j j}| j|||||||	|
|d	}|d }| |}|jddd\}}|d }|d }d }|d ur|d urt| dkr|d}t| dkr|d}|d}|	d|}|	d|}t
|d}|||}|||}|| d }|s:||f|dd   }|d ur6|f| S |S t||||j|jdS )	Nr   r   r   r'   r   )Zignore_indexru   )r   start_logits
end_logitsr}   r   )rF   r   r   r   splitr   r|   r   rS   clampr   r   r}   r   )rE   rV   r   rW   r&   r   rX   r   r   r   r   r   r   r   r   r   r  Z
total_lossZignored_indexr   Z
start_lossZend_lossr   rI   rI   rJ   r\     sP    






z!IBertForQuestionAnswering.forward)NNNNNNNNNNN)r]   r^   r_   r0   r   r   r8   r   r   r   r   r   r   r\   ra   rI   rI   rG   rJ   r     s6   
           r   c                 C   s6   |  | }tj|dd|| | }| | S )aM  
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's *utils.make_positions*.

    Args:
    input_ids (`torch.LongTensor`):
           Indices of input sequence tokens in the vocabulary.

    Returns: torch.Tensor
    r   r   )nerk   r8   ZcumsumZtype_asrU   )rV   r#   rY   maskZincremental_indicesrI   rI   rJ   rP     s    rP   )r   r   r   r   r   r   r   )r   )?r`   rz   typingr   r   r8   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   Zactivationsr	   Zmodeling_outputsr
   r   r   r   r   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   utilsr   r   Zconfiguration_ibertr   Zquant_modulesr   r   r   r   r   r   Z
get_loggerr]   loggerModuler   rb   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   rP   __all__rI   rI   rI   rJ   <module>   sX   $	 
z /1"/:=wHQfBM
