如何使用破烂的张量为张量流模型提供两个输入?

发布于 2025-01-23 22:14:31 字数 5969 浏览 0 评论 0原文

我正在尝试创建一个带有两个输入的模型。该模型非常简单,仅包含每个输入的一个LSTM层。问题在于我想提供不同长度作为输入的列表。为此,我使用的是破烂的张量,但是训练过程失败了。

ds = pd.DataFrame({"col_1":[[0],[0,0],[0,0,0],[0,0,0,0],[0,0,0,0,0],[0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0]],"col_2":[8*[0],7*[1],6*[2],5*[3],4*[4],3*[5],2*[6],1*[7]]})
ds = ds.loc[ds.index.repeat(1250)].reset_index(drop=True)
ds = ds.sample(frac=1, random_state=43).reset_index(drop=True)

feat_1_inputs = [tf.keras.layers.Input(batch_shape=(None,None,1),ragged=True,name="col_1")]
feat_1 = tf.keras.layers.LSTM(10, return_sequences=True, return_state=False, stateful=False)(feat_1_inputs[0])

feat_2_inputs = [tf.keras.layers.Input(batch_shape=(None,None,1),ragged=True,name="col_2")]
feat_2 = tf.keras.layers.LSTM(10, return_sequences=True, return_state=False, stateful=False)(feat_2_inputs[0])

concat_inputs = tf.keras.layers.Concatenate()([feat_1, feat_2])
output = tf.keras.layers.Dense(10, activation='relu',kernel_initializer=glorot_uniform())(concat_inputs)
output = tf.keras.layers.Dense(10, kernel_initializer=glorot_uniform())(output)
output = tf.keras.layers.Activation(activation='softmax', dtype='float32')(output)

model = tf.keras.Model(feat_1_inputs + feat_2_inputs, output)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=tf.keras.losses.sparse_categorical_crossentropy)

col_1_data = [tf.expand_dims(tf.ragged.constant(ds['col_1'].values,dtype=np.int64),axis=-1)]
col_2_data = tf.expand_dims(tf.ragged.constant(ds['col_2'].values,dtype=np.int64),axis=-1)
col_1_data.append(col_2_data)

model.fit(x=col_1_data,y=col_2_data,epochs=10)

错误:

Epoch 1/10
Traceback (most recent call last):
  File "/home/user/.config/JetBrains/PyCharmCE2021.2/scratches/scratch_19.py", line 33, in <module>
    model.fit(x=col_1_data,y=col_2_data,epochs=10)
  File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 54, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:

Detected at node 'model/concatenate/RaggedConcat/assert_equal_1/Assert/AssertGuard/Assert' defined at (most recent call last):
    File "/home/user/.config/JetBrains/PyCharmCE2021.2/scratches/scratch_19.py", line 33, in <module>
      model.fit(x=col_1_data,y=col_2_data,epochs=10)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/training.py", line 1384, in fit
      tmp_logs = self.train_function(iterator)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/training.py", line 1021, in train_function
      return step_function(self, iterator)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/training.py", line 1010, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/training.py", line 859, in train_step
      y_pred = self(x, training=True)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/functional.py", line 451, in call
      return self._run_internal_graph(
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/functional.py", line 589, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/layers/merge.py", line 183, in call
      return self._merge_function(inputs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/layers/merge.py", line 531, in _merge_function
      return backend.concatenate(inputs, axis=self.axis)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/backend.py", line 3311, in concatenate
      return tf.concat(tensors, axis)
        Node: 'model/concatenate/RaggedConcat/assert_equal_1/Assert/AssertGuard/Assert'
        assertion failed: [Inputs must have identical ragged splits] [Condition x == y did not hold element-wise:] [x (model/lstm/RaggedFromTensor/concat:0) = ] [0 8 11...] [y (model/lstm_1/RaggedFromTensor/concat:0) = ] [0 1 7...]
             [[{{node model/concatenate/RaggedConcat/assert_equal_1/Assert/AssertGuard/Assert}}]] [Op:__inference_train_function_9256]

如果两列中的行包含相同长度的列表,则可以正常工作。 有没有一种方法可以使用张紧张量使用不同长度的列表?

使用TF2.8。

I am trying to create a model with two inputs. The model is very simple containing only one lstm layer for each input. The problem is that I want to provide lists of different length as inputs. For that, I am using ragged tensors, but the training process fails.

ds = pd.DataFrame({"col_1":[[0],[0,0],[0,0,0],[0,0,0,0],[0,0,0,0,0],[0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0]],"col_2":[8*[0],7*[1],6*[2],5*[3],4*[4],3*[5],2*[6],1*[7]]})
ds = ds.loc[ds.index.repeat(1250)].reset_index(drop=True)
ds = ds.sample(frac=1, random_state=43).reset_index(drop=True)

feat_1_inputs = [tf.keras.layers.Input(batch_shape=(None,None,1),ragged=True,name="col_1")]
feat_1 = tf.keras.layers.LSTM(10, return_sequences=True, return_state=False, stateful=False)(feat_1_inputs[0])

feat_2_inputs = [tf.keras.layers.Input(batch_shape=(None,None,1),ragged=True,name="col_2")]
feat_2 = tf.keras.layers.LSTM(10, return_sequences=True, return_state=False, stateful=False)(feat_2_inputs[0])

concat_inputs = tf.keras.layers.Concatenate()([feat_1, feat_2])
output = tf.keras.layers.Dense(10, activation='relu',kernel_initializer=glorot_uniform())(concat_inputs)
output = tf.keras.layers.Dense(10, kernel_initializer=glorot_uniform())(output)
output = tf.keras.layers.Activation(activation='softmax', dtype='float32')(output)

model = tf.keras.Model(feat_1_inputs + feat_2_inputs, output)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=tf.keras.losses.sparse_categorical_crossentropy)

col_1_data = [tf.expand_dims(tf.ragged.constant(ds['col_1'].values,dtype=np.int64),axis=-1)]
col_2_data = tf.expand_dims(tf.ragged.constant(ds['col_2'].values,dtype=np.int64),axis=-1)
col_1_data.append(col_2_data)

model.fit(x=col_1_data,y=col_2_data,epochs=10)

Error:

Epoch 1/10
Traceback (most recent call last):
  File "/home/user/.config/JetBrains/PyCharmCE2021.2/scratches/scratch_19.py", line 33, in <module>
    model.fit(x=col_1_data,y=col_2_data,epochs=10)
  File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 54, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:

Detected at node 'model/concatenate/RaggedConcat/assert_equal_1/Assert/AssertGuard/Assert' defined at (most recent call last):
    File "/home/user/.config/JetBrains/PyCharmCE2021.2/scratches/scratch_19.py", line 33, in <module>
      model.fit(x=col_1_data,y=col_2_data,epochs=10)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/training.py", line 1384, in fit
      tmp_logs = self.train_function(iterator)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/training.py", line 1021, in train_function
      return step_function(self, iterator)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/training.py", line 1010, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/training.py", line 859, in train_step
      y_pred = self(x, training=True)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/functional.py", line 451, in call
      return self._run_internal_graph(
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/functional.py", line 589, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/layers/merge.py", line 183, in call
      return self._merge_function(inputs)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/layers/merge.py", line 531, in _merge_function
      return backend.concatenate(inputs, axis=self.axis)
    File "/home/user/miniconda3/envs/model/lib/python3.9/site-packages/keras/backend.py", line 3311, in concatenate
      return tf.concat(tensors, axis)
        Node: 'model/concatenate/RaggedConcat/assert_equal_1/Assert/AssertGuard/Assert'
        assertion failed: [Inputs must have identical ragged splits] [Condition x == y did not hold element-wise:] [x (model/lstm/RaggedFromTensor/concat:0) = ] [0 8 11...] [y (model/lstm_1/RaggedFromTensor/concat:0) = ] [0 1 7...]
             [[{{node model/concatenate/RaggedConcat/assert_equal_1/Assert/AssertGuard/Assert}}]] [Op:__inference_train_function_9256]

If rows in both columns contain lists of the same length then it works fine.
Is there a way to work with lists of different length using ragged tensors?

TF2.8 is used.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文