如何使用 map() 方法实现自定义 Pytorch 数据集
许多 Pytorch 示例都使用 Dataset map() 方法。例如:
https://huggingface.co/voidful/wav2vec2-large -xlsr-53-tw-gpt
ds = load_dataset("common_voice", 'zh-TW', split="test")
ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
def map_to_array(batch):
audio = batch["audio"]
batch["speech"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
batch["sampling_rate"] = audio["sampling_rate"]
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
return batch
ds = ds.map(map_to_array)
def map_to_pred(batch):
features = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0], padding=True, return_tensors="pt")
input_values = features.input_values.to(device)
attention_mask = features.attention_mask.to(device)
with torch.no_grad():
logits = model(input_values, attention_mask=attention_mask).logits
pred_ids = torch.argmax(logits, dim=-1)
batch["predicted"] = processor.batch_decode(pred_ids)
batch["target"] = batch["sentence"]
return batch
result = ds.map(map_to_pred, batched=True, batch_size=3, remove_columns=list(ds.features.keys()))
然而,实现自定义地图样式数据集只需要 __len__()
和__getitem__()
将自定义数据集转换为具有示例所需的所有有用方法的正确方法是什么?
Many of the Pytorch examples use the Dataset map() method. For example:
https://huggingface.co/voidful/wav2vec2-large-xlsr-53-tw-gpt
ds = load_dataset("common_voice", 'zh-TW', split="test")
ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
def map_to_array(batch):
audio = batch["audio"]
batch["speech"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
batch["sampling_rate"] = audio["sampling_rate"]
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
return batch
ds = ds.map(map_to_array)
def map_to_pred(batch):
features = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0], padding=True, return_tensors="pt")
input_values = features.input_values.to(device)
attention_mask = features.attention_mask.to(device)
with torch.no_grad():
logits = model(input_values, attention_mask=attention_mask).logits
pred_ids = torch.argmax(logits, dim=-1)
batch["predicted"] = processor.batch_decode(pred_ids)
batch["target"] = batch["sentence"]
return batch
result = ds.map(map_to_pred, batched=True, batch_size=3, remove_columns=list(ds.features.keys()))
However, implementing a custom Map style dataset only requires __len__()
and __getitem__()
What is the correct way to convert a custom Dataset into one with all the useful methods needed by the examples?
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论