CCF BDCI 劇本角色情感識別第二版分享:多工學習開源方案
宣告:歡迎轉載,轉載請註明出處以及連結,碼字不易,歡迎小夥伴們點贊和分享。
一、比賽介紹
比賽背景
資料格式說明
從資料格式來看是一個多標籤多分類的問題,目前我看比賽群裡有一種思路就是透過多工迴歸的方式來做,我這裡有另一種方法來實現情感多工模型。
評估方式
二、資料介紹
訓練集樣例:
id content character emotions
1171_0001_A_1 天空下著暴雨,o2正在給c1穿雨衣,他自己卻只穿著單薄的軍裝,完全暴露在大雨之中。 o2 0,0,0,0,0,0
1171_0001_A_2 天空下著暴雨,o2正在給c1穿雨衣,他自己卻只穿著單薄的軍裝,完全暴露在大雨之中。 c1 0,0,0,0,0,0
1171_0001_A_5 o2停下來接過c1手裡的行李:你媽媽交待我了,等領了軍裝一定要照張相寄回去,讓街坊鄰居都知道你當兵了。 o2 0,0,0,0,0,0
1171_0001_A_6 o2停下來接過c1手裡的行李:你媽媽交待我了,等領了軍裝一定要照張相寄回去,讓街坊鄰居都知道你當兵了。 c1 0,0,0,0,0,0
1171_0001_A_7 c1開心地點了點頭。 c1 0,1,0,0,0,0
content是文字描述的內容,character是情感分析的角色,emotions是情感分類的標籤,資料中很大一部分是什麼情感都沒有的。
三、資料預處理的trick
其實從圖上看出來,這三個數字其實是包含一些意思在裡面的,1171這個數字代表的意思是一個劇本的編號,0001代表的是劇本中場景的編號,1代表的是句子的順序編號。所以瞭解到這些數字的資訊,你就很快會發現資料中存在的一個問題,那就是這些句子的排列並不一定是按照順序排列的。從愛奇藝官方寫的部落格來看,說明句子當前的情感與上下文都有很深的聯絡。所以拼接的上下文一定是需要正確的上下文關係的,之前很多參加比賽的同學反饋說拼接了上下文反而效果更差了,很重要的一點就是拼接錯了上下文。
對資料集進行排序
#劇本角色資料集排序
path1
=
‘data_sort。tsv’
path2
=
‘data。tsv’
text_list
=
[]
s
=
open
(
path1
,
‘w’
,
encoding
=
‘utf-8’
)
with
open
(
path2
,
‘r’
,
encoding
=
‘utf-8’
)
as
f
:
for
l
in
f
:
id
=
l
。
split
(
‘
\t
’
)[
0
]
script_ids
=
id
。
split
(
‘_’
)[
0
]
scene_nums
=
id
。
split
(
‘_’
)[
1
]
sentence_nums
=
id
。
split
(
‘_’
)[
3
]
text_list
。
append
((
script_ids
,
scene_nums
,
sentence_nums
,
l
。
replace
(
‘
\n
’
,
‘’
)))
text_list
。
sort
(
key
=
lambda
x
:
int
(
x
[
0
]))
n1
=
0
while
n1
<
len
(
text_list
):
scene_list
=
[(
i
[
1
],
i
[
2
],
i
[
3
])
for
i
in
text_list
if
text_list
[
n1
][
0
]
==
i
[
0
]]
n1
+=
len
(
scene_list
)
scene_list
。
sort
(
key
=
lambda
x
:
int
(
x
[
0
]))
n2
=
0
while
n2
<
len
(
scene_list
):
sentence_list
=
[(
i
[
1
],
i
[
2
])
for
i
in
scene_list
if
scene_list
[
n2
][
0
]
==
i
[
0
]]
n2
+=
len
(
sentence_list
)
sentence_list
。
sort
(
key
=
lambda
x
:
int
(
x
[
0
]))
for
t
in
sentence_list
:
s
。
write
(
t
[
1
]
+
‘
\n
’
)
f
。
close
()
s
。
close
()
排序之後會得到一個正確順序的劇本
官方建議資料集按照劇本來切分成訓練集和驗證集。
融合句子上文的語境
透過識別當前角色來拼接上文中該角色所說的一些話,透過簡單的嘗試發現拼接太長或者太短的上文資訊都會帶來模型指標的下降,大概maxlen控制在300-400區間會帶來比較好的語義資訊度。
#訓練集生成
s
=
open
(
‘train。tsv’
,
‘w’
,
encoding
=
‘utf-8’
)
data
=
{}
target
=
‘’
with
open
(
‘train_sort。tsv’
,
‘r’
,
encoding
=
‘utf-8’
)
as
f
:
for
l
in
f
。
readlines
():
if
target
!=
l
。
split
(
‘
\t
’
)[
0
]
。
split
(
‘_’
)[
0
]:
data
=
{}
character
=
l
。
split
(
‘
\t
’
)[
2
]
。
replace
(
‘
\n
’
,
‘’
)
content
=
l
。
split
(
‘
\t
’
)[
1
]
label
=
l
。
split
(
‘
\t
’
)[
3
]
id
=
l
。
split
(
‘
\t
’
)[
0
]
if
character
!=
‘’
:
if
data
。
get
(
character
)
!=
None
:
text
=
data
[
character
]
+
‘(’
+
‘角色是:’
+
character
+
‘)’
+
content
up_text
=
data
。
get
(
character
)
+
content
data
[
character
]
=
up_text
else
:
text
=
‘(’
+
‘角色是:’
+
character
+
‘)’
+
content
data
[
character
]
=
content
if
len
(
text
)
>
maxlen
:
old_text
=
‘’
text_list
=
[
i
for
i
in
text
。
split
(
‘。’
)
if
i
!=
‘’
]
for
t
in
range
(
len
(
text_list
)):
if
len
(
text_list
[
len
(
text_list
)
-
1
-
t
]
+
‘。’
+
old_text
)
<
maxlen
:
old_text
=
text_list
[
len
(
text_list
)
-
1
-
t
]
+
‘。’
+
old_text
else
:
break
if
old_text
!=
‘’
:
text
=
old_text
else
:
text
=
text
[(
len
(
text
)
-
1
-
maxlen
+
len
(
content
)):]
else
:
text
=
‘(’
+
‘角色是:’
+
character
+
‘)’
+
content
if
text
==
‘’
:
text
=
‘(’
+
‘角色是:’
+
character
+
‘)’
+
content
if
label
!=
‘’
and
len
(
label
。
split
(
‘,’
))
==
6
:
s
。
write
(
id
+
‘
\t
’
+
text
+
‘
\t
’
+
label
)
target
=
l
。
split
(
‘
\t
’
)[
0
]
。
split
(
‘_’
)[
0
]
f
。
close
()
s
。
close
()
#驗證集生成
s
=
open
(
‘vaild。tsv’
,
‘w’
,
encoding
=
‘utf-8’
)
data
=
{}
target
=
‘’
with
open
(
‘vaild_sort。tsv’
,
‘r’
,
encoding
=
‘utf-8’
)
as
f
:
for
l
in
f
。
readlines
():
if
target
!=
l
。
split
(
‘
\t
’
)[
0
]
。
split
(
‘_’
)[
0
]:
data
=
{}
character
=
l
。
split
(
‘
\t
’
)[
2
]
。
replace
(
‘
\n
’
,
‘’
)
content
=
l
。
split
(
‘
\t
’
)[
1
]
label
=
l
。
split
(
‘
\t
’
)[
3
]
id
=
l
。
split
(
‘
\t
’
)[
0
]
if
character
!=
‘’
:
if
data
。
get
(
character
)
!=
None
:
text
=
data
[
character
]
+
‘(’
+
‘角色是:’
+
character
+
‘)’
+
content
up_text
=
data
。
get
(
character
)
+
content
data
[
character
]
=
up_text
else
:
text
=
‘(’
+
‘角色是:’
+
character
+
‘)’
+
content
data
[
character
]
=
content
if
len
(
text
)
>
maxlen
:
old_text
=
‘’
text_list
=
[
i
for
i
in
text
。
split
(
‘。’
)
if
i
!=
‘’
]
for
t
in
range
(
len
(
text_list
)):
if
len
(
text_list
[
len
(
text_list
)
-
1
-
t
]
+
‘。’
+
old_text
)
<
maxlen
:
old_text
=
text_list
[
len
(
text_list
)
-
1
-
t
]
+
‘。’
+
old_text
else
:
break
if
old_text
!=
‘’
:
text
=
old_text
else
:
text
=
text
[(
len
(
text
)
-
1
-
maxlen
+
len
(
content
)):]
else
:
text
=
‘(’
+
‘角色是:’
+
character
+
‘)’
+
content
if
text
==
‘’
:
text
=
‘(’
+
‘角色是:’
+
character
+
‘)’
+
content
if
label
!=
‘’
and
len
(
label
。
split
(
‘,’
))
==
6
:
s
。
write
(
id
+
‘
\t
’
+
text
+
‘
\t
’
+
label
)
target
=
l
。
split
(
‘
\t
’
)[
0
]
。
split
(
‘_’
)[
0
]
f
。
close
()
s
。
close
()
#測試集生成
s
=
open
(
‘test。tsv’
,
‘w’
,
encoding
=
‘utf-8’
)
data
=
{}
target
=
‘’
with
open
(
‘test_sort。tsv’
,
‘r’
,
encoding
=
‘utf-8’
)
as
f
:
for
l
in
f
。
readlines
():
if
target
!=
l
。
split
(
‘
\t
’
)[
0
]
。
split
(
‘_’
)[
0
]:
data
=
{}
character
=
l
。
split
(
‘
\t
’
)[
2
]
。
replace
(
‘
\n
’
,
‘’
)
content
=
l
。
split
(
‘
\t
’
)[
1
]
id
=
l
。
split
(
‘
\t
’
)[
0
]
if
character
!=
‘’
:
if
data
。
get
(
character
)
!=
None
:
text
=
data
[
character
]
+
‘(’
+
‘角色是:’
+
character
+
‘)’
+
content
up_text
=
data
。
get
(
character
)
+
content
data
[
character
]
=
up_text
else
:
text
=
‘(’
+
‘角色是:’
+
character
+
‘)’
+
content
data
[
character
]
=
content
if
len
(
text
)
>
maxlen
:
old_text
=
‘’
text_list
=
[
i
for
i
in
text
。
split
(
‘。’
)
if
i
!=
‘’
]
for
t
in
range
(
len
(
text_list
)):
if
len
(
text_list
[
len
(
text_list
)
-
1
-
t
]
+
‘。’
+
old_text
)
<
maxlen
:
old_text
=
text_list
[
len
(
text_list
)
-
1
-
t
]
+
‘。’
+
old_text
else
:
break
if
old_text
!=
‘’
:
text
=
old_text
else
:
text
=
text
[(
len
(
text
)
-
1
-
maxlen
+
len
(
content
)):]
else
:
text
=
‘(’
+
‘角色是:’
+
character
+
‘)’
+
content
if
text
==
‘’
:
text
=
‘(’
+
‘角色是:’
+
character
+
‘)’
+
content
s
。
write
(
id
+
‘
\t
’
+
text
+
‘
\n
’
)
target
=
l
。
split
(
‘
\t
’
)[
0
]
。
split
(
‘_’
)[
0
]
f
。
close
()
s
。
close
()
生成訓練集、驗證集、測試集資料
融合後的完整文字
會生成如下三個文字
1、train。tsv
2、vaild。tsv
3、test。tsv
四、模型構建思路和程式碼
先說下我構建模型的思路,我還融合了比賽一開始分享的多標籤分類的做法,這裡要感謝下蘇劍林蘇神寫的bert4keras框架來搭建模型baseline。我再結合多工迴歸的方法,讓模型同時做迴歸和分類並且進行解碼融合,這個方法是我拍腦袋想的,反正本著試一試的心態玩玩。
模型結構
多標籤分類
我上一篇關於愛奇藝的方案中有具體介紹:
多工迴歸
多標籤分類的方法我在上一篇baseline分享中已經講了思路和做法,我這裡來講下多工迴歸的做法,六個全連線層分別對應著愛、樂、驚、怒、恐、哀六類情感,每個全連線層分別負責每一類情感的[0-3]四種等級的預測。
1)
標籤歸一化:
我首先將訓練集標籤區間歸一化,對每個標籤除以三就能將區間對映到0-1區間。
2)
啟用函式選擇:
將標籤歸一化之後,我們希望模型輸出也是在0-1區間的連續值,於是將最後全連線層設定為sigmoid函式,這樣做還有個好處就是避免輸出值太大或者出現負值導致loss不穩定,將輸出的正負無窮的區間對映到0-1區間內。
3)
多工loss組合:
透過多個任務組合的loss進行聯合最佳化。
4)
Bert多層資訊融合:
取Bert最後兩層進行全域性池化融合,Bert每一層學習到的資訊都是不一樣的,融合可以增強模型魯棒性。
5)
還原模型預測:
既然在之前將標籤對映0-1區間,模型輸出結果我們透過sigmoid函式保證輸出在0-1區間內,接下來我們就需要把這個結果映射回0-3區間中,對結果直接乘三。
多工迴歸和分類聯合學習
我把其中六個全連線層進行迴歸任務,然後把剩下的一個全連線層進行多標籤分類,多個任務用同一個Bert編碼,相當於共享Bert權重,讓Bert學習到多個任務隱含的聯絡這是更深層次的語義關係。
在解碼上的做法,首先我發現浮點型結果提交會比整型結果提交要好上不少,多工迴歸它是自然分佈在0-3區間的連續值,所以直接是浮點型結果不需要修改。這裡我需要最佳化的是多標籤分類的結果,我將每個情感型別分解為4個二分類標籤,在這裡需要進行劃窗聚合。如果單純使用argmax函式是無法得到浮點型連續值資料。
我採用的對映做法:
先用argmax得到最大的整數類別之後進行判斷
如果是0標籤的話:機率為1 - x1[x1。argmax()]
否則: 機率為x1。argmax() - 1 + x1[x1。argmax()]
模型訓練具體程式碼如下:
import
json
import
numpy
as
np
from
sklearn。metrics
import
mean_squared_error
,
mean_absolute_error
,
f1_score
from
bert4keras。backend
import
keras
,
search_layer
,
K
from
bert4keras。tokenizers
import
Tokenizer
from
bert4keras。models
import
build_transformer_model
from
bert4keras。optimizers
import
Adam
from
bert4keras。snippets
import
sequence_padding
,
DataGenerator
from
keras。layers
import
Lambda
,
Dense
,
SpatialDropout1D
from
tqdm
import
tqdm
from
keras。utils
import
multi_gpu_model
import
os
from
tensorflow。python。ops
import
array_ops
from
keras
import
backend
as
K
from
keras。layers
import
Dense
,
Input
,
concatenate
,
Bidirectional
,
LSTM
,
MaxPool1D
,
MaxPool3D
,
GlobalMaxPooling1D
,
\
GlobalAveragePooling1D
,
Dropout
import
tensorflow
as
tf
from
tfdeterminism
import
patch
patch
()
def
seed_everything
(
seed
=
0
):
np
。
random
。
seed
(
seed
)
tf
。
set_random_seed
(
seed
)
os
。
environ
[
‘PYTHONHASHSEED’
]
=
str
(
seed
)
os
。
environ
[
‘TF_DETERMINISTIC_OPS’
]
=
‘1’
SEED
=
1234
seed_everything
(
SEED
)
is_train
=
True
if
is_train
==
True
:
gpus
=
‘1,2,3,4,5,6,7’
else
:
gpus
=
‘0’
num_classes
=
1
maxlen
=
330
n_gpu
=
len
(
gpus
。
split
(
‘,’
))
batch_size
=
64
epochs
=
5
learn_rating
=
2e-5
model_name
=
‘roberta_pre’
model_save
=
model_name
+
‘。weights’
os
。
environ
[
“CUDA_VISIBLE_DEVICES”
]
=
gpus
if
model_name
==
‘nezha_base’
:
# bert配置
config_path
=
‘。。/nezha_base/bert_config。json’
checkpoint_path
=
‘。。/nezha_base/model。ckpt’
dict_path
=
‘。。/nezha_base/vocab。txt’
if
model_name
==
‘roberta_large’
:
# bert配置
config_path
=
‘。。/roberta_wwm_large/bert_config。json’
checkpoint_path
=
‘。。/roberta_wwm_large/bert_model。ckpt’
dict_path
=
‘。。/roberta_wwm_large/vocab。txt’
if
model_name
==
‘roberta_base’
:
# bert配置
config_path
=
‘。。/roberta_wwm_base/bert_config。json’
checkpoint_path
=
‘。。/roberta_wwm_base/bert_model。ckpt’
dict_path
=
‘。。/roberta_wwm_base/vocab。txt’
if
model_name
==
‘nezha_large’
:
# bert配置
config_path
=
‘。。/nezha_large/bert_config。json’
checkpoint_path
=
‘。。/nezha_large/model。ckpt’
dict_path
=
‘。。/nezha_large/vocab。txt’
if
model_name
==
‘nezha_wwm_large’
:
# bert配置
config_path
=
‘。。/nezha_wwm_large/bert_config。json’
checkpoint_path
=
‘。。/nezha_wwm_large/model。ckpt’
dict_path
=
‘。。/nezha_wwm_large/vocab。txt’
if
model_name
==
‘nezha_wwm_base’
:
# bert配置
config_path
=
‘。。/nezha_wwm_base/bert_config。json’
checkpoint_path
=
‘。。/nezha_wwm_base/model。ckpt’
dict_path
=
‘。。/nezha_wwm_base/vocab。txt’
if
model_name
==
‘roberta_zh_large’
:
# bert配置
config_path
=
‘。。/roberta_zh_large/bert_config_large。json’
checkpoint_path
=
‘。。/roberta_zh_large/roberta_zh_large_model。ckpt’
dict_path
=
‘。。/roberta_zh_large/vocab。txt’
if
model_name
==
‘roberta_pre’
:
# bert配置
config_path
=
‘。/roberta_model/bert_config。json’
checkpoint_path
=
‘。/roberta_model/bert_model。ckpt’
dict_path
=
‘。/roberta_model/vocab。txt’
if
model_name
==
‘nezha_pre’
:
# bert配置
config_path
=
‘。/nezha_model/bert_config。json’
checkpoint_path
=
‘。/nezha_model/model。ckpt’
dict_path
=
‘。/nezha_model/vocab。txt’
if
model_name
==
‘bert_pre’
:
# bert配置
config_path
=
‘。/bert_model/bert_config。json’
checkpoint_path
=
‘。/bert_model/bert_model。ckpt’
dict_path
=
‘。/bert_model/vocab。txt’
def
load_data
(
filename
):
“”“載入資料
單條格式:(文字, 標籤id)
”“”
D
=
[]
with
open
(
filename
)
as
f
:
for
i
,
l
in
enumerate
(
f
):
text
,
label1
=
l
。
split
(
‘
\t
’
)[
1
],
[
float
(
i
)
/
3
for
i
in
l
。
split
(
‘
\t
’
)[
2
]
。
split
(
‘,’
)]
tmp_list
=
[
int
(
i
*
3
)
for
i
in
label1
]
t0
,
t1
,
t2
,
t3
,
t4
,
t5
=
[
0
]
*
4
,
[
0
]
*
4
,
[
0
]
*
4
,
[
0
]
*
4
,
[
0
]
*
4
,
[
0
]
*
4
t0
[
tmp_list
[
0
]]
=
1。0
t1
[
tmp_list
[
1
]]
=
1。0
t2
[
tmp_list
[
2
]]
=
1。0
t3
[
tmp_list
[
3
]]
=
1。0
t4
[
tmp_list
[
4
]]
=
1。0
t5
[
tmp_list
[
5
]]
=
1。0
label2
=
t0
+
t1
+
t2
+
t3
+
t4
+
t5
D
。
append
((
text
,
label1
,
label2
))
return
D
# 載入資料集
train_data
=
load_data
(
‘。。/ccf/train。tsv’
)
valid_data
=
load_data
(
‘。。/ccf/vaild。tsv’
)
# 建立分詞器
tokenizer
=
Tokenizer
(
dict_path
,
do_lower_case
=
True
)
class
data_generator
(
DataGenerator
):
“”“資料生成器
”“”
def
__iter__
(
self
,
random
=
False
):
batch_token_ids
,
batch_segment_ids
,
batch_label1
,
batch_label2
,
batch_label3
,
batch_label4
,
batch_label5
,
batch_label6
,
batch_label7
=
[],
[],
[],
[],
[],
[],[],
[],
[]
for
is_end
,
(
text
,
label1
,
label2
)
in
self
。
sample
(
random
):
token_ids
,
segment_ids
=
tokenizer
。
encode
(
text
,
maxlen
=
maxlen
)
batch_token_ids
。
append
(
token_ids
)
batch_segment_ids
。
append
(
segment_ids
)
batch_label1
。
append
([
label1
[
0
]])
batch_label2
。
append
([
label1
[
1
]])
batch_label3
。
append
([
label1
[
2
]])
batch_label4
。
append
([
label1
[
3
]])
batch_label5
。
append
([
label1
[
4
]])
batch_label6
。
append
([
label1
[
5
]])
batch_label7
。
append
(
label2
)
if
len
(
batch_token_ids
)
==
self
。
batch_size
or
is_end
:
batch_token_ids
=
sequence_padding
(
batch_token_ids
)
batch_segment_ids
=
sequence_padding
(
batch_segment_ids
)
batch_label1
=
sequence_padding
(
batch_label1
)
batch_label2
=
sequence_padding
(
batch_label2
)
batch_label3
=
sequence_padding
(
batch_label3
)
batch_label4
=
sequence_padding
(
batch_label4
)
batch_label5
=
sequence_padding
(
batch_label5
)
batch_label6
=
sequence_padding
(
batch_label6
)
batch_label7
=
sequence_padding
(
batch_label7
)
yield
[
batch_token_ids
,
batch_segment_ids
],
[
batch_label1
,
batch_label2
,
batch_label3
,
batch_label4
,
batch_label5
,
batch_label6
,
batch_label7
]
batch_token_ids
,
batch_segment_ids
,
batch_label1
,
batch_label2
,
batch_label3
,
batch_label4
,
batch_label5
,
batch_label6
,
batch_label7
=
[],
[],
[]
,
[],
[],
[]
,
[],
[],
[]
# 轉換資料集
train_generator
=
data_generator
(
train_data
,
batch_size
)
valid_generator
=
data_generator
(
valid_data
,
batch_size
)
with
tf
。
device
(
‘/cpu:0’
):
# 載入預訓練模型
bert
=
build_transformer_model
(
config_path
=
config_path
,
checkpoint_path
=
checkpoint_path
,
return_keras_model
=
False
,
model
=
model_name
。
split
(
‘_’
)[
0
],
)
# output = Lambda(lambda x: x[:, 0])(bert。model。output)
v1
=
bert
。
model
。
output
v2
=
bert
。
model
。
get_layer
(
‘Transformer-10-FeedForward-Norm’
)
。
output
gp1
=
GlobalMaxPooling1D
()(
v1
)
gp2
=
GlobalMaxPooling1D
()(
v2
)
#gp3 = GlobalMaxPooling1D()(bert。model。get_layer(‘Transformer-9-FeedForward-Norm’)。output)
output
=
concatenate
([
gp1
,
gp2
],
axis
=
1
)
#output = Dropout(rate=0。1)(output)
love
=
Dense
(
units
=
num_classes
,
activation
=
‘sigmoid’
,
name
=
‘love’
,
kernel_initializer
=
bert
。
initializer
)(
output
)
joy
=
Dense
(
units
=
num_classes
,
activation
=
‘sigmoid’
,
name
=
‘joy’
,
kernel_initializer
=
bert
。
initializer
)(
output
)
fright
=
Dense
(
units
=
num_classes
,
activation
=
‘sigmoid’
,
name
=
‘fright’
,
kernel_initializer
=
bert
。
initializer
)(
output
)
anger
=
Dense
(
units
=
num_classes
,
activation
=
‘sigmoid’
,
name
=
‘anger’
,
kernel_initializer
=
bert
。
initializer
)(
output
)
fear
=
Dense
(
units
=
num_classes
,
activation
=
‘sigmoid’
,
name
=
‘fear’
,
kernel_initializer
=
bert
。
initializer
)(
output
)
sorrow
=
Dense
(
units
=
num_classes
,
activation
=
‘sigmoid’
,
name
=
‘sorrow’
,
kernel_initializer
=
bert
。
initializer
)(
output
)
classification
=
Dense
(
units
=
24
,
activation
=
‘sigmoid’
,
name
=
‘classification’
,
kernel_initializer
=
bert
。
initializer
)(
output
)
model
=
keras
。
models
。
Model
(
bert
。
model
。
input
,
[
love
,
joy
,
fright
,
anger
,
fear
,
sorrow
,
classification
])
#model。summary()
if
is_train
==
True
:
model
。
summary
()
m_model
=
multi_gpu_model
(
model
,
gpus
=
n_gpu
)
if
is_train
==
True
:
m_model
。
compile
(
loss
=
{
‘love’
:
‘mean_squared_error’
,
‘joy’
:
‘mean_squared_error’
,
‘fright’
:
‘mean_squared_error’
,
‘anger’
:
‘mean_squared_error’
,
‘fear’
:
‘mean_squared_error’
,
‘sorrow’
:
‘mean_squared_error’
,
‘classification’
:
‘binary_crossentropy’
},
optimizer
=
Adam
(
learn_rating
),
metrics
=
{
‘love’
:
‘mean_squared_error’
,
‘joy’
:
‘mean_squared_error’
,
‘fright’
:
‘mean_squared_error’
,
‘anger’
:
‘mean_squared_error’
,
‘fear’
:
‘mean_squared_error’
,
‘sorrow’
:
‘mean_squared_error’
,
‘classification’
:
‘accuracy’
},
)
def
adversarial_training
(
model
,
embedding_name
,
epsilon
=
1
):
“”“給模型新增對抗訓練
其中model是需要新增對抗訓練的keras模型,embedding_name
則是model裡邊Embedding層的名字。要在模型compile之後使用。
”“”
if
model
。
train_function
is
None
:
# 如果還沒有訓練函式
model
。
_make_train_function
()
# 手動make
old_train_function
=
model
。
train_function
# 備份舊的訓練函式
# 查詢Embedding層
for
output
in
model
。
outputs
:
embedding_layer
=
search_layer
(
output
,
embedding_name
)
if
embedding_layer
is
not
None
:
break
if
embedding_layer
is
None
:
raise
Exception
(
‘Embedding layer not found’
)
# 求Embedding梯度
embeddings
=
embedding_layer
。
embeddings
# Embedding矩陣
gradients
=
K
。
gradients
(
model
。
total_loss
,
[
embeddings
])
# Embedding梯度
gradients
=
K
。
zeros_like
(
embeddings
)
+
gradients
[
0
]
# 轉為dense tensor
# 封裝為函式
inputs
=
(
model
。
_feed_inputs
+
model
。
_feed_targets
+
model
。
_feed_sample_weights
)
# 所有輸入層
embedding_gradients
=
K
。
function
(
inputs
=
inputs
,
outputs
=
[
gradients
],
name
=
‘embedding_gradients’
,
)
# 封裝為函式
def
train_function
(
inputs
):
# 重新定義訓練函式
grads
=
embedding_gradients
(
inputs
)[
0
]
# Embedding梯度
delta
=
epsilon
*
grads
/
(
np
。
sqrt
((
grads
**
2
)
。
sum
())
+
1e-8
)
# 計算擾動
K
。
set_value
(
embeddings
,
K
。
eval
(
embeddings
)
+
delta
)
# 注入擾動
outputs
=
old_train_function
(
inputs
)
# 梯度下降
K
。
set_value
(
embeddings
,
K
。
eval
(
embeddings
)
-
delta
)
# 刪除擾動
return
outputs
model
。
train_function
=
train_function
# 覆蓋原訓練函式
if
is_train
==
True
:
# 寫好函式後,啟用對抗訓練只需要一行程式碼
adversarial_training
(
m_model
,
‘Embedding-Token’
,
0。5
)
def
evaluate
(
data
):
total
,
right
,
rmse
,
f1
=
0。
,
0。
,
0。
,
0。
#f = open(‘vaild_pred。tsv’, ‘w’, encoding=‘utf-8’)
for
x_true
,
y_true
,
_
in
tqdm
(
data
):
token_ids
,
segment_ids
=
tokenizer
。
encode
(
x_true
,
maxlen
=
maxlen
)
model_pred
=
m_model
。
predict
([[
token_ids
],
[
segment_ids
]])
y_pred_v1
=
[
i
[
0
][
0
]
*
3
for
i
in
model_pred
[:
6
]]
y_pred_v2
=
[]
pred
=
model_pred
[
6
][
0
]
for
i
in
range
(
0
,
24
,
4
):
x1
=
pred
[
i
:
i
+
4
]
if
int
(
x1
。
argmax
())
==
0
:
y_pred_v2
。
append
(
1
-
x1
[
x1
。
argmax
()])
else
:
y_pred_v2
。
append
(
x1
。
argmax
()
-
1
+
x1
[
x1
。
argmax
()])
y_pred
=
[
0。6
*
y_pred_v1
[
l
]
+
0。4
*
y_pred_v2
[
l
]
for
l
in
range
(
len
(
y_pred_v1
))]
y_true
=
[
i
*
3
for
i
in
y_true
]
rmse
+=
mean_squared_error
(
y_true
,
y_pred
)
**
0。5
/
len
(
data
)
#f。write(str(x_true) + ‘\t’ + ‘,’。join([str(l) for l in y_pred]) + ‘\n’)
#f。close()
return
1
/
(
1
+
rmse
)
class
Evaluator
(
keras
。
callbacks
。
Callback
):
“”“評估與儲存
”“”
def
__init__
(
self
):
self
。
best_val_score
=
0。
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
val_score
=
evaluate
(
valid_data
)
if
val_score
>
self
。
best_val_score
:
self
。
best_val_score
=
val_score
m_model
。
save_weights
(
model_save
)
m_model
。
load_weights
(
model_save
)
model
。
save_weights
(
model_save
)
(
u
‘val_score:
%。5f
, best_val_score:
%。5f
\n
’
%
(
val_score
,
self
。
best_val_score
)
)
def
model_predict
(
file
,
save_path
):
s
=
open
(
save_path
,
‘w’
,
encoding
=
‘utf-8’
)
s
。
write
(
‘id’
+
‘
\t
’
+
‘emotion’
+
‘
\n
’
)
with
open
(
file
,
‘r’
,
encoding
=
‘utf-8’
)
as
f
:
for
l
in
tqdm
(
f
。
readlines
()):
text
=
l
。
split
(
‘
\t
’
)[
1
]
token_ids
,
segment_ids
=
tokenizer
。
encode
(
text
,
maxlen
=
maxlen
)
model_pred
=
model
。
predict
([[
token_ids
],
[
segment_ids
]])
y_pred_v1
=
[
i
[
0
][
0
]
*
3
for
i
in
model_pred
[:
6
]]
y_pred_v2
=
[]
pred
=
model_pred
[
6
][
0
]
for
i
in
range
(
0
,
24
,
4
):
x1
=
pred
[
i
:
i
+
4
]
if
int
(
x1
。
argmax
())
==
0
:
y_pred_v2
。
append
(
1
-
x1
[
x1
。
argmax
()])
else
:
y_pred_v2
。
append
(
x1
。
argmax
()
-
1
+
x1
[
x1
。
argmax
()])
y_pred
=
[
0。6
*
y_pred_v1
[
l
]
+
0。4
*
y_pred_v2
[
l
]
for
l
in
range
(
len
(
y_pred_v1
))]
s
。
write
(
l
。
split
(
‘
\t
’
)[
0
]
+
‘
\t
’
+
‘,’
。
join
([
str
(
i
)
for
i
in
y_pred
])
+
‘
\n
’
)
f
。
close
()
s
。
close
()
if
__name__
==
‘__main__’
:
if
is_train
==
True
:
evaluator
=
Evaluator
()
m_model
。
fit
(
train_generator
。
forfit
(),
steps_per_epoch
=
len
(
train_generator
),
epochs
=
epochs
,
callbacks
=
[
evaluator
]
)
else
:
model
。
load_weights
(
model_save
)
model_predict
(
‘test。tsv’
,
‘result。tsv’
)
五、總結
根據上訴的做法可以達到差不多0。70227的水平,這並不是一個好成績,不過我目前也沒時間去做進一步的優化了,畢竟還要打工正事要緊,寫這篇文章的目的主要是給那些迷茫沒讀懂賽題的人分享我做這個比賽的思路,之前一直在玩kaggle,我特別喜歡kaggle的一點就是在比賽期間不斷有人分享自己做比賽的思路給別人帶來更多的靈感,我也是在kaggle上學習了不少東西,說實話國內這種分享的思維還是稍遜一籌,希望這篇文章能夠給一些新手帶來幫助。