a
    hL                 
   @   s\  d Z ddlZddlmZ ddlmZmZmZ ddlZddl	m
Z
 ddlmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZ ddlmZmZ ddlmZ eeZ G dd de
j!Z"G dd de
j!Z#G dd de
j!Z$G dd de
j!Z%G dd de
j!Z&G dd de
j!Z'dre
j!ej(ej(ej(eej( ee) e)eej( dddZ*G dd  d e
j!Z+G d!d" d"e
j!Z,G d#d$ d$e
j!Z-G d%d& d&e
j!Z.G d'd( d(e
j!Z/G d)d* d*e
j!Z0G d+d, d,e
j!Z1eG d-d. d.eZ2G d/d0 d0e
j!Z3dsej(e)ee4 e5e6d2d3d4Z7dtej(ee4e6f ee4 e6d5d6d7Z8G d8d9 d9e
j!Z9G d:d; d;e
j!Z:G d<d= d=e
j!Z;G d>d? d?e
j!Z<G d@dA dAe
j!Z=eedBdCG dDdE dEeZ>G dFdG dGe2Z?eedHdCG dIdJ dJeZ@edKdCG dLdM dMe2ZAeedNdCG dOdP dPeZBedQdCG dRdS dSe2ZCeedTdCG dUdV dVeZDeedWdCG dXdY dYeZEeedWdCG dZd[ d[eZFejGjHej(ej(d\d]d^ZIduej(eej( ej(d_d`daZJG dbdc dce2ZKeedddCG dedf dfeZLG dgdh dhe2ZMeedidCG djdk dkeZNG dldm dme
j!ZOedndCG dodp dpe2ZPg dqZQdS )vzPyTorch PatchTSMixer model.    N)	dataclass)CallableOptionalUnion)PreTrainedModel)ModelOutput   )FlashAttentionKwargs)ALL_ATTENTION_FUNCTIONS)Unpack)NegativeBinomialOutputNormalOutputStudentTOutput)auto_docstringlogging   )PatchTSMixerConfigc                       s0   e Zd ZdZeed fddZdd Z  ZS )PatchTSMixerGatedAttentionz
    Module that applies gated attention to input data.

    Args:
        in_size (`int`): The input size.
        out_size (`int`): The output size.
    in_sizeout_sizec                    s*   t    t||| _tjdd| _d S )Ndim)super__init__nnLinear
attn_layerZSoftmaxattn_softmax)selfr   r   	__class__ r/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/patchtsmixer/modeling_patchtsmixer.pyr   /   s    
z#PatchTSMixerGatedAttention.__init__c                 C   s   |  | |}|| }|S N)r   r   )r    inputsZattn_weightr#   r#   r$   forward4   s    z"PatchTSMixerGatedAttention.forward)__name__
__module____qualname____doc__intr   r'   __classcell__r#   r#   r!   r$   r   &   s   r   c                       s6   e Zd ZdZed fddZejdddZ  Z	S )PatchTSMixerBatchNormzP
    Compute batch normalization over the sequence length (time) dimension.
    configc                    s"   t    tj|j|jd| _d S )Neps)r   r   r   BatchNorm1dd_modelnorm_eps	batchnormr    r0   r!   r#   r$   r   @   s    
zPatchTSMixerBatchNorm.__init__r&   c                 C   s"   | dd}| |}| ddS )a  
        Parameters:
            inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
                input for Batch norm calculation
        Returns:
            `torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
        r      )	transposer6   )r    r&   outputr#   r#   r$   r'   D   s    
zPatchTSMixerBatchNorm.forward
r(   r)   r*   r+   r   r   torchTensorr'   r-   r#   r#   r!   r$   r.   ;   s   r.   c                       sL   e Zd ZdZed fddZeeejdddZ	e
jdd	d
Z  ZS )PatchTSMixerPositionalEncodingz'
    Class for positional encoding
    r/   c                    s:   t    |jr| || _ntt|j	|j
| _d S r%   )r   r   use_positional_encoding_init_peposition_encr   	Parameterr=   zerosnum_patchesr4   r7   r!   r#   r$   r   V   s    
z'PatchTSMixerPositionalEncoding.__init__)r0   returnc                 C   s   | j dkr&tjt| j| jdd}n| j dkrt| j| j}td| j	d}t
td| jdtd| j   }t|| |d d dd df< t|| |d d dd df< ||  }|| d	  }tj|d
d}nt| j  d|S )NrandomTZrequires_gradZsincosr   r   r9   g     @
   FzN is not a valid positional encoder. Available types are 'random' and 'sincos'.)positional_encoding_typer   rC   r=   ZrandnrE   r4   rD   Zarange	unsqueezeexpmathlogsincosmeanstd
ValueError)r0   rB   positionZdiv_termr#   r#   r$   rA   ^   s    

(  
z'PatchTSMixerPositionalEncoding._init_pepatch_inputc                 C   s   || j  }|S r%   )rB   )r    rV   hidden_stater#   r#   r$   r'   r   s    
z&PatchTSMixerPositionalEncoding.forward)r(   r)   r*   r+   r   r   staticmethodr   rC   rA   r=   r>   r'   r-   r#   r#   r!   r$   r?   Q   s
   r?   c                       s6   e Zd ZdZed fddZejdddZ  Z	S )PatchTSMixerNormLayerzeNormalization block

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r/   c                    sD   t    |j| _d|j v r,t|| _ntj|j|j	d| _d S )Nbatchr1   )
r   r   norm_mlplowerr.   normr   	LayerNormr4   r5   r7   r!   r#   r$   r      s
    
zPatchTSMixerNormLayer.__init__r8   c                 C   sd   d| j  v rVt||jd |jd  |jd |jd f}| |}t||j}n
| |}|S )a  
        Args:
            inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
                Input to the normalization layer.
        Returns:
            `torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`
        rZ   r   r   r9   r   )r[   r\   r=   reshapeshaper]   )r    r&   Zinputs_reshapedr#   r#   r$   r'      s    


