我正在使用Resnet进行神经网络分类,并希望尝试在预先训练和未经训练的网络之间进行比较。但是,我确实想使用偏置项,这不是Pytorch的Resnet模块中的默认设置。
是否有一种方法可以包括预先训练的模型并在其中使用偏差术语?
我当前代码的简短片段,我从这里重新定义了重新连接体系结构 - 和 set bias = true true
net = resnet18(pretrained=True)
net.fc = nn.Linear(512, num_classes)
现在明显的错误现在是
resnet加载state_dict中的错误(s):
state_dict中缺少键:“ conv1.bias”,“ layer1.0.conv1.bias”,“ layer1.0.conv2.bias”,“ layer1.1.conv1.bias”,“ layer1.conv2 .bias“,” layer2.0.conv1.bias”,“ layer2.0.conv2.bias”,“ layer2.0.downsample.0.bias”,“ layer2.1.conv1.bias”,“ layer2.1 。 .1.conv2.bias“,” layer4.0.conv1.bias”,“ layer4.0.conv2.bias”,“ layer4.0.downsample.0.bias”,“ layer4.1.conv1.bias”, “ layer4.1.conv2.bias”。
I am using a ResNet for neural network classification and wish to try out a comparison between pre-trained and non-pre-trained networks. However, I do want to use the Bias term which is not the default setting in Pytorch's ResNet modules.
Is there a way to include a pre-trained model and use bias terms on top of that?
A very brief snippet of my current code, I redefine ResNet architecture from here - https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html and set Bias = True
net = resnet18(pretrained=True)
net.fc = nn.Linear(512, num_classes)
The obvious error right now is
Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "conv1.bias", "layer1.0.conv1.bias", "layer1.0.conv2.bias", "layer1.1.conv1.bias", "layer1.1.conv2.bias", "layer2.0.conv1.bias", "layer2.0.conv2.bias", "layer2.0.downsample.0.bias", "layer2.1.conv1.bias", "layer2.1.conv2.bias", "layer3.0.conv1.bias", "layer3.0.conv2.bias", "layer3.0.downsample.0.bias", "layer3.1.conv1.bias", "layer3.1.conv2.bias", "layer4.0.conv1.bias", "layer4.0.conv2.bias", "layer4.0.downsample.0.bias", "layer4.1.conv1.bias", "layer4.1.conv2.bias".
发布评论
评论(2)
来更改摘要中给出的_Resnet函数,它应该忽略非匹配键并避免崩溃
您应该通过添加strict = false
You should change the _resnet function given in the snippet
By adding strict=False it should ignore non matching keys and avoid crashing
默认_Resnet函数应如下更改:
这将使我们从state_dict加载预训练的权重并忽略非匹配键。
The default _resnet function should be changed as follows:
This would let us load the pre-trained weights from state_dict and ignore the non-matching keys.