tensorflow实现mnist手写数字识别报错?
代码:
import tensorflow as tf
import numpy as np
tf.enable_eager_execution()
class DataLoader():
def __init__(self):
mnist = tf.keras.datasets.mnist.load_data(path = 'mnist.npz')
self.train_data = mnist[0][0]
self.train_data = np.reshape(self.train_data,(self.train_data.shape[0],28*28))
self.train_labels = mnist[0][1]
self.eval_data = mnist[1][0]
self.train_data = np.reshape(self.train_data,(self.train_data.shape[0],28*28))
self.eval_labels = mnist[1][1]
def get_batch(self, batch_size):
indexs = np.random.randint(0,self.train_data.shape[0],batch_size)
return self.train_data[indexs, :], self.train_labels[indexs]
'''
class MLP(tf.keras.Modle):
'''
class MLP(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense1 = tf.keras.layers.Dense(units=100, activation= tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(units=10,activation =None)
def call(self, inputs):
x = self.dense1(inputs)
y = self.dense2(x)
return y
def predict(self, inputs):
logits = self(inputs)
return tf.argmax(logits, axis=-1)
num_batches = 1000
batch_size = 50
learning_rate = 0.001
model = MLP()
data_loader = DataLoader()
optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
for batch_index in range(num_batches):
X , y = data_loader.get_batch(batch_size)
print(np.shape(X))
with tf.GradientTape() as tape:
X = tf.convert_to_tensor(X, dtype = tf.int64, name = 'X')
print(X)
y_logit_pred = model(X)
loss = tf.losses.sparse_softmax_cross_entropy(labels = y, logits = y_logit_pred)
print('batch %d: loss %f' % (batch_index, loss.numpy()))
grads = tape.gradient(loss, model.variables)
optimizer.apply_gradients(grads_and_vars = zip(grads, model.variables))
num_eval_samples = np.shape(data_loader.eval_labels)[0]
y_pred = model.predict(data_loader.eval_data).numpy()
print("test accuracy: %f" % (sum(y_pred == data_loader.eval_labels) / num_eval_samples))
错误信息:
/home/kalarea/.conda/envs/py35/bin/python /home/kalarea/PycharmProjects/start_tensorflow/start_tensorflow.py
/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from float
to np.floating
is deprecated. In future, it will be treated as np.float64 == np.dtype(float).type
.
from ._conv import register_converters as _register_converters
(50, 784)
2018-10-14 18:28:18.977966: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
tf.Tensor(
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
...
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]], shape=(50, 784), dtype=int64)
Traceback (most recent call last):
File "/home/kalarea/PycharmProjects/start_tensorflow/start_tensorflow.py", line 55, in <module>
y_logit_pred = model(X)
File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 769, in call
outputs = self.call(inputs, *args, **kwargs)
File "/home/kalarea/PycharmProjects/start_tensorflow/start_tensorflow.py", line 30, in call
x = self.dense1(inputs)
File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 759, in call
self.build(input_shapes)
File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/keras/layers/core.py", line 921, in build
trainable=True)
File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 586, in add_weight
aggregation=aggregation)
File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/checkpointable/base.py", line 591, in _add_variable_with_custom_getter
**kwargs_for_getter)
File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1986, in make_variable
aggregation=aggregation)
File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/variables.py", line 145, in call
return cls._variable_call(*args, **kwargs)
File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/variables.py", line 141, in _variable_call
aggregation=aggregation)
File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/variables.py", line 120, in <lambda>
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 2434, in default_variable_creator
import_scope=import_scope)
File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/variables.py", line 147, in call
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 297, in init
constraint=constraint)
File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 420, in _init_from_args
initial_value = initial_value()
File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1970, in <lambda>
shape, dtype=dtype, partition_info=partition_info)
File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/init_ops.py", line 483, in call
shape, -limit, limit, dtype, seed=self.seed)
File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/random_ops.py", line 240, in random_uniform
shape, minval, maxval, seed=seed1, seed2=seed2, name=name)
File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/gen_random_ops.py", line 848, in random_uniform_int
_six.raise_from(_core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: Need minval < maxval, got 0 >= 0 [Op:RandomUniformInt] name: mlp/dense/kernel/random_uniform/
Process finished with exit code 1
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论