a
    h5Z                  
   @   s  d dl Z d dlmZmZmZ d dlZd dlmZ ddlmZ ddl	m
Z
 ddlmZ ddlmZ dd	lmZmZmZ dd
lmZmZ ddlmZ ddlmZmZmZmZ ddlmZ ddlm Z m!Z! ddl"m#Z#m$Z$ e%e&Z'd&ej(ej)ej)ej)eej) ee* e*eej) dddZ+G dd dej(Z,G dd deZ-eG dd deZ.eddG dd de.Z/G d d! d!ej(Z0ed"dG d#d$ d$e.eZ1g d%Z2dS )'    N)CallableOptionalUnion)nn   )ACT2FN)Cache)GenerationMixin)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPastCausalLMOutputWithPast)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)TransformersKwargsauto_docstringcan_return_tuplelogging)check_model_inputs   )	AutoModelAutoModelForCausalLM   )VoxtralConfigVoxtralEncoderConfig        )modulequerykeyvalueattention_maskscalingdropout	head_maskc                 K   s   |d u r| dd }t||dd| }	|d urj|jdkrj|	|d d d d d d d |jd f  }	tjj|	dd}	|d ur|	|	dddd }	tjj
|	|| jd	}	t|	|}
|
dd }
|
|	fS )
N      r   r      )dimr   ptraining)sizetorchmatmul	transposendimshaper   
functionalZsoftmaxviewr#   r,   
contiguous)r   r   r   r    r!   r"   r#   r$   kwargsattn_weightsattn_output r9   h/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/voxtral/modeling_voxtral.pyeager_attention_forward,   s    *r;   c                       s   e Zd ZdZdeeeeeeee ee d fddZ	e
jeed	d
dZde
jee
j ee
j eee
jee
j eee
j  f dddZ  ZS )VoxtralAttentionz=Multi-headed attention from 'Attention Is All You Need' paperr   FTN)	embed_dim	num_headsr#   
is_decoderbias	is_causal	layer_idxconfigc	           	         s   t    || _|| _|| _|| | _|| _| j| | jkrTtd| j d| d| jd | _|| _	|| _
|d u r|rtd| jj d || _tj||dd| _tj|||d| _tj|||d| _tj|||d| _d S )	Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).r&   zInstantiating a decoder z without passing `layer_idx` is not recommended and will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` when creating this class.Fr@   )super__init__r=   r>   r#   head_dimrC   
ValueErrorr"   r?   rA   loggerZwarning_once	__class____name__rB   r   Lineark_projv_projq_projout_proj)	selfr=   r>   r#   r?   r@   rA   rB   rC   rJ   r9   r:   rF   M   s0    


zVoxtralAttention.__init__)tensorseq_lenbszc                 C   s    | ||| j| jdd S )Nr   r   )r4   r>   rG   r0   r5   )rQ   rS   rT   rU   r9   r9   r:   _shapeu   s    zVoxtralAttention._shapehidden_statesr!   layer_head_maskoutput_attentionsreturnc                 K   s   |  \}}}| | || j ||}	| | |d|}
| | |d|}t}| jjdkrlt	| jj }|| |	|
||f| j
sdn| jd||d|\}}|||d }| |}||fS )z#Input shape: Batch x Time x Channelr%   eagerr         ?)r#   r"   rZ   r$   )r-   rV   rO   r"   rM   rN   r;   rC   Z_attn_implementationr   r,   r#   reshaper5   rP   )rQ   rX   r!   rY   rZ   r6   rU   Ztgt_len_Zquery_statesZ
key_statesZvalue_statesZattention_interfacer8   r7   r9   r9   r:   forwardx   s0    



zVoxtralAttention.forward)r   FTFNN)NNF)rK   
__module____qualname____doc__intfloatboolr   r   rF   r.   TensorrV   tupler`   __classcell__r9   r9   rR   r:   r<   J   s8         (   r<   c                       sB   e Zd Zed fddZdejejejeejdddZ  Z	S )	VoxtralEncoderLayerrC   c                    s   t    |j| _t| j|j|j|d| _t	| j| _
|j| _t|j | _|j| _t| j|j| _t|j| j| _t	| j| _d S )N)r=   r>   r#   rC   )rE   rF   d_modelr=   r<   Zencoder_attention_headsZattention_dropout	self_attnr   	LayerNormself_attn_layer_normr#   r   Zactivation_functionactivation_fnactivation_dropoutrL   Zencoder_ffn_dimfc1fc2final_layer_normrQ   rC   rR   r9   r:   rF      s    
zVoxtralEncoderLayer.__init__FrW   c                 C   s   |}|  |}| j||||d\}}tjj|| j| jd}|| }|}| |}| | |}tjj|| j	| jd}| 
|}tjj|| j| jd}|| }|jtjkrt|jjd }tj|| |d}||fS )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )rX   r!   rY   rZ   r*   i  )minmax)ro   rm   r   r3   r#   r,   rt   rp   rr   rq   rs   dtyper.   Zfloat16Zfinforw   clamp)rQ   rX   r!   rY   rZ   Zresidualr7   Zclamp_valuer9   r9   r:   r`      s*    



