使用python实现一段bp神经网络后,由于数据量增多(76条),出现nan报错
使用python实现一段bp神经网络后,由于数据量增多(76条),出现nan报错
代码如下:
# coding:utf-8
import numpy as np
import matplotlib.pyplot as plt
def logsig(x):
return 1.0/(1+np.exp(-x))
#76条归一化数据
#输出层
# 订单数
ordernum=[12668619, 13103780, 13257004, 14638190, 14726263, 12600723, 12949334, 13341628, 13169524, 13301632, 15020992, 15606709, 13446703, 13371539, 13262251, 13162602, 13168550, 13598886, 12370827, 11962626, 11717528, 11869467, 12318038, 13303466, 14288269, 12600621, 12957388, 13200367, 13271152, 13051823, 15049197, 16090773, 13568785, 13152151, 12927545, 12827040, 13172809, 14904629, 15239996, 12646447, 12834756, 13102306, 13447091, 13632212, 15747318, 16151994, 13377774, 13633451, 13317559, 13592490, 14065724, 15377027, 15486313, 13024443, 13557400, 13234309, 13538881, 14339158, 15770610, 16017075, 13932597, 14057435, 14136860, 14621550, 15745729, 17343336, 17072830, 14217876, 14457732, 14380455, 14505519, 14957927, 16587364, 16489540, 13943044, 14326370]
# 输入层 3 个变量
#是否节假日(0代表工作日;1代表节假日)
holiday=[0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0]
#整体在线商家数
shopnum= [0, 0.0171, 0.0372, 0.0415, 0.0432, 0.0618, 0.0794, 0.0971, 0.1153, 0.1313, 0.1359, 0.1377, 0.1537, 0.1685, 0.1799, 0.1882, 0.1949, 0.1989, 0.2021, 0.2016, 0.1977, 0.1931, 0.1933, 0.1945, 0.1901, 0.2004, 0.2148, 0.2323, 0.2508, 0.2686, 0.2724, 0.2725, 0.2854, 0.3036, 0.3230, 0.3428, 0.3610, 0.3637, 0.3642, 0.3770, 0.3920, 0.4077, 0.4199, 0.4304, 0.4301, 0.4299, 0.4389, 0.4428, 0.4576, 0.4702, 0.4837, 0.4837, 0.4850, 0.5022, 0.5165, 0.5273, 0.5427, 0.5560, 0.5566, 0.5566, 0.5724, 0.5871, 0.6020, 0.6167, 0.6307, 0.6319, 0.6332, 0.6489, 0.6601, 0.6670, 0.6713, 0.6759, 0.6673, 0.6655, 0.6733, 0.6720]
#整体补贴率
subsidyrate=[0.5333, 0.7238, 0.5905, 0.5619, 0.5762, 0.4619, 0.4143, 0.4571, 0.4667, 0.2762, 0.3238, 0.3381, 0.2095, 0.1952, 0.2333, 0.3571, 0.3714, 0.3524, 0.0667, 0.0143, 0.0000, 0.1000, 0.0095, 0.1048, 0.1619, 0.0714, 0.0619, 0.0857, 0.2238, 0.1143, 0.2000, 0.2143, 0.0857, 0.1810, 0.2286, 0.4333, 0.3571, 0.3810, 0.3905, 0.3571, 0.4000, 0.5000, 0.9095, 0.9238, 0.9190, 0.9476, 1.0000, 0.9381, 0.7905, 0.9619, 0.8095, 0.6286, 0.5714, 0.5143, 0.4952, 0.4619, 0.5857, 0.5190, 0.6714, 0.5952, 0.5571, 0.6333, 0.6714, 0.7190, 0.6476, 0.6952, 0.5286, 0.4190, 0.4000, 0.4190, 0.5810, 0.5476, 0.5714, 0.5524, 0.4762, 0.4714]
# 输入值
samplein = np.mat([holiday,subsidyrate,shopnum]) #3*76
sampleinminmax = np.array([samplein.min(axis=1).T.tolist()[0],samplein.max(axis=1).T.tolist()[0]]).transpose()#3*2,对应最大值最小值
# 输出值
sampleout = np.mat([ordernum])#1*76
sampleoutminmax = np.array([sampleout.min(axis=1).T.tolist()[0],sampleout.max(axis=1).T.tolist()[0]]).transpose()#2*2,对应最大值最小值
#3*76
sampleinnorm = (2*(np.array(samplein.T)-sampleinminmax.transpose()[0])/(sampleinminmax.transpose()[1]-sampleinminmax.transpose()[0])-1).transpose()
#1*76
sampleoutnorm = (2*(np.array(sampleout.T).astype(float)-sampleoutminmax.transpose()[0])/(sampleoutminmax.transpose()[1]-sampleoutminmax.transpose()[0])-1).transpose()
#给输出样本添加噪音
noise = 0.03*np.random.rand(sampleoutnorm.shape[0],sampleoutnorm.shape[1])
sampleoutnorm += noise
# 训练次数
maxepochs = 1500
# 学习速率
learnrate = 0.035
# 误差率
errorfinal = 0.65*10**(-3)
#errorfinal = 0.05
# 数据量
samnum = 76
# 输入数
indim = 3
# 输出数
outdim = 1
hiddenunitnum = 8 # 2~10
# 随机产生权值和偏置
w1 = 0.5*np.random.rand(hiddenunitnum,indim)-0.1
b1 = 0.5*np.random.rand(hiddenunitnum,1)-0.1
w2 = 0.5*np.random.rand(outdim,hiddenunitnum)-0.1
b2 = 0.5*np.random.rand(outdim,1)-0.1
errhistory = []
for i in range(maxepochs):
hiddenout = logsig((np.dot(w1,sampleinnorm).transpose()+b1.transpose())).transpose()
hiddenout = np.nan_to_num(hiddenout)
networkout = (np.dot(w2,np.nan_to_num(hiddenout)).transpose()+b2.transpose()).transpose()
err = sampleoutnorm - np.nan_to_num(networkout)
sse = sum(sum(err**2))
errhistory.append(sse)
if sse < errorfinal:
break
delta2 = err
delta1 = np.dot(w2.transpose(),delta2)*hiddenout*(1-hiddenout)
dw2 = np.dot(delta2,hiddenout.transpose())
db2 = np.dot(delta2,np.ones((samnum,1)))
dw1 = np.dot(delta1,sampleinnorm.transpose())
dw1 = np.nan_to_num(dw1)
db1 = np.dot(delta1,np.ones((samnum,1)))
# db1 = np.nan_to_num(db1)
# w2 += learnrate*dw2
# b2 += learnrate*db2
w2 = np.nan_to_num(w2+learnrate * dw2)
b2 = np.nan_to_num(b2+learnrate*db2)
w1 += learnrate*dw1
b1 += learnrate * db1
# b1 = np.nan_to_num(b1+learnrate*db1)
# print 'db1=', db1
# print 'db1=', db1
# print 'err=',err,'w1=',w1,'w2=',w2,'b1=',b1,'b2=',b2
# 误差曲线图
errhistory10 = np.log10(errhistory)
minerr = min(errhistory10)
plt.plot(errhistory10)
plt.plot(range(0,i+1000,1000),[minerr]*len(range(0,i+1000,1000)))
ax=plt.gca()
ax.set_yticks([-2,-1,0,1,2,minerr])
ax.set_yticklabels([u'$10^{-2}$',u'$10^{-1}$',u'$1$',u'$10^{1}$',u'$10^{2}$',str(('%.4f'%np.power(10,minerr)))])
ax.set_xlabel('iteration')
ax.set_ylabel('error')
ax.set_title('Error History')
plt.savefig('total-errorhistory.png',dpi=700)
plt.close()
# 仿真输出和实际输出对比图
hiddenout = logsig((np.dot(w1,sampleinnorm).transpose()+b1.transpose())).transpose()
networkout = (np.dot(w2,hiddenout).transpose()+b2.transpose()).transpose()
diff = sampleoutminmax[:,1]-sampleoutminmax[:,0]
networkout2 = (networkout+1)/2
networkout2[0] = networkout2[0]*diff[0]+sampleoutminmax[0][0]
#networkout2[1] = networkout2[1]*diff[1]+sampleoutminmax[1][0]
sampleout = np.array(sampleout)
print 'networkout',networkout
fig,axes = plt.subplots(nrows=2,ncols=1,figsize=(12,10))
line1, =axes[0].plot(networkout2[0],'k',marker = u'$\circ$')
line2, = axes[0].plot(sampleout[0],'r',markeredgecolor='b',marker = u'$\star$',markersize=9)
print "chazhi=",(networkout2[0]-sampleout[0])
print"准确率=",(sampleout[0]/networkout2[0])
print"误差率=",((networkout2[0]-sampleout[0])/networkout2[0])
axes[0].legend((line1,line2),('simulation output','real output'),loc = 'upper left')
yticks = [10000000,12000000,14000000,16000000,18000000]
ytickslabel = [u'$10$',u'$12$',u'$14$',u'$16$',u'$18$']
axes[0].set_yticks(yticks)
axes[0].set_yticklabels(ytickslabel)
axes[0].set_ylabel('ordernum')
xticks = range(0,30,1)
xtickslabel = range(0,30,1)
axes[0].set_xticks(xticks)
axes[0].set_xticklabels(xtickslabel)
axes[0].set_xlabel(u'date')
axes[0].set_title('ordernum')
fig.savefig('total-simulation.png',dpi=500,bbox_inches='tight')
plt.close()
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
![扫码二维码加入Web技术交流群](/public/img/jiaqun_03.jpg)
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
请问你解决了吗?我也遇到了类似的问题