参数不会在python中的非线性最小二平方实现中以较低的公差收敛
我将一些R代码转换为Python作为学习过程,尤其是为Autodiff尝试JAX
。
在实现非线性最小值的函数中,当我将公差设置为1e-8时,估计的参数几乎是相同的,几乎是相同的,但是算法似乎永远不会收敛。
但是,R码在TOL = 1E-8和TOL = 1E-9时在第12个INTER上收敛。估计的参数几乎与python实现产生的参数相同。
我认为这与浮点有关,但不确定我可以改进哪个步骤以使其与R相同。
这是我的代码,大多数步骤与R相同
import jax
import jax.numpy as jnp
import numpy as np
import scipy.linalg as ola
def update_parm(X, y, fun, dfun, parm, theta, wt):
len_y = len(y)
mean_fun = fun(X, parm)
if (type(wt) == bool):
if (wt):
var_fun = np.exp(theta * np.log(mean_fun))
sqrtW = 1 / np.sqrt(var_fun ** 2)
else:
sqrtW = 1
else:
sqrtW = wt
gradX = dfun(x, parm)
weighted_X = sqrtW.reshape(len_y, 1) * gradX
z = gradX @ parm + (y - mean_fun)
weighted_z = sqrtW * z
qr_gradX = ola.qr(weighted_X, mode="economic")
Q = qr_gradX[0]
R = qr_gradX[1]
new_parm = ola.solve(R, np.dot(Q.T, weighted_z))
return new_parm
def nls_irwls(X, y, fun, dfun, init, theta = 1, tol = 1e-8, maxiter = 500):
old_parm = init
iter = 0
while (iter < maxiter):
new_parm = update_parm(X, y, fun, dfun, parm=old_parm, theta=theta, wt=True)
parm_diff = np.max(np.abs(new_parm - old_parm) / np.abs(old_parm))
print(parm_diff)
if (parm_diff < tol) :
break
else:
old_parm = new_parm
iter += 1
print(new_parm)
if (iter == maxiter):
print("The algorithm failed to converge")
else:
return {"Estimated coefficient": new_parm}
x = np.array([0.25, 0.5, 0.75, 1, 1.25, 2, 3, 4, 5, 6, 8])
y = np.array([2.05, 1.04, 0.81, 0.39, 0.30, 0.23, 0.13, 0.11, 0.08, 0.10, 0.06])
def model(x, W):
comp1 = jnp.exp(W[0])
comp2 = jnp.exp(-jnp.exp(W[1]) * x)
comp3 = jnp.exp(W[2])
comp4 = jnp.exp(-jnp.exp(W[3]) * x)
return comp1 * comp2 + comp3 * comp4
init = np.array([0.69, 0.69, -1.6, -1.6])
#autodiff
model_grad = jax.jit(jax.jacfwd(model, argnums=1))
#manual derivative
def dModel(x, W):
e1 = np.exp(W[1])
e2 = np.exp(W[3])
e5 = np.exp(-(x * e1))
e6 = np.exp(-(x * e2))
e7 = np.exp(W[0])
e8 = np.exp(W[2])
b1 = e5 * e7
b2 = -(x * e5 * e7 * e1)
b3 = e6 * e8
b4 = -(x * e6 * e8 * e2)
return np.array([b1, b2, b3, b4]).T
nls_irwls(x, y, model, model_grad, init=init, theta=1, tol=1e-8, maxiter=50)
nls_irwls(x, y, model, dModel, init=init, theta=1, tol=1e-8, maxiter=50)
I am translating some of my R codes to Python as a learning process, especially trying JAX
for autodiff.
In functions to implement non-linear least square, when I set tolerance at 1e-8, the estimated parameters are nearly identical after several iterations, but the algorithm never appear to converge.
However, the R codes converge at the 12th inter at tol=1e-8 and 14th inter at tol=1e-9. The estimated parameters are almost the same as the ones resulted from Python implementation.
I think this has something to do with floating point, but not sure which step I could improve to make the converge as quickly as seen in R.
Here are my codes, and most steps are the same as in R
import jax
import jax.numpy as jnp
import numpy as np
import scipy.linalg as ola
def update_parm(X, y, fun, dfun, parm, theta, wt):
len_y = len(y)
mean_fun = fun(X, parm)
if (type(wt) == bool):
if (wt):
var_fun = np.exp(theta * np.log(mean_fun))
sqrtW = 1 / np.sqrt(var_fun ** 2)
else:
sqrtW = 1
else:
sqrtW = wt
gradX = dfun(x, parm)
weighted_X = sqrtW.reshape(len_y, 1) * gradX
z = gradX @ parm + (y - mean_fun)
weighted_z = sqrtW * z
qr_gradX = ola.qr(weighted_X, mode="economic")
Q = qr_gradX[0]
R = qr_gradX[1]
new_parm = ola.solve(R, np.dot(Q.T, weighted_z))
return new_parm
def nls_irwls(X, y, fun, dfun, init, theta = 1, tol = 1e-8, maxiter = 500):
old_parm = init
iter = 0
while (iter < maxiter):
new_parm = update_parm(X, y, fun, dfun, parm=old_parm, theta=theta, wt=True)
parm_diff = np.max(np.abs(new_parm - old_parm) / np.abs(old_parm))
print(parm_diff)
if (parm_diff < tol) :
break
else:
old_parm = new_parm
iter += 1
print(new_parm)
if (iter == maxiter):
print("The algorithm failed to converge")
else:
return {"Estimated coefficient": new_parm}
x = np.array([0.25, 0.5, 0.75, 1, 1.25, 2, 3, 4, 5, 6, 8])
y = np.array([2.05, 1.04, 0.81, 0.39, 0.30, 0.23, 0.13, 0.11, 0.08, 0.10, 0.06])
def model(x, W):
comp1 = jnp.exp(W[0])
comp2 = jnp.exp(-jnp.exp(W[1]) * x)
comp3 = jnp.exp(W[2])
comp4 = jnp.exp(-jnp.exp(W[3]) * x)
return comp1 * comp2 + comp3 * comp4
init = np.array([0.69, 0.69, -1.6, -1.6])
#autodiff
model_grad = jax.jit(jax.jacfwd(model, argnums=1))
#manual derivative
def dModel(x, W):
e1 = np.exp(W[1])
e2 = np.exp(W[3])
e5 = np.exp(-(x * e1))
e6 = np.exp(-(x * e2))
e7 = np.exp(W[0])
e8 = np.exp(W[2])
b1 = e5 * e7
b2 = -(x * e5 * e7 * e1)
b3 = e6 * e8
b4 = -(x * e6 * e8 * e2)
return np.array([b1, b2, b3, b4]).T
nls_irwls(x, y, model, model_grad, init=init, theta=1, tol=1e-8, maxiter=50)
nls_irwls(x, y, model, dModel, init=init, theta=1, tol=1e-8, maxiter=50)
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
要注意的一件事是,默认情况下,JAX以32位执行计算,而R和Numpy之类的工具则以64位执行计算。由于
1E-8
处于32位浮点精度的边缘,因此我怀疑这就是为什么您的程序无法收敛的原因。您可以通过将其放置在脚本的开头来启用64位计算:
在执行此操作之后,您的程序会按预期收敛。有关更多信息,请参见。
One thing to be aware of is that by default, JAX performs computations in 32-bit, while tools like R and numpy perform computations in 64-bit. Since
1E-8
is at the edge of 32-bit floating point precision, I suspect this is why your program is failing to converge.You can enable 64-bit computation by putting this at the beginning of your script:
After doing this, your program converges as expected. For more information, see JAX Sharp Bits: Double Precision.