您當前的位置:首頁 > 收藏

語義分割之dice loss深度分析(梯度視覺化)

作者:由 皮特潘 發表于 收藏時間:2020-10-28

dice loss 來自文章VNet(V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation),旨在應對語義分割中正負樣本強烈不平衡的場景。本文透過理論推導和實驗驗證的方式對dice loss進行解析,幫助大家去更好的理解和使用。

dice loss 定義

dice loss 來自 dice coefficient,是一種用於評估兩個樣本的相似性的度量函式,取值範圍在0到1之間,取值越大表示越相似。dice coefficient定義如下:

dice=\frac{2|X\bigcap Y|}{|X|+|Y|}

其中其中

|X\bigcap Y|

是X和Y之間的交集,

|X|

|Y|

分表表示X和Y的元素的個數,分子乘2為了保證分母重複計算後取值範圍在

[0,1]

之間。

因此dice loss可以寫為:

L_{dice}=1-\frac{2|X\bigcap Y|}{|X|+|Y|}

對於二分類問題,一般預測值分為以下幾種:

TP: true positive,真陽性,預測是陽性,預測對了,實際也是正例。

TN: true negative,真陰性,預測是陰性,預測對了,實際也是負例。

FP: false positive,假陽性,預測是陽性,預測錯了,實際是負例。

FN: false negative,假陰性,預測是陰性,預測錯了,實際是正例。

語義分割之dice loss深度分析(梯度視覺化)

這裡dice coefficient可以寫成如下形式:

dice = \frac{2TP}{2TP+FP+FN}

而我們知道:

precision = \frac{TP}{TP+FP}, recall = \frac{TP}{TP+FN}

F_{1}-score = \frac{2*precision*recall}{precision+recall}=\frac{2TP}{2TP+FP+FN}=dice

可見dice coefficient是等同

F1 score

,直觀上dice coefficient是計算$X$與$Y$的相似性,本質上則同時隱含precision和recall兩個指標。可見dice loss是直接最佳化

F1 score

這裡考慮通用的實現方式來表達,定義:

I=\sum_1^N t_iy_i

U=\sum_1^N (t_i+y_i)=\sum_1^N t_i+\sum_1^N y_i

其中

y_i

為為網路預測值,是經過sigmoid或softmax的值,取值在

(0,1)

之間。

t_i

為target值,取值非0即1。

dice loss 有以下幾種形式:

形式1

L_{dice}=1-\frac{2I+\varepsilon}{U+\varepsilon}

形式2(原論文形式)

L_{dice}=1-\frac{I+\varepsilon}{U-I+\varepsilon}

形式3

U

為加平方的方式獲取:

U=\sum_1^N (t_i^2+y_i^2)

\varepsilon

為一個極小的數,一般稱為平滑係數,有兩個作用:

防止分母預測為0。值得說明的是,一般分割網路輸出經過sigmoid 或 softmax,是不存在輸出為絕對0的情況。這裡加平滑係數主要防止一些極端情況,輸出位數太小而導致編譯器丟失數位的情況。

平滑係數可以起到平滑loss和梯度的操作。

不同實現形式計算不同,但本質並無太大區別,本文主要討論形式1。下面為pytorch的實現方式:

def

dice_loss

target

predictive

ep

=

1e-8

):

intersection

=

2

*

torch

sum

predictive

*

target

+

ep

union

=

torch

sum

predictive

+

torch

sum

target

+

ep

loss

=

1

-

intersection

/

union

return

loss

梯度分析

從dice loss的定義可以看出,dice loss 是一種

區域相關

的loss。意味著某畫素點的loss以及梯度值不僅和該點的label以及預測值相關,和其他點的label以及預測值也相關,這點和ce (交叉熵cross entropy) loss 不同。因此分析起來比較複雜,這裡我們簡化一下,首先從loss曲線和求導曲線對單點輸出方式分析。然後對於多點輸出的情況,利用模擬預測輸出來分析其梯度。而多分類softmax是sigmoid的一種推廣,本質一樣,所以這裡只考慮sigmoid輸出的二分類問題,首先sigmoid函式定義如下:

y = sigmoid(x)=\frac{1}{1+e^{-x}}

求導:

\frac{dy}{dx}=\frac{e^{-x}}{(1+e^{-x})^2}=y(1-y)

單點輸出的形式

單點輸出的情況是網路輸出的是一個數值而不是一個map,單點輸出的dice loss公式如下:

L_{dice}=1-\frac{2ty+\varepsilon}{t+y+\varepsilon}=\begin{cases}\frac{y}{y+\varepsilon}& \text{t=0}\\\frac{1-y}{1+y+\varepsilon}& \text{t=1} \end{cases}

繪製曲線圖如下,其中藍色的為ce loss,橙色的為dice loss。

語義分割之dice loss深度分析(梯度視覺化)

t=0

時,

x

在一個較大的範圍內,loss的值都很大接近1。只有

x

預測非常小,

y

接近於0(和

\epsilon

量級相近)時loss才會變小,而這種情況出現的機率也較小。一般情況下,在正常範圍內,預測不管為任何值,都無差別對待,loss 都統一非常大。

t=1

時,

x

在0左右較小的範圍內,保持不錯的特性。但隨著

x

遠離0點,loss呈現飽和現象。

計算梯度:

\frac{dL_{dice}}{dy}=-\frac{2t(t+y+\varepsilon)-2ty-\varepsilon}{(t+y+\varepsilon)^2}=\begin{cases} \frac{\varepsilon}{(y+\varepsilon)^2}& \text{t=0}\\-\frac{2+\varepsilon}{(1+y+\varepsilon)^2}& \text{t=1} \end{cases}

