将多类图像分类减少为Pytorch中的二进制分类

发布于 2025-01-21 13:48:26 字数 548 浏览 4 评论 0原文

我正在研究一个由10个不同类组成的STL-10图像数据集。我想将这个多类图像分类问题减少到二进制类图像分类,例如1类VS REST。我正在使用pytorch torchvision下载和使用STL数据,但我无法做到这一点,而其余的则无法做到。

train_data=torchvision.datasets.STL10(root='data',split='train',transform=data_transforms['train'], download=True)
test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True)

train_dataloader = DataLoader(train_data,batch_size = 64,shuffle=True,num_workers=2)
test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2)

I am working on an stl-10 image dataset that consists of 10 different classes. I want to reduce this multiclass image classification problem to the binary class image classification such as class 1 Vs rest. I am using PyTorch torchvision to download and use the stl data but I am unable to do it as one Vs the rest.

train_data=torchvision.datasets.STL10(root='data',split='train',transform=data_transforms['train'], download=True)
test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True)

train_dataloader = DataLoader(train_data,batch_size = 64,shuffle=True,num_workers=2)
test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2)

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(3

二智少女 2025-01-28 13:48:26

对于Torchvision数据集,有一种内置的方法可以做到这一点。您需要定义转换功能或类,并在创建数据集时将其添加到target_transform中。

torchvision.datasets.STL10(root: str, split: str = 'train', folds: Union[int, NoneType] = None, transform: Union[Callable, NoneType] = None, target_transform: Union[Callable, NoneType] = None, download: bool = False)

这是一个工作示例供参考:


import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms


class Multi2UniLabelTfm():
    def __init__(self,pos_label=5):
        if isinstance(pos_label,int) or isinstance(pos_label,float):
            pos_label = [pos_label,]
        self.pos_label = pos_label

    def __call__(self,y):
        # if y==self.pos_label:
        if y in self.pos_label:
            return 1
        else:
            return 0

if __name__=='__main__':

    test_tfms = transforms.Compose([
        transforms.ToTensor()
    ])
    data_transforms = {'val':test_tfms}


    #Original Labels
    # target_transform = None   

    # Label 5 is converted to 1. Rest are 0.
    # target_transform = Multi2UniLabelTfm(pos_label=5)     

    # Labels 5,6,7 are converted to 1. Rest are 0.
    target_transform = Multi2UniLabelTfm(pos_label=[5,6,7])
    test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True, target_transform=target_transform)
    test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2)

    for idx,(x,y) in enumerate(test_dataloader):
        print(idx,y)

        if idx == 5:
            break

For torchvision datasets, there is an inbuilt way to do this. You need to define a transformation function or class and add that into the target_transform while creating the dataset.

torchvision.datasets.STL10(root: str, split: str = 'train', folds: Union[int, NoneType] = None, transform: Union[Callable, NoneType] = None, target_transform: Union[Callable, NoneType] = None, download: bool = False)

Here is a working example for reference :


import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms


class Multi2UniLabelTfm():
    def __init__(self,pos_label=5):
        if isinstance(pos_label,int) or isinstance(pos_label,float):
            pos_label = [pos_label,]
        self.pos_label = pos_label

    def __call__(self,y):
        # if y==self.pos_label:
        if y in self.pos_label:
            return 1
        else:
            return 0

if __name__=='__main__':

    test_tfms = transforms.Compose([
        transforms.ToTensor()
    ])
    data_transforms = {'val':test_tfms}


    #Original Labels
    # target_transform = None   

    # Label 5 is converted to 1. Rest are 0.
    # target_transform = Multi2UniLabelTfm(pos_label=5)     

    # Labels 5,6,7 are converted to 1. Rest are 0.
    target_transform = Multi2UniLabelTfm(pos_label=[5,6,7])
    test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True, target_transform=target_transform)
    test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2)

    for idx,(x,y) in enumerate(test_dataloader):
        print(idx,y)

        if idx == 5:
            break
双马尾 2025-01-28 13:48:26

您需要重新标记图像。在开始时,类别0对应于标签0,类1对应于标签1,...和类别10对应于标签9。如果要实现二进制分类,则需要更改类别图1的标签(或其他)到0,以及所有其他类别的图片至1。

You need to relabel the image. At the beginning, class 0 corresponds to label 0, class 1 corresponds to label 1, ..., and class 10 corresponds to label 9. If you want to achieve binary classification, you need to change the label of the picture of category 1 (or other) to 0, and the picture of all other categories to 1.

假面具 2025-01-28 13:48:26

一种方法是在运行时更新标签值,然后将其传递给训练循环中的损失功能。假设我们要将5类Relabel Relabel Relabel As 1,其余为0:

my_class_id = 5
for imgs, labels in train_dataloader:
    labels = torch.where(labels == my_class_id, 1, 0)
    ...

您也可能需要对Test_dataloader进行类似的重新标记。另外,我不确定标签的数据类型。如果它的浮动,请相应地更改。

One way is to update label values at runtime before passing them to loss function in the training loop. Let's say we want to relabel class 5 as 1, and the rest as 0:

my_class_id = 5
for imgs, labels in train_dataloader:
    labels = torch.where(labels == my_class_id, 1, 0)
    ...

You may also need to do similar relabeling for test_dataloader. Also, I am not sure about the datatype of labels. If its float, change accordingly.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文