中文字幕av专区_日韩电影在线播放_精品国产精品久久一区免费式_av在线免费观看网站

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

如何將TensorFlow的模型網絡導出為單個文件

發布時間:2021-08-13 10:30:55 來源:億速云 閱讀:139 作者:小新 欄目:開發技術

這篇文章主要為大家展示了“如何將TensorFlow的模型網絡導出為單個文件”,內容簡而易懂,條理清晰,希望能夠幫助大家解決疑惑,下面讓小編帶領大家一起研究并學習一下“如何將TensorFlow的模型網絡導出為單個文件”這篇文章吧。

有時候,我們需要將TensorFlow的模型導出為單個文件(同時包含模型架構定義與權重),方便在其他地方使用(如在c++中部署網絡)。利用tf.train.write_graph()默認情況下只導出了網絡的定義(沒有權重),而利用tf.train.Saver().save()導出的文件graph_def與權重是分離的,因此需要采用別的方法。

我們知道,graph_def文件中沒有包含網絡中的Variable值(通常情況存儲了權重),但是卻包含了constant值,所以如果我們能把Variable轉換為constant,即可達到使用一個文件同時存儲網絡架構與權重的目標。

我們可以采用以下方式凍結權重并保存網絡:

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 構造網絡
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要給輸出tensor取一個名字!!
output = tf.add(a, b, name='out')

# 轉換Variable為constant,并將網絡寫入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 這里需要填入輸出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

當恢復網絡時,可以使用如下方式:

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

輸出結果為:

[array([[ 7.],
       [ 8.]], dtype=float32)]

可以看到之前的權重確實保存了下來!!

問題來了,我們的網絡需要能有一個輸入自定義數據的接口啊!不然這玩意有什么用。。別急,當然有辦法。

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

用上述代碼重新保存網絡至graph.pb,這次我們有了一個輸入placeholder,下面來看看怎么恢復網絡并輸入自定義數據。

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))

輸出結果為:

[array([[ 11.],
       [ 12.]], dtype=float32)]

可以看到結果沒有問題,當然在input_map那里可以替換為新的自定義的placeholder,如下所示:

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))

看看輸出,同樣沒有問題。

[array([[ 11.],
       [ 12.]], dtype=float32)]

另外需要說明的一點是,在利用tf.train.write_graph寫網絡架構的時候,如果令as_text=True了,則在導入網絡的時候,需要做一點小修改。

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

以上是“如何將TensorFlow的模型網絡導出為單個文件”這篇文章的所有內容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內容對大家有所幫助,如果還想學習更多知識,歡迎關注億速云行業資訊頻道!

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

浦北县| 科尔| 翁牛特旗| 安乡县| 凤翔县| 麻栗坡县| 玛纳斯县| 辰溪县| 龙游县| 秀山| 凤凰县| 万荣县| 六安市| 蓝田县| 广昌县| 额济纳旗| 公安县| 永仁县| 恭城| 宣武区| 磴口县| 平邑县| 屯昌县| 孟州市| 和平县| 郸城县| 龙陵县| 筠连县| 南康市| 九台市| 浪卡子县| 波密县| 岳阳县| 磐石市| 遵义市| 潜江市| 铜陵市| 刚察县| 修水县| 金秀| 塔河县|