a
    hR                     @   s  d Z ddlZddlmZmZ ddlZddlmZ ddlmZm	Z	m
Z
 ddlmZ ddlmZmZmZmZmZmZmZ dd	lmZ dd
lmZmZ ddlmZ eeZG dd dejZ G dd dejZ!G dd dej"Z#G dd dejZ$G dd dejZ%G dd dejZ&G dd dejZ'G dd dejZ(G dd dejZ)G dd  d ejZ*G d!d" d"ejZ+G d#d$ d$ejZ,eG d%d& d&eZ-eG d'd( d(e-Z.eG d)d* d*e-Z/ed+d,G d-d. d.e-Z0eG d/d0 d0e-Z1eG d1d2 d2e-Z2eG d3d4 d4e-Z3g d5Z4dS )6zPyTorch SqueezeBert model.    N)OptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BaseModelOutputBaseModelOutputWithPoolingMaskedLMOutputMultipleChoiceModelOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel)auto_docstringlogging   )SqueezeBertConfigc                       s*   e Zd ZdZ fddZdddZ  ZS )SqueezeBertEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _| jdt|jddd d S )N)padding_idxepsposition_ids)r   F)
persistent)super__init__r   	Embedding
vocab_sizeembedding_sizeZpad_token_idword_embeddingsZmax_position_embeddingsposition_embeddingsZtype_vocab_sizetoken_type_embeddings	LayerNormhidden_sizelayer_norm_epsDropouthidden_dropout_probdropoutZregister_buffertorchZarangeexpandselfconfig	__class__ p/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/squeezebert/modeling_squeezebert.pyr   0   s    
zSqueezeBertEmbeddings.__init__Nc           
      C   s   |d ur|  }n|  d d }|d }|d u rH| jd d d |f }|d u rftj|tj| jjd}|d u rx| |}| |}| |}|| | }	| 	|	}	| 
|	}	|	S )Nr   r   dtypedevice)sizer   r+   zeroslongr6   r"   r#   r$   r%   r*   )
r.   	input_idstoken_type_idsr   inputs_embedsinput_shapeZ
seq_lengthr#   r$   
embeddingsr2   r2   r3   forward@   s     





zSqueezeBertEmbeddings.forward)NNNN__name__
__module____qualname____doc__r   r?   __classcell__r2   r2   r0   r3   r   -   s   r   c                       s(   e Zd ZdZ fddZdd Z  ZS )MatMulWrapperz
    Wrapper for torch.matmul(). This makes flop-counting easier to implement. Note that if you directly call
    torch.matmul() in your code, the flop counter will typically ignore the flops of the matmul.
    c                    s   t    d S N)r   r   r.   r0   r2   r3   r   _   s    zMatMulWrapper.__init__c                 C   s   t ||S )a0  

        :param inputs: two torch tensors :return: matmul of these tensors

        Here are the typical dimensions found in BERT (the B is optional) mat1.shape: [B, <optional extra dims>, M, K]
        mat2.shape: [B, <optional extra dims>, K, N] output shape: [B, <optional extra dims>, M, N]
        )r+   matmul)r.   Zmat1Zmat2r2   r2   r3   r?   b   s    zMatMulWrapper.forwardr@   r2   r2   r0   r3   rF   Y   s   rF   c                   @   s"   e Zd ZdZdddZdd ZdS )	SqueezeBertLayerNormz
    This is a nn.LayerNorm subclass that accepts NCW data layout and performs normalization in the C dimension.

    N = batch C = channels W = sequence length
    -q=c                 C   s   t jj| ||d d S )N)Znormalized_shaper   )r   r%   r   )r.   r&   r   r2   r2   r3   r   t   s    zSqueezeBertLayerNorm.__init__c                 C   s*   | ddd}tj| |}| dddS )Nr      r   )permuter   r%   r?   )r.   xr2   r2   r3   r?   w   s    zSqueezeBertLayerNorm.forwardN)rK   )rA   rB   rC   rD   r   r?   r2   r2   r2   r3   rJ   m   s   
