按顺序组合多个 DataLoader

发布于 2025-01-19 10:58:26 字数 1082 浏览 4 评论 0 原文

我对如何将多个 dataloader 顺序培训组合。 我知道我可以使用 concatdataset 首先组合数据集,但这对我的用例不起作用。我有一个传递给每个数据加载程序的自定义 callate_fn ,此功能取决于基础 dataSet 的属性。因此,我将拥有一组自定义 dataLoader 喜欢以下内容:

def custom_collate(sample, ref):
    data = clean_sample(torch.stack([x[0] for x in sample]), ref)
    labels = torch.tensor([x[1] for x in sample])
    return data, labels

class CollateLoader(torch.utils.data.DataLoader):
    def __init__(self, ref, *args, **kwargs):
        collate_fn = functools.partial(custom_collate, ref=ref)
        super().__init__(collate_fn = collate_fn, *args, **kwargs)

其中 ref 是自定义 dataset 类的属性,并已传递关于 collat​​eloader 的初始化。另外,我知道可以在数据集中应用变换,但是在我的情况下,必须在批处理上进行。

那么,如何组合多个 dataloader s?在pytorch-lightning LightningDatamodule 中,我们可以做类似的事情

def train_dataloader(self):
    return [data_loader_1, data_loader_2]

,但这会返回批处理列表,而不是顺序批处理。

I'm interested in how I'd go about combining multiple DataLoaders sequentially for training. I understand I can use ConcatDataset to combine datasets first, but this does not work for my use case. I have a custom collate_fn that is passed to each dataloader, and this function depends on an attribute of the underlying Dataset. So, I'll have a set of custom DataLoaders like the following:

def custom_collate(sample, ref):
    data = clean_sample(torch.stack([x[0] for x in sample]), ref)
    labels = torch.tensor([x[1] for x in sample])
    return data, labels

class CollateLoader(torch.utils.data.DataLoader):
    def __init__(self, ref, *args, **kwargs):
        collate_fn = functools.partial(custom_collate, ref=ref)
        super().__init__(collate_fn = collate_fn, *args, **kwargs)

Where ref is a property of the custom Dataset class and is passed on initialization of a CollateLoader. Also, I know transforms can be applied in the Dataset, but in my case it must be done batch-wise.

So, how would I go about combining multiple DataLoaders? In the PyTorch-Lightning LightningDataModule, we can do something like

def train_dataloader(self):
    return [data_loader_1, data_loader_2]

But this will return a list of batches, not the batches sequentially.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(2

GRAY°灰色天空 2025-01-26 10:58:26

我遇到了同样的问题并找到了解决方法。我使用 Loops API 覆盖了纪元训练循环来自PytorchLightning,定义一个继承自pytorch_lightning.loops.TrainingEpochLoop的类CustomLoop,并重写advance()方法。我复制粘贴了 pytorch_lightning 的源代码并替换了这些

if not hasattr(self,'dataloader_idx'):
    self.dataloader_idx=0
if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
    batch_idx = self.batch_idx + 1
    batch = next(data_fetcher.dataloader.loaders[self.dataloader_idx])
    self.dataloader_idx+=1
    if self.dataloader_idx == len(data_fetcher.dataloader.loaders):
        self.dataloader_idx = 0
else:
    batch_idx, batch = next(data_fetcher)

这样,我就不用迭代CombinedLoader了一次迭代一个数据加载器。
然后,要使用此自定义循环,您必须替换 Trainer 中的默认循环:

trainer.fit_loop.replace(epoch_loop=CustomLoop)
trainer.fit(my_model)

I ran into the same problem and found a workaround. I overrided the epoch training loop using the Loops API from PytorchLightning, defining a class CustomLoop which inherits from pytorch_lightning.loops.TrainingEpochLoop, and overrided the advance() method. I copy pasted the source code from pytorch_lightning and replaced these lines with:

if not hasattr(self,'dataloader_idx'):
    self.dataloader_idx=0
if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
    batch_idx = self.batch_idx + 1
    batch = next(data_fetcher.dataloader.loaders[self.dataloader_idx])
    self.dataloader_idx+=1
    if self.dataloader_idx == len(data_fetcher.dataloader.loaders):
        self.dataloader_idx = 0
else:
    batch_idx, batch = next(data_fetcher)

That way, instead of iterating over the CombinedLoader, i make it iterate over one dataloader at a time.
Then, to make use of this custom loop you have to replace the default loop in the Trainer:

trainer.fit_loop.replace(epoch_loop=CustomLoop)
trainer.fit(my_model)
゛清羽墨安 2025-01-26 10:58:26

您可以返回[train_dataloader,train_2_dataloader],然后拿出两个批次,每个数据加载程序,因此,您可以申请和损失

You can return [train_dataloader, train_2_dataloader] and then you take two batches, each dataloader, so, you can apply a for and sum losses

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文