修改i-th下一个张量子值,每次值1出现在张量中

发布于 2025-02-13 23:20:33 字数 1323 浏览 0 评论 0原文

我有两个张量相同的张量:

a = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b = [0, 1, 1, 1, 1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0,  0,  0,  1]

张量 a 有三个区域,这些区域由连续值删除:区域1 IS [1,2,3,4,5],区域2是[10,11,12,13],区域3是[20,21,22,23,24,24,25,26,27,28]

对于每个区域,我要应用以下逻辑:如果 b 的值之一为1,则以下 i 值设置为0。如果它们已经为0,它们继续为0。在更改 i 值之后,直到另一个值B 为1。在这种情况下,下一个 i < /em>值被迫为0 ...

一些示例:

# i = 1

a     = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 1, 0,  1,  0,  1,  0,  1,  0,  1,  0,  1,  0,  0,  0,  1]


# i = 2

a     = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 0, 1,  1,  0,  0,  0,  1,  0,  0,  1,  0,  0,  0,  0,  1]


# i = 4

a     = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 0, 0,  1,  0,  0,  0,  1,  0,  0,  0,  0,  0,  0,  0,  1]

不确定如果这会有所帮助,但是我能够通过这样做将区域分为细分市场:

a_shifted = tf.roll(a - 1, shift=-1, axis=0)
a_shifted_segs = tf.math.cumsum(tf.cast(a_shifted != a, dtype=tf.int64), exclusive=True)

# a_shifted_segs = 
= [0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2]

您知道有效地做到这一点吗?

I have two tensors with the same size:

a = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b = [0, 1, 1, 1, 1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0,  0,  0,  1]

Tensor a has three regions which are demarked by consecutive values: region 1 is [1,2,3,4,5], region 2 is [10,11,12,13] and region 3 is [20, 21, 22, 23, 24, 25, 26, 27, 28].

For each of those regions, I want to apply the following logic: if one of the values of b is 1, then the following i values are set to 0. If they are already 0, they continue as 0. After i values are changed, nothing happens until another value of b is 1. In that case, the next i values are forced to 0...

Some examples:

# i = 1

a     = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 1, 0,  1,  0,  1,  0,  1,  0,  1,  0,  1,  0,  0,  0,  1]


# i = 2

a     = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 0, 1,  1,  0,  0,  0,  1,  0,  0,  1,  0,  0,  0,  0,  1]


# i = 4

a     = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 0, 0,  1,  0,  0,  0,  1,  0,  0,  0,  0,  0,  0,  0,  1]

Not sure if this would help, but I was able to separate the regions into segments by doing:

a_shifted = tf.roll(a - 1, shift=-1, axis=0)
a_shifted_segs = tf.math.cumsum(tf.cast(a_shifted != a, dtype=tf.int64), exclusive=True)

# a_shifted_segs = 
= [0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2]

Do you know any way of doing this efficiently?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(3

仙女 2025-02-20 23:20:33

在这里,您有一个基于tf.scan的TensorFlow解决方案。我知道条件有点复杂,如果您有建议如何简化,我就会提出建议。但是,如果您知道如何阅读条件,则应很清楚代码的作用。

在这里,变量i告诉我们,对于数组中的每个位置,b值必须用0覆盖多少b

import tensorflow as tf 

a = tf.constant([1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28])
b = tf.constant([0, 1, 1, 1, 1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0,  0,  0,  1])

# Extract switches inside a
switches = tf.scan(
    lambda e, new_a: {'a': new_a, 'out': new_a != (e['a']+1)}, 
    a, 
    initializer={'a': tf.reduce_min(a)-2, 'out': tf.constant(False)}
)['out']

# Define inputs for the scan iterations
initializer = {'b': tf.constant(False), 'i': tf.constant(0)}
elems = {'switches': switches, 'b': tf.cast(b, dtype=tf.bool)}

@tf.function
def step(last_out, new_in, max_i):
    new_i = tf.cond(
        last_out['i'] > 0, # If we are currently overwriting with 0
        lambda: tf.cond(
            new_in['switches'], # Is there a segment switch?
            lambda: tf.cond( # if switches:
                new_in['b'], # Check if b == 1
                lambda: tf.constant(max_i), # if b == 1: i = max_i
                lambda: tf.constant(0) # if b == 0: i = 0
            ),
            lambda: tf.maximum(last_out['i']-1, 0) # If no switch, decrement i
        ),
        lambda: tf.cond( # if we are currently not overwriting with 0
            new_in['b'], # check if b == 1
            lambda: tf.constant(max_i), # if b == 1: i = max_i
            lambda: tf.constant(0) # if b == 0: i = 0
        )
    )
    b = tf.cond(
        tf.equal(new_i, max_i), # Have we just reset i ?
        lambda: tf.constant(True), # If yes, we want to write b = 1
        lambda: tf.constant(False) # Otherwise, we write b = 0
    )
    
    return {'b': b, 'i': new_i}

示例:

outp_1 = tf.scan(lambda _last, _inp: step(_last, _inp, max_i=1), elems=elems, initializer=initializer)
print( tf.cast(outp_1['b'], tf.int32) )
# tf.Tensor([0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 0 0 1], shape=(18,), dtype=int32)

outp_2 = tf.scan(lambda _last, _inp: step(_last, _inp, max_i=2), elems=elems, initializer=initializer)
print( tf.cast(outp_2['b'], tf.int32) )
# tf.Tensor([0 1 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 1], shape=(18,), dtype=int32)

outp_4 = tf.scan(lambda _last, _inp: step(_last, _inp, max_i=4), elems=elems, initializer=initializer)
print( tf.cast(outp_4['b'], tf.int32) )
# tf.Tensor([0 1 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 1], shape=(18,), dtype=int32)

此答案由Lambda赞助。

Here you have a tensorflow solution, based on tf.scan. I know the conditionals are a bit complicated, if you have suggestions how to simplify, I'm open for suggestions. However, if you know how to read the conditionals, it should be quite clear what the code does.

Here, the variable i tells us, for each position in the array, how many more b values have to overwritten with 0.

import tensorflow as tf 

a = tf.constant([1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28])
b = tf.constant([0, 1, 1, 1, 1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0,  0,  0,  1])

# Extract switches inside a
switches = tf.scan(
    lambda e, new_a: {'a': new_a, 'out': new_a != (e['a']+1)}, 
    a, 
    initializer={'a': tf.reduce_min(a)-2, 'out': tf.constant(False)}
)['out']

# Define inputs for the scan iterations
initializer = {'b': tf.constant(False), 'i': tf.constant(0)}
elems = {'switches': switches, 'b': tf.cast(b, dtype=tf.bool)}

@tf.function
def step(last_out, new_in, max_i):
    new_i = tf.cond(
        last_out['i'] > 0, # If we are currently overwriting with 0
        lambda: tf.cond(
            new_in['switches'], # Is there a segment switch?
            lambda: tf.cond( # if switches:
                new_in['b'], # Check if b == 1
                lambda: tf.constant(max_i), # if b == 1: i = max_i
                lambda: tf.constant(0) # if b == 0: i = 0
            ),
            lambda: tf.maximum(last_out['i']-1, 0) # If no switch, decrement i
        ),
        lambda: tf.cond( # if we are currently not overwriting with 0
            new_in['b'], # check if b == 1
            lambda: tf.constant(max_i), # if b == 1: i = max_i
            lambda: tf.constant(0) # if b == 0: i = 0
        )
    )
    b = tf.cond(
        tf.equal(new_i, max_i), # Have we just reset i ?
        lambda: tf.constant(True), # If yes, we want to write b = 1
        lambda: tf.constant(False) # Otherwise, we write b = 0
    )
    
    return {'b': b, 'i': new_i}

Examples:

outp_1 = tf.scan(lambda _last, _inp: step(_last, _inp, max_i=1), elems=elems, initializer=initializer)
print( tf.cast(outp_1['b'], tf.int32) )
# tf.Tensor([0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 0 0 1], shape=(18,), dtype=int32)

outp_2 = tf.scan(lambda _last, _inp: step(_last, _inp, max_i=2), elems=elems, initializer=initializer)
print( tf.cast(outp_2['b'], tf.int32) )
# tf.Tensor([0 1 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 1], shape=(18,), dtype=int32)

outp_4 = tf.scan(lambda _last, _inp: step(_last, _inp, max_i=4), elems=elems, initializer=initializer)
print( tf.cast(outp_4['b'], tf.int32) )
# tf.Tensor([0 1 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 1], shape=(18,), dtype=int32)

This answer is sponsored by lambda.

遇见了你 2025-02-20 23:20:33

这是纯tensorflow方法,它将在急切执行图形模式:

# copy, paste, acknowledge

import tensorflow as tf

def split_regions_and_modify(a, b, i):
  indices = tf.squeeze(tf.where(a[:-1] != a[1:] - 1), axis=-1) + 1
  row_splits = tf.cast(tf.cond(tf.not_equal(tf.shape(indices)[0], 0), 
                    lambda: tf.concat([indices, [indices[-1] + (tf.cast(tf.shape(a), dtype=tf.int64)[0] - indices[-1])]], axis=0), 
        lambda: tf.shape(a)[0][None]), dtype=tf.int32)

  def body(i, j, k, tensor, row_splits):
    k = tf.cond(tf.equal(row_splits[k], j), lambda: tf.add(k, 1), lambda: k)
    current_indices = tf.range(j + 1, tf.minimum(j + 1 + i, row_splits[k]), dtype=tf.int32)

    tensor = tf.cond(tf.logical_and(tf.equal(tensor[j], 1), tf.not_equal(j,  row_splits[k])), lambda: 
                  tf.tensor_scatter_nd_update(tensor, current_indices[..., None], tf.zeros_like(current_indices)), lambda: tensor)
    return i, tf.add(j, 1), k, tensor, row_splits 

  j0 = tf.constant(0)
  k0 = tf.constant(0)
  c = lambda i, j0, k0, b, row_splits: tf.logical_and(tf.less(j0, tf.shape(b)[0]), tf.less(k0, tf.shape(row_splits)[0]))
  _, _, _, output, _ = tf.while_loop(c, body, loop_vars=[i, j0, k0, b, row_splits])
  return output

usage:usage:usage:usage:

a = tf.constant([1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28])
b = tf.constant([0, 1, 1, 1, 1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0,  0,  0,  1])

split_regions_and_modify(a, b, 1)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1], dtype=int32)>

