a
    hK                     @   s  d Z ddlZddlZddlZddlmZ ddlmZmZ ddlZ	ddl
mZ ddlmZmZmZmZ ddlmZ dd	lmZ eeZeG d
d dZG dd dZeG dd deZG dd dZG dd deZG dd deZG dd deZG dd deZG dd deeZ dS )zJ
Callbacks to use with the Trainer class and customize the training loop.
    N)	dataclass)OptionalUnion)tqdm   )HPSearchBackendIntervalStrategySaveStrategy
has_length)TrainingArguments)loggingc                   @   sv  e Zd ZU dZdZee ed< dZe	ed< dZ
e	ed< dZe	ed< dZe	ed	< dZe	ed
< dZee	 ed< dZe	ed< dZe	ed< dZeed< dZeeeef  ed< dZee ed< dZee	 ed< dZee ed< dZeed< dZeed< dZeed< dZee ed< dZeeeeee	ef f ed< dZ ed ed< dd Z!eddd Z"e#edd!d"Z$d#d$ Z%d%d& Z&dS )'TrainerStatea  
    A class containing the [`Trainer`] inner state that will be saved along the model and optimizer when checkpointing
    and passed to the [`TrainerCallback`].

    <Tip>

    In all this class, one step is to be understood as one update step. When using gradient accumulation, one update
    step may require several forward and backward passes: if you use `gradient_accumulation_steps=n`, then one update
    step requires going through *n* batches.

    </Tip>

    Args:
        epoch (`float`, *optional*):
            Only set during training, will represent the epoch the training is at (the decimal part being the
            percentage of the current epoch completed).
        global_step (`int`, *optional*, defaults to 0):
            During training, represents the number of update steps completed.
        max_steps (`int`, *optional*, defaults to 0):
            The number of update steps to do during the current training.
        logging_steps (`int`, *optional*, defaults to 500):
            Log every X updates steps
        eval_steps (`int`, *optional*):
            Run an evaluation every X steps.
        save_steps (`int`, *optional*, defaults to 500):
            Save checkpoint every X updates steps.
        train_batch_size (`int`, *optional*):
            The batch size for the training dataloader. Only needed when
            `auto_find_batch_size` has been used.
        num_input_tokens_seen (`int`, *optional*, defaults to 0):
            When tracking the inputs tokens, the number of tokens seen during training (number of input tokens, not the
            number of prediction tokens).
        total_flos (`float`, *optional*, defaults to 0):
            The total number of floating operations done by the model since the beginning of training (stored as floats
            to avoid overflow).
        log_history (`list[dict[str, float]]`, *optional*):
            The list of logs done since the beginning of training.
        best_metric (`float`, *optional*):
            When tracking the best model, the value of the best metric encountered so far.
        best_global_step (`int`, *optional*):
            When tracking the best model, the step at which the best metric was encountered.
            Used for setting `best_model_checkpoint`.
        best_model_checkpoint (`str`, *optional*):
            When tracking the best model, the value of the name of the checkpoint for the best model encountered so
            far.
        is_local_process_zero (`bool`, *optional*, defaults to `True`):
            Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
            several machines) main process.
        is_world_process_zero (`bool`, *optional*, defaults to `True`):
            Whether or not this process is the global main process (when training in a distributed fashion on several
            machines, this is only going to be `True` for one process).
        is_hyper_param_search (`bool`, *optional*, defaults to `False`):
            Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will
            impact the way data will be logged in TensorBoard.
        stateful_callbacks (`list[StatefulTrainerCallback]`, *optional*):
            Callbacks attached to the `Trainer` that should have their states be saved or restored.
            Relevant callbacks should implement a `state` and `from_state` function.
    Nepochr   global_step	max_stepsi  logging_steps
eval_steps
save_stepstrain_batch_sizenum_train_epochsnum_input_tokens_seen
total_floslog_historybest_metricbest_global_stepbest_model_checkpointTis_local_process_zerois_world_process_zeroFis_hyper_param_search
trial_nametrial_paramsTrainerCallbackstateful_callbacksc                 C   s   | j d u rg | _ | jd u r"i | _nt| jtr0n~i }| jD ]l}t|tsZtdt| |jj}||v rt|| t	s|| g||< || 
|  q:| ||< q:|| _d S )NzNAll callbacks passed to be saved must inherit `ExportableState`, but received )r   r"   
isinstancedictExportableState	TypeErrortype	__class____name__listappendstate)selfr"   callbackname r0   Y/var/www/html/assistant/venv/lib/python3.9/site-packages/transformers/trainer_callback.py__post_init__u   s&    



zTrainerState.__post_init__)	json_pathc                 C   sV   t jt| dddd }t|ddd}|| W d   n1 sH0    Y  dS )	zDSave the content of this instance in JSON format inside `json_path`.   T)indent	sort_keys
wutf-8encodingN)jsondumpsdataclassesasdictopenwrite)r-   r3   Zjson_stringfr0   r0   r1   save_to_json   s    zTrainerState.save_to_jsonc                 C   sH   t |dd}| }W d   n1 s*0    Y  | f i t|S )z3Create an instance from the content of `json_path`.r9   r:   N)r@   readr<   loads)clsr3   rB   textr0   r0   r1   load_from_json   s    &zTrainerState.load_from_jsonc                 C   sN   dD ]D}t || d}|dur|dk r6t|| }t| | d| qdS )z
        Calculates and stores the absolute value for logging,
        eval, and save steps based on if it was a proportion
        or not.
        )r   evalsaveZ_stepsNr   )getattrmathceilsetattr)r-   argsr   Z	step_kindZ	num_stepsr0   r0   r1   compute_steps   s    zTrainerState.compute_stepsc                 C   s   |j dur"|jdur"| |j| _d| _|dur\ddlm} |jtjkrN|j	n|}||| _|| _
|| _| | _| | _dS )zI
        Stores the initial training references needed in `self`
        Nr   )	hp_params)Zhp_nameZ_trialr   r    Ztransformers.integrationsrQ   Zhp_search_backendr   ZSIGOPTassignmentsr   r   r   r   )r-   Ztrainerr   r   ZtrialrQ   rR   r0   r0   r1   init_training_references   s    

z%TrainerState.init_training_references)'r)   
__module____qualname____doc__r   r   float__annotations__r   intr   r   r   r   r   r   r   r   r   r*   r$   strr   r   r   r   boolr   r   r   r    r   r"   r2   rC   classmethodrH   rP   rS   r0   r0   r0   r1   r   #   s6   
; r   c                   @   s*   e Zd ZdZedddZedd ZdS )r%   aj  
    A class for objects that include the ability to have its state
    be saved during `Trainer._save_checkpoint` and loaded back in during
    `Trainer._load_from_checkpoint`.

    These must implement a `state` function that gets called during the respective
    Trainer function call. It should only include parameters and attributes needed to
    recreate the state at a particular time, to avoid utilizing pickle/maintain standard
    file IO writing.

    Example:

    ```python
    class EarlyStoppingCallback(TrainerCallback, ExportableState):
        def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
            self.early_stopping_patience = early_stopping_patience
            self.early_stopping_threshold = early_stopping_threshold
            # early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
            self.early_stopping_patience_counter = 0

        def state(self) -> dict:
            return {
                "args": {
                    "early_stopping_patience": self.early_stopping_patience,
                    "early_stopping_threshold": self.early_stopping_threshold,
                },
                "attributes": {
                    "early_stopping_patience_counter": self.early_stopping_patience_counter,
                }
            }
    ```returnc                 C   s   t dd S )Nz<You must implement a `state` function to utilize this class.)NotImplementedErrorr-   r0   r0   r1   r,      s    zExportableState.statec                 C   s8   | f i |d }|d   D ]\}}t||| q|S )NrO   
attributes)itemsrN   )rF   r,   instancekvr0   r0   r1   
from_state   s    zExportableState.from_stateN)r)   rT   rU   rV   r$   r,   r\   rf   r0   r0   r0   r1   r%      s    r%   c                   @   st   e Zd ZU dZdZeed< dZeed< dZeed< dZ	eed< dZ
eed< dd	 Zd
d Zdd ZedddZdS )TrainerControlaA  
    A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some
    switches in the training loop.

    Args:
        should_training_stop (`bool`, *optional*, defaults to `False`):
            Whether or not the training should be interrupted.

            If `True`, this variable will not be set back to `False`. The training will just stop.
        should_epoch_stop (`bool`, *optional*, defaults to `False`):
            Whether or not the current epoch should be interrupted.

            If `True`, this variable will be set back to `False` at the beginning of the next epoch.
        should_save (`bool`, *optional*, defaults to `False`):
            Whether or not the model should be saved at this step.

            If `True`, this variable will be set back to `False` at the beginning of the next step.
        should_evaluate (`bool`, *optional*, defaults to `False`):
            Whether or not the model should be evaluated at this step.

            If `True`, this variable will be set back to `False` at the beginning of the next step.
        should_log (`bool`, *optional*, defaults to `False`):
            Whether or not the logs should be reported at this step.

            If `True`, this variable will be set back to `False` at the beginning of the next step.
    Fshould_training_stopshould_epoch_stopshould_saveshould_evaluate
should_logc                 C   s
   d| _ dS )z<Internal method that resets the variable for a new training.FN)rh   r`   r0   r0   r1   _new_training  s    zTrainerControl._new_trainingc                 C   s
   d| _ dS )z9Internal method that resets the variable for a new epoch.FN)ri   r`   r0   r0   r1   
_new_epoch  s    zTrainerControl._new_epochc                 C   s   d| _ d| _d| _dS )z8Internal method that resets the variable for a new step.FN)rj   rk   rl   r`   r0   r0   r1   	_new_step  s    zTrainerControl._new_stepr]   c                 C   s    | j | j| j| j| jdi dS )Nrh   ri   rj   rk   rl   rO   ra   rp   r`   r0   r0   r1   r,     s    zTrainerControl.stateN)r)   rT   rU   rV   rh   r[   rX   ri   rj   rk   rl   rm   rn   ro   r$   r,   r0   r0   r0   r1   rg      s   
rg   c                   @   s  e Zd ZdZeeedddZeeedddZeeedddZ	eeedd	d
Z
eeedddZeeedddZeeedddZeeedddZeeedddZeeedddZeeedddZeeedddZeeedddZeeedddZeeeddd Zd!S )"r!   a	  
    A class for objects that will inspect the state of the training loop at some events and take some decisions. At
    each of those events the following arguments are available:

    Args:
        args ([`TrainingArguments`]):
            The training arguments used to instantiate the [`Trainer`].
        state ([`TrainerState`]):
            The current state of the [`Trainer`].
        control ([`TrainerControl`]):
            The object that is returned to the [`Trainer`] and can be used to make some decisions.
        model ([`PreTrainedModel`] or `torch.nn.Module`):
            The model being trained.
        tokenizer ([`PreTrainedTokenizer`]):
            The tokenizer used for encoding the data. This is deprecated in favour of `processing_class`.
        processing_class ([`PreTrainedTokenizer` or `BaseImageProcessor` or `ProcessorMixin` or `FeatureExtractionMixin`]):
            The processing class used for encoding the data. Can be a tokenizer, a processor, an image processor or a feature extractor.
        optimizer (`torch.optim.Optimizer`):
            The optimizer used for the training steps.
        lr_scheduler (`torch.optim.lr_scheduler.LambdaLR`):
            The scheduler used for setting the learning rate.
        train_dataloader (`torch.utils.data.DataLoader`, *optional*):
            The current dataloader used for training.
        eval_dataloader (`torch.utils.data.DataLoader`, *optional*):
            The current dataloader used for evaluation.
        metrics (`dict[str, float]`):
            The metrics computed by the last evaluation phase.

            Those are only accessible in the event `on_evaluate`.
        logs  (`dict[str, float]`):
            The values to log.

            Those are only accessible in the event `on_log`.

    The `control` object is the only one that can be changed by the callback, in which case the event that changes it
    should return the modified version.

    The argument `args`, `state` and `control` are positionals for all events, all the others are grouped in `kwargs`.
    You can unpack the ones you need in the signature of the event using them. As an example, see the code of the
    simple [`~transformers.PrinterCallback`].

    Example:

    ```python
    class PrinterCallback(TrainerCallback):
        def on_log(self, args, state, control, logs=None, **kwargs):
            _ = logs.pop("total_flos", None)
            if state.is_local_process_zero:
                print(logs)
    ```rO   r,   controlc                 K   s   dS )zS
        Event called at the end of the initialization of the [`Trainer`].
        Nr0   r-   rO   r,   rs   kwargsr0   r0   r1   on_init_end^  s    zTrainerCallback.on_init_endc                 K   s   dS )z<
        Event called at the beginning of training.
        Nr0   rt   r0   r0   r1   on_train_begind  s    zTrainerCallback.on_train_beginc                 K   s   dS )z6
        Event called at the end of training.
        Nr0   rt   r0   r0   r1   on_train_endj  s    zTrainerCallback.on_train_endc                 K   s   dS )z<
        Event called at the beginning of an epoch.
        Nr0   rt   r0   r0   r1   on_epoch_beginp  s    zTrainerCallback.on_epoch_beginc                 K   s   dS )z6
        Event called at the end of an epoch.
        Nr0   rt   r0   r0   r1   on_epoch_endv  s    zTrainerCallback.on_epoch_endc                 K   s   dS )z
        Event called at the beginning of a training step. If using gradient accumulation, one training step might take
        several inputs.
        Nr0   rt   r0   r0   r1   on_step_begin|  s    zTrainerCallback.on_step_beginc                 K   s   dS )zv
        Event called before the optimizer step but after gradient clipping. Useful for monitoring gradients.
        Nr0   rt   r0   r0   r1   on_pre_optimizer_step  s    z%TrainerCallback.on_pre_optimizer_stepc                 K   s   dS )z}
        Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients.
        Nr0   rt   r0   r0   r1   on_optimizer_step  s    z!TrainerCallback.on_optimizer_stepc                 K   s   dS )zU
        Event called at the end of an substep during gradient accumulation.
        Nr0   rt   r0   r0   r1   on_substep_end  s    zTrainerCallback.on_substep_endc                 K   s   dS )z
        Event called at the end of a training step. If using gradient accumulation, one training step might take
        several inputs.
        Nr0   rt   r0   r0   r1   on_step_end  s    zTrainerCallback.on_step_endc                 K   s   dS )z9
        Event called after an evaluation phase.
        Nr0   rt   r0   r0   r1   on_evaluate  s    zTrainerCallback.on_evaluatec                 K   s   dS )z=
        Event called after a successful prediction.
        Nr0   )r-   rO   r,   rs   metricsru   r0   r0   r1   
on_predict  s    zTrainerCallback.on_predictc                 K   s   dS )z7
        Event called after a checkpoint save.
        Nr0   rt   r0   r0   r1   on_save  s    zTrainerCallback.on_savec                 K   s   dS )z;
        Event called after logging the last logs.
        Nr0   rt   r0   r0   r1   on_log  s    zTrainerCallback.on_logc                 K   s   dS )z7
        Event called after a prediction step.
        Nr0   rt   r0   r0   r1   on_prediction_step  s    z"TrainerCallback.on_prediction_stepN)r)   rT   rU   rV   r   r   rg   rv   rw   rx   ry   rz   r{   r|   r}   r~   r   r   r   r   r   r   r0   r0   r0   r1   r!   )  s    3r!   c                   @   sR  e Zd ZdZdd Zdd Zdd Zdd	 Zed
d Z	e
eedddZe
eedddZe
eedddZe
eedddZe
eedddZe
eedddZe
eedddZe
eedddZe
eedddZe
eeddd Ze
eedd!d"Ze
eedd#d$Ze
eedd%d&Ze
eedd'd(Ze
eedd)d*Zd+d, Zd-S ).CallbackHandlerz>Internal class that just calls the list of callbacks in order.c                 C   sf   g | _ |D ]}| | q
|| _|| _|| _|| _d | _d | _tdd | j D sbt	
d| j  d S )Nc                 s   s   | ]}t |tV  qd S N)r#   DefaultFlowCallback.0cbr0   r0   r1   	<genexpr>      z+CallbackHandler.__init__.<locals>.<genexpr>zThe Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You
should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list ofcallbacks is
:)	callbacksadd_callbackmodelprocessing_class	optimizerlr_schedulertrain_dataloadereval_dataloaderanyloggerwarningcallback_list)r-   r   r   r   r   r   r   r0   r0   r1   __init__  s    zCallbackHandler.__init__c                 C   sh   t |tr| n|}t |tr"|n|j}|dd | jD v rXtd| dd | j  | j| d S )Nc                 S   s   g | ]
}|j qS r0   )r(   )r   cr0   r0   r1   
<listcomp>  r   z0CallbackHandler.add_callback.<locals>.<listcomp>zYou are adding a zH to the callbacks of this Trainer, but there is already one. The currentzlist of callbacks is
:)r#   r'   r(   r   r   r   r   r+   )r-   r.   r   Zcb_classr0   r0   r1   r     s    
zCallbackHandler.add_callbackc                 C   sb   t |tr6| jD ]"}t ||r| j| |  S qn(| jD ] }||kr<| j| |  S q<d S r   r#   r'   r   remover-   r.   r   r0   r0   r1   pop_callback  s    



zCallbackHandler.pop_callbackc                 C   sD   t |tr4| jD ] }t ||r| j|  d S qn| j| d S r   r   r   r0   r0   r1   remove_callback  s    



