a
    h                     @   sl  d Z ddlZddlZddlmZmZ ddlZddlmZ ddlm	Z	 ddl
mZ ddlmZ dd	lmZmZmZ dd
lmZ ddlmZ ddlmZ ddlmZmZmZ ddlmZ ddlmZm Z  ddl!m"Z"m#Z#m$Z$m%Z%m&Z& ddl'm(Z( ddl)m*Z* e# rddl+m,Z, ddl-m.Z. e&/e0Z1dZ2zddl3m4Z4 dZ2e15d W n0 e6yX   Y n e7yt   e18d Y n0 G dd dej9Z:e2se4Z:G dd dej9Z;G dd  d ej9Z<G d!d" d"ej9Z=G d#d$ d$ej9Z>G d%d& d&ej9Z?G d'd( d(ej9Z@G d)d* d*eZAe"G d+d, d,eZBG d-d. d.eBZCG d/d0 d0ej9ZDe"d1d2G d3d4 d4eBeZEd4d,gZFdS )5zPyTorch Pop2Piano model.    N)OptionalUnion)nn)CrossEntropyLoss)GenerationConfig   )ACT2FN)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)AttentionMaskConverter)GradientCheckpointingLayer)BaseModelOutput)BaseModelOutputWithPastAndCrossAttentionsSeq2SeqLMOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)auto_docstringis_torch_flex_attn_availableis_torch_fx_proxyis_torchdynamo_compilinglogging)deprecate_kwarg   )Pop2PianoConfig)	BlockMask)make_flex_block_causal_maskT)FusedRMSNormFzVDiscovered apex.normalization.FusedRMSNorm - will use it instead of Pop2PianoLayerNormzIDiscovered apex but it failed to load, falling back to Pop2PianoLayerNormc                       s&   e Zd Zd fdd	Zdd Z  ZS )Pop2PianoLayerNormư>c                    s&   t    tt|| _|| _dS )zj
        Construct a layernorm module in the Pop2Piano style. No bias and no subtraction of mean.
        N)super__init__r   	Parametertorchonesweightvariance_epsilon)selfZhidden_sizeeps	__class__ l/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/pop2piano/modeling_pop2piano.pyr#   B   s    
zPop2PianoLayerNorm.__init__c                 C   s\   | tjdjddd}|t|| j  }| jjtj	tj
fv rR| | jj}| j| S )N   T)Zkeepdim)tor%   Zfloat32powmeanZrsqrtr(   r'   dtypefloat16Zbfloat16)r)   hidden_statesZvariancer-   r-   r.   forwardJ   s
    zPop2PianoLayerNorm.forward)r!   )__name__
__module____qualname__r#   r7   __classcell__r-   r-   r+   r.   r    A   s   r    c                       s*   e Zd Zed fddZdd Z  ZS )Pop2PianoDenseActDenseconfigc                    sT   t    tj|j|jdd| _tj|j|jdd| _t|j	| _
t|j | _d S NFbias)r"   r#   r   Lineard_modeld_ffwiwoDropoutdropout_ratedropoutr   dense_act_fnactr)   r>   r+   r-   r.   r#   `   s
    
zPop2PianoDenseActDense.__init__c                 C   sl   |  |}| |}| |}t| jjtjr^|j| jjjkr^| jjjtj	kr^|
| jjj}| |}|S N)rE   rK   rI   
isinstancerF   r'   r%   Tensorr4   int8r1   )r)   r6   r-   r-   r.   r7   g   s    



zPop2PianoDenseActDense.forwardr8   r9   r:   r   r#   r7   r;   r-   r-   r+   r.   r<   _   s   r<   c                       s*   e Zd Zed fddZdd Z  ZS )Pop2PianoDenseGatedActDenser=   c                    sj   t    tj|j|jdd| _tj|j|jdd| _tj|j|jdd| _t	|j
| _t|j | _d S r?   )r"   r#   r   rB   rC   rD   wi_0wi_1rF   rG   rH   rI   r   rJ   rK   rL   r+   r-   r.   r#   w   s    
z$Pop2PianoDenseGatedActDense.__init__c                 C   sz   |  | |}| |}|| }| |}t| jjtjrl|j	| jjj	krl| jjj	tj
krl|| jjj	}| |}|S rM   )rK   rS   rT   rI   rN   rF   r'   r%   rO   r4   rP   r1   )r)   r6   Zhidden_geluZhidden_linearr-   r-   r.   r7      s    


