“无类型” Pytorch 模型测试部分中的对象不可迭代?

发布于 2025-01-16 02:19:53 字数 2528 浏览 0 评论 0原文

我正在使用 Pytorch 模型研究高光谱图像。当我想测试模型时,出现“NoneType objectis not iterable”错误。图像上有一个滑动窗口,我将其添加到测试部分。在 for 循环中,批处理值变为 None,我不知道如何修复它。我附上下面的工作代码。感谢您的帮助。

def test(net, img, hyperparams):

net.eval()
patch_size = hyperparams["patch_size"]
center_pixel = hyperparams["center_pixel"]
batch_size, device = hyperparams["batch_size"], hyperparams["device"]
n_classes = hyperparams["n_classes"]

kwargs = {
    "step": hyperparams["test_stride"],
    "window_size": (patch_size, patch_size),
}
probs = np.zeros(img.shape[:2] + (n_classes,))

iterations = count_sliding_window(img, **kwargs) // batch_size
for batch in tqdm(
    grouper(batch_size, sliding_window(img, **kwargs)),
    total=(iterations),
    desc="Inference on the image",
):
    with torch.no_grad():
        if patch_size == 1:
            data = [b[0][0, 0] for b in batch]
            data = np.copy(data)
            data = torch.from_numpy(data)
        else:
            data = [b[0] for b in batch]
            data = np.copy(data)
            data = data.transpose(0, 3, 1, 2)
            data = torch.from_numpy(data)
            data = data.unsqueeze(1)

        indices = [b[1:] for b in batch]
        data = data.to(device)
        output = net(data)
        if isinstance(output, tuple):
            output = output[0]
        output = output.to("cpu")

        if patch_size == 1 or center_pixel:
            output = output.numpy()
        else:
            output = np.transpose(output.numpy(), (0, 2, 3, 1))
        for (x, y, w, h), out in zip(indices, output):
            if center_pixel:
                probs[x + w // 2, y + h // 2] += out
            else:
                probs[x : x + w, y : y + h] += out
return probs

probabilities = test(model, img, hyperparams)

然后我得到这个错误:

Inference on the image:   0%|          | 0/210 [00:00<?, ?it/s]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-55-cabf8b6fa06d> in <module>()
----> 1 probabilities = test(model, img, hyperparams)
      2 prediction = np.argmax(probabilities, axis=-1)

<ipython-input-54-b248a53c93ce> in test(net, img, hyperparams)
        23         with torch.no_grad():
        24             if patch_size == 1:
   ---> 25                 data = [b[0][0, 0] for b in batch]
        26                 data = np.copy(data)
        27                 data = torch.from_numpy(data)

I' m studying on a hyperspectral image with Pytorch model. When i want to test the model i got the "NoneType objectis not iterable" error. There is a window sliding on the image and I added it to the test part. Within the for loop, the batch values are getting None and I don't know how to fix it. I am attaching the working code below. Thanks for your help.

def test(net, img, hyperparams):

net.eval()
patch_size = hyperparams["patch_size"]
center_pixel = hyperparams["center_pixel"]
batch_size, device = hyperparams["batch_size"], hyperparams["device"]
n_classes = hyperparams["n_classes"]

kwargs = {
    "step": hyperparams["test_stride"],
    "window_size": (patch_size, patch_size),
}
probs = np.zeros(img.shape[:2] + (n_classes,))

iterations = count_sliding_window(img, **kwargs) // batch_size
for batch in tqdm(
    grouper(batch_size, sliding_window(img, **kwargs)),
    total=(iterations),
    desc="Inference on the image",
):
    with torch.no_grad():
        if patch_size == 1:
            data = [b[0][0, 0] for b in batch]
            data = np.copy(data)
            data = torch.from_numpy(data)
        else:
            data = [b[0] for b in batch]
            data = np.copy(data)
            data = data.transpose(0, 3, 1, 2)
            data = torch.from_numpy(data)
            data = data.unsqueeze(1)

        indices = [b[1:] for b in batch]
        data = data.to(device)
        output = net(data)
        if isinstance(output, tuple):
            output = output[0]
        output = output.to("cpu")

        if patch_size == 1 or center_pixel:
            output = output.numpy()
        else:
            output = np.transpose(output.numpy(), (0, 2, 3, 1))
        for (x, y, w, h), out in zip(indices, output):
            if center_pixel:
                probs[x + w // 2, y + h // 2] += out
            else:
                probs[x : x + w, y : y + h] += out
return probs

probabilities = test(model, img, hyperparams)

Then i got this error:

Inference on the image:   0%|          | 0/210 [00:00<?, ?it/s]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-55-cabf8b6fa06d> in <module>()
----> 1 probabilities = test(model, img, hyperparams)
      2 prediction = np.argmax(probabilities, axis=-1)

<ipython-input-54-b248a53c93ce> in test(net, img, hyperparams)
        23         with torch.no_grad():
        24             if patch_size == 1:
   ---> 25                 data = [b[0][0, 0] for b in batch]
        26                 data = np.copy(data)
        27                 data = torch.from_numpy(data)

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

随梦而飞# 2025-01-23 02:19:53

看起来是批次的内容。找到它并不容易,但这里有一些关于如何更接近错误的建议:

  1. Python 3.10 有更好的错误消息(显示错误在行中的位置!)

  2. 尝试通过批次进行 enumerate(),然后找到缺陷索引

  3. 使用 try/expect (只是临时的)来中断此 ValueError 并打印出来并检查整个批次

  4. 如果还不错,则该特定批次为“none”,您可以跳过它:

    data = [b[0][0, 0] for b in batch]

    变成

    data = [b[0][0, 0] for b in batch if b]

Looks like it's the content of batch. It's not easy to find, but here some advice on how to get closer to the bug:

  1. Python 3.10 has better error messages (shows where in the line the error is!)

  2. try to enumerate() through batches and then find the defect index

  3. use a try/expect (just temporary) to break on this ValueError and print out and inspect the whole batch

  4. If it's not bad, that this specific batch is "none" you can skip it:

    data = [b[0][0, 0] for b in batch]

    becomes

    data = [b[0][0, 0] for b in batch if b]

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文