返回介绍

数学基础

统计学习

深度学习

工具

Scala

二、Callbacks

发布于 2023-07-17 23:38:23 字数 14271 浏览 0 评论 0 收藏 0

  1. callbacks 是一种对象,它可以自定义 PyTorch Trainertraining loopTensorFlow 中尚未实现此功能)。例如,检查 training loop 状态(用于进度报告、在 TensorBoard 或其他 ML 平台上进行 logging)并做出决定(如 early stopping )。

    callbacks"read only" 的代码,除了它们返回的 TrainerControl 对象外,它们不能改变 training loop 中的任何东西。如果需要改变 training loop ,那么你应该对 Trainer 进行子类化并覆盖你想要改变的方法。

  2. 默认情况下,Trainer 将使用以下 callbacks

    • DefaultFlowCallback:处理 logging, saving, evaluation 的默认 callback
    • PrinterCallbackProgressCallback:显示训练进度,或打印日志。如果你通过 TrainingArguments 禁用 tqdm,那么Trainer 就使用 PrinterCallback;否则就使用 ProgressCallback
    • TensorBoardCallback:如果 tensorboard可用(安装了 PyTorch >= 1.4tensorboardX ),则Trainer 就使用 TensorBoardCallback
    • WandbCallback:如果 wandb 已安装,则 Trainer 使用 WandbCallback
    • CometCallback:如果 comet_ml 已安装,则 Trainer 使用 CometCallback
    • MLflowCallback:如果 mlflow 已安装,则 Trainer 使用 MLflowCallback
    • NeptuneCallback:如果 neptune 已安装,则 Trainer 使用 NeptuneCallback
    • AzureMLCallback:如果 azureml-sdk 已安装,则 Trainer 使用 AzureMLCallback
    • CodeCarbonCallback:如果 codecarbon 已安装,则 Trainer 使用 CodeCarbonCallback
    • ClearMLCallback:如果 clearml 已安装,则 Trainer 使用 ClearMLCallback
  3. 实现 callbacks 的主要类是 TrainerCallback 。它获得用于实例化 TrainerTrainingArguments ,可以通过 TrainerState 访问该 Trainer 的内部状态,并且可以通过 TrainerControltraining loop 采取一些行动。

2.1 API

  1. class transformers.TrainerCallbackTrainerCallback ,它将在一些事件中检查 training loop 的状态并作出一些决定。

    初始化参数:

    • args:一个 TrainingArguments,指定用于实例化 Trainer 的训练参数。

    • state:一个 TrainerState,指定训练器的当前状态。

    • control:一个 TrainerControl,指定返回给训练器的对象,它可以用来做一些决定。

    • model:一个 PreTrainedModeltorch.nn.Module,指定正在训练的模型。

    • tokenizer:一个 PreTrainedTokenizer,指定用于对数据进行编码的 tokenizer

    • optimizer:一个 torch.optim.Optimizer,指定用于训练的优化器。

    • lr_scheduler:一个 torch.optim.lr_scheduler.LambdaLR,指定用于训练的学习率调度器。

    • train_dataloader:一个 torch.utils.data.DataLoader,指定 training dataloader

    • eval_dataloader:一个 torch.utils.data.DataLoader,指定 evaluation dataloader

    • metrics:一个字典 Dict[str, float],指定由上一次 evaluation 阶段计算得到的指标。

      它仅在 on_evaluate 事件中才能访问。

    • logs:一个字典 Dict[str, float],指定需要 log 的内容。

      它只能在事件 on_log 中访问。

    方法(这些参数参考初始化参数):

    • on_epoch_begin(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):在一个 epoch 的开始时被调用的事件。

    • on_epoch_end(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):在一个 epoch 的结束时被调用的事件。

    • on_evaluate(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):在 evaluation 阶段之后被调用的事件。

    • on_init_end(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):在 Trainer 的初始化结束之后被调用的事件。

    • on_log(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):在 logging last logs 之后被调用的事件。

    • on_predict(args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs) :在一个成功的预测之后被调用的事件。

    • on_prediction_step(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):在一个 prediction step 之后被调用的事件。

    • on_save(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):在一个 checkpoint save 之后被调用的事件。

    • on_step_begin(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):在一个 training step 之前被调用的事件。

      如果使用梯度累积gradient accumulation,那么一个 training step 可能需要若干个 inputs

    • on_step_end(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):在一个 training step 之后被调用的事件。

      如果使用梯度累积gradient accumulation,那么一个 training step 可能需要若干个 inputs

    • on_substep_end(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):在 gradient accumulation 期间的每个 training substep 之后被调用的事件。

    • on_train_begin(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):在训练开始时被调用的事件。

    • on_train_end(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):在训练结束时被调用的事件。

    在每个事件中,都有以下参数:

    • control 对象:是唯一可以被 callback 改变的对象,在这种情况下,改变它的事件应该返回修改后的版本。
    • args, state, control 是所有事件中的位置参数,而其他参数都位于 kwargs 关键字参数。你可以 unpack 你需要的关键字参数。例如:
    
    
    xxxxxxxxxx
    class PrinterCallback(TrainerCallback): def on_log(self, args, state, control, logs=None, **kwargs): _ = logs.pop("total_flos", None) if state.is_local_process_zero: print(logs)
  2. 将一个自定义的 callback 注册到 PyTorch Trainer 的例子:

    
    
    xxxxxxxxxx
    class MyCallback(TrainerCallback): def on_train_begin(self, args, state, control, **kwargs): print("Starting training") trainer = Trainer(model,args,train_dataset=train_dataset,eval_dataset=eval_dataset, callbacks=[MyCallback], # 可以传入一个类,也可以传入一个 callback 对象 )

    也可以通过如下的方式注册:

    
    
    xxxxxxxxxx
    trainer = Trainer(...) trainer.add_callback(MyCallback) # 或者 trainer.add_callback(MyCallback())
  3. library 中目前可用的 TrainerCallback

    
    
    xxxxxxxxxx
    class transformers.integrations.CometCallback() # send logs to CometML def setup(args, state, model) class transformers.DefaultFlowCallback() # default callback for logging, saving, evaluation class transformers.PrinterCallback() # just prints the logs class transformers.ProgressCallback() # displays the progress of training or evaluation class transformers.EarlyStoppingCallback( # handles early stopping, Use with TrainingArguments metric_for_best_model early_stopping_patience: int = 1, early_stopping_threshold: typing.Optional[float] = 0.0) class transformers.integrations.TensorBoardCallback( tb_writer = None ) # sends the logs to TensorBoard class transformers.integrations.WandbCallback() # sends the logs to Weight and Biases def setup(args, state, model, **kwargs ) class transformers.integrations.MLflowCallback() # sends the logs to MLflow def setup(args, state, model) class transformers.integrations.AzureMLCallback(azureml_run = None) # sends the logs to AzureML class transformers.integrations.CodeCarbonCallback() # tracks the CO2 emission of training class transformers.integrations.NeptuneCallback( # sends the logs to Neptune api_token: typing.Optional[str] = None, project: typing.Optional[str] = None, name: typing.Optional[str] = None, base_namespace: str = 'finetuning', run: typing.Optional[ForwardRef('Run')] = None, log_parameters: bool = True, log_checkpoints: typing.Optional[str] = None, **neptune_run_kwargs ) class transformers.integrations.ClearMLCallback() # sends the logs to ClearML
  4. class transformers.TrainerState:一个包含Trainer 内部状态的类,在 checkpointing 时将伴随着模型和优化器保存并传递给 TrainerCallback

    
    
    xxxxxxxxxx
    class transformers.TrainerState( epoch: typing.Optional[float] = None, global_step: int = 0, max_steps: int = 0, num_train_epochs: int = 0, total_flos: float = 0, log_history: typing.List[typing.Dict[str, float]] = None, best_metric: typing.Optional[float] = None, best_model_checkpoint: typing.Optional[str] = None, is_local_process_zero: bool = True, is_world_process_zero: bool = True, is_hyper_param_search: bool = False, trial_name: str = None, trial_params: typing.Dict[str, typing.Union[str, float, int, bool]] = None )

    参数:

    • epoch:一个浮点数,仅用于训练期间,指定当前训练所处的 epoch(小数部分代表当前 epoch 完成的百分比)。
    • global_step:一个整数,仅用于训练期间,指定已经完成的 update steps 数量。
    • max_steps:一个整数,指定当前训练需要执行的 update steps 数量。
    • total_flos:一个浮点数,指定从训练开始以来,模型所做的浮点预算的总和。以浮点形式存储,避免溢出。
    • log_history:一个关于字典的列表 List[Dict[str, float]],指定自训练开始以来完成的日志列表。
    • best_metric:一个浮点数,指定当 tracking best model 时,到目前为止遇到的最佳指标值。
    • best_model_checkpoint:一个浮点数,指定当 tracking best model 时,到目前为止遇到的最佳模型的 checkpoint 的名称。
    • is_local_process_zero:一个布尔值,指定当前进程是否是 local 的主进程(用于分布式训练的场景)。
    • is_world_process_zero:一个布尔值,指定当前进程是否是 global 的主进程。当以分布式的方式在几台机器上进行训练时,只有一个进程为 True
    • is_hyper_param_search:一个布尔值,指定我们是否正在使用 Trainer.hyperparameter_search 进行超参数搜索。这将影响数据在 TensorBoard 中的记录方式。

    注意,在 TrainerState 中,一个 step 应理解为一个 update step 。当使用 gradient accumulation 时,一个 update step 可能需要几个前向和反向传播:如果你使用 gradient_accumulation_steps=n ,那么一个 update step 需要经过n$ n $ 个 batch

    方法:

    • load_from_json(json_path: str ):从 json_path 的内容创建一个 TrainerState 实例。
    • save_to_json(json_path: str ):将当前实例的内容以 JSON 格式存储到 json_path
  5. class class transformers.TrainerControl:一个处理 Trainer 控制流的类。这个类被 TrainerCallback 用来激活 training loop 中的一些开关。

    
    
    xxxxxxxxxx
    class transformers.TrainerControl( should_training_stop: bool = False, should_epoch_stop: bool = False, should_save: bool = False, should_evaluate: bool = False, should_log: bool = False )

    参数:

    • should_training_stop:一个布尔值,指定训练是否应该被中断。如果为 True,那么这个变量将没有机会被设置为 False,因为训练将直接停止。
    • should_epoch_stop:一个布尔值,指定当前的 epoch 是否应该被中断。如果是 True ,这个变量将在下一个 epoch 的开始被设置为 False
    • should_save:一个布尔值,指定当前 step 是否应该保存模型。如果是 True ,这个变量将在下一个 step 开始时被设置为False
    • should_evaluate:一个布尔值,指定当前 step 是否应该评估模型。如果是 True ,这个变量将在下一个 step 开始时被设置为False
    • should_log:一个布尔值,指定当前 step 是否应该上报日志。如果是 True ,这个变量将在下一个 step 开始时被设置为False

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文