Tensorflow多元线性回归参数不收敛的问题

发布于 2022-09-06 22:03:51 字数 1723 浏览 28 评论 0

在使用Tensorflow进行多元线性回归的时候,遇到了参数不收敛的问题。问题在于优化方法的选择上:如果使用tf.train.AdamOptimizer(0.01).minimize(loss)进行,参数会收敛,损失函数也比较合理,但是权重和偏置项与原来的不一致,这是第一个不明白的地方;如果使用opt = tf.train.GradientDescentOptimizer(0.01).minimize(loss),则损失函数会一直增大,找不到原因。如果初学者一直找不到原因,希望大家有明白的,可以帮忙解释一下,代码量并不大。下面是代码:

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# 模拟输入数据,区间均匀分布
X1 = np.matrix(np.random.uniform(-10, 10, 100)).T
X2 = np.matrix(np.linspace(-10, 10, 100)).T
X3 = np.matrix(np.linspace(-10, 10, 100)).T
X_input = np.concatenate((X1, X2, X3), axis=1)
# 权重应该是 20,, -35, 4.3 偏置项是25
Y_input = 20 * X1 - 35 * X2 + 4.3 * X3 + 25 * np.ones((100, 1))

# 权重向量和偏置项
W = tf.Variable(tf.random_uniform(shape=[3, 1]))
b = tf.Variable(tf.random_uniform(shape=[1, 1]))

# 占位符
X = tf.placeholder(dtype=tf.float32, shape=[None, 3])
Y = tf.placeholder(dtype=tf.float32, shape=[None, 1])

# 预测值
Y_pred = tf.matmul(X, W) + b * np.ones((100, 1))

# 损失函数
loss = tf.reduce_sum(tf.square(Y_pred - Y)) / 100

# Adma算法优化,学习步长是0.01
opt = tf.train.AdamOptimizer(0.01).minimize(loss)
# 梯度下降
# opt = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

# 用于绘图
x_axis = []
y_axis = []

with tf.Session() as sess:
    # 初始化变量
    sess.run(tf.global_variables_initializer())
    print("training,please wait...")
    for i in range(20000):
        sess.run(opt, feed_dict={Y: Y_input, X: X_input})
        x_axis.append(i)
        y_axis.append(sess.run(loss, feed_dict={Y: Y_input, X: X_input}))
    print("finish training!")
    print("W:", sess.run(W), "\nb:", sess.run(b))
    print(sess.run(loss, feed_dict={Y: Y_input, X: X_input}))
    plt.plot(x_axis, y_axis)
    plt.show()

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

不美如何 2022-09-13 22:03:51

这么点数据搞2万轮的话很容易overfitting
但这不是主要问题,主要是GD没有动量,容易陷入局部最优解;而adam自带动量,一般来说不容易陷入局部最优,性能会比较好。
至于你开始设的权重的话只是为了计算Y——input值,而神经网络是自己拟合权值的,完全无视你设的权值,所以不一样是正常的,一样才是吊鬼了。

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文