中文字幕av专区_日韩电影在线播放_精品国产精品久久一区免费式_av在线免费观看网站

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

tensorflow如何使用freeze_graph.py將ckpt轉為pb文件

發布時間:2020-08-01 11:01:27 來源:億速云 閱讀:370 作者:小豬 欄目:開發技術

這篇文章主要為大家展示了tensorflow如何使用freeze_graph.py將ckpt轉為pb文件,內容簡而易懂,希望大家可以學習一下,學習完之后肯定會有收獲的,下面讓小編帶大家一起來看看吧。

廢話少說直接上代碼樣例如下

import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph
# 本來這個model本無需解釋太多,但是這么多人不能耐下心來看,那么我簡單的說一下吧
# network是你們自己定義的模型結構而已
# ps:
# def network(input):
# return tf.layers.max_pooling2d(input, 2, 2)
from model import network


os.environ['CUDA_VISIBLE_DEVICES']='2' #設置GPU


model_path = "path to /model.ckpt-0000" #設置model的路徑,因新版tensorflow會生成三個文件,只需寫到數字前


def main():

 tf.reset_default_graph()

 input_node = tf.placeholder(tf.float32, shape=(228, 304, 3)) #這個是你送入網絡的圖片大小,如果你是其他的大小自行修改
 input_node = tf.expand_dims(input_node, 0)
 flow = network(input_node)
 flow = tf.cast(flow, tf.uint8, 'out') #設置輸出類型以及輸出的接口名字,為了之后的調用pb的時候使用

 saver = tf.train.Saver()
 with tf.Session() as sess:

  saver.restore(sess, model_path)

  #保存圖
  tf.train.write_graph(sess.graph_def, 'output_model/pb_model', 'model.pb')
  #把圖和參數結構一起
  freeze_graph.freeze_graph('output_model/pb_model/model.pb', '', False, model_path, 'out','save/restore_all', 'save/Const:0', 'output_model/pb_model/frozen_model.pb', False, "")

 print("done")

if __name__ == '__main__':
 main()

這節是關于tensorflow的Freezing,字面意思是冷凍,可理解為整合合并;整合什么呢,就是將模型文件和權重文件整合合并為一個文件,主要用途是便于發布。

官方解釋可參考:https://www.tensorflow.org/extend/tool_developers/#freezing 

這里我按我的理解翻譯下,不對的地方請指正:
有一點令我們為比較困惑的是,tensorflow在訓練過程中,通常不會將權重數據保存的格式文件里(這里我理解是模型文件),反而是分開保存在一個叫checkpoint的檢查點文件里,當初始化時,再通過模型文件里的變量Op節點來從checkoupoint文件讀取數據并初始化變量。這種模型和權重數據分開保存的情況,使得發布產品時不是那么方便,所以便有了freeze_graph.py腳本文件用來將這兩文件整合合并成一個文件。
freeze_graph.py是怎么做的呢?首行它先加載模型文件,再從checkpoint文件讀取權重數據初始化到模型里的權重變量,再將權重變量轉換成權重 常量 (因為 常量 能隨模型一起保存在同一個文件里),然后再通過指定的輸出節點將沒用于輸出推理的Op節點從圖中剝離掉,再重新保存到指定的文件里(用write_graphdef或Saver)

文件目錄:tensorflow/python/tools/free_graph.py
測試文件:tensorflow/python/tools/free_graph_test.py 這個測試文件很有學習價值

參數:

總共有11個參數,一個個介紹下(必選: 表示必須有值;可選: 表示可以為空):
1、input_graph:(必選)模型文件,可以是二進制的pb文件,或文本的meta文件,用input_binary來指定區分(見下面說明)
2、input_saver:(可選)Saver解析器。保存模型和權限時,Saver也可以自身序列化保存,以便在加載時應用合適的版本。主要用于版本不兼容時使用。可以為空,為空時用當前版本的Saver。
3、input_binary:(可選)配合input_graph用,為true時,input_graph為二進制,為false時,input_graph為文件。默認False
4、input_checkpoint:(必選)檢查點數據文件。訓練時,給Saver用于保存權重、偏置等變量值。這時用于模型恢復變量值。
5、output_node_names:(必選)輸出節點的名字,有多個時用逗號分開。用于指定輸出節點,將沒有在輸出線上的其它節點剔除。
6、restore_op_name:(可選)從模型恢復節點的名字。升級版中已棄用。默認:save/restore_all
7、filename_tensor_name:(可選)已棄用。默認:save/Const:0
8、output_graph:(必選)用來保存整合后的模型輸出文件。
9、clear_devices:(可選),默認True。指定是否清除訓練時節點指定的運算設備(如cpu、gpu、tpu。cpu是默認)
10、initializer_nodes:(可選)默認空。權限加載后,可通過此參數來指定需要初始化的節點,用逗號分隔多個節點名字。
11、variable_names_blacklist:(可先)默認空。變量黑名單,用于指定不用恢復值的變量,用逗號分隔多個變量名字。

用法:

例:python tensorflow/python/tools/free_graph.py \
–input_graph=some_graph_def.pb \ 注意:這里的pb文件是用tf.train.write_graph方法保存的
–input_checkpoint=model.ckpt.1001 \ 注意:這里若是r12以上的版本,只需給.data-00000….前面的文件名,如:model.ckpt.1001.data-00000-of-00001,只需寫model.ckpt.1001
–output_graph=/tmp/frozen_graph.pb
–output_node_names=softmax

另外,如果模型文件是.meta格式的,也就是說用saver.Save方法和checkpoint一起生成的元模型文件,free_graph.py不適用,但可以改造下:
1、copy free_graph.py為free_graph_meta.py
2、修改free_graph.py,導入meta_graph:from tensorflow.python.framework import meta_graph
3、將91行到97行換成:input_graph_def = meta_graph.read_meta_graph_file(input_graph).graph_def

這樣改即可加載meta文件

以上就是關于tensorflow如何使用freeze_graph.py將ckpt轉為pb文件的內容,如果你們有學習到知識或者技能,可以把它分享出去讓更多的人看到。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

枞阳县| 闽侯县| 顺平县| 东山县| 茌平县| 文成县| 金寨县| 阳高县| 民和| 太原市| 怀安县| 香格里拉县| 沙雅县| 斗六市| 太湖县| 天柱县| 滨海县| 瓮安县| 康保县| 灵丘县| 达尔| 玉田县| 应用必备| 新丰县| 眉山市| 中山市| 新津县| 汤阴县| 郯城县| 浦江县| 黔西| 西充县| 常熟市| 宁海县| 措勤县| 峡江县| 大洼县| 江达县| 桓台县| 依兰县| 胶南市|