Pyrhon 上批量梯度下降的训练误差与测试误差,无法理解我应该做什么
我有 此数据集 包含训练数据和测试数据,并且我必须绘制具有平方和逻辑损失的梯度下降算法的训练和测试误差。 我是一个初学者,我有点迷茫不知道该怎么做。
到目前为止,我认为我已经成功地实现了梯度下降算法,这是我为平方损失所做的:
import numpy as np
import scipy.io
import matplotlib.pyplot as plt
data = scipy.io.loadmat('data_orsay_2017.mat')
x0,x1=data['Xtrain'],data['Xtest']
y0,y1=data['ytrain'],data['Ytrain']
def gradientdescent(x,y,n,alpha,max_iterations): # n is the sample size, alpha is the learning rate
d = x.shape[1] # dimension of the data
theta = np.random.random(d)
error = []
for j in range(max_iterations):
prediction = x.dot(theta)
cost = 1/(2*n)*sum((y[i,0]-prediction[i])**2 for i in range(n))
error.append(cost)
grad = (1/n) * sum((prediction[i] - y[i,0])*x[i] for i in range(n))
theta-=alpha*grad
return (theta,error)
现在,我可以简单地在 x0,y0 和 x1,y1 上运行该算法并绘制误差,但我我不认为那是我应该做的吗?我认为我应该以不同的方式处理训练数据和测试数据,但我不知道如何处理。
I have this dataset containing training data and testing data , and I have to plot training and testing errors for the gradient descent algorithm with square and logistic losses.
I'm a beginner, I'm a bit lost on what to do.
So far, I think I have managed to implement the gradient descent algorithm successfully, here is the one i did for the square loss:
import numpy as np
import scipy.io
import matplotlib.pyplot as plt
data = scipy.io.loadmat('data_orsay_2017.mat')
x0,x1=data['Xtrain'],data['Xtest']
y0,y1=data['ytrain'],data['Ytrain']
def gradientdescent(x,y,n,alpha,max_iterations): # n is the sample size, alpha is the learning rate
d = x.shape[1] # dimension of the data
theta = np.random.random(d)
error = []
for j in range(max_iterations):
prediction = x.dot(theta)
cost = 1/(2*n)*sum((y[i,0]-prediction[i])**2 for i in range(n))
error.append(cost)
grad = (1/n) * sum((prediction[i] - y[i,0])*x[i] for i in range(n))
theta-=alpha*grad
return (theta,error)
Now, i could simply run the algorithm on x0,y0 and x1,y1 and plot the errors, but I don't think that's what I'm supposed to do is it ? I assume I am supposed to treat the training data and the testing data differently, but I don't know how.
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
data:image/s3,"s3://crabby-images/d5906/d59060df4059a6cc364216c4d63ceec29ef7fe66" alt="扫码二维码加入Web技术交流群"
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论