rJ   c                       s(   e Zd ZdZ fddZdd Z  ZS )ConvDropoutLayerNormz8
    ConvDropoutLayerNorm: Conv, Dropout, LayerNorm
    c                    s8   t    tj||d|d| _t|| _t|| _d S Nr   Zin_channelsZout_channelsZkernel_sizegroups)	r   r   r   Conv1dconv1drJ   	layernormr(   r*   )r.   cincoutrR   dropout_probr0   r2   r3   r      s    

zConvDropoutLayerNorm.__init__c                 C   s*   |  |}| |}|| }| |}|S rG   )rT   r*   rU   )r.   hidden_statesZinput_tensorrN   r2   r2   r3   r?      s
    


zConvDropoutLayerNorm.forwardr@   r2   r2   r0   r3   rO   }   s   rO   c                       s(   e Zd ZdZ fddZdd Z  ZS )ConvActivationz*
    ConvActivation: Conv, Activation
    c                    s,   t    tj||d|d| _t| | _d S rP   )r   r   r   rS   rT   r	   act)r.   rV   rW   rR   r[   r0   r2   r3   r      s    
zConvActivation.__init__c                 C   s   |  |}| |S rG   )rT   r[   )r.   rN   outputr2   r2   r3   r?      s    
zConvActivation.forwardr@   r2   r2   r0   r3   rZ      s   rZ   c                       s>   e Zd Zd fdd	Zdd Zdd Zdd	 Zd
d Z  ZS )SqueezeBertSelfAttentionr   c                    s   t    ||j dkr0td| d|j d|j| _t||j | _| j| j | _tj||d|d| _	tj||d|d| _
tj||d|d| _t|j| _tjdd| _t | _t | _d	S )
z
        config = used for some things; ignored for others (work in progress...) cin = input channels = output channels
        groups = number of groups to use in conv1d layers
        r   zcin (z6) is not a multiple of the number of attention heads ()r   rQ   r   dimN)r   r   num_attention_heads
ValueErrorintattention_head_sizeall_head_sizer   rS   querykeyvaluer(   Zattention_probs_dropout_probr*   ZSoftmaxsoftmaxrF   	matmul_qk
matmul_qkv)r.   r/   rV   q_groupsk_groupsv_groupsr0   r2   r3   r      s    
z!SqueezeBertSelfAttention.__init__c                 C   s:   |  d | j| j|  d f}|j| }|ddddS )z
        - input: [N, C, W]
        - output: [N, C1, W, C2] where C1 is the head index, and C2 is one head's contents
        r   r   r   r   rL   )r7   ra   rd   viewrM   r.   rN   Znew_x_shaper2   r2   r3   transpose_for_scores   s     
z-SqueezeBertSelfAttention.transpose_for_scoresc                 C   s.   |  d | j| j|  d f}|j| }|S )z
        - input: [N, C, W]
        - output: [N, C1, C2, W] where C1 is the head index, and C2 is one head's contents
        r   r   )r7   ra   rd   ro   rp   r2   r2   r3   transpose_key_for_scores   s     
z1SqueezeBertSelfAttention.transpose_key_for_scoresc                 C   s>   | dddd }| d | j| d f}|j| }|S )zE
        - input: [N, C1, W, C2]
        - output: [N, C, W]
        r   r   r   rL   )rM   
contiguousr7   re   ro   rp   r2   r2   r3   transpose_output   s    
z)SqueezeBertSelfAttention.transpose_outputc                 C   s   |  |}| |}| |}| |}| |}| |}	| ||}
|
t| j }
|
| }
| 	|
}| 
|}| ||	}| |}d|i}|r|
|d< |S )z
        expects hidden_states in [N, C, W] data layout.

        The attention_mask data layout is [N, W], and it does not need to be transposed.
        context_layerattention_score)rf   rg   rh   rq   rr   rj   mathsqrtrd   ri   r*   rk   rt   )r.   rY   attention_maskoutput_attentionsZmixed_query_layerZmixed_key_layerZmixed_value_layerZquery_layerZ	key_layerZvalue_layerrv   Zattention_probsru   resultr2   r2   r3   r?      s"    








z SqueezeBertSelfAttention.forward)r   r   r   )	rA   rB   rC   r   rq   rr   rt   r?   rE   r2   r2   r0   r3   r]      s
   	

