a
    h                     @   s  d Z ddlZddlmZ ddlmZ ddlmZmZ ddl	Z	ddl
Z	ddl	mZ ddlmZmZmZ dd	lmZmZ e rdd
lmZ ddlmZ ddlmZ ddlmZmZmZmZmZmZm Z m!Z!m"Z" ddl#m$Z$ ddl%m&Z& ddlm'Z' ddl(m)Z) e'*e+Z,dd Z-dd Z.dd Z/G dd dej0Z1G dd dej0Z2G dd dej0Z3G dd  d ej0Z4G d!d" d"ej0Z5G d#d$ d$ej0Z6G d%d& d&eZ7G d'd( d(ej0Z8G d)d* d*ej0Z9G d+d, d,ej0Z:G d-d. d.ej0Z;G d/d0 d0ej0Z<G d1d2 d2ej0Z=G d3d4 d4ej0Z>eG d5d6 d6e$Z?eed7d8G d9d: d:eZ@eG d;d< d<e?ZAed=d8G d>d? d?e?ZBeG d@dA dAe?ZCedBd8G dCdD dDe?ZDedEd8G dFdG dGe?ZEeG dHdI dIe?ZFeG dJdK dKe?ZGeG dLdM dMe?ZHg dNZIdS )OzPyTorch FNet model.    N)	dataclass)partial)OptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )auto_docstringis_scipy_available)linalg)ACT2FN)GradientCheckpointingLayer)	BaseModelOutputBaseModelOutputWithPoolingMaskedLMOutputModelOutputMultipleChoiceModelOutputNextSentencePredictorOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel)apply_chunking_to_forward)logging   )
FNetConfigc                 C   s:   | j d }|d|d|f }| tj} td| ||S )z4Applies 2D matrix multiplication to 3D input arrays.r   Nzbij,jk,ni->bnk)shapetypetorch	complex64Zeinsum)xmatrix_dim_onematrix_dim_two
seq_length r&   b/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/fnet/modeling_fnet.py_two_dim_matmul8   s    
r(   c                 C   s   t | ||S N)r(   )r"   r#   r$   r&   r&   r'   two_dim_matmulA   s    r*   c                 C   s4   | }t t| jdd D ]}tjj||d}q|S )z
    Applies n-dimensional Fast Fourier Transform (FFT) to input array.

    Args:
        x: Input n-dimensional array.

    Returns:
        n-dimensional Fourier transform of input n-dimensional array.
    r   N)axis)reversedrangendimr    fft)r"   outr+   r&   r&   r'   fftnF   s    
r1   c                       s*   e Zd ZdZ fddZdddZ  ZS )FNetEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j|j| _t|j| _| jdt|jddd | jdtj| j tjddd d S )	N)padding_idxepsposition_ids)r   F)
persistenttoken_type_idsdtype)super__init__r   	Embedding
vocab_sizehidden_sizeZpad_token_idword_embeddingsmax_position_embeddingsposition_embeddingsZtype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsLinear
projectionDropouthidden_dropout_probdropoutregister_bufferr    Zarangeexpandzerosr6   sizelongselfconfig	__class__r&   r'   r=   Y   s    
zFNetEmbeddings.__init__Nc                 C   s   |d ur|  }n|  d d }|d }|d u rH| jd d d |f }|d u rt| dr| jd d d |f }||d |}|}ntj|tj| jjd}|d u r| 	|}| 
|}	||	 }
| |}|
|7 }
| |
}
| |
}
| |
}
|
S )Nr7   r   r9   r   r;   device)rO   r6   hasattrr9   rM   r    rN   rP   rW   rA   rD   rC   rE   rH   rK   )rR   	input_idsr9   r6   inputs_embedsinput_shaper%   buffered_token_type_ids buffered_token_type_ids_expandedrD   
embeddingsrC   r&   r&   r'   forwardo   s,    







zFNetEmbeddings.forward)NNNN)__name__
__module____qualname____doc__r=   r_   __classcell__r&   r&   rT   r'   r2   V   s   r2   c                       s,   e Zd Z fddZdd Zdd Z  ZS )FNetBasicFourierTransformc                    s   t    | | d S r)   )r<   r=   _init_fourier_transformrQ   rT   r&   r'   r=      s    
z"FNetBasicFourierTransform.__init__c                 C   s   |j sttjjdd| _n~|jdkrt r| dtj	t
|jtjd | dtj	t
|jtjd tt| j| jd| _qtd t| _nt| _d S )	N)r      dim   dft_mat_hiddenr:   dft_mat_seq)r#   r$   zpSciPy is needed for DFT matrix calculation and is not found. Using TPU optimized fast fourier transform instead.)use_tpu_fourier_optimizationsr   r    r/   r1   fourier_transformrB   r   rL   Ztensorr   Zdftr@   r!   tpu_short_seq_lengthr*   rl   rk   r   warningrQ   r&   r&   r'   rf      s$    


z1FNetBasicFourierTransform._init_fourier_transformc                 C   s   |  |j}|fS r)   )rn   real)rR   hidden_statesoutputsr&   r&   r'   r_      s    z!FNetBasicFourierTransform.forward)r`   ra   rb   r=   rf   r_   rd   r&   r&   rT   r'   re      s   re   c                       s$   e Zd Z fddZdd Z  ZS )FNetBasicOutputc                    s"   t    tj|j|jd| _d S Nr4   )r<   r=   r   rE   r@   rF   rQ   rT   r&   r'   r=      s    
zFNetBasicOutput.__init__c                 C   s   |  || }|S r)   )rE   rR   rr   input_tensorr&   r&   r'   r_      s    zFNetBasicOutput.forwardr`   ra   rb   r=   r_   rd   r&   r&   rT   r'   rt      s   rt   c                       s$   e Zd Z fddZdd Z  ZS )FNetFourierTransformc                    s"   t    t|| _t|| _d S r)   )r<   r=   re   rR   rt   outputrQ   rT   r&   r'   r=      s    

zFNetFourierTransform.__init__c                 C   s$   |  |}| |d |}|f}|S Nr   )rR   rz   )rR   rr   Zself_outputsfourier_outputrs   r&   r&   r'   r_      s    
zFNetFourierTransform.forwardrx   r&   r&   rT   r'   ry      s   ry   c                       s0   e Zd Z fddZejejdddZ  ZS )FNetIntermediatec                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S r)   )r<   r=   r   rG   r@   intermediate_sizedense
isinstance
hidden_actstrr   intermediate_act_fnrQ   rT   r&   r'   r=      s
    
zFNetIntermediate.__init__rr   returnc                 C   s   |  |}| |}|S r)   )r   r   rR   rr   r&   r&   r'   r_      s    

