a
    hJ                 
   @   s  d Z ddlZddlmZ ddlmZmZmZ ddlZddlm	Z	 ddl
mZ ddlmZ dd	lmZ dd
lmZmZ ddlmZ ddlmZmZmZ ddlmZmZmZ ddlmZ ee Z!dfe	j"ej#ej#ej#eej# ee$ e$eej# dddZ%G dd de	j"Z&G dd de	j"Z'dgej#e$ee( e)e*dddZ+dhej#ee(e*f ee( e*dddZ,G dd  d e	j"Z-G d!d" d"e	j"Z.G d#d$ d$e	j"Z/eG d%d& d&eZ0G d'd( d(e	j"Z1G d)d* d*e	j"Z2G d+d, d,e0Z3eed-d.G d/d0 d0eZ4eed1d.G d2d3 d3eZ5eed4d.G d5d6 d6eZ6eed7d.G d8d9 d9eZ7eed:d.G d;d< d<eZ8eed=d.G d>d? d?eZ9ej:j;ej#ej#d@dAdBZ<diej#eej# ej#dCdDdEZ=G dFdG dGe	j"Z>G dHdI dIe	j"Z?G dJdK dKe	j"Z@G dLdM dMe	j"ZAeG dNdO dOe0ZBG dPdQ dQe	j"ZCedRd.G dSdT dTe0ZDG dUdV dVe	j"ZEedWd.G dXdY dYe0ZFedZd.G d[d\ d\e	j"ZGed]d.G d^d_ d_e0ZHG d`da dae	j"ZIedbd.G dcdd dde0ZJg deZKdS )jzPyTorch PatchTST model.    N)	dataclass)CallableOptionalUnion)nn   )ACT2CLS)FlashAttentionKwargs)BaseModelOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)NegativeBinomialOutputNormalOutputStudentTOutput)ModelOutputauto_docstringlogging   )PatchTSTConfig        )modulequerykeyvalueattention_maskscalingdropout	head_maskc                 K   s   |d u r| dd }t||dd| }	|d ur>|	| }	tjj|	dd}	|d urj|	|dddd }	tjj|	|| j	d}	t|	|}
|
dd
 }
|
|	fS )N         r   dimr   )ptraining)sizetorchmatmul	transposer   Z
functionalZsoftmaxviewr   r%   
contiguous)r   r   r   r   r   r   r   r   kwargsattn_weightsattn_output r/   j/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/models/patchtst/modeling_patchtst.pyeager_attention_forward&   s    r1   c                       s   e Zd ZdZdeeeeeeee d fddZ	de
jee
j ee
j ee
j ee ee ee
jee
j eee
j  f d	d
dZ  ZS )PatchTSTAttentionz=Multi-headed attention from 'Attention Is All You Need' paperr   FTN)	embed_dim	num_headsr   
is_decoderbias	is_causalconfigc                    s   t    || _|| _|| _|| | _|| _| j| | jkrTtd| j d| d| jd | _|| _	|| _
tj|||d| _tj|||d| _tj|||d| _tj|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: ).r    r6   )super__init__r3   r4   r   head_dimr8   
ValueErrorr   r5   r7   r   Lineark_projv_projq_projout_proj)selfr3   r4   r   r5   r6   r7   r8   	__class__r/   r0   r<   H   s&    



zPatchTSTAttention.__init__)hidden_stateskey_value_statesr   layer_head_maskoutput_attentionsr,   returnc                 K   s  |du}|j dd \}}	|r(|j d n|	}
||	d| jf}||
d| jf}| |j| dd}|rh|n|}| |j| dd}| |j| dd}t}| jj	dkrt
| jj	 }|| ||||f| jsdn| j| j||d|\}}|||	d }| |}||dfS )z#Input shape: Batch x Time x ChannelNr   r   r!   eagerr   )r   r   rJ   r   )shaper=   rB   r*   r)   r@   rA   r1   r8   Z_attn_implementationr   r%   r   r   reshaper+   rC   )rD   rG   rH   r   rI   rJ   r,   Zis_cross_attentionZbszZtgt_lenZsrc_lenZq_input_shapeZkv_input_shapeZquery_statesZcurrent_statesZ
key_statesZvalue_statesZattention_interfacer.   r-   r/   r/   r0   forwardg   s:    


zPatchTSTAttention.forward)r   FTFN)NNNF)__name__
__module____qualname____doc__intfloatboolr   r   r<   r'   Tensorr   r	   tuplerO   __classcell__r/   r/   rE   r0   r2   E   s8        "    r2   c                       s6   e Zd ZdZed fddZejdddZ  Z	S )PatchTSTBatchNormzP
    Compute batch normalization over the sequence length (time) dimension.
    r8   c                    s"   t    tj|j|jd| _d S )Neps)r;   r<   r   ZBatchNorm1dd_modelnorm_eps	batchnormrD   r8   rE   r/   r0   r<      s    
zPatchTSTBatchNorm.__init__)inputsc                 C   s"   | dd}| |}| ddS )a  
        Parameters:
            inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
                input for Batch norm calculation
        Returns:
            `torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
        r   r!   )r)   r`   )rD   rb   outputr/   r/   r0   rO      s    
zPatchTSTBatchNorm.forward
rP   rQ   rR   rS   r   r<   r'   rW   rO   rY   r/   r/   rE   r0   rZ      s   rZ   Frb   
mask_ratiounmasked_channel_indiceschannel_consistent_masking
mask_valuec                 C   s,  |dk s|dkr t d| d| j\}}}}| j}	t|d|  }
|rjtj|d||	d}|d|d}ntj||||	d}tj||||	d}d|ddddd|
f< tj|dd}tj|dd}tj	|d|d	}|
dddd|}|durd|dd|ddddf< | | |}||d
 fS )a  random_masking: Mask the input considering the control variables.

    Args:
        inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`):
            The input tensor to mask.
        mask_ratio (`float`):
            Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1.
        unmasked_channel_indices (list, *optional*):
            Indices of channels that will not be masked.
        channel_consistent_masking (bool, *optional*, defaults to `False`):
            When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary
            across channels.
        mask_value (int, *optional*, defaults to 0):
            Define the value of masked patches for pretraining.

    Returns:
        `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x
        n]
    r   r   zMask ratio z has to be between 0 and 1.deviceNr   r"   )r#   index.r   )r>   rM   rk   rT   r'   ZrandrepeatZonesZargsortZgather	unsqueezemasked_fillrV   )rb   rf   rg   rh   ri   