zPatchTSMixerNormLayer.forwardr<   r#   r#   r!   r$   rY   x   s   
rY   c                       s,   e Zd Z fddZejdddZ  ZS )PatchTSMixerMLPc                    sP   t    ||j }t||| _t|j| _t||| _	t|j| _
d S r%   )r   r   Zexpansion_factorr   r   fc1Dropoutdropoutdropout1fc2dropout2)r    in_featuresout_featuresr0   Z
num_hiddenr!   r#   r$   r      s    

zPatchTSMixerMLP.__init__r8   c                 C   s0   |  tj| |}| |}| |}|S )z
        Args:
            inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
                Input to the MLP layer.
        Returns:
            `torch.Tensor` of the same shape as `inputs`
        )re   r   
functionalZgelurb   rf   rg   )r    r&   r#   r#   r$   r'      s    

zPatchTSMixerMLP.forward)r(   r)   r*   r   r=   r>   r'   r-   r#   r#   r!   r$   ra      s   ra   c                       s6   e Zd ZdZed fddZejdddZ  Z	S )$PatchTSMixerChannelFeatureMixerBlockzThis module mixes the features in the channel dimension.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r/   c                    sL   t    t|| _|j| _t|j|j|d| _|jrHt|j|jd| _	d S Nrh   ri   r0   r   )
r   r   rY   r]   
gated_attnra   num_input_channelsmlpr   gating_blockr7   r!   r#   r$   r      s    

z-PatchTSMixerChannelFeatureMixerBlock.__init__r8   c                 C   sT   |}|  |}|dddd}| jr.| |}| |}|dddd}|| }|S )z
        Args:
            inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
                input to the MLP layer
        Returns:
            `torch.Tensor` of the same shape as `inputs`
        r   r   r9   r   )r]   Zpermutern   rq   rp   )r    r&   residualoutr#   r#   r$   r'      s    


z,PatchTSMixerChannelFeatureMixerBlock.forwardr<   r#   r#   r!   r$   rk      s   rk           )modulequerykeyvalueattention_maskscalingrd   	head_maskc                 K   s   |d u r| dd }t||dd| }	|d ur>|	| }	tjj|	dd}	|d urj|	|dddd }	tjj|	|| j	d}	t|	|}
|
dd
 }
|
|	fS )Nr         r9   r   r   r   )ptraining)sizer=   matmulr:   r   rj   Zsoftmaxviewrd   r~   
contiguous)ru   rv   rw   rx   ry   rz   rd   r{   kwargsattn_weightsattn_outputr#   r#   r$   eager_attention_forward   s    r   c                       s   e Zd ZdZdeeeeeeee d fddZ	de
jee
j ee
j ee
j ee ee ee
jee
j eee
j  f d	d
dZ  ZS )PatchTSMixerAttentionz=Multi-headed attention from 'Attention Is All You Need' paperrt   FTN)	embed_dim	num_headsrd   
is_decoderbias	is_causalr0   c                    s   t    || _|| _|| _|| | _|| _| j| | jkrTtd| j d| d| jd | _|| _	|| _
tj|||d| _tj|||d| _tj|||d| _tj|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: ).r|   )r   )r   r   r   r   rd   head_dimr0   rS   rz   r   r   r   r   k_projv_projq_projout_proj)r    r   r   rd   r   r   r   r0   r!   r#   r$   r     s&    



zPatchTSMixerAttention.__init__)hidden_stateskey_value_statesry   layer_head_maskoutput_attentionsr   rF   c                 K   s  |du}|j dd \}}	|r(|j d n|	}
||	d| jf}||
d| jf}| |j| dd}|rh|n|}| |j| dd}| |j| dd}t}| jj	dkrt
| jj	 }|| ||||f| jsdn| j| j||d|\}}|||	d }| |}||dfS )z#Input shape: Batch x Time x ChannelNr   r   r9   eagerrt   )rd   rz   r   r{   )r`   r   r   r   r:   r   r   r   r0   Z_attn_implementationr
   r~   rd   rz   r_   r   r   )r    r   r   ry   r   r   r   Zis_cross_attentionZbszZtgt_lenZsrc_lenZq_input_shapeZkv_input_shapeZquery_statesZcurrent_statesZ
key_statesZvalue_statesZattention_interfacer   r   r#   r#   r$   r'   2  s:    


zPatchTSMixerAttention.forward)rt   FTFN)NNNF)r(   r)   r*   r+   r,   floatboolr   r   r   r=   r>   r   r	   tupler'   r-   r#   r#   r!   r$   r     s8        "    r   c                       s.   e Zd ZdZed fddZdd Z  ZS )PatchMixerBlockzxThis module mixes the patch dimension.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r/   c                    s|   t    t|| _|j| _|j| _t|j|j|d| _|jrPt	|j|jd| _
|jrxt|j|j|j|d| _t|| _d S )Nrm   r   )r   r   rd   r0   )r   r   rY   r]   	self_attnrn   ra   rE   rp   r   rq   r   r4   Zself_attn_headsrd   self_attn_layer	norm_attnr7   r!   r#   r$   r   p  s&    

zPatchMixerBlock.__init__c                 C   s   |}|  |}| jrX|j\}}}}||| ||}| j|dd\}}	}	|||||}|dd}| |}| jr~| |}|dd}| jr| 	|| }|| }
|
S )z
        Args:
            hidden_state (`torch.Tensor`): Input tensor.

        Returns:
            `torch.Tensor`: Transformed tensor.
        F)r   r9   r   )