zVoxtralEncoderLayer.forward)F)
rK   ra   rb   r   rF   r.   rg   rf   r`   ri   r9   r9   rR   r:   rj      s    rj   c                   @   sF   e Zd ZU eed< dZdZdZdZdZ	dZ
dZdZdZdZdd ZdS )VoxtralPreTrainedModelrC   modelTNpast_key_valuesc                 C   s   t | jdr| jjn| jjj}t|tjtjfrZ|jj	j
d|d |jd ur|jj	  nbt|tjr|jj	d |jj	  n:t|tjr|jj	j
d|d |jd ur|jj	|j   d S )Ninitializer_ranger   )meanstdr]   )hasattrrC   r}   audio_config
isinstancer   rL   Conv1dweightdataZnormal_r@   Zzero_rn   Zfill_	Embeddingpadding_idx)rQ   r   r   r9   r9   r:   _init_weights   s    



z$VoxtralPreTrainedModel._init_weights)rK   ra   rb   r   __annotations__Zbase_model_prefixZsupports_gradient_checkpointing_no_split_modulesZ_skip_keys_device_placementZ_supports_flash_attnZ_supports_sdpaZ_supports_flex_attnZ_supports_cache_classZ_supports_attention_backendZ_can_compile_fullgraphr   r9   r9   r9   r:   rz      s   
rz   z:
    The Voxtral encoder, which is a Whisper encoder.
    )Zcustom_introc                       s   e Zd ZU dZeed< dZdgZee	dZ
ed fddZd	d
 ZejdddZejdddZedee dddZejdddZ  ZS )VoxtralEncoderz
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`VoxtralEncoderLayer`].

    Args:
        config: VoxtralEncoderConfig
    rC   input_featuresrj   )Z
attentionsrX   rk   c                    s   t     j| _ j| _ j} j| _ j| _ j	| _	 j
rJt|nd| _tj| j|ddd| _tj||dddd| _t| j	|| _| jd t fdd	t jD | _t j| _tjddd
| _d| _|   d S )Nr]   r   r   )kernel_sizepaddingr   )r   strider   Fc                    s   g | ]}t  qS r9   )rj   ).0r_   rk   r9   r:   
<listcomp>/      z+VoxtralEncoder.__init__.<locals>.<listcomp>)r   )rE   rF   r#   Zencoder_layerdropZ	layerdroprl   Znum_mel_binsZpad_token_idr   max_source_positionsZscale_embeddingmathsqrtZembed_scaler   r   conv1conv2r   embed_positionsZrequires_grad_Z
ModuleListrangeZencoder_layerslayersrn   
layer_normZ	AvgPool1dZ
avg_poolerZgradient_checkpointing	post_init)rQ   rC   r=   rR   rk   r:   rF     s"     zVoxtralEncoder.__init__c                 C   s   |   D ]
}d|_qd| _d S )NF)
parametersZrequires_gradZ_requires_grad)rQ   paramr9   r9   r:   _freeze_parameters8  s    z!VoxtralEncoder._freeze_parameters)r[   c                 C   s   | j S Nr   rQ   r9   r9   r:   get_input_embeddings=  s    z#VoxtralEncoder.get_input_embeddings)r    c                 C   s
   || _ d S r   r   rQ   r    r9   r9   r:   set_input_embeddings@  s    z#VoxtralEncoder.set_input_embeddingsN)r6   c                 K   s  | j j| jjd  | jjd  }|jd |krPtd| d|jd  d| d|j| jjj	| jjj
d}tj| |}tj| |}|ddd	}| jj}|| |j	}tjj|| j| jd
}t| jD ]\}}	|	||dd}
|
d }q| |}t|dS )a  
        Args:
            input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
                Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
                obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
                `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
                `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
                and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
            attention_mask (`torch.Tensor`)`, *optional*):
                Voxtral does not support masking of the `input_features`, this argument is preserved for compatibility,
                but it is not used. By default the silence in the input log mel spectrogram are ignored.
        r   r%   z:Qwen2Audio expects the mel input features to be of length z, but found z-. Make sure to pad the input mel features to .)rx   devicer   r   r*   N)r!   rY   )last_hidden_state)rC   r   r   r   r   r2   rH   tor   rx   r   r   r3   ZgeluZpermuter   r#   r,   	enumerater   r   r   )rQ   r   r!   r6   Zexpected_seq_lengthinputs_embedsZ	embed_posrX   idxZencoder_layerZlayer_outputsr9   r9   r:   r`   C  s.     

zVoxtralEncoder.forward)input_lengthsc                 C   s(   |d d d }|d d d }||fS )zs
        Computes the output length of the convolutional layers and the output length of the audio encoder
        r   r   r9   )rQ   r   Zoutput_lengthsr9   r9   r:    _get_feat_extract_output_lengthst  s    z/VoxtralEncoder._get_feat_extract_output_lengths)N)rK   ra   rb   rc   r   r   Zmain_input_namer   r<   rj   Z_can_record_outputsrF   r   r   Moduler   r   r   r   r   r`   r.   
LongTensorr   ri   r9   r9   rR   r:   r     s"   
	 0r   c                       s*   e Zd Zed fddZdd Z  ZS )VoxtralMultiModalProjectorrk   c                    sN   t    tj|jj|jjdd| _t	|j
 | _tj|jj|jjdd| _d S )NFrD   )rE   rF   r   rL   r   intermediate_sizetext_configZhidden_sizelinear_1r   Zprojector_hidden_actactlinear_2ru   rR   r9   r:   rF   ~  s    
z#VoxtralMultiModalProjector.__init__c                 C   s"   |  |}| |}| |}|S r   )r   r   r   )rQ   Zaudio_featuresrX   r9   r9   r:   r`     s    


z"VoxtralMultiModalProjector.forward)rK   ra   rb   r   rF   r`   ri   r9   r9   rR   r:   r   }  s   r   zs
    The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a LLama language model.
    c                       s   e Zd ZdgZddiZddgdgfiZdgZ fddZd	d
 Zdd Z	dd Z
dd Zdd Zdd ZejdddZeedeej eej eej eej ee eej eej ee eej eeejf ee edddZ fddZ  ZS ) VoxtralForConditionalGenerationzlm_head.weightZlm_headZcolwise_reprX   Zlogitsr   c                    sH   t  | |jj| _t|j| _t|j| _	t
|| _|   d S r   )rE   rF   r   Z
vocab_sizer   from_configr   audio_towerr   language_modelr   multi_modal_projectorr   ru   rR   r9   r:   rF     s    

z(VoxtralForConditionalGeneration.__init__c                 C   s
   | j  S r   )r   r   r   r9   r9   r:   r     s    z4VoxtralForConditionalGeneration.get_input_embeddingsc                 C   s   | j | d S r   )r   r   r   r9   r9   r:   r     s    z4VoxtralForConditionalGeneration.set_input_embeddingsc                 C   s
   | j  S r   )r   get_output_embeddingsr   r9   r9   r:   r     s    z5VoxtralForConditionalGeneration.get_output_embeddingsc                 C   s   | j | d S r   )r   set_output_embeddings)rQ   Znew_embeddingsr9   r9   r:   r     s    z5VoxtralForConditionalGeneration.set_output_embeddingsc                 C   s   | j | d S r   )r   set_decoder)rQ   decoderr9   r9   r:   r     s    z+VoxtralForConditionalGeneration.set_decoderc                 C   s
   | j  S r   )r   get_decoderr   r9   r9   r:   r     s    z+VoxtralForConditionalGeneration.get_decoder)r   c                 C   s0   |  |}|j}|d| jjj}| |}|S )a  
        This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector.
        Args:
            input_features (`torch.FloatTensor`):
                Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
                obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
                `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
                `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
                and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]

        Returns:
            `torch.FloatTensor`:
                The audio embeddings.
        r%   )r   r   r^   rC   r   r   r   )rQ   r   Zaudio_outputsZaudio_hidden_statesaudio_embedsr9   r9   r:   get_audio_embeds  s
    

z0VoxtralForConditionalGeneration.get_audio_embedsNr   )	input_idsr   r!   position_idsr|   r   labels	use_cachecache_positionlogits_to_keepr6   r[   c                 K   s`   |du r|   |}|dur:| |}|| jjk}|||< | jf |||||||	|
d|}|S )aj  
        Example:

        ```python
        >>> from transformers import VoxtralForConditionalGeneration, AutoProcessor
        >>> import torch

        >>> device = "cuda" if torch.cuda.is_available() else "cpu"
        >>> repo_id = "mistralai/Voxtral-Mini-3B-2507"

        >>> processor = AutoProcessor.from_pretrained(repo_id)
        >>> model = VoxtralForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map=device)

        >>> conversation = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "audio",
                        "url": "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/dude_where_is_my_car.wav",
                    },
                    {"type": "text", "text": "What can you tell me about this audio?"},
                ],
            }
        ]

        >>> inputs = processor.apply_chat_template(conversation)
        >>> inputs = inputs.to(device, dtype=torch.bfloat16)

        >>> outputs = model.generate(**inputs, max_new_tokens=30)
        >>> processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
        ["This audio is a humorous conversation between two friends, likely in English, where one of them is trying to figure out what the other's tattoo says."]
        ```N)r!   r   r|   r   r   r   r   r   )r   r   rC   Zaudio_token_idr   )rQ   r   r   r!   r   r|   r   r   r   r   r   r6   r   Zaudio_token_maskoutputsr9   r9   r:   r`     s&    1
	z'VoxtralForConditionalGeneration.forwardc                    sH   | dd }|d}t j|i |}|d urD|d dkrD||d< |S )Nr   r   r   )popgetrE   prepare_inputs_for_generation)rQ   argsr6   r   r   Zmodel_inputsrR   r9   r:   r     s    
z=VoxtralForConditionalGeneration.prepare_inputs_for_generation)
NNNNNNNNNr   ) rK   ra   rb   Z_tied_weights_keysZ_tp_planZ_pp_planZ_keep_in_fp32_modules_strictrF   r   r   r   r   r   r   r.   ZFloatTensorr   r   r   r   r   rg   r   rf   r   rd   r   r   r   r`   r   ri   r9   r9   rR   r:   r     sN   
          Fr   )rz   r   r   )Nr   N)3r   typingr   r   r   r.   r   Zactivationsr   Zcache_utilsr   Z
generationr	   Zmodeling_layersr
   Zmodeling_outputsr   r   r   Zmodeling_utilsr   r   Zprocessing_utilsr   utilsr   r   r   r   Zutils.genericr   autor   r   Zconfiguration_voxtralr   r   Z
get_loggerrK   rI   r   rg   re   r;   r<   rj   rz   r   r   r   __all__r9   r9   r9   r:   <module>   sV   
	   Z?#q 