\frac{dL_{dice}}{dx}=\frac{dL_{dice}}{dy}\frac{dy}{dx}

繪圖如下:

語義分割之dice loss深度分析(梯度視覺化)

梯度正負符號代表梯度的方向,網路採用梯度下降法更新引數,當梯度為正時,引數更新變小,當梯度為負時引數更新變大。這裡為了討論正負樣本的梯度關係,所以取了絕對值操作。

t=0

時,同樣在

x

的正常範圍內,

x

的梯度值接近0 。實際上,由於平滑係數的存在,該梯度不為0,而是一個非常小的值 。該值過於小,對網路的貢獻也非常有限。

t=1

時,

x

在0點附近存在一個峰值,此時

y

接近0。5。隨著預測值

y

越接近1或0,梯度越小,出現梯度飽和的現象。

一般神經網路訓練之前都會採取權重初始化,不管是Xavier初始化還是Kaiming初始化(或者其他初始化的方法), 輸出

x

是接近於0的。再回到上面的圖,可見此時正樣本(

t=1

)的監督是遠遠大於負樣本(

t=0

)的監督,可以認為網路前期會重點挖掘正樣本。而ce loss 是平等對待兩種樣本的。

多點情況分析

dice loss 是應用於語義分割而不是分類任務,並且是一個區域相關的loss,因此更適合針對多點的情況進行分析。由於多點輸出的情況比較難用曲線呈現,這裡使用模擬預測值的形式觀察梯度的變化。

下圖為原始圖片和對應的label:

語義分割之dice loss深度分析(梯度視覺化)

為了便於梯度視覺化,這裡對梯度求絕對值操作,因為我們關注的是梯度的大小而非方向。另外梯度值都乘以

10^4

保證在容易辨認的範圍。

首先定義如下熱圖,值越大,顏色越亮,反之亦然:

語義分割之dice loss深度分析(梯度視覺化)

預測值變化(

y

值,圖上的數字為預測值區間):

語義分割之dice loss深度分析(梯度視覺化)

dice loss 對應

x

值的梯度:

語義分割之dice loss深度分析(梯度視覺化)

ce loss 對應

x

值的梯度:

語義分割之dice loss深度分析(梯度視覺化)

可以看出:

一般情況下,dice loss 正樣本的梯度大於背景樣本的; 尤其是剛開始網路預測接近0。5的時候,這點和單點輸出的現象一致。說明 dice loss 更具有指向性,更加偏向於正樣本,保證有較低的FN。

負樣本(背景區域)也會產生梯度。

極端情況下,網路預測接近0或1時,對應點梯度值極小,dice loss 存在梯度飽和現象。此時預測失敗(FN,FP)的情況很難扭轉回來。不過該情況出現的機率較低,因為網路初始化輸出接近0。5,此時具有較大的梯度值。而網路透過梯度下降的方式更新引數,只會逐漸削弱預測失敗的畫素點。

對於ce loss,當前的點的梯度僅和當前預測值與label的距離相關,預測越接近label,梯度越小。當網路預測接近0或1時,梯度依然保持該特性。

對比發現, 訓練前中期,dice loss下正樣本的梯度值相對於ce loss,顏色更亮,值更大。說明dice loss 對挖掘正樣本更加有優勢。

dice loss為何能夠解決正負樣本不平衡問題?

因為dice loss是一個區域相關的loss。區域相關的意思就是,當前畫素的loss不光和當前畫素的預測值相關,和其他點的值也相關。dice loss的求交的形式可以理解為mask掩碼操作,因此不管圖片有多大,固定大小的正樣本的區域計算的loss是一樣的,對網路起到的監督貢獻不會隨著圖片的大小而變化。從上圖視覺化也發現,訓練更傾向於挖掘前景區域,正負樣本不平衡的情況就是前景佔比較小。而ce loss 會公平處理正負樣本,當出現正樣本佔比較小時,就會被更多的負樣本淹沒。

dice loss背景區域能否起到監督作用?

可以的,但是會小於前景區域。和直觀理解不同的是,隨著訓練的進行,背景區域也能產生較為可觀的梯度。這點和單點的情況分析不同。這裡求偏導,當

t_i=0

時:

\frac {\partial L_{dice}}{\partial y_i}=\frac{-2\frac{\partial I}{\partial y_i}(U+\epsilon)+\frac{\partial U}{\partial y_i}(2I+\epsilon)}{(U+\epsilon)^2}=\frac{2I+\epsilon}{(U+\epsilon)^2}

可以看出, 背景區域的梯度是存在的,只有預測值命中的區域極小時, 背景梯度才會很小。

dice loss 為何訓練會很不穩定?

在使用dice loss時,一般正樣本為小目標時會產生嚴重的震盪。因為在只有前景和背景的情況下,小目標一旦有部分畫素預測錯誤,那麼就會導致loss值大幅度的變動,從而導致梯度變化劇烈。可以假設極端情況,只有一個畫素為正樣本,如果該畫素預測正確了,不管其他畫素預測如何,loss 就接近0,預測錯誤了,loss 接近1。而對於ce loss,loss的值是總體求平均的,更多會依賴負樣本的地方。

總結

dice loss 對正負樣本嚴重不平衡的場景有著不錯的效能,訓練過程中更側重對前景區域的挖掘。但訓練loss容易不穩定,尤其是小目標的情況下。另外極端情況會導致梯度飽和現象。因此有一些改進操作,主要是結合ce loss等改進,比如: dice+ce loss,dice + focal loss等,本文不再論述。

標簽: loss  DICE  梯度  樣本  預測