TQDM不会更新时期
for epoch in range(args.num_epochs):
model.train()
# print(f"Epoch {epoch}")
with tqdm(total=len(input_tensor_catted), unit="ba") as pbar:
pbar.set_description(f"Epoch {epoch}")
pbar.update(1)
# for step, batch in enumerate(train_dataloader):
for step in range(len(input_tensor_catted) // args.batch_size):
indices = torch.multinomial(torch.ones(len(input_tensor_catted)) / len(input_tensor_catted), args.batch_size, replacement=True)
clean_inputs = input_tensor_catted[indices, :]
clean_conditioning = original_cost_tensor_catted[indices, :].to(clean_inputs.device)
# clean_inputs = batch["input"]
noise_samples = torch.randn(clean_inputs.shape).to(clean_inputs.device)
bsz = clean_inputs.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_inputs.device).long()
# add noise onto the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.training_step(clean_inputs, noise_samples, timesteps)
if step % args.gradient_accumulation_steps != 0:
with accelerator.no_sync(model):
# from noisy images, predict epsilon
output = model(noisy_images, timesteps, clean_conditioning)
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
else:
output = model(noisy_images, timesteps, clean_conditioning)
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
optimizer.step()
这是我的代码。
重点是,顶部计数器(与时期1平行)仅从1/10000升级到10/10000始终停止,即使时代大于10。
for epoch in range(args.num_epochs):
model.train()
# print(f"Epoch {epoch}")
with tqdm(total=len(input_tensor_catted), unit="ba") as pbar:
pbar.set_description(f"Epoch {epoch}")
pbar.update(1)
# for step, batch in enumerate(train_dataloader):
for step in range(len(input_tensor_catted) // args.batch_size):
indices = torch.multinomial(torch.ones(len(input_tensor_catted)) / len(input_tensor_catted), args.batch_size, replacement=True)
clean_inputs = input_tensor_catted[indices, :]
clean_conditioning = original_cost_tensor_catted[indices, :].to(clean_inputs.device)
# clean_inputs = batch["input"]
noise_samples = torch.randn(clean_inputs.shape).to(clean_inputs.device)
bsz = clean_inputs.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_inputs.device).long()
# add noise onto the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.training_step(clean_inputs, noise_samples, timesteps)
if step % args.gradient_accumulation_steps != 0:
with accelerator.no_sync(model):
# from noisy images, predict epsilon
output = model(noisy_images, timesteps, clean_conditioning)
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
else:
output = model(noisy_images, timesteps, clean_conditioning)
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
optimizer.step()
This is my code.
This is an example of what is printed to the console:
The point is, the top counter (parallel with Epoch 1) only upates from 1/10000 to 10/10000 and always stops, even if the Epoch is greater than 10.
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
您可以在每个时期都可以做类似的事情,
每个时期都会创建一个新的TQDM栏,而您不必担心重置它。
You can do something like this every epoch
This will create a new tqdm bar each epoch and you don't have to worry about resetting it.