文章来源于网络收集而来,版权归原创者所有,如有侵权请及时联系!
6.4 数据加载
数据的相关处理主要保存在 data/dataset.py
中。关于数据加载的相关操作,在上一章中我们已经提到过,其基本原理就是使用 Dataset
提供数据集的封装,再使用 Dataloader
实现数据并行加载。Kaggle 提供的数据包括训练集和测试集,而我们在实际使用中,还需专门从训练集中取出一部分作为验证集。
对于这三类数据集,其相应操作也不太一样,而如果专门写三个 Dataset
,则稍显复杂和冗余,因此这里通过加一些判断来区分。对于训练集,我们希望做一些数据增强处理,如随机裁剪、随机翻转、加噪声等,而验证集和测试集则不需要。下面看 dataset.py
的代码:
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T
class DogCat(data.Dataset):
def __init__(self, root, transforms=None, train=True, test=False):
'''
目标:获取所有图片地址,并根据训练、验证、测试划分数据
'''
self.test = test
imgs = [os.path.join(root, img) for img in os.listdir(root)]
# test1: data/test1/8973.jpg
# train: data/train/cat.10004.jpg
if self.test:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))
else:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))
imgs_num = len(imgs)
# 划分训练、验证集,验证:训练 = 3:7
if self.test:
self.imgs = imgs
elif train:
self.imgs = imgs[:int(0.7*imgs_num)]
else :
self.imgs = imgs[int(0.7*imgs_num):]
if transforms is None:
# 数据转换操作,测试验证和训练的数据转换有所区别
normalize = T.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])
# 测试集和验证集
if self.test or not train:
self.transforms = T.Compose([
T.Scale(224),
T.CenterCrop(224),
T.ToTensor(),
normalize
])
# 训练集
else :
self.transforms = T.Compose([
T.Scale(256),
T.RandomSizedCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize
])
def __getitem__(self, index):
'''
返回一张图片的数据
对于测试集,没有 label,返回图片 id,如 1000.jpg 返回 1000
'''
img_path = self.imgs[index]
if self.test:
label = int(self.imgs[index].split('.')[-2].split('/')[-1])
else:
label = 1 if 'dog' in img_path.split('/')[-1] else 0
data = Image.open(img_path)
data = self.transforms(data)
return data, label
def __len__(self):
'''
返回数据集中所有图片的个数
'''
return len(self.imgs)
关于数据集使用的注意事项,在上一章中已经提到,将文件读取等费时操作放在 __getitem__
函数中,利用多进程加速。避免一次性将所有图片都读进内存,不仅费时也会占用较大内存,而且不易进行数据增强等操作。另外在这里,我们将训练集中的 30%作为验证集,可用来检查模型的训练效果,避免过拟合。在使用时,我们可通过 dataloader 加载数据。
train_dataset = DogCat(opt.train_data_root, train=True)
trainloader = DataLoader(train_dataset,
batch_size = opt.batch_size,
shuffle = True,
num_workers = opt.num_workers)
for ii, (data, label) in enumerate(trainloader):
train()
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论