您好,登錄后才能下訂單哦!
本篇內容介紹了“Pytorch使用tensor特定條件判斷索引的方法”的有關知識,在實際案例的操作過程中,不少人都會遇到這樣的困境,接下來就讓小編帶領大家學習一下如何處理這些情況吧!希望大家仔細閱讀,能夠學有所成!
torch.where() 用于將兩個broadcastable的tensor組合成新的tensor,類似于c++中的三元操作符“?:”
區別于python numpy中的where()直接可以找到特定條件元素的index
想要實現numpy中where()的功能,可以借助nonzero()
對應numpy中的where()操作效果:
補充:Pytorch torch.Tensor.detach()方法的用法及修改指定模塊權重的方法
detach的中文意思是分離,官方解釋是返回一個新的Tensor,從當前的計算圖中分離出來
需要注意的是,返回的Tensor和原Tensor共享相同的存儲空間,但是返回的 Tensor 永遠不會需要梯度
import torch as t a = t.ones(10,) b = a.detach() print(b) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
–假如A網絡輸出了一個Tensor類型的變量a, a要作為輸入傳入到B網絡中,如果我想通過損失函數反向傳播修改B網絡的參數,但是不想修改A網絡的參數,這個時候就可以使用detcah()方法
a = A(input) a = detach() b = B(a) loss = criterion(b, target) loss.backward()
import torch as t x = t.ones(1, requires_grad=True) x.requires_grad #True y = t.ones(1, requires_grad=True) y.requires_grad #True x = x.detach() #分離之后 x.requires_grad #False y = x+y #tensor([2.]) y.requires_grad #我還是True y.retain_grad() #y不是葉子張量,要加上這一行 z = t.pow(y, 2) z.backward() #反向傳播 y.grad #tensor([4.]) x.grad #None
以上代碼就說明了反向傳播到y就結束了,沒有到達x,所以x的grad屬性為None
–假如A網絡輸出了一個Tensor類型的變量a, a要作為輸入傳入到B網絡中,如果我想通過損失函數反向傳播修改A網絡的參數,但是不想修改B網絡的參數,這個時候又應該怎么辦了?
這時可以使用Tensor.requires_grad屬性,只需要將requires_grad修改為False即可.
for param in B.parameters(): param.requires_grad = False a = A(input) b = B(a) loss = criterion(b, target) loss.backward()
“Pytorch使用tensor特定條件判斷索引的方法”的內容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業相關的知識可以關注億速云網站,小編將為大家輸出更多高質量的實用文章!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。