a
    hN                     @   s   d Z ddlZddlmZ ddlmZmZ ddlmZm	Z	 ddl
mZ ddlmZmZmZmZ dd	lmZmZ e rzddlZe rddlZG d
d deddZG dd deddZG dd deZdgZdS )zProcessor class for Dia    N)Path)OptionalUnion   )
AudioInputmake_list_of_audio)BatchFeature)AudioKwargsProcessingKwargsProcessorMixinUnpack)is_soundfile_availableis_torch_availablec                   @   s:   e Zd ZU eed< eed< eed< ee ed< eed< dS )DiaAudioKwargsbos_token_ideos_token_idpad_token_iddelay_pattern
generationN)__name__
__module____qualname__int__annotations__listbool r   r   b/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/dia/processing_dia.pyr   "   s
   
r   F)totalc                   @   s@   e Zd ZU eed< dddddddg d	dd
dddidZdS )DiaProcessorKwargsaudio_kwargsTrightF)paddingZpadding_sideZadd_special_tokensi   i  i  )	r      	   
                  iD  )r   r   r   r   r   sampling_ratereturn_tensorspt)text_kwargsr    common_kwargsN)r   r   r   r   r   	_defaultsr   r   r   r   r   *   s   
r   c                	       s   e Zd ZdZdZdZdZ fddZdee	e
e	 f ee ee ee d	d
dZd dee ee e
d dddZd!dee ee ddddZdee edddZeee	ee
ee	ef  f ee dddZed"eeee
e eed dddZedeeed ddddZ  ZS )#DiaProcessora  
    Constructs a Dia processor which wraps a [`DiaFeatureExtractor`], [`DiaTokenizer`], and a [`DacModel`] into
    a single processor. It inherits, the audio feature extraction, tokenizer, and audio encode/decode functio-
    nalities. See [`~DiaProcessor.__call__`], [`~DiaProcessor.encode`], and [`~DiaProcessor.decode`] for more
    information.

    Args:
        feature_extractor (`DiaFeatureExtractor`):
            An instance of [`DiaFeatureExtractor`]. The feature extractor is a required input.
        tokenizer (`DiaTokenizer`):
            An instance of [`DiaTokenizer`]. The tokenizer is a required input.
        audio_tokenizer (`DacModel`):
            An instance of [`DacModel`] used to encode/decode audio into/from codebooks. It is is a required input.
    ZDiaFeatureExtractorZDiaTokenizerZDacModelc                    s   t  j|||d d S )N)audio_tokenizer)super__init__)selffeature_extractor	tokenizerr2   	__class__r   r   r4   R   s    zDiaProcessor.__init__NF)textaudiooutput_labelskwargsc           '   	   K   s  t  std|du rtd| jtfi |}|d }|d }|d }|dd}	|	dkrnt| jj d	i }
t|tr|g}n(t|t	t
frtd
d |D std| j|fi |}|
| |dd}|dd}|dd}|dd}|dd}|du s,|du s,|du s,|du r4td|rV|rVtd| d| d|
d jd }t|}t|}|dur2t|}| j|fi |}t| jjj}|d d jd | }g }g }t|d |d D ]8\}}| jj}t|jdd| | }|| }|| }t B |ddd|f | jj}| j |j!"dd}W d   n1 sf0    Y  |stj#j$j%|d d!|d"}tj#j$j%|dd|d dddfd!|d"}|d | }||rdnd7 }tj&dg| dg|  tj'd#dddf } |(| |(|  qtj)|dd}tj)|dd}n@|rjtj*|d|f|tj'd#}tj+|d| ftj'd$}ntd%||jd krtd&| d'|jd  d(|jd }!|!| }"| j,||!||d)d*}#tj*||!|f|tj-d+}$||$ddd|"f< | j.|$|||#d,}%|
|%|d- |r|
d. / ddddf }&d/|&|&|k< d/|&|&|k< |&"dd0|| d1 ' |
d0< |
d. ddddf |
d.< |
d1 ddddf |
d1< t2|
|	d2S )3a  
        Main method to prepare text(s) and audio to be fed as input to the model. The `audio` argument is
        forwarded to the DiaFeatureExtractor's [`~DiaFeatureExtractor.__call__`] and subsequently to the
        DacModel's [`~DacModel.encode`]. The `text` argument to [`~DiaTokenizer.__call__`]. Please refer
        to the docstring of the above methods for more information.
        zThe `DiaProcessor` relies on the `audio_tokenizer` which requires `torch` but we couldn't find it in your environment. You can install torch via `pip install torch`.Nz0You need to specify the `text` input to process.r.   r    r/   r,   r-   z% only supports `return_tensors='pt'`.c                 s   s   | ]}t |tV  qd S N)
