a
    hl                    @   s  d dl Z d dlZd dlmZmZ d dlZd dlZd dlm	Z	 d dl
m	  mZ d dlmZ ddlmZ ddlmZ ddlmZ ddlmZ dd	lmZmZmZmZmZmZ dd
lmZ ddlm Z m!Z!m"Z" ddl#m$Z$ e"%e&Z'G dd de	j(Z)G dd de	j(Z*G dd de	j(Z+G dd de	j(Z,G dd de	j(Z-G dd deZ.G dd deZ/G dd de	j(Z0G dd de	j(Z1G d d! d!e	j(Z2e G d"d# d#eZ3G d$d% d%eZ4G d&d' d'eZ5G d(d) d)eZ6G d*d+ d+e	j(Z7G d,d- d-e	j(Z8G d.d/ d/e	j(Z9dGe:e;e;f e<e;eej= e;ej>d0d1d2Z?eZ@e G d3d4 d4e3ZAd5ZBe d6d7G d8d9 d9e3ZCe d:d7G d;d< d<e3ZDe G d=d> d>e3ZEG d?d@ d@e	j(ZFG dAdB dBe	j(ZGe dCd7G dDdE dEe3ZHg dFZIdS )H    N)OptionalUnion)CrossEntropyLoss   )ACT2FN)is_deepspeed_zero3_enabled)is_fsdp_managed_module)GradientCheckpointingLayer)BaseModelOutputCausalLMOutputSequenceClassifierOutputTokenClassifierOutputWav2Vec2BaseModelOutputXVectorOutput)PreTrainedModel)auto_docstringis_peft_availablelogging   )WavLMConfigc                       s$   e Zd Z fddZdd Z  ZS )WavLMSamePadLayerc                    s$   t    |d dkrdnd| _d S N   r   r   )super__init__num_pad_remove)selfnum_conv_pos_embeddings	__class__ d/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/wavlm/modeling_wavlm.pyr   &   s    
zWavLMSamePadLayer.__init__c                 C   s,   | j dkr(|d d d d d | j  f }|S )Nr   )r   r   hidden_statesr    r    r!   forward*   s    
zWavLMSamePadLayer.forward__name__
__module____qualname__r   r$   __classcell__r    r    r   r!   r   %   s   r   c                       s$   e Zd Z fddZdd Z  ZS )WavLMPositionalConvEmbeddingc                    s$  t    tj|j|j|j|jd |jd| _tjj	}t
tjjdrNtjjj	}t rdd l}|jj| jjdd" || jddd| _W d    n1 s0    Y  t
| jdr| jjjj}| jjjj}n| jj}| jj}|j| | |j| | n|| jddd| _t|j| _t|j | _d S )	Nr   )kernel_sizepaddinggroupsweight_normr   )Zmodifier_rankweight)namedimparametrizations)r   r   nnConv1dhidden_sizer   Znum_conv_pos_embedding_groupsconvutilsr.   hasattrr2   r   	deepspeedzeroZGatheredParametersr/   Z	original0Z	original1weight_gweight_vZregister_external_parameterr   r,   r   feat_extract_activation
activation)r   configr.   r9   r;   r<   r   r    r!   r   1   s2    

0z%WavLMPositionalConvEmbedding.__init__c                 C   s:   | dd}| |}| |}| |}| dd}|S Nr   r   )	transposer6   r,   r>   r"   r    r    r!   r$   R   s    


z$WavLMPositionalConvEmbedding.forwardr%   r    r    r   r!   r*   0   s   !r*   c                       s$   e Zd Z fddZdd Z  ZS )WavLMFeatureProjectionc                    sJ   t    tj|jd |jd| _t|jd |j| _	t
|j| _d S )Neps)r   r   r3   	LayerNormconv_dimlayer_norm_eps
layer_normLinearr5   
projectionDropoutZfeat_proj_dropoutdropoutr   r?   r   r    r!   r   ^   s    
zWavLMFeatureProjection.__init__c                 C   s&   |  |}| |}| |}||fS N)rI   rK   rM   )r   r#   Znorm_hidden_statesr    r    r!   r$   d   s    


zWavLMFeatureProjection.forwardr%   r    r    r   r!   rB   ]   s   rB   c                       s   e Zd ZdZdeeeeeed fddZdej	e
ej	 e
ej	 eeej	e
ej	 e
eej	  f dddZejeejejf ejeejejfdddZeeejdddZejejdddZ  ZS )WavLMAttentionz=Multi-headed attention from 'Attention Is All You Need' paper        @     T	embed_dim	num_headsrM   num_bucketsmax_distancehas_relative_position_biasc                    s   t    || _|| _|| _|| | _| j| | jkrNtd| j d| d| jd | _t	||| _
t	||| _t	||| _t	||| _|| _|| _ttd| jdd| _t	| jd| _|rt| j| j| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      r      )r   r   rU   rV   rM   Zhead_dim
ValueErrorZscalingr3   rJ   k_projv_projq_projout_projrW   rX   	Parametertorchonesgru_rel_pos_constgru_rel_pos_linearZ	Embeddingrel_attn_embed)r   rU   rV   rM   rW   rX   rY   r   r    r!   r   o   s,    	


zWavLMAttention.__init__NFr   )r#   attention_maskposition_biasoutput_attentionsreturnc                 C   s  |  \}}}|du rH| ||}|d|ddd|| j ||}||jdd | jdf }	|	dddd}	| |	}
|
|	jdd d 	d}
t
|
jddd\}}||| j d	  d
 }||| j dd| }|d||f}| ||||\}}|||fS )z'Attention layer with relative attentionNr   r   rC   r   r   )r      r1         ?g       @)sizecompute_bias	unsqueezerepeatviewrV   shapepermuterd   sumra   Zsigmoidchunkrc   torch_multi_head_self_attention)r   r#   rf   rg   rh   indexZbszZtgt_len_Zgated_hidden_statesZrelative_position_projZgate_aZgate_bZgate_outputgated_position_biasattn_outputattn_weightsr    r    r!   r$      s"    	$
zWavLMAttention.forward)r#   rf   ry   rh   ri   c                 C   s   | dd } }}|dur&|dnd}d }	}
d}tj|||| j| jtdgt| j	j
| jj
| jj
f|	|
|| j| jj| jj
| j|||d| j	j| jj| jjd\}}| dd}|dur|dddf |jdd | jf |jdd  }||fS )zCsimple wrapper around torch's multi_head_attention_forward functionr   r   NFT)Zuse_separate_proj_weightZq_proj_weightZk_proj_weightZv_proj_weight)rA   neFZmulti_head_attention_forwardrU   rV   ra   emptycatr^   biasr\   r]   rM   r_   r/   trainingbroadcast_torr   )r   r#   rf   ry   rh   querykeyvalueZkey_padding_maskZbias_kZbias_vZadd_zero_attnrz   r{   r    r    r!   rv      sB    	

"z.WavLMAttention.torch_multi_head_self_attention)query_length
key_lengthri   c                 C   sv   t j|t jdd d d f }t j|t jdd d d f }|| }| |}|| jjj}| |}|g d}|S )Ndtype)r   r   r   )	ra   arangelong_relative_positions_buckettore   r/   devicers   )r   r   r   Zcontext_positionZmemory_positionZrelative_positionZrelative_position_bucketvaluesr    r    r!   rn      s    

zWavLMAttention.compute_bias)relative_positionsri   c                 C   s   | j d }|dktj| }t|}|d }||k }t| | }|t| j|  }|||  }|| tj}t	|t
||d }|t|||7 }|S r   )rW   r   ra   r   abslogfloatmathrX   minZ	full_likewhere)r   r   rW   Zrelative_bucketsZ	max_exactZis_smallZrelative_positions_if_largeZrelative_position_if_larger    r    r!   r      s    

z)WavLMAttention._relative_positions_bucket)rQ   rR   rS   T)NNFr   )r&   r'   r(   __doc__intr   boolr   ra   Tensorr   tupler$   FloatTensorr   
LongTensorZ
BoolTensorrv   rn   r   r)   r    r    r   r!   rP   l   s@       '    +
7
rP   c                       s$   e Zd Z fddZdd Z  ZS )WavLMFeedForwardc                    sp   t    t|j| _t|j|j| _	t
|jtrDt|j | _n|j| _t|j|j| _t|j| _d S rO   )r   r   r3   rL   Zactivation_dropoutintermediate_dropoutrJ   r5   Zintermediate_sizeintermediate_dense
isinstanceZ
hidden_actstrr   intermediate_act_fnoutput_densehidden_dropoutoutput_dropoutrN   r   r    r!   r     s    
zWavLMFeedForward.__init__c                 C   s6   |  |}| |}| |}| |}| |}|S rO   )r   r   r   r   r   r"   r    r    r!   r$      s    




zWavLMFeedForward.forwardr%   r    r    r   r!   r     s   r   c                       s0   e Zd Zd
eed fddZddd	Z  ZS )WavLMEncoderLayerTr?   rY   c                    sn   t    t|j|j|j|j|j|d| _t	
|j| _t	j|j|jd| _t|| _t	j|j|jd| _d S NrT   rD   r   r   rP   r5   Znum_attention_headsZattention_dropoutrW   Zmax_bucket_distance	attentionr3   rL   r   rM   rF   rH   rI   r   feed_forwardfinal_layer_normr   r?   rY   r   r    r!   r   +  s    

zWavLMEncoderLayer.__init__NFr   c           	      C   sl   |}| j |||||d\}}}| |}|| }| |}|| | }| |}||f}|rh||f7 }|S )Nrf   rg   rh   rw   )r   rM   rI   r   r   )	r   r#   rf   rg   rh   rw   attn_residualr{   outputsr    r    r!   r$   :  s"    



zWavLMEncoderLayer.forward)T)NNFr   r&   r'   r(   r   r   r   r$   r)   r    r    r   r!   r   *  s   r   c                       s0   e Zd Zd	eed fddZd
ddZ  ZS ) WavLMEncoderLayerStableLayerNormTr   c                    sn   t    t|j|j|j|j|j|d| _t	
|j| _t	j|j|jd| _t|| _t	j|j|jd| _d S r   r   r   r   r    r!   r   T  s    

z)WavLMEncoderLayerStableLayerNorm.__init__NFc                 C   sf   |}|  |}| j||||d\}}}| |}|| }|| | | }||f}|rb||f7 }|S )N)rf   rg   rh   )rI   r   rM   r   r   )r   r#   rf   rg   rh   r   r{   r   r    r    r!   r$   c  s    


z(WavLMEncoderLayerStableLayerNorm.forward)T)NNFr   r    r    r   r!   r   S  s   r   c                       s&   e Zd Z fddZdddZ  ZS )	WavLMEncoderc                    sf   t     | _t | _tj j jd| _	t
 j| _t fddt jD | _d| _d S )NrD   c                    s   g | ]}t  |d kdqS r   )rY   )r   .0ir?   r    r!   
<listcomp>      z)WavLMEncoder.__init__.<locals>.<listcomp>Fr   r   r?   r*   pos_conv_embedr3   rF   r5   rH   rI   rL   r   rM   
ModuleListrangenum_hidden_layerslayersgradient_checkpointingrN   r   r   r!   r   y  s    

zWavLMEncoder.__init__NFTc                 C   sB  |rdnd }|rdnd }|d urD| ddd|jd }d|| < | |}	||	 }| |}| |}t pvt| }
d }t| j	D ]~\}}|r||f }t
g }| jo|dko|| jjk }|r|
r||||||d}|d d \}}|rd}|r||d f }q|r||f }|s4tdd	 |||fD S t|||d
S )Nr    rC   r   r   r   r   NNNc                 s   s   | ]}|d ur|V  qd S rO   r    r   vr    r    r!   	<genexpr>  r   z'WavLMEncoder.forward.<locals>.<genexpr>last_hidden_stater#   
attentions)ro   rp   rr   r   rI   rM   r   r   	enumerater   ra   randr   r?   	layerdropr   r
   r   r#   rf   rh   output_hidden_statesreturn_dictZall_hidden_statesZall_self_attentionsZexpand_attention_maskZposition_embeddingsZsynced_gpusrg   r   layerZdropout_probabilityZskip_the_layerZlayer_outputsr    r    r!   r$     sL    






zWavLMEncoder.forward)NFFTr%   r    r    r   r!   r   x  s       r   c                       s&   e Zd Z fddZdddZ  ZS )	WavLMEncoderStableLayerNormc                    sf   t     | _t | _tj j jd| _	t
 j| _t fddt jD | _d| _d S )NrD   c                    s   g | ]}t  |d kdqS r   )r   r   r   r    r!   r     s   z8WavLMEncoderStableLayerNorm.__init__.<locals>.<listcomp>Fr   rN   r   r   r!   r     s    


z$WavLMEncoderStableLayerNorm.__init__NFTc                 C   s@  |rdnd }|rdnd }|d urD| ddd|jd }d|| < | |}	||	 }| |}t plt| }
d }t| jD ]|\}}|r||f }t	
g }| jo|dko|| jjk }|r|
r|||||d}|d d \}}|rd}|r|||d f }q|| |}|r||f }|s2tdd	 |||fD S t|||d
S )Nr    rC   r   r   r   )rf   rh   rg   r   c                 s   s   | ]}|d ur|V  qd S rO   r    r   r    r    r!   r     r   z6WavLMEncoderStableLayerNorm.forward.<locals>.<genexpr>r   )ro   rp   rr   r   rM   r   r   r   r   ra   r   r   r?   r   rI   r   r
   r   r    r    r!   r$     sF    






z#WavLMEncoderStableLayerNorm.forward)NFFTr%   r    r    r   r!   r     s       r   c                       s4   e Zd ZdZ fddZedd Zdd Z  ZS )WavLMGumbelVectorQuantizerz
    Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
    GUMBEL-SOFTMAX](https://huggingface.co/papers/1611.01144) for more information.
    c                    s   t    |j| _|j| _|j| j dkrDtd|j d| j dt	t
d| j| j |j| j | _t|jd | j| j | _d| _d S )Nr   z`config.codevector_dim z5 must be divisible by `config.num_codevector_groups` z for concatenation.r   rC   r   )r   r   Znum_codevector_groups
num_groupsZnum_codevectors_per_groupnum_varsZcodevector_dimr[   r3   r`   ra   r   codevectorsrJ   rG   weight_projtemperaturerN   r   r    r!   r     s    

z#WavLMGumbelVectorQuantizer.__init__c                 C   s8   | j dd}ttj|t|d  dd  }|S )Nr   rk   gHz>rC   )meanra   exprt   r   )ZprobsZmarginal_probs
perplexityr    r    r!   _compute_perplexity(  s    (z.WavLMGumbelVectorQuantizer._compute_perplexityc                 C   s  |j \}}}| |}||| | j d}| jrtjj| | j	dd}|
|}tj||| | jd dd}| |}nH|jdd}|j|j  d|ddd}||| | jd}| |}||| d}|d| j }	|	|| | j| jd}
|
d||d}
|
|fS )NrC   T)tauhardrk   r   rl   )rr   r   rq   r   r   r3   
functionalZgumbel_softmaxr   r   Ztype_asra   softmaxr   argmaxZ	new_zerosZscatter_ro   r   r   rt   )r   r#   
batch_sizesequence_lengthr5   Zcodevector_probsZcodevector_soft_distr   Zcodevector_idxZcodevectors_per_groupr   r    r    r!   r$   .  s*    


z"WavLMGumbelVectorQuantizer.forward)	r&   r'   r(   r   r   staticmethodr   r$   r)   r    r    r   r!   r     s
   