split_regions_and_modify(a, b, 2)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1], dtype=int32)>

split_regions_and_modify(a, b, 4)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1], dtype=int32)>

Here is a pure Tensorflow approach, which will work in Eager Execution and Graph mode:

# copy, paste, acknowledge

import tensorflow as tf

def split_regions_and_modify(a, b, i):
  indices = tf.squeeze(tf.where(a[:-1] != a[1:] - 1), axis=-1) + 1
  row_splits = tf.cast(tf.cond(tf.not_equal(tf.shape(indices)[0], 0), 
                    lambda: tf.concat([indices, [indices[-1] + (tf.cast(tf.shape(a), dtype=tf.int64)[0] - indices[-1])]], axis=0), 
        lambda: tf.shape(a)[0][None]), dtype=tf.int32)

  def body(i, j, k, tensor, row_splits):
    k = tf.cond(tf.equal(row_splits[k], j), lambda: tf.add(k, 1), lambda: k)
    current_indices = tf.range(j + 1, tf.minimum(j + 1 + i, row_splits[k]), dtype=tf.int32)

    tensor = tf.cond(tf.logical_and(tf.equal(tensor[j], 1), tf.not_equal(j,  row_splits[k])), lambda: 
                  tf.tensor_scatter_nd_update(tensor, current_indices[..., None], tf.zeros_like(current_indices)), lambda: tensor)
    return i, tf.add(j, 1), k, tensor, row_splits 

  j0 = tf.constant(0)
  k0 = tf.constant(0)
  c = lambda i, j0, k0, b, row_splits: tf.logical_and(tf.less(j0, tf.shape(b)[0]), tf.less(k0, tf.shape(row_splits)[0]))
  _, _, _, output, _ = tf.while_loop(c, body, loop_vars=[i, j0, k0, b, row_splits])
  return output

Usage:

a = tf.constant([1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28])
b = tf.constant([0, 1, 1, 1, 1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0,  0,  0,  1])

split_regions_and_modify(a, b, 1)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1], dtype=int32)>

split_regions_and_modify(a, b, 2)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1], dtype=int32)>

split_regions_and_modify(a, b, 4)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1], dtype=int32)>
﹉夏雨初晴づ 2025-02-20 23:20:33

如果我很好地了解,则需要在列表a定义的每个部分中,以保持第一个1您在b中遇到i b中的本节中的元素,如果有1,请再次检查其余元素并应用相同的逻辑(零算出下一个<代码> i 元素)。然后移至下一节等。如果我很好地理解了一种实施方法,就是这样:

def get_new_b(a, b, i):
    sect_idx = []
    start_idx = 0
    new_b = b.copy()
    for idx in range(1, len(a)):  # Find sections of array a
        if (a[idx] - a[idx-1]) != 1 or idx == len(a) - 1:
            sect_idx.append([start_idx, idx])
            start_idx = idx

    for sect_start, sect_stop in sect_idx:
        for b_idx in range(sect_start, sect_stop):
            if new_b[b_idx] == 1:
                for b_zer in range(b_idx + 1, min(b_idx + 1 + i, sect_stop)):
                    new_b[b_zer] = 0
    return new_b

for:

a = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]

b = [0, 1, 1, 1, 1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0,  0,  0, 1]

结果将是:

print(get_new_b(a=a, b=b, i=1))
print(get_new_b(a=a, b=b, i=2))
print(get_new_b(a=a, b=b, i=4))

>>> [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1]
>>> [0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]
>>> [0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]

If I understand well, you want in every section defined from list a, to keep the first 1 you meet in b then zero out the next i elements in this section in b and check again the rest elements if there is an 1 and apply the same logic (zero out the next i elements). then move to the next section etc. If I understand well a way to implement it is like this:

def get_new_b(a, b, i):
    sect_idx = []
    start_idx = 0
    new_b = b.copy()
    for idx in range(1, len(a)):  # Find sections of array a
        if (a[idx] - a[idx-1]) != 1 or idx == len(a) - 1:
            sect_idx.append([start_idx, idx])
            start_idx = idx

    for sect_start, sect_stop in sect_idx:
        for b_idx in range(sect_start, sect_stop):
            if new_b[b_idx] == 1:
                for b_zer in range(b_idx + 1, min(b_idx + 1 + i, sect_stop)):
                    new_b[b_zer] = 0
    return new_b

for:

a = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]

b = [0, 1, 1, 1, 1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0,  0,  0, 1]

the results would be :

print(get_new_b(a=a, b=b, i=1))
print(get_new_b(a=a, b=b, i=2))
print(get_new_b(a=a, b=b, i=4))

>>> [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1]
>>> [0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]
>>> [0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文