如何根据输出张量从pytorch模型中删除预测头?

发布于 2025-01-15 01:31:57 字数 482 浏览 0 评论 0原文

我正在开发一个 ViT(Vision Transformer)相关项目,一些低级定义位于 timm 库的深处,我无法更改。低级库定义涉及线性分类预测头,它不是我的网络的一部分。

一切都很好,直到我切换到 DDP 并行实现。 Pytorch 抱怨一些参数对损失没有影响,它指示我使用“find_unused_pa​​rameters=True”。事实上,这是一个常见的场景,如果我将这个“find_unused_pa​​rameters=True”添加到训练例程中,它会再次起作用。但是,我只能更改代码库中的模型定义,但不能修改与训练相关的任何内容……

所以我想我现在唯一能做的就是从模型中“删除”线性头。 虽然我无法深入研究 ViT 的低级定义,但我可以像这样输出这个张量:

encoder_output,   linear_head_output =  ViT(input)

是否可以根据这个 Linear_head_output 张量删除这个线性预测头?

I am working on a ViT (Vision Transformer) related project and some low level definition is deep inside timm library, which I can not change. The low level library definition involves a linear classification prediction head, which is not a part of my network.

Every thing was fine until I switched to DDP parallel implementation. Pytorch complained about some parameters which didn’t contribute to the loss, and it instructed me to use “find_unused_parameters=True”. In fact, it is a common scenario and it worked again if I added this “find_unused_parameters=True” to the training routine. However, I am only allowed to change the model definition in our code base, but I cannot modify anything related to training …

So I guess the only thing I can do right now, is to “remove” the linear head from the model.
Although I cannot dig into the low level definition of ViT, but I can output this tensor like this:

encoder_output,   linear_head_output =  ViT(input)

Is it possible to remove this linear prediction head based on this linear_head_output tensor?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

别靠近我心 2025-01-22 01:31:57

只需在调用 timm.create_model() 创建 ViT 模型时设置 num_classes=0 即可。

以下是有关特征提取的 TIMM 文档中的示例:

import torch
import timm
m = timm.create_model('resnet50', pretrained=True, num_classes=0, global_pool='')
o = m(torch.randn(2, 3, 224, 224))
print(f'Unpooled shape: {o.shape}')

Just set the num_classes=0 when you create your ViT model by calling timm.create_model().

Here is an example from TIMM documentation on Feature Extraction:

import torch
import timm
m = timm.create_model('resnet50', pretrained=True, num_classes=0, global_pool='')
o = m(torch.randn(2, 3, 224, 224))
print(f'Unpooled shape: {o.shape}')
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文