語義分割之dice loss深度分析(梯度視覺化)
dice loss 來自文章VNet(V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation),旨在應對語義分割中正負樣本強烈不平衡的場景。本文透過理論推導和實驗驗證的方式對dice loss進行解析,幫助大家去更好的理解和使用。
dice loss 定義
dice loss 來自 dice coefficient,是一種用於評估兩個樣本的相似性的度量函式,取值範圍在0到1之間,取值越大表示越相似。dice coefficient定義如下:
其中其中
是X和Y之間的交集,
和
分表表示X和Y的元素的個數,分子乘2為了保證分母重複計算後取值範圍在
之間。
因此dice loss可以寫為:
對於二分類問題,一般預測值分為以下幾種:
TP: true positive,真陽性,預測是陽性,預測對了,實際也是正例。
TN: true negative,真陰性,預測是陰性,預測對了,實際也是負例。
FP: false positive,假陽性,預測是陽性,預測錯了,實際是負例。
FN: false negative,假陰性,預測是陰性,預測錯了,實際是正例。
這裡dice coefficient可以寫成如下形式:
而我們知道:
可見dice coefficient是等同
F1 score
,直觀上dice coefficient是計算$X$與$Y$的相似性,本質上則同時隱含precision和recall兩個指標。可見dice loss是直接最佳化
F1 score
。
這裡考慮通用的實現方式來表達,定義:
其中
為為網路預測值,是經過sigmoid或softmax的值,取值在
之間。
為target值,取值非0即1。
dice loss 有以下幾種形式:
形式1
:
形式2(原論文形式)
:
形式3
:
為加平方的方式獲取:
為一個極小的數,一般稱為平滑係數,有兩個作用:
防止分母預測為0。值得說明的是,一般分割網路輸出經過sigmoid 或 softmax,是不存在輸出為絕對0的情況。這裡加平滑係數主要防止一些極端情況,輸出位數太小而導致編譯器丟失數位的情況。
平滑係數可以起到平滑loss和梯度的操作。
不同實現形式計算不同,但本質並無太大區別,本文主要討論形式1。下面為pytorch的實現方式:
def
dice_loss
(
target
,
predictive
,
ep
=
1e-8
):
intersection
=
2
*
torch
。
sum
(
predictive
*
target
)
+
ep
union
=
torch
。
sum
(
predictive
)
+
torch
。
sum
(
target
)
+
ep
loss
=
1
-
intersection
/
union
return
loss
梯度分析
從dice loss的定義可以看出,dice loss 是一種
區域相關
的loss。意味著某畫素點的loss以及梯度值不僅和該點的label以及預測值相關,和其他點的label以及預測值也相關,這點和ce (交叉熵cross entropy) loss 不同。因此分析起來比較複雜,這裡我們簡化一下,首先從loss曲線和求導曲線對單點輸出方式分析。然後對於多點輸出的情況,利用模擬預測輸出來分析其梯度。而多分類softmax是sigmoid的一種推廣,本質一樣,所以這裡只考慮sigmoid輸出的二分類問題,首先sigmoid函式定義如下:
求導:
單點輸出的形式
單點輸出的情況是網路輸出的是一個數值而不是一個map,單點輸出的dice loss公式如下:
繪製曲線圖如下,其中藍色的為ce loss,橙色的為dice loss。
當
時,
在一個較大的範圍內,loss的值都很大接近1。只有
預測非常小,
接近於0(和
量級相近)時loss才會變小,而這種情況出現的機率也較小。一般情況下,在正常範圍內,預測不管為任何值,都無差別對待,loss 都統一非常大。
當
時,
在0左右較小的範圍內,保持不錯的特性。但隨著
遠離0點,loss呈現飽和現象。
計算梯度:
繪圖如下:
梯度正負符號代表梯度的方向,網路採用梯度下降法更新引數,當梯度為正時,引數更新變小,當梯度為負時引數更新變大。這裡為了討論正負樣本的梯度關係,所以取了絕對值操作。
當
時,同樣在
的正常範圍內,
的梯度值接近0 。實際上,由於平滑係數的存在,該梯度不為0,而是一個非常小的值 。該值過於小,對網路的貢獻也非常有限。
當
時,
在0點附近存在一個峰值,此時
接近0。5。隨著預測值
越接近1或0,梯度越小,出現梯度飽和的現象。
一般神經網路訓練之前都會採取權重初始化,不管是Xavier初始化還是Kaiming初始化(或者其他初始化的方法), 輸出
是接近於0的。再回到上面的圖,可見此時正樣本(
)的監督是遠遠大於負樣本(
)的監督,可以認為網路前期會重點挖掘正樣本。而ce loss 是平等對待兩種樣本的。
多點情況分析
dice loss 是應用於語義分割而不是分類任務,並且是一個區域相關的loss,因此更適合針對多點的情況進行分析。由於多點輸出的情況比較難用曲線呈現,這裡使用模擬預測值的形式觀察梯度的變化。
下圖為原始圖片和對應的label:
為了便於梯度視覺化,這裡對梯度求絕對值操作,因為我們關注的是梯度的大小而非方向。另外梯度值都乘以
保證在容易辨認的範圍。
首先定義如下熱圖,值越大,顏色越亮,反之亦然:
預測值變化(
值,圖上的數字為預測值區間):
dice loss 對應
值的梯度:
ce loss 對應
值的梯度:
可以看出:
一般情況下,dice loss 正樣本的梯度大於背景樣本的; 尤其是剛開始網路預測接近0。5的時候,這點和單點輸出的現象一致。說明 dice loss 更具有指向性,更加偏向於正樣本,保證有較低的FN。
負樣本(背景區域)也會產生梯度。
極端情況下,網路預測接近0或1時,對應點梯度值極小,dice loss 存在梯度飽和現象。此時預測失敗(FN,FP)的情況很難扭轉回來。不過該情況出現的機率較低,因為網路初始化輸出接近0。5,此時具有較大的梯度值。而網路透過梯度下降的方式更新引數,只會逐漸削弱預測失敗的畫素點。
對於ce loss,當前的點的梯度僅和當前預測值與label的距離相關,預測越接近label,梯度越小。當網路預測接近0或1時,梯度依然保持該特性。
對比發現, 訓練前中期,dice loss下正樣本的梯度值相對於ce loss,顏色更亮,值更大。說明dice loss 對挖掘正樣本更加有優勢。
dice loss為何能夠解決正負樣本不平衡問題?
因為dice loss是一個區域相關的loss。區域相關的意思就是,當前畫素的loss不光和當前畫素的預測值相關,和其他點的值也相關。dice loss的求交的形式可以理解為mask掩碼操作,因此不管圖片有多大,固定大小的正樣本的區域計算的loss是一樣的,對網路起到的監督貢獻不會隨著圖片的大小而變化。從上圖視覺化也發現,訓練更傾向於挖掘前景區域,正負樣本不平衡的情況就是前景佔比較小。而ce loss 會公平處理正負樣本,當出現正樣本佔比較小時,就會被更多的負樣本淹沒。
dice loss背景區域能否起到監督作用?
可以的,但是會小於前景區域。和直觀理解不同的是,隨著訓練的進行,背景區域也能產生較為可觀的梯度。這點和單點的情況分析不同。這裡求偏導,當
時:
可以看出, 背景區域的梯度是存在的,只有預測值命中的區域極小時, 背景梯度才會很小。
dice loss 為何訓練會很不穩定?
在使用dice loss時,一般正樣本為小目標時會產生嚴重的震盪。因為在只有前景和背景的情況下,小目標一旦有部分畫素預測錯誤,那麼就會導致loss值大幅度的變動,從而導致梯度變化劇烈。可以假設極端情況,只有一個畫素為正樣本,如果該畫素預測正確了,不管其他畫素預測如何,loss 就接近0,預測錯誤了,loss 接近1。而對於ce loss,loss的值是總體求平均的,更多會依賴負樣本的地方。
總結
dice loss 對正負樣本嚴重不平衡的場景有著不錯的效能,訓練過程中更側重對前景區域的挖掘。但訓練loss容易不穩定,尤其是小目標的情況下。另外極端情況會導致梯度飽和現象。因此有一些改進操作,主要是結合ce loss等改進,比如: dice+ce loss,dice + focal loss等,本文不再論述。