batch_sizenum_channelssequence_lengthnum_featuresrk   Zlen_keepnoisemaskZids_shuffleZids_restoreinputs_maskr/   r/   r0   random_masking   s&    
rx   rb   num_forecast_mask_patchesrg   ri   c                 C   s  t |tr|g}dd |D }| j\}}}}tj|||| jd}	g }
d}t|}t||D ]P\}}|dksr||krtd| dt|| | }|
	|||g ||7 }qZt
|
dd d	}
||k r|
d d
 ||  |
d d
< n&||kr|
d d
 ||  |
d d
< d}|
D ]4\}}}|| }d|	||dd| df< |}qt|	jd }|	| }	|	dddd|}	|durd|	dd|ddddf< | |	 |}||	d fS )a  Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches.
    If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list.

    Parameters:
        inputs (`torch.Tensor`):
            Input of shape `(bs, num_channels, num_patch, patch_length)`
        num_forecast_mask_patches (`list`):
            Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
        unmasked_channel_indices (`list`, *optional*):
            Indices of channels that are not masked.
        mask_value (`int`, *optional*, defaults to 0):
            Values in the masked patches will be filled by `mask_value`.

    Returns:
        `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs,
        num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)`
    c                 S   s   g | ]}d qS )r   r/   .0_r/   r/   r0   
<listcomp>      z$forecast_masking.<locals>.<listcomp>rj   r   znum_forecast_mask_patches z6 should be greater than 0 and less than total patches.c                 S   s   | d S )Nr!   r/   )xr/   r/   r0   <lambda>  r   z"forecast_masking.<locals>.<lambda>)r   r!   r   r   Nrm   )
isinstancerT   rM   r'   zerosrk   sumzipr>   appendsortedZrandpermro   rn   rp   rV   )rb   rz   rg   ri   Zforecast_mask_ratiosrq   rr   rs   rt   rv   Zt_listtotal_lengthtotal_ratiopatch_lengthratioZtemp_lenZbatch1Z	patch_lenr}   Zbatch2permrw   r/   r/   r0   forecast_masking   sB    




r   c                       s6   e Zd ZdZed fddZejdddZ  Z	S )PatchTSTPatchifyz
    A class to patchify the time series sequence into different patches

    Returns:
        `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
    r[   c                    s   t    |j| _|j| _|j| _| j| jkrHtd| j d| j dt| j| j| j | j d | _| j| j| jd   }| j| | _	d S )NzSequence length (z+) has to be greater than the patch length ()r   )
r;   r<   context_lengthrs   r   patch_strider>   maxnum_patchessequence_start)rD   r8   Znew_sequence_lengthrE   r/   r0   r<   9  s    
 zPatchTSTPatchify.__init__)past_valuesc                 C   sp   |j d }|| jkr,td| d| j d|dd| jdddf }|jd| j| jd}|dd }|S )a!  
        Parameters:
            past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
                Input for patchification

        Returns:
            `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
        zInput sequence length (z%) doesn't match model configuration (r9   N)	dimensionr&   step)	rM   rs   r>   r   Zunfoldr   r   r)   r+   )rD   r   rs   rc   r/   r/   r0   rO   J  s    	

zPatchTSTPatchify.forwardrd   r/   r/   rE   r0   r   1  s   r   c                       s6   e Zd ZdZed fddZejdddZ  Z	S )PatchTSTMaskinga  
    Class to perform random or forecast masking.

    Parameters:
        config (`PatchTSTConfig`): model config
    Returns:
        x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
            Masked patched input
        mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
            Bool tensor indicating True on masked points
    r[   c                    sT   t    |j| _|j| _|j| _|j| _|j| _|j| _| jd urPt| j| _d S N)	r;   r<   random_mask_ratiorh   	mask_typerz   rg   ri   r   ra   rE   r/   r0   r<   n  s    

zPatchTSTMasking.__init__patch_inputc                 C   sr   | j dkr*t|| j| j| j| jd\}}n8| j dkrPt|| j| j| jd\}}ntd| j  d|	 }||fS )a  
        Parameters:
            patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
                Patch input

        Return:
            masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
                Masked patched input
            mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
                Bool tensor indicating True on masked points

        randomre   Zforecastry   zInvalid mask type .)
r   rx   r   rg   rh   ri   r   rz   r>   rV   )rD   r   Zmasked_inputrv   r/   r/   r0   rO   y  s$    

zPatchTSTMasking.forwardrd   r/   r/   rE   r0   r   a  s   r   c                       s>   e Zd ZdZed fddZd	ejee	 dddZ
  ZS )
PatchTSTEncoderLayerz 
    PatchTST encoder layer
    r[   c              
      s  t    |j| _t|j|j|j|d| _|jdkr@t	
|jnt	 | _|jdkr`t|| _n0|jdkrt	j|j|jd| _nt|j d| jr|jdkrt	
|jnt	 | _|jdkrt|| _n0|jdkrt	j|j|jd| _nt|j dt	t	j|j|j|jdt|j  |jdkr6t	
|jnt	 t	j|j|j|jd| _|jdkrnt	
|jnt	 | _|jdkrt|| _n2|jdkrt	j|j|jd| _nt|j d|j| _d S )N)r3   r4   r   r8   r   r`   Z	layernormr\   z$ is not a supported norm layer type.r:   ) r;   r<   channel_attentionr2   r^   Znum_attention_headsZattention_dropout	self_attnZpath_dropoutr   DropoutIdentitydropout_path1Z	norm_typerZ   norm_sublayer1	LayerNormr_   r>   dropout_path2norm_sublayer2Z
Sequentialr?   Zffn_dimr6   r   Zactivation_functionZ
ff_dropoutffdropout_path3norm_sublayer3pre_normra   rE   r/   r0   r<     sD    
 

 


"zPatchTSTEncoderLayer.__init__Nhidden_staterJ   c                 C   s  |j \}}}}||| ||}| jrP| j| ||d\}}}	|| | }n(| j||d\}}}	| || | }|||||}| jr*|dd	 }||| ||}| jr| j| 
||d\}}
}	|| | }n(| j||d\}}
}	| 
|| | }|||||}|dd	 }||| ||}| jr`|| | | | }n| || | | }|||||}|f}|r|| jr||
fn|f7 }|S )a  
        Parameters:
            hidden_state (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)`, *required*):
                Past values of the time series
            output_attentions (`bool`, *optional*):
                Whether or not to return the output attention of all layers
        Return:
            `torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)`

        )rG   rJ   r!   r   )rM   r*   r   r   r   r   rN   r   r)   r+   r   r   r   r   r   )rD   r   rJ   rq   num_input_channelsrs   r^   r.   r-   r}   Zchannel_attn_weightsoutputsr/   r/   r0   rO     sF    

zPatchTSTEncoderLayer.forward)N)rP   rQ   rR   rS   r   r<   r'   rW   r   rV   rO   rY   r/   r/   rE   r0   r     s   2r   c                   @   s<   e Zd ZU eed< dZdZdZej	dddZ
ddd	Zd
S )PatchTSTPreTrainedModelr8   modelr   F)r   c                 C   s   t |trdt| jj| jj| jj | jj d }| jjrRtj	j
|jdd |d7 }|| j||_nt |tjr|jj  |jjd nbt |tr|jjj  |jjjd n8t |tjr|jjj
d| jjd |jdur|jj  dS )z$
        Initialize weights
        r   g{Gz?)std      ?r   )meanr   N)r   PatchTSTPositionalEncodingr   r8   r   r   r   use_cls_tokenr   initZnormal_	cls_token_init_peposition_encr   r6   dataZzero_weightZfill_rZ   r`   r?   Zinit_std)rD   r   r   r/   r/   r0   _init_weights/  s(    


z%PatchTSTPreTrainedModel._init_weightsc                 C   s   t |tr||_d S r   )r   PatchTSTEncodergradient_checkpointing)rD   r   r   r/   r/   r0   _set_gradient_checkpointingI  s    
z3PatchTSTPreTrainedModel._set_gradient_checkpointingN)F)rP   rQ   rR   r   __annotations__Zbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingr   Moduler   r   r/   r/   r/   r0   r   (  s   
r   c                       s2   e Zd Zed fddZejdddZ  ZS )PatchTSTEmbeddingr[   c                    sj   t    |j| _|j| _| jr4t|j|j| _n2t	 | _t
|jD ]}| jt|j|j qHd S r   )r;   r<   r   share_embeddingr   r?   r   r^   input_embedding
ModuleListranger   )rD   r8   r}   rE   r/   r0   r<   O  s    

zPatchTSTEmbedding.__init__r   c                    sh    j d }|jkr,tdj d| djr> }n& fddt|D }tj|dd}|S )a%  
        Parameters:
            patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
                Patch input for embedding
        return:
            `torch.Tensor` of shape `(batch_size, num_channels, num_patches, d_model)`
        r   z&The defined number of input channels (zQ) in the config has to be the same as the number of channels in the batch input (r   c              	      s2   g | ]*}j |  d d |d d d d f qS r   )r   r|   ir   rD   r/   r0   r~   m  r   z-PatchTSTEmbedding.forward.<locals>.<listcomp>r"   )rM   r   r>   r   r   r   r'   stack)rD   r   r   Z
embeddingsr/   r   r0   rO   [  s    	


zPatchTSTEmbedding.forward	rP   rQ   rR   r   r<   r'   rW   rO   rY   r/   r/   rE   r0   r   N  s   r   c                       sP   e Zd ZdZeed fddZeeeej	dddZ
ejdd	d
Z  ZS )r   z'
    Class for positional encoding
    r8   r   c                    st   t    |j| _|j| _|jrBttddd|j| _	|d7 }| 
||| _|jdkrft|jnt | _d S )Nr   r   )r;   r<   r   r   r   	Parameterr'   r   r^   r   r   r   positional_dropoutr   r   rD   r8   r   rE   r/   r0   r<   w  s    
z#PatchTSTPositionalEncoding.__init__)r8   r   rK   c                 C   s   | j dkr$tjt|| jdd}n| j dkrt|| j}td|d}t	td| jdt
d| j   }t|| |d d dd df< t|| |d d dd df< ||  }|| d	  }tj|d
d}nt| j  d|S )Nr   TZrequires_gradZsincosr   r   r!   g     @
   FzN is not a valid positional encoder. Available types are 'random' and 'sincos'.)Zpositional_encoding_typer   r   r'   Zrandnr^   r   Zarangero   expmathlogsincosr   r   r>   )r8   r   r   positionZdiv_termr/   r/   r0   r     s    

(  
z#PatchTSTPositionalEncoding._init_per   c                 C   s   | j rn| || jdd d d f  }| j| jd dd d f  }||jd | jdd}tj||fdd}n| || j }|S )Nr   r   r   r!   r"   )	r   r   r   r   expandrM   r   r'   cat)rD   r   r   Z
cls_tokensr   r/   r/   r0   rO     s     z"PatchTSTPositionalEncoding.forward)rP   rQ   rR   rS   r   rT   r<   staticmethodr   r   r   r'   rW   rO   rY   r/   r/   rE   r0   r   r  s
   r   c                       sH   e Zd ZdZeed fddZd	eje	e
 e	e
 edddZ  ZS )
r   z
    PatchTST Encoder
    r   c                    sT   t    d| _t | _t || _t fddt	 j
D | _|   d S )NFc                    s   g | ]}t  qS r/   )r   r   r[   r/   r0   r~     r   z,PatchTSTEncoder.__init__.<locals>.<listcomp>)r;   r<   r   r   embedderr   positional_encoderr   r   r   Znum_hidden_layerslayers	post_initr   rE   r[   r0   r<     s    
 zPatchTSTEncoder.__init__N)r   output_hidden_statesrJ   rK   c           	      C   s   |dur|n| j j}|dur |n| j j}| |}| |}|rDdnd}|rPdnd}| jD ]8}|rl||f }|||d}|d }|rZ||d f }qZt|||dS )a  
        Parameters:
            patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
                Past values of the time series
            output_hidden_states (bool, optional): Indicates if hidden states should be outputted.
            output_attentions (bool, optional): Indicates if attentions should be outputted.

        return:
            `BaseModelOutput`
        Nr/   r   r   r   )last_hidden_staterG   
attentions)r8   rJ   r   r   r   r   r
   )	rD   r   r   rJ   r   encoder_statesZall_attentionsZencoder_layerZlayer_outputsr/   r/   r0   rO     s    



zPatchTSTEncoder.forward)NN)rP   rQ   rR   rS   r   rT   r<   r'   rW   r   rV   r
   rO   rY   r/   r/   rE   r0   r     s     r   zG
    Base class for model's outputs, with potential hidden states.
    )Zcustom_introc                   @   s   e Zd ZU dZdZeej ed< dZ	ee
ej  ed< dZee
ej  ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed	< dS )
PatchTSTModelOutputa>  
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
        Sequence of hidden-states at the output of the last layer of the model.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
        one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of
        the model at the output of each layer plus the optional initial embedding outputs.
    mask (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`, *optional*):
        Bool masked tensor indicating which patches are masked
    loc (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
        Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length
    scale (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
        Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length
    patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
        Patched input to the Transformer
    Nr   rG   r   rv   locscaler   )rP   rQ   rR   rS   r   r   r'   FloatTensorr   rG   rX   r   rv   r   r   r   r/   r/   r/   r0   r     s   
r   z4
    Output type of [`PatchTSTForPretraining`].
    c                   @   sb   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dS )PatchTSTForPretrainingOutputa  
    loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
        MSE loss.
    prediction_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction outputs of the time series modeling heads.
    Nlossprediction_outputrG   r   )rP   rQ   rR   rS   r   r   r'   r   r   r   rG   rX   r   r/   r/   r/   r0   r   	  s
   
r   z3
    Output type of [`PatchTSTForRegression`].
    c                   @   sb   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dS )PatchTSTForRegressionOutputa  
    loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
        MSE loss.
    regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
        Regression outputs of the time series modeling heads.
    Nr   regression_outputsrG   r   )rP   rQ   rR   rS   r   r   r'   r   r   r   rG   rX   r   r/   r/   r/   r0   r     s
   
r   z3
    Output type of [`PatchTSTForPrediction`].
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dZeej ed< dZeej ed< dS )	PatchTSTForPredictionOutputa!  
    loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
        MSE loss.
    prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, -1)`):
        Prediction outputs of the time series modeling heads.
    attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
        heads.
    loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
        Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length
    scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
        Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length
    Nr   prediction_outputsrG   r   r   r   )rP   rQ   rR   rS   r   r   r'   r   r   r   rG   rX   r   r   r   r/   r/   r/   r0   r   1  s   
r   z7
    Output type of [`PatchTSTForClassification`].
    c                   @   sb   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dS )PatchTSTForClassificationOutputa  
    loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
        Total loss as the sum of the masked language modeling loss and the next sequence prediction
        (classification) loss.
    prediction_logits (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
        Prediction scores of the PatchTST modeling head (scores before SoftMax).
    Nr   prediction_logitsrG   r   )rP   rQ   rR   rS   r   r   r'   r   r   r   rG   rX   r   r/   r/   r/   r0   r   Q  s
   
r   z
    Base class for time series model's predictions outputs that contains the sampled values from the chosen
    distribution.
    c                   @   s$   e Zd ZU dZdZeej ed< dS )SamplePatchTSTOutputz
    sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, num_targets)`):
        Sampled values from the chosen distribution.
    N	sequences)	rP   rQ   rR   rS   r   r   r'   r   r   r/   r/   r/   r0   r   f  s   
r   )inputtargetrK   c                 C   s   |  | S )zc
    Computes the negative log likelihood loss from input distribution with respect to target.
    )Zlog_prob)r   r   r/   r/   r0   nllw  s    r   )input_tensorweightsrK   c                 C   sr   |durbt |dk| | t | }t j|r8|j|dn| dd}|rV|j|dn| | S | j|dS dS )aj  
    Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
    meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.

    Args:
        input_tensor (`torch.FloatTensor`):
            Input tensor, of which the average must be computed.
        weights (`torch.FloatTensor`, *optional*):
            Weights tensor, of the same shape as `input_tensor`.
        dim (`int`, *optional*):
            The dim along which to average `input_tensor`.

    Returns:
        `torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
    Nr   r"   r   min)r'   where
zeros_likeclampr   r   )r   r   r#   Zweighted_tensorZsum_weightsr/   r/   r0   weighted_average  s
    "r  c                       sL   e Zd ZdZed fddZejejeejejejf dddZ	  Z
S )PatchTSTStdScalerz
    Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
    subtracting from the mean and dividing by the standard deviation.
    r[   c                    sP   t    t|dr|jnd| _t|dr0|jnd| _t|drF|jnd| _d S )Nscaling_dimr   keepdimTminimum_scalegh㈵>)r;   r<   hasattrr  r#   r  r  ra   rE   r/   r0   r<     s    
zPatchTSTStdScaler.__init__r   observed_indicatorrK   c                 C   sz   |j | j| jd}|d}|| j | j| jd| }|| | d j | j| jd| }t|| j }|| | ||fS )C  
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
            observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Calculating the scale on the observed indicator.
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        r  r   r!   )r   r#   r  Z	clamp_minr'   sqrtr  )rD   r   r  denominatorr   Zvariancer   r/   r/   r0   rO     s    
"zPatchTSTStdScaler.forwardrP   rQ   rR   rS   r   r<   r'   rW   rX   rO   rY   r/   r/   rE   r0   r    s
   r  c                       sL   e Zd ZdZed fddZejejeejejejf dddZ	  Z
S )PatchTSTMeanScalerz
    Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
    accordingly.
    r[   c                    sf   t    t|dr|jnd| _t|dr0|jnd| _t|drF|jnd| _t|dr\|jnd | _d S )Nr  r   r  Tr  绽|=default_scale)r;   r<   r  r  r#   r  r  r  ra   rE   r/   r0   r<     s
    
zPatchTSTMeanScaler.__init__r  c           
      C   s   ||   j| jdd}|j| jdd}|tj|dd }| jdu rt|jdd}tj|ddd}t|| }n| jt| }t|dk||}tj|| j	d}|| }	| j
s|j| jd}|	t||fS )r	  Tr
  r   r   Nr   r"   )absr   r#   r'   r   r  Zsqueeze	ones_liker   r  r  r   )
rD   r   r  Zts_sumZnum_observedr   Z	batch_sumZbatch_observationsr  Zscaled_datar/   r/   r0   rO     s    
zPatchTSTMeanScaler.forwardr  r/   r/   rE   r0   r    s
   r  c                       sR   e Zd ZdZed fddZd	ejeej e	ejejejf dddZ
  ZS )
PatchTSTNOPScalerz|
    Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
    r[   c                    s:   t    t|dr|jnd| _t|dr0|jnd| _d S )Nr  r   r  T)r;   r<   r  r  r#   r  ra   rE   r/   r0   r<     s    
zPatchTSTNOPScaler.__init__Nr  c                 C   sB   t j|ddj| j| jd}t j|ddj| j| jd}|||fS )a  
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        Fr   )r#   r  )r'   r  r   r#   r  r   )rD   r   r  r   r   r/   r/   r0   rO     s    zPatchTSTNOPScaler.forward)N)rP   rQ   rR   rS   r   r<   r'   rW   r   rX   rO   rY   r/   r/   rE   r0   r    s    r  c                       sH   e Zd Zed fddZejejeejejejf dddZ  Z	S )PatchTSTScalerr[   c                    sN   t    |jdks|jdu r*t|| _n |jdkr@t|| _n
t|| _d S )Nr   Tr   )r;   r<   r   r  scalerr  r  ra   rE   r/   r0   r<     s    

zPatchTSTScaler.__init__r  c                 C   s   |  ||\}}}|||fS )a>  
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Input for scaler calculation
            observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Calculating the scale on the observed indicator.
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, um_input_channels)`)
        )r  )rD   r   r  r   r   r/   r/   r0   rO     s    zPatchTSTScaler.forward)
rP   rQ   rR   r   r<   r'   rW   rX   rO   rY   r/   r/   rE   r0   r    s   
r  c                
       s`   e Zd Zed fddZdejeej eej ee ee ee e	e
ef dddZ  ZS )	PatchTSTModelr[   c                    sf   t  | t|| _t|| _|j| _| jj}| jrBt|| _	n
t
 | _	t||d| _|   d S )N)r   )r;   r<   r  r  r   
patchifierdo_mask_inputr   r   maskingr   r   r   encoderr   r   rE   r/   r0   r<   ,  s    


zPatchTSTModel.__init__Nr   past_observed_maskfuture_valuesr   rJ   return_dictrK   c              	   C   s   |dur|n| j j}|dur |n| j j}|dur4|n| j j}|du rNt|}| ||\}}}	| |}
| jr| 	|
\}}n| 	|
d }}| j
|||d}|s|j|j|jf}||||	|
f }tdd |D S t|j|j|j|||	|
dS )a  
        Parameters:
            past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
                Input sequence to the model
            past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
            future_values (`torch.BoolTensor` of shape `(batch_size, prediction_length, num_input_channels)`, *optional*):
                Future target values associated with the `past_values`
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers
            output_attentions (`bool`, *optional*):
                Whether or not to return the output attention of all layers
            return_dict (`bool`, *optional*):
                Whether or not to return a `ModelOutput` instead of a plain tuple.

        Returns:
            `PatchTSTModelOutput` or tuple of `torch.Tensor` (if `return_dict`=False or `config.return_dict`=False)

        Examples:

        ```python
        >>> from huggingface_hub import hf_hub_download
        >>> import torch
        >>> from transformers import PatchTSTModel

        >>> file = hf_hub_download(
        ...     repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset"
        ... )
        >>> batch = torch.load(file)

        >>> model = PatchTSTModel.from_pretrained("namctin/patchtst_etth1_pretrain")

        >>> # during training, one provides both past and future values
        >>> outputs = model(
        ...     past_values=batch["past_values"],
        ...     future_values=batch["future_values"],
        ... )

        >>> last_hidden_state = outputs.last_hidden_state
        ```N)r   r   rJ   c                 s   s   | ]}|d ur|V  qd S r   r/   )r|   vr/   r/   r0   	<genexpr>  r   z(PatchTSTModel.forward.<locals>.<genexpr>)r   rG   r   rv   r   r   r   )r8   use_return_dictrJ   r   r'   r  r  r  r  r  r  r   rG   r   rX   r   )rD   r   r  r  r   rJ   r  Zscaled_past_valuesr   r   Zpatched_valuesZmasked_valuesrv   Zencoder_outputr   r/   r/   r0   rO   >  s6    6

zPatchTSTModel.forward)NNNNN)rP   rQ   rR   r   r<   r'   rW   r   rV   r   rX   r   rO   rY   r/   r/   rE   r0   r  *  s        
r  c                       s:   e Zd ZdZed fddZejejdddZ  Z	S )PatchTSTMaskPretrainHeadz-
    Pretraining head for mask modelling
    r[   c                    sH   t    |jdkr t|jnt | _t|j|j	| _
|j| _d S )Nr   )r;   r<   head_dropoutr   r   r   r   r?   r^   r   linearr   ra   rE   r/   r0   r<     s    
 z!PatchTSTMaskPretrainHead.__init__)	embeddingrK   c                 C   s:   |  | |}| jr6|ddddddddf }|S )a  
        Parameters:
            embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                    `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
                Embedding from the model
        Returns:
            `torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                            `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True

        Nr   )r$  r   r   )rD   r%  r/   r/   r0   rO     s     z PatchTSTMaskPretrainHead.forwardrd   r/   r/   rE   r0   r"    s   r"  z*
    The PatchTST for pretrain model.
    c                	       sX   e Zd Zed fddZdejeej ee ee ee e	e
