性无码一区二区三区在线观看,少妇被爽到高潮在线观看,午夜精品一区二区三区,无码中文字幕人妻在线一区二区三区,无码精品国产一区二区三区免费

徐土豆
認(rèn)證:優(yōu)質(zhì)創(chuàng)作者
所在專題目錄 查看專題
圖文多模態(tài)語義融合前的語義對齊——一種單雙混合塔多模態(tài)模型
在多模態(tài)模型訓(xùn)練時(shí),如何合適地融合單模態(tài)損失
FILIP: 一種基于交互的細(xì)粒度圖文預(yù)訓(xùn)練模型
ERNIE VIL 2.0,多模態(tài)模型的一種多視角預(yù)訓(xùn)練范式
VQ-VAE的實(shí)現(xiàn)方法分析——一種基于梯度回調(diào)的方法
【論文極速讀】視頻檢索中的模態(tài)均衡方法
作者動態(tài) 更多
給定計(jì)算預(yù)算下的最佳LLM模型尺寸與預(yù)訓(xùn)練數(shù)據(jù)量分配
1天前
大模型推理時(shí)的尺度擴(kuò)展定律
2天前
世界多胞體與世界模型
1星期前
獎(jiǎng)勵(lì)模型中的尺度擴(kuò)展定律和獎(jiǎng)勵(lì)劫持
1星期前
MeCo——給預(yù)訓(xùn)練數(shù)據(jù)增加源信息,就能減少33%的訓(xùn)練量并且提升效果
2星期前

VQ-VAE的實(shí)現(xiàn)方法分析——一種基于梯度回調(diào)的方法

筆者在前文 [2] 中曾經(jīng)介紹過VQ-VAE模型,如Fig 1.所示,該模型基于最近鄰查找的方式從字典中查找其索引,作為其稀疏化后的令牌,具體細(xì)節(jié)可見博文[2]。

Fig 1. 通過最近鄰方法在字典里面查找稀疏令牌,作為稀疏編碼的結(jié)果,然后通過反查字典可以對feature map進(jìn)行恢復(fù)。整個(gè)框架中有若干參數(shù)需要學(xué)習(xí),分別是encoder,decoder網(wǎng)絡(luò)參數(shù)和Embedding space字典的參數(shù)。然而稀疏編碼的過程由于出現(xiàn)了最近鄰方法,這個(gè)過程顯然是無法傳遞梯度的,為了實(shí)現(xiàn)編碼器的更新,可以考慮將解碼器的梯度直接拷貝到編碼器中。假設(shè)對于編碼后恢復(fù)的而言,其每個(gè)元素表示為,那么對于其中某個(gè)元素的梯度表示為,同理,對于編碼后的而言,同樣有? ,令? 。

那么對于編碼器的梯度就可以表示為 。在詳細(xì)分析代碼實(shí)現(xiàn)邏輯之前,讓我們回顧下其損失函數(shù),如(1-1)所示,其中的為停止梯度函數(shù),表示該函數(shù)無梯度傳導(dǎo)。decoder的參數(shù)通過第一項(xiàng)損失項(xiàng)進(jìn)行更新(這部分損失可通過MSE損失建模),稱之為重建損失。encoder參數(shù)通過第一項(xiàng)和第三項(xiàng)損失進(jìn)行更新,其中第一項(xiàng)是重建損失,第三項(xiàng)是為了encoder編碼產(chǎn)出和embedding space進(jìn)行對齊而設(shè)計(jì)的,由于此時(shí)通過函數(shù)停止了梯度,因此此時(shí)的參數(shù)不會得到更新。Embedding space的參數(shù)通過第二項(xiàng)損失項(xiàng)進(jìn)行更新,通過將encoder編碼結(jié)果進(jìn)行停止梯度,我們只對E \mathcal{E}E進(jìn)行參數(shù)更新。

Fig 2. 通過梯度拷貝,將decoder的梯度拷貝到encoder中。

