hinge loss是一種常用損失[1],常用于度量學(xué)習(xí)和表征學(xué)習(xí)。對(duì)于一個(gè)模型,如果給定了樣本x的標(biāo)簽y(假設(shè)標(biāo)簽是0/1標(biāo)簽,分別表示負(fù)樣本和正樣本),那么可以有兩種選擇進(jìn)行模型的表征學(xué)習(xí)。第一是pointwise形式的監(jiān)督學(xué)習(xí),通過(guò)交叉熵?fù)p失進(jìn)行模型訓(xùn)練,也即是如式子(1-1)所示。
其中的是softmax函數(shù)。第二種方式是將樣本之間組成如
的pair,通過(guò)hinge loss進(jìn)行pair的偏序關(guān)系學(xué)習(xí),其hinge loss可以描述為式子(1-2):
其中的和
分別表示負(fù)樣本和正樣本的打分,而m mm這是正樣本與負(fù)樣本之間打分的最小間隔。如Fig 1.所示,我們發(fā)現(xiàn)
,而
,從式子(1-2)中可以發(fā)現(xiàn),只有
會(huì)產(chǎn)生loss,而
? 則不會(huì)產(chǎn)生loss,這一點(diǎn)能防止模型過(guò)擬合一些簡(jiǎn)單的負(fù)樣本,而盡量去學(xué)習(xí)難負(fù)例。
從實(shí)現(xiàn)的角度出發(fā),我們通??梢圆捎孟旅娴姆绞綄?shí)現(xiàn),我們簡(jiǎn)單介紹下其實(shí)現(xiàn)邏輯。
import torch
import torch.nn.functional as F
margin = 0.3
for data in dataloader():
inputs, labels = data
score_orig = model(inputs) # score_orig shape (N, 1)
N = score_orig.shape[0]
score_1 = score_orig.expand(1, N) # score_1 shape (N, N)
score_2 = torch.transpose(score_1, 1, 0)
label_1 = label.expand(1, N) # label_1 shape (N, N)
label_2 = label_1.transpose(label_1, 1, 0)
label_diff = F.relu(label_1 - label_2)
score_diff = F.relu(score_2 - score_1 + margin)
hinge_loss = score_diff * label_diff
...
為了實(shí)現(xiàn)充分利用一個(gè)batch內(nèi)的樣本,我們希望對(duì)batch內(nèi)的所有樣本都進(jìn)行組pair,也就是說(shuō)當(dāng)batch size為N的時(shí)候,將會(huì)產(chǎn)出個(gè)pair(樣本自身不產(chǎn)生pair),為了實(shí)現(xiàn)這個(gè)目的,就需要代碼中expand和transpose這兩個(gè)操作,如Fig 2.所示,通過(guò)這兩個(gè)操作產(chǎn)出的score_1和score_2之差就是batch內(nèi)所有樣本之間的打分差,也就可以認(rèn)為是batch內(nèi)兩兩均組了pair。
Fig 2. 對(duì)score的處理流程圖
與此相似的,如Fig 3.所示,我們也對(duì)label進(jìn)行類似的處理,但是考慮到偏序已經(jīng)預(yù)測(cè)對(duì)了的pair不需要產(chǎn)生loss,而只有偏序錯(cuò)誤的pair需要產(chǎn)出loss,因此是label_1-label_2產(chǎn)出label_diff。通過(guò)F.relu()我們替代max()的操作,將不產(chǎn)出loss的pair進(jìn)行屏蔽,將score_diff和label_diff相乘就產(chǎn)出了hinge loss。
Fig 3. 對(duì)label處理的流程圖。
即便我們的label不是0/1標(biāo)簽,而是分檔標(biāo)簽,比如相關(guān)性中的0/1/2/3四個(gè)分檔,只要具有高檔位大于低檔位的這種物理含義(而不是分類標(biāo)簽),同樣也可以采用相同的方法進(jìn)行組pair,不過(guò)此時(shí)label_1-label_2產(chǎn)出的label_diff中會(huì)出現(xiàn)大于1的item,可視為是對(duì)某組pair的loss加權(quán),此時(shí)需要進(jìn)行標(biāo)準(zhǔn)化,代碼將會(huì)改成如下:
import torch
import torch.nn.functional as F
margin = 0.3
epsilon = 1e-6
for data in dataloader():
inputs, labels = data
score_orig = model(inputs) # score_orig shape (N, 1)
N = score_orig.shape[0]
score_1 = score_orig.expand(1, N) # score_1 shape (N, N)
score_2 = torch.transpose(score_1, 1, 0)
label_1 = label.expand(1, N) # label_1 shape (N, N)
label_2 = label_1.transpose(label_1, 1, 0)
label_diff = F.relu(label_1 - label_2)
score_diff = F.relu(score_2 - score_1 + margin)
hinge_loss = torch.sum(score_diff * label_diff) / (torch.sum(label_diff) + epsilon) # 標(biāo)準(zhǔn)化處理,加上epsilon防止溢出
...
Reference
[1]. https://blog.csdn.net/LoseInVain/article/details/103995962, 《一文理解Ranking Loss/Contrastive Loss/Margin Loss/Triplet Loss/Hinge Loss》