r]   c                       s$   e Zd Z fddZdd Z  ZS )SqueezeBertModulec                    s   t    |j}|j}|j}|j}t|||j|j|jd| _t	|||j
|jd| _t|||j|jd| _t	|||j|jd| _dS )a  
        - hidden_size = input chans = output chans for Q, K, V (they are all the same ... for now) = output chans for
          the module
        - intermediate_size = output chans for intermediate layer
        - groups = number of groups for all layers in the BertModule. (eventually we could change the interface to
          allow different groups for different layers)
        )r/   rV   rl   rm   rn   )rV   rW   rR   rX   )rV   rW   rR   r[   N)r   r   r&   Zintermediate_sizer]   rl   rm   rn   	attentionrO   Zpost_attention_groupsr)   post_attentionrZ   Zintermediate_groups
hidden_actintermediateZoutput_groupsr\   )r.   r/   Zc0c1c2c3r0   r2   r3   r      s    
zSqueezeBertModule.__init__c           
      C   sT   |  |||}|d }| ||}| |}| ||}d|i}	|rP|d |	d< |	S )Nru   feature_maprv   )r}   r~   r   r\   )
r.   rY   ry   rz   ZattZattention_outputZpost_attention_outputZintermediate_outputlayer_outputZoutput_dictr2   r2   r3   r?     s    
zSqueezeBertModule.forwardrA   rB   rC   r   r?   rE   r2   r2   r0   r3   r|      s   r|   c                       s&   e Zd Z fddZdddZ  ZS )	SqueezeBertEncoderc                    sB   t     j jksJ dt fddt jD | _d S )NzIf you want embedding_size != intermediate hidden_size, please insert a Conv1d layer to adjust the number of channels before the first SqueezeBertModule.c                 3   s   | ]}t  V  qd S rG   )r|   ).0_r/   r2   r3   	<genexpr>.      z.SqueezeBertEncoder.__init__.<locals>.<genexpr>)	r   r   r!   r&   r   Z
ModuleListrangenum_hidden_layerslayersr-   r0   r   r3   r   %  s
    
zSqueezeBertEncoder.__init__NFTc                 C   s  |d u rd}n| d t|kr&d}nd}|du s:J d|ddd}|rPdnd }|r\dnd }	| jD ]V}
|r|ddd}||f7 }|ddd}|
|||}|d }|rf|	|d	 f7 }	qf|ddd}|r||f7 }|std
d |||	fD S t|||	dS )NTFzAhead_mask is not yet supported in the SqueezeBert implementation.r   rL   r   r2   r   rv   c                 s   s   | ]}|d ur|V  qd S rG   r2   )r   vr2   r2   r3   r   [  r   z-SqueezeBertEncoder.forward.<locals>.<genexpr>)last_hidden_staterY   
attentions)countlenrM   r   r?   tupler
   )r.   rY   ry   	head_maskrz   output_hidden_statesreturn_dictZhead_mask_is_all_noneZall_hidden_statesZall_attentionslayerr   r2   r2   r3   r?   0  s4    	


zSqueezeBertEncoder.forward)NNFFTr   r2   r2   r0   r3   r   $  s        r   c                       s$   e Zd Z fddZdd Z  ZS )SqueezeBertPoolerc                    s*   t    t|j|j| _t | _d S rG   )r   r   r   Linearr&   denseZTanh
activationr-   r0   r2   r3   r   b  s    
zSqueezeBertPooler.__init__c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )r.   rY   Zfirst_token_tensorpooled_outputr2   r2   r3   r?   g  s    

zSqueezeBertPooler.forwardr   r2   r2   r0   r3   r   a  s   r   c                       s$   e Zd Z fddZdd Z  ZS )"SqueezeBertPredictionHeadTransformc                    sV   t    t|j|j| _t|jtr6t	|j | _
n|j| _
tj|j|jd| _d S )Nr   )r   r   r   r   r&   r   
isinstancer   strr	   transform_act_fnr%   r'   r-   r0   r2   r3   r   q  s    
z+SqueezeBertPredictionHeadTransform.__init__c                 C   s"   |  |}| |}| |}|S rG   )r   r   r%   r.   rY   r2   r2   r3   r?   z  s    