r   c                   @   sj   e Zd ZU eed< dZdZdZdZdZ	dZ
dd Zdeejef ee d	d
dZdeejdddZdS )WavLMPreTrainedModelr?   wavlminput_valuesTFc              	   C   s  t |tr>|jjjjddd |jjj  tj	
|j njt |trtj	j|jjddtd|jjd |jj   d tj	|jjd nt |trtd|jj }tj	j
|jj| |d tj	j
|jj| |d nt |tjr|jjjd| jjd |jdur|jj  nt |tjtjfrN|jj  |jjd nZt |tjrtj	|j |jdurt|j|j|jd   }tj	j
|j| |d dS )	zInitialize the weightsrQ   r   )r   stdr   r   )abNrl   )r   r   r   r/   dataZnormal_r   Zzero_r3   inituniform_r   r*   r6   r   sqrtr+   Zin_channelsZ	constant_rB   rK   Zin_featuresrJ   r?   Zinitializer_rangerF   	GroupNormZfill_r4   Zkaiming_normal_r-   )r   modulekr    r    r!   _init_weights]  s6    

 
z"WavLMPreTrainedModel._init_weightsN)input_lengthsadd_adapterc                 C   sn   |du r| j jn|}dd }t| j j| j jD ]\}}||||}q.|rjt| j jD ]}||d| j j}qT|S )zH
        Computes the output length of the convolutional layers
        Nc                 S   s   t j| | |ddd S )Nfloor)Zrounding_moder   )ra   divinput_lengthr+   strider    r    r!   _conv_out_length  s    zOWavLMPreTrainedModel._get_feat_extract_output_lengths.<locals>._conv_out_lengthr   )r?   r   zipconv_kernelconv_strider   num_adapter_layersadapter_stride)r   r   r   r   r+   r   rx   r    r    r!    _get_feat_extract_output_lengths~  s    z5WavLMPreTrainedModel._get_feat_extract_output_lengths)feature_vector_lengthrf   c                 C   s   |j ddd d df }| j||d}|tj}|jd }tj||f|j|jd}d|tj	|jd |jd|d f< |
dg d
dg }|S )NrC   rk   r   r   )r   r   r   )r   )Zcumsumr   r   ra   r   rr   zerosr   r   r   flipr   )r   r  rf   r   Znon_padded_lengthsZoutput_lengthsr   r    r    r!   "_get_feature_vector_attention_mask  s    
"z7WavLMPreTrainedModel._get_feature_vector_attention_mask)N)N)r&   r'   r(   r   __annotations__Zbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_supports_flash_attnZ_supports_sdpaZ_supports_flex_attnr   r   ra   r   r   r   r   r   r  r    r    r    r!   r   S  s    
"  r   c                       s&   e Zd Zd fdd	Zdd Z  ZS )WavLMNoLayerNormConvLayerr   c                    sj   t    |dkr |j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
t|j | _d S )Nr   r   r+   r   r   )r   r   rG   in_conv_dimout_conv_dimr3   r4   r   r   	conv_biasr6   r   r=   r>   r   r?   layer_idr   r    r!   r     s    
z"WavLMNoLayerNormConvLayer.__init__c                 C   s   |  |}| |}|S rO   )r6   r>   r"   r    r    r!   r$     s    

z!WavLMNoLayerNormConvLayer.forward)r   r%   r    r    r   r!   r    s   r  c                       s&   e Zd Zd fdd	Zdd Z  ZS )WavLMLayerNormConvLayerr   c                    s|   t    |dkr |j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
tj| jdd| _t|j | _d S )Nr   r   r  T)Zelementwise_affine)r   r   rG   r	  r
  r3   r4   r   r   r  r6   rF   rI   r   r=   r>   r  r   r    r!   r     s    
z WavLMLayerNormConvLayer.__init__c                 C   s:   |  |}|dd}| |}|dd}| |}|S )Nr   rC   )r6   rA   rI   r>   r"   r    r    r!   r$     s    


zWavLMLayerNormConvLayer.forward)r   r%   r    r    r   r!   r    s   r  c                       s&   e Zd Zd fdd	Zdd Z  ZS )WavLMGroupNormConvLayerr   c                    s   t    |dkr |j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
t|j | _tj| j| jdd| _d S )Nr   r   r  T)r   Znum_channelsZaffine)r   r   rG   r	  r
  r3   r4   r   r   r  r6   r   r=   r>   r   rI   r  r   r    r!   r     s    
z WavLMGroupNormConvLayer.__init__c                 C   s"   |  |}| |}| |}|S rO   )r6   rI   r>   r"   r    r    r!   r$     s    


zWavLMGroupNormConvLayer.forward)r   r%   r    r    r   r!   r    s   r  c                       s0   e Zd ZdZ fddZdd Zdd Z  ZS )WavLMFeatureEncoderz.Construct the features from raw audio waveformc                    s   t     jdkr@t ddg fddt jd D  }n6 jdkrd fddt jD }ntd	 j d
t|| _	d| _
d| _d S )Ngroupr   r  c                    s   g | ]}t  |d  dqS )r   r  )r  r   r   r    r!   r     s   z0WavLMFeatureEncoder.__init__.<locals>.<listcomp>r   r   c                    s   g | ]}t  |d qS )r  )r  r   r   r    r!   r     r   z`config.feat_extract_norm` is z), but has to be one of ['group', 'layer']FT)r   r   Zfeat_extract_normr  r   Znum_feat_extract_layersr[   r3   r   conv_layersr   _requires_grad)r   r?   r  r   r   r!   r     s    



