从文件中加载具有正确参数的 pytorch 模型
按照 Chris McCormick 的教程创建 BERT 假新闻检测器(链接此处),最后他使用以下代码保存 PyTorch 模型:
output_dir = './model_save/'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
正如他自己所说,可以使用以下命令重新加载它: from_pretrained()
。目前,代码的作用是创建一个包含 6 个文件的输出目录:
config.json
merges.txt
pytorch_model.bin
special_tokens_map.json
tokenizer_config.json
vocab.json
那么如何使用 from_pretrained() 方法来加载模型及其所有参数和各自的权重,以及我要加载哪些文件从六个开始使用?
我知道模型可以这样加载(来自 PyTorch 文档):
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
但是如何加载我可以利用输出目录中的文件来执行此操作吗?
任何帮助表示赞赏!
Having followed Chris McCormick's tutorial for creating a BERT Fake News Detector (link here), at the end he saves the PyTorch model using the following code:
output_dir = './model_save/'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
As he says himself, it can be reloaded using from_pretrained()
. Currently, what the code does is create an output directory with 6 files:
config.json
merges.txt
pytorch_model.bin
special_tokens_map.json
tokenizer_config.json
vocab.json
So how can I use the from_pretrained()
method to load the model with all of its arguments and respective weights, and which files do I use from the six?
I understand that a model can be loaded as such (from PyTorch documentation):
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
but how can I make use of the files in the output directory to do this?
Any help is appreciated!
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
我只需将模型路径提供给 from_pretrained() 函数就可以完成此任务。 from_pretrained 函数能够识别相关的 json 配置文件并加载模型。像这样:
有时尝试一些代码并看看它是否有效会有所帮助。
I was able to accomplish this just feeding the model path to the from_pretrained() function. The from_pretrained function was able to identify relevant json config files and load the model. Like this:
Sometimes it helps to just try some code and see if it works.