如何在Tensorflow中实现集合查找?
在张量流数据集的预处理过程中,我需要检查某个值是否包含在不可变集中。如果不是,我需要将其替换为默认值。本质上它是关于审查/替换某些异常值
在Python中我会做这样的事情:
def map_id (value):
s = frozenset([1,2,3])
if value in s:
return value
else:
return 0 # default for all outliers
这个map_id
函数将像这样调用
def preprocess(item):
return (map_id(item["investment_id"]), item["features"]), item["target"]
preprocess
函数将像这样调用
def make_dataset(file_paths, batch_size=4096, mode="train"):
ds = tf.data.TFRecordDataset(file_paths)
ds = ds.map(decode_function)
ds = ds.map(preprocess)
if mode == "train":
ds = ds.shuffle(batch_size * 4)
ds = ds.batch(batch_size).cache().prefetch(tf.data.AUTOTUNE)
return ds
如何在 Tensorflow 2.x 中编写这个 map_id
函数?
During the preprocessing of a tensorflow dataset I need to check whether a certain value is contained in an unmutable set. If it isn't I need to replace it with a default value. Essentially it is about censoring/replacing certain outliers
In python I would do something like this:
def map_id (value):
s = frozenset([1,2,3])
if value in s:
return value
else:
return 0 # default for all outliers
This map_id
function will be called like this
def preprocess(item):
return (map_id(item["investment_id"]), item["features"]), item["target"]
The preprocess
function will be called like this
def make_dataset(file_paths, batch_size=4096, mode="train"):
ds = tf.data.TFRecordDataset(file_paths)
ds = ds.map(decode_function)
ds = ds.map(preprocess)
if mode == "train":
ds = ds.shuffle(batch_size * 4)
ds = ds.batch(batch_size).cache().prefetch(tf.data.AUTOTUNE)
return ds
How to write this map_id
function in Tensorflow 2.x?
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
我不确定您的数据是什么样子,但是您应该能够使用简单的
Statichashtable
作为set> set
替代用例,因为它将以图形模式运行:I am not sure what your data looks like, but you should be able to use a simple
StaticHashTable
as aSet
alternative for your use case, since it will run in graph mode: