a
    h
Z                     @   s  d Z ddlZddlmZ ddlmZmZ ddlZddlm	Z	 ddl
m	  mZ ddlmZ ddlmZmZ dd	lmZ d
dlmZ eG dd deZeG dd deZeG dd deZG dd de	jZG dd de	jZG dd de	jZG dd de	jZG dd de	jZG dd de	jZG dd de	jZ G d d! d!e	jZ!eG d"d# d#eZ"ed$d%G d&d' d'e"Z#d'd#gZ$dS )(zTransformers Xcodec model.    N)	dataclass)OptionalUnion   )PreTrainedAudioTokenizerBase)ModelOutputauto_docstring   )	AutoModel   )XcodecConfigc                   @   s6   e Zd ZU dZdZeej ed< dZ	eej
 ed< dS )XcodecOutputao  
    Args:
        audio_codes (`torch.LongTensor`  of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
            Discrete code indices computed using `model.encode`.
        audio_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`, *optional*)
            Decoded audio values obtained using the decoder part of Xcodec.
    Naudio_codesaudio_values)__name__
__module____qualname____doc__r   r   torch
LongTensor__annotations__r   FloatTensor r   r   f/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/xcodec/modeling_xcodec.pyr      s   
r   c                   @   s$   e Zd ZU dZdZeej ed< dS )XcodecEncoderOutputz
    Args:
        audio_codes (`torch.LongTensor`  of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
            Discrete code indices computed using `model.encode`.
    Nr   )	r   r   r   r   r   r   r   r   r   r   r   r   r   r   -   s   
r   c                   @   s$   e Zd ZU dZdZeej ed< dS )XcodecDecoderOutputz
    Args:
        audio_values (`torch.FloatTensor`  of shape `(batch_size, channels, num_samples)`, *optional*):
            Decoded audio values obtained using the decoder part of Xcodec.
    Nr   )	r   r   r   r   r   r   r   r   r   r   r   r   r   r   8   s   
r   c                       s@   e Zd ZdZeeeed fddZejejdddZ	  Z
S )ResidualUnitzFResidual block for SemanticEncoder and SemanticDecoder used in Xcodec.)configin_channelsout_channelsdilationc              
      s\   t    t | _|jd d | }tj|||jd||ddd| _tj||ddd| _d S )Nr   r	   F)stridepaddingr    groupsbias)r   r   kernel_sizer$   )	super__init__nnZELU
activationZunit_kernel_sizeConv1dconv1conv2)selfr   r   r   r    r"   	__class__r   r   r'   F   s    


zResidualUnit.__init__hidden_statereturnc                 C   s0   |  |}| |}|  |}| |}|| S N)r)   r+   r,   )r-   r1   Zoutput_tensorr   r   r   forwardV   s
    



zResidualUnit.forward)r   r   r   r   r   intr'   r   Tensorr4   __classcell__r   r   r.   r   r   C   s   r   c                       s<   e Zd Zeeeed fddZejejdddZ  Z	S )SemanticEncoderBlockr   r   r   r!   c                    sd   t    t fdd jD | _|dkr4dnd| }|d d }tj||||dd| _d S )Nc                    s   g | ]}t  |qS r   r   .0r    r   r   r   r   
<listcomp>b       z1SemanticEncoderBlock.__init__.<locals>.<listcomp>r   r   r	   Tr%   r!   r"   r$   )r&   r'   r(   
ModuleListblock_dilations	res_unitsr*   conv)r-   r   r   r   r!   Zkernelr"   r.   r=   r   r'   _   s    
zSemanticEncoderBlock.__init__r0   c                 C   s"   | j D ]}||}q| |}|S r3   )rC   rD   r-   r1   unitr   r   r   r4   j   s    


zSemanticEncoderBlock.forward
r   r   r   r   r5   r'   r   r6   r4   r7   r   r   r.   r   r8   ^   s   r8   c                       s0   e Zd Z fddZejejdddZ  ZS )SemanticEncoderc                    s   t    t|jt|jkr&tdtj|j|j|j	d|j	d dd| _
|j}g }t|jD ]4\}}t|j|j|  }|t||||g7 }|}q^t|| _d S )Nz:Number of strides must match the number of channel_ratios.r   r	   Fr$   )r&   r'   lenstrideschannel_ratios
ValueErrorr(   r*   semantic_hidden_sizer%   rD   	enumerater5   r8   rA   conv_blocks)r-   r   r   rP   ir!   r   r.   r   r   r'   r   s$    
	zSemanticEncoder.__init__r0   c                 C   s"   |  |}| jD ]}||}q|S r3   )rD   rP   r-   r1   blockr   r   r   r4      s    


zSemanticEncoder.forwardr   r   r   r'   r   r6   r4   r7   r   r   r.   r   rH   q   s   rH   c                       s<   e Zd Zeeeed fddZejejdddZ  Z	S )SemanticDecoderBlockr9   c              	      s   t    |dkr,tj|ddddd| _nBd| }|d d }|d dkrPdnd}tj|||||dd| _t fd	d
 jD | _d S )Nr   r   Tr@   r	   r   FrI   c                    s   g | ]}t  |qS r   r:   r;   r   r   r   r   r>      r?   z1SemanticDecoderBlock.__init__.<locals>.<listcomp>)	r&   r'   r(   r*   rD   ConvTranspose1drA   rB   rC   )r-   r   r   r   r!   r%   r"   output_paddingr.   rV   r   r'      s&    

	zSemanticDecoderBlock.__init__r0   c                 C   s"   |  |}| jD ]}||}q|S r3   )rD   rC   rE   r   r   r   r4      s    


zSemanticDecoderBlock.forwardrG   r   r   r.   r   rU      s   rU   c                       s0   e Zd Z fddZejejdddZ  ZS )SemanticDecoderc                    s   t    tj|jt|j|jd  |jd|jd dd| _g }t	|j
D ]b\}}t|j|j|  }|t|jd k rt|j|j|d   }n|j}|t||||g7 }qJt|| _tj|j|j|jd|jd dd| _d S )Nr   r   r	   F)r   r   r%   r!   r"   r$   )r!   r"   r$   )r&   r'   r(   r*   rN   r5   rL   r%   r+   rO   rK   rJ   rU   rA   rP   r,   )r-   r   rP   rQ   r!   r   r   r.   r   r   r'      s2    
zSemanticDecoder.__init__r0   c                 C   s,   |  |}| jD ]}||}q| |}|S r3   )r+   rP   r,   rR   r   r   r   r4      s
    



zSemanticDecoder.forwardrT   r   r   r.   r   rY      s   rY   c                       s8   e Zd ZdZ fddZdd Zdd Zdd	 Z  ZS )
XcodecEuclideanCodebookz!Codebook with Euclidean distance.c                    sj   t    t|j|j}|j| _| dtdg | dt|j | d| | d|  d S )NinitedTZcluster_sizeembedZ	embed_avg)	r&   r'   r   Zzeroscodebook_sizeZcodebook_dimZregister_bufferr6   clone)r-   r   r\   r.   r   r   r'      s    
z XcodecEuclideanCodebook.__init__c                 C   sV   | j  }|djddd}|d| |  |djddd  }|jddj}|S )Nr	   r   T)Zkeepdimr   dim)r\   tpowsummaxindices)r-   hidden_statesr\   Zscaled_statesdist	embed_indr   r   r   quantize   s
    
&z XcodecEuclideanCodebook.quantizec                 C   s8   |j }|d|d f}| |}|j|d d  }|S )Nr_   )shapeZreshaperj   view)r-   rg   rk   ri   r   r   r   encode   s
    
zXcodecEuclideanCodebook.encodec                 C   s   t || j}|S r3   )FZ	embeddingr\   )r-   ri   	quantizedr   r   r   decode   s    zXcodecEuclideanCodebook.decode)	r   r   r   r   r'   rj   rm   rp   r7   r   r   r.   r   rZ      s
   
rZ   c                       s6   e Zd ZdZed fddZdd Zdd Z  ZS )	XcodecVectorQuantizationzY
    Vector quantization implementation. Currently supports only euclidean distance.
    r   c                    s   t    t|| _d S r3   )r&   r'   rZ   codebookr-   r   r.   r   r   r'      s    
z!XcodecVectorQuantization.__init__c                 C   s   | ddd}| j|}|S Nr   r	   r   )permuters   rm   )r-   rg   Zembed_inr   r   r   rm      s    zXcodecVectorQuantization.encodec                 C   s   | j |}|ddd}|S ru   )rs   rp   rv   )r-   ri   rj   r   r   r   rp     s    zXcodecVectorQuantization.decode)	r   r   r   r   r   r'   rm   rp   r7   r   r   r.   r   rq      s   rq   c                       sh   e Zd ZdZed fddZdd Zdedd	d
Zde	j
e	j
dddZe	j
e	j
dddZ  ZS ) XcodecResidualVectorQuantizationzv
    Residual vector quantization implementation. Follows Algorithm 1 in https://huggingface.co/papers/2107.03312
    rr   c                    sF   t    t fddt jD | _ j| _ j| _ j| _d S )Nc                    s   g | ]}t  qS r   )rq   )r<   _rr   r   r   r>     r?   z=XcodecResidualVectorQuantization.__init__.<locals>.<listcomp>)	r&   r'   r(   rA   rangenum_quantizers
quantizers
frame_rater]   rt   r.   rr   r   r'     s
    
 z)XcodecResidualVectorQuantization.__init__c                 C   s   t | j| j d S )zReturn bandwidth per quantizer.i  )mathlog2r]   r|   )r-   r   r   r   get_bandwidth_per_quantizer  s    z<XcodecResidualVectorQuantization.get_bandwidth_per_quantizerN)r2   c                 C   s:   |   }| j}|dur6|dkr6ttdt|| }|S )z:Return num_quantizers based on specified target bandwidth.N        r   )r   rz   r5   re   r}   floor)r-   	bandwidthZbw_per_qrz   r   r   r    get_num_quantizers_for_bandwidth  s
    zAXcodecResidualVectorQuantization.get_num_quantizers_for_bandwidth)
embeddingsr2   c           
      C   sZ   |  |}|}g }| jd| D ]*}||}||}|| }|| q t|}	|	S )a  
        Encode the input tensor into discrete indices using RVQ, with the number of quantizers selected based on the given bandwidth.
        Each quantizer /codebook residually quantizes the input and returns the nearest indices in terms of Euclidian distance.
        N)r   r{   rm   rp   appendr   stack)
r-   r   r   rz   ZresidualZall_indices	quantizerrf   ro   Zout_indicesr   r   r   rm   $  s    



z'XcodecResidualVectorQuantization.encode)codesr2   c                 C   sB   t jd|jd}t|D ]$\}}| j| }||}|| }q|S )z9Decode the given codes to their quantized representation.r   )device)r   Ztensorr   rO   r{   rp   )r-   r   Zquantized_outrQ   rf   r   ro   r   r   r   rp   4  s    


z'XcodecResidualVectorQuantization.decode)N)N)r   r   r   r   r   r'   r   r5   r   r   r6   rm   rp   r7   r   r   r.   r   rw     s   rw   c                   @   s4   e Zd ZdZeZdZdZdd Zdd Z	dd	 Z
d
S )XcodecPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    Zxcodecinput_valuesc                 C   s   t |tjr:|jjjd| jjd |jdur|jj	  nt |tj
tjfrh|jj	  |jjd nVt |tjrtj|j |jdurt|j|j|jd   }tjj|j| |d dS )zInitialize the weightsr   )meanZstdNg      ?r   )ab)
isinstancer(   LinearweightdataZnormal_r   Zinitializer_ranger$   Zzero_Z	LayerNormZ	GroupNormZfill_r*   initZkaiming_normal_r}   sqrtr#   r   r%   Zuniform_)r-   modulekr   r   r   _init_weightsI  s    

z#XcodecPreTrainedModel._init_weightsc                 C   s   t jjj}tt jjjdr&t jjjj}|| jj || jj | jj	D ]8}||j |j
|j|jfD ]}||j ||j qdqF|| jjdd || jjdd | jj	D ]D}||jdd |j
|j|jfD ] }||jdd ||jdd qqdS )znApply weight norm in the acoustic encoder and decoder because the original checkpoint has weight norm applied.weight_normr   nameN)r   r(   utilsr   hasattrparametrizationsacoustic_encoderr+   r,   rS   Z	res_unit1Z	res_unit2Z	res_unit3acoustic_decoderZconv_t1)r-   r   rS   Zres_unitr   r   r   apply_weight_normY  s"    


z'XcodecPreTrainedModel.apply_weight_normc                 C   s|   | j | jfD ]j}| D ]\}ztjjj|dd W n ttfyH   Y n0 t	|drd|j
v rtjjjj|ddd qqdS )z=Remove the weight norm from the acoustic encoder and decoder.r   r   r   T)Zleave_parametrizedN)r   r   modulesr   r(   r   remove_weight_normrM   AttributeErrorr   r   ZparametrizeZremove_parametrizations)r-   r   mr   r   r   r   q  s    z(XcodecPreTrainedModel.remove_weight_normN)r   r   r   r   r   Zconfig_classZbase_model_prefixZmain_input_namer   r   r   r   r   r   r   r   >  s   r   z$The Xcodec neural audio codec model.)Zcustom_introc                
       s   e Zd Z fddZeejdddZej	ej	dddZ
edejee ee eejef d
ddZedejee eejef dddZedejeej ee ee eeejejf ef dddZ  ZS )XcodecModelc                    s   t  | || _|jd | _t|j}|j| _	|j
| _| | j t|| _t|| _t|j | _t|j|j| _t|j|jj| _t|j|jj| _t|| _d S )Nr	   )r&   r'   r   Z
hop_lengthpadr
   from_configZacoustic_model_configencoderr   decoderr   _adjust_dac_decoderrH   encoder_semanticrY   Zdecoder_semanticZsemantic_model_configevalsemantic_modelr(   r   Zhidden_sizefcZfc1fc2rw   r   )r-   r   Zacoustic_modelr.   r   r   r'     s    

zXcodecModel.__init__)r   c                 C   sh   |   D ]8}t|tjrt|jtr.|jd n|j}|d f|_qt| drdt| jtj	rdt
 | _dS )z
        DAC implemented in Xcodec is slightly different from the HF version.
        DAC in Xcodec adjusts the output padding in every ConvTranspose1d in the decoder and removes
        the final `nn.Tanh` activation function.
        r   r	   tanhN)r   r   r(   rW   r!   tuplerX   r   r   ZTanhZIdentity)r   r   r!   r   r   r   r     s    zXcodecModel._adjust_dac_decoder)r   r2   c                 C   s   |d d dd d f }t || j| jf}t $ | j|dd}|j}W d    n1 s\0    Y  tj|dd}|jddS )Nr   T)Zoutput_hidden_statesr   r`   )rn   r   r   Zno_gradr   rg   r   r   )r-   r   outputsrg   Zstackedr   r   r   _extract_semantic_features  s    
$z&XcodecModel._extract_semantic_featuresN)r   r   return_dictr2   c           
   	   C   s8  |dur|n| j j}|jd }|dkr4td| |du rJ| j jd }n&|| j jvrptd| d| j j d| | }| |dd}| 	|}|jd |jd kr| 	t
|ddd	ddf | j| jfd}tj||gdd
}| |dddd}| j||}	|	d	d}	|s0|	S t|	S )ac  
        input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`):
            Float values of the input audio waveform.
        bandwidth (`float`, *optional*):
            The target bandwidth in (kbps) supports only values in `config.target_bandwidths`.
            Defaults to the highest available bandwidth `4.0` kbps.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`].

        Returns:
            `torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)` containing the discrete encoded audio codes.
        Nr   zAudio must be mono, but got r_   z)This model doesn't support the bandwidth z. Select one of .r	   r   r`   )r   r   rk   rM   Ztarget_bandwidthsr   detachr   	transposer   rn   r   Z	unsqueezer   catr   r   rm   r   )
r-   r   r   r   ZchannelsZe_semantic_inputZ
e_semanticZ
e_acousticr   r   r   r   r   rm     s,    

2zXcodecModel.encode)r   r   r2   c                 C   s`   |dur|n| j j}|dd}| j|}| |dddd}| |}|sX|S t|S )a  
        audio_codes (`torch.LongTensor`  of shape `(batch_size, num_quantizers, codes_length)`):
            Discrete code indices computed using `model.encode`.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`]

        Returns:
            Decoded audio values of shape `(batch_size, channels, num_samples)` obtained using the decoder part of
            Xcodec.
        Nr   r   r	   )r   r   r   r   rp   r   r   r   )r-   r   r   ro   Zquantized_acousticr   r   r   r   rp     s    
zXcodecModel.decode)r   r   r   r   r2   c                 C   sl   |dur|n| j j}|jd }|du r6| j||dd}| j||dd dd|f }|s`||fS t||dS )a  
        input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`):
            The raw float values of the input audio waveform.
        audio_codes (`torch.LongTensor`  of shape `(batch_size, num_quantizers, codes_length)`:
            Discrete code indices computed using `model.encode`.
        bandwidth (`float`, *optional*):
            Target bandwidth in kbps. Must be one of `config.target_bandwidths`. Defaults to the highest available bandwidth.
        bandwidth (`float`, *optional*):
            Target bandwidth in kbps. Must be one of `config.target_bandwidths`. Defaults to the highest available bandwidth.
        return_dict (`bool`, *optional*):
            Whether to return a [`XcodecOutput`] instead of a plain tuple.

        Returns:
            `XcodecOutput` or tuple `(audio_codes, audio_values)`:
            - `audio_codes` of shape `(batch_size, num_quantizers, codes_length)`: the quantized discrete codes.
            - `audio_values` of shape `(batch_size, channels, num_samples)`: the reconstructed audio waveform given the codes.

        Example:

        ```python
        >>> from datasets import load_dataset
        >>> from transformers import AutoFeatureExtractor, XcodecModel

        >>> model_id = "hf-audio/xcodec-hubert-librispeech"
        >>> model = XcodecModel.from_pretrained(model_id)
        >>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)

        >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        >>> dataset = dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
        >>> audio_sample = dataset[0]['audio']['array']

        >>> inputs = feature_extractor(raw_audio=audio_sample, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> audio_codes = outputs.audio_codes
        >>> audio_values = outputs.audio_values
        ```
        Nr_   F)r   r   .)r   r   )r   r   rk   rm   rp   r   )r-   r   r   r   r   lengthr   r   r   r   r4     s    .
zXcodecModel.forward)NN)N)NNN)r   r   r   r'   staticmethodr(   Moduler   r   r   r   r   r6   r   floatboolr   r   rm   r   rp   r   r   r4   r7   r   r   r.   r   r   }  s>   
  1    r   )%r   r}   dataclassesr   typingr   r   r   Ztorch.nnr(   Ztorch.nn.functionalZ
functionalrn   Zmodeling_utilsr   r   r   r   autor
   Zconfiguration_xcodecr   r   r   r   r   r   r8   rH   rU   rY   rZ   rq   rw   r   r   __all__r   r   r   r   <module>   s<   

( 2> 4