a
    h                    @   s  d Z ddlZddlZddlZddlmZ ddlmZmZ ddl	Z	ddl
Z	ddl	mZ ddlmZmZmZ ddlmZ dd	lmZ dd
lmZmZmZmZmZmZ ddlmZ ddlmZmZm Z  ddl!m"Z"m#Z# ddl$m%Z% e#&e'Z(g dZ)ee"ddG dd deZ*dd Z+G dd dej,Z-G dd dej,Z.G dd dej,Z/G dd dej,Z0G dd  d ej,Z1G d!d" d"ej,Z2G d#d$ d$ej,Z3G d%d& d&ej,Z4G d'd( d(eZ5G d)d* d*ej,Z6G d+d, d,ej,Z7G d-d. d.ej,Z8G d/d0 d0ej,Z9G d1d2 d2ej,Z:e"G d3d4 d4eZ;e"G d5d6 d6e;Z<e"d7dG d8d9 d9e;Z=e"G d:d; d;e;Z>e"G d<d= d=e;Z?e"G d>d? d?e;Z@g d@ZAdS )AzPyTorch CANINE model.    N)	dataclass)OptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputModelOutputMultipleChoiceModelOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)auto_docstringlogging   )CanineConfig)   +   ;   =   I   a   g   q                           a  
    Output type of [`CanineModel`]. Based on [`~modeling_outputs.BaseModelOutputWithPooling`], but with slightly
    different `hidden_states` and `attentions`, as these also include the hidden states and attentions of the shallow
    Transformer encoders.
    )Zcustom_introc                   @   sb   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dS )CanineModelOutputWithPoolinga  
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
        Sequence of hidden-states at the output of the last layer of the model (i.e. the output of the final
        shallow Transformer encoder).
    pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
        Hidden-state of the first token of the sequence (classification token) at the last layer of the deep
        Transformer encoder, further processed by a Linear layer and a Tanh activation function. The Linear layer
        weights are trained from the next sentence prediction (classification) objective during pretraining.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the input to each encoder + one for the output of each layer of each
        encoder) of shape `(batch_size, sequence_length, hidden_size)` and `(batch_size, sequence_length //
        config.downsampling_rate, hidden_size)`. Hidden-states of the model at the output of each layer plus the
        initial input to each Transformer encoder. The hidden states of the shallow encoders have length
        `sequence_length`, but the hidden states of the deep encoder have length `sequence_length` //
        `config.downsampling_rate`.
    attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of the 3 Transformer encoders of shape `(batch_size,
        num_heads, sequence_length, sequence_length)` and `(batch_size, num_heads, sequence_length //
        config.downsampling_rate, sequence_length // config.downsampling_rate)`. Attentions weights after the
        attention softmax, used to compute the weighted average in the self-attention heads.
    Nlast_hidden_statepooler_outputhidden_states
attentions)__name__
__module____qualname____doc__r+   r   torchFloatTensor__annotations__r,   r-   tupler.    r7   r7   f/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/canine/modeling_canine.pyr*   3   s
   
	r*   c                 C   sN  zddl }ddl}ddl}W n ty:   td  Y n0 tj|}t	d|  |j
|}g }g }	|D ]@\}
}t	d|
 d|  |j
||
}||
 |	| qpt||	D ]\}
}|
d}
tdd	 |
D rt	d
d|
  q|
d dkrd|
d< n|
d dkr0|
|
d  nh|
d dkrHd|
d< nP|
d dkrjdg|
dd  }
n.|
d dkr|
d dv rdg|
dd  }
| }|
D ]}|d|rd|vr|d|}n|g}|d dks|d dkrt|d}n|d dks|d dkr"t|d}n^|d d kr<t|d}nDzt||d }W n0 ty~   t	d
d|
  Y qY n0 t|d!krt|d }|| }q|d"d d#krt|d}n@|d$d d%d& td'D v rt|d}n|dkr||}|j|jkr,td(|j d)|j d*t	d+|
  t||_q| S ),z'Load tf checkpoints in a pytorch model.r   NzLoading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions.z&Converting TensorFlow checkpoint from zLoading TF weight z with shape /c                 s   s   | ]}|d v V  qdS ))Zadam_vZadam_mZAdamWeightDecayOptimizerZAdamWeightDecayOptimizer_1Zglobal_stepclsZautoregressive_decoderZchar_output_weightsNr7   ).0nr7   r7   r8   	<genexpr>v   s   z,load_tf_weights_in_canine.<locals>.<genexpr>z	Skipping Zbertencoderr   
