Pytorch自定义数据集的__getitem__在处理异常时无限期地调用自身

发布于 2025-01-11 04:59:27 字数 3087 浏览 0 评论 0原文

我正在为我的 customdatset 类编写一个脚本,但每当我使用 for 循环访问数据时,都会收到 Index out of range 错误,如下所示:

cd = CustomDataset(df)
for img, target in cd:
   pass

我意识到读取一些图像时可能会遇到问题(如果它们是损坏),因此我实现了一个 random_on_error 功能,如果当前图像出现问题,该功能会选择随机图像。我确信这就是问题所在。正如我注意到的,数据集中的所有 2160 个图像都被读取而没有任何问题(我打印每次迭代的索引号),但循环不会停止并读取第 2161 个图像,这会导致 Index 超出范围通过读取随机图像来处理的异常。这将永远持续下去。

这是我的课程:

class CustomDataset(Dataset):
    def __init__(self, data: pd.DataFrame, augmentations=None, exit_on_error=False, random_on_error: bool = True):
        """
        :param data: Pandas dataframe with paths as first column and target as second column
        :param augmentations: Image transformations
        :param exit_on_error: Stop execution once an exception rises. Cannot be used in conjunction with random_on_error
        :param random_on_error: Upon an exception while reading an image, pick a random image and process it instead.
        Cannot be used in conjuntion with exit_on_error.
        """
 
        if exit_on_error and random_on_error:
            raise ValueError("Only one of 'exit_on_error' and 'random_on_error' can be true")
 
        self.image_paths = data.iloc[:, 0].to_numpy()
        self.targets = data.iloc[:, 1].to_numpy()
        self.augmentations = augmentations
        self.exit_on_error = exit_on_error
        self.random_on_error = random_on_error
 
    def __len__(self):
        return self.image_paths.shape[0]
 
    def __getitem__(self, index):
        image, target = None, None
        try:
            image, target = self.read_image_data(index)
        except:
            print(f"Exception occurred while reading image, {index}")
            if self.exit_on_error:
                print(self.image_paths[index])
                raise
            if self.random_on_error:
                random_index = np.random.randint(0, self.__len__())
                print(f"Replacing with random image, {random_index}")
                image, target = self.read_image_data(random_index)
 
            else:  # todo implement return logic when self.random_on_error is false
                return
 
        if self.augmentations is not None:
            aug_image = self.augmentations(image=image)
            image = aug_image["image"]
 
        image = np.transpose(image, (2, 0, 1))
 
        return (
            torch.tensor(image, dtype=torch.float),
            torch.tensor(target, dtype=torch.long)
        )
 
    def read_image_data(self, index: int) -> ImagePlusTarget: 
        # reads image, converts to 3 channel ndarray if image is grey scale and converts rgba to rgb (if applicable)
        target = self.targets[index]
        image = io.imread(self.image_paths[index])
        if image.ndim == 2:
            image = np.expand_dims(image, 2)
        if image.shape[2] > 3:
            image = color.rgba2rgb(image)
 
        return image, target

我相信问题出在 __getitem__() 中的 except 块(第 27 行),因为当我删除它时,代码工作正常。但我看不出这里的问题是什么。

任何帮助表示赞赏,谢谢

I'm writing a script for my customdatset class but I get Index out of range error whenever I access data using for loop like so:

cd = CustomDataset(df)
for img, target in cd:
   pass

I realized I might have a problem reading a few images (if they are corrupt) so I implemented a random_on_error feature which chooses a random image if something is wrong with the current image. And I'm sure that's where the problem is. As I've noticed that all the 2160 images in the dataset are read without any hiccups(i print the index number for every iteration) but the loop would not stop and reads the 2161st image which results in an Index out of range exception that gets handled by reading a random image. This continues forever.

Here is my class:

class CustomDataset(Dataset):
    def __init__(self, data: pd.DataFrame, augmentations=None, exit_on_error=False, random_on_error: bool = True):
        """
        :param data: Pandas dataframe with paths as first column and target as second column
        :param augmentations: Image transformations
        :param exit_on_error: Stop execution once an exception rises. Cannot be used in conjunction with random_on_error
        :param random_on_error: Upon an exception while reading an image, pick a random image and process it instead.
        Cannot be used in conjuntion with exit_on_error.
        """
 
        if exit_on_error and random_on_error:
            raise ValueError("Only one of 'exit_on_error' and 'random_on_error' can be true")
 
        self.image_paths = data.iloc[:, 0].to_numpy()
        self.targets = data.iloc[:, 1].to_numpy()
        self.augmentations = augmentations
        self.exit_on_error = exit_on_error
        self.random_on_error = random_on_error
 
    def __len__(self):
        return self.image_paths.shape[0]
 
    def __getitem__(self, index):
        image, target = None, None
        try:
            image, target = self.read_image_data(index)
        except:
            print(f"Exception occurred while reading image, {index}")
            if self.exit_on_error:
                print(self.image_paths[index])
                raise
            if self.random_on_error:
                random_index = np.random.randint(0, self.__len__())
                print(f"Replacing with random image, {random_index}")
                image, target = self.read_image_data(random_index)
 
            else:  # todo implement return logic when self.random_on_error is false
                return
 
        if self.augmentations is not None:
            aug_image = self.augmentations(image=image)
            image = aug_image["image"]
 
        image = np.transpose(image, (2, 0, 1))
 
        return (
            torch.tensor(image, dtype=torch.float),
            torch.tensor(target, dtype=torch.long)
        )
 
    def read_image_data(self, index: int) -> ImagePlusTarget: 
        # reads image, converts to 3 channel ndarray if image is grey scale and converts rgba to rgb (if applicable)
        target = self.targets[index]
        image = io.imread(self.image_paths[index])
        if image.ndim == 2:
            image = np.expand_dims(image, 2)
        if image.shape[2] > 3:
            image = color.rgba2rgb(image)
 
        return image, target

I believe the problem is with the except block (line 27) in __getitem__(), as when I remove it the code works fine. But I cannot see what the problem here is.

Any help is appreciated, thanks

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

空心空情空意 2025-01-18 04:59:27

您希望 python 如何知道何时停止从 CustomDataset 读取数据?

CustomDataset 中定义方法 __getitem__ 使其成为 python 中的可迭代对象。也就是说,Python 可以一项一项地迭代 CustomDataset 的项目。但是,可迭代对象必须引发 StopIterationIndexError 让 python 知道它到达了迭代的末尾。

您可以将循环更改为显式使用数据集的__len__

for i in range(len(cd)):
  img, target = cd[i] 

或者,您应该确保raise IndexErrorindex 超出范围,则从数据集中提取 code>。这可以使用多个except子句来完成。
像这样的东西:

try: 
  image, target = self.read_image_data(index)
except IndexError:
  raise  # do not handle this error
except:
  # treat all other exceptions (corrupt images) here
  ...

How do you expect python to know when to stop reading from your CustomDataset?

Defining a method __getitem__ in CustomDataset makes it an iterable object in python. That is, python can iterate over CustomDataset's items one by one. However, the iterable object must raise either StopIteration or IndexError for python to know it reached the end of the iterations.

You can either change the loop to expicitly use the __len__ of your dataset:

for i in range(len(cd)):
  img, target = cd[i] 

Alternatively, you should make sure you raise IndexError from your dataset if index is outside the range. This can be done using multiple except clauses.
Something like:

try: 
  image, target = self.read_image_data(index)
except IndexError:
  raise  # do not handle this error
except:
  # treat all other exceptions (corrupt images) here
  ...
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文