如何将自定义数据扩展应用于TensorFlow中的预处理层?

发布于 2025-02-11 02:44:57 字数 1115 浏览 4 评论 0 原文

我正在对频谱图图像进行数据增强,并掩盖时间和频率,这是Tensorflow中预处理层的一部分。我正在遇到以下内容:

'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment

这是我使用的代码:

def random_mask_time(img):
  MAX_OCCURENCE = 5
  MAX_LENGTH = 10
  nums = random.randint(0,MAX_OCCURENCE) # number of masks
  
  for n in range(nums):
    length = random.randint(0, MAX_LENGTH) # number of columns to mask (up to 20px in time)
    pos = random.randint(0, img.shape[0]-MAX_LENGTH) # position to start masking
    img[:,pos:(pos+length),:] = 0

  return img


def layer_random_mask_time():
  return layers.Lambda(lambda x: random_mask_time(x))

rnd_time = layer_random_mask_time()

data_augmentation = tf.keras.Sequential([
  rnd_time,
  rnd_freq,
  layers.RandomCrop(input_shape[1], input_shape[0]),
])

然后将其用作KERAS顺序模型的一部分。

我知道张量是不变的,但是如何掩盖图像?

我将其用于参考:

I am performing data augmentation on spectrogram images and mask out time and frequencies as part of preprocessing layers in tensorflow. I am encountering the following:

'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment

here is the code I use:

def random_mask_time(img):
  MAX_OCCURENCE = 5
  MAX_LENGTH = 10
  nums = random.randint(0,MAX_OCCURENCE) # number of masks
  
  for n in range(nums):
    length = random.randint(0, MAX_LENGTH) # number of columns to mask (up to 20px in time)
    pos = random.randint(0, img.shape[0]-MAX_LENGTH) # position to start masking
    img[:,pos:(pos+length),:] = 0

  return img


def layer_random_mask_time():
  return layers.Lambda(lambda x: random_mask_time(x))

rnd_time = layer_random_mask_time()

data_augmentation = tf.keras.Sequential([
  rnd_time,
  rnd_freq,
  layers.RandomCrop(input_shape[1], input_shape[0]),
])

I then use it as part of my keras sequential model.

I get that tensors are immutable, but how can I mask out images?

I used this for reference: https://www.tensorflow.org/tutorials/images/data_augmentation#custom_data_augmentation

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

烟雨扶苏 2025-02-18 02:44:57

尝试这样的尝试:

import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt

def random_mask_time(img):
  MAX_OCCURENCE = 5
  MAX_LENGTH = 10
  nums = tf.random.uniform((), minval = 0, maxval = MAX_OCCURENCE, dtype=tf.int32) # number of masks
  for n in tf.range(nums):
    length = tf.random.uniform((), minval = 0, maxval = MAX_LENGTH, dtype=tf.int32) # number of columns to mask (up to 20px in time)
    pos = tf.random.uniform((), minval = 0, maxval = img.shape[1]-MAX_LENGTH, dtype=tf.int32) # position to start masking
    img = tf.concat([img[:, :, :pos,:], img[:, :, pos:(pos+length),:]*0, img[:, :, (pos+length):,:]], axis=2)
  return img

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
batch_size = 1

ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  seed=123,
  image_size=(128, 128),
  batch_size=batch_size)

data_augmentation = tf.keras.Sequential([
  tf.keras.layers.Lambda(lambda x: random_mask_time(x)),
  tf.keras.layers.RandomCrop(128, 128),
])

image, _ = next(iter(ds.take(1)))
image = data_augmentation(image)
plt.imshow(image[0].numpy() / 255)

“在此处输入图像描述”

”在此处输入图像描述”

Try something like this:

import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt

def random_mask_time(img):
  MAX_OCCURENCE = 5
  MAX_LENGTH = 10
  nums = tf.random.uniform((), minval = 0, maxval = MAX_OCCURENCE, dtype=tf.int32) # number of masks
  for n in tf.range(nums):
    length = tf.random.uniform((), minval = 0, maxval = MAX_LENGTH, dtype=tf.int32) # number of columns to mask (up to 20px in time)
    pos = tf.random.uniform((), minval = 0, maxval = img.shape[1]-MAX_LENGTH, dtype=tf.int32) # position to start masking
    img = tf.concat([img[:, :, :pos,:], img[:, :, pos:(pos+length),:]*0, img[:, :, (pos+length):,:]], axis=2)
  return img

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
batch_size = 1

ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  seed=123,
  image_size=(128, 128),
  batch_size=batch_size)

data_augmentation = tf.keras.Sequential([
  tf.keras.layers.Lambda(lambda x: random_mask_time(x)),
  tf.keras.layers.RandomCrop(128, 128),
])

image, _ = next(iter(ds.take(1)))
image = data_augmentation(image)
plt.imshow(image[0].numpy() / 255)

enter image description here

enter image description here

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文