isinstancestr).0tr   r   r   	<genexpr>}       z(DiaProcessor.__call__.<locals>.<genexpr>zAInvalid input text. Please provide a string, or a list of stringsr   r   r   r   r   TzTo enable processing for Dia, we need the `bos_token_id`, `eos_token_id`, `pad_token_id`, and `delay_pattern`. You may have accidentally overwritten one of those.z9Labels with `generation` is incompatible, got generation=z, output_labels=.	input_idsr   padding_maskZinput_valuesdim.      )r   r   r   rK   r   r   Zconstant)padmodevaluedtype)sizerQ   z;If you try to train, you should provide audio data as well.zNNeed the same amount of samples for both text and audio, but got text samples=z and audio samples = z	 instead.Fbszseq_lennum_channelsr   revert)Z
fill_valuerQ   r;   r   r   precomputed_idx)decoder_input_idsdecoder_attention_maskrZ   ilabelsr[   )dataZtensor_type)3r   
ValueError_merge_kwargsr   popr9   r   r?   r@   r   tupleallr7   updateshapelenmaxr   r6   mathprodr2   configZdownsampling_ratioszipZ
hop_lengthceilsumtorchno_gradtodeviceencodeaudio_codes	transposennZ
functionalrM   tensorlongappendcatfullZonesbuild_indicesr   apply_audio_delayclonereshape
contiguousr   )'r5   r:   r;   r<   r=   output_kwargsr.   r    r/   r,   r]   	encodingsr   audio_bos_token_idZaudio_eos_token_idaudio_pad_token_idr   Z
batch_sizerV   Z	max_delayZinput_audiosZcompression_rateZmax_encoded_sequence_lenrZ   r[   rG   Zbase_pad_lenZcurrent_audio_lenZencoded_sequence_lenZpadding_lenrF   Znum_valid_inputsZattention_maskZmax_seq_lenZmax_audio_lenrY   ZprefillZdelayed_decoder_input_idsr\   r   r   r   __call__U   s    
 


6,


$zDiaProcessor.__call__torch.Tensor)rZ   audio_prompt_lenr=   returnc                 K   s  | j tfi |}|d }|dd}|dd}|dd}|du sV|du sV|du r^td|durtj||jtjd}|d |j	d }	n"|dddddf |kj
d	d
}	|j	d |dddddf |kj
d	d
 d }
