如何将带有亮度的数据增强添加到图像分类框架中?

发布于 2025-01-14 07:15:02 字数 1379 浏览 5 评论 0原文

我正在使用 pytorch 进行图像分类,使用 github。 我需要在训练模型之前添加数据增强, 我选择了蛋白处理来做到这一点。 这是我添加白蛋白时的代码:

data_transform = {
    "train": A.Compose([ 
                        A.RandomResizedCrop(224,224),
                        A.HorizontalFlip(p=0.5),
                        A.RandomGamma(gamma_limit=(80, 120), eps=None, always_apply=False, p=0.5),
                        A.RandomBrightnessContrast (p=0.5),
                        A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=0.5),
                        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
                        A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
                        A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                        ToTensorV2(),]),
    "val": A.Compose([
                      A.Resize(256,256),
                      A.CenterCrop(224,224),
                      A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                      ToTensorV2()])}

我收到此错误:

KeyError:在 DataLoader 工作进程 0 中捕获 KeyError。

KeyError:“您必须将数据作为命名参数传递给增强,例如:aug(image=image)”

I am using pytorch for image classification using this code from github.
I need to add data augmentation before training my model,
I chose albumentation to do this.
here is my code when I add albumentation:

data_transform = {
    "train": A.Compose([ 
                        A.RandomResizedCrop(224,224),
                        A.HorizontalFlip(p=0.5),
                        A.RandomGamma(gamma_limit=(80, 120), eps=None, always_apply=False, p=0.5),
                        A.RandomBrightnessContrast (p=0.5),
                        A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=0.5),
                        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
                        A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
                        A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                        ToTensorV2(),]),
    "val": A.Compose([
                      A.Resize(256,256),
                      A.CenterCrop(224,224),
                      A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                      ToTensorV2()])}

I got this error:

KeyError: Caught KeyError in DataLoader worker process 0.

