返回介绍

Estimating parameters of a linear regreession model

发布于 2025-02-25 23:43:58 字数 7759 浏览 0 评论 0 收藏 0

We will show how to estimate regression parameters using a simple linear modesl

\[y \sim ax + b\]

We can restate the linear model

\[y = ax + b + \epsilon\]

as sampling from a probability distribution

\[y \sim \mathcal{N}(ax + b, \sigma^2)\]

Now we can use pymc to estimate the paramters \(a\), \(b\) and \(\sigma\) (pymc2 uses precision \(\tau\) which is \(1/\sigma^2\) so we need to do a simple transformation). We will assume the following priors

\[\begin{split}a \sim \mathcal{N}(0, 100) \\ b \sim \mathcal{N}(0, 100) \\ \tau \sim \text{Gamma}(0.1, 0.1)\end{split}\]

# observed data
n = 11
_a = 6
_b = 2
x = np.linspace(0, 1, n)
y = _a*x + _b + np.random.randn(n)

with pm.Model() as model:
    a = pm.Normal('a', mu=0, sd=20)
    b = pm.Normal('b', mu=0, sd=20)
    sigma = pm.Uniform('sigma', lower=0, upper=20)

    y_est = a*x + b # simple auxiliary variables

    likelihood = pm.Normal('y', mu=y_est, sd=sigma, observed=y)
    # inference
    start = pm.find_MAP()
    step = pm.NUTS() # Hamiltonian MCMC with No U-Turn Sampler
    trace = pm.sample(niter, step, start, random_seed=123, progressbar=True)
    pm.traceplot(trace);
[-----------------100%-----------------] 1000 of 1000 complete in 8.9 sec

Alternative fromulation using GLM formulas

data = dict(x=x, y=y)

with pm.Model() as model:
    pm.glm.glm('y ~ x', data)
    step = pm.NUTS()
    trace = pm.sample(2000, step, progressbar=True)
[-----------------100%-----------------] 2000 of 2000 complete in 8.1 sec
pm.traceplot(trace);

plt.figure(figsize=(7, 7))
plt.scatter(x, y, s=30, label='data')
pm.glm.plot_posterior_predictive(trace, samples=100,
                                 label='posterior predictive regression lines',
                                 c='blue', alpha=0.2)
plt.plot(x, _a*x + _b, label='true regression line', lw=3., c='red')
plt.legend(loc='best');

Simple Logistic model

We have observations of height and weight and want to use a logistic model to guess the sex.

# observed data
df = pd.read_csv('HtWt.csv')
df.head()
 maleheightweight
0063.2168.7
1068.7169.8
2064.8176.6
3067.9246.8
4168.9151.6
niter = 1000
with pm.Model() as model:
    pm.glm.glm('male ~ height + weight', df, family=pm.glm.families.Binomial())
    trace = pm.sample(niter, step=pm.Slice(), random_seed=123, progressbar=True)
[-----------------100%-----------------] 1000 of 1000 complete in 3.2 sec
# note that height and weigth in trace refer to the coefficients

df_trace = pm.trace_to_dataframe(trace)
pd.scatter_matrix(df_trace[-1000:], diagonal='kde');

plt.figure(figsize=(12, 4))
plt.subplot(121)
plt.plot(df_trace.ix[-1000:, 'height'], linewidth=0.7)
plt.subplot(122)
plt.plot(df_trace.ix[-1000:, 'weight'], linewidth=0.7);

There is no convergence!

Becaue of ths strong correlation between height and weight, a one-at-a-time sampler such as the slice or Gibbs sampler will take a long time to converge. The HMC does much better.

niter = 1000
with pm.Model() as model:
    pm.glm.glm('male ~ height + weight', df, family=pm.glm.families.Binomial())
    trace = pm.sample(niter, step=pm.NUTS(), random_seed=123, progressbar=True)
[-----------------100%-----------------] 1001 of 1000 complete in 27.0 sec
df_trace = pm.trace_to_dataframe(trace)
pd.scatter_matrix(df_trace[-1000:], diagonal='kde');

pm.summary(trace);
Intercept:

  Mean             SD               MC Error         95% HPD interval
  -------------------------------------------------------------------

  -51.393          11.299           0.828            [-73.102, -29.353]

  Posterior quantiles:
  2.5            25             50             75             97.5
  |--------------|==============|==============|--------------|

  -76.964        -58.534        -50.383        -43.856        -30.630


height:

  Mean             SD               MC Error         95% HPD interval
  -------------------------------------------------------------------

  0.747            0.170            0.012            [0.422, 1.096]

  Posterior quantiles:
  2.5            25             50             75             97.5
  |--------------|==============|==============|--------------|

  0.453          0.630          0.732          0.853          1.139


weight:

  Mean             SD               MC Error         95% HPD interval
  -------------------------------------------------------------------

  0.011            0.012            0.001            [-0.012, 0.034]

  Posterior quantiles:
  2.5            25             50             75             97.5
  |--------------|==============|==============|--------------|

  -0.012         0.002          0.010          0.019          0.034
import seaborn as sn
sn.kdeplot(trace['weight'], trace['height'])
plt.xlabel('Weight', fontsize=20)
plt.ylabel('Height', fontsize=20)
plt.style.use('ggplot')

We can use the logistic regression results to classify subjects as male or female based on their height and weight, using 0.5 as a cutoff, as shown in the plot below. Green = true positive male, yellow = true positive female, red halo = misclassification. Blue line shows the 0.5 cutoff.

intercept, height, weight = df_trace[-niter//2:].mean(0)

def predict(w, h, height=height, weight=weight):
    """Predict gender given weight (w) and height (h) values."""
    v = intercept + height*h + weight*w
    return np.exp(v)/(1+np.exp(v))

# calculate predictions on grid
xs = np.linspace(df.weight.min(), df.weight.max(), 100)
ys = np.linspace(df.height.min(), df.height.max(), 100)
X, Y = np.meshgrid(xs, ys)
Z = predict(X, Y)

plt.figure(figsize=(6,6))

# plot 0.5 contour line - classify as male if above this line
plt.contour(X, Y, Z, levels=[0.5])

# classify all subjects
colors = ['lime' if i else 'yellow' for i in df.male]
ps = predict(df.weight, df.height)
errs = ((ps < 0.5) & df.male) |((ps >= 0.5) & (1-df.male))
plt.scatter(df.weight[errs], df.height[errs], facecolors='red', s=150)
plt.scatter(df.weight, df.height, facecolors=colors, edgecolors='k', s=50, alpha=1);
plt.xlabel('Weight', fontsize=16)
plt.ylabel('Height', fontsize=16)
plt.title('Gender classification by weight and height', fontsize=16)
plt.tight_layout();

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文