zFNetIntermediate.forwardr`   ra   rb   r=   r    Tensorr_   rd   r&   r&   rT   r'   r}      s   r}   c                       s4   e Zd Z fddZejejejdddZ  ZS )
FNetOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S ru   )r<   r=   r   rG   r~   r@   r   rE   rF   rI   rJ   rK   rQ   rT   r&   r'   r=      s    
zFNetOutput.__init__)rr   rw   r   c                 C   s&   |  |}| |}| || }|S r)   )r   rK   rE   rv   r&   r&   r'   r_      s    

zFNetOutput.forwardr   r&   r&   rT   r'   r      s   r   c                       s,   e Zd Z fddZdd Zdd Z  ZS )	FNetLayerc                    s:   t    |j| _d| _t|| _t|| _t|| _	d S Nr   )
r<   r=   chunk_size_feed_forwardseq_len_dimry   fourierr}   intermediater   rz   rQ   rT   r&   r'   r=      s    


zFNetLayer.__init__c                 C   s0   |  |}|d }t| j| j| j|}|f}|S r{   )r   r   feed_forward_chunkr   r   )rR   rr   Zself_fourier_outputsr|   layer_outputrs   r&   r&   r'   r_      s    
zFNetLayer.forwardc                 C   s   |  |}| ||}|S r)   )r   rz   )rR   r|   Zintermediate_outputr   r&   r&   r'   r     s    
zFNetLayer.feed_forward_chunk)r`   ra   rb   r=   r_   r   rd   r&   r&   rT   r'   r      s   r   c                       s&   e Zd Z fddZdddZ  ZS )FNetEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r&   )r   ).0_rS   r&   r'   
<listcomp>      z(FNetEncoder.__init__.<locals>.<listcomp>F)	r<   r=   rS   r   Z
ModuleListr-   Znum_hidden_layerslayerZgradient_checkpointingrQ   rT   r   r'   r=     s    
 zFNetEncoder.__init__FTc                 C   sr   |rdnd }t | jD ]&\}}|r,||f }||}|d }q|rL||f }|sftdd ||fD S t||dS )Nr&   r   c                 s   s   | ]}|d ur|V  qd S r)   r&   )r   vr&   r&   r'   	<genexpr>   r   z&FNetEncoder.forward.<locals>.<genexpr>)last_hidden_staterr   )	enumerater   tupler   )rR   rr   output_hidden_statesreturn_dictZall_hidden_statesiZlayer_moduleZlayer_outputsr&   r&   r'   r_     s    


zFNetEncoder.forward)FTrx   r&   r&   rT   r'   r   
  s   r   c                       s0   e Zd Z fddZejejdddZ  ZS )
FNetPoolerc                    s*   t    t|j|j| _t | _d S r)   )r<   r=   r   rG   r@   r   ZTanh
activationrQ   rT   r&   r'   r=   '  s    
zFNetPooler.__init__r   c                 C   s(   |d d df }|  |}| |}|S r{   )r   r   )rR   rr   Zfirst_token_tensorpooled_outputr&   r&   r'   r_   ,  s    

zFNetPooler.forwardr   r&   r&   rT   r'   r   &  s   r   c                       s0   e Zd Z fddZejejdddZ  ZS )FNetPredictionHeadTransformc                    sV   t    t|j|j| _t|jtr6t	|j | _
n|j| _
tj|j|jd| _d S ru   )r<   r=   r   rG   r@   r   r   r   r   r   transform_act_fnrE   rF   rQ   rT   r&   r'   r=   7  s    
z$FNetPredictionHeadTransform.__init__r   c                 C   s"   |  |}| |}| |}|S r)   )r   r   rE   r   r&   r&   r'   r_   @  s    


z#FNetPredictionHeadTransform.forwardr   r&   r&   rT   r'   r   6  s   	r   c                       s2   e Zd Z fddZdd ZddddZ  ZS )	FNetLMPredictionHeadc                    sH   t    t|| _t|j|j| _t	t
|j| _| j| j_d S r)   )r<   r=   r   	transformr   rG   r@   r?   decoder	Parameterr    rN   biasrQ   rT   r&   r'   r=   H  s
    

zFNetLMPredictionHead.__init__c                 C   s   |  |}| |}|S r)   )r   r   r   r&   r&   r'   r_   S  s    