z*SqueezeBertPredictionHeadTransform.forwardr   r2   r2   r0   r3   r   p  s   	r   c                       s2   e Zd Z fddZddddZdd Z  ZS )	SqueezeBertLMPredictionHeadc                    sL   t    t|| _tj|j|jdd| _t	t
|j| _| j| j_d S )NF)bias)r   r   r   	transformr   r   r&   r    decoder	Parameterr+   r8   r   r-   r0   r2   r3   r     s
    

z$SqueezeBertLMPredictionHead.__init__N)returnc                 C   s   | j | j_ d S rG   )r   r   rH   r2   r2   r3   _tie_weights  s    z(SqueezeBertLMPredictionHead._tie_weightsc                 C   s   |  |}| |}|S rG   )r   r   r   r2   r2   r3   r?     s    

z#SqueezeBertLMPredictionHead.forward)rA   rB   rC   r   r   r?   rE   r2   r2   r0   r3   r     s   r   c                       s$   e Zd Z fddZdd Z  ZS )SqueezeBertOnlyMLMHeadc                    s   t    t|| _d S rG   )r   r   r   predictionsr-   r0   r2   r3   r     s    
zSqueezeBertOnlyMLMHead.__init__c                 C   s   |  |}|S rG   )r   )r.   sequence_outputprediction_scoresr2   r2   r3   r?     s    
zSqueezeBertOnlyMLMHead.forwardr   r2   r2   r0   r3   r     s   r   c                   @   s"   e Zd ZU eed< dZdd ZdS )SqueezeBertPreTrainedModelr/   transformerc                 C   s   t |tjtjfr@|jjjd| jjd |j	dur|j	j
  n~t |tjr|jjjd| jjd |jdur|jj|j 
  n>t |tjr|j	j
  |jjd nt |tr|j	j
  dS )zInitialize the weightsg        )meanZstdNg      ?)r   r   r   rS   weightdataZnormal_r/   Zinitializer_ranger   Zzero_r   r   r%   Zfill_r   )r.   moduler2   r2   r3   _init_weights  s    


z(SqueezeBertPreTrainedModel._init_weightsN)rA   rB   rC   r   __annotations__Zbase_model_prefixr   r2   r2   r2   r3   r     s   
r   c                       s   e Zd Z fddZdd Zdd Zdd Zedee	j
 ee	j
 ee	j
 ee	j
 ee	j
 ee	j ee ee ee eeef d

ddZ  ZS )SqueezeBertModelc                    s6   t  | t|| _t|| _t|| _|   d S rG   )	r   r   r   r>   r   encoderr   pooler	post_initr-   r0   r2   r3   r     s
    


zSqueezeBertModel.__init__c                 C   s   | j jS rG   r>   r"   rH   r2   r2   r3   get_input_embeddings  s    z%SqueezeBertModel.get_input_embeddingsc                 C   s   || j _d S rG   r   r.   Znew_embeddingsr2   r2   r3   set_input_embeddings  s    z%SqueezeBertModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r}   Zprune_heads)r.   Zheads_to_pruner   Zheadsr2   r2   r3   _prune_heads  s    zSqueezeBertModel._prune_headsN)
r:   ry   r;   r   r   r<   rz   r   r   r   c
                 C   s^  |d ur|n| j j}|d ur |n| j j}|	d ur4|	n| j j}	|d urV|d urVtdn@|d urt| || | }
n"|d ur| d d }
ntd|d ur|jn|j}|d u rtj	|
|d}|d u rtj
|
tj|d}| ||
}| || j j}| j||||d}| j||||||	d}|d }| |}|	sJ||f|d	d   S t|||j|jd
S )NzDYou cannot specify both input_ids and inputs_embeds at the same timer   z5You have to specify either input_ids or inputs_embeds)r6   r4   )r:   r   r;   r<   )rY   ry   r   rz   r   r   r   r   )r   Zpooler_outputrY   r   )r/   rz   r   use_return_dictrb   Z%warn_if_padding_and_no_attention_maskr7   r6   r+   Zonesr8   r9   Zget_extended_attention_maskZget_head_maskr   r>   r   r   r   rY   r   )r.   r:   ry   r;   r   r   r<   rz   r   r   r=   r6   Zextended_attention_maskZembedding_outputZencoder_outputsr   r   r2   r2   r3   r?     sP    