那么在代碼中如何實(shí)現(xiàn)這些邏輯呢?我們首先可以參考[3]項(xiàng)目中的實(shí)現(xiàn)。我們首先分析model.py文件中的forward函數(shù),字典定義為一個(gè)nn.Embedding層(Code 1.1),其參數(shù)就是self.dict.weight,那么求最近鄰的操作就如Code 1.2所示。Code 1.3將最近鄰的索引結(jié)果(也即是稀疏化后的視覺令牌),在字典中進(jìn)行查詢,對feature map進(jìn)行恢復(fù)。因此W_j的形狀和Z是一致的。此時(shí)Code 1.4中對Z和W_j進(jìn)行detach,這個(gè)detach的作用之前在博文[4]中闡述過,本文不進(jìn)行累述,其主要作用可視為是停止了該節(jié)點(diǎn)開始的梯度傳導(dǎo),也即是用于實(shí)現(xiàn)公式(1-1)中的。

Code 1. model.py的主要邏輯

def __init__(self,...):
	...
	self.dict = nn.Embedding(k_dim, z_dim) # Code 1.1
	
def forward(self, x):
     h = self.encoder(x) # (?, z_dim*2, 1, 1)
     sz = h.size()
     
     # BCWH -> BWHC
     org_h = h
     h = h.permute(0,2,3,1)
     h = h.contiguous()
     Z = h.view(-1,self.z_dim)
     W = self.dict.weight
	 
	 # Code 1.2
     def L2_dist(a,b):
         return ((a - b) ** 2)
     # Sample nearest embedding
     j = L2_dist(Z[:,None],W[None,:]).sum(2).min(1)[1]
	 
	 # Code 1.3
     W_j = W[j]

     # Code 1.4, Stop gradients
     Z_sg = Z.detach()
     W_j_sg = W_j.detach()

     # BWHC -> BCWH
     h = W_j.view(sz[0],sz[2],sz[3],sz[1])
     h = h.permute(0,3,1,2)
	 
	 # Code 1.5, gradient hook register
     def hook(grad):
         nonlocal org_h
         self.saved_grad = grad
         self.saved_h = org_h
         return grad

     h.register_hook(hook)
     
     # Code 1.6, losses
     return self.decoder(h), L2_dist(Z,W_j_sg).sum(1).mean(), L2_dist(Z_sg,W_j).sum(1).mean()

# Code 1.7, back propagation for encoder
def bwd(self):
    self.saved_h.backward(self.saved_grad)

此時(shí)有一個(gè)比較有意思的函數(shù)調(diào)用,如Code 1.5所示,此處的h.register_hook(hook_fn)表示對張量h注冊了個(gè)回調(diào)鉤子函數(shù) hook_fn,我們先看下這個(gè)函數(shù)具體作用是什么,從官網(wǎng)的API信息[5]中可以知道,當(dāng)每次對這個(gè)張量進(jìn)行梯度計(jì)算的時(shí)候,都會調(diào)用這個(gè)回調(diào)函數(shù)hook_fn。hook_fn的輸入是該張量的原始梯度grad_orig,hook_fn會對梯度進(jìn)行變換得到grad_new = hook_fn(grad_orig),并且將grad_orig更新為grad_new。這個(gè)功能可以讓我們實(shí)現(xiàn)將decoder的梯度賦值到encoder中,我們且看是如何實(shí)現(xiàn)的。我們留意到其對h,也即是W_j的結(jié)果進(jìn)行了注冊回調(diào),我們也知道W_j和Z的形狀是一致的,此時(shí)我們希望 ,因此我們需要以某種方式緩存下Z和W_j的梯度,在梯度反向傳播的時(shí)候,將W_j的梯度賦值到Z的梯度上,這也就是回調(diào)hook的目的——緩存下此時(shí)W_j的梯度和原始的Z節(jié)點(diǎn)。 在Code 1.6就開始構(gòu)建decoder的輸出以及? 和這兩個(gè)loss了,那么何時(shí)我們對其encoder的梯度進(jìn)行賦值呢?我們繼續(xù)看到solver.py文件~

