将batchdataset与keras vgg16 precrocess_input连接
我正在使用 tf.keras.preprocessing.image_dataset_from_directory
获取 batchdataset
,其中数据集有10个类。
我正在尝试将此 batchdataset
与keras vgg16
( docs )网络。从文档中:
注意:每个KERAS应用程序都期望特定的输入预处理。对于VGG16,请致电
tf.keras.applications.vgg16.preprocess_input
在将它们传递给模型之前。
但是,我正在努力获得此 preprocess_input
与 batchdataset一起工作
。 您能帮我弄清楚如何连接这两个点吗?
请参阅以下代码:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(train_data_dir, image_size=(224, 224))
train_ds = tf.keras.applications.vgg16.preprocess_input(train_ds)
这将抛出 typeerror:'batchdataset'对象无法订阅
:
Traceback (most recent call last):
...
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/vgg16.py", line 232, in preprocess_input
return imagenet_utils.preprocess_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 117, in preprocess_input
return _preprocess_symbolic_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 278, in _preprocess_symbolic_input
x = x[..., ::-1]
TypeError: 'BatchDataset' object is not subscriptable
来自 typeError:'datasetV1adapter'对象不可订阅(来自 batchdataset试图将python字典作为表格)提示是使用:
train_ds = tf.keras.applications.vgg16.preprocess_input(
list(train_ds.as_numpy_iterator())
)
但是,这也失败了:这
Traceback (most recent call last):
...
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/vgg16.py", line 232, in preprocess_input
return imagenet_utils.preprocess_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 117, in preprocess_input
return _preprocess_symbolic_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 278, in _preprocess_symbolic_input
x = x[..., ::-1]
TypeError: list indices must be integers or slices, not tuple
一切都失败了:这都是使用 python = 3.10 == 3.10 == 3.10 .3
带有 tensorflow == 2.8.0
。
我该如何工作?先感谢您。
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
好吧,我弄清楚了。我需要传递
tf.tensor
,而不是tf.data.dataset
。可以通过迭代数据集
来获得张量
。这可以通过几种方式完成:
如果将选项2转换为生成器,则可以直接传递到下游
model.fit
中。干杯!Okay I figured it out. I needed to pass a
tf.Tensor
, not atf.data.Dataset
. One can get aTensor
out by iterating over theDataset
.This can be done in a few ways:
If you convert option 2 into a generator, it can be directly passed into the downstream
model.fit
. Cheers!