给定一个(5,2)张量,删除行在第二列中具有重复的行

发布于 2025-01-24 07:50:22 字数 186 浏览 0 评论 0原文

因此,让我们假设我有这样的张量:

[[0,18],
[1,19],
[2, 3],
[3,19],
[4, 18]]

我需要仅使用TensorFlow删除第二列中包含重复的行。最终的输出应该是:

[[0,18],
[1,19],
[2, 3]]

So, let's assume I have a tensor like this:

[[0,18],
[1,19],
[2, 3],
[3,19],
[4, 18]]

I need to delete rows that contains duplicates in the second column only by using tensorflow. The final output should be this:

[[0,18],
[1,19],
[2, 3]]

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

只等公子 2025-01-31 07:50:22

您应该能够使用tf.math.unsorted_segrent_mintf.gather

import tensorflow as tf

x = tf.constant([[0,18],
                 [1,19],
                 [2, 3],
                 [3,19],
                 [4, 18]])


y, idx = tf.unique(x[:, 1])
indices = tf.math.unsorted_segment_min(tf.range(tf.shape(x)[0]), idx, tf.shape(y)[0])
result = tf.gather(x, indices)
print(result)
tf.Tensor(
[[ 0 18]
 [ 1 19]
 [ 2  3]], shape=(3, 2), dtype=int32)

这是对调用tf.unique tf.unique <的解决方案的解决。 /code>:

You should be able to solve this with tf.math.unsorted_segment_min and tf.gather:

import tensorflow as tf

x = tf.constant([[0,18],
                 [1,19],
                 [2, 3],
                 [3,19],
                 [4, 18]])


y, idx = tf.unique(x[:, 1])
indices = tf.math.unsorted_segment_min(tf.range(tf.shape(x)[0]), idx, tf.shape(y)[0])
result = tf.gather(x, indices)
print(result)
tf.Tensor(
[[ 0 18]
 [ 1 19]
 [ 2  3]], shape=(3, 2), dtype=int32)

Here is a simple explanation to what is happening after calling tf.unique:

enter image description here

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文