def hook(grad):
	nonlocal org_h
	 self.saved_grad = grad
	 self.saved_h = org_h
	 return grad

在solver.py中,最主要的邏輯如下所示,其中的self.G(x)即是Code 1所示的forward()邏輯,對于其輸出的解碼器輸出out,構(gòu)建重建損失,對重建損失loss_rec和其他倆對齊損失loss_e1和loss_e2進(jìn)行加和后得到loss,對loss進(jìn)行梯度計(jì)算(注意此時(shí)需要將retain_graph設(shè)置為True,以保留葉子節(jié)點(diǎn)的梯度,具體作用見博文[6])。注意到此時(shí)由于最近鄰查表的引入,loss.backward(retain_graph=True)只對decoder進(jìn)行了梯度計(jì)算,此時(shí)為了對encoder也進(jìn)行梯度計(jì)算,還需要進(jìn)行self.G.bwd(),這個(gè)也正是我們剛才提到的,將W_j的梯度賦值到Z的梯度上,我們且看看如何實(shí)現(xiàn)的。如Code 1.7所示,self.G.bwd()的邏輯很簡單,對緩存的Z進(jìn)行梯度『賦值』為緩存下來的W_j梯度,但是準(zhǔn)確的說,此處并不是對Z的梯度賦值,而是制定了計(jì)算Z梯度的前繼梯度為self.saved_grad(梯度計(jì)算是鏈?zhǔn)椒▌t,這意味著梯度計(jì)算勢必有前繼和后續(xù)),我們在附錄里面會舉個(gè)例子說明tensor.backward()和tensor.register_hook()的作用??偠灾?,通過調(diào)用self.G.bwd()我們可以對encoder的梯度也進(jìn)行計(jì)算了,最后調(diào)用optimizer.step()進(jìn)行參數(shù)更新即可了。

def bwd(self):
    self.saved_h.backward(self.saved_grad)

Code 2. solver.py的主要邏輯

# ================== Train G ================== #
# Train with real images (VQ-VAE)
out, loss_e1, loss_e2 = self.G(x)
loss_rec = reconst_loss(out, x)

loss = loss_rec + loss_e1 + self.vq_beta * loss_e2
self.g_optimizer.zero_grad()

# For decoder
loss.backward(retain_graph=True)

# For encoder
self.G.bwd()

self.g_optimizer.step()

附錄A. tensor.backward()和tensor.register_hook()的作用

>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2)  # 梯度翻倍
>>> v.backward(torch.tensor([1., 2., 3.])) # v的梯度前繼為[1, 2, 3]
>>> v.grad # 因此輸出的梯度為[2, 4, 6]

 2
 4
 6
[torch.FloatTensor of size (3,)]

>>> h.remove()  # removes the hook

Reference

[1]. Van Den Oord, Aaron, and Oriol Vinyals. “Neural discrete representation learning.” Advances in neural information processing systems 30 (2017).

[2]. https://blog.csdn.net/LoseInVain/article/details/129224424, 【論文極速讀】VQ-VAE:一種稀疏表征學(xué)習(xí)方法

[3]. https://github.com/nakosung/VQ-VAE

[4]. https://blog.csdn.net/LoseInVain/article/details/105461904, 在pytorch中停止梯度流的若干辦法,避免不必要模塊的參數(shù)更新

[5]. https://pytorch.org/do

聲明:本內(nèi)容為作者獨(dú)立觀點(diǎn),不代表電子星球立場。未經(jīng)允許不得轉(zhuǎn)載。授權(quán)事宜與稿件投訴,請聯(lián)系:editor@netbroad.com
覺得內(nèi)容不錯(cuò)的朋友,別忘了一鍵三連哦!
贊 1
收藏 2
關(guān)注 52
成為作者 賺取收益
全部留言
0/200
成為第一個(gè)和作者交流的人吧