在 pytorch Lightning 中训练时出现 pytorch_lightning.utilities.exceptions.MisconfigurationException
我正在使用虚拟数据训练样本模型,然后出现此错误。我已经正确提供了所有内容,但仍然收到此错误:没有定义`configure_optimizers()`方法。当我开始训练时,Lightning“Trainer”至少需要定义“training_step()”、“train_dataloader()”和“configure_optimizers()”。
问题是因为我将虚拟数据输入网络的方式还是其他原因。
import torch
from torch import nn, optim
import pytorch_lightning as pl
from torch.utils.data import DataLoader
class ImageClassifier(pl.LightningModule):
def __init__(self, learning_rate=0.001):
super().__init__()
self.learning_rate = learning_rate
self.conv_layer1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
def forward(self,x):
output = self.conv_layer1(x)
print(output.shape)
return output
def training_step(self,batch, batch_idx):
inputs, targets = batch
output = self(inputs)
accuracy = self.binary_accuracy(output, targets)
loss = self.loss(output, targets)
self.log('train_accuracy', accuracy, prog_bar=True)
self.log('train_loss', loss)
return {'loss':loss,"training_accuracy": accuracy}
def test_step(self, batch, batch_idx):
inputs, targets = batch
outputs = self.inputs(inputs)
accuracy = self.binary_accuracy(outputs, targets)
loss = self.loss(outputs, targets)
self.log('test_accuracy', accuracy)
return {"test_loss":loss, "test_accuracy":accuracy}
def configure_optimizer(self):
params = self.parameters()
optimizer = optim.Adam(params=params, lr=self.learning_rate)
return optimizer
def binary_accuracy(self, outputs, inputs):
_, outputs = torch.max(outputs,1)
correct_results_sum = (outputs == targets).sum().float()
acc = correct_results_sum/targets.shape[0]
return acc
model = ImageClassifier()
Input = DataLoader(torch.randn(1,3,28,28))
trainer = pl.Trainer(max_epochs=10, progress_bar_refresh_rate=1)
trainer.fit(model, train_dataloader = Input)
I am training a sample model with dummy data then i got this error. I have gave everything properly but still i am getting this error: No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.
when i start training. Is the problem because the way i feed the dummy data into network or is their any other reason.
import torch
from torch import nn, optim
import pytorch_lightning as pl
from torch.utils.data import DataLoader
class ImageClassifier(pl.LightningModule):
def __init__(self, learning_rate=0.001):
super().__init__()
self.learning_rate = learning_rate
self.conv_layer1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
def forward(self,x):
output = self.conv_layer1(x)
print(output.shape)
return output
def training_step(self,batch, batch_idx):
inputs, targets = batch
output = self(inputs)
accuracy = self.binary_accuracy(output, targets)
loss = self.loss(output, targets)
self.log('train_accuracy', accuracy, prog_bar=True)
self.log('train_loss', loss)
return {'loss':loss,"training_accuracy": accuracy}
def test_step(self, batch, batch_idx):
inputs, targets = batch
outputs = self.inputs(inputs)
accuracy = self.binary_accuracy(outputs, targets)
loss = self.loss(outputs, targets)
self.log('test_accuracy', accuracy)
return {"test_loss":loss, "test_accuracy":accuracy}
def configure_optimizer(self):
params = self.parameters()
optimizer = optim.Adam(params=params, lr=self.learning_rate)
return optimizer
def binary_accuracy(self, outputs, inputs):
_, outputs = torch.max(outputs,1)
correct_results_sum = (outputs == targets).sum().float()
acc = correct_results_sum/targets.shape[0]
return acc
model = ImageClassifier()
Input = DataLoader(torch.randn(1,3,28,28))
trainer = pl.Trainer(max_epochs=10, progress_bar_refresh_rate=1)
trainer.fit(model, train_dataloader = Input)
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(2)
在您的代码中,方法名称为
configure_optimizer()
。因此,没有定义configure_optimizers()
方法。似乎函数名称有错误。In your code, the method name is
configure_optimizer()
. Therefore, noconfigure_optimizers()
method defined. Seems like an error in name of the function.我有同样的问题,然后我意识到错误的方法名称可能会导致错误。只需确保您输入方法名称或导入包并正确使用它即可。
I have a same problem, then I realize that a wrong method name could lead to the error. just make sure you type medthods name or import package and use it appropriately.