ef dddZ  ZS )	PatchTSTForPretrainingr[   c                    s4   t  | d|_t|d| _t|| _|   d S )NTr[   )r;   r<   r  r  r   r"  headr   ra   rE   r/   r0   r<     s
    
zPatchTSTForPretraining.__init__N)r   r  r   rJ   r  rK   c                 C   s   |dur|n| j j}| j||||dd}| |j}tjdd}|||j}	|	jdd|j	 
 |j	
 d  }
|j}|s|f|d	d
  }|
dur|
f| n|}|S t|
|||jdS )a	  
        Parameters:
            past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
                Input sequence to the model
            past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers
            output_attentions (`bool`, *optional*):
                Whether or not to return the output attention of all layers
            return_dict (`bool`, *optional*): Whether or not to return a `ModelOutput` instead of a plain tuple.

        Returns:
            `PatchTSTForPretrainingOutput` or tuple of `torch.Tensor` (if `return_dict`=False or
            `config.return_dict`=False)

        Examples:

        ```python
        >>> from huggingface_hub import hf_hub_download
        >>> import torch
        >>> from transformers import PatchTSTConfig, PatchTSTForPretraining

        >>> file = hf_hub_download(
        ...     repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset"
        ... )
        >>> batch = torch.load(file)

        >>> # Config for random mask pretraining
        >>> config = PatchTSTConfig(
        ...     num_input_channels=7,
        ...     context_length=512,
        ...     patch_length=12,
        ...     stride=12,
        ...     mask_type='random',
        ...     random_mask_ratio=0.4,
        ...     use_cls_token=True,
        ... )
        >>> # Config for forecast mask pretraining
        >>> config = PatchTSTConfig(
        ...     num_input_channels=7,
        ...     context_length=512,
        ...     patch_length=12,
        ...     stride=12,
        ...     mask_type='forecast',
        ...     num_forecast_mask_patches=5,
        ...     use_cls_token=True,
        ... )
        >>> model = PatchTSTForPretraining(config)

        >>> # during training, one provides both past and future values
        >>> outputs = model(past_values=batch["past_values"])

        >>> loss = outputs.loss
        >>> loss.backward()
        ```NTr   r  r   rJ   r  noneZ	reductionr   r"   r  r   )r   r   rG   r   )r8   r!  r   r'  r   r   MSELossr   r   rv   r   rG   r   r   )rD   r   r  r   rJ   r  model_outputZx_hatr   loss_valZmasked_lossr   r   r/   r/   r0   rO     s(    E
$
zPatchTSTForPretraining.forward)NNNN)rP   rQ   rR   r   r<   r'   rW   r   rV   r   rX   r   rO   rY   r/   r/   rE   r0   r&    s       
r&  c                       s2   e Zd Zed fddZejdddZ  ZS )PatchTSTClassificationHeadr[   c                    sd   t    |j| _|j| _tjdd| _|jdkr>t|jnt	 | _
t|j|j |j| _d S Nr   Z	start_dimr   )r;   r<   r   pooling_typer   Flattenflattenr#  r   r   r   r?   r   r^   num_targetsr$  ra   rE   r/   r0   r<   ,  s    
 z#PatchTSTClassificationHead.__init__r%  c                 C   s   | j r$|dddddddf }nD| jdkr<|jdd}n,| jdkrV|jddj}ntd| j d| |}| | |}|S )	a[  
        Parameters:
            embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                     `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
                Embedding from the model
        Returns:
            `torch.Tensor` of shape `(bs, num_targets)`

        Nr   r   r!   r"   r   pooling operator  is not implemented yet)	r   r2  r   r   valuesr>   r4  r$  r   rD   r%  pooled_embeddingrc   r/   r/   r0   rO   4  s    



z"PatchTSTClassificationHead.forwardr   r/   r/   rE   r0   r/  +  s   r/  z0
    The PatchTST for classification model.
    c                       sb   e Zd Zed fddZedejeej ee	 ee	 ee	 ee	 e
eef dddZ  ZS )	PatchTSTForClassificationr[   c                    sB   t  | |jr"td d|_t|| _t|| _| 	  d S )N+Setting `do_mask_input` parameter to False.F)
r;   r<   r  loggerwarningr  r   r/  r'  r   ra   rE   r/   r0   r<   V  s    


z"PatchTSTForClassification.__init__Nr   target_valuesr  r   rJ   r  rK   c                 C   s   |dur|n| j j}| j||||dd}| |j}d}	|durRt }
|
||}	|s|f|dd  }|	durz|	f| n|}|S t|	||j|j	dS )ac  
        past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
            Input sequence to the model
        target_values (`torch.Tensor`, *optional*):
            Labels associates with the `past_values`
        past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:

            - 1 for values that are **observed**,
            - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).

        Examples:

        ```python
        >>> from transformers import PatchTSTConfig, PatchTSTForClassification

        >>> # classification task with two input channel2 and 3 classes
        >>> config = PatchTSTConfig(
        ...     num_input_channels=2,
        ...     num_targets=3,
        ...     context_length=512,
        ...     patch_length=12,
        ...     stride=12,
        ...     use_cls_token=True,
        ... )
        >>> model = PatchTSTForClassification(config=config)

        >>> # during inference, one only provides past values
        >>> past_values = torch.randn(20, 512, 2)
        >>> outputs = model(past_values=past_values)
        >>> labels = outputs.prediction_logits
        ```NTr(  r   r   )r   r   rG   r   )
r8   r!  r   r'  r   r   ZCrossEntropyLossr   rG   r   )rD   r   rA  r  r   rJ   r  r-  y_hatr.  r   r   r/   r/   r0   rO   d  s.    ,
z!PatchTSTForClassification.forward)NNNNN)rP   rQ   rR   r   r<   r   r'   rW   r   rV   r   rX   r   rO   rY   r/   r/   rE   r0   r<  P  s         
r<  z,
    The PatchTST for regression Model.
    c                       s6   e Zd Zdeed fddZejdddZ  Z	S )	PatchTSTPredictionHeadNr   c                    sF  t    |j| _|j| _|j| _|j| _| js6| jr>|j}n
|j| }| jst | _	t | _
t | _t| jD ]p}| jtjdd |du r| j	t||j n| j	|| | j
|jdkrt|jnt  qvnXtjdd| _|du rt||j| _n||| _|jdkr8t|jnt | _dS )a  
        num_patches (`int`):
            The number of patches in the input sequence.
        distribution_output (`DistributionOutput`, *optional*):
            The distribution output layer for probabilistic forecasting. If None, a linear output layer is used.
        r!   r1  Nr   )r;   r<   share_projectionr   r   r2  r^   r   r   projectionsdropoutsflattensr   r   r3  r?   prediction_lengthget_parameter_projectionr#  r   r   r4  
projectionr   )rD   r8   r   distribution_outputr=   r   rE   r/   r0   r<     s.    




*
zPatchTSTPredictionHead.__init__r6  c                 C   s  | j r$|dddddddf }n6| jdkr<|jdd}n| jdkrV|jddj}n|}| jsg }t| jD ]J}| j| |dd|ddf }| j	| |}| j
| |}|| qntj|dd}n| |}| |}| |}t|trtdd	 |D }n|dd}|S )
aj  
        Parameters:
            embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                     `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
                Embedding from the model
        Returns:
            `torch.Tensor` of shape `(bs, forecast_len, num_channels)`

        Nr   r   r!   r"   r   r   c                 s   s   | ]}| d dV  qdS )r!   r   N)r)   )r|   zr/   r/   r0   r     r   z1PatchTSTPredictionHead.forward.<locals>.<genexpr>)r   r2  r   r   r9  rD  r   r   rG  rF  rE  r   r'   r   r4  r   rJ  r   rX   r)   )rD   r%  r;  rc   r   r/   r/   r0   rO     s,    


 


zPatchTSTPredictionHead.forward)N)
rP   rQ   rR   r   rT   r<   r'   rW   rO   rY   r/   r/   rE   r0   rC    s   +rC  z,
    The PatchTST for prediction model.
    c                
       s   e Zd Zed fddZdejeej eej ee ee ee e	e
ef dddZe dejeej edd	d
Z  ZS )PatchTSTForPredictionr[   c                    s   t  | |jr"td d|_t|| _|jdkr>d | _n^|jdkrXt	|j
d| _nD|jdkrrt|j
d| _n*|jdkrt|j
d| _ntd|j t|| jjj| jd	| _|   d S )
Nr=  Fmse	student_tr"   normalnegative_binomialUnknown distribution output )rK  )r;   r<   r  r>  r?  r  r   r   rK  r   rH  r   r   r>   rC  r  r   r'  r   ra   rE   r/   r0   r<     s$    





zPatchTSTForPrediction.__init__Nr  c                 C   s   |dur|n| j j}| j||||dd}| |j}d}	| jrD|}
n||j |j }
|dur| jr| jj||j|jd}t	||}	t
|	}	ntjdd}||
|}	|j}|j}|s|
f|dd  }|	dur|	f| n|}|S t|	|
|j|j||d	S )
aV	  
        Parameters:
            past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
                Input sequence to the model
            past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
            future_values (`torch.Tensor` of shape `(bs, forecast_len, num_input_channels)`, *optional*):
                Future target values associated with the `past_values`
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers
            output_attentions (`bool`, *optional*):
                Whether or not to return the output attention of all layers
            return_dict (`bool`, *optional*):
                Whether or not to return a `ModelOutput` instead of a plain tuple.

        Returns:
            `PatchTSTForPredictionOutput` or tuple of `torch.Tensor` (if `return_dict`=False or
            `config.return_dict`=False)

        Examples:

        ```python
        >>> from huggingface_hub import hf_hub_download
        >>> import torch
        >>> from transformers import PatchTSTConfig, PatchTSTForPrediction

        >>> file = hf_hub_download(
        ...     repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset"
        ... )
        >>> batch = torch.load(file)

        >>> # Prediction task with 7 input channels and prediction length is 96
        >>> model = PatchTSTForPrediction.from_pretrained("namctin/patchtst_etth1_forecast")

        >>> # during training, one provides both past and future values
        >>> outputs = model(
        ...     past_values=batch["past_values"],
        ...     future_values=batch["future_values"],
        ... )

        >>> loss = outputs.loss
        >>> loss.backward()

        >>> # during inference, one only provides past values, the model outputs future values
        >>> outputs = model(past_values=batch["past_values"])
        >>> prediction_outputs = outputs.prediction_outputs
        ```NTr(  r   r   r   r*  r   r   )r   r   rG   r   r   r   )r8   r!  r   r'  r   rK  r   r   distributionr   r  r   r,  r   rG   r   )rD   r   r  r  r   rJ   r  r-  rB  r.  Z	y_hat_outrT  r   r   r   r   r/   r/   r0   rO   4  sH    =



zPatchTSTForPrediction.forwardr   r  rK   c                    sr   | j j}| |d|dd}| jr\| jj|j|j|jd  fddt|D }tj	|dd}n|j
d}t|d	S )
a   
        Generate sequences of sample predictions from a model with a probability distribution head.

        Parameters:
            past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Past values of the time series that serves as context in order to predict the future.
            past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).

        Return:
            [`SamplePatchTSTOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of
            samples, prediction_length, 1)` or `(batch_size, number of samples, prediction_length, num_input_channels)`
            for multivariate predictions.
        NF)r   r  r  r   rS  c                    s   g | ]}   qS r/   sampler{   rT  r/   r0   r~     r   z2PatchTSTForPrediction.generate.<locals>.<listcomp>r   r"   r   )r8   num_parallel_samplesrK  rT  r   r   r   r   r'   r   ro   r   rD   r   r  rZ  r   Zsamplesr/   rX  r0   generate  s    zPatchTSTForPrediction.generate)NNNNN)N)rP   rQ   rR   r   r<   r'   rW   r   rV   r   rX   r   rO   no_gradr   r\  rY   r/   r/   rE   r0   rM    s,         
m rM  c                       s8   e Zd ZdZd	ed fddZejdddZ  Z	S )
PatchTSTRegressionHeadz
    Regression head
    Nr[   c                    s   t    |j| _|j| _|j| _|| _|j|j }t	j
dd| _|jdkrXt	|jnt	 | _|d u r|t	||j| _n||| _d S r0  )r;   r<   Zoutput_rangey_ranger   r2  rK  r   r^   r   r3  r4  r#  r   r   r   r?   r5  rJ  rI  )rD   r8   rK  r=   rE   r/   r0   r<     s    
 zPatchTSTRegressionHead.__init__r6  c                 C   s   | j r$|dddddddf }nD| jdkr<|jdd}n,| jdkrV|jddj}ntd| j d| | |}| |}| j	du | j
du@ rt|| j
d	 | j
d   | j
d  }|S )
aY  
        Parameters:
            embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                    `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
                Embedding from the model
        Returns:
            `torch.Tensor` of shape `(bs, output_dim)`

        Nr   r   r!   r"   r   r7  r8  r   )r   r2  r   r   r9  r>   r   r4  rJ  rK  r_  r'   Zsigmoidr:  r/   r/   r0   rO     s    



(zPatchTSTRegressionHead.forward)Nrd   r/   r/   rE   r0   r^    s   r^  z,
    The PatchTST for regression model.
    c                       s   e Zd Zed fddZedejeej eej ee	 ee	 ee	 e
eef dddZe dejeej edd	d
Z  ZS )PatchTSTForRegressionr[   c                    s   t  | |jr"td d|_t|| _|jdkr>d | _n^|jdkrXt	|j
d| _nD|jdkrrt|j
d| _n*|jdkrt|j
d| _ntd|j t|| j| _|   d S )	Nr=  FrN  rO  r"   rP  rQ  rR  )r;   r<   r  r>  r?  r  r   r   rK  r   r5  r   r   r>   r^  r'  r   ra   rE   r/   r0   r<     s     





zPatchTSTForRegression.__init__Nr@  c                    s   |dur|n j j} j||||dd} |j}d}	|dur jr| j|}
t fdd|D }t|
|}	t	|	}	nt
jdd}	|	||}	|s|f|dd	  }|	dur|	f| n|}|S t|	||j|jd
S )a#  
        past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
            Input sequence to the model
        target_values (`torch.Tensor` of shape `(bs, num_input_channels)`):
            Target values associates with the `past_values`
        past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:

            - 1 for values that are **observed**,
            - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
            Whether or not to return a `ModelOutput` instead of a plain tuple.

        Examples:

        ```python
        >>> from transformers import PatchTSTConfig, PatchTSTForRegression

        >>> # Regression task with 6 input channels and regress 2 targets
        >>> model = PatchTSTForRegression.from_pretrained("namctin/patchtst_etth1_regression")

        >>> # during inference, one only provides past values, the model outputs future values
        >>> past_values = torch.randn(20, 512, 6)
        >>> outputs = model(past_values=past_values)
        >>> regression_outputs = outputs.regression_outputs
        ```NTr(  c                 3   s   | ]}| d  jjV  qdS )r   N)r*   r8   r5  )r|   itemrD   r/   r0   r   _  r   z0PatchTSTForRegression.forward.<locals>.<genexpr>r   r*  r   r   )r   r   rG   r   )r8   r!  r   r'  r   rK  rT  rX   r   r  r   r,  r   rG   r   )rD   r   rA  r  r   rJ   r  r-  rB  r   rT  r   r/   rb  r0   rO   )  s8    %


zPatchTSTForRegression.forwardrU  c                    sb   | j j}| |d|dd}| j|j  fddt|D }tj|ddd|| j j	}t
|d	S )
a  
        Generate sequences of sample predictions from a model with a probability distribution head.

        Parameters:
            past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Past values of the time series that serves as context in order to predict the future.
            past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).

        Return:
            [`SamplePatchTSTOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of
            samples, num_targets)`.
        NF)r   rA  r  r   c                    s   g | ]}   qS r/   rV  r{   rX  r/   r0   r~     r   z2PatchTSTForRegression.generate.<locals>.<listcomp>r   r"   r   rY  )r8   rZ  rK  rT  r   r   r'   r   r*   r5  r   r[  r/   rX  r0   r\  s  s    zPatchTSTForRegression.generate)NNNNN)N)rP   rQ   rR   r   r<   r   r'   rW   r   rV   r   rX   r   rO   r]  r   r\  rY   r/   r/   rE   r0   r`  	  s.        
I r`  )r  r   rM  r&  r`  r<  )Nr   N)NFr   )Nr   )NN)LrS   r   dataclassesr   typingr   r   r   r'   r   Zactivationsr   Zmodeling_flash_attention_utilsr	   Zmodeling_outputsr
   Zmodeling_utilsr   r   Zprocessing_utilsr   Ztime_series_utilsr   r   r   utilsr   r   r   Zconfiguration_patchtstr   Z
get_loggerrP   r>  r   rW   rU   r1   r2   rZ   listrV   rT   rx   r   r   r   r   r   r   r   r   r   r   r   r   r   r   distributionsDistributionr   r  r  r  r  r  r  r"  r&  r/  r<  rC  rM  r^  r`  __all__r/   r/   r/   r0   <module>   s   

   X   =  
D0< %$8>
$7po%W` =7 