Pytorch:如何为 CNN 制作自定义数据加载器?
我正在尝试从 CNN 的自定义数据集创建自己的数据加载器。原始的 Dataloader 是通过以下方式创建的:
train_loader = torch.utils.data.DataLoader(mnist_data, batch_size=64)
如果我检查上面的形状,我会得到
i1, l1 = next(iter(train_loader))
print(i1.shape) # torch.Size([64, 1, 28, 28])
print(l1.shape) # torch.Size([64])
当我将这个 train_loader 输入到我的 CNN 中时,它工作得很好。但是,我有一个自定义数据集。我已经完成了以下操作:
mnist_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
trainset = mnist_data
testset = mnist_data
x_train = np.array(trainset.data)
y_train = np.array(trainset.targets)
# modify x_train/y_train
现在,我如何才能将 x_train、y_train 放入与第一个类似的数据加载器中?我已经完成了以下操作:
train_data = []
for i in range(len(x_train)):
train_data.append([x_train[i], y_train[i]])
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64)
for i, (images, labels) in enumerate(train_loader):
images = images.unsqueeze(1)
但是,我仍然缺少频道列(应该是 1)。我该如何解决这个问题?
I'm trying to create my own Dataloader from a custom dataset for a CNN. The original Dataloader was created by writing:
train_loader = torch.utils.data.DataLoader(mnist_data, batch_size=64)
If I check the shape of the above, I get
i1, l1 = next(iter(train_loader))
print(i1.shape) # torch.Size([64, 1, 28, 28])
print(l1.shape) # torch.Size([64])
When I feed this train_loader into my CNN, it works beautifully. However, I have a custom dataset. I have done the following:
mnist_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
trainset = mnist_data
testset = mnist_data
x_train = np.array(trainset.data)
y_train = np.array(trainset.targets)
# modify x_train/y_train
Now, how would I be able to take x_train, y_train and make it into a Dataloader similar to the first one? I have done the following:
train_data = []
for i in range(len(x_train)):
train_data.append([x_train[i], y_train[i]])
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64)
for i, (images, labels) in enumerate(train_loader):
images = images.unsqueeze(1)
However, I'm still missing the channel column (which should be 1). How would I fix this?
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
我无权访问您的
x_train
和y_train
,但这可能有效:I don't have access to your
x_train
andy_train
, but probably this works: