Contrastive Predictive Coding
Representation Learning with Contrastive Predictive Coding。 2018
1。 Contrastive Learning 的思想
對比學習屬於無監督學習的範疇,透過無監督來學習特徵表示。我們希望特徵表示是一種 high-level 的表示,不需要過多關注影象畫素層面的資訊,而更多關注全域性的、語義的資訊。“對比學習”中的“對比”是 positive樣本 和 negative 樣本的對比。在學習到的“表示空間”內,增大某樣本與其 positive 樣本之間的相似度,減少與negative 樣本的相似度。
形式化的表述如下:
score 是相似度的度量,例如
。在 Contrastive Learning 中用互資訊來衡量相似度。
是我們希望學習到的 representation,表示為一個 encoder,對原始的輸入
進行編碼。
樣本
相對於
是正樣本,
相對於
是負樣本。x可以被稱為是一個 anchor 樣本
隨後我們構建損失函式,目標是最大化
和
的 score,壓低
和
的score。因此損失函式設定如下,是 softmax 損失。
2。 Contrastive Predictive Coding(CPC)
CPC 的特點是,利用序列資料中資料之間的關聯特性來挖掘資料的特徵。許多資料都是關聯的,如文字,影片,語音等。另外,單張影象也可以透過按照某一方向(如對角線)進行切塊來獲得具有內部關聯性的資料。
假設某一段具有序列關係的資料為
。CPC中使用的 positive 樣本可以是從中順序的取樣多個樣本(如8個)來構成:
在該樣本中,使用前面四個樣本來預測後面四個樣本。那麼 negative 樣本可以是
其中
,即待預測的樣本跟輸入的樣本沒有順序關係,而是從資料中隨機取樣的。一般在一個訓練批次的
個樣本中,我們使用 1 個 positive 樣本和
個negative 樣本。
CPC有如下特點
:
不直接在樣本
的層面進行預測,而是透過編碼器
,將原始資料編碼後,使用損失函式進行計算
Anchor樣本的使用。在上面的例子
中,Anchor樣本是
,Positive樣本是
,Negative樣本是
。CPC將Anchor分別經過編碼器得到
之後,再使用一個迴圈網路將這四個樣本編碼為一個context vector,記為
。可以認為,
是一個 high-level 的 feature,在文中被稱為 slow feature,意為可以跨越多個時間步的特徵。
最大化 Anchor 和 Positive 之前的 score 轉化為,最大化 Anchor 編碼為的 context
和 positive 樣本對應的編碼
, 最小化
和隨機取樣的樣本對應的編碼
之間的互資訊。
在訓練之後,
和
均可以作為
的特徵表示,其中
更關注與單個時間步的資訊,而
可以包含多個時間步的資訊。
透過一個例子來說明CPC的損失函式定義。一段語音序列,使用前三個時間步來產生 context,將後面第4個時間步的樣本作為 positive ,將隨機取樣的語音作為 negative,分別計算 score(
) 和 score(
),使用以上損失函式來計算。圖中的Bilinear是指,本文將 score(
) 引數化形式設定為
,其中
需要進行學習。
3。 Connection Between CPC and Mutual-Information
我們下面證明
“最小化CPC定義的損失函式實際上最大化 context c_t 和待預測正樣本 X_t 之間的互資訊”
。互資訊可以挖掘兩個變數之間的非線性依賴關係,因此訓練後 context 可以代表的與之關聯的樣本的特徵。
以下證明過程與原文有所不同,個人認為更好理解。
在CPC的設定中,
有兩種型別。我們用
表示樣本
是 positive 的,用
表示該樣本是 negative 的。在 CPC 中,樣本是 positive 的機率記為
。我們的目的是增加該似然,因此損失函式設定為:
為了求該損失,我們先分析幾個基本量:
先驗
。在訓練中,N個樣本含有一個正樣本和N-1個負樣本。因此先驗
,
。
似然
。我們這裡設定
是正樣本,因此
。而
,原因是二者相互獨立。
根據貝葉斯公式,損失函式中的關鍵機率
可以分解為:
進一步的,損失
其中,log中的第一項,由於
和
不是獨立的,因此
, 兩邊同除以
,從而
。因此上式第一項大於0。因此
最後一行的推導是基於互資訊的定義:
進而,根據
可得
因此 CPC 的損失函式
是 Mutual-Information 的一個下界,最小化
等價於最大化互資訊。
4。 實驗
如果我們透過CPC學到了很好的特徵表示,那麼在後續的任務中,我們可以在特徵空間中進行建模。CPC的優勢在於,特徵的表示學習過程中使用的是無監督資料,語音,影象,影片的無監督資料非常容易獲得,但是為資料進行標註卻非常困難。因此無監督學習具有重要的意義。
4。1 語音
使用語音資料集來學習特徵表示。由於語音的前後關係較大,因此使用 context
作為特徵表示。以下代表了
的分佈,可以看出每個說話人的語音特徵會進行聚類。在特徵
的基礎上,我們只新增線性層就可以進行分類,例如分類說話人的身份。
4。2 影象
單張影象本身沒有序列關係。因此將影象進行分塊,沿著對角線方向分成64*64的影象塊,且每個影象之間有覆蓋部分。因此按照該方向切分開的影象序列是具有序列關係的,可以使用CPC進行訓練。負樣本可以從不是該方向的其餘位置的影象塊中進行取樣。在經過學習後,直接將
編碼器作為特徵提取模組,隨後接線性層就可以進行分類任務。
4。3 強化學習
在A3C中,將CPC作為一個輔助任務。即在學習策略和值函式損失的同時,學習一個CPC來預測未來的觀測。在CPC中,將一個週期的互動樣本當做是一個序列。CPC和A3C共享卷機編碼,因此可以幫助RL智慧體進行快速迭代。
5。 附:程式碼
這裡有一個簡單易懂的Mnist例子實現。
在該例子中,呈現順序label的mnist圖片被當做一個序列,例如lable分別為“0,1,2,3,4,5,6”的圖片,或“6,7,8,9,0,1,2”的圖片,都是positive的序列。而後面被預測部分的順序被打亂的序列,是negative的序列。假設我們用前面“6”張圖片預測後面“4”張圖片,則被用於訓練的資料如下圖,每行代表一個訓練樣本。第一行是 negative 的,因此“0,1,2,3,4,5”之後為“8,1,3,0”,順序被打亂。第二行同樣,“3,4,5,6,7,8”之後為“3,7,2,6”。第四行為positive樣本“1,2,3,4,5,6”預測“7,8,9,0“。
模型結構如下: