您當前的位置:首頁 > 體育

文字識別中的CTC損失

作者:由 laygin 發表于 體育時間:2020-05-09

文字識別

就是將一張精確裁剪的只包含文字的影象作為輸入,透過模型識別出對應的計算機可處理的文字內容。

傳統的文字識別方法需要將每一個字元單獨切割出來進行多分類。 對於文字的標註和對齊是很困難的,畢竟要花費大量的時間精力,並且即使同一個文字的寬度也不總是一樣大。

而使用

神經網路模型

,可以幫助我們提取非常好的特徵。

使用CTC操作,我們只關心哪些文字在圖中出現了,而不必關心文字具體的位置和寬度,CTC損失會引導NN模型進行訓練,那麼,CTC是如何實現的呢?

文字如何編碼

為了瞭解CTC的編碼方式,得先看看CTC是如何解碼的。 對於一串字元,比如“bbbbbeeeeeeee”,CTC會將重複的字元移除掉,所以這一串字元解碼出來就是“be”。

不知你發現沒有,這會導致一個問題,如果目標文字是bee呢? 被解碼成be就不對了啊。

為了解決這個問題,CTC引入了一個特別的字元,為了方便就用“-”表示(但它不是真的“-”符號,在原始碼中編碼為0),有了這個字元就可以很好的分隔文字 中的重複字元了,

比如要編碼“be”,可以是“bbbbee”或“——-bb-e”或“b-e”等,

那“bee”就可以表示為“——-

bbee-ee

——”或“b-e-eee”或“

bee——-ee

”等 相信你已經知道其中的意思了。

損失如何計算

神經網路模型的輸出是一個序列,比如使用一個

矩陣

(24, 37)來表示,其中24是序列長度或時間步長,37表示字元個數(比如10個數字26個小寫字母加上一個特殊字元)。

以一個略縮版的矩陣為例,其中時間步為2,字元數為3(包括特殊字元),如圖所示:

文字識別中的CTC損失

在每個時間步每個字元的機率之和為1(softmax),即矩陣axis為列的和為1

每一種走向(或者路徑)的得分透過將相應字元的機率相乘得到,比如aa的得分為0。4x0。4=0。16,a-為0。24,以此類推

損失透過對應於標籤文字的所有可能路徑得分之和計算得到。在這個例子中,如果標籤為a,對應a的所有可能路徑為aa,a-,-a,將這三條路徑的得分相加為0。64

如果該例的標籤為空格(使用-表示)呢?對應的路徑只有一條,即——,所以得分為0。6x0。6=0。36

以上計算的是機率值,而不是損失。損失可以透過機率的負對數計算得到(以2為底數,在[0,1]區間遞減且非負,即得分越大,損失越小)

Q:如果標籤是b呢?得分是多少?????

計算出了損失,NN模型就能不斷進行最佳化,輸出一個更好的序列表示矩陣

如何根據模型輸出矩陣進行解碼

透過上面損失的計算,模型的訓練,能夠找到一個很好的矩陣來表示輸入的影象,我們最終的目的是能夠透過這個矩陣得到影象對應的文字,也就是解碼。

最簡單直接的方法就是

最佳路徑解碼

1。 獲取每一個時間步最大得分對應的字元作為該時間步的預測結果

2。 將所有時間步得到的輸出組合成為一串字元,首先移除掉所有重複字元,然後移除特殊分隔字元,與編碼階段的操作正好相反。剩下的字元就是最終識別結果。

這種方法只能得到近似結果,比如上一個例子,假設那個矩陣是模型的輸出,以最佳路徑解碼出來是“——”,也就是空,其機率為0。36,而實際上應該是“a”,因為“a”的機率為0。64

還有其他更佳的解碼方法,比如束搜尋,後面在深入介紹下。

pytorch中的 CTC LOSS 如何使用

假設CNN+RNN模型的輸入影象大小為32x100,輸出

機率矩陣

的大小為24x37(如上所述,24為時間步長,37為字元類別個數),忽略掉batch size。

在最後一維需要執行softmax變成0~1之間的機率值。

criterion

=

torch

nn

CTCLoss

zero_infinity

=

True

loss

=

criterion

preds

text

preds_size

length

CTCLoss的zero_infinity代表是否將無限大的損失和梯度歸零,無限損失主要發生在輸入太短而無法與目標對齊時。

其中,假設模型輸出的shape為(batch size,24, 37),而CTCLoss 輸入的preds要求shape為(24, batch size,37)所以需要將模型輸出的第0,1維交換;

preds_size是一個列表,長度為batch size,元素為該batch內每個輸入序列的長度,在本例中也就是24;

text和length分別是透過對原始label進行編碼得到的文字索引和每個文字的長度(一個batch的文字)

References

標簽: 字元  文字  CTC  解碼  24