zFNetLMPredictionHead.forwardN)r   c                 C   s*   | j jjjdkr| j| j _n
| j j| _d S )Nmeta)r   r   rW   r   rR   r&   r&   r'   _tie_weightsX  s    z!FNetLMPredictionHead._tie_weights)r`   ra   rb   r=   r_   r   rd   r&   r&   rT   r'   r   G  s   r   c                       s$   e Zd Z fddZdd Z  ZS )FNetOnlyMLMHeadc                    s   t    t|| _d S r)   )r<   r=   r   predictionsrQ   rT   r&   r'   r=   b  s    
zFNetOnlyMLMHead.__init__c                 C   s   |  |}|S r)   )r   )rR   sequence_outputprediction_scoresr&   r&   r'   r_   f  s    
zFNetOnlyMLMHead.forwardrx   r&   r&   rT   r'   r   a  s   r   c                       s$   e Zd Z fddZdd Z  ZS )FNetOnlyNSPHeadc                    s   t    t|jd| _d S Nrg   )r<   r=   r   rG   r@   seq_relationshiprQ   rT   r&   r'   r=   m  s    
zFNetOnlyNSPHead.__init__c                 C   s   |  |}|S r)   )r   )rR   r   seq_relationship_scorer&   r&   r'   r_   q  s    
zFNetOnlyNSPHead.forwardrx   r&   r&   rT   r'   r   l  s   r   c                       s$   e Zd Z fddZdd Z  ZS )FNetPreTrainingHeadsc                    s(   t    t|| _t|jd| _d S r   )r<   r=   r   r   r   rG   r@   r   rQ   rT   r&   r'   r=   x  s    

zFNetPreTrainingHeads.__init__c                 C   s   |  |}| |}||fS r)   )r   r   )rR   r   r   r   r   r&   r&   r'   r_   }  s    

zFNetPreTrainingHeads.forwardrx   r&   r&   rT   r'   r   w  s   r   c                   @   s&   e Zd ZU eed< dZdZdd ZdS )FNetPreTrainedModelrS   fnetTc                 C   s   t |tjr:|jjjd| jjd |jdur|jj	  nft |tj
rz|jjjd| jjd |jdur|jj|j 	  n&t |tjr|jj	  |jjd dS )zInitialize the weightsg        )meanZstdNg      ?)r   r   rG   weightdataZnormal_rS   Zinitializer_ranger   Zzero_r>   r3   rE   Zfill_)rR   moduler&   r&   r'   _init_weights  s    

z!FNetPreTrainedModel._init_weightsN)r`   ra   rb   r   __annotations__Zbase_model_prefixZsupports_gradient_checkpointingr   r&   r&   r&   r'   r     s   
r   z0
    Output type of [`FNetForPreTraining`].
    )Zcustom_introc                   @   s^   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeej  ed< dS )FNetForPreTrainingOutputa  
    loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
        Total loss as the sum of the masked language modeling loss and the next sequence prediction
        (classification) loss.
    prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
    seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
        Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
        before SoftMax).
    Nlossprediction_logitsseq_relationship_logitsrr   )r`   ra   rb   rc   r   r   r    FloatTensorr   r   r   rr   r   r&   r&   r&   r'   r     s
   
r   c                       sz   e Zd ZdZd fdd	Zdd Zdd Zedee	j
 ee	j
 ee	j
 ee	j ee ee eeef d
ddZ  ZS )	FNetModelz

    The model can behave as an encoder, following the architecture described in [FNet: Mixing Tokens with Fourier
    Transforms](https://huggingface.co/papers/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon.

    Tc                    sD   t  | || _t|| _t|| _|r2t|nd| _| 	  dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        N)
r<   r=   rS   r2   r^   r   encoderr   pooler	post_init)rR   rS   Zadd_pooling_layerrT   r&   r'   r=     s    

zFNetModel.__init__c                 C   s   | j jS r)   r^   rA   r   r&   r&   r'   get_input_embeddings  s    zFNetModel.get_input_embeddingsc                 C   s   || j _d S r)   r   )rR   valuer&   r&   r'   set_input_embeddings  s    zFNetModel.set_input_embeddingsN)rY   r9   r6   rZ   r   r   r   c                 C   s~  |d ur|n| j j}|d ur |n| j j}|d urB|d urBtdnD|d ur\| }|\}}	n*|d ur~| d d }|\}}	ntd| j jr|	dkr| j j|	krtd|d ur|jn|j}
|d u rt| j	dr| j	j
d d d |	f }|||	}|}ntj|tj|
d}| j	||||d}| j|||d	}|d
 }| jd urP| |nd }|sn||f|dd   S t|||jdS )NzDYou cannot specify both input_ids and inputs_embeds at the same timer7   z5You have to specify either input_ids or inputs_embedsrj   zThe `tpu_short_seq_length` in FNetConfig should be set equal to the sequence length being passed to the model when using TPU optimizations.r9   rV   )rY   r6   r9   rZ   )r   r   r   r   )r   pooler_outputrr   )rS   r   use_return_dict
