从TensorFlow数据集中提取元素

发布于 2025-01-22 15:08:11 字数 538 浏览 0 评论 0原文

我有一个包含我所有数据和标签的TensorFlow数据集。 前20个元素使用以下代码提取到另一个数据集中:

train_dataset = big_dataset.take(20)

但是如何将BIG_DATASET的最后20个元素提取到新数据集中?

谢谢,我进步了!

编辑: 以下代码显示了我如何定义big_dataset:

big_dataset = tf.data.Dataset.from_tensor_slices((all_points, all_labels))

现在有效获取第一个elemets的是以下代码(train_size as eg 20):

train_dataset = big_dataset.take(train_size)
train_dataset = train_dataset.shuffle(train_size).map(augment).batch(BATCH_SIZE)

但是使用.skip()。

I have a tensorflow dataset containing all my data and labels.
The first 20 elements are extracted into another dataset using following code:

train_dataset = big_dataset.take(20)

But how do i extract for example the last 20 elements from big_dataset into a new dataset?

Thanks i advance!

EDIT:
The following code shows how i define the big_dataset:

big_dataset = tf.data.Dataset.from_tensor_slices((all_points, all_labels))

What works now to get the first elemets is the following code (where train_size is e.g. 20):

train_dataset = big_dataset.take(train_size)
train_dataset = train_dataset.shuffle(train_size).map(augment).batch(BATCH_SIZE)

But using the .skip().take() results in an empty database

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

雨巷深深 2025-01-29 15:08:11

尝试使用跳过。例如,假设您有120个数据示例和一个批次示例1,而您尚未将数据改组,那么您可以尝试以下内容:

train_dataset = big_dataset.skip(100).take(20)

对于特定数据集,请尝试:

import tensorflow as tf

samples = 29
all_points  = tf.random.normal((samples, 5))
all_labels  = tf.random.normal((samples, 1))
big_dataset = tf.data.Dataset.from_tensor_slices((all_points, all_labels))
train_size = 20 
train_dataset = big_dataset.skip(9).take(train_size)
print(len(train_dataset))
20

Try using skip. For example, suppose you have 120 data samples and a batch_size of 1 and you have not shuffled your data, then you can try something like the following:

train_dataset = big_dataset.skip(100).take(20)

For your specific dataset, try:

import tensorflow as tf

samples = 29
all_points  = tf.random.normal((samples, 5))
all_labels  = tf.random.normal((samples, 1))
big_dataset = tf.data.Dataset.from_tensor_slices((all_points, all_labels))
train_size = 20 
train_dataset = big_dataset.skip(9).take(train_size)
print(len(train_dataset))
20
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文