embeddingsZsegment_embeddingstoken_type_embeddingsinitial_char_encoderchars_to_moleculesfinal_char_encoder)	LayerNormconv
projectionz[A-Za-z]+_\d+ZEmbedderz_(\d+)ZkernelgammaweightZoutput_biasbetabiasZoutput_weights   iZ_embeddingsic                 S   s   g | ]}d | qS )Z	Embedder_r7   )r;   ir7   r7   r8   
<listcomp>       z-load_tf_weights_in_canine.<locals>.<listcomp>   zPointer shape z and array shape z mismatchedzInitialize PyTorch weight )renumpyZ
tensorflowImportErrorloggererrorospathabspathinfotrainZlist_variablesZload_variableappendzipsplitanyjoinremove	fullmatchgetattrAttributeErrorlenintrange	transposeshape
ValueErrorr3   Z
from_numpydata)modelconfigZtf_checkpoint_pathrQ   nptfZtf_pathZ	init_varsnamesZarraysnamerh   arrayZpointerZm_nameZscope_namesnumr7   r7   r8   load_tf_weights_in_canineX   s    




 

rs   c                       st   e Zd ZdZ fddZeedddZeeeddd	Zdee	j
 ee	j
 ee	j
 ee	j e	jdddZ  ZS )CanineEmbeddingsz<Construct the character, position and token_type embeddings.c                    s   t    || _|j|j }t|jD ]$}d| }t| |t|j	| q&t|j	|j| _
t|j|j| _tj|j|jd| _t|j| _| jdt|jddd t|dd| _d S )	NHashBucketCodepointEmbedder_epsposition_ids)r   F)
persistentposition_embedding_typeabsolute)super__init__rl   hidden_sizenum_hash_functionsrf   setattrr   	Embeddingnum_hash_bucketschar_position_embeddingsZtype_vocab_sizer@   rE   layer_norm_epsDropouthidden_dropout_probdropoutZregister_bufferr3   arangemax_position_embeddingsexpandrb   r{   )selfrl   Zshard_embedding_sizerM   rp   	__class__r7   r8   r~      s    

zCanineEmbeddings.__init__
num_hashesnum_bucketsc                 C   sV   |t tkrtdt t td| }g }|D ]}|d | | }|| q2|S )a  
        Converts ids to hash bucket ids via multiple hashing.

        Args:
            input_ids: The codepoints or other IDs to be hashed.
            num_hashes: The number of hash functions to use.
            num_buckets: The number of hash buckets (i.e. embeddings in each table).

        Returns:
            A list of tensors, each of which is the hash bucket IDs from one hash function.
        z`num_hashes` must be <= Nr   )rd   _PRIMESri   r[   )r   	input_idsr   r   ZprimesZresult_tensorsprimehashedr7   r7   r8   _hash_bucket_tensors   s    z%CanineEmbeddings._hash_bucket_tensors)embedding_sizer   r   c                 C   sx   || dkr"t d| d| d| j|||d}g }t|D ]*\}}d| }	t| |	|}
||
 q>tj|ddS )	zDConverts IDs (e.g. codepoints) into embeddings via multiple hashing.r   zExpected `embedding_size` (z) % `num_hashes` (z) == 0r   ru   ry   dim)ri   r   	enumeraterb   r[   r3   cat)r   r   r   r   r   Zhash_bucket_tensorsZembedding_shardsrM   Zhash_bucket_idsrp   Zshard_embeddingsr7   r7   r8   _embed_hash_buckets   s    
z$CanineEmbeddings._embed_hash_bucketsN)r   token_type_idsrx   inputs_embedsreturnc           
      C   s   |d ur|  }n|  d d }|d }|d u rH| jd d d |f }|d u rftj|tj| jjd}|d u r| || jj| jj	| jj
}| |}|| }| jdkr| |}	||	7 }| |}| |}|S )Nry   r   dtypedevicer|   )sizerx   r3   zeroslongr   r   rl   r   r   r   r@   r{   r   rE   r   )
r   r   r   rx   r   input_shape
seq_lengthr@   r?   Zposition_embeddingsr7   r7   r8   forward   s(    





zCanineEmbeddings.forward)NNNN)r/   r0   r1   r2   r~   re   r   r   r   r3   
LongTensorr4   r   __classcell__r7   r7   r   r8   rt      s       rt   c                       s4   e Zd ZdZ fddZejejdddZ  ZS )CharactersToMoleculeszeConvert character sequence to initial molecule sequence (i.e. downsample) using strided convolutions.c                    sJ   t    tj|j|j|j|jd| _t|j | _	tj
|j|jd| _
d S )NZin_channelsZout_channelskernel_sizestriderv   )r}   r~   r   Conv1dr   downsampling_raterF   r
   
