返回介绍

Estimating parameters of a linear regreession model

发布于 2025-02-25 23:43:57 字数 2885 浏览 0 评论 0 收藏 0

We will show how to estimate regression parameters using a simple linear modesl

\[y \sim ax + b\]

We can restate the linear model

\[y = ax + b + \epsilon\]

as sampling from a probability distribution

\[y \sim \mathcal{N}(ax + b, \sigma^2)\]

Now we can use pymc to estimate the paramters \(a\), \(b\) and \(\sigma\) (pymc2 uses precision \(\tau\) which is \(1/\sigma^2\) so we need to do a simple transformation). We will assume the following priors

\[\begin{split}a \sim \mathcal{N}(0, 100) \\ b \sim \mathcal{N}(0, 100) \\ \tau \sim \text{Gamma}(0.1, 0.1)\end{split}\]

Here we need a helper function to let PyMC know that the mean is a deterministic function of the parameters \(a\), \(b\) and \(x\). We can do this with a decorator, like so:

@pymc.deterministic
def mu(a=a, b=b, x=x):
    return a*x + b
# observed data
n = 21
a = 6
b = 2
sigma = 2
x = np.linspace(0, 1, n)
y_obs = a*x + b + np.random.normal(0, sigma, n)
data = pd.DataFrame(np.array([x, y_obs]).T, columns=['x', 'y'])
data.plot(x='x', y='y', kind='scatter', s=50);

# define priors
a = pymc.Normal('slope', mu=0, tau=1.0/10**2)
b = pymc.Normal('intercept', mu=0, tau=1.0/10**2)
tau = pymc.Gamma("tau", alpha=0.1, beta=0.1)

# define likelihood
@pymc.deterministic
def mu(a=a, b=b, x=x):
    return a*x + b

y = pymc.Normal('y', mu=mu, tau=tau, value=y_obs, observed=True)

# inference
m = pymc.Model([a, b, tau, x, y])
mc = pymc.MCMC(m)
mc.sample(iter=11000, burn=10000)
[-----------------100%-----------------] 11000 of 11000 complete in 6.1 sec
abar = a.stats()['mean']
bbar = b.stats()['mean']
data.plot(x='x', y='y', kind='scatter', s=50);
xp = np.array([x.min(), x.max()])
plt.plot(a.trace()*xp[:, None] + b.trace(), c='red', alpha=0.01)
plt.plot(xp, abar*xp + bbar, linewidth=2, c='red');

pymc.Matplot.plot(mc)
Plotting intercept
Plotting slope
Plotting tau
/Users/cliburn/anaconda/lib/python2.7/site-packages/numpy/core/fromnumeric.py:2507: VisibleDeprecationWarning: rank is deprecated; use the ndim attribute or function instead. To find the rank of a matrix see numpy.linalg.matrix_rank.
  VisibleDeprecationWarning)

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文