如何在Tensorflow中实现集合查找​​?

发布于 2025-01-18 04:50:43 字数 825 浏览 2 评论 0原文

在张量流数据集的预处理过程中,我需要检查某个值是否包含在不可变集中。如果不是,我需要将其替换为默认值。本质上它是关于审查/替换某些异常值

在Python中我会做这样的事情:

def map_id (value):
  s = frozenset([1,2,3])
  if value in s:
    return value
  else:
    return 0 # default for all outliers

这个map_id函数将像这样调用

def preprocess(item):
  return (map_id(item["investment_id"]), item["features"]), item["target"]

preprocess函数将像这样调用

def make_dataset(file_paths, batch_size=4096, mode="train"):
  ds = tf.data.TFRecordDataset(file_paths)
  ds = ds.map(decode_function)
  ds = ds.map(preprocess)
  if mode == "train":
    ds = ds.shuffle(batch_size * 4)
  ds = ds.batch(batch_size).cache().prefetch(tf.data.AUTOTUNE)
  return ds

如何在 Tensorflow 2.x 中编写这个 map_id 函数?

During the preprocessing of a tensorflow dataset I need to check whether a certain value is contained in an unmutable set. If it isn't I need to replace it with a default value. Essentially it is about censoring/replacing certain outliers

In python I would do something like this:

def map_id (value):
  s = frozenset([1,2,3])
  if value in s:
    return value
  else:
    return 0 # default for all outliers

This map_id function will be called like this

def preprocess(item):
  return (map_id(item["investment_id"]), item["features"]), item["target"]

The preprocess function will be called like this

def make_dataset(file_paths, batch_size=4096, mode="train"):
  ds = tf.data.TFRecordDataset(file_paths)
  ds = ds.map(decode_function)
  ds = ds.map(preprocess)
  if mode == "train":
    ds = ds.shuffle(batch_size * 4)
  ds = ds.batch(batch_size).cache().prefetch(tf.data.AUTOTUNE)
  return ds

How to write this map_id function in Tensorflow 2.x?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

┊风居住的梦幻卍 2025-01-25 04:50:43

我不确定您的数据是什么样子,但是您应该能够使用简单的Statichashtable作为set> set替代用例,因为它将以图形模式运行:

import tensorflow as tf

data = {
    "investment_id": [1, 2, 3, 4, 5], 
    "features": [12, 912, 28, 90, 17],
    "target": [1, 0, 1, 1, 1]
}

keys_tensor = tf.constant([1, 2, 3])
vals_tensor = tf.constant([1, 2, 3])
table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
    default_value=0)


ds = tf.data.Dataset.from_tensor_slices(data)
ds = ds.map(lambda item: (table.lookup(item['investment_id']), item['features'], item['target']))

for d in ds:
  print(d)
(<tf.Tensor: shape=(), dtype=int32, numpy=1>, <tf.Tensor: shape=(), dtype=int32, numpy=12>, <tf.Tensor: shape=(), dtype=int32, numpy=1>)
(<tf.Tensor: shape=(), dtype=int32, numpy=2>, <tf.Tensor: shape=(), dtype=int32, numpy=912>, <tf.Tensor: shape=(), dtype=int32, numpy=0>)
(<tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(), dtype=int32, numpy=28>, <tf.Tensor: shape=(), dtype=int32, numpy=1>)
(<tf.Tensor: shape=(), dtype=int32, numpy=0>, <tf.Tensor: shape=(), dtype=int32, numpy=90>, <tf.Tensor: shape=(), dtype=int32, numpy=1>)
(<tf.Tensor: shape=(), dtype=int32, numpy=0>, <tf.Tensor: shape=(), dtype=int32, numpy=17>, <tf.Tensor: shape=(), dtype=int32, numpy=1>)

I am not sure what your data looks like, but you should be able to use a simple StaticHashTable as a Set alternative for your use case, since it will run in graph mode:

import tensorflow as tf

data = {
    "investment_id": [1, 2, 3, 4, 5], 
    "features": [12, 912, 28, 90, 17],
    "target": [1, 0, 1, 1, 1]
}

keys_tensor = tf.constant([1, 2, 3])
vals_tensor = tf.constant([1, 2, 3])
table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
    default_value=0)


ds = tf.data.Dataset.from_tensor_slices(data)
ds = ds.map(lambda item: (table.lookup(item['investment_id']), item['features'], item['target']))

for d in ds:
  print(d)
(<tf.Tensor: shape=(), dtype=int32, numpy=1>, <tf.Tensor: shape=(), dtype=int32, numpy=12>, <tf.Tensor: shape=(), dtype=int32, numpy=1>)
(<tf.Tensor: shape=(), dtype=int32, numpy=2>, <tf.Tensor: shape=(), dtype=int32, numpy=912>, <tf.Tensor: shape=(), dtype=int32, numpy=0>)
(<tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(), dtype=int32, numpy=28>, <tf.Tensor: shape=(), dtype=int32, numpy=1>)
(<tf.Tensor: shape=(), dtype=int32, numpy=0>, <tf.Tensor: shape=(), dtype=int32, numpy=90>, <tf.Tensor: shape=(), dtype=int32, numpy=1>)
(<tf.Tensor: shape=(), dtype=int32, numpy=0>, <tf.Tensor: shape=(), dtype=int32, numpy=17>, <tf.Tensor: shape=(), dtype=int32, numpy=1>)
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文