hidden_act
activationrE   r   r   rl   r   r7   r8   r~   !  s    
zCharactersToMolecules.__init__)char_encodingr   c                 C   s   |d d ddd d f }t |dd}| |}t |dd}| |}|d d ddd d f }t j||gdd}| |}|S )Nr   r   rL   ry   r   )r3   rg   rF   r   r   rE   )r   r   Zcls_encodingZdownsampledZdownsampled_truncatedresultr7   r7   r8   r   0  s    


zCharactersToMolecules.forward)	r/   r0   r1   r2   r~   r3   Tensorr   r   r7   r7   r   r8   r     s   r   c                       s>   e Zd ZdZ fddZdejeej ejdddZ  Z	S )	ConvProjectionz
    Project representations from hidden_size*2 back to hidden_size across a window of w = config.upsampling_kernel_size
    characters.
    c                    s`   t    || _tj|jd |j|jdd| _t|j	 | _
tj|j|jd| _t|j| _d S )NrL   r   r   rv   )r}   r~   rl   r   r   r   upsampling_kernel_sizerF   r
   r   r   rE   r   r   r   r   r   r   r7   r8   r~   R  s    
zConvProjection.__init__N)inputsfinal_seq_char_positionsr   c           
      C   s   t |dd}| jjd }|d }|| }t||fd}| ||}t |dd}| |}| |}| 	|}|}|d urt
dn|}	|	S )Nr   rL   r   z,CanineForMaskedLM is currently not supported)r3   rg   rl   r   r   ZConstantPad1drF   r   rE   r   NotImplementedError)
r   r   r   Z	pad_totalZpad_begZpad_endpadr   Zfinal_char_seqZ	query_seqr7   r7   r8   r   a  s    



zConvProjection.forward)N)
r/   r0   r1   r2   r~   r3   r   r   r   r   r7   r7   r   r8   r   L  s    r   c                
       sZ   e Zd Z fddZdejejeej eej ee e	ejeej f dddZ
  ZS )	CanineSelfAttentionc                    s   t    |j|j dkr>t|ds>td|j d|j d|j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _t|dd| _| jdks| jd	kr|j| _t	d
|j d | j| _d S )Nr   r   zThe hidden size (z6) is not a multiple of the number of attention heads ()r{   r|   relative_keyrelative_key_queryrL   r   )r}   r~   r   num_attention_headshasattrri   re   attention_head_sizeall_head_sizer   Linearquerykeyvaluer   Zattention_probs_dropout_probr   rb   r{   r   r   distance_embeddingr   r   r7   r8   r~     s$    

zCanineSelfAttention.__init__NF)from_tensor	to_tensorattention_mask	head_maskoutput_attentionsr   c                 C   s>  |j \}}}| ||d| j| jdd}	| ||d| j| jdd}
| ||d| j| jdd}t	||	dd}| j
dks| j
dkrb| d }tj|tj|jddd}tj|tj|jddd}|| }| || j d }|j|jd}| j
dkr.td	||}|| }n4| j
dkrbtd	||}td
|	|}|| | }|t| j }|d ur|jdkrtj|dd}d|  t|jj }|| }tjj|dd}| |}|d ur|| }t	||
}|dddd  }| d d | j!f }|j| }|r4||fn|f}|S )Nry   r   rL   rC   r   r   r   )r   zbhld,lrd->bhlrzbhrd,lrd->bhlrr	   r         ?r   )"rh   r   viewr   r   rg   r   r   r3   matmulr{   r   r   r   r   r   r   tor   ZeinsummathsqrtndimZ	unsqueezefloatZfinfominr   Z
functionalZsoftmaxr   Zpermute
contiguousr   )r   r   r   r   r   r   
batch_sizer   _Z	key_layerZvalue_layerZquery_layerZattention_scoresZposition_ids_lZposition_ids_rZdistanceZpositional_embeddingZrelative_position_scoresZrelative_position_scores_queryZrelative_position_scores_keyZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr7   r7   r8   r     sd    







zCanineSelfAttention.forward)NNF)r/   r0   r1   r~   r3   r   r   r4   boolr6   r   r   r7   r7   r   r8   r     s      r   c                       sB   e Zd Z fddZeej ejeejejf dddZ  ZS )CanineSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S Nrv   )r}   r~   r   r   r   denserE   r   r   r   r   r   r   r7   r8   r~     s    
zCanineSelfOutput.__init__r-   input_tensorr   c                 C   s&   |  |}| |}| || }|S Nr   r   rE   r   r-   r   r7   r7   r8   r     s    

zCanineSelfOutput.forward	r/   r0   r1   r~   r6   r3   r4   r   r   r7   r7   r   r8   r     s   r   c                	       sx   e Zd ZdZdeeeeeed fddZdd Zdee	j
 ee	j
 ee	j
 ee ee	j
ee	j
 f d
ddZ  ZS )CanineAttentionav  
    Additional arguments related to local attention:

        - **local** (`bool`, *optional*, defaults to `False`) -- Whether to apply local attention.
        - **always_attend_to_first_position** (`bool`, *optional*, defaults to `False`) -- Should all blocks be able to
          attend
        to the `to_tensor`'s first position (e.g. a [CLS] position)? - **first_position_attends_to_all** (`bool`,
        *optional*, defaults to `False`) -- Should the *from_tensor*'s first position be able to attend to all
        positions within the *from_tensor*? - **attend_from_chunk_width** (`int`, *optional*, defaults to 128) -- The
        width of each block-wise chunk in `from_tensor`. - **attend_from_chunk_stride** (`int`, *optional*, defaults to
        128) -- The number of elements to skip when moving to the next block in `from_tensor`. -
        **attend_to_chunk_width** (`int`, *optional*, defaults to 128) -- The width of each block-wise chunk in
        *to_tensor*. - **attend_to_chunk_stride** (`int`, *optional*, defaults to 128) -- The number of elements to
        skip when moving to the next block in `to_tensor`.
    F   )always_attend_to_first_positionfirst_position_attends_to_allattend_from_chunk_widthattend_from_chunk_strideattend_to_chunk_widthattend_to_chunk_stridec	           	         st   t    t|| _t|| _t | _|| _||k r<t	d||k rLt	d|| _
|| _|| _|| _|| _|| _d S )Nze`attend_from_chunk_width` < `attend_from_chunk_stride` would cause sequence positions to get skipped.z``attend_to_chunk_width` < `attend_to_chunk_stride`would cause sequence positions to get skipped.)r}   r~   r   r   r   outputsetpruned_headslocalri   r   r   r   r   r   r   	r   rl   r   r   r   r   r   r   r   r   r7   r8   r~     s&    


zCanineAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r   )rd   r   r   r   r   r   r   r   r   r   r   r   r   union)r   headsindexr7   r7   r8   prune_heads2  s    zCanineAttention.prune_headsNr-   r   r   r   r   c                 C   sL  | j s$| |||||}|d }n|jd  }}| }	}
g }| jrT|d d}nd}t||| jD ]"}t||| j }|||f qfg }| jr|d|f td|| j	D ]"}t||| j
 }|||f qt|t|krtd| d| dg }g }t||D ]\\}}\}}|	d d ||d d f }|
d d ||d d f }|d d ||||f }| jr|d d ||ddf }tj||gdd}|
d d ddd d f }tj||gdd}| |||||}||d  |r||d  qtj|dd}| ||}|f}| j s<||dd   }n|t| }|S )	Nr   r   )r   r   z/Expected to have same number of `from_chunks` (z) and `to_chunks` (z). Check strides.rL   r   )r   r   rh   r   r[   rf   r   r   r   r   r   rd   ri   r\   r   r3   r   r   r6   )r   r-   r   r   r   Zself_outputsattention_outputfrom_seq_lengthto_seq_lengthr   r   Zfrom_chunksZ
from_startZchunk_startZ	chunk_endZ	to_chunksZattention_output_chunksZattention_probs_chunksZfrom_endZto_startZto_endZfrom_tensor_chunkZto_tensor_chunkZattention_mask_chunkZcls_attention_maskZcls_positionZattention_outputs_chunkr   r7   r7   r8   r   D  sb    

zCanineAttention.forward)FFFr   r   r   r   )NNF)r/   r0   r1   r2   r   re   r~   r   r6   r3   r4   r   r   r   r7   r7   r   r8   r      s6          !   r   c                       s0   e Zd Z fddZejejdddZ  ZS )CanineIntermediatec                    sB   t    t|j|j| _t|jt	r6t
|j | _n|j| _d S r   )r}   r~   r   r   r   intermediate_sizer   
isinstancer   strr
   intermediate_act_fnr   r   r7   r8   r~     s
    
