您好,登錄后才能下訂單哦!
這篇文章主要為大家展示了“Pytorch如何加載部分預訓練模型的參數”,內容簡而易懂,條理清晰,希望能夠幫助大家解決疑惑,下面讓小編帶領大家一起研究并學習一下“Pytorch如何加載部分預訓練模型的參數”這篇文章吧。
直接加載預選臉模型
如果我們使用的模型和預訓練模型完全一樣,那么我們就可以直接加載別人的模型,還有一種情況,我們在訓練自己模型的過程中,突然中斷了,但只要我們保存了之前的模型的參數也可以使用下面的代碼直接加載我們保存的模型繼續訓練,不用從頭開始。
model=DPN(*args, **kwargs) model.load_state_dict(torch.load("DPN.pth"))
這樣的加載方式是基于Pytorch使用的模型存儲方法:
torch.save(DPN.state_dict(), "DPN.pth")
加載部分預訓練模型參數
其實大多數時候我們根據自己的任物所提出的模型是在一些公開模型的基礎上改變而來,其中公開模型的參數我們沒有必要在從頭開始訓練,只要加載其訓練好的模型參數即可,這樣有助于提高訓練的準確率和我們模型的泛化能力。
model = DPN(num_init_features=64, k_R=96, G=32, k_sec=(3,4,20,3), inc_sec=(16,32,24,128), num_classes=1,decoder=args.decoder) http = {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'} pretrained_dict=model_zoo.load_url(http['url']) model_dict = model.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#filter out unnecessary keys model_dict.update(pretrained_dict) model.load_state_dict(model_dict) model = torch.nn.DataParallel(model).cuda()
因為需要刪除預訓練模型中不匹配的的鍵,也就是層的名字。
以上是“Pytorch如何加載部分預訓練模型的參數”這篇文章的所有內容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內容對大家有所幫助,如果還想學習更多知識,歡迎關注億速云行業資訊頻道!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。