r]   r   r`   r_   r   r:   rp   rn   rq   r   )r    rW   rr   
batch_sizeZn_varsrE   r4   Zhidden_state_reshapedZx_attn_rs   r#   r#   r$   r'     s     


zPatchMixerBlock.forwardr(   r)   r*   r+   r   r   r'   r-   r#   r#   r!   r$   r   h  s   r   c                       s6   e Zd ZdZed fddZejdddZ  Z	S )FeatureMixerBlockzThis module mixes the hidden feature dimension.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.

    r/   c                    sL   t    t|| _|j| _t|j|j|d| _|jrHt|j|jd| _	d S rl   )
r   r   rY   r]   rn   ra   r4   rp   r   rq   r7   r!   r#   r$   r     s    

zFeatureMixerBlock.__init__hiddenc                 C   s4   |}|  |}| |}| jr(| |}|| }|S )
        Args:
            hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
                Input tensor to the layer.

        Returns:
            `torch.Tensor`: Transformed tensor.
        )r]   rp   rn   rq   )r    r   rr   rs   r#   r#   r$   r'     s    	


zFeatureMixerBlock.forwardr<   r#   r#   r!   r$   r     s   r   c                       s6   e Zd ZdZed fddZejdddZ  Z	S )PatchTSMixerLayerz
    The `PatchTSMixer` layer that does all three kinds of mixing.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.

    r/   c                    sD   t    t|d| _t|d| _|j| _|jdkr@t|d| _d S )Nr/   mix_channel)	r   r   r   patch_mixerr   feature_mixermoderk   channel_feature_mixerr7   r!   r#   r$   r     s    

zPatchTSMixerLayer.__init__r   c                 C   s,   | j dkr| |}| |}| |}|S )r   r   )r   r   r   r   )r    r   r#   r#   r$   r'     s
    	



zPatchTSMixerLayer.forwardr<   r#   r#   r!   r$   r     s   	r   c                       s6   e Zd ZdZed fddZd	edddZ  ZS )
PatchTSMixerBlockzThe main computing framework of the `PatchTSMixer` model.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r/   c                    s2   t     j}t fddt|D | _d S )Nc                    s   g | ]}t  d qS )r/   )r   .0r   r/   r#   r$   
<listcomp>      z.PatchTSMixerBlock.__init__.<locals>.<listcomp>)r   r   
num_layersr   Z
ModuleListrangemixers)r    r0   r   r!   r/   r$   r   	  s    
zPatchTSMixerBlock.__init__Foutput_hidden_statesc                 C   sB   g }|}| j D ]}||}|r|| q|r6||fS |dfS dS )as  
        Args:
            hidden_state (`torch.Tensor`): The input tensor.
            output_hidden_states (`bool`, *optional*, defaults to False.):
                Whether to output the hidden states as well.

        Returns:
            `torch.Tensor`: The embedding. `list`: List of all hidden states if `output_hidden_states` is set to
            `True`.
        N)r   append)r    rW   r   Zall_hidden_statesZ	embeddingmodr#   r#   r$   r'     s    
zPatchTSMixerBlock.forward)F)	r(   r)   r*   r+   r   r   r   r'   r-   r#   r#   r!   r$   r     s   r   c                       s0   e Zd ZdZded fddZdd Z  ZS )	PatchTSMixerForPredictionHeadzqPrediction Head for Forecasting

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    Nr/   c                    s|   t    |j| _| jd ur&| j  t|j| _|d u rVt|j	|j
 |j| _n||j	|j
 | _tjdd| _d S )NZ	start_dim)r   r   prediction_channel_indicessortr   rc   head_dropoutdropout_layerr   rE   r4   prediction_lengthbase_forecast_blockget_parameter_projectionFlattenflatten)r    r0   distribution_outputr!   r#   r$   r   2  s    



z&PatchTSMixerForPredictionHead.__init__c                    s     |} |} |}t|tr<tdd |D }n|dd} jdurt|trtt fdd|D }n|d jf }|S )ar  

        Args:
            hidden_features (`torch.Tensor` of shape `(batch_size, num_patch, d_model)` in `flatten` mode
                or `(batch_size, n_vars, num_patch, d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
                features.

        Returns:
            `torch.Tensor` of shape `(batch_size, prediction_length, nvars)`.

        c                 s   s   | ]}| d dV  qdS )r   r   N)r:   r   zr#   r#   r$   	<genexpr>U  r   z8PatchTSMixerForPredictionHead.forward.<locals>.<genexpr>r   r   Nc                 3   s   | ]}|d  j f V  qdS ).N)r   r   r    r#   r$   r   [  r   .)r   r   r   
isinstancer   r:   r   r    hidden_featuresforecastr#   r   r$   r'   D  s    





z%PatchTSMixerForPredictionHead.forward)Nr   r#   r#   r!   r$   r   *  s   r   c                       s0   e Zd ZdZded fddZdd Z  ZS )	PatchTSMixerLinearHeadzLinear head for Classification and Regression.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    Nr/   c                    s   t    |j| _|j| _|jd u r,|j}nd}|| _|d u r\t|j|j	 | |j
| _n||j|j	 | | _|jd u rtjdd| _ntjdd| _t|j| _d S )Nr   r   r   )r   r   head_aggregationoutput_rangerE   r   r   r   r4   ro   num_targets
projectionr   r   r   rc   r   rd   )r    r0   r   Z
mul_factorr!   r#   r$   r   j  s&    


zPatchTSMixerLinearHead.__init__c                 C   s   | dd}| jdkr |d }n0| jdkr:|jddj}n| jdkrP|jdd}| jr`| |}| |}| |}| jdu r| j	durt
|| j	d	 | j	d
   | j	d
  }|S )ai  
        Args:
            hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
                or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
                features.

        Returns:
            `torch.Tensor` of shape `(batch_size x num_targets)`.
        r   r   Zuse_last).r   Zmax_poolr   Zavg_poolNr   r   )r:   r   maxvaluesrQ   r   rd   r   r   r   r=   Zsigmoid)r    r   r#   r#   r$   r'     s    






&zPatchTSMixerLinearHead.forward)Nr   r#   r#   r!   r$   r   b  s   r   c                   @   s*   e Zd ZU eed< dZdZdZdd ZdS )PatchTSMixerPreTrainedModelr0   modelpast_valuesFc                 C   s   t |tr,| jjdkrtjj|jddd nt |tjtj	frZ|j
j  |jjd nbt |tr|jj
j  |jjjd n8t |tjr|jjjd| jjd |j
dur|j
j  dS )zInitialize weightsrG   rt   g?)rQ   rR         ?N)r   r?   r0   rJ   r   initZnormal_rB   r^   r3   r   dataZzero_weightZfill_r.   r6   r   Zinit_std)r    ru   r#   r#   r$   _init_weights  s    


z)PatchTSMixerPreTrainedModel._init_weightsN)	r(   r)   r*   r   __annotations__Zbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingr   r#   r#   r#   r$   r     s
   
r   c                       s.   e Zd ZdZed fddZdd Z  ZS )PatchTSMixerPretrainHeadzcPretraining head.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r/   c                    s.   t    t|j| _t|j|j| _	d S r%   )
r   r   r   rc   r   r   r   r4   patch_lengthbase_pt_blockr7   r!   r#   r$   r     s    
z!PatchTSMixerPretrainHead.__init__c                 C   s   |  |}| |}|S )a  
        Args:
            hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
                or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
                features.

        Returns:
            `torch.Tensor` of shape `(batch_size x n_vars x num_patch x patch_length)`.
        )r   r   r   r#   r#   r$   r'     s    

z PatchTSMixerPretrainHead.forwardr   r#   r#   r!   r$   r     s   r   Fr&   
mask_ratiounmasked_channel_indiceschannel_consistent_masking
mask_valuec                 C   s,  |dk s|dkr t d| d| j\}}}}| j}	t|d|  }
|rjtj|d||	d}|d|d}ntj||||	d}tj||||	d}d|ddddd|
f< tj|dd}tj|dd}tj	|d|d	}|
dddd|}|durd|dd|ddddf< | | |}||d
 fS )a  random_masking: Mask the input considering the control variables.

    Args:
        inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`):
            The input tensor to mask.
        mask_ratio (`float`):
            Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1.
        unmasked_channel_indices (list, *optional*):
            Indices of channels that will not be masked.
        channel_consistent_masking (bool, *optional*, defaults to `False`):
            When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary
            across channels.
        mask_value (int, *optional*, defaults to 0):
            Define the value of masked patches for pretraining.

    Returns:
        `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x
        n]
    r   r   zMask ratio z has to be between 0 and 1.deviceNr   r   )r   index.r   )rS   r`   r   r,   r=   ZrandrepeatZonesZargsortZgatherrK   masked_fillr   )r&   r   r   r   r   r   num_channelssequence_lengthnum_featuresr   Zlen_keepnoisemaskZids_shuffleZids_restoreinputs_maskr#   r#   r$   random_masking  s&    
r   r&   num_forecast_mask_patchesr   r   c                 C   s  t |tr|g}dd |D }| j\}}}}tj|||| jd}	g }
d}t|}t||D ]P\}}|dksr||krtd| dt|| | }|
	|||g ||7 }qZt
|
dd d	}
||k r|
d d
 ||  |
d d
< n&||kr|
d d
 ||  |
d d
< d}|
D ]4\}}}|| }d|	||dd| df< |}qt|	jd }|	| }	|	dddd|}	|durd|	dd|ddddf< | |	 |}||	d fS )a  Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches.
    If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list.

    Parameters:
        inputs (`torch.Tensor`):
            Input of shape `(bs, num_channels, num_patch, patch_length)`
        num_forecast_mask_patches (`list`):
            Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
        unmasked_channel_indices (`list`, *optional*):
            Indices of channels that are not masked.
        mask_value (`int`, *optional*, defaults to 0):
            Values in the masked patches will be filled by `mask_value`.

    Returns:
        `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs,
        num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)`
    c                 S   s   g | ]}d qS )r   r#   r   r#   r#   r$   r   7  r   z$forecast_masking.<locals>.<listcomp>r   r   znum_forecast_mask_patches z6 should be greater than 0 and less than total patches.c                 S   s   | d S Nr9   r#   )xr#   r#   r$   <lambda>I  r   z"forecast_masking.<locals>.<lambda>)rw   r9   r   r   Nr   )r   r,   r`   r=   rD   r   sumziprS   r   sortedZrandpermrK   r   r   r   )r&   r   r   r   Zforecast_mask_ratiosr   r   r   r   r   Zt_listtotal_lengthtotal_ratior   ratioZtemp_lenZbatch1Z	patch_lenr   Zbatch2permr   r#   r#   r$   forecast_masking  sB    




r   c                       s6   e Zd ZdZed fddZejdddZ  Z	S )PatchTSMixerPatchifyz
    A class to patchify the time series sequence into different patches

    Returns:
        `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
    r/   c                    s   t    |j| _|j| _|j| _| j| jkrHtd| j d| j dt| j| j| j | j d | _| j| j| jd   }| j| | _	d S )NzSequence length (z+) has to be greater than the patch length ()r   )
r   r   Zcontext_lengthr   r   patch_striderS   r   rE   sequence_start)r    r0   Znew_sequence_lengthr!   r#   r$   r   j  s    
 zPatchTSMixerPatchify.__init__)r   c                 C   sp   |j d }|| jkr,td| d| j d|dd| jdddf }|jd| j| jd}|dd }|S )a!  
        Parameters:
            past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
                Input for patchification

        Returns:
            `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
        r   zInput sequence length (z%) doesn't match model configuration (r   N)	dimensionr   stepr   )	r`   r   rS   r  Zunfoldr   r  r:   r   )r    r   r   r;   r#   r#   r$   r'   {  s    	

zPatchTSMixerPatchify.forwardr<   r#   r#   r!   r$   r   b  s   r   c                       s6   e Zd ZdZed fddZejdddZ  Z	S )PatchTSMixerMaskinga  
    Class to perform random or forecast masking.

    Parameters:
        config (`PatchTSMixerConfig`): model config
    Returns:
        x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
            Masked patched input
        mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
            Bool tensor indicating True on masked points
    r/   c                    sT   t    |j| _|j| _|j| _|j| _|j| _|j| _| jd urPt| j| _d S r%   )	r   r   random_mask_ratior   	mask_typer   r   r   r   r7   r!   r#   r$   r     s    

zPatchTSMixerMasking.__init__rU   c                 C   sr   | j dkr*t|| j| j| j| jd\}}n8| j dkrPt|| j| j| jd\}}ntd| j  d|	 }||fS )a  
        Parameters:
            patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
                Patch input

        Return:
            masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
                Masked patched input
            mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
                Bool tensor indicating True on masked points

        rG   r   r   r   zInvalid mask type .)
r  r   r  r   r   r   r   r   rS   r   )r    rV   Zmasked_inputr   r#   r#   r$   r'     s$    