zCanineIntermediate.__init__r-   r   c                 C   s   |  |}| |}|S r   )r   r   r   r-   r7   r7   r8   r     s    

zCanineIntermediate.forward)r/   r0   r1   r~   r3   r4   r   r   r7   r7   r   r8   r     s   r   c                       s8   e Zd Z fddZeej ejejdddZ  ZS )CanineOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S r   )r}   r~   r   r   r   r   r   rE   r   r   r   r   r   r   r7   r8   r~     s    
zCanineOutput.__init__r   c                 C   s&   |  |}| |}| || }|S r   r   r   r7   r7   r8   r     s    

zCanineOutput.forwardr   r7   r7   r   r8   r    s   r  c                	       sb   e Zd Z fddZd
eej eej eej ee eejeej f dddZ	dd	 Z
  ZS )CanineLayerc	           	   	      sH   t    |j| _d| _t||||||||| _t|| _t|| _	d S Nr   )
r}   r~   chunk_size_feed_forwardseq_len_dimr   	attentionr   intermediater  r   r   r   r7   r8   r~     s    


zCanineLayer.__init__NFr   c           	      C   sH   | j ||||d}|d }|dd  }t| j| j| j|}|f| }|S )N)r   r   r   )r  r   feed_forward_chunkr  r  )	r   r-   r   r   r   Zself_attention_outputsr   r   layer_outputr7   r7   r8   r     s    
zCanineLayer.forwardc                 C   s   |  |}| ||}|S r   )r  r   )r   r   Zintermediate_outputr
  r7   r7   r8   r	    s    
zCanineLayer.feed_forward_chunk)NNF)r/   r0   r1   r~   r6   r3   r4   r   r   r   r	  r   r7   r7   r   r8   r    s      r  c                
       s`   e Zd Zd
 fdd	Zdeej eej eej ee ee ee e	ee
f ddd	Z  ZS )CanineEncoderFr   c	           	   
      sH   t    | _t fddtjD | _d| _d S )Nc                    s"   g | ]}t  qS r7   )r  )r;   r   r   r   r   r   r   rl   r   r   r7   r8   rN     s   z*CanineEncoder.__init__.<locals>.<listcomp>F)	r}   r~   rl   r   Z
ModuleListrf   num_hidden_layerslayerZgradient_checkpointingr   r   r  r8   r~     s    
zCanineEncoder.__init__NT)r-   r   r   r   output_hidden_statesreturn_dictr   c                 C   s   |rdnd }|rdnd }t | jD ]R\}	}
|r8||f }|d urH||	 nd }|
||||}|d }|r"||d f }q"|r||f }|stdd |||fD S t|||dS )Nr7   r   r   c                 s   s   | ]}|d ur|V  qd S r   r7   r;   vr7   r7   r8   r=   !  rO   z(CanineEncoder.forward.<locals>.<genexpr>)r+   r-   r.   )r   r  r6   r   )r   r-   r   r   r   r  r  all_hidden_statesall_self_attentionsrM   Zlayer_moduleZlayer_head_maskZlayer_outputsr7   r7   r8   r     s&    	

zCanineEncoder.forward)FFFr   r   r   r   )NNFFT)r/   r0   r1   r~   r6   r3   r4   r   r   r   r   r   r   r7   r7   r   r8   r    s,          !     
r  c                       s4   e Zd Z fddZeej ejdddZ  ZS )CaninePoolerc                    s*   t    t|j|j| _t | _d S r   )r}   r~   r   r   r   r   ZTanhr   r   r   r7   r8   r~   *  s    
zCaninePooler.__init__r   c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )r   r-   Zfirst_token_tensorpooled_outputr7   r7   r8   r   /  s    

zCaninePooler.forwardr   r7   r7   r   r8   r  )  s   r  c                       s4   e Zd Z fddZeej ejdddZ  ZS )CaninePredictionHeadTransformc                    sV   t    t|j|j| _t|jtr6t	|j | _
n|j| _
tj|j|jd| _d S r   )r}   r~   r   r   r   r   r   r   r   r
   transform_act_fnrE   r   r   r   r7   r8   r~   9  s    
z&CaninePredictionHeadTransform.__init__r   c                 C   s"   |  |}| |}| |}|S r   )r   r  rE   r  r7   r7   r8   r   B  s    


z%CaninePredictionHeadTransform.forwardr   r7   r7   r   r8   r  8  s   	r  c                       s4   e Zd Z fddZeej ejdddZ  ZS )CanineLMPredictionHeadc                    sL   t    t|| _tj|j|jdd| _t	t
|j| _| j| j_d S )NF)rK   )r}   r~   r  	transformr   r   r   Z
vocab_sizedecoder	Parameterr3   r   rK   r   r   r7   r8   r~   J  s
    

zCanineLMPredictionHead.__init__r   c                 C   s   |  |}| |}|S r   )r  r  r  r7   r7   r8   r   W  s    

zCanineLMPredictionHead.forwardr   r7   r7   r   r8   r  I  s   r  c                       s8   e Zd Z fddZeej eej dddZ  ZS )CanineOnlyMLMHeadc                    s   t    t|| _d S r   )r}   r~   r  predictionsr   r   r7   r8   r~   ^  s    
zCanineOnlyMLMHead.__init__)sequence_outputr   c                 C   s   |  |}|S r   )r  )r   r  Zprediction_scoresr7   r7   r8   r   b  s    
zCanineOnlyMLMHead.forward)	r/   r0   r1   r~   r6   r3   r   r   r   r7   r7   r   r8   r  ]  s   r  c                   @   s*   e Zd ZU eed< eZdZdZdd Z	dS )CaninePreTrainedModelrl   canineTc                 C   s   t |tjtjfr@|jjjd| jjd |j	dur|j	j
  nft |tjr|jjjd| jjd |jdur|jj|j 
  n&t |tjr|j	j
  |jjd dS )zInitialize the weightsg        )meanZstdNr   )r   r   r   r   rI   rj   Znormal_rl   Zinitializer_rangerK   Zzero_r   Zpadding_idxrE   Zfill_)r   moduler7   r7   r8   _init_weightsq  s    

z#CaninePreTrainedModel._init_weightsN)
r/   r0   r1   r   r5   rs   Zload_tf_weightsZbase_model_prefixZsupports_gradient_checkpointingr$  r7   r7   r7   r8   r   j  s
   
r   c                       s   e Zd Zd fdd	Zdd Zdd Zejedd	d
Z	ejeejdddZ
edeej eej eej eej eej eej ee ee ee eeef d
ddZ  ZS )CanineModelTc              
      s   t  | || _t|}d|_t|| _t|ddd|j	|j	|j	|j	d| _
t|| _t|| _t|| _t|| _|rt|nd| _|   dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        r   TF)r   r   r   r   r   r   r   N)r}   r~   rl   copydeepcopyr  rt   char_embeddingsr  Zlocal_transformer_striderA   r   rB   r>   r   rG   rD   r  pooler	post_init)r   rl   Zadd_pooling_layerZshallow_configr   r7   r8   r~     s*    






zCanineModel.__init__c                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr>   r  r  r   )r   Zheads_to_pruner  r   r7   r7   r8   _prune_heads  s    zCanineModel._prune_headsc                 C   s\   |j d |j d  }}|j d }t||d|f }tj||dftj|jd}|| }|S )aP  
        Create 3D attention mask from a 2D tensor mask.

        Args:
            from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
            to_mask: int32 Tensor of shape [batch_size, to_seq_length].

        Returns:
            float Tensor of shape [batch_size, from_seq_length, to_seq_length].
        r   r   )r   r   r   )rh   r3   reshaper   onesZfloat32r   )r   r   Zto_maskr   r   r   Zbroadcast_onesmaskr7   r7   r8   )_create_3d_attention_mask_from_input_mask  s    
z5CanineModel._create_3d_attention_mask_from_input_mask)char_attention_maskr   c                 C   sF   |j \}}t||d|f}tjj||d| }tj|dd}|S )z[Downsample 2D character attention mask to 2D molecule attention mask using MaxPool1d layer.r   )r   r   ry   r   )rh   r3   r-  r   Z	MaxPool1dr   squeeze)r   r1  r   r   Zchar_seq_lenZpoolable_char_maskZpooled_molecule_maskmolecule_attention_maskr7   r7   r8   _downsample_attention_mask  s    
z&CanineModel._downsample_attention_mask)	moleculeschar_seq_lengthr   c           	      C   sz   | j j}|ddddddf }tj||dd}|ddddddf }|| }tj||| dd}tj||gddS )zDRepeats molecules to make them the same length as the char sequence.Nr   rC   )Zrepeatsr   ry   r   )rl   r   r3   Zrepeat_interleaver   )	r   r5  r6  ZrateZmolecules_without_extra_clsZrepeatedZlast_moleculeZremainder_lengthZremainder_repeatedr7   r7   r8   _repeat_molecules  s    zCanineModel._repeat_moleculesN)
r   r   r   rx   r   r   r   r  r  r   c
           "      C   s  |d ur|n| j j}|d ur |n| j j}|r0dnd }
|r<dnd }|	d urL|	n| j j}	|d urn|d urntdn@|d ur| || | }n"|d ur| d d }ntd|\}}|d ur|jn|j}|d u rtj	||f|d}|d u rtj
|tj|d}| ||}| j|| j jd}| |||jd f}| || j j}| j||||d}| |d urf|n||}| j||||d	}|j}| |}| j||||||	d
}|d }| jd ur| |nd }| j||d d}tj||gdd}| |}| j||||d	}|j}|r<|	r |jn|d }|
|j | |j }
|rj|	rN|jn|d } ||j |  |j }|	s||f}!|!tdd |
|fD 7 }!|!S t |||
|dS )Nr7   zDYou cannot specify both input_ids and inputs_embeds at the same timery   z5You have to specify either input_ids or inputs_embeds)r   r   )r   )r   rx   r   r   )r   r   r  )r   r   r   r  r  r   )r6  r   r   c                 s   s   | ]}|d ur|V  qd S r   r7   r  r7   r7   r8   r=     rO   z&CanineModel.forward.<locals>.<genexpr>)r+   r,   r-   r.   )!rl   r   r  use_return_dictri   Z%warn_if_padding_and_no_attention_maskr   r   r3   r.  r   r   Zget_extended_attention_maskr4  r   rh   Zget_head_maskr  r(  r0  rA   r+   rB   r>   r)  r7  r   rG   rD   r-   r.   r6   r*   )"r   r   r   r   rx   r   r   r   r  r  r  r  r   r   r   r   Zextended_attention_maskr3  Z extended_molecule_attention_maskZinput_char_embeddingsr1  Zinit_chars_encoder_outputsZinput_char_encodingZinit_molecule_encodingZencoder_outputsZmolecule_sequence_outputr  Zrepeated_moleculesconcatr  Zfinal_chars_encoder_outputsZdeep_encoder_hidden_statesZdeep_encoder_self_attentionsr   r7   r7   r8   r     s    

	


zCanineModel.forward)T)	NNNNNNNNN)r/   r0   r1   r~   r,  r0  r3   r   re   r4  r7  r   r   r   r4   r   r   r6   r*   r   r   r7   r7   r   r8   r%    s6   "         
r%  z
    CANINE Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    c                       s   e Zd Z fddZedeej eej eej eej eej eej eej ee	 ee	 ee	 e
eef dddZ  ZS )CanineForSequenceClassificationc                    sJ   t  | |j| _t|| _t|j| _t	|j
|j| _|   d S r   r}   r~   
num_labelsr%  r!  r   r   r   r   r   r   
classifierr*  r   r   r7   r8   r~     s    
z(CanineForSequenceClassification.__init__Nr   r   r   rx   r   r   labelsr   r  r  r   c                 C   s|  |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}| |}d}|dur8| j jdu r| jdkrzd| j _n4| jdkr|jtj	ks|jtj
krd| j _nd| j _| j jdkrt }| jdkr|| | }n
|||}nN| j jdkrt }||d| j|d}n| j jdkr8t }|||}|
sh|f|dd  }|durd|f| S |S t|||j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   rx   r   r   r   r  r  r   Z
regressionZsingle_label_classificationZmulti_label_classificationry   rL   losslogitsr-   r.   )rl   r8  r!  r   r=  Zproblem_typer<  r   r3   r   re   r   r2  r   r   r   r   r-   r.   )r   r   r   r   rx   r   r   r?  r   r  r  r   r  rC  rB  loss_fctr   r7   r7   r8   r     sV    




"


z'CanineForSequenceClassification.forward)
NNNNNNNNNN)r/   r0   r1   r~   r   r   r3   r   r4   r   r   r6   r   r   r   r7   r7   r   r8   r:    s2             
r:  c                       s   e Zd Z fddZedeej eej eej eej eej eej eej ee	 ee	 ee	 e
eef dddZ  ZS )CanineForMultipleChoicec                    s@   t  | t|| _t|j| _t|j	d| _
|   d S r  )r}   r~   r%  r!  r   r   r   r   r   r   r=  r*  r   r   r7   r8   r~     s
    
z CanineForMultipleChoice.__init__Nr>  c                 C   st  |
dur|
n| j j}
|dur&|jd n|jd }|durJ|d|dnd}|durh|d|dnd}|dur|d|dnd}|dur|d|dnd}|dur|d|d|dnd}| j||||||||	|
d	}|d }| |}| |}|d|}d}|dur0t }|||}|
s`|f|dd  }|dur\|f| S |S t	|||j
|jdS )a[  
        input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
            model's internal embedding lookup matrix.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        Nr   ry   rC   r@  rL   rA  )rl   r8  rh   r   r   r!  r   r=  r   r   r-   r.   )r   r   r   r   rx   r   r   r?  r   r  r  Znum_choicesr   r  rC  Zreshaped_logitsrB  rD  r   r7   r7   r8   r     sL    ,



zCanineForMultipleChoice.forward)
NNNNNNNNNN)r/   r0   r1   r~   r   r   r3   r   r4   r   r   r6   r   r   r   r7   r7   r   r8   rE    s2   
          