zSqueezeBertModel.forward)	NNNNNNNNN)rA   rB   rC   r   r   r   r   r   r   r+   TensorZFloatTensorboolr   r   r   r?   rE   r2   r2   r0   r3   r     s4   
         
r   c                       s   e Zd ZddgZ fddZdd Zdd Zedee	j
 ee	j
 ee	j
 ee	j
 ee	j
 ee	j
 ee	j
 ee ee ee eeef d
ddZ  ZS )SqueezeBertForMaskedLMzcls.predictions.decoder.weightzcls.predictions.decoder.biasc                    s,   t  | t|| _t|| _|   d S rG   )r   r   r   r   r   clsr   r-   r0   r2   r3   r     s    

zSqueezeBertForMaskedLM.__init__c                 C   s
   | j jjS rG   )r   r   r   rH   r2   r2   r3   get_output_embeddings&  s    z,SqueezeBertForMaskedLM.get_output_embeddingsc                 C   s   || j j_|j| j j_d S rG   )r   r   r   r   r   r2   r2   r3   set_output_embeddings)  s    
z,SqueezeBertForMaskedLM.set_output_embeddingsNr:   ry   r;   r   r   r<   labelsrz   r   r   r   c                 C   s   |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}d}|durpt }||d| j j|d}|
s|f|dd  }|dur|f| S |S t|||j|j	dS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        Nry   r;   r   r   r<   rz   r   r   r   r   rL   losslogitsrY   r   )
r/   r   r   r   r   ro   r    r   rY   r   )r.   r:   ry   r;   r   r   r<   r   rz   r   r   outputsr   r   Zmasked_lm_lossloss_fctr\   r2   r2   r3   r?   -  s6    
zSqueezeBertForMaskedLM.forward)
NNNNNNNNNN)rA   rB   rC   Z_tied_weights_keysr   r   r   r   r   r+   r   r   r   r   r   r?   rE   r2   r2   r0   r3   r     s8   	          
r   z
    SqueezeBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    )Zcustom_introc                       s   e Zd Z fddZedeej eej eej eej eej eej eej ee ee ee e	e
ef dddZ  ZS )$SqueezeBertForSequenceClassificationc                    sR   t  | |j| _|| _t|| _t|j| _	t
|j| jj| _|   d S rG   )r   r   
num_labelsr/   r   r   r   r(   r)   r*   r   r&   
classifierr   r-   r0   r2   r3   r   j  s    
z-SqueezeBertForSequenceClassification.__init__Nr   c                 C   s|  |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}| |}d}|dur8| j jdu r| jdkrzd| j _n4| jdkr|jtj	ks|jtj
krd| j _nd| j _| j jdkrt }| jdkr|| | }n
|||}nN| j jdkrt }||d| j|d}n| j jdkr8t }|||}|
sh|f|dd  }|durd|f| S |S t|||j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr   rL   r   )r/   r   r   r*   r   Zproblem_typer   r5   r+   r9   rc   r   squeezer   ro   r   r   rY   r   )r.   r:   ry   r;   r   r   r<   r   rz   r   r   r   r   r   r   r   r\   r2   r2   r3   r?   v  sV    




"


z,SqueezeBertForSequenceClassification.forward)
NNNNNNNNNN)rA   rB   rC   r   r   r   r+   r   r   r   r   r   r?   rE   r2   r2   r0   r3   r   c  s2             
r   c                       s   e Zd Z fddZedeej eej eej eej eej eej eej ee ee ee e	e
ef dddZ  ZS )SqueezeBertForMultipleChoicec                    s@   t  | t|| _t|j| _t|j	d| _
|   d S )Nr   )r   r   r   r   r   r(   r)   r*   r   r&   r   r   r-   r0   r2   r3   r     s
    