zCallbackHandler.remove_callbackc                 C   s   d dd | jD S )Nr7   c                 s   s   | ]}|j jV  qd S r   )r(   r)   r   r0   r0   r1   r     r   z0CallbackHandler.callback_list.<locals>.<genexpr>)joinr   r`   r0   r0   r1   r     s    zCallbackHandler.callback_listrr   c                 C   s   |  d|||S )Nrv   
call_eventr-   rO   r,   rs   r0   r0   r1   rv     s    zCallbackHandler.on_init_endc                 C   s   d|_ | d|||S )NFrw   )rh   r   r   r0   r0   r1   rw     s    zCallbackHandler.on_train_beginc                 C   s   |  d|||S )Nrx   r   r   r0   r0   r1   rx     s    zCallbackHandler.on_train_endc                 C   s   d|_ | d|||S )NFry   )ri   r   r   r0   r0   r1   ry     s    zCallbackHandler.on_epoch_beginc                 C   s   |  d|||S )Nrz   r   r   r0   r0   r1   rz     s    zCallbackHandler.on_epoch_endc                 C   s"   d|_ d|_d|_| d|||S )NFr{   )rl   rk   rj   r   r   r0   r0   r1   r{     s    zCallbackHandler.on_step_beginc                 C   s   |  d|||S )Nr|   r   r   r0   r0   r1   r|     s    z%CallbackHandler.on_pre_optimizer_stepc                 C   s   |  d|||S )Nr}   r   r   r0   r0   r1   r}     s    z!CallbackHandler.on_optimizer_stepc                 C   s   |  d|||S )Nr~   r   r   r0   r0   r1   r~     s    zCallbackHandler.on_substep_endc                 C   s   |  d|||S )Nr   r   r   r0   r0   r1   r     s    zCallbackHandler.on_step_endc                 C   s   d|_ | jd||||dS )NFr   r   )rk   r   r-   rO   r,   rs   r   r0   r0   r1   r     s    zCallbackHandler.on_evaluatec                 C   s   | j d||||dS )Nr   r   r   r   r0   r0   r1   r     s    zCallbackHandler.on_predictc                 C   s   d|_ | d|||S )NFr   )rj   r   r   r0   r0   r1   r     s    zCallbackHandler.on_savec                 C   s   d|_ | jd||||dS )NFr   )logs)rl   r   )r-   rO   r,   rs   r   r0   r0   r1   r   #  s    zCallbackHandler.on_logc                 C   s   |  d|||S )Nr   r   r   r0   r0   r1   r   '  s    z"CallbackHandler.on_prediction_stepc              
   K   sP   | j D ]D}t|||||f| j| j| j| j| j| jd|}|d ur|}q|S )N)r   r   r   r   r   r   )r   rK   r   r   r   r   r   r   )r-   eventrO   r,   rs   ru   r.   resultr0   r0   r1   r   *  s$    

zCallbackHandler.call_eventN)r)   rT   rU   rV   r   r   r   r   propertyr   r   r   rg   rv   rw   rx   ry   rz   r{   r|   r}   r~   r   r   r   r   r   r   r   r0   r0   r0   r1   r     s.   	
r   c                   @   s4   e Zd ZdZeeedddZeeedddZdS )r   zx
    A [`TrainerCallback`] that handles the default flow of the training loop for logs, evaluation and checkpoints.
    rr   c                 K   s   |j dkr|jrd|_|jtjkr8|j |j dkr8d|_|jtjkrf|j |j dkrf|j	|j krfd|_
|jtjkr|jdkr|j |j dkrd|_|j |jkrd|_|jtjkrd|_|S )Nr   Tr   )r   Zlogging_first_steprl   logging_strategyr   ZSTEPSr   eval_strategyr   
eval_delayrk   save_strategyr	   r   rj   r   rh   rt   r0   r0   r1   r   C  s.    


zDefaultFlowCallback.on_step_endc                 K   sF   |j tjkrd|_|jtjkr0|j|jkr0d|_|jt	jkrBd|_
|S )NT)r   r   EPOCHrl   r   r   r   rk   r   r	   rj   rt   r0   r0   r1   rz   c  s    z DefaultFlowCallback.on_epoch_endN)	r)   rT   rU   rV   r   r   rg   r   rz   r0   r0   r0   r1   r   >  s    r   c                   @   s\   e Zd ZdZdedddZdd Zdd	 ZdddZdd Z	dd Z
dddZdd Zd
S )ProgressCallbackz
    A [`TrainerCallback`] that displays the progress of training or evaluation.
    You can modify `max_str_len` to control how long strings are truncated when logging.
    d   )max_str_lenc                 C   s   d| _ d| _|| _dS )a!  
        Initialize the callback with optional max_str_len parameter to control string truncation length.

        Args:
            max_str_len (`int`):
                Maximum length of strings to display in logs.
                Longer strings will be truncated with a message.
        N)training_barprediction_barr   )r-   r   r0   r0   r1   r   y  s    	zProgressCallback.__init__c                 K   s    |j rt|jdd| _d| _d S )NT)totaldynamic_ncolsr   )r   r   r   r   current_steprt   r0   r0   r1   rw     s    zProgressCallback.on_train_beginc                 K   s&   |j r"| j|j| j  |j| _d S r   )r   r   updater   r   rt   r0   r0   r1   r     s    zProgressCallback.on_step_endNc                 K   sB   |j r>t|r>| jd u r2tt|| jd u dd| _| jd d S )NT)r   Zleaver   r   )r   r
   r   r   lenr   r   )r-   rO   r,   rs   r   ru   r0   r0   r1   r     s    
z#ProgressCallback.on_prediction_stepc                 K   s$   |j r | jd ur| j  d | _d S r   r   r   closert   r0   r0   r1   r     s    

zProgressCallback.on_evaluatec                 K   s$   |j r | jd ur| j  d | _d S r   r   rt   r0   r0   r1   r     s    

zProgressCallback.on_predictc           
      K   s   |j r| jd uri }| D ]F\}}t|trZt|| jkrZdt| d| j d||< q|||< q|dd }	d|v rt|d d|d< | j	t| d S )Nz%[String too long to display, length: z > z/. Consider increasing `max_str_len` if needed.]r   r   r4   )
r   r   rb   r#   rZ   r   r   poproundrA   )
r-   rO   r,   rs   r   ru   Zshallow_logsrd   re   _r0   r0   r1   r     s    
zProgressCallback.on_logc                 K   s   |j r| j  d | _d S r   )r   r   r   rt   r0   r0   r1   rx     s    
zProgressCallback.on_train_end)r   )N)N)r)   rT   rU   rV   rY   r   rw   r   r   r   r   r   rx   r0   r0   r0   r1   r   s  s   

r   c                   @   s   e Zd ZdZdddZdS )PrinterCallbackz?
    A bare [`TrainerCallback`] that just prints the logs.
    Nc                 K   s   | dd }|jrt| d S )Nr   )r   r   print)r-   rO   r,   rs   r   ru   r   r0   r0   r1   r     s    zPrinterCallback.on_log)N)r)   rT   rU   rV   r   r0   r0   r0   r1   r     s   r   c                   @   sL   e Zd ZdZdeee dddZdd Zd	d
 Z	dd Z
edddZdS )EarlyStoppingCallbacka1  
    A [`TrainerCallback`] that handles early stopping.

    Args:
        early_stopping_patience (`int`):
            Use with `metric_for_best_model` to stop training when the specified metric worsens for
            `early_stopping_patience` evaluation calls.
        early_stopping_threshold(`float`, *optional*):
            Use with TrainingArguments `metric_for_best_model` and `early_stopping_patience` to denote how much the
            specified metric must improve to satisfy early stopping conditions. `

    This callback depends on [`TrainingArguments`] argument *load_best_model_at_end* functionality to set best_metric
    in [`TrainerState`]. Note that if the [`TrainingArguments`] argument *save_steps* differs from *eval_steps*, the
    early stopping will not occur until the next save step.
    r           early_stopping_patienceearly_stopping_thresholdc                 C   s   || _ || _d| _d S )Nr   r   r   early_stopping_patience_counter)r-   r   r   r0   r0   r1   r     s    zEarlyStoppingCallback.__init__c                 C   sV   |j rtjntj}|jd u s<|||jrDt||j | jkrDd| _n|  jd7  _d S )Nr   r   )Zgreater_is_betternpZgreaterlessr   absr   r   )r-   rO   r,   rs   metric_valueoperatorr0   r0   r1   check_metric_value  s    

