a
    h!                     @  s  d dl mZ d dlZd dlZd dlZd dlZd dlmZ d dlm	Z	 d dl
mZmZmZ d dlZd dlZd dlZd dlmZ d dlmZmZ d dlmZ d d	lmZ d d
lmZ d dlmZ d dlmZmZm Z  d dl!m"Z" d dl#m$Z$ d dl%m&Z&m'Z'm(Z( d dl)m*Z*m+Z+m,Z, ddl-m.Z. ddl/m0Z0 e, r@d dl1m2Z2m3Z3 e4e5Z6erhd dl7m8Z8 d dl9m:Z: G dd deZ;G dd deZ<G dd deZ=G dd dZ>dS )    )annotationsN)Iterable)Path)TYPE_CHECKINGAnyCallable)version)Tensornn)	Optimizer)LambdaLR)
DataLoader)trange)TrainerCallbackTrainerControlTrainerState)NoDuplicatesDataLoader)SentenceLabelDataset)BatchSamplersMultiDatasetBatchSamplers$SentenceTransformerTrainingArguments)batch_to_devicefullnameis_datasets_available   )SentenceEvaluator)ModelCardTemplate)DatasetDatasetDict)InputExample)SentenceTransformerc                      sj   e Zd ZdZddddd fddZd	dd
ddZdddddddddZddddddddZ  ZS )SaveModelCallbackaa  A Callback to save the model to the `output_dir`.

    There are two cases:
    1. save_best_model is True and evaluator is defined:
        We save on evaluate, but only if the new model is better than the currently saved one
        according to the evaluator.
    2. If evaluator is not defined:
        We save after the model has been trained.
    strSentenceEvaluator | NoneboolNone)
output_dir	evaluatorsave_best_modelreturnc                   s&   t    || _|| _|| _d | _d S N)super__init__r&   r'   r(   best_metric)selfr&   r'   r(   	__class__ [/var/www/html/assistant/venv/lib/python3.9/site-packages/sentence_transformers/fit_mixin.pyr,   7   s
    
zSaveModelCallback.__init__float)
new_metricr)   c                 C  s"   t | jddr|| jkS || jk S )NZgreater_is_betterT)getattrr'   r-   )r.   r4   r1   r1   r2   	is_better>   s    
zSaveModelCallback.is_betterr   r   r   dict[str, Any]r    )argsstatecontrolmetricsmodelr)   c           
      K  sd   | j d ur`| jr`t| j dd}| D ]8\}}	||r&| jd u sL| |	r&|	| _|| j q&d S NZprimary_metricr'   )	r'   r(   r5   itemsendswithr-   r6   saver&   )
r.   r8   r9   r:   r;   r<   kwargs
metric_keykeyvaluer1   r1   r2   on_evaluateC   s    	
zSaveModelCallback.on_evaluater8   r9   r:   r<   r)   c                 K  s   | j d u r|| j d S r*   )r'   r@   r&   )r.   r8   r9   r:   r<   rA   r1   r1   r2   on_train_endT   s    
zSaveModelCallback.on_train_end)	__name__
__module____qualname____doc__r,   r6   rE   rG   __classcell__r1   r1   r/   r2   r!   ,   s
   
r!   c                      sB   e Zd ZdZddddd fddZd	d
ddddddZ  ZS )EvaluatorCallbackzThe SentenceTransformers.fit method always ran the evaluator on every epoch,
    in addition to every "evaluation_steps". This callback is responsible for that.

    The `.trainer` must be provided after the trainer has been created.
    Nr   
str | Noner%   )r'   output_pathr)   c                   sR   t    || _|| _| jd urBtj| jd| _tj| jdd d| _d | _	d S )NevalTexist_ok)
r+   r,   r'   rO   ospathjoinmakedirsmetric_key_prefixtrainer)r.   r'   rO   r/   r1   r2   r,   g   s    

zEvaluatorCallback.__init__r   r   r   r    rF   c                 K  s   | j || j|j|jd}t|ts*d|i}t| D ]0}|| j	 ds6|
||| j	 d| < q6| jd ur| jjj||||d d S )NrO   epochstepsr'   _)r;   )r'   rO   rZ   global_step
isinstancedictlistkeys
startswithrW   poprX   callback_handlerrE   )r.   r8   r9   r:   r<   rA   Zevaluator_metricsrC   r1   r1   r2   on_epoch_endr   s    

