a
    hQ                  
   @   sh  d dl Z d dlZd dlmZmZmZ d dlZd dlZd dlm	Z	 d dl
mZ ddlmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZ ddlmZmZ ddlmZ ddlm Z m!Z! ddl"m#Z# e!$e%Z&G dd deZ'G dd deZ(G dd deZ)G dd de	j*Z+G dd de	j*Z,G dd de	j*Z-G dd de	j*Z.d:e	j*ej/ej/ej/eej/ ee0 e0eej/ d d!d"Z1G d#d$ d$e	j*Z2G d%d& d&e	j*Z3G d'd( d(eZ4G d)d* d*e	j*Z5e G d+d, d,eZ6d;e7e8e8f e0e8eej9 e8ej:d-d.d/Z;e G d0d1 d1e6Z<dZ=e d2d3G d4d5 d5e6Z>e d6d3G d7d8 d8e6Z?g d9Z@dS )<    N)CallableOptionalUnion)nn)CrossEntropyLoss   )ACT2FN)is_deepspeed_zero3_enabled)is_fsdp_managed_module)FlashAttentionKwargs)GradientCheckpointingLayer)BaseModelOutputCausalLMOutputSequenceClassifierOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)auto_docstringlogging   )	SEWConfigc                       s&   e Zd Zd fdd	Zdd Z  ZS )SEWNoLayerNormConvLayerr   c                    sj   t    |dkr |j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
t|j | _d S )Nr   r   kernel_sizestridebias)super__init__conv_dimin_conv_dimout_conv_dimr   Conv1dconv_kernelconv_stride	conv_biasconvr   feat_extract_activation
activationselfconfiglayer_id	__class__ `/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/sew/modeling_sew.pyr   /   s    
z SEWNoLayerNormConvLayer.__init__c                 C   s   |  |}| |}|S N)r%   r'   r)   hidden_statesr.   r.   r/   forward=   s    

zSEWNoLayerNormConvLayer.forward)r   __name__
__module____qualname__r   r3   __classcell__r.   r.   r,   r/   r   .   s   r   c                       s&   e Zd Zd fdd	Zdd Z  ZS )SEWLayerNormConvLayerr   c                    s|   t    |dkr |j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
tj| jdd| _t|j | _d S )Nr   r   r   T)Zelementwise_affine)r   r   r   r   r    r   r!   r"   r#   r$   r%   	LayerNorm
layer_normr   r&   r'   r(   r,   r.   r/   r   D   s    
zSEWLayerNormConvLayer.__init__c                 C   s:   |  |}|dd}| |}|dd}| |}|S )N)r%   	transposer;   r'   r1   r.   r.   r/   r3   S   s    


zSEWLayerNormConvLayer.forward)r   r4   r.   r.   r,   r/   r9   C   s   r9   c                       s&   e Zd Zd fdd	Zdd Z  ZS )SEWGroupNormConvLayerr   c                    s   t    |dkr |j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
t|j | _tj| j| jdd| _d S )Nr   r   r   T)Z
num_groupsZnum_channelsZaffine)r   r   r   r   r    r   r!   r"   r#   r$   r%   r   r&   r'   	GroupNormr;   r(   r,   r.   r/   r   _   s    
zSEWGroupNormConvLayer.__init__c                 C   s"   |  |}| |}| |}|S r0   )r%   r;   r'   r1   r.   r.   r/   r3   o   s    


zSEWGroupNormConvLayer.forward)r   r4   r.   r.   r,   r/   r?   ^   s   r?   c                       s$   e Zd Z fddZdd Z  ZS )SEWPositionalConvEmbeddingc                    s(  t    tj|j|j|j|jd |j|jd| _tj	j
}ttj	jdrRtj	jj
}t rdd l}|jj| jjdd" || jddd| _W d    n1 s0    Y  t| jdr| jjjj}| jjjj}n| jj}| jj}|j| | |j| | n|| jddd| _t|j| _t|j | _d S )	N   )r   paddinggroupsr   weight_normr   Zmodifier_rankweight)namedimparametrizations)r   r   r   r!   hidden_sizenum_conv_pos_embeddingsZnum_conv_pos_embedding_groupssqueeze_factorr%   utilsrE   hasattrrJ   r	   	deepspeedzeroGatheredParametersrG   Z	original0Z	original1weight_gweight_vZregister_external_parameterSEWSamePadLayerrC   r   r&   r'   )r)   r*   rE   rP   rS   rT   r,   r.   r/   r   w   s4    
	
0z#SEWPositionalConvEmbedding.__init__c                 C   s"   |  |}| |}| |}|S r0   )r%   rC   r'   r1   r.   r.   r/   r3      s    


z"SEWPositionalConvEmbedding.forwardr4   r.   r.   r,   r/   rA   v   s   "rA   c                       s$   e Zd Z fddZdd Z  ZS )rU   c                    s$   t    |d dkrdnd| _d S )NrB   r   r   )r   r   num_pad_remove)r)   rL   r,   r.   r/   r      s    
zSEWSamePadLayer.__init__c                 C   s,   | j dkr(|d d d d d | j  f }|S )Nr   )rV   r1   r.   r.   r/   r3      s    
zSEWSamePadLayer.forwardr4   r.   r.   r,   r/   rU      s   rU   c                       s$   e Zd Z fddZdd Z  ZS )SEWUpsamplingc                    s:   t    t|j|j|j | _t|j | _	|j| _d S r0   )
r   r   r   LinearrK   rM   
projectionr   r&   r'   r)   r*   r,   r.   r/   r      s    
zSEWUpsampling.__init__c                 C   sd   |  |}| |}| jdkr`| \}}}|| j }|| j }|||| j|}||||}|S )Nr   )rY   r'   rM   sizereshape)r)   r2   bszsrc_lenZsrc_embed_dimtgt_lenZtgt_embed_dimr.   r.   r/   r3      s    




zSEWUpsampling.forwardr4   r.   r.   r,   r/   rW      s   rW   c                       s0   e Zd ZdZ fddZdd Zdd Z  ZS )SEWFeatureEncoderz.Construct the features from raw audio waveformc                    s   t     jdkr@t ddg fddt jd D  }n6 jdkrd fddt jD }ntd	 j d
t|| _	d| _
d| _d S )Ngroupr   r+   c                    s   g | ]}t  |d  dqS )r   rb   )r   .0ir*   r.   r/   
<listcomp>   s   z.SEWFeatureEncoder.__init__.<locals>.<listcomp>r   layerc                    s   g | ]}t  |d qS )rb   )r9   rc   rf   r.   r/   rg          z`config.feat_extract_norm` is z), but has to be one of ['group', 'layer']FT)r   r   Zfeat_extract_normr?   rangeZnum_feat_extract_layers
ValueErrorr   
ModuleListconv_layersgradient_checkpointing_requires_grad)r)   r*   rm   r,   rf   r/   r      s    



zSEWFeatureEncoder.__init__c                 C   s   |   D ]
}d|_qd| _d S )NF)
parametersrequires_gradro   r)   paramr.   r.   r/   _freeze_parameters   s    z$SEWFeatureEncoder._freeze_parametersc                 C   s:   |d d d f }| j r"| jr"d|_| jD ]}||}q(|S )NT)ro   trainingrq   rm   )r)   input_valuesr2   Z
conv_layerr.   r.   r/   r3      s    

zSEWFeatureEncoder.forward)r5   r6   r7   __doc__r   rt   r3   r8   r.   r.   r,   r/   r`      s   r`           )modulequerykeyvalueattention_maskscalingdropout	head_maskc                 K   s   |d u r| dd }t||dd| }	|d ur>|	| }	tjj|	dd}	|d urj|	|dddd }	tjj|	|| j	d}	t|	|}
|
dd
 }
|
|	fS )Nr=         rB   r   rI   r   )pru   )r[   torchmatmulr>   r   
functionalsoftmaxviewr   ru   
contiguous)ry   rz   r{   r|   r}   r~   r   r   kwargsattn_weightsattn_outputr.   r.   r/   eager_attention_forward   s    r   c                       s   e Zd ZdZdeeeeeeee d fddZ	de
jee
j ee
j ee
j ee ee ee
jee
j eee
j  f d	d
dZ  ZS )SEWAttentionz=Multi-headed attention from 'Attention Is All You Need' paperrx   FTN)	embed_dim	num_headsr   
is_decoderr   	is_causalr*   c                    s   t    || _|| _|| _|| | _|| _| j| | jkrTtd| j d| d| jd | _|| _	|| _
tj|||d| _tj|||d| _tj|||d| _tj|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).r   )r   )r   r   r   r   r   head_dimr*   rk   r~   r   r   r   rX   k_projv_projq_projout_proj)r)   r   r   r   r   r   r   r*   r,   r.   r/   r   	  s&    



zSEWAttention.__init__)r2   key_value_statesr}   layer_head_maskoutput_attentionsr   returnc                 K   s  |du}|j dd \}}	|r(|j d n|	}
||	d| jf}||
d| jf}| |j| dd}|rh|n|}| |j| dd}| |j| dd}t}| jj	dkrt
| jj	 }|| ||||f| jsdn| j| j||d|\}}|||	d }| |}||dfS )z#Input shape: Batch x Time x ChannelNr=   r   rB   eagerrx   )r   r~   r   r   )shaper   r   r   r>   r   r   r   r*   _attn_implementationr   ru   r   r~   r\   r   r   )r)   r2   r   r}   r   r   r   Zis_cross_attentionr]   r_   r^   Zq_input_shapeZkv_input_shapeZquery_statesZcurrent_statesZ
key_statesZvalue_statesZattention_interfacer   r   r.   r.   r/   r3   (  s:    


zSEWAttention.forward)rx   FTFN)NNNF)r5   r6   r7   rw   intfloatboolr   r   r   r   Tensorr   r   tupler3   r8   r.   r.   r,   r/   r     s8        "    r   c                       s$   e Zd Z fddZdd Z  ZS )SEWFeedForwardc                    sp   t    t|j| _t|j|j| _	t
|jtrDt|j | _n|j| _t|j|j| _t|j| _d S r0   )r   r   r   DropoutZactivation_dropoutintermediate_dropoutrX   rK   Zintermediate_sizeintermediate_dense
isinstanceZ
hidden_actstrr   intermediate_act_fnoutput_densehidden_dropoutoutput_dropoutrZ   r,   r.   r/   r   _  s    
zSEWFeedForward.__init__c                 C   s6   |  |}| |}| |}| |}| |}|S r0   )r   r   r   r   r   r1   r.   r.   r/   r3   l  s    




zSEWFeedForward.forwardr4   r.   r.   r,   r/   r   ^  s   r   c                       s&   e Zd Z fddZdddZ  ZS )SEWEncoderLayerc                    sh   t    t|j|j|jd|d| _t|j	| _
tj|j|jd| _t|| _tj|j|jd| _d S )NF)r   r   r   r   r*   eps)r   r   r   rK   Znum_attention_headsZattention_dropout	attentionr   r   r   r   r:   layer_norm_epsr;   r   feed_forwardfinal_layer_normrZ   r,   r.   r/   r   w  s    

zSEWEncoderLayer.__init__NFc                 C   sf   |}| j |||d\}}}| |}|| }| |}|| | }| |}|f}|rb||f7 }|S )Nr}   r   )r   r   r;   r   r   )r)   r2   r}   r   Zattn_residualr   _outputsr.   r.   r/   r3     s    



zSEWEncoderLayer.forward)NFr4   r.   r.   r,   r/   r   v  s   r   c                       s&   e Zd Z fddZdddZ  ZS )	
SEWEncoderc                    s   t     | _t | _t j j| _tj	 j
 jd| _t j| _t fddt jD | _t | _d| _d S )Nr   c                    s   g | ]}t  qS r.   )r   rd   r   rf   r.   r/   rg     ri   z'SEWEncoder.__init__.<locals>.<listcomp>F)r   r   r*   rA   pos_conv_embedr   Z	AvgPool1drM   poolr:   rK   r   r;   r   r   r   rl   rj   num_hidden_layerslayersrW   upsamplern   rZ   r,   rf   r/   r     s    

 
zSEWEncoder.__init__NFTc              	   C   s  |rdnd }|rdnd }|d ur8| ddd|jd }| jjdkrld|| < |d urfd|v rf|nd }nd|| < | d}	|	| jj }
|jd | jj }tj	d||
j
ddd|
jd d}||
ddk  }d	|d d d d d d f j|jd
 }|t|jj }||jd d|jd |jd }|jd }|dd}| |}| |}t|d|d}|dd |f |dd |f  }|dd}| |}| |}t pt| }| jD ]t}|r||f }tg }| jo|| jjk }|r|r"||||d}|d }|r,d}|r||d f }q|rT||f }| |}|jd |k rtj |ddd||jd  f}|st!dd |||fD S t"|||dS )Nr.   r=   r   rB   Zflash_attention_2rx   r   device      ?dtype.r   )NNc                 s   s   | ]}|d ur|V  qd S r0   r.   )rd   vr.   r.   r/   	<genexpr>  ri   z%SEWEncoder.forward.<locals>.<genexpr>Zlast_hidden_stater2   
attentions)#	unsqueezerepeatr   r*   r   longsumrM   r   aranger   r   expandtor   Zfinfominr>   r   r   r[   r;   r   r	   r
   r   randru   Z	layerdropr   r   r   padr   r   )r)   r2   r}   r   output_hidden_statesreturn_dictZall_hidden_statesZall_self_attentionsZexpand_attention_maskinput_lengthsoutput_lengthsZmax_encoder_lengthZattention_idsZn_input_timestepsZposition_embeddingsZpooled_hidden_statesZ
min_lengthZsynced_gpusrh   Zdropout_probabilityZskip_the_layerZlayer_outputsr.   r.   r/   r3     sv    



&


 






 zSEWEncoder.forward)NFFTr4   r.   r.   r,   r/   r     s       r   c                   @   s`   e Zd ZU eed< dZdZdZdZdZ	dZ
dd Zeejef dd	d
ZeejdddZdS )SEWPreTrainedModelr*   sewrv   TFc              	   C   s  t |trTtjj|jjddtd|jj	d |jj
   d tj|jjd n,t |tjrz|jjjd| jjd nt |tjtjfr|jj  |jjd nt |tjrt rpddl}t|dr*t|d	r*|jj|j|jgdd
  tj|jj W d   n1 s0    Y  nD|jj|jdd
  tj|jj W d   n1 sd0    Y  ntj|jj t |tjtjfr|jdur|jj  dS )zInitialize the weightsr   rB   r   )meanZstdrx   r   NrT   rS   rF   )r   rA   r   initZnormal_r%   rG   mathsqrtr   Zin_channelsZ	constant_r   rX   datar*   Zinitializer_ranger:   r@   Zzero_Zfill_r!   r	   rP   rO   rQ   rR   rT   rS   Zkaiming_normal_)r)   ry   rP   r.   r.   r/   _init_weights
  s.    
 22 z SEWPreTrainedModel._init_weights)r   c                 C   s4   dd }t | jj| jjD ]\}}||||}q|S )zH
        Computes the output length of the convolutional layers
        c                 S   s   t j| | |ddd S )Nfloor)Zrounding_moder   )r   div)input_lengthr   r   r.   r.   r/   _conv_out_length/  s    zMSEWPreTrainedModel._get_feat_extract_output_lengths.<locals>._conv_out_length)zipr*   r"   r#   )r)   r   r   r   r   r.   r.   r/    _get_feat_extract_output_lengths*  s    z3SEWPreTrainedModel._get_feat_extract_output_lengths)feature_vector_lengthr}   c                 C   s~   |  |dtj}|jd }tj||f|j|jd}d|tj	|jd |jd|d f< |
dgd
dg }|S )Nr=   r   )r   r   r   r   )r   r   r   r   r   r   zerosr   r   r   flipZcumsumr   )r)   r   r}   r   
batch_sizer.   r.   r/   "_get_feature_vector_attention_mask9  s    
"z5SEWPreTrainedModel._get_feature_vector_attention_maskN)r5   r6   r7   r   __annotations__Zbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_supports_flash_attnZ_supports_sdpaZ_supports_flex_attnr   r   r   
LongTensorr   r   r   r.   r.   r.   r/   r      s   
 r   )r   	mask_probmask_lengthr}   	min_masksr   c                    s  | \}dk rt dkr6t d d dtjd   fdd}|durt| d	 nfd
dt|D }tj	|ft
d}g }	|}
|
dkr|S |D ]v}||}tjjt|d  |dd}t|dkrd }n|d }t|tj|
| tjd| g}|	| qt|	}	t|	dddddf ||
f}	|	||
 }	tddddf }t|||
f||
 }|	| }	|	 d kr҈d |	|	d k< t||	dd	 |S )an  
    Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
    ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
    CPU as part of the preprocessing during training.

    Args:
        shape: The shape for which to compute masks. This should be of a tuple of size 2 where
               the first element is the batch size and the second element is the length of the axis to span.
        mask_prob:  The percentage of the whole axis (between 0 and 1) which will be masked. The number of
                    independently generated mask spans of length `mask_length` is computed by
                    `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
                    actual percentage will be smaller.
        mask_length: size of the mask
        min_masks: minimum number of masked spans
        attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
                        each batch dimension.
    r   z&`mask_length` has to be bigger than 0.zO`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: z and `sequence_length`: `c                    sX   t |     }t|}| kr2 }| d  |k rTt| d  d}|S )z;Given input length, compute how many spans should be maskedr   r   )r   max)r   num_masked_spanepsilonr   r   r   sequence_lengthr.   r/   compute_num_masked_spanl  s    
