Torchvision RetinaNet 预测不需要的班级背景

发布于 2025-01-16 22:34:04 字数 1324 浏览 1 评论 0原文

我想用我的自定义数据集和 2 个类(无背景)训练来自 torchvision 的预训练 RetinaNet。为了使用 RetinaNet 进行训练,我做了以下修改:

num_classes = 3 # num of objects to identify + background class

model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True)
# replace classification layer 
in_features = model.head.classification_head.conv[0].in_channels
num_anchors = model.head.classification_head.num_anchors
model.head.classification_head.num_classes = num_classes

cls_logits = torch.nn.Conv2d(in_features, num_anchors * num_classes, kernel_size = 3, stride=1, padding=1)
torch.nn.init.normal_(cls_logits.weight, std=0.01)  # as per pytorch code
torch.nn.init.constant_(cls_logits.bias, -math.log((1 - 0.01) / 0.01))  # as per pytorcch code 
# assign cls head to model
model.head.classification_head.cls_logits = cls_logits

问题是我得到了 0 类的检测,无论 num_classes 是 2 还是 3,它都是背景。

我尝试理解源代码并找不到类似的内容 fasterrcnn roi_head

# remove predictions with the background label
boxes = boxes[:, 1:]
scores = scores[:, 1:]
labels = labels[:, 1:]

我该如何解决这个问题?任何帮助将不胜感激!

I want to train the pretrained RetinaNet from torchvision with my custom dataset with 2 classes (without background). To train with RetinaNet, I did follow modifications:

num_classes = 3 # num of objects to identify + background class

model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True)
# replace classification layer 
in_features = model.head.classification_head.conv[0].in_channels
num_anchors = model.head.classification_head.num_anchors
model.head.classification_head.num_classes = num_classes

cls_logits = torch.nn.Conv2d(in_features, num_anchors * num_classes, kernel_size = 3, stride=1, padding=1)
torch.nn.init.normal_(cls_logits.weight, std=0.01)  # as per pytorch code
torch.nn.init.constant_(cls_logits.bias, -math.log((1 - 0.01) / 0.01))  # as per pytorcch code 
# assign cls head to model
model.head.classification_head.cls_logits = cls_logits

The problem is that I got the detections for class 0, which is background no matter whether the num_classes is 2 or 3.

I tried to understand the source code and could not find anything like in fasterrcnn roi_head

# remove predictions with the background label
boxes = boxes[:, 1:]
scores = scores[:, 1:]
labels = labels[:, 1:]

How can I solve this problem? Any help would be very appreciated!

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文