“无类型” Pytorch 模型测试部分中的对象不可迭代?
我正在使用 Pytorch 模型研究高光谱图像。当我想测试模型时,出现“NoneType objectis not iterable”错误。图像上有一个滑动窗口,我将其添加到测试部分。在 for 循环中,批处理值变为 None,我不知道如何修复它。我附上下面的工作代码。感谢您的帮助。
def test(net, img, hyperparams):
net.eval()
patch_size = hyperparams["patch_size"]
center_pixel = hyperparams["center_pixel"]
batch_size, device = hyperparams["batch_size"], hyperparams["device"]
n_classes = hyperparams["n_classes"]
kwargs = {
"step": hyperparams["test_stride"],
"window_size": (patch_size, patch_size),
}
probs = np.zeros(img.shape[:2] + (n_classes,))
iterations = count_sliding_window(img, **kwargs) // batch_size
for batch in tqdm(
grouper(batch_size, sliding_window(img, **kwargs)),
total=(iterations),
desc="Inference on the image",
):
with torch.no_grad():
if patch_size == 1:
data = [b[0][0, 0] for b in batch]
data = np.copy(data)
data = torch.from_numpy(data)
else:
data = [b[0] for b in batch]
data = np.copy(data)
data = data.transpose(0, 3, 1, 2)
data = torch.from_numpy(data)
data = data.unsqueeze(1)
indices = [b[1:] for b in batch]
data = data.to(device)
output = net(data)
if isinstance(output, tuple):
output = output[0]
output = output.to("cpu")
if patch_size == 1 or center_pixel:
output = output.numpy()
else:
output = np.transpose(output.numpy(), (0, 2, 3, 1))
for (x, y, w, h), out in zip(indices, output):
if center_pixel:
probs[x + w // 2, y + h // 2] += out
else:
probs[x : x + w, y : y + h] += out
return probs
probabilities = test(model, img, hyperparams)
然后我得到这个错误:
Inference on the image: 0%| | 0/210 [00:00<?, ?it/s]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-55-cabf8b6fa06d> in <module>()
----> 1 probabilities = test(model, img, hyperparams)
2 prediction = np.argmax(probabilities, axis=-1)
<ipython-input-54-b248a53c93ce> in test(net, img, hyperparams)
23 with torch.no_grad():
24 if patch_size == 1:
---> 25 data = [b[0][0, 0] for b in batch]
26 data = np.copy(data)
27 data = torch.from_numpy(data)
I' m studying on a hyperspectral image with Pytorch model. When i want to test the model i got the "NoneType objectis not iterable" error. There is a window sliding on the image and I added it to the test part. Within the for loop, the batch values are getting None and I don't know how to fix it. I am attaching the working code below. Thanks for your help.
def test(net, img, hyperparams):
net.eval()
patch_size = hyperparams["patch_size"]
center_pixel = hyperparams["center_pixel"]
batch_size, device = hyperparams["batch_size"], hyperparams["device"]
n_classes = hyperparams["n_classes"]
kwargs = {
"step": hyperparams["test_stride"],
"window_size": (patch_size, patch_size),
}
probs = np.zeros(img.shape[:2] + (n_classes,))
iterations = count_sliding_window(img, **kwargs) // batch_size
for batch in tqdm(
grouper(batch_size, sliding_window(img, **kwargs)),
total=(iterations),
desc="Inference on the image",
):
with torch.no_grad():
if patch_size == 1:
data = [b[0][0, 0] for b in batch]
data = np.copy(data)
data = torch.from_numpy(data)
else:
data = [b[0] for b in batch]
data = np.copy(data)
data = data.transpose(0, 3, 1, 2)
data = torch.from_numpy(data)
data = data.unsqueeze(1)
indices = [b[1:] for b in batch]
data = data.to(device)
output = net(data)
if isinstance(output, tuple):
output = output[0]
output = output.to("cpu")
if patch_size == 1 or center_pixel:
output = output.numpy()
else:
output = np.transpose(output.numpy(), (0, 2, 3, 1))
for (x, y, w, h), out in zip(indices, output):
if center_pixel:
probs[x + w // 2, y + h // 2] += out
else:
probs[x : x + w, y : y + h] += out
return probs
probabilities = test(model, img, hyperparams)
Then i got this error:
Inference on the image: 0%| | 0/210 [00:00<?, ?it/s]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-55-cabf8b6fa06d> in <module>()
----> 1 probabilities = test(model, img, hyperparams)
2 prediction = np.argmax(probabilities, axis=-1)
<ipython-input-54-b248a53c93ce> in test(net, img, hyperparams)
23 with torch.no_grad():
24 if patch_size == 1:
---> 25 data = [b[0][0, 0] for b in batch]
26 data = np.copy(data)
27 data = torch.from_numpy(data)
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
看起来是批次的内容。找到它并不容易,但这里有一些关于如何更接近错误的建议:
Python 3.10 有更好的错误消息(显示错误在行中的位置!)
尝试通过批次进行 enumerate(),然后找到缺陷索引
使用 try/expect (只是临时的)来中断此 ValueError 并打印出来并检查整个批次
如果还不错,则该特定批次为“none”,您可以跳过它:
data = [b[0][0, 0] for b in batch]
变成
data = [b[0][0, 0] for b in batch if b]
Looks like it's the content of batch. It's not easy to find, but here some advice on how to get closer to the bug:
Python 3.10 has better error messages (shows where in the line the error is!)
try to enumerate() through batches and then find the defect index
use a try/expect (just temporary) to break on this ValueError and print out and inspect the whole batch
If it's not bad, that this specific batch is "none" you can skip it:
data = [b[0][0, 0] for b in batch]
becomes
data = [b[0][0, 0] for b in batch if b]