|j	\}}}| j||||dd}| j|d	d	|ddd}g }t x t|	j	d D ]X}||dd|	| |
| f d }|| jj}| jj|dj  }|| q8W d   n1 s0    Y  |S )a  
        Decodes a batch of audio codebook sequences into their respective audio waveforms via the
        `audio_tokenizer`. See [`~DacModel.decode`] for more information.

        Args:
            decoder_input_ids (`torch.Tensor`): The complete output sequence of the decoder.
            audio_prompt_len (`int`): The audio prefix length (e.g. when using voice cloning).
        r    r   Nr   r   zTo enable decoding for Dia, we need the `bos_token_id`, `pad_token_id`, and `delay_pattern`. You may have accidentally overwritten one of those.)rp   rQ   r   rH   rI   rK   TrS   rX   rL   )N.)rr   )r_   r   r`   r^   rm   ru   rp   rv   expandrd   rl   rz   r{   rs   rn   rangero   r2   decodeZaudio_valuescpuZsqueezerw   )r5   rZ   r   r=   r   r    r   r   r   Zstart_of_generation_idxZend_of_generation_idxrT   rU   rV   rY   Zoutput_sequencesZaudiosiZoutput_iZaudio_ir   r   r   batch_decode  sV    ".

".zDiaProcessor.batch_decodec                 K   s<   |j d dkr$td|j d  d| j||fi |d S )z
        Decodes a single sequence of audio codebooks into the respective audio waveform via the
        `audio_tokenizer`. See [`~DacModel.decode`] and [`~DiaProcessor.batch_decode`] for more information.
        r   rK   z5Expecting a single output to be decoded but received z samples instead.)rd   r^   r   )r5   rZ   r   r=   r   r   r   r   I  s
    
zDiaProcessor.decode)r[   r=   r   c                 K   sH   | j tfi |}|d }|dd}|du r6td|jd t| S )z0Utility function to get the audio prompt length.r    r   NzTo enable the utility of retrieving the prompt length for Dia, we need the `delay_pattern`. You may have accidentally overwritten this.rK   )r_   r   r`   r^   rd   rf   )r5   r[   r=   r   r    r   r   r   r   get_audio_prompt_lenZ  s    z!DiaProcessor.get_audio_prompt_len)r;   saving_pathr=   c           	      K   s   t  stdt|}t|ttfr,|g}n(t|ttfrLtdd |D sTt	dt
|t
|krlt	d| jtfi |}|d }|d }t||D ]2\}}t|tjr|   }t||| qd S )Nz/Please install `soundfile` to save audio files.c                 s   s   | ]}t |ttfV  qd S r>   )r?   r@   r   )rA   pr   r   r   rC     rD   z*DiaProcessor.save_audio.<locals>.<genexpr>zAInvalid input path. Please provide a string, or a list of stringsz5The number of audio and saving paths must be the samer    r+   )r   ImportErrorr   r?   r@   r   r   ra   rb   r^   re   r_   r   rj   rm   ZTensorr   floatnumpysfwrite)	r5   r;   r   r=   r   r    r+   Zaudio_valuer   r   r   r   
save_audioo  s(     zDiaProcessor.save_audio)r   r   )rT   rU   rV   r   rW   r   c                 C   s   t j|t jd}t j|t jddddf | |d }|sT||ddddf  }n||ddddf  }t |d|d }t j| t jdddddf | ||}t j|t jdddddf | ||}	t j|d|d|	dgdd }
||
fS )a  
        Precompute (sequence_idx, all_idx) so that out[seq, channel] = in[seq - delay[channel], channel]
        or in[seq, channel] = out[seq + delay[channel], channel] if `revert`.
        Negative sequence_idx => BOS; sequence_idx >= seq_len => PAD.
        rP   N).Nr   rK   rH   rI   )	rm   ru   Zint32Zaranger   clampstackr}   rv   )rT   rU   rV   r   rW   Zdelay_arraysequence_idxvalid_sequence_idx	batch_idxchannel_idxall_idxr   r   r   rz     s    (((
zDiaProcessor.build_indices)r;   r   r   rY   r   c              	   C   s   | j }|\}}||}||}tj|dd\}}}	| |||	f |  }
|dk }|| jd k}t||t|||
}|S )a  
        Applies or reverts the delay pattern to batched audio tokens using precomputed indices,
        inserting BOS where sequence_idx < 0 and PAD where sequence_idx >= seq_len.

        Args:
            audio: audio tokens of shape [bsz, seq_len, num_channels]
            pad_token_id: the PAD token
            bos_token_id: the BOS token
            precomputed_idx: from `build_indices`

        Returns:
            final_audio: delayed or reverted audio tokens of shape [bsz, seq_len, num_channels]
        rH   rI   r   rK   )rp   ro   rm   ZunbindviewrR   rd   where)r;   r   r   rY   rp   r   r   r   r   r   Zgathered_audioZmask_bosZmask_padZfinal_audior   r   r   r{     s    

zDiaProcessor.apply_audio_delay)NF)N)N)F)r   r   r   __doc__Zfeature_extractor_classZtokenizer_classZaudio_tokenizer_classr4   r   r@   r   r   r   r   r   r   r   r   r   r   r   r   r   staticmethodra   rz   r{   __classcell__r   r   r8   r   r1   >   sj      1 J " "r1   )r   rg   pathlibr   typingr   r   Zaudio_utilsr   r   Zfeature_extraction_utilsr   Zprocessing_utilsr	   r
   r   r   utilsr   r   rm   Z	soundfiler   r   r   r1   __all__r   r   r   r   <module>   s$      