rE  c                       s   e Zd Z fddZedeej eej eej eej eej eej eej ee	 ee	 ee	 e
eef dddZ  ZS )CanineForTokenClassificationc                    sJ   t  | |j| _t|| _t|j| _t	|j
|j| _|   d S r   r;  r   r   r7   r8   r~   X  s    
z%CanineForTokenClassification.__init__Nr>  c                 C   s   |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}| |}d}|durxt }||d| j|d}|
s|f|dd  }|dur|f| S |S t|||j	|j
dS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, CanineForTokenClassification
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google/canine-s")
        >>> model = CanineForTokenClassification.from_pretrained("google/canine-s")

        >>> inputs = tokenizer(
        ...     "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
        ... )

        >>> with torch.no_grad():
        ...     logits = model(**inputs).logits

        >>> predicted_token_class_ids = logits.argmax(-1)

        >>> # Note that tokens are classified rather then input words which means that
        >>> # there might be more predicted token classes than words.
        >>> # Multiple token classes might account for the same word
        >>> predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
        >>> predicted_tokens_classes  # doctest: +SKIP
        ```

        ```python
        >>> labels = predicted_token_class_ids
        >>> loss = model(**inputs, labels=labels).loss
        >>> round(loss.item(), 2)  # doctest: +SKIP
        ```Nr@  r   ry   rL   rA  )rl   r8  r!  r   r=  r   r   r<  r   r-   r.   )r   r   r   r   rx   r   r   r?  r   r  r  r   r  rC  rB  rD  r   r7   r7   r8   r   c  s8    0

z$CanineForTokenClassification.forward)
NNNNNNNNNN)r/   r0   r1   r~   r   r   r3   r   r4   r   r   r6   r   r   r   r7   r7   r   r8   rF  V  s2             
rF  c                       s   e Zd Z fddZedeej eej eej eej eej eej eej eej ee	 ee	 ee	 e
eef dddZ  ZS )CanineForQuestionAnsweringc                    s<   t  | |j| _t|| _t|j|j| _| 	  d S r   )
r}   r~   r<  r%  r!  r   r   r   
qa_outputsr*  r   r   r7   r8   r~     s
    
z#CanineForQuestionAnswering.__init__N)r   r   r   rx   r   r   start_positionsend_positionsr   r  r  r   c                 C   sD  |d ur|n| j j}| j|||||||	|
|d	}|d }| |}|jddd\}}|d}|d}d }|d ur|d urt| dkr|d}t| dkr|d}|d}|d| |d| t	|d}|||}|||}|| d }|s.||f|dd   }|d ur*|f| S |S t
||||j|jdS )	Nr@  r   r   ry   r   )Zignore_indexrL   )rB  start_logits
end_logitsr-   r.   )rl   r8  r!  rH  r]   r2  rd   r   Zclamp_r   r   r-   r.   )r   r   r   r   rx   r   r   rI  rJ  r   r  r  r   r  rC  rK  rL  Z
total_lossZignored_indexrD  Z
start_lossZend_lossr   r7   r7   r8   r     sP    








z"CanineForQuestionAnswering.forward)NNNNNNNNNNN)r/   r0   r1   r~   r   r   r3   r   r4   r   r   r6   r   r   r   r7   r7   r   r8   rG    s6   
           
rG  )rE  rG  r:  rF  r  r%  r   rs   )Br2   r&  r   rV   dataclassesr   typingr   r   r3   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   Zactivationsr
   Zmodeling_layersr   Zmodeling_outputsr   r   r   r   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   utilsr   r   Zconfiguration_caniner   Z
get_loggerr/   rT   r   r*   rs   Modulert   r   r   r   r   r   r   r  r  r  r  r  r  r  r   r%  r:  rE  rF  rG  __all__r7   r7   r7   r8   <module>   sp    
ae.:j :C  Ug`M