zEvaluatorCallback.on_epoch_end)N)rH   rI   rJ   rK   r,   re   rL   r1   r1   r/   r2   rM   `   s   rM   c                      s@   e Zd ZdZdddd fddZdd	d
dddddZ  ZS )OriginalCallbackzA Callback to invoke the original callback function that was provided to SentenceTransformer.fit()

    This callback has the following signature: `(score: float, epoch: int, steps: int) -> None`
    !Callable[[float, int, int], None]r   r%   )callbackr'   r)   c                   s   t    || _|| _d S r*   )r+   r,   rh   r'   )r.   rh   r'   r/   r1   r2   r,      s    
zOriginalCallback.__init__ztransformers.TrainingArgumentsr   r   r7   )r8   r9   r:   r;   r)   c           	      K  sD   t | jdd}| D ](\}}||r| ||j|j  S qd S r=   )r5   r'   r>   r?   rh   rZ   r]   )	r.   r8   r9   r:   r;   rA   rB   rC   rD   r1   r1   r2   rE      s    
zOriginalCallback.on_evaluate)rH   rI   rJ   rK   r,   rE   rL   r1   r1   r/   r2   rf      s   rf   c                   @  s  e Zd ZdZdddddejjddidd	dd
dddd
ddd	dfdddddddddddddddddddddddZedddddddZ	dd d!d"d#Z
dddddejjddidd	dd
dddd
ddd	fdddddddddddddddddddd$d%d&Zdd'd(d)Zdd'd*d+ZdS ),FitMixinzYMixin class for injecting the `fit` and `old_fit` methods into SentenceTransformer modelsNr   ZWarmupLineari'  lrgh㈵>g{Gz?r   TFi  z&Iterable[tuple[DataLoader, nn.Module]]r#   intr"   ztype[Optimizer]zdict[str, object]r3   rN   r$   rg   r%   )train_objectivesr'   epochs	schedulerwarmup_stepsoptimizer_classoptimizer_paramsweight_decayevaluation_stepsrO   r(   max_grad_normuse_amprh   show_progress_barcheckpoint_pathcheckpoint_save_stepscheckpoint_save_total_limitresume_from_checkpointr)   c           2        s  t  stdddlm} t| \}}dd }|D ]
}||_q2d tj}i }t|ddD ]\}}t	|t
rrtj}nt|d	rt	|jtrtj}t|d
  g }g }|D ]*}tdd |D  \} }!|| 7 }||!7 }qtdd tt| D }"d}#zt|dhkrd}#W n ty   Y n0 |#r2|"d|}"|"|d| < qXt|}dddd}$dd t|ddD }%d}&|dur|dkr|dkr|}&ntd d}ttjtdkrdnd}'tf |p|$ |tj  ||&d|'|
dur|
dkrdnd i|
||| |durdnd ||d!}(|du s@|dkrZt  fd"d|! D }t"|| })t#| $ }*g d#fd$d|*D |	d%fd&d|*D d'd%g}+||+fi |},| j%|,|||)d(}-g }.|dur|.&t'|| |dur|.&t(|| || |(|d|%||,|-f|.d)}/|/j)j*D ]}t	|t'r(|/|_+q(|dur^|/,t-||| |dur|rt.j/0|rt.j/1|rt2d*|  d+d t.3|D }0|0rt4|0d,d- d.}1t.j/5||1}t2d/|  ntd0|  d}ntd1|  d}|/j6|d2 dS )3aq  
        Deprecated training method from before Sentence Transformers v3.0, it is recommended to use
        :class:`~sentence_transformers.trainer.SentenceTransformerTrainer` instead. This method uses
        :class:`~sentence_transformers.trainer.SentenceTransformerTrainer` behind the scenes, but does
        not provide as much flexibility as the Trainer itself.

        This training approach uses a list of DataLoaders and Loss functions to train the model. Each DataLoader
        is sampled in turn for one batch. We sample only as many batches from each DataLoader as there are in the
        smallest one to make sure of equal training with each dataset, i.e. round robin sampling.

        This method should produce equivalent results in v3.0+ as before v3.0, but if you encounter any issues
        with your existing training scripts, then you may wish to use
        :meth:`SentenceTransformer.old_fit <sentence_transformers.SentenceTransformer.old_fit>` instead.
        That uses the old training method from before v3.0.

        Args:
            train_objectives: Tuples of (DataLoader, LossFunction). Pass
                more than one for multi-task learning
            evaluator: An evaluator (sentence_transformers.evaluation)
                evaluates the model performance during training on held-
                out dev data. It is used to determine the best model
                that is saved to disk.
            epochs: Number of epochs for training
            steps_per_epoch: Number of training steps per epoch. If set
                to None (default), one epoch is equal the DataLoader
                size from train_objectives.
            scheduler: Learning rate scheduler. Available schedulers:
                constantlr, warmupconstant, warmuplinear, warmupcosine,
                warmupcosinewithhardrestarts
            warmup_steps: Behavior depends on the scheduler. For
                WarmupLinear (default), the learning rate is increased
                from o up to the maximal learning rate. After these many
                training steps, the learning rate is decreased linearly
                back to zero.
            optimizer_class: Optimizer
            optimizer_params: Optimizer parameters
            weight_decay: Weight decay for model parameters
            evaluation_steps: If > 0, evaluate the model using evaluator
                after each number of training steps
            output_path: Storage path for the model and evaluation files
            save_best_model: If true, the best model (according to
                evaluator) is stored at output_path
            max_grad_norm: Used for gradient normalization.
            use_amp: Use Automatic Mixed Precision (AMP). Only for
                Pytorch >= 1.6.0
            callback: Callback function that is invoked after each
                evaluation. It must accept the following three
                parameters in this order: `score`, `epoch`, `steps`
            show_progress_bar: If True, output a tqdm progress bar
            checkpoint_path: Folder to save checkpoints during training
            checkpoint_save_steps: Will save a checkpoint after so many
                steps
            checkpoint_save_total_limit: Total number of checkpoints to
                store
            resume_from_checkpoint: If true, searches for checkpoints
                to continue training from.
        zGPlease install `datasets` to use this function: `pip install datasets`.r   )SentenceTransformerTrainerc                 S  s   | S r*   r1   batchr1   r1   r2   identity   s    zFitMixin.fit.<locals>.identity   r   )startdataset
batch_sizec                 S  s   g | ]}|j |jfqS r1   )textslabel.0Zexampler1   r1   r2   
<listcomp>      z FitMixin.fit.<locals>.<listcomp>c                 S  s   i | ]\}}d | |qS )Z	sentence_r1   )r   idxtextr1   r1   r2   
<dictcomp>  r   z FitMixin.fit.<locals>.<dictcomp>TFr   	_dataset_r"   r)   c                  S  sD   d} d}t |  r@ttt |  dkr@d| } |d7 }q| S )Nzcheckpoints/modelr   r   zcheckpoints/model_)r   existslenr`   iterdir)dir_namer   r1   r1   r2   _default_checkpoint_dir$  s    $

