训练循环不会在数据加载器中的num_workers> 1开始
我有一个NLP分类问题,其中我有一个数据载体对象,并且它的代码是
train_patentload = DataLoader(train_patentset, batch_size=4, shuffle=True,num_workers=2)
在运行训练循环时它不起作用并卡住了,尽管当num_workers = 2被删除时,代码正常运行。我已经被困了一段时间了,我感谢
在数据框架中的数据集的
class PatentDataset(Dataset):
def __init__(self, df):
self.df= df
def __len__(self):
return len(df)
def __getitem__(self, ind):
conv_dict = {0 : [1., 0., 0., 0., 0.], 0.25 : [1., 1., 0., 0., 0.], 0.5 : [1., 1., 1., 0., 0.], 0.75 : [1., 1., 1., 1., 0.], 1 : [1., 1., 1., 1., 1.]}
inputs = df.iloc[ind, 1:-1].to_list()
text = ' #^& '.join(inputs)
label = np.array(conv_dict[df.iloc[ind, -1]])
label = torch.as_tensor(label)
text = tokenizer(text, padding='max_length', max_length = 256, truncation=True, return_tensors="pt")
return text, label
train_patentset = PatentDataset(train)
帮助。
I have an NLP classification problem where I have a DataLoader object and its code is
train_patentload = DataLoader(train_patentset, batch_size=4, shuffle=True,num_workers=2)
When I run the training loop it doesn't work and gets stuck though the code runs normally when num_workers=2 is removed. I have been stuck for a while now and I'd appreciate the help
the code of the DataSet
class PatentDataset(Dataset):
def __init__(self, df):
self.df= df
def __len__(self):
return len(df)
def __getitem__(self, ind):
conv_dict = {0 : [1., 0., 0., 0., 0.], 0.25 : [1., 1., 0., 0., 0.], 0.5 : [1., 1., 1., 0., 0.], 0.75 : [1., 1., 1., 1., 0.], 1 : [1., 1., 1., 1., 1.]}
inputs = df.iloc[ind, 1:-1].to_list()
text = ' #^& '.join(inputs)
label = np.array(conv_dict[df.iloc[ind, -1]])
label = torch.as_tensor(label)
text = tokenizer(text, padding='max_length', max_length = 256, truncation=True, return_tensors="pt")
return text, label
train_patentset = PatentDataset(train)
where train is the dataframe.
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
数字工人2导致Pytorch通过多处理产生2个过程。然后,这两个过程都有自己的train_patentset实例。如果省略num_worker,则将其设置为0,这意味着数据加载是在主线程/进程中完成的。
由于您没有发布更多代码,因此唯一的建议是查看数据集的代码,看看是否有多个点可用于多处理。您是否以某种方式使用数据集类别外的信息,依靠变量或其他内容进行通信。
最好的是更新问题并包括数据集的代码。
Num workers 2 causes pytorch to spawn 2 processes via multiprocessing. Both processes then have their own instance of the train_patentset. If you omit the num_worker it is set to 0 which means the data loading is done in the main thread/process.
Since you did not post more code the only advise is to look through the code of your dataset and see if there are points which are not save for multi processing. Do you somehow use information from outside the dataset class, rely on communication via variables or something.
Best would be to update the question and include the code for the dataset.