我如何使用Pytorch Lightning训练Regnet-800mf主链的最后几层
我试图通过允许对以前冷冻的骨干(Regnet-800mf)进行一些最后一层的培训来获得更好的结果。如何在Pytorch Lightning中实现这一目标?我是ML的新手,所以如果我遗漏了任何重要信息,请原谅我。
我的模型(机械classifier)调用另一个类(参数classclassifier),该类别包括预训练的regnet作为冷冻骨架。在训练过程中,正向函数仅通过参数classifier的主链而不是分类层传递。我将在下面包括两者的初始功能。
我的 MechClassifier
模型:
class MechClassifier(pl.LightningModule):
def __init__(
self,
num_classes,
lr=4e-3,
weight_decay=1e-8,
gpus=1,
max_epochs=30,
):
super().__init__()
self.lr = lr
self.weight_decay = weight_decay
self.__dict__.update(locals())
self.backbone = ParametersClassifier.load_from_checkpoint(
checkpoint_path="checkpoints/param_classifier/last.ckpt",
num_classes=3,
gpus=1,
)
self.backbone.freeze()
self.backbone.eval()
self.mf_classifier = nn.Sequential(
nn.Linear(self.backbone.num_ftrs, 8),
nn.ReLU(),
nn.Linear(8, num_classes),
)
self.wd_classifier = nn.Sequential(
nn.Linear(self.backbone.num_ftrs, 8),
nn.ReLU(),
nn.Linear(8, num_classes),
)
def forward(self, x):
self.backbone.eval()
with torch.no_grad():
x = self.backbone.model(x)
# x = self.model(x)
out1 = self.mf_classifier(x)
out2 = self.wd_classifier(x)
# print(out1.size())
return (out1, out2)
parametersclassifier
(从检查点加载):
class ParametersClassifier(pl.LightningModule):
def __init__(
self,
num_classes,
lr=4e-3,
weight_decay=0.05,
gpus=1,
max_epochs=30,
):
super().__init__()
self.lr = lr
self.weight_decay = weight_decay
self.__dict__.update(locals())
self.model = models.regnet_y_800mf(pretrained=True)
self.num_ftrs = self.model.fc.in_features
self.model.fc = nn.Identity()
self.fc1 = nn.Linear(self.num_ftrs, num_classes)
self.fc2 = nn.Linear(self.num_ftrs, num_classes)
self.fc3 = nn.Linear(self.num_ftrs, num_classes)
self.fc4 = nn.Linear(self.num_ftrs, num_classes)
def forward(self, x):
x = self.model(x)
out1 = self.fc1(x)
out2 = self.fc2(x)
out3 = self.fc3(x)
out4 = self.fc4(x)
return (out1, out2, out3, out4)
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
您可以查看 regnet regnet 模型您正在使用 。它的
forward
函数:而不是像您一样使用
torch.no_grad
上下文管理器,而是应该按必要时打开/关闭onegress_grad
。默认情况下,模块参数具有其需要
flag设置为true
,这意味着他们能够执行梯度计算。如果此标志设置为false
,则可以将这些组件视为冷冻。根据您想要冻结的层以及想要捕获的层,您可以手动做到这一点。例如,如果要冻结骨架并捕获了Regnet的完全连接层,然后从
MechClassifier
'S__ INT __ INT __ INT __ INT __
中替换以下内容:使用以下行:
和使用
MechClassifier
对forward
函数这样的推断:You can look at the implementation for the
Regnet
model you are usinghere
. Itsforward
function:Instead of using a
torch.no_grad
context manager as you did, you should rather switch on/off therequires_grad
as necessary. By default module parameters have theirrequires_grad
flag set toTrue
which means they are able to perform gradient computation. If this flag is set toFalse
, you can consider those components as frozen.Depending on which layers you want to freeze and those that you want to finetune, you can manually do that. For example, if you want to freeze the backbone and finetune the fully connected layer of the Regnet, and replace the following from
MechClassifier
's__init__
:With the following lines:
And perform inference on
MechClassifier
with aforward
function like so: