无法将Pytorch模型转换为Torchscript格式
加载了预处理的pytorch型号文件,当我尝试使用Torch.jit.script运行它时,我会遇到以下错误,当我尝试从pytorch.org运行内置验证的型号时,它可以很好地工作。 (例如链接到示例代码a href =“ https://github.com/aaltoml/gp-mvs” rel =“ nofollow noreferrer”>含有预读取的模型权重),(pretrained weight)
encoder = enCoder()
encoder = torch.nn.DataParallel(encoder)
encoder.load_state_dict(weights['state_dict'])
encoder.eval()
torchscript_model = torch.jit.script(encoder)
# Error
---------------------------------------------------------------------------
NotSupportedError Traceback (most recent call last)
[<ipython-input-30-1d9f30e14902>](https://localhost:8080/#) in <module>()
1 # torch.quantization.convert(encoder, inplace=True)
2
----> 3 torchscript_model = torch.jit.script(encoder)
8 frames
[/usr/local/lib/python3.7/dist-packages/torch/jit/_script.py](https://localhost:8080/#) in script(obj, optimize, _frames_up, _rcb, example_inputs)
1256 obj = call_prepare_scriptable_func(obj)
1257 return torch.jit._recursive.create_script_module(
-> 1258 obj, torch.jit._recursive.infer_methods_to_compile
1259 )
1260
[/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py](https://localhost:8080/#) in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
449 if not is_tracing:
450 AttributeTypeIsSupportedChecker().check(nn_module)
--> 451 return create_script_module_impl(nn_module, concrete_type, stubs_fn)
452
453 def create_script_module_impl(nn_module, concrete_type, stubs_fn):
[/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py](https://localhost:8080/#) in create_script_module_impl(nn_module, concrete_type, stubs_fn)
461 """
462 cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
--> 463 method_stubs = stubs_fn(nn_module)
464 property_stubs = get_property_stubs(nn_module)
465 hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module)
[/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py](https://localhost:8080/#) in infer_methods_to_compile(nn_module)
730 stubs = []
731 for method in uniqued_methods:
--> 732 stubs.append(make_stub_from_method(nn_module, method))
733 return overload_stubs + stubs
734
[/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py](https://localhost:8080/#) in make_stub_from_method(nn_module, method_name)
64 # In this case, the actual function object will have the name `_forward`,
65 # even though we requested a stub for `forward`.
---> 66 return make_stub(func, method_name)
67
68
[/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py](https://localhost:8080/#) in make_stub(func, name)
49 def make_stub(func, name):
50 rcb = _jit_internal.createResolutionCallbackFromClosure(func)
---> 51 ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
52 return ScriptMethodStub(rcb, ast, func)
53
[/usr/local/lib/python3.7/dist-packages/torch/jit/frontend.py](https://localhost:8080/#) in get_jit_def(fn, def_name, self_name, is_classmethod)
262 pdt_arg_types = type_trace_db.get_args_types(qualname)
263
--> 264 return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
265
266 # TODO: more robust handling of recognizing ignore context manager
[/usr/local/lib/python3.7/dist-packages/torch/jit/frontend.py](https://localhost:8080/#) in build_def(ctx, py_def, type_line, def_name, self_name, pdt_arg_types)
300 py_def.col_offset + len("def"))
301
--> 302 param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
303 return_type = None
304 if getattr(py_def, 'returns', None) is not None:
[/usr/local/lib/python3.7/dist-packages/torch/jit/frontend.py](https://localhost:8080/#) in build_param_list(ctx, py_args, self_name, pdt_arg_types)
324 expr = py_args.kwarg
325 ctx_range = ctx.make_range(expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg))
--> 326 raise NotSupportedError(ctx_range, _vararg_kwarg_err)
327 if py_args.vararg is not None:
328 expr = py_args.vararg
NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
File "/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/data_parallel.py", line 147
def forward(self, *inputs, **kwargs):
~~~~~~~ <--- HERE
with torch.autograd.profiler.record_function("DataParallel.forward"):
if not self.device_ids:
`
### Versions
Collecting environment information...
PyTorch version: 1.10.0+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.12.0
Libc version: glibc-2.26
Python version: 3.7.13 (default, Mar 16 2022, 17:37:17) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.144+-x86_64-with-Ubuntu-18.04-bionic
Is CUDA available: False
CUDA runtime version: 11.1.105
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.21.6
[pip3] torch==1.10.0+cu111
[pip3] torchaudio==0.10.0+cu111
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.11.0
[pip3] torchvision==0.11.1+cu111
[conda] Could not collect
Any help is appreciated.
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
torch.jit.script
通过从module.forward()中解析Python源代码来创建scriptFunction(带有图形的函数)。如果您的模块包含一些语法无法支持Python Parser,则将失败。特别是对于该对象不包含静态类型。
使用
Torch.jit.trace
能够避免此类问题。它在OP调用过程(C ++方式)中创建图形。它永远不会失败,但不能处理如果else分支机构案件。如果您有分支机构,则应在每个训练过程中追踪每次迭代,从而导致2向后1向后。使用No-Brach模型,您可以重复使用跟踪的ScriptFunction。torch.jit.script
create a ScriptFunction(a Function with Graph) by parsing the python source code from module.forward().If your module contains some grammar cannot support by the python parser, it will failed. Especially for the object not contains a static type.
Using
torch.jit.trace
is able to avoid such problems. It creates Graph in the op call process (c++ way). It will never failed, but cannot handle if-else branch cases. If you have branches, you should trace it every iteration which leading to 2 forward 1 backward in each training process. With no-brach model, you can reuse the traced ScriptFunction.