z#Pop2PianoDenseGatedActDense.forwardrQ   r-   r-   r+   r.   rR   v   s   rR   c                       s*   e Zd Zed fddZdd Z  ZS )Pop2PianoLayerFFr=   c                    sJ   t    |jrt|| _n
t|| _t|j|jd| _	t
|j| _d S )Nr*   )r"   r#   Zis_gated_actrR   DenseReluDenser<   r    rC   layer_norm_epsilon
layer_normr   rG   rH   rI   rL   r+   r-   r.   r#      s    

zPop2PianoLayerFF.__init__c                 C   s&   |  |}| |}|| | }|S rM   )rY   rW   rI   )r)   r6   Zforwarded_statesr-   r-   r.   r7      s    

zPop2PianoLayerFF.forwardrQ   r-   r-   r+   r.   rU      s   
rU   c                
       sb   e Zd Zdeee d fddZdd ZedddZ	dddZ
edddddddZ  ZS )Pop2PianoAttentionFN)r>   	layer_idxc                    s  t    |j| _|| _|j| _|j| _|j| _|j| _|j	| _
|j| _| j
| j | _|| _|d u r| jrtd| jj d tj| j| jdd| _tj| j| jdd| _tj| j| jdd| _tj| j| jdd| _| jrt| j| j
| _t | _d| _d S )NzInstantiating a decoder z without passing `layer_idx` is not recommended and will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` when creating this class.Fr@   )r"   r#   
is_decoderhas_relative_attention_biasrelative_attention_num_bucketsrelative_attention_max_distancerC   d_kvkey_value_proj_dim	num_headsn_headsrH   rI   	inner_dimr[   loggerwarning_oncer,   r8   r   rB   qkvo	Embeddingrelative_attention_biassetpruned_headsgradient_checkpointingr)   r>   r]   r[   r+   r-   r.   r#      s.    
zPop2PianoAttention.__init__c                 C   s   t |dkrd S t|| j| j| j\}}t| j|| _t| j|| _t| j|| _t| j	|dd| _	| jt | | _| j| j | _
| j|| _d S )Nr   r   dim)lenr   rc   ra   rn   r   rg   rh   ri   rj   rd   union)r)   Zheadsindexr-   r-   r.   prune_heads   s    zPop2PianoAttention.prune_headsT       c                 C   s   d}|r4|d }|| dk tj| 7 }t| } nt| t|  } |d }| |k }|t|  | t||  ||   tj }t|t	||d }|t
|| |7 }|S )a  
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on

        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
            max_distance: an integer

        Returns:
            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
        r   r/   r   )r1   r%   longabsminZ
zeros_likelogfloatmathZ	full_likewhere)relative_positionbidirectionalnum_bucketsmax_distanceZrelative_bucketsZ	max_exactZis_smallZrelative_position_if_larger-   r-   r.   _relative_position_bucket   s,    z,Pop2PianoAttention._relative_position_bucketc           
      C   s   |du r| j jj}|du r:tj|tj|ddddf }n|dddf |}tj|tj|ddddf }|| }| j|| j | j	| j
d}|  |}	|	g dd}	|	S )z%Compute binned relative position biasN)r4   device)r   r   r   )r/   r   r   r   )rl   r'   r   r%   arangery   r1   r   r\   r^   r_   Zpermute	unsqueeze)
r)   query_length
key_lengthr   cache_positionZcontext_positionZmemory_positionr   Zrelative_position_bucketvaluesr-   r-   r.   compute_bias  s     
 
zPop2PianoAttention.compute_biaspast_key_valuepast_key_values4.58new_nameversionc                 C   s  |j dd \}}|du}| |}||d| j| jdd}|durtt|trt|j	| j
}|rl|j}qx|j}n|}|r|n|}|r|dur|r|j| j
 j}|j| j
 j}n| |}| |}||d| j| jdd}||d| j| jdd}|durB|s|
nd}
|||| j
d|
i\}}|rBd|j| j
< t||dd}|du r0|j d }|durx|n
|
d d }| jstjd| j||f|j|jd	}| jr| jrd|_n6| j|||j|
d
}|dddd| dddf }|dur0|ddddddd|j d f }|| }| jrlt|j d }d|t| j< |dd|  f }n|}||7 }t!j"j#|$ dd%|}t!j"j&|| j&| jd}|dur|| }t||}|dd' }||d| j(}| )|}||f}|	r||f }|S )z
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
        Nr/   r0   r   r   Tr   )r   r4   )r   r   r   rq   )ptraining)*shaperg   viewrc   ra   Z	transposerN   r   
is_updatedgetr[   Zcross_attention_cacheself_attention_cacheZlayerskeysr   rh   ri   updater%   matmulr]   Zzerosr   r4   ro   r   Zrequires_gradr   rn   r&   listboolr   Z
functionalZsoftmaxr}   Ztype_asrI   
contiguousrd   rj   )r)   r6   maskkey_value_statesposition_biasr   layer_head_maskr   	use_cacheoutput_attentionsr   
batch_size
seq_lengthZis_cross_attentionZquery_statesr   Zcurr_past_key_valueZcurrent_statesZ
key_statesZvalue_statesZscoresr   Zreal_seq_lengthcausal_maskZposition_bias_maskedZattn_weightsZattn_outputoutputsr-   r-   r.   r7     sx    






"
&


zPop2PianoAttention.forward)FN)Trw   rx   )NN)	NNNNNNFFN)r8   r9   r:   r   r   intr#   rv   staticmethodr   r   r   r7   r;   r-   r-   r+   r.   rZ      s*     #/
         rZ   c                       s@   e Zd Zdee d fddZedddd	dd
dZ  ZS )Pop2PianoLayerSelfAttentionFNr[   c                    s>   t    t|||d| _t|j|jd| _t	|j
| _d S )Nr]   r[   rV   )r"   r#   rZ   SelfAttentionr    rC   rX   rY   r   rG   rH   rI   rp   r+   r-   r.   r#     s    
z$Pop2PianoLayerSelfAttention.__init__r   r   r   r   c	              
   C   sL   |  |}	| j|	|||||||d}
|| |
d  }|f|
dd   }|S )N)r   r   r   r   r   r   r   r   r   )rY   r   rI   )r)   r6   attention_maskr   r   r   r   r   r   normed_hidden_statesattention_outputr   r-   r-   r.   r7     s    

z#Pop2PianoLayerSelfAttention.forward)FN)NNNNFFN	r8   r9   r:   r   r   r#   r   r7   r;   r-   r-   r+   r.   r     s          r   c                	       s@   e Zd Zdee d fddZedddddd
dZ  ZS )Pop2PianoLayerCrossAttentionNr   c                    s>   t    t|d|d| _t|j|jd| _t	|j
| _d S )NFr   rV   )r"   r#   rZ   EncDecAttentionr    rC   rX   rY   r   rG   rH   rI   )r)   r>   r[   r+   r-   r.   r#     s    
z%Pop2PianoLayerCrossAttention.__init__r   r   r   r   Fc                 C   sP   |  |}| j|||||||||	|
d
}|| |d  }|f|dd   }|S )N)	r   r   r   r   r   r   r   r   r   r   r   )rY   r   rI   )r)   r6   r   r   r   r   r   r   r   r   r   r   r   Zlayer_outputr   r-   r-   r.   r7     s     
z$Pop2PianoLayerCrossAttention.forward)N)NNNNFNFNr   r-   r-   r+   r.   r     s           r   c                       s@   e Zd Zdee d fddZedddd	dddZ  ZS )Pop2PianoBlockFNr   c                    s`   t    |j| _t | _| jt|||d | jrL| jt||d | jt	| d S )Nr   r   )
r"   r#   r\   r   
ModuleListlayerappendr   r   rU   rp   r+   r-   r.   r#     s    