zPatchTSMixerMasking.forwardr<   r#   r#   r!   r$   r    s   r  c                       sL   e Zd ZdZed fddZejejeejejejf dddZ	  Z
S )PatchTSMixerStdScalerz
    Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
    subtracting from the mean and dividing by the standard deviation.
    r/   c                    sP   t    t|dr|jnd| _t|dr0|jnd| _t|drF|jnd| _d S )Nscaling_dimr   keepdimTminimum_scalegh㈵>)r   r   hasattrr  r   r  r  r7   r!   r#   r$   r     s    
zPatchTSMixerStdScaler.__init__r   observed_indicatorrF   c                 C   sz   |j | j| jd}|d}|| j | j| jd| }|| | d j | j| jd| }t|| j }|| | ||fS )C  
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
            observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Calculating the scale on the observed indicator.
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        r  r   r9   )r   r   r  Z	clamp_minr=   sqrtr  )r    r   r  denominatorlocZvariancescaler#   r#   r$   r'     s    
"zPatchTSMixerStdScaler.forwardr(   r)   r*   r+   r   r   r=   r>   r   r'   r-   r#   r#   r!   r$   r
    s
   r
  c                       sL   e Zd ZdZed fddZejejeejejejf dddZ	  Z
S )PatchTSMixerMeanScalerz
    Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
    accordingly.
    r/   c                    sf   t    t|dr|jnd| _t|dr0|jnd| _t|drF|jnd| _t|dr\|jnd | _d S )Nr  r   r  Tr  绽|=default_scale)r   r   r  r  r   r  r  r  r7   r!   r#   r$   r     s
    
zPatchTSMixerMeanScaler.__init__r  c           
      C   s   ||   j| jdd}|j| jdd}|tj|dd }| jdu rt|jdd}tj|ddd}t|| }n| jt| }t|dk||}tj|| j	d}|| }	| j
s|j| jd}|	t||fS )r  Tr  r   minNr   r   )absr   r   r=   clampr  Zsqueeze	ones_likewherer  r  
zeros_like)
r    r   r  Zts_sumZnum_observedr  Z	batch_sumZbatch_observationsr  Zscaled_datar#   r#   r$   r'     s    
zPatchTSMixerMeanScaler.forwardr  r#   r#   r!   r$   r    s
   r  c                       sR   e Zd ZdZed fddZd	ejeej e	ejejejf dddZ
  ZS )
PatchTSMixerNOPScalerz|
    Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
    r/   c                    s:   t    t|dr|jnd| _t|dr0|jnd| _d S )Nr  r   r  T)r   r   r  r  r   r  r7   r!   r#   r$   r   0  s    
zPatchTSMixerNOPScaler.__init__Nr  c                 C   sB   t j|ddj| j| jd}t j|ddj| j| jd}|||fS )a  
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        FrH   )r   r  )r=   r  rQ   r   r  r!  )r    r   r  r  r  r#   r#   r$   r'   5  s    zPatchTSMixerNOPScaler.forward)N)r(   r)   r*   r+   r   r   r=   r>   r   r   r'   r-   r#   r#   r!   r$   r"  +  s    r"  zS
    Base class for `PatchTSMixerEncoderOutput`, with potential hidden states.
    )Zcustom_introc                   @   s:   e Zd ZU dZdZeej ed< dZ	ee
ej  ed< dS )PatchTSMixerEncoderOutputa-  
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
        Hidden-state at the output of the last layer of the model.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Hidden-states of the model at the output of each layer.
    Nlast_hidden_stater   )r(   r)   r*   r+   r$  r   r=   FloatTensorr   r   r   r#   r#   r#   r$   r#  F  s   
r#  c                       sR   e Zd ZdZed fddZed
eje	e
 e	e
 eeef ddd	Z  ZS )PatchTSMixerEncoderz
    Encoder for PatchTSMixer which inputs patched time-series and outputs patched embeddings.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r/   c                    s^   t  | |j| _t|j|j| _|jr:t	|d| _
nd | _
t|d| _|jrZ|   d S )Nr/   )r   r   use_return_dictr   r   r   r4   patcherr@   r?   positional_encoderr   mlp_mixer_encoder	post_initr7   r!   r#   r$   r   a  s    zPatchTSMixerEncoder.__init__FN)r   r   return_dictrF   c                 C   sh   |dur|n| j }| |}| jdur0| |}| j||d\}}|s\tdd ||fD S t||dS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to
            predict the masked portion. For a forecasting task, this denotes the history/past time series values.
            Similarly, for classification or regression tasks, it denotes the appropriate context values of the
            time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series,
            it is greater than 1.

        Returns:
            `torch.FloatTensor` of shape `(batch_size, n_vars, num_patches, d_model)`
        Nr   c                 s   s   | ]
}|V  qd S r%   r#   r   vr#   r#   r$   r     s   z.PatchTSMixerEncoder.forward.<locals>.<genexpr>)r$  r   )r'  r(  r)  r*  r   r#  )r    r   r   r,  Zpatchesr$  r   r#   r#   r$   r'   q  s    


zPatchTSMixerEncoder.forward)FN)r(   r)   r*   r+   r   r   r   r=   r>   r   r   r   r   r#  r'   r-   r#   r#   r!   r$   r&  X  s     
r&  zG
    Base class for model's outputs, with potential hidden states.
    c                   @   s   e Zd ZU dZdZeej ed< dZ	ee
