Pytorch自定义数据集的__getitem__在处理异常时无限期地调用自身
我正在为我的 customdatset 类编写一个脚本,但每当我使用 for 循环访问数据时,都会收到 Index out of range
错误,如下所示:
cd = CustomDataset(df)
for img, target in cd:
pass
我意识到读取一些图像时可能会遇到问题(如果它们是损坏),因此我实现了一个 random_on_error
功能,如果当前图像出现问题,该功能会选择随机图像。我确信这就是问题所在。正如我注意到的,数据集中的所有 2160 个图像都被读取而没有任何问题(我打印每次迭代的索引号),但循环不会停止并读取第 2161 个图像,这会导致 Index 超出范围通过读取随机图像来处理的异常。这将永远持续下去。
这是我的课程:
class CustomDataset(Dataset):
def __init__(self, data: pd.DataFrame, augmentations=None, exit_on_error=False, random_on_error: bool = True):
"""
:param data: Pandas dataframe with paths as first column and target as second column
:param augmentations: Image transformations
:param exit_on_error: Stop execution once an exception rises. Cannot be used in conjunction with random_on_error
:param random_on_error: Upon an exception while reading an image, pick a random image and process it instead.
Cannot be used in conjuntion with exit_on_error.
"""
if exit_on_error and random_on_error:
raise ValueError("Only one of 'exit_on_error' and 'random_on_error' can be true")
self.image_paths = data.iloc[:, 0].to_numpy()
self.targets = data.iloc[:, 1].to_numpy()
self.augmentations = augmentations
self.exit_on_error = exit_on_error
self.random_on_error = random_on_error
def __len__(self):
return self.image_paths.shape[0]
def __getitem__(self, index):
image, target = None, None
try:
image, target = self.read_image_data(index)
except:
print(f"Exception occurred while reading image, {index}")
if self.exit_on_error:
print(self.image_paths[index])
raise
if self.random_on_error:
random_index = np.random.randint(0, self.__len__())
print(f"Replacing with random image, {random_index}")
image, target = self.read_image_data(random_index)
else: # todo implement return logic when self.random_on_error is false
return
if self.augmentations is not None:
aug_image = self.augmentations(image=image)
image = aug_image["image"]
image = np.transpose(image, (2, 0, 1))
return (
torch.tensor(image, dtype=torch.float),
torch.tensor(target, dtype=torch.long)
)
def read_image_data(self, index: int) -> ImagePlusTarget:
# reads image, converts to 3 channel ndarray if image is grey scale and converts rgba to rgb (if applicable)
target = self.targets[index]
image = io.imread(self.image_paths[index])
if image.ndim == 2:
image = np.expand_dims(image, 2)
if image.shape[2] > 3:
image = color.rgba2rgb(image)
return image, target
我相信问题出在 __getitem__()
中的 except
块(第 27 行),因为当我删除它时,代码工作正常。但我看不出这里的问题是什么。
任何帮助表示赞赏,谢谢
I'm writing a script for my customdatset class but I get Index out of range
error whenever I access data using for loop like so:
cd = CustomDataset(df)
for img, target in cd:
pass
I realized I might have a problem reading a few images (if they are corrupt) so I implemented a random_on_error
feature which chooses a random image if something is wrong with the current image. And I'm sure that's where the problem is. As I've noticed that all the 2160 images in the dataset are read without any hiccups(i print the index number for every iteration) but the loop would not stop and reads the 2161st image which results in an Index out of range
exception that gets handled by reading a random image. This continues forever.
Here is my class:
class CustomDataset(Dataset):
def __init__(self, data: pd.DataFrame, augmentations=None, exit_on_error=False, random_on_error: bool = True):
"""
:param data: Pandas dataframe with paths as first column and target as second column
:param augmentations: Image transformations
:param exit_on_error: Stop execution once an exception rises. Cannot be used in conjunction with random_on_error
:param random_on_error: Upon an exception while reading an image, pick a random image and process it instead.
Cannot be used in conjuntion with exit_on_error.
"""
if exit_on_error and random_on_error:
raise ValueError("Only one of 'exit_on_error' and 'random_on_error' can be true")
self.image_paths = data.iloc[:, 0].to_numpy()
self.targets = data.iloc[:, 1].to_numpy()
self.augmentations = augmentations
self.exit_on_error = exit_on_error
self.random_on_error = random_on_error
def __len__(self):
return self.image_paths.shape[0]
def __getitem__(self, index):
image, target = None, None
try:
image, target = self.read_image_data(index)
except:
print(f"Exception occurred while reading image, {index}")
if self.exit_on_error:
print(self.image_paths[index])
raise
if self.random_on_error:
random_index = np.random.randint(0, self.__len__())
print(f"Replacing with random image, {random_index}")
image, target = self.read_image_data(random_index)
else: # todo implement return logic when self.random_on_error is false
return
if self.augmentations is not None:
aug_image = self.augmentations(image=image)
image = aug_image["image"]
image = np.transpose(image, (2, 0, 1))
return (
torch.tensor(image, dtype=torch.float),
torch.tensor(target, dtype=torch.long)
)
def read_image_data(self, index: int) -> ImagePlusTarget:
# reads image, converts to 3 channel ndarray if image is grey scale and converts rgba to rgb (if applicable)
target = self.targets[index]
image = io.imread(self.image_paths[index])
if image.ndim == 2:
image = np.expand_dims(image, 2)
if image.shape[2] > 3:
image = color.rgba2rgb(image)
return image, target
I believe the problem is with the except
block (line 27) in __getitem__()
, as when I remove it the code works fine. But I cannot see what the problem here is.
Any help is appreciated, thanks
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
您希望 python 如何知道何时停止从
CustomDataset
读取数据?在
CustomDataset
中定义方法__getitem__
使其成为 python 中的可迭代对象。也就是说,Python 可以一项一项地迭代CustomDataset
的项目。但是,可迭代对象必须引发StopIteration
或IndexError
让 python 知道它到达了迭代的末尾。您可以将循环更改为显式使用数据集的
__len__
:或者,您应该确保
raise
IndexErrorindex
超出范围,则从数据集中提取 code>。这可以使用多个except
子句来完成。像这样的东西:
How do you expect python to know when to stop reading from your
CustomDataset
?Defining a method
__getitem__
inCustomDataset
makes it an iterable object in python. That is, python can iterate overCustomDataset
's items one by one. However, the iterable object must raise eitherStopIteration
orIndexError
for python to know it reached the end of the iterations.You can either change the loop to expicitly use the
__len__
of your dataset:Alternatively, you should make sure you
raise
IndexError
from your dataset ifindex
is outside the range. This can be done using multipleexcept
clauses.Something like: