筆者在前文 [2] 中曾經(jīng)介紹過VQ-VAE模型,如Fig 1.所示,該模型基于最近鄰查找的方式從字典中查找其索引,作為其稀疏化后的令牌,具體細(xì)節(jié)可見博文[2]。
Fig 1. 通過最近鄰方法在字典里面查找稀疏令牌,作為稀疏編碼的結(jié)果,然后通過反查字典可以對feature map進(jìn)行恢復(fù)。整個(gè)框架中有若干參數(shù)需要學(xué)習(xí),分別是encoder,decoder網(wǎng)絡(luò)參數(shù)和Embedding space字典的參數(shù)。然而稀疏編碼的過程由于出現(xiàn)了最近鄰方法,這個(gè)過程顯然是無法傳遞梯度的,為了實(shí)現(xiàn)編碼器的更新,可以考慮將解碼器的梯度直接拷貝到編碼器中。假設(shè)對于編碼后恢復(fù)的而言,其每個(gè)元素表示為
,那么對于其中某個(gè)元素的梯度表示為
,同理,對于編碼后的
而言,同樣有
? ,令
? 。
那么對于編碼器的梯度就可以表示為 。在詳細(xì)分析代碼實(shí)現(xiàn)邏輯之前,讓我們回顧下其損失函數(shù),如(1-1)所示,其中的
為停止梯度函數(shù),表示該函數(shù)無梯度傳導(dǎo)。decoder的參數(shù)通過第一項(xiàng)損失項(xiàng)進(jìn)行更新(這部分損失可通過MSE損失
建模),稱之為重建損失。encoder參數(shù)通過第一項(xiàng)和第三項(xiàng)損失進(jìn)行更新,其中第一項(xiàng)是重建損失,第三項(xiàng)是為了encoder編碼產(chǎn)出和embedding space進(jìn)行對齊而設(shè)計(jì)的,由于此時(shí)通過
函數(shù)停止了梯度,因此此時(shí)
的參數(shù)不會得到更新。Embedding space的參數(shù)通過第二項(xiàng)損失項(xiàng)進(jìn)行更新,通過將encoder編碼結(jié)果進(jìn)行停止梯度,我們只對E \mathcal{E}E進(jìn)行參數(shù)更新。
Fig 2. 通過梯度拷貝,將decoder的梯度拷貝到encoder中。
那么在代碼中如何實(shí)現(xiàn)這些邏輯呢?我們首先可以參考[3]項(xiàng)目中的實(shí)現(xiàn)。我們首先分析model.py文件中的forward函數(shù),字典定義為一個(gè)nn.Embedding層(Code 1.1),其參數(shù)就是self.dict.weight,那么求最近鄰的操作就如Code 1.2所示。Code 1.3將最近鄰的索引結(jié)果(也即是稀疏化后的視覺令牌),在字典中進(jìn)行查詢,對feature map進(jìn)行恢復(fù)。因此W_j的形狀和Z是一致的。此時(shí)Code 1.4中對Z和W_j進(jìn)行detach,這個(gè)detach的作用之前在博文[4]中闡述過,本文不進(jìn)行累述,其主要作用可視為是停止了該節(jié)點(diǎn)開始的梯度傳導(dǎo),也即是用于實(shí)現(xiàn)公式(1-1)中的和
。
Code 1. model.py的主要邏輯
def __init__(self,...):
...
self.dict = nn.Embedding(k_dim, z_dim) # Code 1.1
def forward(self, x):
h = self.encoder(x) # (?, z_dim*2, 1, 1)
sz = h.size()
# BCWH -> BWHC
org_h = h
h = h.permute(0,2,3,1)
h = h.contiguous()
Z = h.view(-1,self.z_dim)
W = self.dict.weight
# Code 1.2
def L2_dist(a,b):
return ((a - b) ** 2)
# Sample nearest embedding
j = L2_dist(Z[:,None],W[None,:]).sum(2).min(1)[1]
# Code 1.3
W_j = W[j]
# Code 1.4, Stop gradients
Z_sg = Z.detach()
W_j_sg = W_j.detach()
# BWHC -> BCWH
h = W_j.view(sz[0],sz[2],sz[3],sz[1])
h = h.permute(0,3,1,2)
# Code 1.5, gradient hook register
def hook(grad):
nonlocal org_h
self.saved_grad = grad
self.saved_h = org_h
return grad
h.register_hook(hook)
# Code 1.6, losses
return self.decoder(h), L2_dist(Z,W_j_sg).sum(1).mean(), L2_dist(Z_sg,W_j).sum(1).mean()
# Code 1.7, back propagation for encoder
def bwd(self):
self.saved_h.backward(self.saved_grad)
此時(shí)有一個(gè)比較有意思的函數(shù)調(diào)用,如Code 1.5所示,此處的h.register_hook(hook_fn)表示對張量h注冊了個(gè)回調(diào)鉤子函數(shù) hook_fn,我們先看下這個(gè)函數(shù)具體作用是什么,從官網(wǎng)的API信息[5]中可以知道,當(dāng)每次對這個(gè)張量進(jìn)行梯度計(jì)算的時(shí)候,都會調(diào)用這個(gè)回調(diào)函數(shù)hook_fn。hook_fn的輸入是該張量的原始梯度grad_orig,hook_fn會對梯度進(jìn)行變換得到grad_new = hook_fn(grad_orig),并且將grad_orig更新為grad_new。這個(gè)功能可以讓我們實(shí)現(xiàn)將decoder的梯度賦值到encoder中,我們且看是如何實(shí)現(xiàn)的。我們留意到其對h,也即是W_j的結(jié)果進(jìn)行了注冊回調(diào),我們也知道W_j和Z的形狀是一致的,此時(shí)我們希望 ,因此我們需要以某種方式緩存下Z和W_j的梯度,在梯度反向傳播的時(shí)候,將W_j的梯度賦值到Z的梯度上,這也就是回調(diào)hook的目的——緩存下此時(shí)W_j的梯度和原始的Z節(jié)點(diǎn)。 在Code 1.6就開始構(gòu)建decoder的輸出以及
? 和
這兩個(gè)loss了,那么何時(shí)我們對其encoder的梯度進(jìn)行賦值呢?我們繼續(xù)看到solver.py文件~
def hook(grad):
nonlocal org_h
self.saved_grad = grad
self.saved_h = org_h
return grad
在solver.py中,最主要的邏輯如下所示,其中的self.G(x)即是Code 1所示的forward()邏輯,對于其輸出的解碼器輸出out,構(gòu)建重建損失,對重建損失loss_rec和其他倆對齊損失loss_e1和loss_e2進(jìn)行加和后得到loss,對loss進(jìn)行梯度計(jì)算(注意此時(shí)需要將retain_graph設(shè)置為True,以保留葉子節(jié)點(diǎn)的梯度,具體作用見博文[6])。注意到此時(shí)由于最近鄰查表的引入,loss.backward(retain_graph=True)只對decoder進(jìn)行了梯度計(jì)算,此時(shí)為了對encoder也進(jìn)行梯度計(jì)算,還需要進(jìn)行self.G.bwd(),這個(gè)也正是我們剛才提到的,將W_j的梯度賦值到Z的梯度上,我們且看看如何實(shí)現(xiàn)的。如Code 1.7所示,self.G.bwd()的邏輯很簡單,對緩存的Z進(jìn)行梯度『賦值』為緩存下來的W_j梯度,但是準(zhǔn)確的說,此處并不是對Z的梯度賦值,而是制定了計(jì)算Z梯度的前繼梯度為self.saved_grad(梯度計(jì)算是鏈?zhǔn)椒▌t,這意味著梯度計(jì)算勢必有前繼和后續(xù)),我們在附錄里面會舉個(gè)例子說明tensor.backward()和tensor.register_hook()的作用??偠灾?,通過調(diào)用self.G.bwd()我們可以對encoder的梯度也進(jìn)行計(jì)算了,最后調(diào)用optimizer.step()進(jìn)行參數(shù)更新即可了。
def bwd(self):
self.saved_h.backward(self.saved_grad)
Code 2. solver.py的主要邏輯
# ================== Train G ================== #
# Train with real images (VQ-VAE)
out, loss_e1, loss_e2 = self.G(x)
loss_rec = reconst_loss(out, x)
loss = loss_rec + loss_e1 + self.vq_beta * loss_e2
self.g_optimizer.zero_grad()
# For decoder
loss.backward(retain_graph=True)
# For encoder
self.G.bwd()
self.g_optimizer.step()
附錄A. tensor.backward()和tensor.register_hook()的作用
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2) # 梯度翻倍
>>> v.backward(torch.tensor([1., 2., 3.])) # v的梯度前繼為[1, 2, 3]
>>> v.grad # 因此輸出的梯度為[2, 4, 6]
2
4
6
[torch.FloatTensor of size (3,)]
>>> h.remove() # removes the hook
Reference
[1]. Van Den Oord, Aaron, and Oriol Vinyals. “Neural discrete representation learning.” Advances in neural information processing systems 30 (2017).
[2]. https://blog.csdn.net/LoseInVain/article/details/129224424, 【論文極速讀】VQ-VAE:一種稀疏表征學(xué)習(xí)方法
[3]. https://github.com/nakosung/VQ-VAE
[4]. https://blog.csdn.net/LoseInVain/article/details/105461904, 在pytorch中停止梯度流的若干辦法,避免不必要模塊的參數(shù)更新
[5]. https://pytorch.org/do