近期在看Shift-GCN的論文[1],該網(wǎng)絡(luò)是基于Shift卷積算子[2]在圖結(jié)構(gòu)數(shù)據(jù)上的延伸。在閱讀源代碼[3]的時候發(fā)現(xiàn)了其對于Non-Local Spatial Shift Graph Convolution有意思的實現(xiàn)方法,在這里簡要記錄一下。 本文轉(zhuǎn)載自徐飛翔的"Shift-GCN中Shift的實現(xiàn)細(xì)節(jié)筆記,通過torch.index_select實現(xiàn)"
版權(quán)聲明:本文為博主原創(chuàng)文章,遵循 CC 4.0 BY-SA 版權(quán)協(xié)議,轉(zhuǎn)載請附上原文出處鏈接和本聲明。
在討論代碼本身之前,簡要介紹下Non-Local Spatial Shift Graph Convolution的操作流程,具體介紹可見博文[1]。對于一個時空骨骼點序列而言,如Fig 1所示,將單幀的骨骼點圖視為是完全圖,因此任何一個節(jié)點都和其他所有節(jié)點有所連接,其shift卷積策略為:
對于一個特征圖
而言,其中
是骨骼點數(shù)量,
是特征通道數(shù)。對于第
個通道的
距離為
。
根據(jù)這種簡單的策略,如Fig 1所示,形成了類似于螺旋上升的特征圖樣。那么我們要如何用代碼描繪這個過程呢?作者公開的源代碼給予了我們一種思路,其主要應(yīng)用了中的
函數(shù)。先簡單介紹一下這個函數(shù)。
是一個用于索引給定張量中某一個維度中某些特定索引元素的方法,其API手冊如:
torch.index_select(input, dim, index, out=None) → Tensor
Parameters:
input (Tensor) – 輸入張量,需要被索引的張量
dim (int) – 在某個維度被索引
index (LongTensor) – 一維張量,用于提供索引信息
out (Tensor, optional) – 輸出張量,可以不填
其作用很簡單,比如我現(xiàn)在的輸入張量為的尺寸大小,其中
為樣本數(shù)量,
為特征數(shù)目,如果我現(xiàn)在需要指定的某些樣本,比如第
,
等等樣本,我可以用一個
進行索引,然后應(yīng)用
就可以索引了,例子如:
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-0.4664, 0.2647, -0.1228, -1.1068],
[-1.1734, -0.6571, 0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices) # 按行索引
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-1.1734, -0.6571, 0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices) # 按列索引
tensor([[ 0.1427, -0.5414],
[-0.4664, -0.1228],
[-1.1734, 0.7230]])
注意到有一個問題是,似乎在使用
的情況下,不檢查
是否會越界,因此如果你的
越界了,但是報錯的地方可能不在使用
的地方,而是在后續(xù)的代碼中,這個似乎就需要留意下你的
了。同時,
是一個
,這個也是要留意的。
我們先貼出主要代碼,看看作者是怎么實現(xiàn)的:
class Shift_gcn(nn.Module):
def __init__(self, in_channels, out_channels, A, coff_embedding=4, num_subset=3):
super(Shift_gcn, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
if in_channels != out_channels:
self.down = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1),
nn.BatchNorm2d(out_channels)
)
else:
self.down = lambda x: x
self.Linear_weight = nn.Parameter(torch.zeros(in_channels, out_channels, requires_grad=True), requires_grad=True)
self.Linear_bias = nn.Parameter(torch.zeros(1,1,out_channels,requires_grad=True),requires_grad=True)
self.Feature_Mask = nn.Parameter(torch.ones(1,25,in_channels, requires_grad=True),requires_grad=True)
self.bn = nn.BatchNorm1d(25*out_channels)
self.relu = nn.ReLU()
index_array = np.empty(25*in_channels).astype(np.int)
for i in range(25):
for j in range(in_channels):
index_array[i*in_channels + j] = (i*in_channels + j + j*in_channels) % (in_channels*25)
self.shift_in = nn.Parameter(torch.from_numpy(index_array),requires_grad=False)
index_array = np.empty(25*out_channels).astype(np.int)
for i in range(25):
for j in range(out_channels):
index_array[i*out_channels + j] = (i*out_channels + j - j*out_channels) % (out_channels*25)
self.shift_out = nn.Parameter(torch.from_numpy(index_array),requires_grad=False)
def forward(self, x0):
n, c, t, v = x0.size()
x = x0.permute(0,2,3,1).contiguous()
# n,t,v,c
# shift1
x = x.view(n*t,v*c)
x = torch.index_select(x, 1, self.shift_in)
x = x.view(n*t,v,c)
x = x * (torch.tanh(self.Feature_Mask)+1)
x = torch.einsum('nwc,cd->nwd', (x, self.Linear_weight)).contiguous() # nt,v,c
x = x + self.Linear_bias
# shift2
x = x.view(n*t,-1)
x = torch.index_select(x, 1, self.shift_out)
x = self.bn(x)
x = x.view(n,t,v,self.out_channels).permute(0,3,1,2) # n,c,t,v
x = x + self.down(x0)
x = self.relu(x)
# print(self.Feature_Mask.shape)
return x
我們把forward()
里面的分為三大部分,分別是:1> shift_in
操作;2> 卷積操作;3> shift_out
操作;其中指的shift_in
和shift_out
只是shift圖卷積算子的不同形式而已,其主要是一致的。整個結(jié)構(gòu)圖如Fig 2(c)所示。
x = torch.einsum('nwc,cd->nwd', (x, self.Linear_weight)).contiguous() # nt,v,c
x = x + self.Linear_bias
x = x * (torch.tanh(self.Feature_Mask)+1)
那么我們著重考慮以下的代碼:
x = x.view(n*t,v*c)
x = torch.index_select(x, 1, self.shift_in)
x = x.view(n*t,v,c)
第一行代碼將特征圖展開,如Fig 3所示,得到了25 × C 25 \times C25×C大小的特征向量。通過torch.index_select
對特征向量的不同分區(qū)進行選擇得到最終的輸出特征向量,選擇的過程如Fig 4所示。
那么可以知道,對于某個關(guān)節(jié)點i ii而言,給定通道j jj,當(dāng)遍歷不同通道時,會存在一個周期,因此是
,比如對于第0號節(jié)點的第1個通道,其需要將
的值移入,如Fig 4的例子所示。而第2個通道則是需要考慮將
的值移入,我們發(fā)現(xiàn)是以C CC為周期的。這個時候假定的是關(guān)節(jié)點都是同一個的時候,當(dāng)遍歷關(guān)節(jié)點時,我們最終的索引規(guī)則是(
,因為考慮到了溢出的問題,因此需要求余,有
。這個對應(yīng)源代碼的第23-32行,如上所示。
在以這個舉個代碼例子,例子如下所示:
import numpy as np
import torch
array = np.arange(0,15).reshape(3,5)
array = torch.tensor(array)
index = np.zeros(15)
for i in range(3):
for j in range(5):
index[i*5+j] = (i*5+j*5+j) % (15)
index = torch.tensor(index).long()
out = torch.index_select(array.view(1,-1), 1, index).view(3,5)
print(array)
print(out)
輸出為:
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
tensor([[ 0, 6, 12, 3, 9],
[ 5, 11, 2, 8, 14],
[10, 1, 7, 13, 4]])
我們把這種正向移入的稱之為,反過來移入則稱之為
,其索引公式有一點小變化,為:
。代碼例子如下:
import numpy as np
import torch
array = np.arange(0,15).reshape(3,5)
array = torch.tensor(array)
index = np.zeros(15)
for i in range(3):
for j in range(5):
index[i*5+j] = (i*5-j*5+j) % (15)
index = torch.tensor(index).long()
out = torch.index_select(array.view(1,-1), 1, index).view(3,5)
print(array)
print(out)
輸出為:
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
tensor([[ 0, 11, 7, 3, 14],
[ 5, 1, 12, 8, 4],
[10, 6, 2, 13, 9]])
輸入和只是因為平移方向反過來了而已。
當(dāng)然,進行了特征向量的還不夠,還需要將其
回一個特征矩陣,因此會有:
x = x.view(n*t,v,c)
這樣的代碼段出現(xiàn)。
Reference
[1]. https://fesian.blog.csdn.net/article/details/109563113
[2]. https://fesian.blog.csdn.net/article/details/109474701
[3]. https://github.com/kchengiva/Shift-GCN
[4]. https://blog.csdn.net/LoseInVain/article/details/81143966