z(EarlyStoppingCallback.check_metric_valuec                 K   s:   |j std |jd us"J d|jtjks6J dd S )NzUsing EarlyStoppingCallback without load_best_model_at_end=True. Once training is finished, the best model will not be loaded automatically.zBEarlyStoppingCallback requires metric_for_best_model to be definedzAEarlyStoppingCallback requires IntervalStrategy of steps or epoch)Zload_best_model_at_endr   r   metric_for_best_modelr   r   NOrt   r0   r0   r1   rw     s    z$EarlyStoppingCallback.on_train_beginc                 K   sh   |j }|dsd| }||}|d u rBtd| d d S | |||| | j| jkrdd|_d S )NZeval_z@early stopping required metric_for_best_model, but did not find z so early stopping is disabledT)	r   
startswithgetr   r   r   r   r   rh   )r-   rO   r,   rs   r   ru   Zmetric_to_checkr   r0   r0   r1   r     s    



z!EarlyStoppingCallback.on_evaluater]   c                 C   s   | j | jdd| jidS )Nr   r   rq   r   r`   r0   r0   r1   r,     s    zEarlyStoppingCallback.stateN)r   r   )r)   rT   rU   rV   rY   r   rW   r   r   rw   r   r$   r,   r0   r0   r0   r1   r     s   r   )!rV   r>   r<   rL   r   typingr   r   numpyr   Z	tqdm.autor   Ztrainer_utilsr   r   r	   r
   Ztraining_argsr   utilsr   Z
get_loggerr)   r   r   r%   rg   r!   r   r   r   r   r   r0   r0   r0   r1   <module>   s2   
 ,=  5J