【CIKM 2020】ST-GRAT: A Novel Spatio-temporal Graph Atention Network for Traffic Forecasting
總結:
Code
Paper
此文使用Transformer結構進行交通預測。不僅為self-attention加上了先驗路網的歸納偏置,同時使用sentinel對self-attention做一個資訊選擇。另外此文是沒有放出原始碼的,我自己實現了一下,但是效果和文章還是有差距。
Preliminary
使用self-attention進行交通預測。交通預測任務在STGODE、DGCRN、StemGNN、HGCN博文中已經介紹,不再贅述。
Challenge
此文首先指出了之前基於GCN的文章都使用固定的靜態路網並沒有考慮到路網的動態變化性。雖然GaAN使用注意力機制動態的計算每個時間片的空間關係,但是卻忽略了流量的方向性同時也沒有利用已知的圖結構資訊。對於時間維度,DCRNN之類使用RNN的模型並不能很好的捕捉長距離的時間依賴。
Contribution
此文提出了一個基於Transformer名為ST-GRAT的交通預測模型使用self-attention捕捉時空間依賴。
此文對於self-attention做出改進,首先對於spatial attention加上路網資訊先驗,然後對於spatial和temporal attention都使用sentinel,sentinel可以自適應的選擇保留原始資訊或者獲取新資訊。
此文在與baselines的比較取得了sota的結果。
Method
Overview:
可以看到與原始的Transformer相比,此文多了一個可以捕捉空間依賴的Spatial Attention。輸入首先經過Embedding Layer加上時空嵌入向量增強模型的時空表達能力;然後編碼器使用時空注意力提取時空特徵對未來做出預測;解碼器與編碼器相比使用的是Masked Temporal Attention,因為對於未來來說歷史是不可見的需要滿足因果關係,另外解碼器還用了Encoder-Decoder Attention獲取歷史資訊。
Embedding Layer:
與Transformer相比,因為是多變數預測任務,所以不僅需要時間維度的positional embedding,還需要帶有空間結構資訊的spatial embedding。此文使用圖嵌入演算法中的LINE獲取spatial embedding維度是
,同時使用positional embedding得到temporal embedding維度是
,然後作者將兩個嵌入相加(Broadcasting)再使與輸入拼接,最後作者將拼接的向量對映為模型維度。
LINE演算法可以參考這篇博文LINE,embedding程式碼如下所示:
class
Emb
(
nn
。
Module
):
def
__init__
(
self
,
outfea
,
max_len
=
12
):
super
(
Emb
,
self
)
。
__init__
()
self
。
ff
=
nn
。
Linear
(
2
*
outfea
,
outfea
)
pe
=
torch
。
zeros
(
max_len
,
outfea
)
for
pos
in
range
(
max_len
):
for
i
in
range
(
0
,
outfea
,
2
):
pe
[
pos
,
i
]
=
math
。
sin
(
pos
/
(
10000
**
((
2
*
i
)
/
outfea
)))
pe
[
pos
,
i
+
1
]
=
math
。
cos
(
pos
/
(
10000
**
((
2
*
(
i
+
1
))
/
outfea
)))
pe
=
pe
。
unsqueeze
(
0
)
。
unsqueeze
(
2
)
#[1,T,1,F]
self
。
register_buffer
(
‘pe’
,
pe
)
def
forward
(
self
,
x
,
se
):
se
=
se
。
unsqueeze
(
0
)
。
unsqueeze
(
1
)
。
repeat
(
x
。
shape
[
0
],
1
,
1
,
1
)
# [B,1,N,F]
ste
=
se
+
Variable
(
self
。
pe
[:,:
x
。
shape
[
1
],:,:],
requires_grad
=
False
)
# [B,T,N,F]
x
=
torch
。
cat
([
x
,
ste
],
-
1
)
# [B,T,N,2*F]
x
=
self
。
ff
(
x
)
# [B,T,N,F]
return
x
Spatial Attention:
此文也使用多頭注意力機制,對於第
個節點的第
個頭,首先計算所有節點對間的注意力係數,也即
,此文還另添加了先驗結構知識,如下式:
程式碼表示為:
# 計算入流
e
=
torch
。
matmul
(
query
[:
self
。
k
//
2
],
key
[:
self
。
k
//
2
]
。
transpose
(
-
1
,
-
2
))
/
(
self
。
d
**
0。5
)
+
self
。
dcni
(
A
)
# 計算出流
eo
=
torch
。
matmul
(
query
[
self
。
k
//
2
:],
key
[
self
。
k
//
2
:]
。
transpose
(
-
1
,
-
2
))
/
(
self
。
d
**
0。5
)
+
self
。
dcno
(
AT
)
其中式子的前半部分是普通的注意力機制,後面是新增的先驗。
由鄰接矩陣做擴散過程得到
的矩陣,如果
為奇數是出流的擴散過程,偶數的入流的擴散過程,體現了流量的有向性,如下所示:
是超引數控制擴散的階數,
是可學習的引數對不同階的資訊做自適應選擇。程式碼表示為:
# 圖擴散類
class
DCN
(
nn
。
Module
):
def
__init__
(
self
,
N
,
h
,
k
):
super
(
DCN
,
self
)
。
__init__
()
self
。
h
=
h
self
。
k
=
k
self
。
W
=
[
nn
。
Parameter
(
torch
。
empty
([
N
,
N
],
dtype
=
torch
。
float32
,
device
=
device
),
requires_grad
=
True
)
for
i
in
range
(
k
*
h
)]
for
i
in
range
(
k
*
h
):
torch
。
nn
。
init
。
xavier_normal_
(
self
。
W
[
i
])
def
forward
(
self
,
a
):
p
=
[]
for
i
in
range
(
self
。
h
):
pi
=
self
。
W
[
i
*
self
。
k
]
*
a
[
0
]
for
j
in
range
(
1
,
self
。
k
):
pi
+=
self
。
W
[
i
*
self
。
k
+
j
]
*
a
[
j
]
p
。
append
(
pi
)
p
=
torch
。
stack
(
p
,
0
)
。
unsqueeze
(
1
)
。
unsqueeze
(
1
)
return
p
此文除了計算
以外,還額外計算了自身與自身的相似度
,並稱為sentinel,如下式:
程式碼表示為:
# es=(Q*K)。sum(-1),sigmoid為了保持es為正。
es
=
torch
。
sigmoid
((
query
[:
self
。
k
//
2
]
*
keys
[:
self
。
k
//
2
])
。
sum
(
-
1
,
keepdim
=
True
)
/
(
self
。
d
**
0。5
))
另外在將相關係數
使用softmax歸一化時,此文為分母額外增加了
,也就是讓
,下式是歸一化過程:
# 實現softmax,減去最大值防止溢位
e
=
e
-
torch
。
max
(
e
,
-
1
,
keepdim
=
True
)[
0
]
a
=
torch
。
exp
(
e
)
/
(
es
+
torch
。
exp
(
e
)
。
sum
(
-
1
,
keepdim
=
True
))
其實我覺得這裡是有點問題的,
並不能保證非負也就是說分母是可能為0的。
最後此文使用計算到相關係數進行訊息傳遞,如下所示:
程式碼表示為:
# 入流
valuei
=
(
1
-
a
。
sum
(
-
1
,
keepdim
=
True
))
*
values
[:
self
。
k
//
2
]
+
torch
。
matmul
(
a
,
value
[:
self
。
k
//
2
])
# 出流
valueo
=
(
1
-
ao
。
sum
(
-
1
,
keepdim
=
True
))
*
values
[
self
。
k
//
2
:]
+
torch
。
matmul
(
ao
,
value
[
self
。
k
//
2
:])
其中
是sentinel在歸一化機率中所佔的比例,後面是其它所有節點佔的比例,所以作者說能控制資訊是保持原樣
還是吸收其它節點資訊 (
)。作者也畫了spatial attention的結構,挺難看懂的其實,像電路圖。
Temporal Attention:
Temporal Attention與Spatial Attention相比只是少了新增圖結構先驗和sentinel部分。
Experiments
此文使用了METR-LA和PEMS-BAY兩個交通資料集。
Experimental Results:
如下圖所示,ST-GRAT除了在長期預測任務上略輸GMAN一點以外,其它任務均取得sota。ST-GRAT是動態解碼式的迭代預測,而GMAN是生成式預測不存在累積誤差,所以長期結果GMAN優一點也很正常。
此外此文還將一天分為多個時段對每個時段分別進行預測,可以看到ST-GRAT在每個時段都取得了sota的效果,而Graph WaveNet明顯在高峰時刻的效果更好。此外此文選擇了速度變化較快的區間稱為Impeded Interval進行預測,ST-GRAT與其他模型相比取得了大幅上升。
Ablation Study:
作者進行消融實驗驗證了各個部分對於模型的有效性。
Computation Time:
ST-GRAT的執行速度優於RNN和其它注意力模型,只比使用卷積的生成模型Graph WaveNet慢。