zWavLMFeatureEncoder.__init__c                 C   s   |   D ]
}d|_qd| _d S )NF)
parametersrequires_gradr  r   paramr    r    r!   _freeze_parameters  s    z&WavLMFeatureEncoder._freeze_parametersc                 C   s:   |d d d f }| j r"| jr"d|_| jD ]}||}q(|S )NT)r  r   r  r  )r   r   r#   Z
conv_layerr    r    r!   r$     s    

zWavLMFeatureEncoder.forward)r&   r'   r(   r   r   r  r$   r)   r    r    r   r!   r    s   r  c                       s$   e Zd Z fddZdd Z  ZS )WavLMAdapterLayerc                    s0   t    tj|jd|j |j|jdd| _d S )Nr   r   )r   r,   )r   r   r3   r4   output_hidden_sizeZadapter_kernel_sizer   r6   rN   r   r    r!   r     s    
zWavLMAdapterLayer.__init__c                 C   s   |  |}tjj|dd}|S )Nr   rk   )r6   r3   r   Zglur"   r    r    r!   r$   #  s    
zWavLMAdapterLayer.forwardr%   r    r    r   r!   r    s   
r  c                       s$   e Zd Z fddZdd Z  ZS )WavLMAdapterc                    sp   t     j jkr8t j j| _t j| _nd  | _| _t	 fddt
 jD | _ j| _d S )Nc                 3   s   | ]}t  V  qd S rO   )r  r   rx   r   r    r!   r   5  r   z(WavLMAdapter.__init__.<locals>.<genexpr>)r   r   r  r5   r3   rJ   projrF   proj_layer_normr   r   r   r   r   rN   r   r   r!   r   +  s    
 zWavLMAdapter.__init__c                 C   sr   | j d ur(| jd ur(|  |}| |}|dd}| jD ]&}tj }| jrX|| jkr:||}q:|dd}|S r@   )r  r  rA   r   nprandomr   r   )r   r#   r   Zlayerdrop_probr    r    r!   r$   8  s    




zWavLMAdapter.forwardr%   r    r    r   r!   r  *  s   r  )rr   	mask_probmask_lengthrf   	min_masksri   c                    s  | \}dk rt dkr6t d d dtjd   fdd}|durt| d	 nfd
dt|D }tj	|ft
d}g }	|}
|
dkr|S |D ]v}||}tjjt|d  |dd}t|dkrd }n|d }t|tj|
| tjd| g}|	| qt|	}	t|	dddddf ||
f}	|	||
 }	tddddf }t|||
f||
 }|	| }	|	 d kr҈d |	|	d k< t||	dd	 |S )an  
    Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
    ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
    CPU as part of the preprocessing during training.

    Args:
        shape: The shape for which to compute masks. This should be of a tuple of size 2 where
               the first element is the batch size and the second element is the length of the axis to span.
        mask_prob:  The percentage of the whole axis (between 0 and 1) which will be masked. The number of
                    independently generated mask spans of length `mask_length` is computed by
                    `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
                    actual percentage will be smaller.
        mask_length: size of the mask
        min_masks: minimum number of masked spans
        attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
                        each batch dimension.
    r   z&`mask_length` has to be bigger than 0.zO`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: z and `sequence_length`: `c                    sX   t |     }t|}| kr2 }| d  |k rTt| d  d}|S )z;Given input length, compute how many spans should be maskedr   r   )r   max)r   num_masked_spanepsilonr#  r"  r$  r   r    r!   compute_num_masked_spano  s    
z6_compute_mask_indices.<locals>.compute_num_masked_spanNrC   c                    s   g | ]} qS r    r    r  )r   r    r!   r     r   z)_compute_mask_indices.<locals>.<listcomp>r   r   F)replace)r[   r   r!  r   itemdetachrt   tolistr   r  r   choicer   lenZconcatenaterb   Zint32appendarrayr   Zreshaper&  Zput_along_axis)rr   r"  r#  rf   r$  r   r*  r   Zspec_aug_maskZspec_aug_mask_idxsZmax_num_masked_spanr   r'  Zspec_aug_mask_idxZdummy_mask_idxoffsetsr    r(  r!   _compute_mask_indicesI  s\    

r4  c                       s   e Zd Zed fddZdd Zdd Zdeje	ej e	ej
 d	d
dZede	ej e	ej e	ej e	e e	e e	e eeef dddZ  ZS )
WavLMModelr   c                    s   t  | || _t|| _t|| _|jdks:|jdkrRt	
t|j | _|jrdt|| _n
t|| _|jr|t|nd | _|   d S )NrQ   )r   r   r?   r  feature_extractorrB   feature_projectionmask_time_probmask_feature_probr3   r`   ra   r   r5   r   masked_spec_embedZdo_stable_layer_normr   encoderr   r   r  adapter	post_initrN   r   r    r!   r     s    


zWavLMModel.__init__c                 C   s   t dt |   dS z
        Calling this function will disable the gradient computation for the feature encoder so that its parameters will
        not be updated during training.
        The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. Please use the equivalent `freeze_feature_encoder` method instead.NwarningswarnFutureWarningfreeze_feature_encoderr   r    r    r!   freeze_feature_extractor  s
    z#WavLMModel.freeze_feature_extractorc                 C   s   | j   dS 
        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
        not be updated during training.
        N)r6  r  rE  r    r    r!   rD    s    z!WavLMModel.freeze_feature_encoderN)r#   mask_time_indicesrf   c                 C   s  t | jdds|S | \}}}|dur<| j|j||< nZ| jjdkr| jrt||f| jj| jj	|| jj
d}tj||jtjd}| j|j||< | jjdkr| jrt||f| jj| jj| jjd}tj||jtjd}|dddf d|d}d||< |S )	z
        Masks extracted features along time axis and/or along feature axis according to
        [SpecAugment](https://huggingface.co/papers/1904.08779).
        Zapply_spec_augmentTNr   )r"  r#  rf   r$  )r   r   )r"  r#  r$  rC   )getattrr?   rm   r:  r   r   r8  r   r4  Zmask_time_lengthZmask_time_min_masksra   Ztensorr   r   r9  Zmask_feature_lengthZmask_feature_min_masksexpand)r   r#   rI  rf   r   r   r5   Zmask_feature_indicesr    r    r!   _mask_hidden_states  s4    zWavLMModel._mask_hidden_states)r   rf   rI  rh   r   r   ri   c           
      C   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}| |}|dd}|durp| j|jd |dd}| |\}}| j	|||d}| j
|||||d}	|	d }| jdur| |}|s||f|	dd  S t|||	j|	jd	S )
a/  
        mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
            masked extracted features in *config.proj_codevector_dim* space.
        Nr   r   Fr  )rI  rf   rf   rh   r   r   r   )r   extract_featuresr#   r   )r?   rh   r   use_return_dictr6  rA   r  rr   r7  rL  r;  r<  WavLMBaseModelOutputr#   r   )
r   r   rf   rI  rh   r   r   rN  r#   Zencoder_outputsr    r    r!   r$     s@    


zWavLMModel.forward)NN)NNNNN)r&   r'   r(   r   r   rF  rD  ra   r   r   r   rL  r   r   r   r   r   rP  r$   r)   r    r    r   r!   r5    s2   
  .     
r5  r   zm
    WavLM Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
    )Zcustom_introc                       s   e Zd Zdee d fddZdd Zdd Zd	d
 Zdd Z	e
deej eej ee ee ee eej eeef dddZ  ZS )WavLMForCTCN)target_langc                    s~   t  | t|| _t|j| _|| _|j	du rFt
d| j dt|dr\|jr\|jn|j}t||j	| _|   dS )a/  
        target_lang (`str`, *optional*):
            Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
            adapter.<lang>.bin. Only relevant when using an instance of [`WavLMForCTC`] with adapters. Uses 'eng' by
            default.
        NzYou are trying to instantiate z with a configuration that does not define the vocabulary size of the language model head. Please instantiate the model as follows: `WavLMForCTC.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of your model's configuration.r   )r   r   r5  r   r3   rL   Zfinal_dropoutrM   rR  
vocab_sizer[   r   r8   r   r  r5   rJ   lm_headr=  )r   r?   rR  r  r   r    r!   r   ^  s    

zWavLMForCTC.__init__c                 C   sr   | j }|dur2t| jdddu r2td| dn<|du rXt| jdddurXtd n|durn| j|dd dS )a'  
        This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
        passing `target_lang=...` to `from_pretrained(...)`.

        This method is **not** supposed to be called by the user and is prone to be changed in the future.
        NZadapter_attn_dimzCannot pass `target_lang`: z- if `config.adapter_attn_dim` is not defined.z)By default `target_lang` is set to 'eng'.T)Z
force_load)rR  rJ  r?   r[   loggerinfoZload_adapter)r   rR  r    r    r!   tie_weights{  s    zWavLMForCTC.tie_weightsc                 C   s   t dt |   dS rH  r?  Nr@  rE  r    r    r!   rF    s
    z$WavLMForCTC.freeze_feature_extractorc                 C   s   | j j  dS rG  r   r6  r  rE  r    r    r!   rD    s    z"WavLMForCTC.freeze_feature_encoderc                 C   s   | j  D ]
}d|_q
dS z
        Calling this function will disable the gradient computation for the base model so that its parameters will not
        be updated during training. Only the classification head will be updated.
        FNr   r  r  r  r    r    r!   freeze_base_model  s    zWavLMForCTC.freeze_base_modelr   rf   rh   r   r   labelsri   c              
   C   s  |dur|n| j j}|dur>| | j jkr>td| j j | j|||||d}|d }| |}| |}	d}
|dur@|dur|ntj	|tj
d}| |dtj
}|dk}|d}||}tjj|	dtjddd}tjjjd	d
6 tjj||||| j j| j j| j jd}
W d   n1 s60    Y  |sp|	f|td  }|
durl|
f| S |S t|
|	|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
            Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
            the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
            All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
            config.vocab_size - 1]`.
        Nz$Label values must be <= vocab_size: rM  r   r   rC   )r1   r   r   F)Zenabled)blankZ	reductionZzero_infinitylosslogitsr#   r   )r?   rO  r&  rS  r[   r   rM   rT  ra   Z	ones_liker   r   rt   r   Zmasked_selectr3   r   Zlog_softmaxZfloat32rA   backendsZcudnnflagsZctc_lossZpad_token_idZctc_loss_reductionZctc_zero_infinity_HIDDEN_STATES_START_POSITIONr   r#   r   )r   r   rf   rh   r   r   r^  r   r#   rb  ra  r   Zlabels_maskZtarget_lengthsZflattened_targetsZ	log_probsoutputr    r    r!   r$     sL    




&
zWavLMForCTC.forward)N)NNNNN)r&   r'   r(   r   r   r   rW  rF  rD  r\  r   ra   r   r   r   r   r   r$   r)   r    r    r   r!   rQ  X  s(        
rQ  z
    WavLM Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
    SUPERB Keyword Spotting.
    c                       sz   e Zd Z fddZdd Zdd Zdd Zedee	j
 ee	j
 ee ee ee ee	j
 eeef d
ddZ  ZS )WavLMForSequenceClassificationc                    s   t  | t|dr$|jr$tdt|| _|jd }|jrTt	
t|| | _t	|j|j| _t	|j|j| _|   d S )Nr   z\Sequence classification does not support the use of WavLM adapters (config.add_adapter=True)r   )r   r   r8   r   r[   r5  r   r   use_weighted_layer_sumr3   r`   ra   rb   layer_weightsrJ   r5   Zclassifier_proj_size	projector
num_labels
classifierr=  r   r?   
num_layersr   r    r!   r     s    

z'WavLMForSequenceClassification.__init__c                 C   s   t dt |   dS r>  r@  rE  r    r    r!   rF    s
    z7WavLMForSequenceClassification.freeze_feature_extractorc                 C   s   | j j  dS rG  rY  rE  r    r    r!   rD    s    z5WavLMForSequenceClassification.freeze_feature_encoderc                 C   s   | j  D ]
}d|_q
dS rZ  r[  r  r    r    r!   r\    s    z0WavLMForSequenceClassification.freeze_base_modelNr]  c                 C   s  |dur|n| j j}| j jr dn|}| j|||||d}| j jr|t }tj|dd}tjj	| j
dd}	||	ddd jdd}n|d }| |}|du r|jdd}
nV| |jd |}|ddd|jd }d	|| < |jdd|jdddd }
| |
}d}|dur<t }||d| j j|d}|sl|f|td  }|durh|f| S |S t|||j|jd
S )	  
        input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
            into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
            (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
            To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
            into a tensor of type `torch.FloatTensor`. See [`WavLMProcessor.__call__`] for details.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        NTrM  r   rk   rC   r   r   rQ   r`  )r?   rO  rh  r   re  ra   stackr3   r   r   ri  rq   rt   rj  r   r  rr   ro   rp   rl  r   rk  r   r#   r   )r   r   rf   rh   r   r   r^  r   r#   norm_weightsZpooled_outputZpadding_maskZexpand_padding_maskrb  ra  loss_fctrf  r    r    r!   r$   &  sH    

 

z&WavLMForSequenceClassification.forward)NNNNN)r&   r'   r(   r   rF  rD  r\  r   r   ra   r   r   r   r   r   r$   r)   r    r    r   r!   rg    s&        
rg  c                       sz   e Zd Z fddZdd Zdd Zdd Zedee	j
 ee	j
 ee	j
 ee ee ee eeef d
ddZ  ZS ) WavLMForAudioFrameClassificationc                    sz   t  | t|dr$|jr$tdt|| _|jd }|jrTt	
t|| | _t	|j|j| _|j| _|   d S )Nr   z_Audio frame classification does not support the use of WavLM adapters (config.add_adapter=True)r   )r   r   r8   r   r[   r5  r   r   rh  r3   r`   ra   rb   ri  rJ   r5   rk  rl  init_weightsrm  r   r    r!   r   n  s    

z)WavLMForAudioFrameClassification.__init__c                 C   s   t dt |   dS rX  r@  rE  r    r    r!   rF  ~  s
    z9WavLMForAudioFrameClassification.freeze_feature_extractorc                 C   s   | j j  dS rG  rY  rE  r    r    r!   rD    s    z7WavLMForAudioFrameClassification.freeze_feature_encoderc                 C   s   | j  D ]
}d|_q
dS rZ  r[  r  r    r    r!   r\    s    z2WavLMForAudioFrameClassification.freeze_base_modelN)r   rf   r^  rh   r   r   ri   c                 C   s   |dur|n| j j}| j jr dn|}| j|||||d}| j jr|t }tj|dd}tjj	| j
dd}	||	ddd jdd}n|d }| |}
d}|durt }||
d| jtj|d| jdd}|s|
f|td  }|S t||
|j|jd	S )
ro  NTrM  r   rk   rC   r   )Zaxisr`  )r?   rO  rh  r   re  ra   rp  r3   r   r   ri  rq   rt   rl  r   rk  r   r   r#   r   )r   r   rf   r^  rh   r   r   r   r#   rq  rb  ra  rr  rf  r    r    r!   r$     s:    
(z(WavLMForAudioFrameClassification.forward)NNNNN)r&   r'   r(   r   rF  rD  r\  r   r   ra   r   r   r   r   r   r$   r)   r    r    r   r!   rs  l  s&        
rs  c                       s&   e Zd Zd fdd	Zdd Z  ZS )AMSoftmaxLoss      >@皙?c                    sB   t    || _|| _|| _tjt||dd| _	t
 | _d S )NT)r  )r   r   scalemarginrk  r3   r`   ra   Zrandnr/   r   ra  )r   Z	input_dimrk  rx  ry  r   r    r!   r     s    
zAMSoftmaxLoss.__init__c           	      C   sx   |  }tjj| jdd}tjj|dd}t||}|| j }tj|| j	}| j
t| || }| ||}|S )Nr   rk   r   )flattenr3   r   	normalizer/   ra   mmry  Zone_hotrk  rx  r   r   ra  )	r   r#   r^  r/   Z	cos_thetapsiZonehotrb  ra  r    r    r!   r$     s    
zAMSoftmaxLoss.forward)rv  rw  r%   r    r    r   r!   ru    s   ru  c                       s2   e Zd Zd fdd	ZejejdddZ  ZS )	TDNNLayerr   c                    sv   t    |dkr |j|d  n|j| | _|j| | _|j| | _|j| | _t	
| j| j | j| _t	 | _d S )Nr   r   )r   r   tdnn_dimr	  r
  tdnn_kernelr+   Ztdnn_dilationdilationr3   rJ   kernelZReLUr>   r  r   r    r!   r     s    
"zTDNNLayer.__init__)r#   ri   c                 C   s   t  rddlm} t  r.t| j|r.td |dd}| jj	| j
| j| jdd}tjj||| jj| jd}|dd}| |}|S )Nr   )	LoraLayerzDetected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. You should exclude TDNNLayer from LoRA's target modules.r   r   )r  )r   Zpeft.tuners.lorar  r   r  rA  rB  rA   r/   rq   r
  r+   r	  r3   r   Zconv1dr   r  r>   )r   r#   r  r/   r    r    r!   r$     s     
zTDNNLayer.forward)r   )r&   r'   r(   r   ra   r   r$   r)   r    r    r   r!   r~    s   
r~  zi
    WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification.
    c                       s   e Zd Z fddZdd Zdd Zdd Zeej	e
f d	d
dZedeej eej ee ee ee eej eeef dddZ  ZS )WavLMForXVectorc                    s   t    t | _ jd } jr<tt	|| | _
t j jd | _ fddtt jD }t|| _t jd d  j| _t j j| _t j j| _|   d S )Nr   r   c                    s   g | ]}t  |qS r    )r~  r   r   r    r!   r     r   z,WavLMForXVector.__init__.<locals>.<listcomp>rC   r   )r   r   r5  r   r   rh  r3   r`   ra   rb   ri  rJ   r5   r  rj  r   r0  r   tdnnZxvector_output_dimr6  rl  ru  rk  	objectivert  )r   r?   rn  Ztdnn_layersr   r   r!   r     s    

zWavLMForXVector.__init__c                 C   s   t dt |   dS rX  r@  rE  r    r    r!   rF  &  s
    z(WavLMForXVector.freeze_feature_extractorc                 C   s   | j j  dS rG  rY  rE  r    r    r!   rD  2  s    z&WavLMForXVector.freeze_feature_encoderc                 C   s   | j  D ]
}d|_q
dS rZ  r[  r  r    r    r!   r\  9  s    z!WavLMForXVector.freeze_base_model)r   c                 C   s&   dd }| j jD ]}|||d}q|S )z?
        Computes the output length of the TDNN layers
        c                 S   s   | | | d S )Nr   r    r   r    r    r!   r   F  s    zBWavLMForXVector._get_tdnn_output_lengths.<locals>._conv_out_lengthr   )r?   r  )r   r   r   r+   r    r    r!   _get_tdnn_output_lengthsA  s    z(WavLMForXVector._get_tdnn_output_lengthsNr]  c                 C   s  |dur|n| j j}| j jr dn|}| j|||||d}| j jr|t }tj|dd}tjj	| j
dd}	||	ddd jdd}n|d }| |}| jD ]}
|
|}q|du r|jdd}|jdd}n| |jdd}| |}g }g }t|D ]D\}}|||d|f jdd |||d|f jdd qt|}t|}tj||gdd}| |}| |}d}|dur| ||}|s||f|td  }|dur|f| S |S t||||j|jdS )	ro  NTrM  r   rk   rC   r   )ra  rb  Z
embeddingsr#   r   )r?   rO  rh  r   re  ra   rp  r3   r   r   ri  rq   rt   rj  r  r   r   r   r  r   r1  r   r6  rl  r  r   r#   r   )r   r   rf   rh   r   r   r^  r   r#   rq  Z
tdnn_layerZmean_featuresZstd_featuresZfeat_extract_output_lengthsZtdnn_output_lengthsr   lengthZstatistic_poolingZoutput_embeddingsrb  ra  rf  r    r    r!   r$   P  s\    



 




zWavLMForXVector.forward)NNNNN)r&   r'   r(   r   rF  rD  r\  r   ra   r   r   r  r   r   r   r   r   r   r$   r)   r    r    r   r!   r    s(        
r  )rs  rQ  rg  r  r5  r   )Nr   )Jr   rA  typingr   r   numpyr   ra   Ztorch.nnr3   Ztorch.nn.functionalr   r}   r   Zactivationsr   Zintegrations.deepspeedr   Zintegrations.fsdpr   Zmodeling_layersr	   Zmodeling_outputsr
   r   r   r   r   r   Zmodeling_utilsr   r7   r   r   r   Zconfiguration_wavlmr   Z
get_loggerr&   rU  Moduler   r*   rB   rP   r   r   r   r   r   r   r   r  r  r  r  r  r  r   r   r   r   Zndarrayr4  rP  r5  re  rQ  rg  rs  ru  r~  r  __all__r    r    r    r!   <module>   s    
- ')%JKFV&#  
w  si  