我对如何将多个 dataloader
顺序培训组合。 我知道我可以使用 concatdataset
首先组合数据集,但这对我的用例不起作用。我有一个传递给每个数据加载程序的自定义 callate_fn
,此功能取决于基础 dataSet
的属性。因此,我将拥有一组自定义 dataLoader
喜欢以下内容:
def custom_collate(sample, ref):
data = clean_sample(torch.stack([x[0] for x in sample]), ref)
labels = torch.tensor([x[1] for x in sample])
return data, labels
class CollateLoader(torch.utils.data.DataLoader):
def __init__(self, ref, *args, **kwargs):
collate_fn = functools.partial(custom_collate, ref=ref)
super().__init__(collate_fn = collate_fn, *args, **kwargs)
其中 ref
是自定义 dataset
类的属性,并已传递关于 collateloader
的初始化。另外,我知道可以在数据集
中应用变换,但是在我的情况下,必须在批处理上进行。
那么,如何组合多个 dataloader
s?在pytorch-lightning LightningDatamodule
中,我们可以做类似的事情
def train_dataloader(self):
return [data_loader_1, data_loader_2]
,但这会返回批处理列表,而不是顺序批处理。
I'm interested in how I'd go about combining multiple DataLoader
s sequentially for training. I understand I can use ConcatDataset
to combine datasets first, but this does not work for my use case. I have a custom collate_fn
that is passed to each dataloader, and this function depends on an attribute of the underlying Dataset
. So, I'll have a set of custom DataLoader
s like the following:
def custom_collate(sample, ref):
data = clean_sample(torch.stack([x[0] for x in sample]), ref)
labels = torch.tensor([x[1] for x in sample])
return data, labels
class CollateLoader(torch.utils.data.DataLoader):
def __init__(self, ref, *args, **kwargs):
collate_fn = functools.partial(custom_collate, ref=ref)
super().__init__(collate_fn = collate_fn, *args, **kwargs)
Where ref
is a property of the custom Dataset
class and is passed on initialization of a CollateLoader
. Also, I know transforms can be applied in the Dataset
, but in my case it must be done batch-wise.
So, how would I go about combining multiple DataLoader
s? In the PyTorch-Lightning LightningDataModule
, we can do something like
def train_dataloader(self):
return [data_loader_1, data_loader_2]
But this will return a list of batches, not the batches sequentially.
发布评论
评论(2)
我遇到了同样的问题并找到了解决方法。我使用 Loops API 覆盖了纪元训练循环来自PytorchLightning,定义一个继承自pytorch_lightning.loops.TrainingEpochLoop的类CustomLoop,并重写advance()方法。我复制粘贴了 pytorch_lightning 的源代码并替换了这些 行:
这样,我就不用迭代CombinedLoader了一次迭代一个数据加载器。
然后,要使用此自定义循环,您必须替换 Trainer 中的默认循环:
I ran into the same problem and found a workaround. I overrided the epoch training loop using the Loops API from PytorchLightning, defining a class CustomLoop which inherits from pytorch_lightning.loops.TrainingEpochLoop, and overrided the advance() method. I copy pasted the source code from pytorch_lightning and replaced these lines with:
That way, instead of iterating over the CombinedLoader, i make it iterate over one dataloader at a time.
Then, to make use of this custom loop you have to replace the default loop in the Trainer:
您可以返回[train_dataloader,train_2_dataloader],然后拿出两个批次,每个数据加载程序,因此,您可以申请和损失
You can return [train_dataloader, train_2_dataloader] and then you take two batches, each dataloader, so, you can apply a for and sum losses