【grid】pytorch中的Flow_filed,MES,affine_gridHGRID,GRID_SAMPLE详解

发布时间:2024年01月14日

grid in Pytorch

官方链接:
https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html#torch.nn.functional.grid_sample

https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html#torch-nn-functional-grid-sample

更多相关请看我的另一篇文章。

Flow filed 流场

在PyTorch中,流场(flow field)通常用来表示图像中像素的运动或位移信息。它是一个二维矢量场,每个像素都对应一个二维矢量,表示该像素从一个图像到另一个图像的位移。流场通常用于计算光流(optical flow)等计算机视觉任务,用于追踪物体的运动、分析视频序列等。

简单来说,就是在把图像a上的点a_{i, j} 变换到图像b的点b_{x, y}上。

Grid 网格

它是一个4维张量,形状为(N, Hout, Wout, 2)。其中N表示批次大小,Hout和Wout表示输出的高度和宽度,2表示每个像素在新图像上的(x, y)坐标。

TORCH.MESHGRID 生成grid

https://pytorch.org/docs/stable/generated/torch.meshgrid.html

MAKE_GRID 用于一次显示多张图

https://pytorch.org/vision/stable/generated/torchvision.utils.make_grid.html

**affine_grid用于仿射变换**

https://pytorch.org/docs/stable/generated/torch.nn.functional.affine_grid.html#torch.nn.functional.affine_grid

TORCH.NN.FUNCTIONAL.GRID_SAMPLE

定义

For each output location output[n, :, h, w], the size-2 vector grid[n, h, w] specifies input pixel locations x and y, which are used to interpolate the output value output[n, :, h, w].

根据输入值和映射网格(flow-field grid)计算输出。它主要用于在图像处理和计算机视觉任务中,根据给定的网格对输入数据进行采样和插值。

提供一个input的Tensor以及一个对应的flow-field网格(比如光流,体素流等),然后根据grid中每个位置提供的坐标信息(这里指input中pixel的坐标),将input中对应位置的像素值填充到grid指定的位置,得到最终的输出。

区别

grid_sample底层是应用双线性插值,把输入的tensor转换为指定大小。那它和interpolate有啥区别呢?

interpolate是规则采样(uniform),但是grid_sample的转换方式,内部采点的方式并不是规则的,是一种更为灵活的方式。

函数

torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None)

align_corners: 一个可选参数,通常为None。如果为True,则网格中的像素坐标(-1, -1)和(1, 1)将对准输入图像的四个角。如果为False或None,则(-1, -1)对应于输入图像的左上角,(1, 1)对应于右下角。

例子中,我们将一个大小为4x4的tensor 转换为了一个20x20的。grid的大小指定了输出大小,每个grid的位置是一个(x,y)坐标,其值来自于:输入input的(x,y)中 的四邻域插值得到的。

import torch
from torch.nn import functional as F

inp = torch.ones(1, 1, 4, 4)

# 目的是得到一个 长宽为20的tensor
out_h = 20
out_w = 20
 # grid的生成方式等价于用mesh_grid
new_h = torch.linspace(-1, 1, out_h).view(-1, 1).repeat(1, out_w)
new_w = torch.linspace(-1, 1, out_w).repeat(out_h, 1)
grid = torch.cat((new_h.unsqueeze(2), new_w.unsqueeze(2)), dim=2)
grid = grid.unsqueeze(0)

outp = F.grid_sample(inp, grid=grid, mode='bilinear')
print(outp.shape)  #torch.Size([1, 1, 20, 20])

图片来自于SFnet(eccv2020)。flow field是grid, low_resolution是input, high resolution是output。

例子

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

def visualize_affine_transformation(input_image, affine_matrix, padding_mode='zeros'):
    # 使用 grid_sample 进行仿射变换
    output_image = F.grid_sample(input_image, affine_matrix.unsqueeze(0).expand(input_image.size(0), -1, -1), padding_mode=padding_mode)
    
    # 可视化输入和输出图像
    input_image = input_image.squeeze().numpy()
    output_image = output_image.squeeze().numpy()

    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(input_image, cmap='gray')
    plt.title('Input Image')

    plt.subplot(1, 2, 2)
    plt.imshow(output_image, cmap='gray')
    plt.title('Output Image after Affine Transformation')

    plt.show()

# 创建一个示例输入图像(单通道)
input_image = torch.zeros(1, 1, 5, 5)
input_image[0, 0, :, 2] = 1  # 在中心放置一个白色像素

# 定义不同的仿射变换矩阵和 padding_mode
affine_matrix1 = torch.tensor([[1, 0, 2], [0, 1, 2], [0, 0, 1]], dtype=torch.float32)
affine_matrix2 = torch.tensor([[0.5, 0, 0], [0, 2, 0], [0, 0, 1]], dtype=torch.float32)
padding_mode1 = 'zeros'
padding_mode2 = 'border'

# 调用可视化函数并尝试不同的参数组合
visualize_affine_transformation(input_image, affine_matrix1, padding_mode1)
visualize_affine_transformation(input_image, affine_matrix2, padding_mode1)
visualize_affine_transformation(input_image, affine_matrix1, padding_mode2)
visualize_affine_transformation(input_image, affine_matrix2, padding_mode2)

TORCH.NN.FUNCTIONAL.AFFINE_GRID

AFFINE_GRID用于生成仿射变换所需的矩阵。也就是映射所需的流场。

Generates a 2D or 3D flow field (sampling grid), given a batch of affine matrices theta.

其中的具体参数

torch.nn.functional.affine_grid(theta, size, align_corners=None)

  • theta:一个4x2的张量,表示仿射变换的参数矩阵。这个矩阵通常由用户指定,它包含了仿射变换的缩放、旋转、平移和错切等信息。矩阵的形状应为 (N, 2, 3),其中 N 是批次大小。通常情况下,你可以使用PyTorch的 torch.tensor 创建这个矩阵。

  • size:一个包含两个整数的元组 (H, W),指定生成的仿射变换网格的大小。H 表示输出的高度,W 表示输出的宽度。

  • align_corners:一个布尔值或None,通常影响网格生成的坐标点的精确位置。

    当它为True时,生成的网格的坐标点会与输入图像的四个角对齐,这意味着生成的网格将确切地覆盖输入图像的所有四个角。这种情况下,坐标点的精确位置与输入图像的四个角相吻合,对于某些精确的几何变换可能更合适。

    如果为False或None(默认值),则生成的网格的坐标点会与输入图像的左上角对齐。这种情况下,坐标点的精确位置可能会在输入图像的像素之间,对于一般的仿射变换通常更常见。

Reference

grid_sample()函数及双线性采样

文章来源:https://blog.csdn.net/prinTao/article/details/135586307
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。