z6_compute_mask_indices.<locals>.compute_num_masked_spanNr=   c                    s   g | ]} qS r.   r.   r   )r   r.   r/   rg     ri   z)_compute_mask_indices.<locals>.<listcomp>r   r   F)replace)rk   nprandomr   itemdetachr   tolistrj   r   r   choicer   lenZconcatenateonesZint32appendarrayZbroadcast_tor\   r   Zput_along_axis)r   r   r   r}   r   r   r   r   Zspec_aug_maskZspec_aug_mask_idxsZmax_num_masked_spanr   r   Zspec_aug_mask_idxZdummy_mask_idxoffsetsr.   r   r/   _compute_mask_indicesF  s\    

r  c                       s   e Zd Zed fddZdejeej eej dddZ	e
deej eej eej ee ee ee eeef dd	d
Z  ZS )SEWModelrf   c                    s   t  | || _t|| _tj|jd |jd| _	|jd |j
k| _| jrbt|jd |j
| _t|j| _|jdks|jdkrtt|j
 | _t|| _|   d S )Nr=   r   rx   )r   r   r*   r`   feature_extractorr   r:   r   r   r;   rK   project_featuresrX   feature_projectionr   Zfeat_proj_dropoutfeature_dropoutmask_time_probmask_feature_prob	Parameterr   r   Zuniform_masked_spec_embedr   encoder	post_initrZ   r,   r.   r/   r     s    

zSEWModel.__init__N)r2   mask_time_indicesr}   c                 C   s  t | jdds|S | \}}}|dur<| j|j||< nZ| jjdkr| jrt||f| jj| jj	|| jj
d}tj||jtjd}| j|j||< | jjdkr| jrt||f| jj| jj| jjd}tj||jtjd}|dddf d|d}d||< |S )	z
        Masks extracted features along time axis and/or along feature axis according to
        [SpecAugment](https://huggingface.co/papers/1904.08779).
        Zapply_spec_augmentTNr   )r   r   r}   r   )r   r   )r   r   r   r=   )getattrr*   r[   r
  r   r   r  ru   r  Zmask_time_lengthZmask_time_min_masksr   Ztensorr   r   r  Zmask_feature_lengthZmask_feature_min_masksr   )r)   r2   r  r}   r   r   rK   Zmask_feature_indicesr.   r.   r/   _mask_hidden_states  s4    zSEWModel._mask_hidden_states)rv   r}   r  r   r   r   r   c           
      C   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}| |}|dd}| |}| jrl| |}| 	|}|dur| 
|jd |}| j||d}| j|||||d}	|	d }|s|f|	dd  S t||	j|	jdS )a/  
        mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
            masked extracted features in *config.proj_codevector_dim* space.
        Nr   rB   )r  r}   r   r   r   r   r   )r*   r   r   use_return_dictr  r>   r;   r  r  r  r   r   r  r  r   r2   r   )
r)   rv   r}   r  r   r   r   Zextract_featuresr2   Zencoder_outputsr.   r.   r/   r3     s8    



zSEWModel.forward)NN)NNNNN)r5   r6   r7   r   r   r   ZFloatTensorr   r   r  r   r   r   r   r   r   r3   r8   r.   r.   r,   r/   r    s.     .     
r  zk
    SEW Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
    )Zcustom_introc                       s   e Zd Zdee d fddZdd Zdd Zd	d
 Zdd Z	e
deej eej ee ee ee eej eeef dddZ  ZS )	SEWForCTCN)target_langc                    s~   t  | t|| _t|j| _|| _|j	du rFt
d| j dt|dr\|jr\|jn|j}t||j	| _|   dS )a-  
        target_lang (`str`, *optional*):
            Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
            adapter.<lang>.bin. Only relevant when using an instance of [`SEWForCTC`] with adapters. Uses 'eng' by
            default.
        NzYou are trying to instantiate z with a configuration that does not define the vocabulary size of the language model head. Please instantiate the model as follows: `SEWForCTC.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of your model's configuration.add_adapter)r   r   r  r   r   r   Zfinal_dropoutr   r  
