使用 flax.nn.Module 实现 RNN
我正在尝试使用 flax.nn.Module 实现基本的 RNN 单元。实现 RNN 单元的方程非常简单:
a_t = W * h_{t-1} + U * x_t + b
h_t = tanh(a_t)
o_t = V * h_t + c
其中 h_t 是时间 t 的更新状态,x_t 是输入,o_t 是输出,Tanh 是我们的激活函数。
我的代码使用flax.nn.Module
,
class ElmanCell(nn.Module):
@nn.compact
def __call__(self, h, x):
nextState = jnp.tanh(jnp.dot(W, h) * jnp.dot(U, x) + b)
return nextState
我不知道如何实现参数W、U和b。它们应该是 nn.Module 的属性吗?
I am trying to implement a basic RNN cell with flax.nn.Module
. the equations to implement the RNN cell are quite simple:
a_t = W * h_{t-1} + U * x_t + b
h_t = tanh(a_t)
o_t = V * h_t + c
where h_t is the updated state at time t, x_t is the input and o_t is the output and Tanh is our activation function.
My code uses flax.nn.Module
,
class ElmanCell(nn.Module):
@nn.compact
def __call__(self, h, x):
nextState = jnp.tanh(jnp.dot(W, h) * jnp.dot(U, x) + b)
return nextState
I don't know hoe to implement the parameters W, U and b. Are they supposed to be attributes of nn.Module?
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
尝试这样的事情:
这样参数就会被学习。请参阅 https://schmit.github.io /jax/2021/06/20/jax-language-model-rnn.html
Try something like:
This way the parameters will be learned. See https://schmit.github.io/jax/2021/06/20/jax-language-model-rnn.html