[pytorch]一些有關程式碼的細節問題
作者:由 kunkun0v0 發表于 繪畫時間:2021-04-04
detach的作用:
(研究simsiam時候發現的detach完全不會用orz)
import
torch
a
=
torch
。
tensor
([
1
,
2
,
3。
],
requires_grad
=
True
)
(
a
。
grad
)
# 輸出None
b
=
torch
。
sigmoid
(
a
)
c
=
a
+
b
。
detach
()
c
。
mean
()
。
backward
()
(
a
。
grad
)
# tensor([0。3333, 0。3333, 0。3333])
(
c
)
# tensor([1。7311, 2。8808, 3。9526], grad_fn=
(
b
。
grad
)
# None
這裡可以發現我們在使用detach後,計算c的一個分支就被剝離出來不參與梯度計算
backward的作用:
見上部分程式碼,我們在呼叫backward後會將計算圖釋放,所以中間節點的梯度無法得到。如果想要得到中間節點的梯度,則對中間節點呼叫。retain_grad()
多次呼叫backward會對梯度進行累加