如何保存变压器 gpt2 的检查点以继续训练?

发布于 2025-01-09 05:43:58 字数 331 浏览 1 评论 0原文

我正在重新训练 GPT2 语言模型,并关注此博客:

https://towardsdatascience.com/train-gpt-2-in-your-own-language-fc6ad4d60171

在这里,他们训练了GPT2 上的网络,我正在尝试重新创建一个相同的网络。但是,我的数据集太大(250Mb),所以我想继续间隔训练。换句话说,我想检查模型训练。我怎么能这样做呢?

I am retraining the GPT2 language model, and am following this blog :

https://towardsdatascience.com/train-gpt-2-in-your-own-language-fc6ad4d60171

Here, they have trained a network on GPT2, and I am trying to recreate a same. However, my dataset is too large(250Mb), so I want to continue training in intervals. In other words, I want to checkpoint the model training. How could I do this?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

贪恋 2025-01-16 05:43:58
training_args = TrainingArguments(
    output_dir=model_checkpoint,
    # other hyper-params
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=dev_set,
    tokenizer=tokenizer
)

trainer.train()
# Save the model to model_dir
trainer.save_model()

def prepare_model(tokenizer, model_name_path):
    model = AutoModelForCausalLM.from_pretrained(model_name_path)
    model.resize_token_embeddings(len(tokenizer))
    return model

# Assume tokenizer is defined, You can simply pass the saved model directory path.
model = prepare_model(tokenizer, model_checkpoint)
training_args = TrainingArguments(
    output_dir=model_checkpoint,
    # other hyper-params
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=dev_set,
    tokenizer=tokenizer
)

trainer.train()
# Save the model to model_dir
trainer.save_model()

def prepare_model(tokenizer, model_name_path):
    model = AutoModelForCausalLM.from_pretrained(model_name_path)
    model.resize_token_embeddings(len(tokenizer))
    return model

# Assume tokenizer is defined, You can simply pass the saved model directory path.
model = prepare_model(tokenizer, model_checkpoint)
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文