神经网络总是预测平均值

发布于 2025-01-09 18:41:14 字数 3034 浏览 4 评论 0原文

我正在尝试训练一个神经网络来逼近两个变量的已知标量函数;然而,无论我的训练参数如何,网络最终总是只是简单地预测真实输出的平均值

我正在使用 MLP 并尝试过:

  • 使用几种网络深度和宽度
  • 不同的优化器(SGD 和 ADAM)
  • 不同的激活(ReLU 和 Sigmoid)
  • 更改学习率(0.1 到 0.001 范围内的几个点)
  • 数据(至 10,000 点)
  • 增加 纪元数(最多 2,000)
  • 不同的随机种子 无济于事。

我的损失函数是 MSE 并且始终稳定在 5.14 左右的值。

无论我做出什么更改,我都会得到以下结果: 输入图片这里的描述

其中蓝色表面是要近似的函数,绿色表面是函数的MLP近似,其值为大致是平均数该域上的真实函数(真实平均值为 2.15,平方为 4.64 - 距离损失平台值不远)。

我觉得我可能错过了一些非常明显的东西,而且只是看它太久了。非常感谢任何帮助!谢谢,

我在这里附上了我的代码(我正在使用 JAX):

import jax.numpy as jnp
from jax import grad, jit, vmap, random, value_and_grad
import flax
import flax.linen as nn
import optax


seed = 2
key, data_key = random.split(random.PRNGKey(seed))
x1, x2, y= generate_data(data_key)  # Data generation function

# Using Flax - define an MLP
class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

# Define function that returns JITted loss function
def make_mlp_loss(input_data, true_y):

  def mlp_loss(params):
    pred_y = model.apply(params, input_data)
    loss_vector = jnp.square(true_y.reshape(-1) - pred_y)
    return jnp.average(loss_vector)

  # Outer scope incapsulation saves the data and true output
  return jit(mlp_loss)


# Concatenate independent variable vectors to be proper input shape
input_data = jnp.hstack((x1.reshape(-1, 1), x2.reshape(-1, 1)))

# Create loss function with data and true output
mlp_loss = make_mlp_loss(input_data, y)

# Create function that returns loss and gradient
loss_and_grad = value_and_grad(mlp_loss)

# Example architectures I've tried
architectures = [[16, 16, 1], [8, 16, 1], [16, 8, 1], [8, 16, 8, 1], [32, 32, 1]]

# Only using one seed but iterated over several
for seed in [645]:
  for architecture in architectures:
    # Create model
    model = MLP(architecture)
    
    # Initialize model with random parameters
    key, params_key = random.split(key)
    dummy = jnp.ones((1000, 2))
    params = model.init(params_key, dummy)

    # Create optimizer
    opt = optax.adam(learning_rate=0.01) #sgd
    opt_state = opt.init(params)

    
    epochs = 50
    for i in range(epochs):
      # Get loss and gradient 
      curr_loss, curr_grad = loss_and_grad(params)
      if i % 5 == 0:
        print(curr_loss)

      # Update
      updates, opt_state = opt.update(curr_grad, opt_state)
      params = optax.apply_updates(params, updates)
      
    print(f"Architecture: {architecture}\nLoss: {curr_loss}\nSeed: {seed}\n\n")

I'm trying to train a neural network to approximate a known scalar function of two variables; however, no matter the parameters of my training, the network always just ends up simply predicting the average value of the true outputs.

I am using an MLP and have tried:

  • using several network depths and widths
  • different optimizers (SGD and ADAM)
  • different activations (ReLU and Sigmoid)
  • changing the learning rate (several points within the range 0.1 to 0.001)
  • increasing the data (to 10,000 points)
  • increasing the number of epochs (to 2,000)
  • and different random seeds
    to no avail.

My loss function is MSE and always plateaus to a value of about 5.14.

Regardless of changes I make, I get the following results:
enter image description here

Where the blue surface is the function to be approximated, and the green surface is the MLP approximation of the function, having a value that is roughly the average of the true function over that domain (the true average is 2.15 with a square of 4.64 - not far from the loss plateau value).

I feel like I could be missing something very obvious and have just been looking at it for too long. Any help is greatly appreciated! Thanks

I've attached my code here (I'm using JAX):

import jax.numpy as jnp
from jax import grad, jit, vmap, random, value_and_grad
import flax
import flax.linen as nn
import optax


seed = 2
key, data_key = random.split(random.PRNGKey(seed))
x1, x2, y= generate_data(data_key)  # Data generation function

# Using Flax - define an MLP
class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

# Define function that returns JITted loss function
def make_mlp_loss(input_data, true_y):

  def mlp_loss(params):
    pred_y = model.apply(params, input_data)
    loss_vector = jnp.square(true_y.reshape(-1) - pred_y)
    return jnp.average(loss_vector)

  # Outer scope incapsulation saves the data and true output
  return jit(mlp_loss)


# Concatenate independent variable vectors to be proper input shape
input_data = jnp.hstack((x1.reshape(-1, 1), x2.reshape(-1, 1)))

# Create loss function with data and true output
mlp_loss = make_mlp_loss(input_data, y)

# Create function that returns loss and gradient
loss_and_grad = value_and_grad(mlp_loss)

# Example architectures I've tried
architectures = [[16, 16, 1], [8, 16, 1], [16, 8, 1], [8, 16, 8, 1], [32, 32, 1]]

# Only using one seed but iterated over several
for seed in [645]:
  for architecture in architectures:
    # Create model
    model = MLP(architecture)
    
    # Initialize model with random parameters
    key, params_key = random.split(key)
    dummy = jnp.ones((1000, 2))
    params = model.init(params_key, dummy)

    # Create optimizer
    opt = optax.adam(learning_rate=0.01) #sgd
    opt_state = opt.init(params)

    
    epochs = 50
    for i in range(epochs):
      # Get loss and gradient 
      curr_loss, curr_grad = loss_and_grad(params)
      if i % 5 == 0:
        print(curr_loss)

      # Update
      updates, opt_state = opt.update(curr_grad, opt_state)
      params = optax.apply_updates(params, updates)
      
    print(f"Architecture: {architecture}\nLoss: {curr_loss}\nSeed: {seed}\n\n")

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文