在 Sagemaker 和 Huggingface 中训练已经训练好的模型,无需重新初始化
假设我已经在一些训练数据上成功训练了一个模型 10 个时期。然后我如何访问相同的模型并再训练 10 个 epoch?
在文档中,它建议“您需要指定通过超参数的检查点输出路径” -->如何?
# define my estimator the standard way
huggingface_estimator = HuggingFace(
entry_point='train.py',
source_dir='./scripts',
instance_type='ml.p3.2xlarge',
instance_count=1,
role=role,
transformers_version='4.10',
pytorch_version='1.9',
py_version='py38',
hyperparameters = hyperparameters,
metric_definitions=metric_definitions
)
# train the model
huggingface_estimator.fit(
{'train': training_input_path, 'test': test_input_path}
)
如果我再次运行 huggingface_estimator.fit
,它就会重新开始整个过程并覆盖我之前的训练。
Let's say I have successfully trained a model on some training data for 10 epochs. How can I then access the very same model and train for a further 10 epochs?
In the docs it suggests "you need to specify a checkpoint output path through hyperparameters" --> how?
# define my estimator the standard way
huggingface_estimator = HuggingFace(
entry_point='train.py',
source_dir='./scripts',
instance_type='ml.p3.2xlarge',
instance_count=1,
role=role,
transformers_version='4.10',
pytorch_version='1.9',
py_version='py38',
hyperparameters = hyperparameters,
metric_definitions=metric_definitions
)
# train the model
huggingface_estimator.fit(
{'train': training_input_path, 'test': test_input_path}
)
If I run huggingface_estimator.fit
again it will just start the whole thing over again and overwrite my previous training.
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
data:image/s3,"s3://crabby-images/d5906/d59060df4059a6cc364216c4d63ceec29ef7fe66" alt="扫码二维码加入Web技术交流群"
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
您可以在 Spot 中找到相关的检查点保存/加载代码实例 - Amazon SageMaker x Hugging Face Transformers。
(该示例启用 Spot 实例,但您可以按需使用)。
'output_dir':'/opt/ml/checkpoints'
。checkpoint_s3_uri
(这对于您将运行的一系列作业来说是唯一的)。You can find the relevant checkpoint save/load code in Spot Instances - Amazon SageMaker x Hugging Face Transformers.
(The example enables Spot instances, but you can use on-demand).
'output_dir':'/opt/ml/checkpoints'
.checkpoint_s3_uri
in the Estimator (which is unique to the series of jobs you'll run).