ej  ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed< dS )	PatchTSMixerModelOutputa  
    last_hidden_state (`torch.FloatTensor`  of shape `(batch_size, num_channels, num_patches, d_model)`):
        Hidden-state at the output of the last layer of the model.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Hidden-states of the model at the output of each layer.
    patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
        Patched input data to the model.
    mask (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`, *optional*):
        Bool Tensor indicating True in masked patches and False otherwise.
    loc (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
        Gives the mean of the context window per channel. Used for revin denorm outside the model, if revin
        enabled.
    scale (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
        Gives the std dev of the context window per channel. Used for revin denorm outside the model, if revin
        enabled.
    Nr$  r   rV   r   r  r  )r(   r)   r*   r+   r$  r   r=   r%  r   r   r   rV   r   r  r  r#   r#   r#   r$   r/    s   
r/  z=
    The PatchTSMixer Model for time-series forecasting.
    c                       sR   e Zd Zd	eed fddZed
eje	ej e	e e	e e
dddZ  ZS )PatchTSMixerModelF)r0   
mask_inputc                    s   t  | |j| _t|| _t|| _|du r<t|| _nd| _|j	dkrXt
|| _n*|j	dksl|j	du rxt|| _n
t|| _|jr|   dS )z
        mask_input (bool, *optional*, defaults to `False`):
            Whether to mask the input using the [`PatchTSMixerMasking`] module.
        TNrQ   rR   )r   r   r'  r&  encoderr   patchingr  maskingrz   r  scalerr
  r"  r+  )r    r0   r1  r!   r#   r$   r     s    



zPatchTSMixerModel.__init__N)r   observed_maskr   r,  rF   c                 C   s   |dur|n| j }d}|du r(t|}| ||\}}}| |}	|	}
| jdur`| |	\}
}| j|
||d}t|trt	| }|stdd |j
|j|	|||fD S t|j
|j|	|||dS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:
            - 1 for values that are **observed**,
            - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
        Nr   r,  c                 s   s   | ]
}|V  qd S r%   r#   r-  r#   r#   r$   r     s   z,PatchTSMixerModel.forward.<locals>.<genexpr>)r$  r   rV   r   r  r  )r'  r=   r  r5  r3  r4  r2  r   r   r#  r$  r   r/  )r    r   r6  r   r,  r   Zscaled_past_valuesr  r  Z	patched_xZ	enc_inputZencoder_outputr#   r#   r$   r'     sD    



zPatchTSMixerModel.forward)F)NFN)r(   r)   r*   r   r   r   r   r=   r>   r   r/  r'   r-   r#   r#   r!   r$   r0    s      r0  z>
    Output type of [`PatchTSMixerForPreTrainingOutput`].
    c                   @   s^   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeej  ed< dS ) PatchTSMixerForPreTrainingOutputa@  
    loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
        Total loss
    prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, patch_length)`):
        Prediction output from the pretrain head.
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
        Backbone embeddings before passing through the head.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Hidden-states of the model at the output of each layer.
    Nlossprediction_outputsr$  r   r(   r)   r*   r+   r9  r   r=   r%  r   r:  r$  r   r   r#   r#   r#   r$   r8  #  s
   
r8  z.
    `PatchTSMixer` for mask pretraining.
    c                	       sP   e Zd Zed fddZed
ejeej ee	 e	ee	 e
ddd	Z  ZS )PatchTSMixerForPretrainingr/   c                    sH   t  | t|dd| _t|d| _|j| _|j| _|jrD|   d S )NT)r1  r/   )	r   r   r0  r   r   headmasked_lossr'  r+  r7   r!   r#   r$   r   A  s    z#PatchTSMixerForPretraining.__init__NFT)r   r6  r   return_lossr,  rF   c           
      C   s   |dur|n| j }| jdu r,tjjdd}ntjjdd}| j||||d}t|tr^t| }| 	|j
}|du r|||j}	nd}	| jdu r|	dur|	jdd|j  |j d	  }	|std
d |	||j
|jfD S t|	||j
|jdS )aT  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:
            - 1 for values that are **observed**,
            - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
        return_loss (`bool`,  *optional*):
            Whether to return the loss in the `forward` call.
        NTnoneZ	reductionrQ   r6  r   r,  r   r   r  c                 s   s   | ]
}|V  qd S r%   r#   r-  r#   r#   r$   r     s   z5PatchTSMixerForPretraining.forward.<locals>.<genexpr>r9  r:  r$  r   )r'  r>  r=   r   MSELossr   r   r   r/  r=  r$  rV   rQ   r   r   r   r8  )
r    r   r6  r   r?  r,  r9  model_outputZx_hatloss_valr#   r#   r$   r'   L  s@    

$
z"PatchTSMixerForPretraining.forward)NFTN)r(   r)   r*   r   r   r   r=   r>   r   r   r8  r'   r-   r#   r#   r!   r$   r<  ;  s       r<  z=
    Output type of [`PatchTSMixerForPredictionOutput`].
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeej  ed< dZeej ed< dZeej ed< dS )	PatchTSMixerForPredictionOutputaD  
    loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
        Total loss.
    prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_input_channels)`):
        Prediction output from the forecast head.
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
        Backbone embeddings before passing through the head.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    loc (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
        Input mean
    scale (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
        Input std dev
    Nr9  r:  r$  r   r  r  )r(   r)   r*   r+   r9  r   r=   r%  r   r:  r$  r   r   r  r  r#   r#   r#   r$   rG    s   
rG  z
    Base class for time series model's predictions outputs that contains the sampled values from the chosen
    distribution.
    c                   @   s$   e Zd ZU dZdZeej ed< dS )"SamplePatchTSMixerPredictionOutput
    sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, number_channels)`):
        Sampled values from the chosen distribution.
    N	sequences	r(   r)   r*   r+   rJ  r   r=   r%  r   r#   r#   r#   r$   rH    s   
rH  c                   @   s$   e Zd ZU dZdZeej ed< dS )"SamplePatchTSMixerRegressionOutputrI  NrJ  rK  r#   r#   r#   r$   rL    s   
rL  )inputtargetrF   c                 C   s   |  | S )zc
    Computes the negative log likelihood loss from input distribution with respect to target.
    )Zlog_prob)rM  rN  r#   r#   r$   nll  s    rO  )input_tensorweightsrF   c                 C   sr   |durbt |dk| | t | }t j|r8|j|dn| dd}|rV|j|dn| | S | j|dS dS )aj  
    Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
    meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.

    Args:
        input_tensor (`torch.FloatTensor`):
            Input tensor, of which the average must be computed.
        weights (`torch.FloatTensor`, *optional*):
            Weights tensor, of the same shape as `input_tensor`.
        dim (`int`, *optional*):
            The dim along which to average `input_tensor`.

    Returns:
        `torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
    Nr   r   r   r  )r=   r   r!  r  r   rQ   )rP  rQ  r   Zweighted_tensorZsum_weightsr#   r#   r$   weighted_average  s
    "rR  c                
       s   e Zd ZdZed fddZedeje	ej e	ej e	e
 e
