将多类图像分类减少为Pytorch中的二进制分类
我正在研究一个由10个不同类组成的STL-10图像数据集。我想将这个多类图像分类问题减少到二进制类图像分类,例如1类VS REST。我正在使用pytorch torchvision下载和使用STL数据,但我无法做到这一点,而其余的则无法做到。
train_data=torchvision.datasets.STL10(root='data',split='train',transform=data_transforms['train'], download=True)
test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True)
train_dataloader = DataLoader(train_data,batch_size = 64,shuffle=True,num_workers=2)
test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2)
I am working on an stl-10 image dataset that consists of 10 different classes. I want to reduce this multiclass image classification problem to the binary class image classification such as class 1 Vs rest. I am using PyTorch torchvision to download and use the stl data but I am unable to do it as one Vs the rest.
train_data=torchvision.datasets.STL10(root='data',split='train',transform=data_transforms['train'], download=True)
test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True)
train_dataloader = DataLoader(train_data,batch_size = 64,shuffle=True,num_workers=2)
test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2)
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
data:image/s3,"s3://crabby-images/d5906/d59060df4059a6cc364216c4d63ceec29ef7fe66" alt="扫码二维码加入Web技术交流群"
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(3)
对于Torchvision数据集,有一种内置的方法可以做到这一点。您需要定义转换功能或类,并在创建数据集时将其添加到
target_transform
中。这是一个工作示例供参考:
For torchvision datasets, there is an inbuilt way to do this. You need to define a transformation function or class and add that into the
target_transform
while creating the dataset.Here is a working example for reference :
您需要重新标记图像。在开始时,类别0对应于标签0,类1对应于标签1,...和类别10对应于标签9。如果要实现二进制分类,则需要更改类别图1的标签(或其他)到0,以及所有其他类别的图片至1。
You need to relabel the image. At the beginning, class 0 corresponds to label 0, class 1 corresponds to label 1, ..., and class 10 corresponds to label 9. If you want to achieve binary classification, you need to change the label of the picture of category 1 (or other) to 0, and the picture of all other categories to 1.
一种方法是在运行时更新标签值,然后将其传递给训练循环中的损失功能。假设我们要将5类Relabel Relabel Relabel As 1,其余为0:
您也可能需要对Test_dataloader进行类似的重新标记。另外,我不确定
标签
的数据类型。如果它的浮动,请相应地更改。One way is to update label values at runtime before passing them to loss function in the training loop. Let's say we want to relabel class 5 as 1, and the rest as 0:
You may also need to do similar relabeling for test_dataloader. Also, I am not sure about the datatype of
labels
. If its float, change accordingly.