ValueErrorrO   rm   ro   rW   rX   r^   r9   rM   r    rN   rP   r   r   r   rr   )rR   rY   r9   r6   rZ   r   r   r[   Z
batch_sizer%   rW   r\   r]   Zembedding_outputZencoder_outputsr   r   r&   r&   r'   r_     s`    




zFNetModel.forward)T)NNNNNN)r`   ra   rb   rc   r=   r   r   r   r   r    Z
LongTensorr   boolr   r   r   r_   rd   r&   r&   rT   r'   r     s(         
r   z
    FNet Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
    sentence prediction (classification)` head.
    c                       s   e Zd ZddgZ fddZdd Zdd Zedee	j
 ee	j
 ee	j
 ee	j
 ee	j
 ee	j
 ee ee eeef d
	ddZ  ZS )FNetForPreTrainingcls.predictions.decoder.biascls.predictions.decoder.weightc                    s,   t  | t|| _t|| _|   d S r)   )r<   r=   r   r   r   clsr   rQ   rT   r&   r'   r=   "  s    

zFNetForPreTraining.__init__c                 C   s
   | j jjS r)   r   r   r   r   r&   r&   r'   get_output_embeddings+  s    z(FNetForPreTraining.get_output_embeddingsc                 C   s   || j j_|j| j j_d S r)   r   r   r   r   rR   Znew_embeddingsr&   r&   r'   set_output_embeddings.  s    
z(FNetForPreTraining.set_output_embeddingsN)	rY   r9   r6   rZ   labelsnext_sentence_labelr   r   r   c	                 C   s   |dur|n| j j}| j||||||d}	|	dd \}
}| |
|\}}d}|dur|durt }||d| j j|d}||dd|d}|| }|s||f|	dd  }|dur|f| S |S t||||	jdS )aH  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see `input_ids` docstring) Indices should be in `[0, 1]`:

            - 0 indicates sequence B is a continuation of sequence A,
            - 1 indicates sequence B is a random sequence.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, FNetForPreTraining
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
        >>> model = FNetForPreTraining.from_pretrained("google/fnet-base")
        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> prediction_logits = outputs.prediction_logits
        >>> seq_relationship_logits = outputs.seq_relationship_logits
        ```Nr9   r6   rZ   r   r   rg   r7   )r   r   r   rr   )	rS   r   r   r   r   viewr?   r   rr   )rR   rY   r9   r6   rZ   r   r   r   r   rs   r   r   r   r   
total_lossloss_fctmasked_lm_lossnext_sentence_lossrz   r&   r&   r'   r_   2  s4    %	zFNetForPreTraining.forward)NNNNNNNN)r`   ra   rb   _tied_weights_keysr=   r   r   r   r   r    r   r   r   r   r   r_   rd   r&   r&   rT   r'   r     s0   	        
r   c                       s   e Zd ZddgZ fddZdd Zdd Zedee	j
 ee	j
 ee	j
 ee	j
 ee	j
 ee ee eeef d
ddZ  ZS )FNetForMaskedLMr   r   c                    s,   t  | t|| _t|| _|   d S r)   )r<   r=   r   r   r   r   r   rQ   rT   r&   r'   r=   |  s    

zFNetForMaskedLM.__init__c                 C   s
   | j jjS r)   r   r   r&   r&   r'   r     s    z%FNetForMaskedLM.get_output_embeddingsc                 C   s   || j j_|j| j j_d S r)   r   r   r&   r&   r'   r     s    