zPop2PianoBlock.__init__r   r   r   r   Tc                 C   s  | j d |||||	|
||d}|d }|dd  }|jtjkrtt| t|jjd t|jj}tj	|| |d}| j
o|d u}|r$| j d ||||||	|d d |
|d	}|d }|jtjkrtt| t|jjd t|jj}tj	|| |d}||dd   }| j d |}|jtjkrtt| t|jjd t|jj}tj	|| |d}|f}|| S )Nr   )r   r   r   r   r   r   r   r   i  )r{   maxr0   )r   r   r   r   r   r   r   r   )r   r4   r%   r5   r   isinfanyfinfor   clampr\   )r)   r6   r   r   encoder_hidden_statesencoder_attention_maskencoder_decoder_position_biasr   cross_attn_layer_head_maskr   r   r   return_dictr   Zself_attention_outputsZattention_outputsZclamp_valueZdo_cross_attentionZcross_attention_outputsr   r-   r-   r.   r7     sh    

zPop2PianoBlock.forward)FN)NNNNNNNNFFTNr   r-   r-   r+   r.   r     s               r   c                   @   sB   e Zd ZU eed< dZdZdZdZdgZ	dgZ
dd Zd	d
 ZdS )Pop2PianoPreTrainedModelr>   ZtransformerFTr   rF   c                 C   s  | j j}t|tr(|jj|d  nt|trN|jjjj	d|d d nt|t
r|jjjj	d|d d t|dr| j js|jjjj	d|d d nLt|tr>|jjjj	d|| j jd  d t|jdr|jjdur|jjj  |jjjj	d|| j jd  d t|jdr|jjdur|jjj  nt|tr*|jjjj	d|| j jd  d t|jdr|jjdur|jjj  |jjjj	d|| j jd  d t|jdr|jjdur|jjj  |jjjj	d|| j jd  d t|jdr|jjdur|jjj  nt|tr| j j}| j j}| j j}|jjjj	d||| d  d |jjjj	d||d  d |jjjj	d||d  d |jjjj	d||| d  d |j r|j!jjj	d||d  d dS )zInitialize the weights      ?        )r3   Zstdlm_head      rA   N)"r>   Zinitializer_factorrN   r    r'   dataZfill_Pop2PianoConcatEmbeddingToMel	embeddingZnormal_!Pop2PianoForConditionalGenerationsharedhasattrtie_word_embeddingsr   r<   rE   rC   rA   Zzero_rF   rD   rR   rS   rT   rZ   r`   rb   rg   rh   ri   rj   r]   rl   )r)   modulefactorrC   ra   rc   r-   r-   r.   _init_weightsM  sH    


       z&Pop2PianoPreTrainedModel._init_weightsc                 C   s   | j j}| j j}|d u r tdt|rbt|jd d d |}tj||dd df gdd}n4|	|j}|dd df 
 |ddd f< ||d< |d u rtd||d	k| |S )