KeyError: 'You have to pass data to augmentations as named arguments, for example: aug(image=image)'

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(3

远山浅 2025-01-21 07:15:02

此 Albumentations 函数采用位置参数“image”并返回一个字典。这是使用它的示例:

transforms = A.Compose([
                A.augmentations.geometric.rotate.Rotate(limit=15,p=0.5),
                A.Perspective(scale=[0,0.1],keep_size=False,fit_output=False,p=1),
                A.Resize(224, 224),
                A.HorizontalFlip(p=0.5),
                A.GaussNoise(var_limit=(10.0, 50.0), mean=0),
                A.RandomToneCurve(scale=0.5,p=1),
                A.Normalize(mean=[0.5, 0.5, 0.5],std=[0.225, 0.225, 0.225]),
                ToTensorV2()
            ])

img = cv2.imread("dog.png")
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
transformed_img = transforms(image=img)["image"]

This Albumentations function takes a positional argument 'image' and returns a dictionnary. This is a sample to use it :

transforms = A.Compose([
                A.augmentations.geometric.rotate.Rotate(limit=15,p=0.5),
                A.Perspective(scale=[0,0.1],keep_size=False,fit_output=False,p=1),
                A.Resize(224, 224),
                A.HorizontalFlip(p=0.5),
                A.GaussNoise(var_limit=(10.0, 50.0), mean=0),
                A.RandomToneCurve(scale=0.5,p=1),
                A.Normalize(mean=[0.5, 0.5, 0.5],std=[0.225, 0.225, 0.225]),
                ToTensorV2()
            ])

img = cv2.imread("dog.png")
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
transformed_img = transforms(image=img)["image"]
你的心境我的脸 2025-01-21 07:15:02

您可以通过编写如下所示的课程来完成您想做的事情:

import albumentations as A
import cv2 

class ImageDataset(Dataset):
    def __init__(self, images_filepaths, transform=None):
        self.images_filepaths = images_filepaths
        self.transform = transform

    def __len__(self):
        return len(self.images_filepaths)

    def __getitem__(self, idx):
        image_filepath = self.images_filepaths[idx]
        image = cv2.imread(image_filepath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform is not None:
            image = self.transform(image=image)["image"]
        return image
    

train_transform = A.Compose([
    A.RandomResizedCrop(224,224),
    A.HorizontalFlip(p=0.5),
    A.RandomGamma(gamma_limit=(80, 120), eps=None, always_apply=False, p=0.5),
    A.RandomBrightnessContrast (p=0.5),
    A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=0.5),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
    A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
    A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ToTensorV2(),
])


val_transform = A.Compose([
    A.Resize(256,256),
    A.CenterCrop(224,224),
    A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ToTensorV2(),
])

train_dataset = ImageDataset(images_filepaths=train_images_filepaths, transform=train_transform)
val_dataset = ImageDataset(images_filepaths=val_images_filepaths, transform=val_transform)

You can do what you want with writing a class like below:

import albumentations as A
import cv2 

class ImageDataset(Dataset):
    def __init__(self, images_filepaths, transform=None):
        self.images_filepaths = images_filepaths
        self.transform = transform

    def __len__(self):
        return len(self.images_filepaths)

    def __getitem__(self, idx):
        image_filepath = self.images_filepaths[idx]
        image = cv2.imread(image_filepath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform is not None:
            image = self.transform(image=image)["image"]
        return image
    

train_transform = A.Compose([
    A.RandomResizedCrop(224,224),
    A.HorizontalFlip(p=0.5),
    A.RandomGamma(gamma_limit=(80, 120), eps=None, always_apply=False, p=0.5),
    A.RandomBrightnessContrast (p=0.5),
    A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=0.5),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
    A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
    A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ToTensorV2(),
])


val_transform = A.Compose([
    A.Resize(256,256),
    A.CenterCrop(224,224),
    A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ToTensorV2(),
])

train_dataset = ImageDataset(images_filepaths=train_images_filepaths, transform=train_transform)
val_dataset = ImageDataset(images_filepaths=val_images_filepaths, transform=val_transform)
鱼忆七猫命九 2025-01-21 07:15:02

我正确使用你的建议吗?
我有好图像和坏图像(水下图像)的数据集

import os
import json
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
import random
from model import resnet34
import cv2 


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
class ImageDataset():
    def __init__(self, images_filepaths, transform=None):
        self.images_filepaths = images_filepaths
        self.transform = transform

    def __len__(self):
        return len(self.images_filepaths)

    def __getitem__(self, idx):
        image_filepath = self.images_filepaths[idx]
        image = cv2.imread(image_filepath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform is not None:
            image = self.transform(image=image)["image"]
        return image
train_transform = A.Compose([
    A.RandomResizedCrop(224,224),
    A.HorizontalFlip(p=0.5),
    A.RandomGamma(gamma_limit=(80, 120), eps=None, always_apply=False, p=0.5),
    A.RandomBrightnessContrast (p=0.5),
    A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=0.5),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
    A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
    A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ToTensorV2(),
])


    val_transform = A.Compose([
      A.Resize(256,256),
      A.CenterCrop(224,224),
      A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
      ToTensorV2(),
])


data_root = os.path.abspath(os.path.join(os.getcwd(), "/content/gdrive/"))  # get             data root path
image_path = os.path.join(data_root, "MyDrive" , "totalimages")  # flower data set path
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)


train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                     transform=train_transform)
train_num = len(train_dataset)

# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
{'bad':1, 'good':2} #
flower_list = train_dataset.class_to_idx
image_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in image_list.items()) #dictionary
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
       json_file.write(json_str)

batch_size = 64
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
print('Using {} dataloader workers every process'.format(nw))

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=nw)
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                        transform=val_transform)
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=False,
                                              num_workers=nw)

   print("using {} images for training, {} images for  validation.".format(train_num,
                                                                       val_num))

net = resnet34()
# load pretrain weights
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth

model_weight_path = "./resnet34-pre.pth"

model_weight_path = "/content/gdrive/MyDrive/resnet34-333f7ec4.pth"
assert os.path.exists(model_weight_path), "file {} does not    exist.".format(model_weight_path)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
net.load_state_dict(torch.load(model_weight_path, map_location=device))
# for param in net.parameters():
#     param.requires_grad = False

# change fc layer structure
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 5)
net.to(device)

# define loss function
loss_function = nn.CrossEntropyLoss()

# construct an optimizer
params = [p for p in net.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.0001)

epochs = 10
best_acc = 0.0
save_path = './resNet34.pth'
train_steps = len(train_loader)
for epoch in range(epochs):
    # train
    net.train()
    running_loss = 0.0
    train_bar = tqdm(train_loader, file=sys.stdout)
    for step, data in enumerate(train_bar):
        images, labels = data
        optimizer.zero_grad()
        logits = net(images.to(device))
        loss = loss_function(logits, labels.to(device))
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()

        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                 epochs,
                                                                 loss)

    # validate
    net.eval()
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        val_bar = tqdm(validate_loader, file=sys.stdout)
        for val_data in val_bar:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            # loss = loss_function(outputs, test_labels)
            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

            val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                       epochs)

    val_accurate = acc / val_num
    print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
          (epoch + 1, running_loss / train_steps, val_accurate))

    if val_accurate > best_acc:
        best_acc = val_accurate
        torch.save(net.state_dict(), save_path)

print('Finished Training')


if __name__ == '__main__':
main()

Am I using your suggestion correctly?
I have dataset of good and bad images (underwater images)

import os
import json
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
import random
from model import resnet34
import cv2 


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
class ImageDataset():
    def __init__(self, images_filepaths, transform=None):
        self.images_filepaths = images_filepaths
        self.transform = transform

    def __len__(self):
        return len(self.images_filepaths)

    def __getitem__(self, idx):
        image_filepath = self.images_filepaths[idx]
        image = cv2.imread(image_filepath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform is not None:
            image = self.transform(image=image)["image"]
        return image
train_transform = A.Compose([
    A.RandomResizedCrop(224,224),
    A.HorizontalFlip(p=0.5),
    A.RandomGamma(gamma_limit=(80, 120), eps=None, always_apply=False, p=0.5),
    A.RandomBrightnessContrast (p=0.5),
    A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=0.5),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
    A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
    A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ToTensorV2(),
])


    val_transform = A.Compose([
      A.Resize(256,256),
      A.CenterCrop(224,224),
      A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
      ToTensorV2(),
])


data_root = os.path.abspath(os.path.join(os.getcwd(), "/content/gdrive/"))  # get             data root path
image_path = os.path.join(data_root, "MyDrive" , "totalimages")  # flower data set path
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)


train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                     transform=train_transform)
train_num = len(train_dataset)

# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
{'bad':1, 'good':2} #
flower_list = train_dataset.class_to_idx
image_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in image_list.items()) #dictionary
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
       json_file.write(json_str)

batch_size = 64
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
print('Using {} dataloader workers every process'.format(nw))

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=nw)
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                        transform=val_transform)
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=False,
                                              num_workers=nw)

   print("using {} images for training, {} images for  validation.".format(train_num,
                                                                       val_num))

net = resnet34()
# load pretrain weights
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth

model_weight_path = "./resnet34-pre.pth"

model_weight_path = "/content/gdrive/MyDrive/resnet34-333f7ec4.pth"
assert os.path.exists(model_weight_path), "file {} does not    exist.".format(model_weight_path)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
net.load_state_dict(torch.load(model_weight_path, map_location=device))
# for param in net.parameters():
#     param.requires_grad = False

# change fc layer structure
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 5)
net.to(device)

# define loss function
loss_function = nn.CrossEntropyLoss()

# construct an optimizer
params = [p for p in net.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.0001)

epochs = 10
best_acc = 0.0
save_path = './resNet34.pth'
train_steps = len(train_loader)
for epoch in range(epochs):
    # train
    net.train()
    running_loss = 0.0
    train_bar = tqdm(train_loader, file=sys.stdout)
    for step, data in enumerate(train_bar):
        images, labels = data
        optimizer.zero_grad()
        logits = net(images.to(device))
        loss = loss_function(logits, labels.to(device))
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()

        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                 epochs,
                                                                 loss)

    # validate
    net.eval()
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        val_bar = tqdm(validate_loader, file=sys.stdout)
        for val_data in val_bar:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            # loss = loss_function(outputs, test_labels)
            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

            val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                       epochs)

    val_accurate = acc / val_num
    print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
          (epoch + 1, running_loss / train_steps, val_accurate))

    if val_accurate > best_acc:
        best_acc = val_accurate
        torch.save(net.state_dict(), save_path)

print('Finished Training')


if __name__ == '__main__':
main()
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文