vocab_sizerk   r-   rO   r  output_hidden_sizerK   rX   lm_headr  )r)   r*   r  r  r,   r.   r/   r   A  s    

zSEWForCTC.__init__c                 C   sr   | j }|dur2t| jdddu r2td| dn<|du rXt| jdddurXtd n|durn| j|dd dS )a'  
        This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
        passing `target_lang=...` to `from_pretrained(...)`.

        This method is **not** supposed to be called by the user and is prone to be changed in the future.
        NZadapter_attn_dimzCannot pass `target_lang`: z- if `config.adapter_attn_dim` is not defined.z)By default `target_lang` is set to 'eng'.T)Z
force_load)r  r  r*   rk   loggerinfoZload_adapter)r)   r  r.   r.   r/   tie_weights^  s    zSEWForCTC.tie_weightsc                 C   s   t dt |   dS )
        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
        not be updated during training.
        The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. Please use the equivalent `freeze_feature_encoder` method instead.NwarningswarnFutureWarningfreeze_feature_encoderr)   r.   r.   r/   freeze_feature_extractors  s
    z"SEWForCTC.freeze_feature_extractorc                 C   s   | j j  dS r  Nr   r  rt   r"  r.   r.   r/   r!    s    z SEWForCTC.freeze_feature_encoderc                 C   s   | j  D ]
}d|_q
dS z
        Calling this function will disable the gradient computation for the base model so that its parameters will not
        be updated during training. Only the classification head will be updated.
        FNr   rp   rq   rr   r.   r.   r/   freeze_base_model  s    zSEWForCTC.freeze_base_modelrv   r}   r   r   r   labelsr   c              
   C   s  |dur|n| j j}|dur>| | j jkr>td| j j | j|||||d}|d }| |}| |}	d}
