使用 flax.nn.Module 实现 RNN

发布于 2025-01-14 01:02:09 字数 548 浏览 2 评论 0原文

我正在尝试使用 flax.nn.Module 实现基本的 RNN 单元。实现 RNN 单元的方程非常简单:

a_t = W * h_{t-1} + U * x_t + b

h_t = tanh(a_t)

o_t = V * h_t + c

其中 h_t 是时间 t 的更新状态,x_t 是输入,o_t 是输出,Tanh 是我们的激活函数。

我的代码使用flax.nn.Module

class ElmanCell(nn.Module):
  @nn.compact
  def __call__(self, h, x):
    nextState = jnp.tanh(jnp.dot(W, h) * jnp.dot(U, x) + b)
    return nextState

我不知道如何实现参数W、U和b。它们应该是 nn.Module 的属性吗?

I am trying to implement a basic RNN cell with flax.nn.Module. the equations to implement the RNN cell are quite simple:

a_t = W * h_{t-1} + U * x_t + b

h_t = tanh(a_t)

o_t = V * h_t + c

where h_t is the updated state at time t, x_t is the input and o_t is the output and Tanh is our activation function.

My code uses flax.nn.Module,

class ElmanCell(nn.Module):
  @nn.compact
  def __call__(self, h, x):
    nextState = jnp.tanh(jnp.dot(W, h) * jnp.dot(U, x) + b)
    return nextState

I don't know hoe to implement the parameters W, U and b. Are they supposed to be attributes of nn.Module?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

圈圈圆圆圈圈 2025-01-21 01:02:10

尝试这样的事情:

class RNNCell(nn.Module):
  @nn.compact
  def __call__(self, state, x):
    # Wh @ h + Wx @ x + b can be efficiently computed
    # by concatenating the vectors and then having a single dense layer
    x = np.concatenate([state, x])
    new_state = np.tanh(nn.Dense(state.shape[0])(x))
    return new_state

这样参数就会被学习。请参阅 https://schmit.github.io /jax/2021/06/20/jax-language-model-rnn.html

Try something like:

class RNNCell(nn.Module):
  @nn.compact
  def __call__(self, state, x):
    # Wh @ h + Wx @ x + b can be efficiently computed
    # by concatenating the vectors and then having a single dense layer
    x = np.concatenate([state, x])
    new_state = np.tanh(nn.Dense(state.shape[0])(x))
    return new_state

This way the parameters will be learned. See https://schmit.github.io/jax/2021/06/20/jax-language-model-rnn.html

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文