在 Sagemaker 和 Huggingface 中训练已经训练好的模型,无需重新初始化

发布于 2025-01-20 23:17:46 字数 842 浏览 4 评论 0原文

假设我已经在一些训练数据上成功训练了一个模型 10 个时期。然后我如何访问相同的模型并再训练 10 个 epoch?

在文档中,它建议“您需要指定通过超参数的检查点输出路径” -->如何?

# define my estimator the standard way
huggingface_estimator = HuggingFace(
    entry_point='train.py',
    source_dir='./scripts',
    instance_type='ml.p3.2xlarge',
    instance_count=1,
    role=role,
    transformers_version='4.10',
    pytorch_version='1.9',
    py_version='py38',
    hyperparameters = hyperparameters,
    metric_definitions=metric_definitions
)

# train the model
huggingface_estimator.fit(
    {'train': training_input_path, 'test': test_input_path}
)

如果我再次运行 huggingface_estimator.fit ,它就会重新开始整个过程​​并覆盖我之前的训练。

Let's say I have successfully trained a model on some training data for 10 epochs. How can I then access the very same model and train for a further 10 epochs?

In the docs it suggests "you need to specify a checkpoint output path through hyperparameters" --> how?

# define my estimator the standard way
huggingface_estimator = HuggingFace(
    entry_point='train.py',
    source_dir='./scripts',
    instance_type='ml.p3.2xlarge',
    instance_count=1,
    role=role,
    transformers_version='4.10',
    pytorch_version='1.9',
    py_version='py38',
    hyperparameters = hyperparameters,
    metric_definitions=metric_definitions
)

# train the model
huggingface_estimator.fit(
    {'train': training_input_path, 'test': test_input_path}
)

If I run huggingface_estimator.fit again it will just start the whole thing over again and overwrite my previous training.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

西瑶 2025-01-27 23:17:46

您可以在 Spot 中找到相关的检查点保存/加载代码实例 - Amazon SageMaker x Hugging Face Transformers
(该示例启用 Spot 实例,但您可以按需使用)。

  1. 在超参数中设置:'output_dir':'/opt/ml/checkpoints'
  2. 您在估算器中定义一个 checkpoint_s3_uri(这对于您将运行的一系列作业来说是唯一的)。
  3. 您为 train.py 添加代码以支持检查点:
from Transformers.trainer_utils import get_last_checkpoint

# 检查检查点是否存在,如果存在则继续训练
如果 get_last_checkpoint(args.output_dir) 不是 None:
    logger.info("*****继续训练*****")
    最后检查点 = get_最后检查点(args.output_dir)
    trainer.train(resume_from_checkpoint=last_checkpoint)
别的:
    训练师.train()

You can find the relevant checkpoint save/load code in Spot Instances - Amazon SageMaker x Hugging Face Transformers.
(The example enables Spot instances, but you can use on-demand).

  1. In hyperparameters you set: 'output_dir':'/opt/ml/checkpoints'.
  2. You define a checkpoint_s3_uri in the Estimator (which is unique to the series of jobs you'll run).
  3. You add code for train.py to support checkpointing:
from transformers.trainer_utils import get_last_checkpoint

# check if checkpoint existing if so continue training
if get_last_checkpoint(args.output_dir) is not None:
    logger.info("***** continue training *****")
    last_checkpoint = get_last_checkpoint(args.output_dir)
    trainer.train(resume_from_checkpoint=last_checkpoint)
else:
    trainer.train()
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文