Error when checking target: expected dense_3 to have 2 dimension
问题描述
使用Tensorflow的keras来处理一个文本情感多分类问题的,但是在构建模型的时候出现了一些莫名其妙的问题?
Error when checking target: expected dense_3 to have 2 dimensions, but got array with shape (4500, 4, 4)
问题出现的环境背景及自己尝试过哪些方法
Layer (type) Output Shape Param #
=================================================================
lstm_9 (LSTM) (150, 1, 32) 772224
_________________________________________________________________
lstm_10 (LSTM) (150, 1, 32) 8320
_________________________________________________________________
lstm_11 (LSTM) (150, 32) 8320
_________________________________________________________________
dense_3 (Dense) (150, 4) 132
=================================================================
Total params: 788,996
Trainable params: 788,996
Non-trainable params: 0
很奇怪,dense_3上一层明明输出的是2 dimensions,为什么这里还会报dim不符合的错误呢?
相关代码
// 请把代码文本粘贴到下方(请勿用图片代替代码)
timesteps = 1
data_dim = 6000
num_classes = 4
batch_size = 150
x_train = train_x_array
x_val = val_x_array
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=6000, value=0)
x_val = keras.preprocessing.sequence.pad_sequences(x_val, maxlen=6000, value=0)
print('x_train shape: {}'.format(x_train.shape))
print('x_val shape: {}'.format(x_val.shape))
x_train = x_train.reshape([batch_size * 30, timesteps, data_dim])
x_val = x_val.reshape([batch_size * 10, timesteps, data_dim])
print('x_train shape: {}'.format(x_train.shape))
print('x_val shape: {}'.format(x_val.shape))
// x_train shape: (4500, 6000)
// x_val shape: (1500, 6000)
// x_train shape: (4500, 1, 6000)
// x_val shape: (1500, 1, 6000)
model = Sequential()
model.add(LSTM(32, return_sequences=True, stateful=True,
batch_input_shape=(batch_size, timesteps, data_dim)))
model.add(LSTM(32, return_sequences=True, stateful=True))
model.add(LSTM(32, stateful=True))
model.add(Dense(4, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.summary()
model.fit(x_train, y_train,
batch_size=batch_size, epochs=3, shuffle=False,
validation_data=(x_val, y_val))
请问如何解决?
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
老哥解决了吗?我有同样的问题