返回介绍

5.2 计算机视觉工具包:torchvision

发布于 2024-01-28 10:35:46 字数 2783 浏览 0 评论 0 收藏 0

计算机视觉是深度学习中最重要的一类应用,为了方便研究者使用,PyTorch 团队专门开发了一个视觉工具包 torchvion ,这个包独立于 PyTorch,需通过 pip instal torchvision 安装。在之前的例子中我们已经见识到了它的部分功能,这里再做一个系统性的介绍。torchvision 主要包含三部分:

  • models:提供深度学习中各种经典网络的网络结构以及预训练好的模型,包括 AlexNet 、VGG 系列、ResNet 系列、Inception 系列等。
  • datasets: 提供常用的数据集加载,设计上都是继承 torhc.utils.data.Dataset ,主要包括 MNISTCIFAR10/100ImageNetCOCO 等。
  • transforms:提供常用的数据预处理操作,主要包括对 Tensor 以及 PIL Image 对象的操作。
from torchvision import models
from torch import nn
# 加载预训练好的模型,如果不存在会进行下载
# 预训练好的模型保存在 ~/.torch/models/下面
resnet34 = models.resnet34(pretrained=True, num_classes=1000)

# 修改最后的全连接层为 10 分类问题(默认是 ImageNet 上的 1000 分类)
resnet34.fc=nn.Linear(512, 10)
from torchvision import datasets
# 指定数据集路径为 data,如果数据集不存在则进行下载
# 通过 train=False 获取测试集
dataset = datasets.MNIST('data/', download=True, train=False, transform=transform)

Transforms 中涵盖了大部分对 Tensor 和 PIL Image 的常用处理,这些已在上文提到,这里就不再详细介绍。需要注意的是转换分为两步,第一步:构建转换操作,例如 transf = transforms.Normalize(mean=x, std=y) ,第二步:执行转换操作,例如 output = transf(input) 。另外还可将多个处理操作用 Compose 拼接起来,形成一个处理转换流程。

from torchvision import transforms 
to_pil = transforms.ToPILImage()
to_pil(t.randn(3, 64, 64))

<PIL.Image.Image image mode=RGB size=64x64 at 0x7FE16C2484A8>

torchvision 还提供了两个常用的函数。一个是 make_grid ,它能将多张图片拼接成一个网格中;另一个是 save_img ,它能将 Tensor 保存成图片。

len(dataset)
10000
dataloader = DataLoader(dataset, shuffle=True, batch_size=16)
from torchvision.utils import make_grid, save_image
dataiter = iter(dataloader)
img = make_grid(next(dataiter)[0], 4) # 拼成 4*4 网格图片,且会转成3通道
to_img(img)

<PIL.Image.Image image mode=RGB size=906x906 at 0x7FE0A59AF780>
save_image(img, 'a.png')
Image.open('a.png')

<PIL.PngImagePlugin.PngImageFile image mode=RGB size=906x906 at 0x7FE16C248390>

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文