从文件中加载具有正确参数的 pytorch 模型

发布于 2025-01-14 03:16:29 字数 1251 浏览 6 评论 0原文

按照 Chris McCormick 的教程创建 BERT 假新闻检测器(链接此处),最后他使用以下代码保存 PyTorch 模型:

output_dir = './model_save/'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    
    
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

正如他自己所说,可以使用以下命令重新加载它: from_pretrained()。目前,代码的作用是创建一个包含 6 个文件的输出目录:

config.json
merges.txt
pytorch_model.bin
special_tokens_map.json
tokenizer_config.json
vocab.json

那么如何使用 from_pretrained() 方法来加载模型及其所有参数和各自的权重,以及我要加载哪些文件从六个开始使用?

我知道模型可以这样加载(来自 PyTorch 文档):

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

但是如何加载我可以利用输出目录中的文件来执行此操作吗?

任何帮助表示赞赏!

Having followed Chris McCormick's tutorial for creating a BERT Fake News Detector (link here), at the end he saves the PyTorch model using the following code:

output_dir = './model_save/'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    
    
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

As he says himself, it can be reloaded using from_pretrained(). Currently, what the code does is create an output directory with 6 files:

config.json
merges.txt
pytorch_model.bin
special_tokens_map.json
tokenizer_config.json
vocab.json

So how can I use the from_pretrained() method to load the model with all of its arguments and respective weights, and which files do I use from the six?

I understand that a model can be loaded as such (from PyTorch documentation):

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

but how can I make use of the files in the output directory to do this?

Any help is appreciated!

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

静若繁花 2025-01-21 03:16:29

我只需将模型路径提供给 from_pretrained() 函数就可以完成此任务。 from_pretrained 函数能够识别相关的 json 配置文件并加载模型。像这样:

TheModelClass.from_pretrained(output_dir)

有时尝试一些代码并看看它是否有效会有所帮助。

I was able to accomplish this just feeding the model path to the from_pretrained() function. The from_pretrained function was able to identify relevant json config files and load the model. Like this:

TheModelClass.from_pretrained(output_dir)

Sometimes it helps to just try some code and see if it works.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文