数据集:快速了解
- 从 numpy 数组读取内存数据。
- 逐行读取 csv 文件。
基本输入
学习如何获取数组的片段,是开始学习 tf.data
最简单的方式。
Premade Estimators
def train_input_fn(features, labels, batch_size):
"""一个用来训练的输入函数"""
# 将输入值转化为数据集。
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# 混排、重复、批处理样本。
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
# 返回数据集
return dataset
下面我们来对这个函数做更仔细的分析。
参数
这个函数一共需要三个参数。如果一个参数的期望类型是 “array” (数组),那么它将可以接受几乎所有可以用 numpy.array
来转化为数组的值。我们可以看到只有一个例外: tuple
,它对 Datasets
有特殊的含义。
features
:一个形如{'feature_name':array}
的数据字典(或者是DataFrame
),它包含了原始的输入特征。labels
:一个包含每个样本的 label 的数组。batch_size
:一个指示所需批量大小的整数。
在 premade_estimator.py
中,我们使用 iris_data.load_data()
函数来检索虹膜数据。
你可以运行该函数,并按如下方式解压结果:
import iris_data
# 获取数据
train, test = iris_data.load_data()
features, labels = train
然后用像下面这样的一行代码,将数据传递给 input 函数:
batch_size=100
iris_data.train_input_fn(features, labels, batch_size)
让我们来具体看看 train_input_fn()
函数。
(数组)片段
TF Layers 教程:构建卷积神经网络
返回这个 Dataset
的代码如下所示:
train, test = tf.keras.datasets.mnist.load_data()
mnist_x, mnist_y = train
mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
print(mnist_ds)
张量
<TensorSliceDataset shapes: (28,28), types: tf.uint8>
The Dataset
above represents a simple collection of arrays, but datasets are much more powerful than this. A Dataset
can transparently handle any nested combination of dictionaries or tuples (or namedtuple
).
For example after converting the iris features
to a standard python dictionary, you can then convert the dictionary of arrays to a Dataset
of dictionaries as follows:
dataset = tf.data.Dataset.from_tensor_slices(dict(features))
print(dataset)
<TensorSliceDataset
shapes: {
SepalLength: (), PetalWidth: (),
PetalLength: (), SepalWidth: ()},
types: {
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64}
>
张量
[需要翻译]The first line of the iris train_input_fn
uses the same functionality,但是增加了一层结构。它创建了一个包含 (features_dict, label)
数据对的数据集。
以下代码表明,标签是类型为 int64
的标量:
# 将输入转化为数据集。
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
print(dataset)
<TensorSliceDataset
shapes: (
{
SepalLength: (), PetalWidth: (),
PetalLength: (), SepalWidth: ()},
()),
types: (
{
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64},
tf.int64)>
操作
目前, Dataset
会按照固定顺序遍历数据一次,且一次只能生成一个元素。在可以用于训练之前,它需要进一步的处理。幸运的是, tf.data.Dataset
类提供了方法来让数据为训练作出更好的准备。 train_input_fn
的下一行代码就利用了几个这样的方法:
# 样本的混排、重复、批处理。
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
print(mnist_ds.batch(100))
<BatchDataset
shapes: (?, 28, 28),
types: tf.uint8>
注意,因为最后一个批次将会有比较少的元素,因此数据集的批量大小是不确定的。
在 train_input_fn
中,批处理之后, 数据集
包含元素们的一维向量,这些一维向量的前面部分是:
print(dataset)
<TensorSliceDataset
shapes: (
{
SepalLength: (?,), PetalWidth: (?,),
PetalLength: (?,), SepalWidth: (?,)},
(?,)),
types: (
{
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64},
tf.int64)>
返回
[需要翻译]At this point the Dataset
contains (features_dict, labels)
pairs. This is the format expected by the train
and evaluate
methods, so the input_fn
returns the dataset.
The labels
can/should be omitted when using the predict
method.
读取 CSV 文件
如下对 iris_data.maybe_download
函数的调用,将会在必要的时候下载数据,并返回结果文件的路径:
import iris_data
train_path, test_path = iris_data.maybe_download()
iris_data.csv_input_fn
函数包括了一个用 Dataset
解析 csv 文件的替代方案。
让我们来看看如何构建一个兼容 Estimator 的、可以读取本地文件的输入函数。
建立 Dataset
ds = tf.data.TextLineDataset(train_path).skip(1)
建立一个 csv 行解析器
我们从建立一个可以解析一行的函数开始。
# 描述文本列的元数据
COLUMNS = ['SepalLength', 'SepalWidth',
'PetalLength', 'PetalWidth',
'label']
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
def _parse_line(line):
# 将行解码到 fields 中
fields = tf.decode_csv(line, FIELD_DEFAULTS)
# 将结果打包成字典
features = dict(zip(COLUMNS,fields))
# 将标签从特征中分离
label = features.pop('label')
return features, label
解析多行
这个 map
方法接受一个 map_func
参数,这个参数描述了 Dataset
中的每一个元素应该如何被转化。
因此,为了在多行数据被从 csv 文件中读取出来的时候解析它们,我们为 map
方法提供 _parse_line
函数:
ds = ds.map(_parse_line)
print(ds)
<MapDataset
shapes: (
{SepalLength: (), PetalWidth: (), ...},
()),
types: (
{SepalLength: tf.float32, PetalWidth: tf.float32, ...},
tf.int32)>
现在,数据集中包含的是 (features, label)
数据对,而不是简单的字符串标量了。
iris_data.csv_input_fn
函数的余下部分和 Basic input 中介绍的 iris_data.train_input_fn
函数相同。
实践
这个函数可以作为 iris_data.train_input_fn
的替代。它可以像如下这样,来给 estimator 提供数据:
train_path, test_path = iris_data.maybe_download()
# 所有的输入都是数字
feature_columns = [
tf.feature_column.numeric_column(name)
for name in iris_data.CSV_COLUMN_NAMES[:-1]]
# 构建 estimator
est = tf.estimator.LinearClassifier(feature_columns,
n_classes=3)
# 训练 estimator
batch_size = 100
est.train(
steps=1000,
input_fn=lambda : iris_data.csv_input_fn(train_path, batch_size))
Estimator 期望 input_fn
没有任何参数。要解除这个限制,我们使用 lambda
来捕获参数并提供预期的接口。
总结
为了从不同的数据源中便捷的读取数据, tf.data
模块提供了类和函数的集合。除此之外, tf.data
有简单并且强大的方法,来应用各种标准和自定义转换。
现在你已经基本了解了如何为 Estimator 高效的获取数据。(作为扩展)接下来可以思考如下的文档:
- 创建定制化 Estimator
- 底层 API 编程介绍
- 数据导入
如果您发现本页面存在错误或可以改进,请 点击此处 帮助我们改进。
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论