z-FitMixin.fit.<locals>._default_checkpoint_dirc                 S  s   i | ]\}}d | |qS )r   r1   )r   r   Zloss_fnr1   r1   r2   r   -  r   NzqSetting `steps_per_epoch` alongside `epochs` > 1 no longer works. We will train with the full datasets per epoch.z4.41.0Zeval_strategyZevaluation_strategy)r&   batch_samplerZmulti_dataset_batch_samplerZper_device_train_batch_sizeZper_device_eval_batch_sizeZnum_train_epochs	max_stepsr[   no)Z
eval_stepsrt   Zfp16Zdisable_tqdmZsave_strategyZ
save_stepsZsave_total_limitc                   s   g | ]}t |  qS r1   r   )r   train_dataset)r   r1   r2   r   V  r   ZbiaszLayerNorm.biaszLayerNorm.weightc                   s*   g | ]"\ }t  fd dD s|qS )c                 3  s   | ]}| v V  qd S r*   r1   r   ndnr1   r2   	<genexpr>_  r   *FitMixin.fit.<locals>.<listcomp>.<genexpr>anyr   pno_decayr   r2   r   _  r   paramsrr   c                   s*   g | ]"\ }t  fd dD r|qS )c                 3  s   | ]}| v V  qd S r*   r1   r   r   r1   r2   r   b  r   r   r   r   r   r   r2   r   b  r           rn   ro   t_total)r<   r8   r   Zeval_datasetlossr'   
optimizers	callbackszLooking for checkpoints in: c                 S  s,   g | ]$}| d r|dd  r|qS )zcheckpoint--r   )rb   splitisdigit)r   
checkpointr1   r1   r2   r     s   c                 S  s   t | dd S )Nr   r   )rk   r   xr1   r1   r2   <lambda>  r   zFitMixin.fit.<locals>.<lambda>rC   z!Resuming from latest checkpoint: z.No checkpoints found in checkpoint directory: z;Checkpoint directory does not exist or is not a directory: )rz   )7r   ImportErrorZsentence_transformers.trainerr{   zip
collate_fnr   ZBATCH_SAMPLER	enumerater^   r   ZNO_DUPLICATEShasattrr   r   ZGROUP_BY_LABELr5   r   	from_dictset	TypeError
add_columnr   loggerwarningr   parsetransformers__version__r   r   ZROUND_ROBINminvaluesrk   r`   named_parameters_get_schedulerappendrM   rf   rd   r   rX   Zadd_callbackr!   rS   rT   r   isdirinfolistdirmaxrU   train)2r.   rl   r'   rm   steps_per_epochrn   ro   rp   rq   rr   rs   rO   r(   rt   ru   rh   rv   rw   rx   ry   rz   r{   Zdata_loadersZloss_fnsr~   Zdata_loaderr   Ztrain_dataset_dictZ
loader_idxr   labelsr}   Zbatch_textsZbatch_labelsr   Zadd_label_columnr   Zloss_fn_dictr   Zeval_strategy_keyr8   num_train_stepsparam_optimizeroptimizer_grouped_parameters	optimizerscheduler_objr   rX   Zall_checkpointsZlatest_checkpointr1   )r   r   r2   fit   s    P

	

	




zFitMixin.fitr   )rn   ro   r   r)   c                 C  s   |  }|dkrt| S |dkr0tj| |dS |dkrHtj| ||dS |dkr`tj| ||dS |dkrxtj| ||dS td| d	S )
z
        Returns the correct learning rate scheduler. Available scheduler:

        - constantlr,
        - warmupconstant,
        - warmuplinear,
        - warmupcosine,
        - warmupcosinewithhardrestarts
        Z