z%FNetForMaskedLM.set_output_embeddingsNrY   r9   r6   rZ   r   r   r   r   c                 C   s   |dur|n| j j}| j||||||d}|d }	| |	}
d}|durjt }||
d| j j|d}|s|
f|dd  }|dur|f| S |S t||
|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        Nr   r   r7   rg   r   logitsrr   )	rS   r   r   r   r   r   r?   r   rr   )rR   rY   r9   r6   rZ   r   r   r   rs   r   r   r   r   rz   r&   r&   r'   r_     s&    	
zFNetForMaskedLM.forward)NNNNNNN)r`   ra   rb   r   r=   r   r   r   r   r    r   r   r   r   r   r_   rd   r&   r&   rT   r'   r   x  s,   	       
r   zT
    FNet Model with a `next sentence prediction (classification)` head on top.
    c                       sl   e Zd Z fddZedeej eej eej eej eej ee ee e	e
ef dddZ  ZS )FNetForNextSentencePredictionc                    s,   t  | t|| _t|| _|   d S r)   )r<   r=   r   r   r   r   r   rQ   rT   r&   r'   r=     s    

z&FNetForNextSentencePrediction.__init__Nr   c                 K   s   d|v rt dt |d}|dur*|n| jj}| j||||||d}	|	d }
| |
}d}|durt }||	dd|	d}|s|f|	dd  }|dur|f| S |S t
|||	jdS )	a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see `input_ids` docstring). Indices should be in `[0, 1]`:

            - 0 indicates sequence B is a continuation of sequence A,
            - 1 indicates sequence B is a random sequence.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, FNetForNextSentencePrediction
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
        >>> model = FNetForNextSentencePrediction.from_pretrained("google/fnet-base")
        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
        >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
        >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
        >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
        >>> logits = outputs.logits
        >>> assert logits[0, 0] < logits[0, 1]  # next sentence was random
        ```r   zoThe `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.Nr   r   r7   rg   r   )warningswarnFutureWarningpoprS   r   r   r   r   r   r   rr   )rR   rY   r9   r6   rZ   r   r   r   kwargsrs   r   Zseq_relationship_scoresr   r   rz   r&   r&   r'   r_     s:    $
	
z%FNetForNextSentencePrediction.forward)NNNNNNN)r`   ra   rb   r=   r   r   r    r   r   r   r   r   r_   rd   r&   r&   rT   r'   r     s&   	       
r   z
    FNet Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    c                       sl   e Zd Z fddZedeej eej eej eej eej ee ee e	e
ef dddZ  ZS )FNetForSequenceClassificationc                    sJ   t  | |j| _t|| _t|j| _t	|j
|j| _|   d S r)   r<   r=   
num_labelsr   r   r   rI   rJ   rK   rG   r@   
classifierr   rQ   rT   r&   r'   r=     s    
z&FNetForSequenceClassification.__init__Nr   c                 C   sr  |dur|n| j j}| j||||||d}|d }	| |	}	| |	}
d}|dur2| j jdu r| jdkrtd| j _n4| jdkr|jtj	ks|jtj
krd| j _nd| j _| j jdkrt }| jdkr||
 | }n
||
|}nN| j jdkrt }||
d| j|d}n| j jdkr2t }||
|}|sb|
f|dd  }|dur^|f| S |S t||
|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr7   rg   r   )rS   r   r   rK   r   Zproblem_typer   r;   r    rP   intr	   squeezer   r   r   r   rr   )rR   rY   r9   r6   rZ   r   r   r   rs   r   r   r   r   rz   r&   r&   r'   r_   #  sF    	




"


