中文字幕av专区_日韩电影在线播放_精品国产精品久久一区免费式_av在线免费观看网站

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

如何使用TensorFlow 載入模型

發布時間:2021-05-20 15:43:21 來源:億速云 閱讀:164 作者:Leah 欄目:開發技術

這篇文章將為大家詳細講解有關如何使用TensorFlow 載入模型,文章內容質量較高,因此小編分享給大家做個參考,希望大家閱讀完這篇文章后對相關知識有一定的了解。

一、TensorFlow常規模型加載方法

保存模型

tf.train.Saver()類,.save(sess, ckpt文件目錄)方法

參數名稱功能說明默認值
var_listSaver中存儲變量集合全局變量集合
reshape加載時是否恢復變量形狀True
sharded是否將變量輪循放在所有設備上True
max_to_keep保留最近檢查點個數5
restore_sequentially是否按順序恢復變量,模型較大時順序恢復內存消耗小True

var_list是字典形式{變量名字符串: 變量符號},相對應的restore也根據同樣形式的字典將ckpt中的字符串對應的變量加載給程序中的符號。

如果Saver給定了字典作為加載方式,則按照字典來,如:saver = tf.train.Saver({"v/ExponentialMovingAverage":v}),否則每個變量尋找自己的name屬性在ckpt中的對應值進行加載。

加載模型

當我們基于checkpoint文件(ckpt)加載參數時,實際上我們使用Saver.restore取代了initializer的初始化

如何使用TensorFlow 載入模型

checkpoint文件會記錄保存信息,通過它可以定位最新保存的模型:

ckpt = tf.train.get_checkpoint_state('./model/')
print(ckpt.model_checkpoint_path)

如何使用TensorFlow 載入模型 

.meta文件保存了當前圖結構

.index文件保存了當前參數名

.data文件保存了當前參數值

tf.train.import_meta_graph函數給出model.ckpt-n.meta的路徑后會加載圖結構,并返回saver對象

ckpt = tf.train.get_checkpoint_state('./model/')

tf.train.Saver函數會返回加載默認圖的saver對象,saver對象初始化時可以指定變量映射方式,根據名字映射變量(『TensorFlow』滑動平均)

saver = tf.train.Saver({"v/ExponentialMovingAverage":v})

saver.restore函數給出model.ckpt-n的路徑后會自動尋找參數名-值文件進行加載

saver.restore(sess,'./model/model.ckpt-0')
saver.restore(sess,ckpt.model_checkpoint_path)

1.不加載圖結構,只加載參數

由于實際上我們參數保存的都是Variable變量的值,所以其他的參數值(例如batch_size)等,我們在restore時可能希望修改,但是圖結構在train時一般就已經確定了,所以我們可以使用tf.Graph().as_default()新建一個默認圖(建議使用上下文環境),利用這個新圖修改和變量無關的參值大小,從而達到目的。

'''
使用原網絡保存的模型加載到自己重新定義的圖上
可以使用python變量名加載模型,也可以使用節點名
'''
import AlexNet as Net
import AlexNet_train as train
import random
import tensorflow as tf
 
IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'
 
with tf.Graph().as_default() as g:
 
 x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3])
 y = Net.inference_1(x, N_CLASS=5, train=False)
 
 with tf.Session() as sess:
  # 程序前面得有 Variable 供 save or restore 才不報錯
  # 否則會提示沒有可保存的變量
  saver = tf.train.Saver()
 
  ckpt = tf.train.get_checkpoint_state('./model/')
  img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
  img = sess.run(tf.expand_dims(tf.image.resize_images(
   tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0))
 
  if ckpt and ckpt.model_checkpoint_path:
   print(ckpt.model_checkpoint_path)
   saver.restore(sess,'./model/model.ckpt-0')
   global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
   res = sess.run(y, feed_dict={x: img})
   print(global_step,sess.run(tf.argmax(res,1)))

2.加載圖結構和參數

'''
直接使用使用保存好的圖
無需加載python定義的結構,直接使用節點名稱加載模型
由于節點形狀已經定下來了,所以有不便之處,placeholder定義batch后單張傳會報錯
現階段不推薦使用,以后如果理解深入了可能會找到使用方法
'''
import AlexNet_train as train
import random
import tensorflow as tf
 
IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'
 
 
ckpt = tf.train.get_checkpoint_state('./model/')       # 通過檢查點文件鎖定最新的模型
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta') # 載入圖結構,保存在.meta文件中
 
with tf.Session() as sess:
 saver.restore(sess,ckpt.model_checkpoint_path)      # 載入參數,參數保存在兩個文件中,不過restore會自己尋找
 
 img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
 img = sess.run(tf.image.resize_images(
  tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)))
 imgs = []
 for i in range(128):
  imgs.append(img)
 print(sess.run(tf.get_default_graph().get_tensor_by_name('fc3:0'),feed_dict={'Placeholder:0': imgs}))
 
 '''
 img = sess.run(tf.expand_dims(tf.image.resize_images(
  tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)), 0))
 print(img)
 imgs = []
 for i in range(128):
  imgs.append(img)
 print(sess.run(tf.get_default_graph().get_tensor_by_name('conv1:0'),
     feed_dict={'Placeholder:0':img}))

注意,在所有兩種方式中都可以通過調用節點名稱使用節點輸出張量,節點.name屬性返回節點名稱。

3.簡化版本

# 連同圖結構一同加載
ckpt = tf.train.get_checkpoint_state('./model/')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
with tf.Session() as sess:
 saver.restore(sess,ckpt.model_checkpoint_path)
    
# 只加載數據,不加載圖結構,可以在新圖中改變batch_size等的值
# 不過需要注意,Saver對象實例化之前需要定義好新的圖結構,否則會報錯
saver = tf.train.Saver()
with tf.Session() as sess:
 ckpt = tf.train.get_checkpoint_state('./model/')
 saver.restore(sess,ckpt.model_checkpoint_path)

二、TensorFlow二進制模型加載方法

這種加載方法一般是對應網上各大公司已經訓練好的網絡模型進行修改的工作

# 新建空白圖
self.graph = tf.Graph()
# 空白圖列為默認圖
with self.graph.as_default():
 # 二進制讀取模型文件
 with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:
  # 新建GraphDef文件,用于臨時載入模型中的圖
  graph_def = tf.GraphDef()
  # GraphDef加載模型中的圖
  graph_def.ParseFromString(f.read())
  # 在空白圖中加載GraphDef中的圖
  tf.import_graph_def(graph_def,name='')
  # 在圖中獲取張量需要使用graph.get_tensor_by_name加張量名
  # 這里的張量可以直接用于session的run方法求值了
  # 補充一個基礎知識,形如'conv1'是節點名稱,而'conv1:0'是張量名稱,表示節點的第一個輸出張量
  self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)
  self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name in self.layer_operation_names]

關于如何使用TensorFlow 載入模型就分享到這里了,希望以上內容可以對大家有一定的幫助,可以學到更多知識。如果覺得文章不錯,可以把它分享出去讓更多的人看到。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

永州市| 东兴市| 怀宁县| 襄垣县| 吉安县| 凯里市| 齐河县| 宁海县| 宝山区| 山丹县| 威信县| 陆川县| 讷河市| 湟中县| 梁山县| 朝阳市| 加查县| 滨海县| 攀枝花市| 鄂尔多斯市| 武安市| 岱山县| 建水县| 丰宁| 乡宁县| 上蔡县| 米林县| 岫岩| 河津市| 灌云县| 云和县| 罗源县| 辽中县| 镇赉县| 宾阳县| 广昌县| 桐乡市| 崇礼县| 四会市| 务川| 洛阳市|