您當前的位置:首頁 > 書法

Focal loss的簡單實現(二分類+多分類)

作者:由 皮特潘 發表于 書法時間:2020-11-25

Focal loss的簡單實現(二分類+多分類)

focal loss提出是為了解決正負樣本不平衡問題和難樣本挖掘的。這裡僅給出公式,不去過多解讀:

Focal loss的簡單實現(二分類+多分類)

p_t

是什麼?

就是預測該類別的機率。在二分類中,就是sigmoid輸出的機率;在多分類中,就是softmax輸出的機率。

原始的CE loss(crossentropy loss):

Focal loss的簡單實現(二分類+多分類)

簡單說明:

在CE loss的基礎上增加動態調整因子,是預測機率高的(易分樣本)的loss拉的更低,預測機率低的(難分樣本)的loss也拉低,不過沒有那麼低。從而達到難樣本挖掘的效果。另外,又額外增加alpha平衡因子,用來處理樣本不平衡的場景。

可以這樣理解

:不管是二分類的sigmoid輸出,還是多分類的softmax輸出。分別找到label對應的輸出機率(onehot中為1的位置對應輸出tensor的地方,再轉化為機率)。該機率就是

p_t

,然後套用上面的公式就可以了。

ps。 其實sigmoid也可理解為多分類(2個輸出值)的情況,負樣本的輸出永遠為0就可以了。

程式碼實現

二分類focal loss

class

BCEFocalLoss

torch

nn

Module

):

def

__init__

self

gamma

=

2

alpha

=

0。25

reduction

=

‘mean’

):

super

BCEFocalLoss

self

__init__

()

self

gamma

=

gamma

self

alpha

=

alpha

self

reduction

=

reduction

def

forward

self

predict

target

):

pt

=

torch

sigmoid

predict

# sigmoide獲取機率

#在原始ce上增加動態權重因子,注意alpha的寫法,下面多類時不能這樣使用

loss

=

-

self

alpha

*

1

-

pt

**

self

gamma

*

target

*

torch

log

pt

-

1

-

self

alpha

*

pt

**

self

gamma

*

1

-

target

*

torch

log

1

-

pt

if

self

reduction

==

‘mean’

loss

=

torch

mean

loss

elif

self

reduction

==

‘sum’

loss

=

torch

sum

loss

return

loss

多分類focal loss

class

MultiCEFocalLoss

torch

nn

Module

):

def

__init__

self

class_num

gamma

=

2

alpha

=

None

reduction

=

‘mean’

):

super

MultiCEFocalLoss

self

__init__

()

if

alpha

is

None

self

alpha

=

Variable

torch

ones

class_num

1

))

else

self

alpha

=

alpha

self

gamma

=

gamma

self

reduction

=

reduction

self

class_num

=

class_num

def

forward

self

predict

target

):

pt

=

F

softmax

predict

dim

=

1

# softmmax獲取預測機率

class_mask

=

F

one_hot

target

self

class_num

#獲取target的one hot編碼

ids

=

target

view

-

1

1

alpha

=

self

alpha

ids

data

view

-

1

)]

# 注意,這裡的alpha是給定的一個list(tensor

#),裡面的元素分別是每一個類的權重因子

probs

=

pt

*

class_mask

sum

1

view

-

1

1

# 利用onehot作為mask,提取對應的pt

log_p

=

probs

log

()

# 同樣,原始ce上增加一個動態權重衰減因子

loss

=

-

alpha

*

torch

pow

((

1

-

probs

),

self

gamma

))

*

log_p

if

self

reduction

==

‘mean’

loss

=

loss

mean

()

elif

self

reduction

==

‘sum’

loss

=

loss

sum

()

return

loss

onehot

也可以用下面三行程式碼自己實現onehot

ids

=

target

view

-

1

1

onehot

=

torch

zeros_like

P

onehot

scatter_

1

ids

data

1。

標簽: loss  self  Alpha  torch  __