文章来源于网络收集而来,版权归原创者所有,如有侵权请及时联系!
5.2 计算机视觉工具包:torchvision
计算机视觉是深度学习中最重要的一类应用,为了方便研究者使用,PyTorch 团队专门开发了一个视觉工具包 torchvion
,这个包独立于 PyTorch,需通过 pip instal torchvision
安装。在之前的例子中我们已经见识到了它的部分功能,这里再做一个系统性的介绍。torchvision 主要包含三部分:
- models:提供深度学习中各种经典网络的网络结构以及预训练好的模型,包括
AlexNet
、VGG 系列、ResNet 系列、Inception 系列等。 - datasets: 提供常用的数据集加载,设计上都是继承
torhc.utils.data.Dataset
,主要包括MNIST
、CIFAR10/100
、ImageNet
、COCO
等。 - transforms:提供常用的数据预处理操作,主要包括对 Tensor 以及 PIL Image 对象的操作。
from torchvision import models
from torch import nn
# 加载预训练好的模型,如果不存在会进行下载
# 预训练好的模型保存在 ~/.torch/models/下面
resnet34 = models.resnet34(pretrained=True, num_classes=1000)
# 修改最后的全连接层为 10 分类问题(默认是 ImageNet 上的 1000 分类)
resnet34.fc=nn.Linear(512, 10)
from torchvision import datasets
# 指定数据集路径为 data,如果数据集不存在则进行下载
# 通过 train=False 获取测试集
dataset = datasets.MNIST('data/', download=True, train=False, transform=transform)
Transforms 中涵盖了大部分对 Tensor 和 PIL Image 的常用处理,这些已在上文提到,这里就不再详细介绍。需要注意的是转换分为两步,第一步:构建转换操作,例如 transf = transforms.Normalize(mean=x, std=y)
,第二步:执行转换操作,例如 output = transf(input)
。另外还可将多个处理操作用 Compose 拼接起来,形成一个处理转换流程。
from torchvision import transforms
to_pil = transforms.ToPILImage()
to_pil(t.randn(3, 64, 64))
<PIL.Image.Image image mode=RGB size=64x64 at 0x7FE16C2484A8>
torchvision 还提供了两个常用的函数。一个是 make_grid
,它能将多张图片拼接成一个网格中;另一个是 save_img
,它能将 Tensor 保存成图片。
len(dataset)
10000
dataloader = DataLoader(dataset, shuffle=True, batch_size=16)
from torchvision.utils import make_grid, save_image
dataiter = iter(dataloader)
img = make_grid(next(dataiter)[0], 4) # 拼成 4*4 网格图片,且会转成3通道
to_img(img)
<PIL.Image.Image image mode=RGB size=906x906 at 0x7FE0A59AF780>
save_image(img, 'a.png')
Image.open('a.png')
<PIL.PngImagePlugin.PngImageFile image mode=RGB size=906x906 at 0x7FE16C248390>
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论