返回介绍

数学基础

统计学习

深度学习

工具

Scala

三、Keras callbacks

发布于 2023-07-17 23:38:23 字数 5867 浏览 0 评论 0 收藏 0

3.1 API

  1. class transformers.KerasMetricCallback:用于 kerascallback,用于在每个 epoch 结束时计算指标。

    与普通的 Keras 指标不同,这些指标不需要由 TF 来编译。它对于像 BLEUROUGE 这样需要字符串操作或 generation loop 的常见 NLP 指标特别有用,这些指标不能被编译。预测(或生成)将在 eval_dataset 上计算,然后以 np.ndarray 格式传递给metric_fnmetric_fn应该计算指标并返回一个字典,字典的键为指标名、值为指标值。

    
    
    xxxxxxxxxx
    class transformers.KerasMetricCallback( metric_fn: typing.Callable, eval_dataset: typing.Union[tensorflow.python.data.ops.dataset_ops.DatasetV2, numpy.ndarray, tensorflow.python.framework.ops.Tensor, tuple, dict], output_cols: typing.Optional[typing.List[str]] = None, label_cols: typing.Optional[typing.List[str]] = None, batch_size: typing.Optional[int] = None, predict_with_generate: bool = False, use_xla_generation: bool = False, generate_kwargs: typing.Optional[dict] = None )

    参数:

    • metric_fn:一个可调用对象,指定度量函数。调用metric_fn 时需要提供两个参数:predictionslabels,它们分别对应了模型的输出结果、以及 ground-truth labelmetric_fn 函数需要返回一个字典,字典的键为指标名、值为指标值。

      下面是一个摘要模型计算 ROUGE 分数的 metric_fn 的示例:

      
      
      xxxxxxxxxx
      from datasets import load_metric rouge_metric = load_metric("rouge") def rouge_fn(predictions, labels): decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) result = rouge_metric.compute(predictions=decoded_predictions, references=decoded_labels) return {key: value.mid.fmeasure * 100 for key, value in result.items()}
    • eval_dataset:一个 tf.data.Dataset 或字典或元组或 np.ndarraytf.Tensor,指定验证数据集。

    • output_cols:一个关于字符串的列表,指定模型输出中的哪些列作为 predictions 。默认为所有列。

    • label_cols:一个关于字符串的列表,指定验证集中的哪些列作为 label 列。如果未提供,则自动检测。

    • batch_size:一个整数,指定 batch size 。只有在验证集不是 pre-batched tf.data.Dataset 时才起作用。

    • predict_with_generate:一个布尔值,指定是否应该使用 model.generate() 来获取模型的输出。

    • use_xla_generation:一个布尔值,如果我们要执行 generating ,是否要用 XLA 来编译 model generation 。这可以极大地提高生成的速度(最多可以提高 100 倍),但是需要对每个 input shape 进行新的 XLA 编译。当使用 XLA generation 时,最好将你的输入填充到相同的大小,或者在你的 tokenizerDataCollator 中使用 pad_to_multiple_of 参数,这将减少 unique input shape 的数量,并节省大量的编译时间。

      如果 predict_with_generate = False ,该参数没有影响。

    • generate_kwargs:关键字参数,用于 generating 时传递给 model.generate() 的关键字参数。

      如果 predict_with_generate = False ,该参数没有影响。

  2. class transformers.PushToHubCallback:用于 kerascallback,用于定期保存和推送模型到 Hub

    
    
    xxxxxxxxxx
    class transformers.PushToHubCallback( output_dir: typing.Union[str, pathlib.Path], save_strategy: typing.Union[str, transformers.trainer_utils.IntervalStrategy] = 'epoch', save_steps: typing.Optional[int] = None, tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None, hub_model_id: typing.Optional[str] = None, hub_token: typing.Optional[str] = None, checkpoint: bool = False, **model_card_args )

    参数:

    • output_dir:一个字符串,指定输出目录,model predictionsmodel checkpoints 将被写入该目录并与 Hub 上的 repo 同步。
    • save_strategy/save_steps:参考 transformers.TrainingArguments
    • tokenizer:一个 PreTrainedTokenizerBase,指定模型使用的 tokenizer 。如果提供,将与模型权重一起上传到 repo
    • hub_model_id/hub_token:参考 transformers.TrainingArguments
    • checkpoint:一个布尔值,指定是否保存完整的 training checkpoints (包括 epochoptimizer state )以允许恢复训练。只在 save_strategy="epoch" 时可用。

    示例:

    
    
    xxxxxxxxxx
    from transformers.keras_callbacks import PushToHubCallback push_to_hub_callback = PushToHubCallback( output_dir="./model_save", tokenizer=tokenizer, hub_model_id="gpt5-7xlarge", ) model.fit(train_dataset, callbacks=[push_to_hub_callback])

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文