对于 pytorch
中的函数
torch.gather(
input, # (Tensor) the source tensor
dim, # (int) the axis along which to index
index, # (LongTensor) the indices of elements to gather
*,
sparse_grad=False,
out=None
) → Tensor
有点绕,很多博客画各种图讲各种故事来解释如何从 input
张量中 gather
位置 index
处的值,乱七八糟,我是都没看明白。所以去官网看了文档:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
从这三行看,意思还是很明晰的:输出 out
和输入 input
之间的差别就是,把相应位置(dim
)的下标替换成 index[i][j][k]
,dim=0,1,2
分别对应替换的位置0,1,2
。但这不够直观!
【注】从上面三行代码可以看出,index
和 input
的维度尺寸是一样的,即 len(index.shape) == len(input.shape)
,但不一定是相同的形状:index.shape[dim] ≠ input.shape[dim]
(其他维度的形状必须满足 index.shape <= input.shape
)。
先从简单的一维向量看看:
x = torch.tensor([3, 4, 5, 6, 7])
按规则看,out[i] = input[index[i]] # dim == 0
,即,从向量里选取指定位置 index[i]
处的数字,放到输出向量 out
的 [i]
处。这个很好理解,python 中 numpy 和 pytorch 都有这样的语法:
x = torch.randn(3)
index = torch.randint(low=0, high=3, size=(5,))
y = x[index]
print(x)
print(index)
print(y)
### output ###
tensor([ 0.8797, 0.2459, -0.1312])
tensor([2, 0, 2, 2, 0])
tensor([-0.1312, 0.8797, -0.1312, -0.1312, 0.8797])
用 torch.gather(...)
函数,就是这样的:
x = torch.tensor([3, 4, 5, 6, 7])
index = torch.tensor([4, 4, 1, 1, 0, 3])
out = torch.gather(x, dim=0, index=index)
### output ###
tensor([7, 7, 4, 4, 3, 6])
举例来说,上面的 index[4] = 0
,那么它会寻找 input[index[4]] = input[0] = 3
,然后放入 out[4]
。这就是英文单词 gather 的意思。
index 的长度是不受限制的,即 gather 多少元素都可以。
小结:在一维向量下,out = torch.gather(x, dim=0, index=index)
等价于 out = x[index]
。
往上升一个维度,看看对二维矩阵实施 gather
函数的操作:
x = torch.tensor([[3, 4, 5, 6, 7], [9, 8, 7, 6, 5]])
idx = torch.randint(low=0, high=5, size=(2, 6))
y = torch.gather(x, dim=1, index=idx)
print(x)
print(idx)
print(y)
### output ###
tensor([[3, 4, 5, 6, 7],
[9, 8, 7, 6, 5]])
tensor([[4, 4, 1, 1, 0, 3],
[0, 1, 2, 1, 4, 1]])
tensor([[7, 7, 4, 4, 3, 6],
[9, 8, 7, 8, 5, 8]])
按规则看,out[i][j] = input[i][index[i][j]] # dim == 1
,即,从向量 input[i]
里选取指定位置 index[i][j]
处的数字,放到输出向量 out[i]
的 [j]
处。也许多了一个维度就有点绕了,但仔细观察,我们可以假定 i = 0
,此时:
out[0][j] = input[0][index[0][j]] # 对应上图的左侧
若假定 i = 1
,则:
out[1][j] = input[1][index[1][j]] # 对应上图的右侧
即,输出 out[i]
是对输入 imput[i]
执行了一次与一维向量时一样的操作,其中下标是 index[i]
。在二维矩阵上的 gather 操作,不过是并行地执行了多个一维向量的 gather。
上面是 dim = 1
时的情况,是沿着矩阵的行进行 gather,当 dim = 0
时,就是沿着列进行 gather:
out[i][0] = input[index[i][0]][0] # dim == 0
out[i][1] = input[index[i][1]][1]
...
也就是并行地执行多个列向量的 gather,每列 index
是一个并行分支,并行分支的数量可以小于 input
的列数,但不能超过,超过的话,它 gather 哪一列呢?
弄懂一维到二维的 gather,更高维的操作也就清晰了,就是画图有一点难画。假设
x = tensor([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9]],
[[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]],
[[20, 21, 22, 13, 24],
[25, 26, 27, 28, 29]]])
则当 dim == 0
时,是沿着第一维进行 gather 的,那么 index.shape[0]
(一个并行分支 gather 的元素的数量) 可为任意数,这里设置为 4,其他 index.shape[i≠0] <= input.shape[i≠0]
:
index = tensor([[[1, 2, 2],
[2, 2, 0]],
[[0, 0, 1],
[1, 0, 1]],
[[2, 0, 0],
[0, 1, 2]],
[[1, 1, 0],
[0, 0, 0]]])
index.shape == (4, 2, 3)
,执行:
y = torch.gather(x, dim=0, index=index)
的示意图如下:
只画了看得见的前两列(两个并行 gather 分支)。红色和绿色箭头表示两列下标沿着 dim=0
进行 gather 操作,每一列和一维向量的 gather 是一样的,只不过这里有 2*3
个列。
再往高维拓展,也是一样,都是从基本的一维向量 gather 拓到并行 gather。