e	e
 edd	d
Ze deje	ej edddZ  ZS )PatchTSMixerForPredictionz
    `PatchTSMixer` for forecasting application.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.

    Returns:
        `None`.
    r/   c                    s   t  | |j| _|j| _|j| _|j| _|jdkr>d | _nF|j}tt	t
d}||jd }|d urt||d| _ntd|j t|| _t|| jd| _|jr|   d S )NmseZ	student_tnormalnegative_binomialr   Unknown distribution output r0   r   )r   r   r9  r'  r   num_parallel_samplesr   r   r   r   r   getrS   r0  r   r   r=  r+  )r    r0   r   distribution_output_mapoutput_classr!   r#   r$   r     s.    

z"PatchTSMixerForPrediction.__init__NFT)r   r6  future_valuesr   r?  r,  rF   c                 C   s  | j dkrtjdd}n| j dkr(t}ntd|dur<|n| j}| j||||d}t|trft	| }| 
|j}	d}
| jdur.| jr| jj|	|jd| jf |jd| jf d	}|dur|d
u r|||d| jf }
t|
}
nL|	|jd| jf  |jd| jf  }	|dur|d
u r||	|d| jf }
nt| jrt| jj|	|j|jd	}|dur|d
u r|||}
t|
}
n.|	|j |j }	|dur|d
u r||	|}
| jdur|jd| jf }|jd| jf }n|j}|j}|stdd |
|	|j|j||fD S t|
|	|j|j||dS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:
            - 1 for values that are **observed**,
            - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
        future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,:
            `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*):
            Target values of the time series, that serve as labels for the model. The `future_values` is what the
            Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
            required for a pretraining task.

            For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
            to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
            pass the target data with all channels, as channel Filtering for both prediction and target will be
            manually applied before the loss computation.
        return_loss (`bool`,  *optional*):
            Whether to return the loss in the `forward` call.
        rT  rQ   rA  rO  2Invalid loss function: Allowed values: mse and nllNrB  .r  r  Tc                 s   s   | ]
}|V  qd S r%   r#   r-  r#   r#   r$   r     s   z4PatchTSMixerForPrediction.forward.<locals>.<genexpr>)r9  r:  r$  r   r  r  )r9  r   rD  rO  rS   r'  r   r   r   r/  r=  r$  r   r   distributionr  r  rR  r   rG  )r    r   r6  r^  r   r?  r,  r9  rE  y_hatrF  ra  r  r  r#   r#   r$   r'     s    $







z!PatchTSMixerForPrediction.forward)r   r6  rF   c                    s\   | j }| |d|dd}| jj|j|j|jd  fddt|D }tj|dd}t	|d	S )
a  
        Generate sequences of sample predictions from a model with a probability distribution head.

        Args:
            past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Past values of the time series that serves as context in order to predict the future.

            observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).

        Return:
            [`SamplePatchTSMixerPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
            number of samples, prediction_length, num_input_channels)`.
        NF)r   r^  r6  r   r`  c                    s   g | ]}   qS r#   sampler   ra  r#   r$   r     r   z6PatchTSMixerForPrediction.generate.<locals>.<listcomp>r   r   rJ  )
rZ  r   ra  r:  r  r  r   r=   stackrH  )r    r   r6  rZ  outputssamplesr#   re  r$   generate  s    	z"PatchTSMixerForPrediction.generate)NNFTN)N)r(   r)   r*   r+   r   r   r   r=   r>   r   r   rG  r'   no_gradrH  rj  r-   r#   r#   r!   r$   rS    s0         y rS  zK
    Output type of [`PatchTSMixerForTimeSeriesClassificationOutput`].
    c                   @   s^   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeej  ed< dS )-PatchTSMixerForTimeSeriesClassificationOutputaP  
    loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
        Total loss.
    prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
        Prediction output from the classification head.
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
        Backbone embeddings before passing through the head.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    Nr9  r:  r$  r   r;  r#   r#   r#   r$   rl    s
   
rl  c                	       sT   e Zd ZdZed fddZedeje	ej e	e
 e
e	e
 edd	d
Z  ZS )'PatchTSMixerForTimeSeriesClassificationz
    `PatchTSMixer` for classification application.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.

    Returns:
        `None`.
    r/   c                    s`   t  | t|| _t|d| _|j| _|jdv rHt|j	|j
d| _nd | _|jr\|   d S )Nr/   rR   rQ   Tr4   rE   )r   r   r0  r   r   r=  r'  rz   InjectScalerStatistics4Dr4   rE   inject_scaler+  r7   r!   r#   r$   r     s    

z0PatchTSMixerForTimeSeriesClassification.__init__NFTr   target_valuesr   r?  r,  rF   c           
      C   s   t j }|dur|n| j}| j|||d}t|tr>t| }| jdur`| j|j	|j
|jd|_	| |j	}|dur|du r|||}	nd}	|stdd |	||j	|jfD S t|	||j	|jdS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
            `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*):
            Target
            values of the time series, that serve as labels for the model. The `target_values` is what the
            Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
            required for a pretraining task.

            For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
            to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
            pass the target data with all channels, as channel Filtering for both prediction and target will be
            manually applied before the loss computation.

            For a classification task, it has a shape of `(batch_size,)`.

            For a regression task, it has a shape of `(batch_size, num_targets)`.
        return_loss (`bool`, *optional*):
            Whether to return the loss in the `forward` call.
        Nr7  r`  Tc                 s   s   | ]
}|V  qd S r%   r#   r-  r#   r#   r$   r   >  s   zBPatchTSMixerForTimeSeriesClassification.forward.<locals>.<genexpr>rC  )r=   r   ZCrossEntropyLossr'  r   r   r   r/  rq  r$  r  r  r=  r   rl  )
r    r   rs  r   r?  r,  r9  rE  rb  rF  r#   r#   r$   r'     sB    $



z/PatchTSMixerForTimeSeriesClassification.forward)NFTN)r(   r)   r*   r+   r   r   r   r=   r>   r   r   rl  r'   r-   r#   r#   r!   r$   rm    s       rm  z=
    Output type of [`PatchTSMixerForRegressionOutput`].
    c                   @   s^   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeej  ed< dS )PatchTSMixerForRegressionOutputaM  
    loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
        Total loss.
    regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
        Prediction output from the regression head.
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
        Backbone embeddings before passing through the head.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    Nr9  regression_outputsr$  r   )r(   r)   r*   r+   r9  r   r=   r%  r   ru  r$  r   r   r#   r#   r#   r$   rt  P  s
   
rt  c                       s@   e Zd Zdeeed fddZejejejdddZ  ZS )	rp  r9   )r4   rE   	expansionc                    s`   t    t|d || | _t|| || _tdd| | _td| d| _|| _d S r   )	r   r   r   r   inverse_trans_expansioninverse_trans_compressionmap_scale_expansionmap_scale_compressionrE   )r    r4   rE   rv  r!   r#   r$   r   i  s    
z!InjectScalerStatistics4D.__init__)r&   r  r  c                 C   s   | dd}|d}|dd| jd}| dd}|d}|dd| jd}tj||gdd}| |}| |}tj||gdd}| |}| 	|}|S )a  
        Args:
            inputs (`torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`)
            loc (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
            scale (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
        Returns:
            `torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`
        r   r   r   r   )
r:   rK   r   rE   r=   catry  rz  rw  rx  )r    r&   r  r  rQ   stdevZconcat_statsr#   r#   r$   r'   r  s    






z InjectScalerStatistics4D.forward)r9   )	r(   r)   r*   r,   r   r=   r>   r'   r-   r#   r#   r!   r$   rp  h  s   	rp  z4
    `PatchTSMixer` for regression application.
    c                	       sj   e Zd Zed fddZedejeej ee	 e	ee	 e
ddd	Ze ejed
ddZ  ZS )PatchTSMixerForRegressionr/   c                    s   t  | t|| _|j| _|j| _|j| _|j| _|jdkrHd | _n@tt	t
d}||j}|d urx||jd| _ntd|j |jdv rt|j|jd| _nd | _t|| jd| _|jr|   d S )NrT  rU  r   rX  rn  ro  rY  )r   r   r0  r   r9  r   r'  rZ  r   r   r   r[  r   rS   rz   rp  r4   rE   rq  r   r=  r+  )r    r0   r\  r]  r!   r#   r$   r     s2    


z"PatchTSMixerForRegression.__init__NFTrr  c                    sL   j dkrtjdd}n j dkr(t}ntd|dur<|n j} j|||d}t|trdt	| } j
dur j
|j|j|jd|_ |j}|dur|d	u r jr jd
krt|dk rtd j|}	t fdd|D }||	|}
t|
}
n
|||}
nd}
|s8tdd |
||j|jfD S t|
||j|jdS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
            `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*):
            Target values of the time series, that serve as labels for the model. The `target_values` is what the
            Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
            required for a pretraining task.

            For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
            to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
            pass the target data with all channels, as channel Filtering for both prediction and target will be
            manually applied before the loss computation.

            For a classification task, it has a shape of `(batch_size,)`.

            For a regression task, it has a shape of `(batch_size, num_targets)`.
        return_loss (`bool`, *optional*):
            Whether to return the loss in the `forward` call.
        rT  rQ   rA  rO  r_  Nr7  r`  TrW  r   zDtarget_values cannot be negative for negative_binomial distribution.c                 3   s   | ]}| d  jjV  qdS )r   N)r   r0   r   )r   itemr   r#   r$   r     r   z4PatchTSMixerForRegression.forward.<locals>.<genexpr>c                 s   s   | ]
}|V  qd S r%   r#   r-  r#   r#   r$   r   	  s   )r9  ru  r$  r   )r9  r   rD  rO  rS   r'  r   r   r   r/  rq  r$  r  r  r=  r   r=   any	Exceptionra  rR  r   rt  )r    r   rs  r   r?  r,  r9  rE  rb  ra  rF  r#   r   r$   r'     sX    #






z!PatchTSMixerForRegression.forward)r   rF   c                    s^   | j }| |ddd}| j|j  fddt|D }tj|ddd|| jj	}t
|d	S )
a
  
        Generate sequences of sample predictions from a model with a probability distribution head.

        Args:
            past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Past values of the time series that serves as context in order to predict the target values.

        Return:
            [`SamplePatchTSMixerRegressionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
            number of samples, num_targets)`.
        NF)r   rs  r   c                    s   g | ]}   qS r#   rc  r   re  r#   r$   r   8  s   z6PatchTSMixerForRegression.generate.<locals>.<listcomp>r   r   r   rf  )rZ  r   ra  ru  r   r=   rg  r   r0   r   rL  )r    r   rZ  rh  ri  r#   re  r$   rj    s    
z"PatchTSMixerForRegression.generate)NFTN)r(   r)   r*   r   r   r   r=   r>   r   r   rt  r'   rk  rL  rj  r-   r#   r#   r!   r$   r}    s$   '    \r}  )r   r0  r<  rS  rm  r}  )Nrt   N)NFr   )Nr   )NN)Rr+   rM   dataclassesr   typingr   r   r   r=   Ztorch.nnr   Ztransformers.modeling_utilsr   Ztransformers.utilsr   Zmodeling_flash_attention_utilsr	   Zmodeling_utilsr
   Zprocessing_utilsr   Ztime_series_utilsr   r   r   utilsr   r   Zconfiguration_patchtsmixerr   Z
get_loggerr(   loggerModuler   r.   r?   rY   ra   rk   r>   r   r   r   r   r   r   r   r   r   r   r   listr   r,   r   r   r   r  r
  r  r"  r#  r&  r/  r0  r8  r<  rG  rH  rL  distributionsDistributionrO  rR  rS  rl  rm  rt  rp  r}  __all__r#   r#   r#   r$   <module>   s   
'17   XF-&)8G"   >  
E1=$7EaT	
 Xn( -