z%FNetForSequenceClassification.forward)NNNNNNN)r`   ra   rb   r=   r   r   r    r   r   r   r   r   r_   rd   r&   r&   rT   r'   r     s&          
r   c                       sl   e Zd Z fddZedeej eej eej eej eej ee ee e	e
ef dddZ  ZS )FNetForMultipleChoicec                    s@   t  | t|| _t|j| _t|j	d| _
|   d S r   )r<   r=   r   r   r   rI   rJ   rK   rG   r@   r   r   rQ   rT   r&   r'   r=   b  s
    
zFNetForMultipleChoice.__init__Nr   c                 C   sL  |dur|n| j j}|dur&|jd n|jd }|durJ|d|dnd}|durh|d|dnd}|dur|d|dnd}|dur|d|d|dnd}| j||||||d}	|	d }
| |
}
| |
}|d|}d}|durt }|||}|s<|f|	dd  }|dur8|f| S |S t	|||	j
dS )a[  
        input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
            model's internal embedding lookup matrix.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        Nr   r7   r   rg   r   )rS   r   r   r   rO   r   rK   r   r   r   rr   )rR   rY   r9   r6   rZ   r   r   r   Znum_choicesrs   r   r   Zreshaped_logitsr   r   rz   r&   r&   r'   r_   l  s:    )	



zFNetForMultipleChoice.forward)NNNNNNN)r`   ra   rb   r=   r   r   r    r   r   r   r   r   r_   rd   r&   r&   rT   r'   r   `  s&   
       
r   c                       sl   e Zd Z fddZedeej eej eej eej eej ee ee e	e
ef dddZ  ZS )FNetForTokenClassificationc                    sJ   t  | |j| _t|| _t|j| _t	|j
|j| _|   d S r)   r   rQ   rT   r&   r'   r=     s    
z#FNetForTokenClassification.__init__Nr   c                 C   s   |dur|n| j j}| j||||||d}|d }	| |	}	| |	}
d}|durrt }||
d| j|d}|s|
f|dd  }|dur|f| S |S t||
|j	dS )z
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        Nr   r   r7   rg   r   )
rS   r   r   rK   r   r   r   r   r   rr   )rR   rY   r9   r6   rZ   r   r   r   rs   r   r   r   r   rz   r&   r&   r'   r_     s(    	

z"FNetForTokenClassification.forward)NNNNNNN)r`   ra   rb   r=   r   r   r    r   r   r   r   r   r_   rd   r&   r&   rT   r'   r     s&          
r   c                       st   e Zd Z fddZedeej eej eej eej eej eej ee ee e	e
ef d	ddZ  ZS )FNetForQuestionAnsweringc                    s<   t  | |j| _t|| _t|j|j| _| 	  d S r)   )
r<   r=   r   r   r   r   rG   r@   
qa_outputsr   rQ   rT   r&   r'   r=     s
    
z!FNetForQuestionAnswering.__init__N)	rY   r9   r6   rZ   start_positionsend_positionsr   r   r   c	                 C   sB  |d ur|n| j j}| j||||||d}	|	d }
| |
}|jddd\}}|d }|d }d }|d ur|d urt| dkr|d}t| dkr|d}|d}|	d|}|	d|}t
|d}|||}|||}|| d }|s0||f|	dd   }|d ur,|f| S |S t||||	jdS )	Nr   r   r   r7   rh   )Zignore_indexrg   )r   start_logits
end_logitsrr   )rS   r   r   r   splitr   
contiguouslenrO   clampr   r   rr   )rR   rY   r9   r6   rZ   r   r   r   r   rs   r   r   r   r   r   Zignored_indexr   Z
start_lossZend_lossrz   r&   r&   r'   r_     sB    	







z FNetForQuestionAnswering.forward)NNNNNNNN)r`   ra   rb   r=   r   r   r    r   r   r   r   r   r_   rd   r&   r&   rT   r'   r     s*           
r   )
r   r   r   r   r   r   r   r   r   r   )Jrc   r   dataclassesr   	functoolsr   typingr   r   r    Ztorch.utils.checkpointr   Ztorch.nnr   r   r	   utilsr   r   Zscipyr   Zactivationsr   Zmodeling_layersr   Zmodeling_outputsr   r   r   r   r   r   r   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   Zconfiguration_fnetr   Z
get_loggerr`   loggerr(   r*   r1   Moduler2   re   rt   ry   r}   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   __all__r&   r&   r&   r'   <module>   s   ,
	=&
eY>UI[9D