tensorflow LSTM时间序列预测问题

发布于 2021-12-06 23:32:56 字数 2091 浏览 794 评论 1

#coding=utf-8
import numpy as np
import tensorflow as tf
import matplotlib as mpl
mpl.use('Agg')
from matplotlib import pyplot as plt



learn=tf.contrib.learn

HIDDEN_SIZE=30    #LSTM中隐藏节点的个数
NUM_LAYERS=2    #LSTM层数
TIMESTEPS=10    #循环神经网络的截断长度
TRAINING_STEPS=10000    #训练轮数
BATCH_SIZE=32       #batch大小
TRAINING_EXAMPLES=10000      #训练数据个数
TESTING_EXAMPLES=1000    #测试数据个数
SAMPLE_GAP=0.01       #采样间隔

def generate_data(seq):
    X=[]
    y=[]
    for i in range(len(seq)-TIMESTEPS-1):
        X.append([seq[i:i+TIMESTEPS]])
        y.append([seq[i+TIMESTEPS]])
    return np.array(X,dtype=np.float32),np.array(y,dtype=np.float32)

def lstm_model(X,y):
    lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)
    cell=tf.nn.rnn_cell.MultiRNNCell([lstm_cell]*NUM_LAYERS)

    x_=tf.unstack(X,axis=1)

    output,_=tf.nn.dynamic_rnn(cell,x_,dtype=tf.float32)

    output=output[-1]
    prediction,loss=learn.models.linear_regression(output,y)
    train_op=tf.contrib.layers.optimize_loss(loss,tf.contrib.framework.get_global_step(),optimizer="Adagrad",learning_rate=0.1)
    return prediction,loss,train_op

regressor=learn.Estimator(model_fn=lstm_model)

test_start=TRAINING_EXAMPLES*SAMPLE_GAP
test_end=(TRAINING_EXAMPLES+TESTING_EXAMPLES)*SAMPLE_GAP
train_X,train_y=generate_data(np.sin(np.linspace(0,test_start,TRAINING_EXAMPLES,dtype=np.float32)))
test_X,test_y=generate_data(np.sin(np.linspace(test_start,test_end,TESTING_EXAMPLES,dtype=np.float32)))

regressor.fit(train_X,train_y,batch_size=BATCH_SIZE,steps=TRAINING_STEPS)

predicted=[[pred] for pred in regressor.predict(test_X)]

rmse=np.sqrt(((predicted-test_y)**2).mean(axis=0))
print('Mean square error is: %f'%rmse[0])

fig=plt.figure()
plot_predicted=plt.plot(predicted,label='predicted')
plot_test=plt.plot(test_y,label='real_sin')
plt.legend([plot_predicted,plot_test],['predicted','real_sin'])
fig.savefig('sin.png')

这是我的代码,预测正弦函数的深度学习算法。

提示报错ValueError: Shape (10, ?) must have rank at least 3

应该是output,_=tf.nn.dynamic_rnn(cell,x_,dtype=tf.float32)这一行开始出现了问题。

请教一下诸位大神,这个应该怎么解决

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

伪装你 2021-12-07 19:24:31

请问你解决这个问题了吗?我也报错相同的问题

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文