(conv1d)张量和jax产生了相同输入的不同输出

发布于 2025-01-22 04:02:48 字数 2055 浏览 1 评论 0原文

我正在尝试使用Conv1D函数在JAX和TensorFlow上进行反式卷发。我为CON1D_transposed操作阅读了JAX和TensorFlow的文档,但它们的输出对相同的输入产生了不同的输出。

我找不到问题是什么。而且我不知道哪一个产生正确的结果。请帮我。

我的JAX实现(JAX代码)

x = np.asarray([[[1, 2, 3, 4, -5], [1, 2, 3, 4, 5]]], dtype=np.float32).transpose((0, 2, 1))
filters = np.array([[[1, 0, -1], [-1,  0,  1]], 
                    [[1, 1,  1], [-1, -1, -1]]], 
                    dtype=np.float32).transpose((2, 1, 0))

kernel_rot = np.rot90(np.rot90(filters))

print(f"x strides:  {x.strides}\nfilters strides: {kernel_rot.strides}\nx shape: {x.shape}\nfilters shape: {filters.shape}\nx: \n{x}\nfilters: \n{filters}\n")

dn1 = lax.conv_dimension_numbers(x.shape, filters.shape,('NWC', 'WIO', 'NWC'))
print(dn1)

res = lax.conv_general_dilated(x,kernel_rot,(1,),'SAME',(1,),(1,),dn1)     

res = np.asarray(res)
print(f"result strides: {res.strides}\nresult shape: {res.shape}\nresult: \n{res}\n")

我的TensorFlow实现(TensorFlow代码)

x = np.asarray([[[1, 2, 3, 4, -5], [1, 2, 3, 4, 5]]], dtype=np.float32).transpose((0, 2, 1))
filters = np.array([[[1, 0, -1], [-1,  0,  1]], 
                    [[1, 1,  1], [-1, -1, -1]]], 
                    dtype=np.float32).transpose((2, 1, 0))

print(f"x strides:  {x.strides}\nfilters strides: {filters.strides}\nx shape: {x.shape}\nfilters shape: {filters.shape}\nx: \n{x}\nfilters: \n{filters}\n")
    
res = tf.nn.conv1d_transpose(x, filters, output_shape = x.shape, strides = (1, 1, 1), padding = 'SAME', data_format='NWC', dilations=1)

res = np.asarray(res)
print(f"result strides: {res.strides}\nresult shape: {res.shape}\nresult: \n{res}\n")

JAX输出输出

result strides: (40, 8, 4)
result shape: (1, 5, 2)
result: 
[[[ 0.  0.]
  [ 0.  0.]
  [ 0.  0.]
  [10. 10.]
  [ 0. 10.]]]

从TensorFlow的

result strides: (40, 8, 4)
result shape: (1, 5, 2)
result: 
[[[  5.  -5.]
  [  8.  -8.]
  [ 11. -11.]
  [  4.  -4.]
  [  5.  -5.]]]

I am trying to use conv1d functions to make a transposed convlotion repectively at jax and tensorflow. I read the documentation of both of jax and tensorflow for the con1d_transposed operation but they are resulting with different outputs for the same input.

I can not find out what the problem is. And I don't know which one produces the correct results. Help me please.

My Jax Implementation (Jax Code)

x = np.asarray([[[1, 2, 3, 4, -5], [1, 2, 3, 4, 5]]], dtype=np.float32).transpose((0, 2, 1))
filters = np.array([[[1, 0, -1], [-1,  0,  1]], 
                    [[1, 1,  1], [-1, -1, -1]]], 
                    dtype=np.float32).transpose((2, 1, 0))

kernel_rot = np.rot90(np.rot90(filters))

print(f"x strides:  {x.strides}\nfilters strides: {kernel_rot.strides}\nx shape: {x.shape}\nfilters shape: {filters.shape}\nx: \n{x}\nfilters: \n{filters}\n")

dn1 = lax.conv_dimension_numbers(x.shape, filters.shape,('NWC', 'WIO', 'NWC'))
print(dn1)

res = lax.conv_general_dilated(x,kernel_rot,(1,),'SAME',(1,),(1,),dn1)     

res = np.asarray(res)
print(f"result strides: {res.strides}\nresult shape: {res.shape}\nresult: \n{res}\n")

My TensorFlow Implementation (TensorFlow Code)

x = np.asarray([[[1, 2, 3, 4, -5], [1, 2, 3, 4, 5]]], dtype=np.float32).transpose((0, 2, 1))
filters = np.array([[[1, 0, -1], [-1,  0,  1]], 
                    [[1, 1,  1], [-1, -1, -1]]], 
                    dtype=np.float32).transpose((2, 1, 0))

print(f"x strides:  {x.strides}\nfilters strides: {filters.strides}\nx shape: {x.shape}\nfilters shape: {filters.shape}\nx: \n{x}\nfilters: \n{filters}\n")
    
res = tf.nn.conv1d_transpose(x, filters, output_shape = x.shape, strides = (1, 1, 1), padding = 'SAME', data_format='NWC', dilations=1)

res = np.asarray(res)
print(f"result strides: {res.strides}\nresult shape: {res.shape}\nresult: \n{res}\n")

Output from the Jax

result strides: (40, 8, 4)
result shape: (1, 5, 2)
result: 
[[[ 0.  0.]
  [ 0.  0.]
  [ 0.  0.]
  [10. 10.]
  [ 0. 10.]]]

Output from the TensorFlow

result strides: (40, 8, 4)
result shape: (1, 5, 2)
result: 
[[[  5.  -5.]
  [  8.  -8.]
  [ 11. -11.]
  [  4.  -4.]
  [  5.  -5.]]]

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

各空 2025-01-29 04:02:48

函数 [filter_width,output_channels,in_channels] 。如果过滤器在上面的片段中被转移以满足此形状,则jax返回正确的结果,而计算dn1参数应为woi> woi(<强> w idth - o utput_channels- i nput_channels)而不是wio> wio w idth- i nput_channels- o utput_channels)。之后:

result.strides = (40, 8, 4)
result.shape = (1, 5, 2)
result: 
[[[ -5.,   5.],
  [ -8.,   8.],
  [-11.,  11.],
  [ -4.,   4.],
  [ -5.,   5.]]]

结果与TensorFlow不同,但是JAX的内核被翻转了,因此实际上是预期的。

Function conv1d_transpose expects filters in shape [filter_width, output_channels, in_channels]. If filters in snippet above were transposed to satisfy this shape, then for jax to return correct results, while computing dn1 parameter should be WOI (Width - Output_channels - Input_channels) and not WIO (Width - Input_channels - Output_channels). After that:

result.strides = (40, 8, 4)
result.shape = (1, 5, 2)
result: 
[[[ -5.,   5.],
  [ -8.,   8.],
  [-11.,  11.],
  [ -4.,   4.],
  [ -5.,   5.]]]

Results not same as with tensorflow, but kernels for jax were flipped, so actually that was expected.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文