Nzoself.model.config.decoder_start_token_id has to be defined. In Pop2Piano it is usually set to the pad_token_id.r0   )r   .rq   r   ).r   z1self.model.config.pad_token_id has to be defined.)r>   decoder_start_token_idpad_token_id
ValueErrorr   r%   fullr   catZ	new_zeroscloneZmasked_fill_)r)   	input_idsr   r   Zshifted_input_idsr-   r-   r.   _shift_right{  s       z%Pop2PianoPreTrainedModel._shift_rightN)r8   r9   r:   r   __annotations__Zbase_model_prefixZis_parallelizableZsupports_gradient_checkpointingZ_can_compile_fullgraphZ_no_split_modulesZ_keep_in_fp32_modulesr   r   r-   r-   r-   r.   r   B  s   
.r   c                       sx   e Zd Zd fdd	Zdd ZdddZdeejd	f ejeje	e
d
ddZeejeeejejedddZ  ZS )Pop2PianoStackNc                    sx   t    || _ j| _t fddt jD | _t	 j
 jd| _t j| _|   d| _d | _d| _d S )Nc                    s"   g | ]}t  t|d k|dqS )r   r   )r   r   ).0ir=   r-   r.   
<listcomp>  s   z+Pop2PianoStack.__init__.<locals>.<listcomp>rV   F)r"   r#   embed_tokensr\   r   r   range
num_layersblockr    rC   rX   final_layer_normrG   rH   rI   	post_initZmodel_parallelZ
device_mapro   )r)   r>   r   r+   r=   r.   r#     s    
zPop2PianoStack.__init__c                 C   s
   || _ d S rM   )r   r)   Znew_embeddingsr-   r-   r.   set_input_embeddings  s    z#Pop2PianoStack.set_input_embeddingsc           %      C   sD  |	d ur|	n| j j}	|