|dur@|dur|ntj	|tj
d}| |dtj
}|dk}|d}||}tjj|	dtjddd}tjjjd	d
6 tjj||||| j j| j j| j jd}
W d   n1 s60    Y  |sp|	f|td  }|
durl|
f| S |S t|
|	|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
            Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
            the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
            All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
            config.vocab_size - 1]`.
        Nz$Label values must be <= vocab_size: r  r   r   r=   )rI   r   r   F)Zenabled)blankZ	reductionZzero_infinitylosslogitsr2   r   )r*   r  r   r  rk   r   r   r  r   Z	ones_liker   r   r   r   Zmasked_selectr   r   Zlog_softmaxZfloat32r>   backendsZcudnnflagsZctc_lossZpad_token_idZctc_loss_reductionZctc_zero_infinity_HIDDEN_STATES_START_POSITIONr   r2   r   )r)   rv   r}   r   r   r   r*  r   r2   r.  r-  r   Zlabels_maskZtarget_lengthsZflattened_targetsZ	log_probsoutputr.   r.   r/   r3     sL    




&
zSEWForCTC.forward)N)NNNNN)r5   r6   r7   r   r   r   r  r#  r!  r(  r   r   r   r   r   r   r   r3   r8   r.   r.   r,   r/   r  ;  s(        
r  z
    SEW Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
    SUPERB Keyword Spotting.
    c                       sz   e Zd Z fddZdd Zdd Zdd Zedee	j
 ee	j
 ee ee ee ee	j
 eeef d
ddZ  ZS )SEWForSequenceClassificationc                    s   t  | t|dr$|jr$tdt|| _|jd }|jrTt	
t|| | _t	|j|j| _t	|j|j| _|   d S )Nr  zZSequence classification does not support the use of SEW adapters (config.add_adapter=True)r   )r   r   rO   r  rk   r  r   r   use_weighted_layer_sumr   r	  r   r   layer_weightsrX   rK   Zclassifier_proj_size	projector
num_labels
classifierr  )r)   r*   Z
num_layersr,   r.   r/   r     s    

z%SEWForSequenceClassification.__init__c                 C   s   t dt |   dS )z
        Calling this function will disable the gradient computation for the feature encoder so that its parameters will
        not be updated during training.
        r  Nr  r"  r.   r.   r/   r#    s
    z5SEWForSequenceClassification.freeze_feature_extractorc                 C   s   | j j  dS r$  r%  r"  r.   r.   r/   r!    s    z3SEWForSequenceClassification.freeze_feature_encoderc                 C   s   | j  D ]
}d|_q
dS r&  r'  rr   r.   r.   r/   r(    s    z.SEWForSequenceClassification.freeze_base_modelNr)  c                 C   s  |dur|n| j j}| j jr dn|}| j|||||d}| j jr|t }tj|dd}tjj	| j
dd}	||	ddd jdd}n|d }| |}|du r|jdd}
nV| |jd |}|ddd|jd }d	|| < |jdd|jdddd }
| |
}d}|dur<t }||d| j j|d}|sl|f|td  }|durh|f| S |S t|||j|jd
S )a  
        input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
            into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
            (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
            To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
            into a tensor of type `torch.FloatTensor`. See [`SEWProcessor.__call__`] for details.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        NTr  r   r   r=   r   rB   rx   r,  )r*   r  r4  r   r1  r   stackr   r   r   r5  r   r   r6  r   r   r   r   r   r8  r   r7  r   r2   r   )r)   rv   r}   r   r   r   r*  r   r2   Znorm_weightsZpooled_outputZpadding_maskZexpand_padding_maskr.  r-  Zloss_fctr2  r.   r.   r/   r3   	  sH    

 

z$SEWForSequenceClassification.forward)NNNNN)r5   r6   r7   r   r#  r!  r(  r   r   r   r   r   r   r   r   r3   r8   r.   r.   r,   r/   r3    s&        
r3  )r  r3  r  r   )Nrx   N)Nr   )Ar   r  typingr   r   r   numpyr   r   r   Ztorch.nnr   Zactivationsr   Zintegrations.deepspeedr	   Zintegrations.fsdpr
   Zmodeling_flash_attention_utilsr   Zmodeling_layersr   Zmodeling_outputsr   r   r   Zmodeling_utilsr   r   Zprocessing_utilsr   rN   r   r   Zconfiguration_sewr   Z
get_loggerr5   r  r   r9   r?   ModulerA   rU   rW   r`   r   r   r   r   r   r   r   r   r   r   r   Zndarrayr  r  r1  r  r3  __all__r.   r.   r.   r/   <module>   s   
+,   X$fI  
wz s