您好,登錄后才能下訂單哦!
這篇文章主要為大家展示了“在pytorch中如何對非葉節點的變量進行梯度計算”,內容簡而易懂,條理清晰,希望能夠幫助大家解決疑惑,下面讓小編帶領大家一起研究并學習一下“在pytorch中如何對非葉節點的變量進行梯度計算”這篇文章吧。
在pytorch中一般只對葉節點進行梯度計算,也就是下圖中的d,e節點,而對非葉節點,也即是c,b節點則沒有顯式地去保留其中間計算過程中的梯度(因為一般來說只有葉節點才需要去更新),這樣可以節省很大部分的顯存,但是在調試過程中,有時候我們需要對中間變量梯度進行監控,以確保網絡的有效性,這個時候我們需要打印出非葉節點的梯度,為了實現這個目的,我們可以通過兩種手段進行。
注冊hook函數
Tensor.register_hook[2] 可以注冊一個反向梯度傳導時的hook函數,這個hook函數將會在每次計算 關于該張量 的時候 被調用,經常用于調試的時候打印出非葉節點梯度。當然,通過這個手段,你也可以自定義某一層的梯度更新方法。[3] 具體到這里的打印非葉節點的梯度,代碼如:
def hook_y(grad): print(grad) x = Variable(torch.ones(2, 2), requires_grad=True) y = x + 2 z = y * y * 3 y.register_hook(hook_y) out = z.mean() out.backward()
輸出如:
tensor([[4.5000, 4.5000], [4.5000, 4.5000]])
retain_grad()
Tensor.retain_grad()顯式地保存非葉節點的梯度,當然代價就是會增加顯存的消耗,而用hook函數的方法則是在反向計算時直接打印,因此不會增加顯存消耗,但是使用起來retain_grad()要比hook函數方便一些。代碼如:
x = Variable(torch.ones(2, 2), requires_grad=True) y = x + 2 y.retain_grad() z = y * y * 3 out = z.mean() out.backward() print(y.grad)
輸出如:
tensor([[4.5000, 4.5000], [4.5000, 4.5000]])
以上是“在pytorch中如何對非葉節點的變量進行梯度計算”這篇文章的所有內容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內容對大家有所幫助,如果還想學習更多知識,歡迎關注億速云行業資訊頻道!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。