本文轉(zhuǎn)自徐飛翔的“緊致卷積網(wǎng)絡(luò)設(shè)計(jì)——Shift卷積算子”
版權(quán)聲明:本文為博主原創(chuàng)文章,遵循 CC 4.0 BY-SA 版權(quán)協(xié)議,轉(zhuǎn)載請(qǐng)附上原文出處鏈接和本聲明。
卷積計(jì)算及其優(yōu)化
為了討論的連續(xù)性,我們先簡(jiǎn)單回顧下傳統(tǒng)的深度學(xué)習(xí)卷積計(jì)算。給定一個(gè)輸入張量,如Fig 1.1中的藍(lán)色塊所示,其尺寸為 ;給定卷積核
,如Fig 1.1中的藍(lán)色虛線框所示,為了方便起見(jiàn),假定步進(jìn)stride = 1,padding = 1,那么最終得到輸出結(jié)果為
,計(jì)算過(guò)程如式子(1.1)所示:
其中為卷積中心,而
是卷積計(jì)算半徑的索引。不難知道,該卷積操作的參數(shù)量為
。計(jì)算量也容易計(jì)算,考慮到每個(gè)卷積操作都需要對(duì)每個(gè)卷積核中的參數(shù)進(jìn)行乘法計(jì)算,那么有乘法因子
,而考慮到stride = 1而且存在填充,那么容易知道計(jì)算量為
? FLOPs。容易發(fā)現(xiàn),卷積的計(jì)算量和參數(shù)量與卷積核大小
?呈現(xiàn)著二次增長(zhǎng)的關(guān)系,這使得卷積的計(jì)算量和參數(shù)量增長(zhǎng)都隨著網(wǎng)絡(luò)設(shè)計(jì)的加深變得難以控制。
Fig 1.1 經(jīng)典的卷積操作示意圖。
在進(jìn)一步對(duì)傳統(tǒng)卷積計(jì)算進(jìn)行優(yōu)化之前,我們先分析一下卷積計(jì)算到底提取了什么類型的信息。以二維卷積為例子,卷積計(jì)算主要在兩個(gè)維度提取信息,空間域和通道域,不過(guò)從本質(zhì)上說(shuō),通道域的信息可以看成是原始輸入(比如RGB圖片輸入)的層次化特征/信息的層疊,因此本質(zhì)上二維卷積還是提取空間域信息,只不過(guò)在層疊卷積過(guò)程中,使得空間域信息按照層次的特點(diǎn),散布在了通道域中。
知道了這一點(diǎn),我們就可以把卷積過(guò)程中的空間卷積和通道卷積分離開(kāi)了,從而得到了所謂的 通道可分離卷積[4,5]。如Fig 1.2所示,這類型的卷積將空間域和通道域卷積完全分開(kāi),在第一步只考慮空間域卷積,因此對(duì)于每個(gè)輸入張量的通道,只會(huì)有唯一一個(gè)對(duì)應(yīng)的卷積核進(jìn)行卷積。數(shù)學(xué)表達(dá)為:
對(duì)比式子(1.1)和(1.2),我們發(fā)現(xiàn)區(qū)別在于對(duì)卷積核的索引上,通過(guò)式子(1.2)輸出的張量形狀為 ,為了接下來(lái)在通道域進(jìn)行卷積,需要進(jìn)一步應(yīng)用1x1卷積,將通道數(shù)從
變?yōu)?
,如式子(1.3)所示。
其中 為1x1卷積核。通道可分離卷積就是將傳統(tǒng)卷積(1.1)分解為了(1.2)(1.3)兩個(gè)步驟。
通過(guò)這種優(yōu)化,可以知道卷積核參數(shù)量變?yōu)?span>,而計(jì)算量變?yōu)?span>
? FLOPs。雖然理論上,深度可分離網(wǎng)絡(luò)的確減少了計(jì)算量和參數(shù)量,但是實(shí)際上,因?yàn)樯疃瓤煞蛛x網(wǎng)絡(luò)的實(shí)現(xiàn)使得訪存1(memory access)過(guò)程占據(jù)了主導(dǎo),使得實(shí)際計(jì)算占用率過(guò)小,限制了硬件的并行計(jì)算能力。
Fig 1.2 深度可分離卷積,對(duì)于輸入張量的每一個(gè)通道,都有其專有的卷積核進(jìn)行卷積,最后通過(guò)一個(gè)1x1的卷積即可完成通道數(shù)量的放縮。
我們用傳統(tǒng)卷積和深度可分離卷積的計(jì)算/訪存系數(shù)進(jìn)行比較(僅考慮最基本的訪存,即是將每個(gè)操作數(shù)都從內(nèi)存中獲取,而不考慮由于局部性原理[6],而對(duì)重復(fù)的操作數(shù)進(jìn)行訪問(wèn)導(dǎo)致的消耗):
式子(1.4)和(1.5)的比較最終會(huì)化簡(jiǎn)為比較 和
的大小,越小意味著計(jì)算效率越高。我們發(fā)現(xiàn),傳統(tǒng)的卷積反而比深度可分離卷積的計(jì)算效率高得多。這是不利于程序并行計(jì)算的。
為此,文章[1]提出了Shift卷積算子,嘗試解決這種問(wèn)題。Shift卷積算子
在Shift卷積算子中,其基本思路也是類似于深度可分離卷積的設(shè)計(jì),將卷積分為空間域和通道域的卷積,通道域的卷積同樣是通過(guò)1x1卷積實(shí)現(xiàn)的,而在空間域卷積中,引入了shift操作。我們接下來(lái)會(huì)詳細(xì)地探討shift操作的設(shè)計(jì)啟發(fā),細(xì)節(jié)和推導(dǎo)。
Fig 2.1 基于Shift的卷積可以分為Shift卷積算子和1x1卷積操作。
shift卷積算子的數(shù)學(xué)形式表達(dá)如式子(2.1)所示,如圖Fig 2.1所示,shift卷積的每一個(gè)卷積核都是一個(gè)“獨(dú)熱”的算子,其卷積核只有一個(gè)元素為1,其他全部為0,如式子(2.2)所示。類似于深度可分離卷積,對(duì)于輸入的 M個(gè)通道的張量,分別對(duì)應(yīng)了 M個(gè)Shift卷積核,如Fig 2.1的不同顏色的卷積核所示。
我們把其中一個(gè)通道的shift卷積操作拿出來(lái)分析,如Fig 2.2所示。我們發(fā)現(xiàn),shift卷積過(guò)程相當(dāng)于將原輸入的矩陣在某個(gè)方向進(jìn)行平移,這也是為什么該操作稱之為shift的原因。雖然簡(jiǎn)單的平移操作似乎沒(méi)有提取到空間信息,但是考慮到我們之前說(shuō)到的,通道域是空間域信息的層次化擴(kuò)散。因此通過(guò)設(shè)置不同方向的shift卷積核,可以將輸入張量不同通道進(jìn)行平移,隨后配合1x1卷積實(shí)現(xiàn)跨通道的信息融合,即可實(shí)現(xiàn)空間域和通道域的信息提取。
Fig 2.2 shift卷積算子中的卷積操作,經(jīng)過(guò)填充后如(2)所示,我們發(fā)現(xiàn),shift卷積相當(dāng)于將原輸入矩陣在某個(gè)方向進(jìn)行平移。
我們發(fā)現(xiàn)shift卷積的本質(zhì)是特定內(nèi)存的訪問(wèn),可學(xué)習(xí)參數(shù)只是集中在1x1卷積操作中。因此如果實(shí)現(xiàn)得當(dāng),shift卷積是不占用額外的計(jì)算量和參數(shù)量的,結(jié)合shift卷積,只使用1x1卷積即可提取到結(jié)構(gòu)化層次化的空間域信息,因此大大減少了卷積網(wǎng)絡(luò)設(shè)計(jì)的參數(shù)量和計(jì)算量。
然而我們注意到,對(duì)于一個(gè)卷積核大小為,通道數(shù)為M的卷積核而言,其可能的搜索空間為
,在學(xué)習(xí)過(guò)程中窮盡這個(gè)搜索空間是不太現(xiàn)實(shí)的。為了減少搜索空間,[1]采用了一種簡(jiǎn)單的啟發(fā)式設(shè)計(jì):將 M個(gè)通道均勻地分成
個(gè)組,我們將每個(gè)組稱之為 平移組(shift group)。每個(gè)組有
個(gè)通道,這些通道都采用相同的平移方向。當(dāng)然,有可能存在除不盡的情況,這個(gè)時(shí)候?qū)?huì)有一些通道不能被劃分到任意一個(gè)組內(nèi),這些剩下的通道都稱之為“居中”組,如Fig 2.3所示,其中心元素為1,其他為0,也即是對(duì)原輸入不進(jìn)行任何處理。
Fig 2.3 居中組的中心元素為1,其他元素都為0。
雖然通過(guò)這種手段大大縮小了搜索空間,但是仍然需要讓模型學(xué)出如何將第 m個(gè)通道映射到第個(gè)平移組的最佳排列規(guī)則,這仍然是一個(gè)很大的搜索空間。為了解決這個(gè)問(wèn)題,以下需要提出一種方法,其能夠使得shift卷積層的輸出和輸入是關(guān)于通道排序無(wú)關(guān)的。假設(shè)
表示是在以
為通道排序的shift卷積操作,那么公式(2.1)可以表示為
,如果我們?cè)谶M(jìn)行該卷積之前,先后進(jìn)行兩次通道排序,分別是
??和
,那么我們有:
其中
表示算子組合。令
和
分別表示1x1卷積操作,我們有式子(2.4)
這一點(diǎn)不難理解,即便對(duì)1x1卷積的輸入進(jìn)行通道排序重組,在學(xué)習(xí)過(guò)程中,通過(guò)算法去調(diào)整1x1卷積的參數(shù)的順序,就可以通過(guò)構(gòu)造的方式,實(shí)現(xiàn) 和
之間的雙射(bijective)。如式子(2.5)所示,就結(jié)論而言,不需要考慮通道的排序,比如只需要依次按著順序賦值某個(gè)平移組,使得其不重復(fù)即可。通過(guò)用1x1卷積“三明治”夾著shift卷積的操作,從理論上可以等價(jià)于其他任何形式的通道排序后的結(jié)果。這點(diǎn)比較繞,有疑問(wèn)的讀者請(qǐng)?jiān)谠u(píng)論區(qū)留言。
根據(jù)以上討論,根據(jù)shift算子構(gòu)建出來(lái)的卷積模塊類似于Fig 2.4所示,注意到藍(lán)色實(shí)線塊的 1x1 conv -> shift kernel -> 1x1 conv
正是和我們的討論一樣的結(jié)構(gòu),而Identity塊則是考慮到仿照ResNet的設(shè)計(jì)補(bǔ)充的short cut鏈路。藍(lán)色虛線塊的shift塊是實(shí)驗(yàn)補(bǔ)充的一個(gè)設(shè)計(jì),存在虛線部分的shift塊的設(shè)計(jì)稱之為結(jié)構(gòu),只存在實(shí)線部分的設(shè)計(jì)則稱之為
結(jié)構(gòu)。
Fig 2.4 基于shift卷積算子構(gòu)建的ResNet網(wǎng)絡(luò)基本模塊。
shift卷積算子的有效性在文章[1]設(shè)置了很多實(shí)驗(yàn)進(jìn)行對(duì)比,這里只給出證實(shí)其在分類任務(wù)上精度和計(jì)算量/參數(shù)量的一個(gè)比較,如Fig 2.5所示,我們發(fā)現(xiàn)shift算子的確在計(jì)算量和參數(shù)量上有著比較大的優(yōu)勢(shì)。
exp_resultFig 2.5 shift卷積網(wǎng)絡(luò)在CIFAR10/100分類任務(wù)上的表現(xiàn)對(duì)比表。
在[7]中有shift卷積算子前向和反向計(jì)算的cuda代碼,其主要操作就是進(jìn)行卷積輸入張量的訪存選擇。有興趣的讀者可以自行移步去閱讀。
那么我只需要固定某個(gè)特定的索引順序即可,最簡(jiǎn)單的方式如Fig a1所示,按行列排列的順序遍歷設(shè)置即可,并不需要對(duì)其進(jìn)行shuffle,因?yàn)榭梢宰C實(shí)其本身都是可以通過(guò)結(jié)合1x1卷積的方式學(xué)習(xí)出來(lái)的(不同的索引順序?qū)W習(xí)出來(lái)的卷積參數(shù)不同,但是如果看成整體的話,它們是等價(jià)的)。因此文章里面應(yīng)該是不需要進(jìn)行shift組的排序shuffle的。
shift_groupFig a1. 按順序排列的shift算子組示例。
說(shuō)到如何實(shí)現(xiàn)shift的訪存優(yōu)化機(jī)制,我們可以先看看shift-gcn是怎么做實(shí)現(xiàn)的,具體見(jiàn)文章[8]。當(dāng)然,本文并不是采用shift-gcn定義的那種shift圖卷積,我們回到開(kāi)源的[7]中的具體代碼段進(jìn)行分析。我不是很熟悉cuda編程,只能作初步的分析,比如代碼段[9]:
__global__ void shiftnet_cuda_moduloshift3x3_nchw_float32_kernel_tilein16x16_tileout14x14(
float *src,
float *dst,
int num_h_tiles,
int num_w_tiles,
int batch_sz,
int channels,
int height,
int width)
{
__shared__ float cache[256];
const int num_blocks = batch_sz * channels * num_h_tiles * num_w_tiles;
const int num_threads = blockDim.x * num_blocks;
const int rd_chans = (channels / 9) * 9;
for (int idx = threadIdx.x + blockDim.x * blockIdx.x;
idx < num_threads; idx += blockDim.x * gridDim.x)
{
const int w_tile_idx = (idx / 256) % num_w_tiles;
const int h_tile_idx = ((idx / 256) / num_w_tiles) % num_h_tiles;
const int tile_ch = (((idx / 256) / num_w_tiles) / num_h_tiles) % channels;
const int tile_batch_idx = ((((idx / 256) / num_w_tiles) / num_h_tiles) / channels) % batch_sz;
const int w_shift = ((tile_ch % 3) - 1) * (tile_ch < rd_chans);
const int h_shift = (((tile_ch / 3) % 3) - 1) * (tile_ch < rd_chans);
const int w_tile_off = threadIdx.x % 16;
const int h_tile_off = threadIdx.x / 16;
const int w_idx = w_tile_off - 1 + 14 * w_tile_idx;
const int h_idx = h_tile_off - 1 + 14 * h_tile_idx;
const int buf_idx = w_idx + width * (h_idx + height * (tile_ch + channels * tile_batch_idx));
if (w_idx >= 0 && w_idx < width && h_idx >= 0 && h_idx < height) {
cache[threadIdx.x] = src[buf_idx];
} else {
cache[threadIdx.x] = 0.0f;
}
__syncthreads();
if (w_tile_off >= 1 && w_tile_off < 15 && h_tile_off >= 1 && h_tile_off < 15) {
if (w_idx >= 0 && w_idx < width && h_idx >= 0 && h_idx < height) {
const int cache_idx = (w_tile_off + w_shift) + 16 * (h_tile_off + h_shift);
dst[buf_idx] = cache[cache_idx];
}
}
__syncthreads();
}
}
它的實(shí)現(xiàn)并沒(méi)有完全優(yōu)化完全,因?yàn)樗麤](méi)有結(jié)合后續(xù)的1x1卷積進(jìn)行優(yōu)化,他只是進(jìn)行了將某行(或列)置位0,代碼是:
if (w_idx >= 0 && w_idx < width && h_idx >= 0 && h_idx < height) {
cache[threadIdx.x] = src[buf_idx];
} else {
cache[threadIdx.x] = 0.0f;
}
}
然后進(jìn)行移位:
if (w_tile_off >= 1 && w_tile_off < 15 && h_tile_off >= 1 && h_tile_off < 15) {
if (w_idx >= 0 && w_idx < width && h_idx >= 0 && h_idx < height) {
const int cache_idx = (w_tile_off + w_shift) + 16 * (h_tile_off + h_shift);
dst[buf_idx] = cache[cache_idx];
}
}
我覺(jué)得這個(gè)并不是最優(yōu)化后的結(jié)果,最優(yōu)化應(yīng)該是指定某個(gè)方塊(比如不包括第一列的其他所有數(shù)據(jù)),與后續(xù)的1x1卷積聯(lián)合起來(lái),只對(duì)這些方塊進(jìn)行卷積,這才是真正的訪存優(yōu)化,顯然這樣難度太大,因此它的實(shí)現(xiàn)并沒(méi)有這樣做。(正如我所說(shuō)的,我不是很熟悉cuda,有謬誤請(qǐng)指出。)
最后需要指出的是,它并不是one-hot矩陣點(diǎn)乘,而是卷積。
以上。
Reference
[1]. Wu, B., Wan, A., Yue, X., Jin, P., Zhao, S., Golmant, N., … & Keutzer, K. (2018). Shift: A zero flop, zero parameter alternative to spatial convolutions. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 9127-9135).
[2]. Cheng, K., Zhang, Y., He, X., Chen, W., Cheng, J., & Lu, H. (2020). Skeleton-Based Action Recognition With Shift Graph Convolutional Network. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 183-192).
[3]. https://github.com/peterhj/shiftnet_cuda_v2
[4]. Howard, Andrew G., Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, and Hartwig Adam. “Mobilenets: Efficient convolutional neural networks for mobile vision applications.” arXiv preprint arXiv:1704.04861 (2017).
[5]. Chollet, F. (2017). Xception: Deep learning with depthwise separable convolutions. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 1251-1258).
[6]. https://baike.baidu.com/item/%E5%B1%80%E9%83%A8%E6%80%A7%E5%8E%9F%E7%90%86
[7]. https://github.com/peterhj/shiftnet_cuda_v2/blob/master/src/shiftnet_cuda_kernels.cu
[8]. https://fesian.blog.csdn.net/article/details/109644297
[9]. https://github.com/peterhj/shiftnet_cuda_v2/blob/4d471bd744751ff0fd6cf5acd518e9484cc70a98/src/shiftnet_cuda_kernels.cu#L25