您好,登錄后才能下訂單哦!
pytorch保存數據
pytorch保存數據的格式為.t7文件或者.pth文件,t7文件是沿用torch7中讀取模型權重的方式。而pth文件是python中存儲文件的常用格式。而在keras中則是使用.h6文件。
# 保存模型示例代碼 print('===> Saving models...') state = { 'state': model.state_dict(), 'epoch': epoch # 將epoch一并保存 } if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') torch.save(state, './checkpoint/autoencoder.t7')
保存用到torch.save函數,注意該函數第一個參數可以是單個值也可以是字典,字典可以存更多你要保存的參數(不僅僅是權重數據)。
pytorch讀取數據
pytorch讀取數據使用的方法和我們平時使用預訓練參數所用的方法是一樣的,都是使用load_state_dict這個函數。
下方的代碼和上方的保存代碼可以搭配使用。
print('===> Try resume from checkpoint') if os.path.isdir('checkpoint'): try: checkpoint = torch.load('./checkpoint/autoencoder.t7') model.load_state_dict(checkpoint['state']) # 從字典中依次讀取 start_epoch = checkpoint['epoch'] print('===> Load last checkpoint data') except FileNotFoundError: print('Can\'t found autoencoder.t7') else: start_epoch = 0 print('===> Start from scratch')
以上是pytorch讀取的方法匯總,但是要注意,在使用官方的預處理模型進行讀取時,一般使用的格式是pth,使用官方的模型讀取命令會檢查你模型的格式是否正確,如果不是使用官方提供模型通過下面的函數強行讀取模型(將其他模型例如caffe模型轉過來的模型放到指定目錄下)會發生錯誤。
def vgg19(pretrained=False, **kwargs): """VGG 19-layer model (configuration "E") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = VGG(make_layers(cfg['E']), **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) return model
假如我們有從caffe模型轉過來的pytorch模型([0-255,BGR]),我們可以使用:
model_dir = '自己的模型地址' model = VGG() model.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))
也就是pytorch的讀取函數進行讀取即可。
以上這篇Pytorch之保存讀取模型實例就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。