返回介绍

6.4 数据加载

发布于 2024-01-28 10:44:17 字数 3271 浏览 0 评论 0 收藏 0

数据的相关处理主要保存在 data/dataset.py 中。关于数据加载的相关操作,在上一章中我们已经提到过,其基本原理就是使用 Dataset 提供数据集的封装,再使用 Dataloader 实现数据并行加载。Kaggle 提供的数据包括训练集和测试集,而我们在实际使用中,还需专门从训练集中取出一部分作为验证集。

对于这三类数据集,其相应操作也不太一样,而如果专门写三个 Dataset ,则稍显复杂和冗余,因此这里通过加一些判断来区分。对于训练集,我们希望做一些数据增强处理,如随机裁剪、随机翻转、加噪声等,而验证集和测试集则不需要。下面看 dataset.py 的代码:

import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T

class DogCat(data.Dataset):
    
    def __init__(self, root, transforms=None, train=True, test=False):
        '''
        目标:获取所有图片地址,并根据训练、验证、测试划分数据
        '''
        self.test = test
        imgs = [os.path.join(root, img) for img in os.listdir(root)] 

        # test1: data/test1/8973.jpg
        # train: data/train/cat.10004.jpg 
        if self.test:
            imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))
        else:
            imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))
            
        imgs_num = len(imgs)
        
        # 划分训练、验证集,验证:训练 = 3:7
        if self.test:
            self.imgs = imgs
        elif train:
            self.imgs = imgs[:int(0.7*imgs_num)]
        else :
            self.imgs = imgs[int(0.7*imgs_num):]            
    
        if transforms is None:
        
            # 数据转换操作,测试验证和训练的数据转换有所区别
	        
            normalize = T.Normalize(mean = [0.485, 0.456, 0.406], 
                                     std = [0.229, 0.224, 0.225])

            # 测试集和验证集
            if self.test or not train: 
                self.transforms = T.Compose([
                    T.Scale(224),
                    T.CenterCrop(224),
                    T.ToTensor(),
                    normalize
                ]) 
            # 训练集
            else :
                self.transforms = T.Compose([
                    T.Scale(256),
                    T.RandomSizedCrop(224),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    normalize
                ]) 
                
        
    def __getitem__(self, index):
        '''
        返回一张图片的数据
        对于测试集,没有 label,返回图片 id,如 1000.jpg 返回 1000
        '''
        img_path = self.imgs[index]
        if self.test: 
             label = int(self.imgs[index].split('.')[-2].split('/')[-1])
        else: 
             label = 1 if 'dog' in img_path.split('/')[-1] else 0
        data = Image.open(img_path)
        data = self.transforms(data)
        return data, label
    
    def __len__(self):
        '''
        返回数据集中所有图片的个数
        '''
        return len(self.imgs)

关于数据集使用的注意事项,在上一章中已经提到,将文件读取等费时操作放在 __getitem__ 函数中,利用多进程加速。避免一次性将所有图片都读进内存,不仅费时也会占用较大内存,而且不易进行数据增强等操作。另外在这里,我们将训练集中的 30%作为验证集,可用来检查模型的训练效果,避免过拟合。在使用时,我们可通过 dataloader 加载数据。

train_dataset = DogCat(opt.train_data_root, train=True)
trainloader = DataLoader(train_dataset,
                        batch_size = opt.batch_size,
                        shuffle = True,
                        num_workers = opt.num_workers)
                  
for ii, (data, label) in enumerate(trainloader):
	train()

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文