Utilities
库:torchvision.utils
常用函数
make_grid:将多张图像(通常是一个 Batch 的 Tensor)拼成一张网格图。tensor: ,或者是包含相同大小图像的列表。nrow: 每行显示的图像数量。padding: 相邻图像之间的填充距离。normalize: 如果设为True,它会将图像像素值线性缩放到 范围内。- 公式为:。
- 用
value_range指定数据的范围。例如,如果数据通过Normalize(mean=0.5, std=0.5)缩放到了 ,可以设置value_range=(-1, 1)。
- 输出维度:
make_grid返回的是一个 3D Tensor 。 - 通道处理: 如果输入是黑白图(1 通道),它会自动扩展成 3 通道。
save_image:make_grid的封装,直接将 Tensor 保存为本地图像文件。- 它接受
make_grid的所有参数。 fp: 保存路径(string)或文件对象(file object)。- 内部逻辑
- 它先调用
make_grid将 Tensor 转成网格形式。 - 将 Tensor 转换到 CPU,并映射到
uint8范围 。 - 利用 PIL 将 Tensor 保存。
- 它先调用
- 它接受
示例
import torch
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
images = torch.randn(64, 3, 32, 32)
grid_img = make_grid(images, nrow=8, normalize=True, value_range=(-1, 1))
# 如果要用 matplotlib 显示,需要调整维度 (C, H, W) -> (H, W, C)
plt.imshow(grid_img.permute(1, 2, 0).cpu().numpy())
plt.show()
save_image(images, 'results.png', nrow=8, normalize=True)Transforms
Note
本词条为
torch的图像处理工具库,请与 Transformer 区分。
库:torchvision.transforms
与侧重于单步操作的工具箱, OpenCV 相比,transforms侧重于 Data Augmentation 和预处理流水线,常与 Dataset 配合使用,在训练时实时处理图像。
常用功能
转换
ToTensor():将PIL Image或numpy.ndarray(H, W, C) 转换为torch.FloatTensor(C, H, W),并自动将像素值从[0, 255]归一化到[0.0, 1.0]Normalize(mean, std):归一化
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ImageNet 的 mean 和 std,如果借用 `torchvision.model` 的模型的话要这样设置。几何变换
Resize(size):
transforms.Resize((224, 224))
# 若为元组则指定长宽,若为单个数字则指定短边,长边按比例缩放CenterCrop(size),RandomCrop(size):中心或随机裁剪大小
CenterCrop(224) # 等价于 CenterCrop((224, 224))
CenterCrop((123, 456))
# 如果裁切大小比原图大小大的话,则报错RandomResizedCrop(size):随机面积、比例裁切后,再 resize 到指定大小RandomHorizontalFlip(p): 以概率p翻转RandomRotation(degrees): 随机旋转[-degrees, degrees]度
色彩变换
ColorJitter(brightness, contrast, saturation, hue): 随机调整亮度、对比度、饱和度GaussianBlur(kernel_size, sigma): 高斯模糊Grayscale(): 转灰度图
组合容器
Compose([...]): 将上述所有步骤串联起来,类似于torch.nn.Sequential()
示例
from torchvision import transforms
from PIL import Image
img = Image.open('input.jpg')
# 定义一个处理流水线
pipeline = transforms.Compose([
# 1. 几何变换
transforms.Resize(256), # 先缩放
transforms.CenterCrop(224), # 中心裁剪出 224x224
transforms.RandomHorizontalFlip(p=0.5), # 50%概率水平翻转
transforms.RandomRotation(degrees=15), # -15到+15度之间随机旋转
# 2. 色彩与噪声
transforms.ColorJitter(brightness=0.2, contrast=0.2), # 亮度对比度随机波动20%
transforms.GaussianBlur(kernel_size=5), # 高斯模糊
# 3. 结构转换
transforms.ToTensor(), # 必须:转为Tensor并归一化到[0,1]
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # 按通道归一
])
# 执行处理
result_tensor = pipeline(img)