如何将tensorflow中通过py_func自定义的操作用到keras中?
通过keras自定义层的时候,如果使用了tensorflow的py_func函数自定义的操作后,会导致在创建模型时提示AttributeError: 'NoneType' object has no attribute '_inbound_nodes'
。
测试代码是:
import tensorflow as tf
from keras import layers
from keras.layers import Layer
from keras.models import Model
class T(Layer):
def __init__(self, **kwargs):
super(T, self).__init__(**kwargs)
def call(self, inputs, **kwargs):
return tf.zeros(shape=(1, 1))
def compute_output_shape(self, input_shape):
return 1, 1
def direct_return(tensor1, tensor2):
return tensor1, tensor2
def main():
input1 = layers.Input(shape=(2, 2))
input2 = layers.Input(shape=(2, 2))
ret1, ret2 = tf.py_func(direct_return, [input1, input2], [tf.float32, tf.float32])
ret1.set_shape((2, 2))
ret2.set_shape((2, 2))
t1 = T()(ret1)
t2 = T()(ret2)
model = Model(inputs=[input1, input2], outputs=[t1, t2])
main()
报错为:
Using TensorFlow backend.
Traceback (most recent call last):
File "XXX/test.py", line 36, in <module>
main()
File "XXX/test.py", line 33, in main
model = Model(inputs=[input1, input2], outputs=[t1, t2])
File "……\conda\envs\tensorflow\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "……\conda\envs\tensorflow\lib\site-packages\keras\engine\network.py", line 93, in __init__
self._init_graph_network(*args, **kwargs)
File "……\conda\envs\tensorflow\lib\site-packages\keras\engine\network.py", line 237, in _init_graph_network
self.inputs, self.outputs)
File "……\conda\envs\tensorflow\lib\site-packages\keras\engine\network.py", line 1353, in _map_graph_network
tensor_index=tensor_index)
File "……\conda\envs\tensorflow\lib\site-packages\keras\engine\network.py", line 1340, in build_map
node_index, tensor_index)
File "……\conda\envs\tensorflow\lib\site-packages\keras\engine\network.py", line 1312, in build_map
node = layer._inbound_nodes[node_index]
AttributeError: 'NoneType' object has no attribute '_inbound_nodes'
如果将t1
和t2
的两行修改为:
t1 = T()(input1)
t2 = T()(input2)
则一切正常。
那么请问应该如何使keras自定义的层包含自定义的操作呢?
同样遇到这个问题,想在keras调用tf的py_func 如果解决了,能加我q775301251 交流下吗?谢谢您!