5.1 数据处理
在解决深度学习问题的过程中,往往需要花费大量的精力去处理数据,包括图像、文本、语音或其它二进制数据等。数据的处理对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练,更会提高模型效果。考虑到这点,PyTorch 提供了几个高效便捷的工具,以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载。
5.1.1 数据加载
在 PyTorch 中,数据加载可通过自定义的数据集对象。数据集对象被抽象为 Dataset
类,实现自定义的数据集需要继承 Dataset,并实现两个 Python 魔法方法:
__getitem__
:返回一条数据,或一个样本。obj[index]
等价于obj.__getitem__(index)
__len__
:返回样本的数量。len(obj)
等价于obj.__len__()
这里我们以 Kaggle 经典挑战赛 Dogs vs. Cat 的数据为例,来详细讲解如何处理数据。Dogs vs. Cats 是一个分类问题,判断一张图片是狗还是猫,其所有图片都存放在一个文件夹下,根据文件名的前缀判断是狗还是猫。
%env LS_COLORS = None
!tree --charset ascii data/dogcat/
env: LS_COLORS=None
data/dogcat/
|-- cat.12484.jpg
|-- cat.12485.jpg
|-- cat.12486.jpg
|-- cat.12487.jpg
|-- dog.12496.jpg
|-- dog.12497.jpg
|-- dog.12498.jpg
`-- dog.12499.jpg
0 directories, 8 files
import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np
class DogCat(data.Dataset):
def __init__(self, root):
imgs = os.listdir(root)
# 所有图片的绝对路径
# 这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片
self.imgs = [os.path.join(root, img) for img in imgs]
def __getitem__(self, index):
img_path = self.imgs[index]
# dog->1, cat->0
label = 1 if 'dog' in img_path.split('/')[-1] else 0
pil_img = Image.open(img_path)
array = np.asarray(pil_img)
data = t.from_numpy(array)
return data, label
def __len__(self):
return len(self.imgs)
dataset = DogCat('./data/dogcat/')
img, label = dataset[0] # 相当于调用 dataset.__getitem__(0)
for img, label in dataset:
print(img.size(), img.float().mean(), label)
torch.Size([375, 499, 3]) 116.81384992206635 1
torch.Size([499, 379, 3]) 171.80847887507645 0
torch.Size([236, 289, 3]) 130.30038805153168 0
torch.Size([377, 499, 3]) 151.7174171508357 1
torch.Size([374, 499, 3]) 115.51768778198108 0
torch.Size([375, 499, 3]) 150.50795635715878 1
torch.Size([400, 300, 3]) 128.154975 1
torch.Size([500, 497, 3]) 106.4915063715627 0
通过上面的代码,我们学习了如何自定义自己的数据集,并可以依次获取。但这里返回的数据不适合实际使用,因其具有如下两方面问题:
- 返回样本的形状不一,因每张图片的大小不一样,这对于需要取 batch 训练的神经网络来说很不友好
- 返回样本的数值较大,未归一化至[-1, 1]
针对上述问题,PyTorch 提供了 torchvision ^1 。它是一个视觉工具包,提供了很多视觉图像处理的工具,其中 transforms
模块提供了对 PIL Image
对象和 Tensor
对象的常用操作。
对 PIL Image 的操作包括:
Scale
:调整图片尺寸,长宽比保持不变CenterCrop
、RandomCrop
、RandomSizedCrop
: 裁剪图片Pad
:填充ToTensor
:将 PIL Image 对象转成 Tensor,会自动将[0, 255]归一化至[0, 1]
对 Tensor 的操作包括:
- Normalize:标准化,即减均值,除以标准差
- ToPILImage:将 Tensor 转为 PIL Image 对象
如果要对图片进行多个操作,可通过`Compose`函数将这些操作拼接起来,类似于`nn.Sequential`。注意,这些操作定义后是以函数的形式存在,真正使用时需调用它的`__call__`方法,这点类似于`nn.Module`。例如要将图片调整为$224\times 224$,首先应构建这个操作`trans = Resize((224, 224))`,然后调用`trans(img)`。下面我们就用 transforms 的这些操作来优化上面实现的 dataset。
import os
from PIL import Image
import numpy as np
from torchvision import transforms as T
transform = T.Compose([
T.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为 224 像素
T.CenterCrop(224), # 从图片中间切出 224*224 的图片
T.ToTensor(), # 将图片(Image) 转成 Tensor,归一化至[0, 1]
T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1, 1],规定均值和标准差
])
class DogCat(data.Dataset):
def __init__(self, root, transforms=None):
imgs = os.listdir(root)
self.imgs = [os.path.join(root, img) for img in imgs]
self.transforms=transforms
def __getitem__(self, index):
img_path = self.imgs[index]
label = 0 if 'dog' in img_path.split('/')[-1] else 1
data = Image.open(img_path)
if self.transforms:
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.imgs)
dataset = DogCat('./data/dogcat/', transforms=transform)
img, label = dataset[0]
for img, label in dataset:
print(img.size(), label)
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 1
除了上述操作之外,transforms 还可通过 Lambda
封装自定义的转换策略。例如想对 PIL Image 进行随机旋转,则可写成这样 trans=T.Lambda(lambda img: img.rotate(random()*360))
。
torchvision 已经预先实现了常用的 Dataset,包括前面使用过的 CIFAR-10,以及 ImageNet、COCO、MNIST、LSUN 等数据集,可通过诸如 torchvision.datasets.CIFAR10
来调用,具体使用方法请参看官方文档 ^1 。在这里介绍一个会经常使用到的 Dataset—— ImageFolder
,它的实现和上述的 DogCat
很相似。 ImageFolder
假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
它主要有四个参数:
root
:在 root 指定的路径下寻找图片transform
:对 PIL Image 进行的转换操作,transform 的输入是使用 loader 读取图片的返回对象target_transform
:对 label 的转换loader
:给定路径后如何读取图片,默认读取为 RGB 格式的 PIL Image 对象
label 是按照文件夹名顺序排序后存成字典,即{类名:类序号(从 0 开始)},一般来说最好直接将文件夹命名为从 0 开始的数字,这样会和 ImageFolder 实际的 label 一致,如果不是这种命名规范,建议看看 self.class_to_idx
属性以了解 label 和文件夹名的映射关系。
!tree --charset ASCII data/dogcat_2/
data/dogcat_2/
|-- cat
| |-- cat.12484.jpg
| |-- cat.12485.jpg
| |-- cat.12486.jpg
| `-- cat.12487.jpg
`-- dog
|-- dog.12496.jpg
|-- dog.12497.jpg
|-- dog.12498.jpg
`-- dog.12499.jpg
2 directories, 8 files
from torchvision.datasets import ImageFolder
dataset = ImageFolder('data/dogcat_2/')
# cat 文件夹的图片对应 label 0,dog 对应 1
dataset.class_to_idx
{'cat': 0, 'dog': 1}
# 所有图片的路径和对应的 label
dataset.imgs
[('data/dogcat_2/cat/cat.12484.jpg', 0),
('data/dogcat_2/cat/cat.12485.jpg', 0),
('data/dogcat_2/cat/cat.12486.jpg', 0),
('data/dogcat_2/cat/cat.12487.jpg', 0),
('data/dogcat_2/dog/dog.12496.jpg', 1),
('data/dogcat_2/dog/dog.12497.jpg', 1),
('data/dogcat_2/dog/dog.12498.jpg', 1),
('data/dogcat_2/dog/dog.12499.jpg', 1)]
# 没有任何的 transform,所以返回的还是 PIL Image 对象
dataset[0][1] # 第一维是第几张图,第二维为 1 返回 label
dataset[0][0] # 为 0 返回图片数据
<PIL.Image.Image image mode=RGB size=497x500 at 0x7FE0A70256D8>
# 加上 transform
normalize = T.Normalize(mean=[0.4, 0.4, 0.4], std=[0.2, 0.2, 0.2])
transform = T.Compose([
T.RandomResizedCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize,
])
dataset = ImageFolder('data/dogcat_2/', transform=transform)
# 深度学习中图片数据一般保存成 CxHxW,即通道数 x 图片高 x 图片宽
dataset[0][0].size()
torch.Size([3, 224, 224])
to_img = T.ToPILImage()
# 0.2 和 0.4 是标准差和均值的近似
to_img(dataset[0][0]*0.2+0.4)
<PIL.Image.Image image mode=RGB size=224x224 at 0x7FE0A59AF978>
Dataset
只负责数据的抽象,一次调用 __getitem__
只返回一个样本。前面提到过,在训练神经网络时,最好是对一个 batch 的数据进行操作,同时还需要对数据进行 shuffle 和并行加速等。对此,PyTorch 提供了 DataLoader
帮助我们实现这些功能。
DataLoader 的函数定义如下: DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
- dataset:加载的数据集(Dataset 对象)
- batch_size:batch size
- shuffle::是否将数据打乱
- sampler: 样本抽样,后续会详细介绍
- num_workers:使用多进程加载的进程数,0 代表不使用多进程
- collate_fn: 如何将多个样本数据拼接成一个 batch,一般使用默认的拼接方式即可
- pin_memory:是否将数据保存在 pin memory 区,pin memory 中的数据转到 GPU 会快一些
- droplast:dataset 中的数据个数可能不是 batchsize 的整数倍,drop_last 为 True 会将多出来不足一个 batch 的数据丢弃
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)
dataiter = iter(dataloader)
imgs, labels = next(dataiter)
imgs.size() # batch_size, channel, height, weight
torch.Size([3, 3, 224, 224])
dataloader 是一个可迭代的对象,意味着我们可以像使用迭代器一样使用它,例如:
for batch_datas, batch_labels in dataloader:
train()
或
dataiter = iter(dataloader)
batch_datas, batch_labesl = next(dataiter)
在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在 __getitem__
函数中将出现异常,此时最好的解决方案即是将出错的样本剔除。如果实在是遇到这种情况无法处理,则可以返回 None 对象,然后在 Dataloader
中实现自定义的 collate_fn
,将空对象过滤掉。但要注意,在这种情况下 dataloader 返回的 batch 数目会少于 batch_size。
class NewDogCat(DogCat): # 继承前面实现的 DogCat 数据集
def __getitem__(self, index):
try:
# 调用父类的获取函数,即 DogCat.__getitem__(self, index)
return super(NewDogCat,self).__getitem__(index)
except:
return None, None
from torch.utils.data.dataloader import default_collate # 导入默认的拼接方式
def my_collate_fn(batch):
'''
batch 中每个元素形如(data, label)
'''
# 过滤为 None 的数据
batch = list(filter(lambda x:x[0] is not None, batch))
if len(batch) == 0: return t.Tensor()
return default_collate(batch) # 用默认方式拼接过滤后的 batch 数据
dataset = NewDogCat('data/dogcat_wrong/', transforms=transform)
dataset[5]
(
( 0 ,.,.) =
1.9608 1.9804 2.0000 ... 1.3922 1.3333 1.2549
1.9804 2.0000 2.0392 ... 1.2157 1.0784 0.9412
1.9804 2.0196 2.0588 ... 0.9804 0.8039 0.6667
... ⋱ ...
1.5490 1.5098 1.4706 ... 1.6078 1.6078 1.6275
1.6863 1.6078 1.5098 ... 1.6078 1.6078 1.6275
1.7647 1.6863 1.5490 ... 1.6078 1.6275 1.6471
( 1 ,.,.) =
2.0980 2.1373 2.1569 ... 1.1961 1.1176 1.0392
2.0196 2.0588 2.0980 ... 1.0588 0.9020 0.7843
1.9804 2.0196 2.0588 ... 0.8431 0.6863 0.5294
... ⋱ ...
1.4314 1.4314 1.4118 ... 1.4706 1.4706 1.4902
1.5490 1.5098 1.4510 ... 1.4706 1.4706 1.4902
1.6275 1.5882 1.4902 ... 1.4706 1.4902 1.5098
( 2 ,.,.) =
1.9412 1.9412 1.9412 ... 0.5098 0.4706 0.4118
1.9608 1.9804 1.9804 ... 0.2353 0.1176 0.0000
1.9608 1.9804 2.0000 ... -0.0392 -0.1961 -0.3137
... ⋱ ...
0.2941 0.3137 0.3333 ... 0.4510 0.4510 0.4314
0.4510 0.4510 0.4314 ... 0.4510 0.4510 0.4314
0.5686 0.5686 0.5098 ... 0.4510 0.4510 0.4510
[torch.FloatTensor of size 3x224x224], 0)
dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn, num_workers=1,shuffle=True)
for batch_datas, batch_labels in dataloader:
print(batch_datas.size(),batch_labels.size())
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([1, 3, 224, 224]) torch.Size([1])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([1, 3, 224, 224]) torch.Size([1])
来看一下上述 batchsize 的大小。其中第 2 个的 batchsize 为 1,这是因为有一张图片损坏,导致其无法正常返回。而最后 1 个的 batchsize 也为 1,这是因为共有 9 张(包括损坏的文件)图片,无法整除 2(batchsize),因此最后一个 batch 的数据会少于 batchszie,可通过指定 drop_last=True
来丢弃最后一个不足 batchsize 的 batch。
对于诸如样本损坏或数据集加载异常等情况,还可以通过其它方式解决。例如但凡遇到异常情况,就随机取一张图片代替:
class NewDogCat(DogCat):
def __getitem__(self, index):
try:
return super(NewDogCat, self).__getitem__(index)
except:
new_index = random.randint(0, len(self)-1)
return self[new_index]
相比较丢弃异常图片而言,这种做法会更好一些,因为它能保证每个 batch 的数目仍是 batch_size。但在大多数情况下,最好的方式还是对数据进行彻底清洗。
DataLoader 里面并没有太多的魔法方法,它封装了 Python 的标准库 multiprocessing
,使其能够实现多进程加速。在此提几点关于 Dataset 和 DataLoader 使用方面的建议:
- 高负载的操作放在
__getitem__
中,如加载图片等。 - dataset 中应尽量只包含只读对象,避免修改任何可变对象,利用多线程进行操作。
第一点是因为多进程会并行的调用 __getitem__
函数,将负载高的放在 __getitem__
函数中能够实现并行加速。 第二点是因为 dataloader 使用多进程加载,如果在 Dataset
实现中使用了可变对象,可能会有意想不到的冲突。在多线程/多进程中,修改一个可变对象,需要加锁,但是 dataloader 的设计使得其很难加锁(在实际使用中也应尽量避免锁的存在),因此最好避免在 dataset 中修改可变对象。例如下面就是一个不好的例子,在多进程处理中 self.num
可能与预期不符,这种问题不会报错,因此难以发现。如果一定要修改可变对象,建议使用 Python 标准库 Queue
中的相关数据结构。
class BadDataset(Dataset):
def __init__(self):
self.datas = range(100)
self.num = 0 # 取数据的次数
def __getitem__(self, index):
self.num += 1
return self.datas[index]
使用 Python multiprocessing
库的另一个问题是,在使用多进程时,如果主程序异常终止(比如用 Ctrl+C 强行退出),相应的数据加载进程可能无法正常退出。这时你可能会发现程序已经退出了,但 GPU 显存和内存依旧被占用着,或通过 top
、 ps aux
依旧能够看到已经退出的程序,这时就需要手动强行杀掉进程。建议使用如下命令:
ps x | grep <cmdline> | awk '{print $1}' | xargs kill
ps x
:获取当前用户的所有进程grep <cmdline>
:找到已经停止的 PyTorch 程序的进程,例如你是通过 python train.py 启动的,那你就需要写grep 'python train.py'
awk '{print $1}'
:获取进程的 pidxargs kill
:杀掉进程,根据需要可能要写成xargs kill -9
强制杀掉进程
在执行这句命令之前,建议先打印确认一下是否会误杀其它进程
ps x | grep <cmdline> | ps x
PyTorch 中还单独提供了一个 sampler
模块,用来对数据进行采样。常用的有随机采样器: RandomSampler
,当 dataloader 的 shuffle
参数为 True 时,系统会自动调用这个采样器,实现打乱数据。默认的是采用 SequentialSampler
,它会按顺序一个一个进行采样。这里介绍另外一个很有用的采样方法: WeightedRandomSampler
,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。
构建 WeightedRandomSampler
时需提供两个参数:每个样本的权重 weights
、共选取的样本总数 num_samples
,以及一个可选参数 replacement
。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。 replacement
用于指定是否可以重复选取某一个样本,默认为 True,即允许在一个 epoch 中重复采样某一个数据。如果设为 False,则当某一类的样本被全部选取完,但其样本数目仍未达到 num_samples 时,sampler 将不会再从该类中选择数据,此时可能导致 weights
参数失效。下面举例说明。
dataset = DogCat('data/dogcat/', transforms=transform)
# 狗的图片被取出的概率是猫的概率的两倍
# 两类图片被取出的概率与 weights 的绝对大小无关,只和比值有关
weights = [2 if label == 1 else 1 for data, label in dataset]
weights
[1, 2, 2, 1, 2, 1, 1, 2]
from torch.utils.data.sampler import WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
num_samples=9,\
replacement=True)
dataloader = DataLoader(dataset,
batch_size=3,
sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())
[0, 0, 1]
[0, 0, 0]
[0, 1, 1]
可见猫狗样本比例约为 1:2,另外一共只有 8 个样本,但是却返回了 9 个,说明肯定有被重复返回的,这就是 replacement 参数的作用,下面将 replacement 设为 False 试试。
sampler = WeightedRandomSampler(weights, 8, replacement=False)
dataloader = DataLoader(dataset, batch_size=4, sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())
[1, 1, 0, 1]
[0, 0, 1, 0]
在这种情况下,num_samples 等于 dataset 的样本总数,为了不重复选取,sampler 会将每个样本都返回,这样就失去 weight 参数的意义了。
从上面的例子可见 sampler 在样本采样中的作用:如果指定了 sampler,shuffle 将不再生效,并且 sampler.num_samples 会覆盖 dataset 的实际大小,即一个 epoch 返回的图片总数取决于 sampler.num_samples
。
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论