constantlrZwarmupconstant)num_warmup_stepsZwarmuplinear)r   Znum_training_stepsZwarmupcosineZwarmupcosinewithhardrestartszUnknown scheduler N)lowerr   Zget_constant_scheduleZ!get_constant_schedule_with_warmupZget_linear_schedule_with_warmupZget_cosine_schedule_with_warmupZ2get_cosine_with_hard_restarts_schedule_with_warmup
ValueError)r   rn   ro   r   r1   r1   r2   r     s$    
zFitMixin._get_schedulerzlist[InputExample]z&tuple[list[dict[str, Tensor]], Tensor])r}   r)   c                   sj   dd |D } fddt | D }dd |D }|rXt|d tjrXtt|}n
t|}||fS )a;  
        Transforms a batch from a SmartBatchingDataset to a batch of tensors for the model
        Here, batch is a list of InputExample instances: [InputExample(...), ...]

        Args:
            batch: a batch from a SmartBatchingDataset

        Returns:
            a batch of tensors for the model
        c                 S  s   g | ]
}|j qS r1   )r   r   r1   r1   r2   r     r   z3FitMixin.smart_batching_collate.<locals>.<listcomp>c                   s   g | ]}  |qS r1   )tokenize)r   Zsentencer.   r1   r2   r     r   c                 S  s   g | ]
}|j qS r1   )r   r   r1   r1   r2   r     r   r   )r   r^   npZndarraytorchZ
from_numpystackZtensor)r.   r}   r   Zsentence_featuresr   Zlabels_tensorr1   r   r2   smart_batching_collate  s    
zFitMixin.smart_batching_collate)rl   r'   rm   rn   ro   rp   rq   rr   rs   rO   r(   rt   ru   rh   rv   rw   rx   ry   r)   c           2        sx  g }|D ]\}}| t|| qddd |D }tjt|||||t|||	|
|d
ddd}d_tj	
d	|
d
|jd< |rddlm} tjj }j dd |D }|D ]}j|_qdd |D }|D ]}|j qd_|du s|dkr tdd |D }t|| }g }g }|D ]}t| } g d  fdd| D |	d fdd| D ddg}!||!fi |}"j|"|||d}#||" ||# q8d}$dd |D }%t|}&d}'t|d| dD ]R}(d})|D ]}|  |   qt|dd| dD ]}*t!|&D ]^}+||+ }||+ }"||+ }|%|+ },zt"|,}-W n0 t#y   t$||+ },|,|%|+< t"|,}-Y n0 |-\}.}/|/j}/tt%fdd |.}.|rH|  ||.|/}0W d   n1 s0    Y  |& }1|'|0(  |)|" tj*j+,|- | |.|" |/  |& |1k}'n.||.|/}0|0(  tj*j+,|- | |".  |"  |'s.|.  q.|)d!7 })|$d!7 }$|
dkr|)|
 dkr0||||(|)| |D ]}|  |   q|dur |dur |dkr |$| dkr 1|||$ q 0||||(d"| q|du r\|dur\2| |durt1|||$ dS )#ak  
        Deprecated training method from before Sentence Transformers v3.0, it is recommended to use
        :class:`sentence_transformers.trainer.SentenceTransformerTrainer` instead. This method should
        only be used if you encounter issues with your existing training scripts after upgrading to v3.0+.

        This training approach uses a list of DataLoaders and Loss functions to train the model. Each DataLoader
        is sampled in turn for one batch. We sample only as many batches from each DataLoader as there are in the
        smallest one to make sure of equal training with each dataset, i.e. round robin sampling.

        Args:
            train_objectives: Tuples of (DataLoader, LossFunction). Pass
                more than one for multi-task learning
            evaluator: An evaluator (sentence_transformers.evaluation)
                evaluates the model performance during training on held-
                out dev data. It is used to determine the best model
                that is saved to disc.
            epochs: Number of epochs for training
            steps_per_epoch: Number of training steps per epoch. If set
                to None (default), one epoch is equal the DataLoader
                size from train_objectives.
            scheduler: Learning rate scheduler. Available schedulers:
                constantlr, warmupconstant, warmuplinear, warmupcosine,
                warmupcosinewithhardrestarts
            warmup_steps: Behavior depends on the scheduler. For
                WarmupLinear (default), the learning rate is increased
                from o up to the maximal learning rate. After these many
                training steps, the learning rate is decreased linearly
                back to zero.
            optimizer_class: Optimizer
            optimizer_params: Optimizer parameters
            weight_decay: Weight decay for model parameters
            evaluation_steps: If > 0, evaluate the model using evaluator
                after each number of training steps
            output_path: Storage path for the model and evaluation files
            save_best_model: If true, the best model (according to
                evaluator) is stored at output_path
            max_grad_norm: Used for gradient normalization.
            use_amp: Use Automatic Mixed Precision (AMP). Only for
                Pytorch >= 1.6.0
            callback: Callback function that is invoked after each
                evaluation. It must accept the following three
                parameters in this order: `score`, `epoch`, `steps`
            show_progress_bar: If True, output a tqdm progress bar
            checkpoint_path: Folder to save checkpoints during training
            checkpoint_save_steps: Will save a checkpoint after so many
                steps
            checkpoint_save_total_limit: Total number of checkpoints to
                store
        z

c                 S  s   g | ]}|qS r1   r1   )r   r   r1   r1   r2   r   "  r   z$FitMixin.old_fit.<locals>.<listcomp>)
r'   rm   r   rn   ro   rp   rq   rr   rs   rt      T)indent	sort_keysNz{LOSS_FUNCTIONS}z{FIT_PARAMETERS}z{TRAINING_SECTION}r   )autocastc                 S  s   g | ]\}}|qS r1   r1   )r   
dataloaderr\   r1   r1   r2   r   @  r   c                 S  s   g | ]\}}|qS r1   r1   )r   r\   r   r1   r1   r2   r   F  r   iigc                 S  s   g | ]}t |qS r1   r   r   r   r1   r1   r2   r   M  r   r   c                   s*   g | ]"\ }t  fd dD s|qS )c                 3  s   | ]}| v V  qd S r*   r1   r   r   r1   r2   r   Z  r   .FitMixin.old_fit.<locals>.<listcomp>.<genexpr>r   r   r   r   r2   r   Z  r   r   c                   s*   g | ]"\ }t  fd dD r|qS )c                 3  s   | ]}| v V  qd S r*   r1   r   r   r1   r2   r   ]  r   r   r   r   r   r   r2   r   ]  r   r   r   c                 S  s   g | ]}t |qS r1   )iterr   r1   r1   r2   r   i  r   FZEpoch)descdisableZ	Iterationg?)r   Z	smoothingr   c                   s   t |  jS r*   )r   devicer|   r   r1   r2   r     r   z"FitMixin.old_fit.<locals>.<lambda>r   r   )3extendr   Zget_train_objective_inforU   jsondumpsr   r"   Z_model_card_textZ__TRAINING_SECTION__replaceZ_model_card_varsZtorch.cuda.ampr   r   cudaampZ
GradScalertor   r   r   
best_scorer   rk   r`   r   r   r   r   r   Z	zero_gradr   rangenextStopIterationr   mapZ	get_scalescaleZbackwardZunscale_r
   utilsZclip_grad_norm_
parametersstepupdate_eval_during_training_save_checkpointr@   )2r.   rl   r'   rm   r   rn   ro   rp   rq   rr   rs   rO   r(   rt   ru   rh   rv   rw   rx   ry   Zinfo_loss_functionsr   r   Zinfo_fit_parametersr   ZscalerZdataloadersZloss_modelsZ
loss_modelr   r   Z
schedulersr   r   r   r   r]   Zdata_iteratorsZnum_train_objectivesZskip_schedulerrZ   Ztraining_stepsr\   Z	train_idxZdata_iteratordatafeaturesr   Z
loss_valueZscale_before_stepr1   )r   r.   r2   old_fit  s    J


*





zFitMixin.old_fitr   c           	      C  s   |}|dur6t j|dd t j|d}t j|dd |dur|| |||d}|durb|||| || jkr|| _|r| | dS )z#Runs evaluation during the trainingNTrQ   rP   rY   )rS   rV   rT   rU   r   r@   )	r.   r'   rO   r(   rZ   r[   rh   Z	eval_pathZscorer1   r1   r2   r     s    
zFitMixin._eval_during_trainingc                 C  s   |  tj|t| |d ur|dkrg }t|D ]*}| r6|t|tj||d q6t	||krt
|dd d}t|d d  d S )Nr   )r   rT   c                 S  s   | d S )Nr   r1   r   r1   r1   r2   r     r   z+FitMixin._save_checkpoint.<locals>.<lambda>r   rT   )r@   rS   rT   rU   r"   r   r   r   rk   r   sortedshutilrmtree)r.   rw   ry   r   Zold_checkpointssubdirr1   r1   r2   r     s     zFitMixin._save_checkpoint)rH   rI   rJ   rK   r   ZoptimZAdamWr   staticmethodr   r   r   r   r   r1   r1   r1   r2   ri      s^   6 v4 fri   )?
__future__r   r   loggingrS   r  collections.abcr   pathlibr   typingr   r   r   numpyr   r   r   	packagingr   r	   r
   Ztorch.optimr   Ztorch.optim.lr_schedulerr   Ztorch.utils.datar   Ztqdm.autonotebookr   r   r   r   Z5sentence_transformers.datasets.NoDuplicatesDataLoaderr   Z3sentence_transformers.datasets.SentenceLabelDatasetr   Z#sentence_transformers.training_argsr   r   r   Zsentence_transformers.utilr   r   r   Z
evaluationr   Zmodel_card_templatesr   Zdatasetsr   r   	getLoggerrH   r   Z*sentence_transformers.readers.InputExampler   Z)sentence_transformers.SentenceTransformerr    r!   rM   rf   ri   r1   r1   r1   r2   <module>   sB   
4)