如何将tensorflow中通过py_func自定义的操作用到keras中?

发布于 09-07 21:14 字数 2279 浏览 24 评论 0

通过keras自定义层的时候,如果使用了tensorflow的py_func函数自定义的操作后,会导致在创建模型时提示AttributeError: 'NoneType' object has no attribute '_inbound_nodes'

测试代码是:

import tensorflow as tf
from keras import layers
from keras.layers import Layer
from keras.models import Model


class T(Layer):
    def __init__(self, **kwargs):
        super(T, self).__init__(**kwargs)

    def call(self, inputs, **kwargs):
        return tf.zeros(shape=(1, 1))

    def compute_output_shape(self, input_shape):
        return 1, 1


def direct_return(tensor1, tensor2):
    return tensor1, tensor2


def main():
    input1 = layers.Input(shape=(2, 2))
    input2 = layers.Input(shape=(2, 2))

    ret1, ret2 = tf.py_func(direct_return, [input1, input2], [tf.float32, tf.float32])
    ret1.set_shape((2, 2))
    ret2.set_shape((2, 2))

    t1 = T()(ret1)
    t2 = T()(ret2)

    model = Model(inputs=[input1, input2], outputs=[t1, t2])


main()

报错为:

Using TensorFlow backend.
Traceback (most recent call last):
  File "XXX/test.py", line 36, in <module>
    main()
  File "XXX/test.py", line 33, in main
    model = Model(inputs=[input1, input2], outputs=[t1, t2])
  File "……\conda\envs\tensorflow\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "……\conda\envs\tensorflow\lib\site-packages\keras\engine\network.py", line 93, in __init__
    self._init_graph_network(*args, **kwargs)
  File "……\conda\envs\tensorflow\lib\site-packages\keras\engine\network.py", line 237, in _init_graph_network
    self.inputs, self.outputs)
  File "……\conda\envs\tensorflow\lib\site-packages\keras\engine\network.py", line 1353, in _map_graph_network
    tensor_index=tensor_index)
  File "……\conda\envs\tensorflow\lib\site-packages\keras\engine\network.py", line 1340, in build_map
    node_index, tensor_index)
  File "……\conda\envs\tensorflow\lib\site-packages\keras\engine\network.py", line 1312, in build_map
    node = layer._inbound_nodes[node_index]
AttributeError: 'NoneType' object has no attribute '_inbound_nodes'

如果将t1t2的两行修改为:

    t1 = T()(input1)
    t2 = T()(input2)

则一切正常。

那么请问应该如何使keras自定义的层包含自定义的操作呢?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

温柔戏命师2022-09-14 21:14:36

同样遇到这个问题,想在keras调用tf的py_func 如果解决了,能加我q775301251 交流下吗?谢谢您!

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文