Tensorflow `from_tensor_slices` 卡死

发布于 2022-09-12 00:32:55 字数 1686 浏览 36 评论 0

警告memory exceeded 10% of system memory
我load本地的tfrecord成np array

def CUB_load_data():

    ds = tfds.load('caltech_birds2011', download=False, data_dir='../../datasets/')
    train_data = ds['train']
    test_data = ds['test']


    train_x = []
    train_y = []

    test_x = []
    test_y = []

    for i in train_data.__iter__():
        resized = cv2.resize(i['image'].numpy(), dsize=(224,224))
        train_x.append(resized)
        train_y.append(i['label'])

    for i in test_data.__iter__():
        resized = cv2.resize(i['image'].numpy(), dsize=(224,224))
        test_x.append(resized)
        test_y.append(i['label'])
    return (train_x, train_y), (test_x, test_y)

这部分应该没问题

我用from_tensor_slices去创建tf.dataset,我尝试过修改batch,但一样没效果,卡死并且报超过系统10%容量的警告。
图片是CUB200_2011,所有图片也就1.G上下。用tfds.load生成的tfrecord

def load_data():
    (train_x, train_y), (test_x, test_y) =  CUB_load_data()
    SHUFFLE_BUFFER_SIZE = 500
    BATCH_SIZE = 2
    @tf.function
    def _parse_function(img, label):
        feature = {}
        img = tf.cast(img, dtype=tf.float32)
        img = img / 255.0
        feature["img"] = img
        feature["label"] = label
        return feature

    train_dataset_raw = tf.data.Dataset.from_tensor_slices(
        (train_x, train_y)).map(_parse_function)
    test_dataset_raw = tf.data.Dataset.from_tensor_slices(
        (test_x, test_y)).map(_parse_function)
    train_dataset = train_dataset_raw.shuffle(SHUFFLE_BUFFER_SIZE).batch(
        BATCH_SIZE)
    test_dataset = test_dataset_raw.shuffle(SHUFFLE_BUFFER_SIZE).batch(
        BATCH_SIZE)
    return train_dataset, test_dataset

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文