d ur |
n| j j}
|d ur4|n| j j}|d urH|n| j j}|d ur|d ur| jrjdnd}td| d| dn`|d ur| }|d|d }n>|d ur| d d }n$| jrdnd}td| d| d	| j	r
| j
r
|	r
td
 d}	|d u r2| jd u r(td| |}|\}}|	du r\| js\td|  d| jr|	r|d u r| j jrtt| j dt| j d}nt| j d}n| jsd }|d ur| nd}|d u rtj||| |jd}|d u rt s|| }tj|||jd}| j jrF| |||t|tr<|jn||
}n<|d d d d d d f }|j|jd}d| t|jj }| jr|d ur| \}}}||f}|d u rtj||jd}| |}nd }| || j j }| || j j }|rdnd }|
rdnd }|
r"| jr"dnd }d }d }| !|}t"| j#D ]\} }!||  }"||  }#|rj||f }|!|||||||"|#||	|
|d}$|$d }|$d }| jr|d ur|$|
rdnd }|
rB||$d f }| jrB||$d f }qB| $|}| !|}|r||f }|s2t%dd |||||fD S t&|||||dS )NZdecoder_ zYou cannot specify both zinput_ids and zinputs_embeds at the same timer0   zYou have to specify either zinput_ids or inputs_embedszZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fz<You have to initialize the model with valid token embeddingsTz)`use_cache` can only be set to `True` if z is used as a decoderr=   r   r   )r4   r   r-   )r   r   r   r   r   r   r   r   r/      c                 s   s   | ]}|d ur|V  qd S rM   r-   )r   ri   r-   r-   r.   	<genexpr>Q  s   z)Pop2PianoStack.forward.<locals>.<genexpr>)last_hidden_stater   r6   
attentionscross_attentions)'r>   r   r   output_hidden_statesuse_return_dictr\   r   sizer   ro   r   re   rf   r   Zis_encoder_decoderr   r
   get_seq_lengthr%   r   r   r   r&   _update_causal_maskrN   r   r1   r4   r   r{   Zinvert_attention_maskZget_head_maskr   rI   	enumerater   r   tupler   )%r)   r   r   r   r   r   	head_maskcross_attn_head_maskr   r   r   r   r   r   Zerr_msg_prefixZinput_shaper   r   past_key_values_lengthZmask_seq_lengthr   Zencoder_batch_sizeZencoder_sequence_length_Zencoder_hidden_shapeZencoder_extended_attention_maskZall_hidden_statesZall_attentionsZall_cross_attentionsr   r   r6   r   Zlayer_moduler   r   Zlayer_outputsr-   r-   r.   r7     s    













zPop2PianoStack.forwardFr   )r   input_tensorr   r   r   c                 C   sB  | j jdkr(|d ur$|dk r$|S d S | j jdkrLt|tjrHt|}|S |d ur\| nd}|d urn|jnd}| j jdkr|s|st	j
|||| jdrd S |j}|jd }	|r| }
n"t|tjr|jd	 n
||	 d }
| j||	|
|||jd d
}| j jdkr>|d ur>|jjdv r>|s>t|j}t	||}|S )NZflash_attention_2r   Zflex_attentionr   FZsdpa)r   r   Zis_trainingr   r0   )sequence_lengthtarget_lengthr4   r   r   )cudaZxpuZnpu)r>   Z_attn_implementationr   rN   r%   rO   r   r   Zis_compileabler   Z_ignore_causal_mask_sdpar   r4   r   Zget_max_cache_shape5_prepare_4d_causal_attention_mask_with_cache_positionr   typer   r{   Z_unmask_unattended)r)   r   r   r   r   r   Zpast_seen_tokensZusing_compilable_cacher4   r   r   r   	min_dtyper-   r-   r.   r   e  sZ    






	z"Pop2PianoStack._update_causal_mask)r   r   r   r4   r   r   c                 K   sF  | dur|   dkr| }n&t|j}tj||f|||jd}|dkrVtj|dd}|tj||jd|ddk9 }|ddddddf 	|ddd}| durB|
 }| jd }	|ddddddd|	f | ddddddf |j }
|
dk}
|ddddddd|	f |
||ddddddd|	f< |S )	aM  
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (`torch.Tensor`):
                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
                `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (`int`):
                The sequence length being processed.
            target_length (`int`):
                The target length: when generating with static cache, the mask should be as long as the static cache,
                to account for the 0 padding, the part of the cache that is not filled yet.
            dtype (`torch.dtype`):
                The dtype to use for the 4D attention mask.
            cache_position (`torch.Tensor`):
                Indices depicting the position of the input sequence tokens in the sequence.
            batch_size (`torch.Tensor`):
                Batch size.
        Nr   )Z
fill_valuer4   r   r   )Zdiagonalr   r0   r   )rr   r%   r   r{   r   r   Ztriur   Zreshapeexpandr   r   r1   Zmasked_fill)r   r   r   r4   r   r   kwargsr   r  Zmask_lengthZpadding_maskr-   r-   r.   r    s*     $

6  zDPop2PianoStack._prepare_4d_causal_attention_mask_with_cache_position)N)NNNNNNNNNNNNN)F)r8   r9   r:   r#   r   r7   r   r%   rO   r	   r   r   r   r   r4   r  r;   r-   r-   r+   r.   r     sB                
 : Dr   c                       s(   e Zd ZdZ fddZdd Z  ZS )r   z'Embedding Matrix for `composer` tokens.c                    s"   t    tj|j|jd| _d S )N)Znum_embeddingsZembedding_dim)r"   r#   r   rk   composer_vocab_sizerC   r   rL   r+   r-   r.   r#     s    
z&Pop2PianoConcatEmbeddingToMel.__init__c                 C   s.   || }|  |d}tj||gdd}|S )Nr   rq   )r   r   r%   r   )r)   featureindex_valueembedding_offsetZindex_shiftedZcomposer_embeddingr   r-   r-   r.   r7     s    z%Pop2PianoConcatEmbeddingToMel.forward)r8   r9   r:   __doc__r#   r7   r;   r-   r-   r+   r.   r     s   r   zA
    Pop2Piano Model with a `language modeling` head on top.
    )Zcustom_introc                       s6  e Zd Zg dZed fddZdd Zdd Zd	d
 Zde	j
eeee	j
 dddZedee	j ee	j
 ee	j ee	j ee	j
 ee	j
 ee	j eeee	j   ee ee	j
 ee	j
 ee	j
 ee	j ee ee ee ee ee	j eee	j
 ef dddZe	 d fdd	Ze	jdddZ  ZS )r   )zencoder.embed_tokens.weightzdecoder.embed_tokens.weightzlm_head.weightr=   c                    s   t  | || _|j| _t|j|j| _t	|| _
t|}d|_d|_d|_t|| j| _t|}d|_d|_|j|_t|| j| _tj|j|jdd| _|   d S )NFTr@   )r"   r#   r>   rC   	model_dimr   rk   Z
vocab_sizer   r   mel_conditionercopydeepcopyr\   r   Ztie_encoder_decoderr   encoderZnum_decoder_layersr   decoderrB   r   r   )r)   r>   Zencoder_configZdecoder_configr+   r-   r.   r#     s"    


z*Pop2PianoForConditionalGeneration.__init__c                 C   s   | j S rM   )r   r)   r-   r-   r.   get_input_embeddings  s    z6Pop2PianoForConditionalGeneration.get_input_embeddingsc                 C   s"   || _ | j| | j| d S rM   )r   r  r   r  r   r-   r-   r.   r     s    z6Pop2PianoForConditionalGeneration.set_input_embeddingsc                 C   s   | j S rM   )r  r  r-   r-   r.   get_encoder  s    z-Pop2PianoForConditionalGeneration.get_encoderN)input_featurescomposergeneration_configr   c                 C   s   |j }||vr*tdt|  d| || }tj|| jd}||jd }t	|
 }| j|||d}|durd||dddf   < tj|dddf dd	|gd	d
}||fS |dfS )a  
        This method is used to concatenate mel conditioner tokens at the front of the input_features in order to
        control the type of MIDI token generated by the model.

        Args:
            input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                input features extracted from the feature extractor.
            composer (`str`):
                composer token which determines the type of MIDI tokens to be generated.
            generation_config (`~generation.GenerationConfig`):
                The generation is used to get the composer-feature_token pair.
            attention_mask (``, *optional*):
                For batched generation `input_features` are padded to have the same shape across all examples.
                `attention_mask` helps to determine which areas were padded and which were not.
                - 1 for tokens that are **not padded**,
                - 0 for tokens that are **padded**.
        zPlease choose a composer from z. Composer received - r   r   )r  r  r	  Nr   r0   r   )Zaxis)composer_to_feature_tokenr   r   r   r%   Ztensorr   repeatr   r{   r   r  r   Zconcatenater   )r)   r  r  r  r   r  Zcomposer_valuer	  r-   r-   r.   get_mel_conditioner_outputs  s&    &z=Pop2PianoForConditionalGeneration.get_mel_conditioner_outputs)r   r   decoder_input_idsdecoder_attention_maskr   decoder_head_maskr   encoder_outputsr   r   r  decoder_inputs_embedslabelsr   r   r   r   r   returnc                 C   s  |dur|n| j j}|dur |n| j j}|
durB|durBtdn|durV|
du rV|}
|du rx| j|||
||||d}nH|rt|tst|d t|dkr|d ndt|dkr|d ndd}|d }|dur|du r|du r| |}| j	||||	|||||||||d}|d }| j j
r.|| jd	  }| |}d}|durntd
d}||d|d|d}|s|f|dd  | }|dur|f| S |S t|||j|j|j|j|j|j|jd	S )a2
  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Pop2Piano is a model with relative position embeddings
            so you should be able to pad the inputs on both the right and the left. Indices can be obtained using
            [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for detail.
            [What are input IDs?](../glossary#input-ids) To know more on how to prepare `input_ids` for pretraining
            take a look a [Pop2Piano Training](./Pop2Piano#training).
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
            [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
            [What are decoder input IDs?](../glossary#decoder-input-ids) Pop2Piano uses the `pad_token_id` as the
            starting token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last
            `decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
            1]`:
            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
            `[0, 1]`:
            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
            labels in `[0, ..., config.vocab_size]`
        NzSBoth `inputs_embeds` and `input_features` received! Please provide only one of them)r   r   r   r   r   r   r   r   r   r/   )r   r6   r   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   )Zignore_indexr0   )	lossZlogitsr   Zdecoder_hidden_statesZdecoder_attentionsr   Zencoder_last_hidden_stater   Zencoder_attentions)r>   r   r   r   r  rN   r   rs   r   r  r   r  r   r   r   r   r   r   r6   r   r   r   )r)   r   r   r  r  r   r  r   r  r   r   r  r  r  r   r   r   r   r   r6   Zdecoder_outputsZsequence_outputZ	lm_logitsr!  Zloss_fctoutputr-   r-   r.   r7   O  s|    5
	




z)Pop2PianoForConditionalGeneration.forward	composer1c                    s   |du r| j }|jf i | t|ds0tdt|j| jjkrbtd| jj dt|j d| j||||d\}}t	 j
f d|||d|S )	a  
        Generates token ids for midi outputs.

        <Tip warning={true}>

        Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
        model's default generation configuration. You can override any `generation_config` by passing the corresponding
        parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. For an overview of generation
        strategies and code examples, check out the [following guide](./generation_strategies).

        </Tip>

        Parameters:
            input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                This is the featurized version of audio generated by `Pop2PianoFeatureExtractor`.
            attention_mask:
                For batched generation `input_features` are padded to have the same shape across all examples.
                `attention_mask` helps to determine which areas were padded and which were not.
                - 1 for tokens that are **not padded**,
                - 0 for tokens that are **padded**.
            composer (`str`, *optional*, defaults to `"composer1"`):
                This value is passed to `Pop2PianoConcatEmbeddingToMel` to generate different embeddings for each
                `"composer"`. Please make sure that the composet value is present in `composer_to_feature_token` in
                `generation_config`. For an example please see
                https://huggingface.co/sweetcocoa/pop2piano/blob/main/generation_config.json .
            generation_config (`~generation.GenerationConfig`, *optional*):
                The generation configuration to be used as base parametrization for the generation call. `**kwargs`
                passed to generate matching the attributes of `generation_config` will override them. If
                `generation_config` is not provided, the default will be used, which had the following loading
                priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
                configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
                default values, whose documentation should be checked to parameterize generation.
            kwargs:
                Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
                forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
                specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
        Return:
            [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
            or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
                Since Pop2Piano is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
                [`~utils.ModelOutput`] types are:
                    - [`~generation.GenerateEncoderDecoderOutput`],
                    - [`~generation.GenerateBeamEncoderDecoderOutput`]
        Nr  z`composer_to_feature_token` was not found! Please refer to https://huggingface.co/sweetcocoa/pop2piano/blob/main/generation_config.jsonand parse a dict like that.ztconfig.composer_vocab_size must be same as the number of keys in generation_config.composer_to_feature_token! Found z vs .)r  r   r  r  )inputsr   r   r  )r  r   r   r   rs   r  r>   r  r  r"   generate)r)   r  r   r  r  r  r+   r-   r.   r&    s:    6

z*Pop2PianoForConditionalGeneration.generate)r  c                 C   s
   |  |S rM   )r   )r)   r  r-   r-   r.   %prepare_decoder_input_ids_from_labels.  s    zGPop2PianoForConditionalGeneration.prepare_decoder_input_ids_from_labels)N)NNNNNNNNNNNNNNNNNN)Nr#  N)r8   r9   r:   Z_tied_weights_keysr   r#   r  r   r  r%   ZFloatTensorstrr   r   r  r   Z
LongTensorZ
BoolTensorrO   r   r	   r   r   r   r7   Zno_gradr&  r'  r;   r-   r-   r+   r.   r     sv    1                      Yr   )Gr
  r  r~   typingr   r   r%   r   Ztorch.nnr   Ztransformers.generationr   Zactivationsr   Zcache_utilsr	   r
   r   Z
generationr   Zmodeling_attn_mask_utilsr   Zmodeling_layersr   Zmodeling_outputsr   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   utilsr   r   r   r   r   Zutils.deprecationr   Zconfiguration_pop2pianor   Z!torch.nn.attention.flex_attentionr   Zintegrations.flex_attentionr   Z
get_loggerr8   re   Z_load_pop2piano_layer_normZapex.normalizationr   infoImportError	ExceptionwarningModuler    r<   rR   rU   rZ   r   r   r   r   r   r   r   __all__r-   r-   r-   r.   <module>   sp   

 j&(dS  N  ?