z%SqueezeBertForMultipleChoice.__init__Nr   c                 C   st  |
dur|
n| j j}
|dur&|jd n|jd }|durJ|d|dnd}|durh|d|dnd}|dur|d|dnd}|dur|d|dnd}|dur|d|d|dnd}| j||||||||	|
d	}|d }| |}| |}|d|}d}|dur0t }|||}|
s`|f|dd  }|dur\|f| S |S t	|||j
|jdS )a[  
        input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where *num_choices* is the size of the second dimension of the input tensors. (see
            *input_ids* above)
        Nr   r   r   rL   r   )r/   r   shapero   r7   r   r*   r   r   r   rY   r   )r.   r:   ry   r;   r   r   r<   r   rz   r   r   Znum_choicesr   r   r   Zreshaped_logitsr   r   r\   r2   r2   r3   r?     sL    ,



z$SqueezeBertForMultipleChoice.forward)
NNNNNNNNNN)rA   rB   rC   r   r   r   r+   r   r   r   r   r   r?   rE   r2   r2   r0   r3   r     s2   
          
r   c                       s   e Zd Z fddZedeej eej eej eej eej eej eej ee ee ee e	e
ef dddZ  ZS )!SqueezeBertForTokenClassificationc                    sJ   t  | |j| _t|| _t|j| _t	|j
|j| _|   d S rG   )r   r   r   r   r   r   r(   r)   r*   r   r&   r   r   r-   r0   r2   r3   r   *  s    
z*SqueezeBertForTokenClassification.__init__Nr   c                 C   s   |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}| |}d}|durxt }||d| j|d}|
s|f|dd  }|dur|f| S |S t|||j	|j
dS )z
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        Nr   r   r   rL   r   )r/   r   r   r*   r   r   ro   r   r   rY   r   )r.   r:   ry   r;   r   r   r<   r   rz   r   r   r   r   r   r   r   r\   r2   r2   r3   r?   5  s8    

z)SqueezeBertForTokenClassification.forward)
NNNNNNNNNN)rA   rB   rC   r   r   r   r+   r   r   r   r   r   r?   rE   r2   r2   r0   r3   r   (  s2             
r   c                       s   e Zd Z fddZedeej eej eej eej eej eej eej eej ee ee ee e	e
ef dddZ  ZS )SqueezeBertForQuestionAnsweringc                    s<   t  | |j| _t|| _t|j|j| _| 	  d S rG   )
r   r   r   r   r   r   r   r&   
qa_outputsr   r-   r0   r2   r3   r   m  s
    
z(SqueezeBertForQuestionAnswering.__init__N)r:   ry   r;   r   r   r<   start_positionsend_positionsrz   r   r   r   c                 C   sP  |d ur|n| j j}| j|||||||	|
|d	}|d }| |}|jddd\}}|d }|d }d }|d ur|d urt| dkr|d}t| dkr|d}|d}|	d|}|	d|}t
|d}|||}|||}|| d }|s:||f|dd   }|d ur6|f| S |S t||||j|jdS )	Nr   r   r   r   r_   )Zignore_indexrL   )r   start_logits
end_logitsrY   r   )r/   r   r   r   splitr   rs   r   r7   clampr   r   rY   r   )r.   r:   ry   r;   r   r   r<   r   r   rz   r   r   r   r   r   r   r   Z
total_lossZignored_indexr   Z
start_lossZend_lossr\   r2   r2   r3   r?   w  sP    






z'SqueezeBertForQuestionAnswering.forward)NNNNNNNNNNN)rA   rB   rC   r   r   r   r+   r   r   r   r   r   r?   rE   r2   r2   r0   r3   r   k  s6   
           
r   )r   r   r   r   r   r   r|   r   )5rD   rw   typingr   r   r+   r   Ztorch.nnr   r   r   Zactivationsr	   Zmodeling_outputsr
   r   r   r   r   r   r   Zmodeling_utilsr   utilsr   r   Zconfiguration_squeezebertr   Z
get_loggerrA   loggerModuler   rF   r%   rJ   rO   rZ   r]   r|   r   r   r   r   r   r   r   r   r   r   r   r   __all__r2   r2   r2   r3   <module>   sP   $	
,Z*=
^IWgBM