为什么model训练了,可是无法预测?
import pandas as pd
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
if __name__ == '__main__':
boston = load_boston()
col_names = ['feature_{}'.format(i) for i in range(boston['data'].shape[1])]
df_full = pd.DataFrame(boston['data'], columns=col_names)
scalers_dict = {}
for col in col_names:
scaler = StandardScaler()
df_full[col] = scaler.fit_transform(df_full[col].values.reshape(-1, 1))
scalers_dict[col] = scaler
x_train, x_test, y_train, y_test = train_test_split(df_full.values, boston['target'], test_size=0.2, random_state=2)
model = torch.nn.Sequential(torch.nn.Linear(x_train.shape[1], 1), torch.nn.ReLU())
criterion = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
n_epochs = 2000
train_loss = []
test_loss = []
x_train = Variable(torch.from_numpy(x_train).float(), requires_grad=True)
y_train = Variable(torch.from_numpy(y_train).float())
for epoch in range(n_epochs):
y_hat = model(x_train)
loss = criterion(y_hat, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss = loss.data ** (1/2)
train_loss.append(epoch_loss)
if (epoch + 1) % 250 == 0:
print("{}:loss = {}".format(epoch + 1, epoch_loss))
order = y_train.argsort()
y_train = y_train[order]
x_train = x_train[order, :]
model.eval()
predicted = model(x_train).detach().numpy()
actual = y_train.numpy()
print('predicted:", predicted[:5].flatten(), actual[:5])
plt.plot(predicted.flatten(), 'r-', label='predicted')
plt.plot(actual, 'g-', label='actual')
plt.show()
前面训练都可以,loss也是稳步下降,可是用model将训练数据测试显示,画到matplotlib上时和真实值对比时,就发现,预测值是一条水平线,而真实值是一条折线,完全不匹配?
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
首先,你的网络模型太浅,一层全连接很难回归一个13维的数据。我加了两层,取得了很好的效果。再者,你的y_train输入有问题,维度应该是(404,1)而不是(404,)。经过更正以后,预测值和真实值非常接近。如果不明白,可以研究我更正后的代码。