如何在Pytorch中正确加载GAN检查点?
我在 256x256 图像上训练了 GAN,基本上扩展了 PyTorch 自己的 DCGAN 教程 中的代码以适应更高分辨率的图像。模型和优化器初始化如下所示:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gen = Generator(...).to(device)
disc = Discriminator(...).to(device)
opt_gen = optim.Adam(gen.parameters(), ...)
opt_disc = optim.Adam(disc.parameters(), ...)
gen.train()
disc.train()
GAN 生成了高质量的样本。在每个时期,我都会使用相同的输入向量 fixed_noise
生成一些图像(并使用 SummaryWriter
在 Tensorboard 上查看它们)到生成器:
with torch.no_grad():
fake = gen(fixed_noise)
img_grid_real = torchvision.utils.make_grid(
real[:NUM_VISUALIZATION_SAMPLES], normalize=True
)
img_grid_fake = torchvision.utils.make_grid(
fake[:NUM_VISUALIZATION_SAMPLES], normalize=True
)
writer_real.add_image("Real", img_grid_real, global_step=step)
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
我在之后保存了 GAN每个训练周期如下:
checkpoint = {
"gen_state": gen.state_dict(),
"gen_optimizer": opt_gen.state_dict(),
"disc_state": disc.state_dict(),
"disc_optimizer": opt_disc.state_dict()
}
torch.save(checkpoint, f"checkpoints/checkpoint_{epoch_number}.pth.tar")
到目前为止,我已经在配备 NVIDIA T4 GPU 和 PyTorch 1.11.0 的 CentOS7.9 机器上训练了 GAN。然后,我将一些检查点(已如上所述保存)同步到我的个人计算机(Windows 10、NVIDIA GTX1050Ti、PyTorch 1.10.1)。对 GAN 使用完全相同的类定义,并以相同的方式初始化它(参见第一个代码片段,除了将它们设置为训练模式),我这样加载了一个检查点:
checkpoint = torch.load(f"checkpoints/checkpoint_10.pth.tar")
gen.load_state_dict(checkpoint["gen_state"])
opt_gen.load_state_dict(checkpoint["gen_optimizer"])
disc.load_state_dict(checkpoint["disc_state"])
opt_disc.load_state_dict(checkpoint["disc_optimizer"])
然后我使用了与第二个代码中相同的代码使用经过训练的 GAN 生成一些图像的代码片段,现在在我的机器中加载了检查点。这产生了垃圾输出:
我尝试使用我拥有的所有检查点,并且所有输出都是废话。我在 PyTorch 论坛中查找问题(1, 2,3),但似乎没有任何帮助。
我保存/加载模型是否错误?
I trained a GAN on 256x256 images, basically extending the code in PyTorch' own DCGAN tutorial to accommodate larger resolution images. The model and optimizer initialization look like this:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gen = Generator(...).to(device)
disc = Discriminator(...).to(device)
opt_gen = optim.Adam(gen.parameters(), ...)
opt_disc = optim.Adam(disc.parameters(), ...)
gen.train()
disc.train()
The GAN produced good quality samples. A few times during each epoch, I generated a few images (and viewed them on Tensorboard using SummaryWriter
) using the same input vector fixed_noise
to the generator:
with torch.no_grad():
fake = gen(fixed_noise)
img_grid_real = torchvision.utils.make_grid(
real[:NUM_VISUALIZATION_SAMPLES], normalize=True
)
img_grid_fake = torchvision.utils.make_grid(
fake[:NUM_VISUALIZATION_SAMPLES], normalize=True
)
writer_real.add_image("Real", img_grid_real, global_step=step)
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
I saved the GAN after each training epoch as such:
checkpoint = {
"gen_state": gen.state_dict(),
"gen_optimizer": opt_gen.state_dict(),
"disc_state": disc.state_dict(),
"disc_optimizer": opt_disc.state_dict()
}
torch.save(checkpoint, f"checkpoints/checkpoint_{epoch_number}.pth.tar")
Thus far, I had trained the GAN on a CentOS7.9 machine with an NVIDIA T4 GPU, with PyTorch 1.11.0. I then rsync
'd a few checkpoints (that had been saved as described above) to my personal machine (Windows 10, NVIDIA GTX1050Ti, PyTorch 1.10.1). Using the exact same class definition for the GAN, and initializing it the same way (cf. first code snippet, except for setting them in training mode), I loaded a checkpoint as such:
checkpoint = torch.load(f"checkpoints/checkpoint_10.pth.tar")
gen.load_state_dict(checkpoint["gen_state"])
opt_gen.load_state_dict(checkpoint["gen_optimizer"])
disc.load_state_dict(checkpoint["disc_state"])
opt_disc.load_state_dict(checkpoint["disc_optimizer"])
I then used the same code as in the second code snippet to generate some images with the trained GAN, now in my machine with the loaded checkpoint. This yielded garbage output:
I tried using all the checkpoints I had, and all output nonsense. I looked in the PyTorch forums for questions (1, 2, 3), but none seemed to help.
Am I saving/loading the model wrong?
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论