TensorFlow:使用并行通道模型的错误

发布于 2025-02-06 19:18:08 字数 7828 浏览 0 评论 0原文

我正在调试我创建的模型,以接受可变数量的输入通道(每个通道都是RGB图像)。我怀疑并非所有的渠道都正确连接。

模型代码:

IMG_SHAPE = (160, 160, 3)

def get_ch_model_simple():
    i_input = tf.keras.Input(shape=IMG_SHAPE)    
    # scale pixels to float
    x = tf.keras.layers.Rescaling(1.0 / 255)(i_input)
    x = tf.keras.layers.Conv2D(32, kernel_size=(3,3), activation="relu")(x)
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)

    return tf.keras.Model(i_input, x)
    
def get_model(n_chan=2):
    
    inputs = tf.keras.Input(shape=(n_chan, 160, 160, 3))

    ch_features = []
    for ch in range(n_chan):
        ch_model = get_ch_model_simple()
        # select specific channel
        ch_model_input = inputs[:,ch,:,:,:]
        i_ch_features = tf.keras.layers.Flatten()(ch_model(ch_model_input))
        i_ch_features = tf.keras.layers.Dropout(0.5)(i_ch_features)
        ch_features.append(i_ch_features)

    all_ch_features = tf.keras.layers.concatenate(ch_features)
    outputs = tf.keras.layers.Dense(2, activation = "softmax")(all_ch_features)
    
    return tf.keras.Model(inputs, outputs)

检查模型:

m = get_model(n_chan=2)
m.summary()

摘要输出:

Model: "model_49"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_25 (InputLayer)          [(None, 2, 160, 160  0           []                               
                                , 3)]                                                             
                                                                                                  
 tf.__operators__.getitem_16 (S  (None, 160, 160, 3)  0          ['input_25[0][0]']               
 licingOpLambda)                                                                                  
                                                                                                  
 tf.__operators__.getitem_17 (S  (None, 160, 160, 3)  0          ['input_25[0][0]']               
 licingOpLambda)                                                                                  
                                                                                                  
 model_47 (Functional)          (None, 79, 79, 32)   896         ['tf.__operators__.getitem_16[0][
                                                                 0]']                             
                                                                                                  
 model_48 (Functional)          (None, 79, 79, 32)   896         ['tf.__operators__.getitem_17[0][
                                                                 0]']                             
                                                                                                  
 flatten_19 (Flatten)           (None, 199712)       0           ['model_47[0][0]']               
                                                                                                  
 flatten_20 (Flatten)           (None, 199712)       0           ['model_48[0][0]']               
                                                                                                  
 dropout_19 (Dropout)           (None, 199712)       0           ['flatten_19[0][0]']             
                                                                                                  
 dropout_20 (Dropout)           (None, 199712)       0           ['flatten_20[0][0]']             
                                                                                                  
 concatenate_8 (Concatenate)    (None, 399424)       0           ['dropout_19[0][0]',             
                                                                  'dropout_20[0][0]']             
                                                                                                  
 dense_10 (Dense)               (None, 2)            798850      ['concatenate_8[0][0]']          
                                                                                                  
==================================================================================================
Total params: 800,642
Trainable params: 800,642
Non-trainable params: 0
__________________________________________________________________________________________________

我担心2个切片操作员已连接到Input_25 [0] [0],他们似乎正在获得相同的频道切片,而不是每个频道每个频道。

另外,如果我尝试创建一个子模型检查CH_MODEL输入和输出,我会获得错误:

r_m = tf.keras.model(model.inputs,model.layers.layers.layers [3] .input)正常工作

但是r_m2 = tf.keras.model(model.inputs,model.layers [3] .output)错误:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [175], in <cell line: 1>()
----> 1 r_m2 = tf.keras.Model(model.inputs, model.layers[3].output)

File ~/PycharmProjects/venv39/lib/python3.9/site-packages/tensorflow/python/training/tracking/base.py:629, in no_automatic_dependency_tracking.<locals>._method_wrapper(self, *args, **kwargs)
    627 self._self_setattr_tracking = False  # pylint: disable=protected-access
    628 try:
--> 629   result = method(self, *args, **kwargs)
    630 finally:
    631   self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

File ~/PycharmProjects/venv39/lib/python3.9/site-packages/keras/engine/functional.py:146, in Functional.__init__(self, inputs, outputs, name, trainable, **kwargs)
    143   if not all([functional_utils.is_input_keras_tensor(t)
    144               for t in tf.nest.flatten(inputs)]):
    145     inputs, outputs = functional_utils.clone_graph_nodes(inputs, outputs)
--> 146 self._init_graph_network(inputs, outputs)

File ~/PycharmProjects/venv39/lib/python3.9/site-packages/tensorflow/python/training/tracking/base.py:629, in no_automatic_dependency_tracking.<locals>._method_wrapper(self, *args, **kwargs)
    627 self._self_setattr_tracking = False  # pylint: disable=protected-access
    628 try:
--> 629   result = method(self, *args, **kwargs)
    630 finally:
    631   self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

File ~/PycharmProjects/venv39/lib/python3.9/site-packages/keras/engine/functional.py:229, in Functional._init_graph_network(self, inputs, outputs)
    226   self._input_coordinates.append((layer, node_index, tensor_index))
    228 # Keep track of the network's nodes and layers.
--> 229 nodes, nodes_by_depth, layers, _ = _map_graph_network(
    230     self.inputs, self.outputs)
    231 self._network_nodes = nodes
    232 self._nodes_by_depth = nodes_by_depth

File ~/PycharmProjects/venv39/lib/python3.9/site-packages/keras/engine/functional.py:1036, in _map_graph_network(inputs, outputs)
   1034 for x in tf.nest.flatten(node.keras_inputs):
   1035   if id(x) not in computable_tensors:
-> 1036     raise ValueError(
   1037         f'Graph disconnected: cannot obtain value for tensor {x} '
   1038         f'at layer "{layer.name}". The following previous layers '
   1039         f'were accessed without issue: {layers_with_complete_input}')
   1040 for x in tf.nest.flatten(node.outputs):
   1041   computable_tensors.add(id(x))

ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, 160, 160, 3), dtype=tf.float32, name='input_8'), name='input_8', description="created by layer 'input_8'") at layer "rescaling_4". The following previous layers were accessed without issue: []

I'm debugging a model I created to accept a variable number of input channels (each channel is an RGB image). I suspect that not all the channels are properly connected.

Model code:

IMG_SHAPE = (160, 160, 3)

def get_ch_model_simple():
    i_input = tf.keras.Input(shape=IMG_SHAPE)    
    # scale pixels to float
    x = tf.keras.layers.Rescaling(1.0 / 255)(i_input)
    x = tf.keras.layers.Conv2D(32, kernel_size=(3,3), activation="relu")(x)
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)

    return tf.keras.Model(i_input, x)
    
def get_model(n_chan=2):
    
    inputs = tf.keras.Input(shape=(n_chan, 160, 160, 3))

    ch_features = []
    for ch in range(n_chan):
        ch_model = get_ch_model_simple()
        # select specific channel
        ch_model_input = inputs[:,ch,:,:,:]
        i_ch_features = tf.keras.layers.Flatten()(ch_model(ch_model_input))
        i_ch_features = tf.keras.layers.Dropout(0.5)(i_ch_features)
        ch_features.append(i_ch_features)

    all_ch_features = tf.keras.layers.concatenate(ch_features)
    outputs = tf.keras.layers.Dense(2, activation = "softmax")(all_ch_features)
    
    return tf.keras.Model(inputs, outputs)

Checking the model:

m = get_model(n_chan=2)
m.summary()

Summary output:

Model: "model_49"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_25 (InputLayer)          [(None, 2, 160, 160  0           []                               
                                , 3)]                                                             
                                                                                                  
 tf.__operators__.getitem_16 (S  (None, 160, 160, 3)  0          ['input_25[0][0]']               
 licingOpLambda)                                                                                  
                                                                                                  
 tf.__operators__.getitem_17 (S  (None, 160, 160, 3)  0          ['input_25[0][0]']               
 licingOpLambda)                                                                                  
                                                                                                  
 model_47 (Functional)          (None, 79, 79, 32)   896         ['tf.__operators__.getitem_16[0][
                                                                 0]']                             
                                                                                                  
 model_48 (Functional)          (None, 79, 79, 32)   896         ['tf.__operators__.getitem_17[0][
                                                                 0]']                             
                                                                                                  
 flatten_19 (Flatten)           (None, 199712)       0           ['model_47[0][0]']               
                                                                                                  
 flatten_20 (Flatten)           (None, 199712)       0           ['model_48[0][0]']               
                                                                                                  
 dropout_19 (Dropout)           (None, 199712)       0           ['flatten_19[0][0]']             
                                                                                                  
 dropout_20 (Dropout)           (None, 199712)       0           ['flatten_20[0][0]']             
                                                                                                  
 concatenate_8 (Concatenate)    (None, 399424)       0           ['dropout_19[0][0]',             
                                                                  'dropout_20[0][0]']             
                                                                                                  
 dense_10 (Dense)               (None, 2)            798850      ['concatenate_8[0][0]']          
                                                                                                  
==================================================================================================
Total params: 800,642
Trainable params: 800,642
Non-trainable params: 0
__________________________________________________________________________________________________

I'm concerned that the 2 slicing operators are connected to input_25[0][0], they seem to be getting the same channel slice instead of getting a different channel each.

In addition if I try to create a submodel to check the ch_model input and output I get error:

r_m = tf.keras.Model(model.inputs, model.layers[3].input) works fine

however r_m2 = tf.keras.Model(model.inputs, model.layers[3].output) errors out:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [175], in <cell line: 1>()
----> 1 r_m2 = tf.keras.Model(model.inputs, model.layers[3].output)

File ~/PycharmProjects/venv39/lib/python3.9/site-packages/tensorflow/python/training/tracking/base.py:629, in no_automatic_dependency_tracking.<locals>._method_wrapper(self, *args, **kwargs)
    627 self._self_setattr_tracking = False  # pylint: disable=protected-access
    628 try:
--> 629   result = method(self, *args, **kwargs)
    630 finally:
    631   self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

File ~/PycharmProjects/venv39/lib/python3.9/site-packages/keras/engine/functional.py:146, in Functional.__init__(self, inputs, outputs, name, trainable, **kwargs)
    143   if not all([functional_utils.is_input_keras_tensor(t)
    144               for t in tf.nest.flatten(inputs)]):
    145     inputs, outputs = functional_utils.clone_graph_nodes(inputs, outputs)
--> 146 self._init_graph_network(inputs, outputs)

File ~/PycharmProjects/venv39/lib/python3.9/site-packages/tensorflow/python/training/tracking/base.py:629, in no_automatic_dependency_tracking.<locals>._method_wrapper(self, *args, **kwargs)
    627 self._self_setattr_tracking = False  # pylint: disable=protected-access
    628 try:
--> 629   result = method(self, *args, **kwargs)
    630 finally:
    631   self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

File ~/PycharmProjects/venv39/lib/python3.9/site-packages/keras/engine/functional.py:229, in Functional._init_graph_network(self, inputs, outputs)
    226   self._input_coordinates.append((layer, node_index, tensor_index))
    228 # Keep track of the network's nodes and layers.
--> 229 nodes, nodes_by_depth, layers, _ = _map_graph_network(
    230     self.inputs, self.outputs)
    231 self._network_nodes = nodes
    232 self._nodes_by_depth = nodes_by_depth

File ~/PycharmProjects/venv39/lib/python3.9/site-packages/keras/engine/functional.py:1036, in _map_graph_network(inputs, outputs)
   1034 for x in tf.nest.flatten(node.keras_inputs):
   1035   if id(x) not in computable_tensors:
-> 1036     raise ValueError(
   1037         f'Graph disconnected: cannot obtain value for tensor {x} '
   1038         f'at layer "{layer.name}". The following previous layers '
   1039         f'were accessed without issue: {layers_with_complete_input}')
   1040 for x in tf.nest.flatten(node.outputs):
   1041   computable_tensors.add(id(x))

ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, 160, 160, 3), dtype=tf.float32, name='input_8'), name='input_8', description="created by layer 'input_8'") at layer "rescaling_4". The following previous layers were accessed without issue: []

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文