中文字幕av专区_日韩电影在线播放_精品国产精品久久一区免费式_av在线免费观看网站

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

TensorFlow如何將ckpt文件固化成pb文件

發布時間:2021-08-13 08:27:38 來源:億速云 閱讀:385 作者:小新 欄目:開發技術

小編給大家分享一下TensorFlow如何將ckpt文件固化成pb文件,相信大部分人都還不怎么了解,因此分享這篇文章給大家參考一下,希望大家閱讀完這篇文章后大有收獲,下面讓我們一起去了解一下吧!

將yolo3目標檢測框架訓練出來的ckpt文件固化成pb文件,主要利用了GitHub上的該項目。

為什么要最終生成pb文件呢?簡單來說就是直接通過tf.saver保存行程的ckpt文件其變量數據和圖是分開的。我們知道TensorFlow是先畫圖,然后通過placeholde往圖里面喂數據。這種解耦形式存在的方法對以后的遷移學習以及對程序進行微小的改動提供了極大的便利性。但是對于訓練好,以后不再改變的話這種存在就不再需要。一方面,ckpt文件儲存的數據都是變量,既然我們不再改動,就應當讓其變成常量,直接‘燒'到圖里面。另一方面,對于線上的模型,我們一般是通過C++或者C語言編寫的程序進行調用。所以一般模型最終形式都是應該寫成pb文件的形式。

由于這次的程序直接從GitHub上下載后改動較小就能夠運行,也就是自己寫了很少一部分程序。因此進行調試的時候還出現了以前根本沒有注意的一些小問題,同時發現自己對TensorFlow還需要更加詳細的去研讀。

首先對程序進行保存的時候,利用 saver = tf.train.Saver(), saver.save(sess,checkpoint_path,global_step=global_step)對訓練的數據進行保存,保存格式為ckpt。但是在恢復的時候一直提示有問題,(其恢復語句為:saver = tf.train.Saver(), saver.restore(sess,ckpt_path),其中,ckpt_path是保存ckpt的文件夾路徑)。出現問題的原因我估計是因為我是按照每50個epoch進行保存,而不是讓其進行固定次數的batch進行保存,這種固定batch次數的保存系統會自動保存最近5次的ckpt文件(該方法的ckpt_path=tf.train,latest_checkpoint('ckpt/')進行回復)。那么如何將利用epoch的次數進行保存呢(這種保存不是近5次的保存,而是每進行一次保存就會留下當時保存的ckpt,而那種按照batch的會在第n次保存,會將n-5次的刪除,n>5)。

我們可以利用:ckpt = tf.train.get_checkpoint_state(ckpt_path),獲取最新的ckptpoint文件,然后利用saver.restore(sess,ckpt.checkpoint_path)進行恢復。當然為了安全起見,應該對ckpt和ckpt.checkpoint_path進行判斷是否存在后,再進行恢復語句的調用,建議打開ckptpoint看一下,里面記錄的最近五次的model的路徑,一目了然。即:

  saver = tf.train.Saver()
  ckpt = tf.train.get_checkpoint_state(model_path)
  if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)

對于固化網絡,網上有很多的介紹。之所以再介紹,還是由于是用了別人的網絡而不是自己的網絡遇到的坑。在固化時候我們需要知道輸出tensor的名字,而再恢復的時候我們需要知道placeholder的名字。但是,如果網絡復雜或者別人的網絡命名比較復雜,或者name=,根本就沒有自己命名而用的系統自定義的,這樣捋起來還是比較費勁的。當時在網上查找的一些方法,像打印整個網絡變量的方法(先不管輸出的網路名稱,甚至隨便起一個名字,先固化好pb文件,然后對pb文件進行讀取,最后打印操作的名字:

 graph = tf.get_default_graph()
  input_graph_def = graph.as_graph_def()
 
  output_graph_def = graph_util.convert_variables_to_constants(
    sess,
    input_graph_def,
    ['cls_score/cls_score', 'cls_prob'] # We split on comma for convenience
  )
  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print ('開始打印節點名字')
  for op in graph.get_operations():
    print(op.name)
  print("%d ops in the final graph." % len(output_graph_def.node))

代碼一

這樣盡然也能打印出來(盡管輸出名字是隨便命名的)。但是打印出來的是所有的節點的名字,簡直不要太多。這樣找的話,一方面可能找不對,另一方面也太費事。

那么怎么辦?答案簡單的讓我也很無語。其實,對ckpt進行數據恢復的時候,直接打印輸出的tensor名字就可以。比如說在saver以及placeholder定義的時候:output = model.yolo_inference(images, config.num_anchors / 3, config.num_classes, is_training),我們在后面跟一句:print output,從打印出來的信息即可查看。placeholder的查看方法同樣如此。

對網絡進行固化:

代碼:

  input_image_shape = tf.placeholder(dtype = tf.int32, shape = (2,))
  input_image = tf.placeholder(shape = [None, 416, 416, 3], dtype = tf.float32)
  predictor = yolo_predictor(config.obj_threshold, config.nms_threshold, config.classes_path, config.anchors_path)
  boxes, scores, classes = predictor.predict(input_image, input_image_shape)
  sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
  saver = tf.train.Saver()
  ckpt = tf.train.get_checkpoint_state(model_path)
  if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)
 
  # 采用meta 結構加載,不需要知道網絡結構
  # saver = tf.train.import_meta_graph(model_path, clear_devices=True) 
  # 這里的model_path是model.ckpt.meta文件的全路徑
  # ckpt_model_path 是保存模型的文件夾路徑
  # saver.restore(sess, tf.train.latest_checkpoint(ckpt_model_path))
 
  graph = tf.get_default_graph()
  input_graph_def = graph.as_graph_def()
  output_graph_def = graph_util.convert_variables_to_constants(
    sess,
    input_graph_def,
    ['concat_11','concat_12','concat_13'] # We split on comma for convenience
  )
  # # Finally we serialize and dump the output graph to the filesystem
  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())

由于固化的時候是需要先恢復ckpt網絡的,所以還是在restore前寫了placeholder和輸出tensor的定義(需要注點意的是,我們保存的ckpt文件是訓練階段的graph和變量等,其inference輸出和最終predict的輸出的Tensor不一樣,因此predict與inference的輸出相比,還包括了一些后處理,比如說nms等等,只有這些后處理也是TensorFlow框架內的方法寫的,才能使最終形成的pb文件能夠做到輸入一張圖片,直接輸出最終結果。因此,對于目標檢測任務,把后處理任務也交由TensorFlow內的api來實現,可免去夸平臺讀取pb文件后仍然需要重新進行后處理等相關程序的編寫帶來的不必要麻煩)。然后結合保存變量的那個文件(ckpt),將變量恢復到inference過程所需的變量數據(predict包括inference和eval兩個過程,訓練過程只有inference和loss過程參與,而預測過程多了一個后處理eval過程,eval過程無變量。這樣在生成pb文件的時候也把后處理eval固化進去。喂給網絡數據,即可得到輸出tensor。

由于有讀者在此問到了還是沒有弄明白'concat_11','concat_12','concat_13'是如何得來的,我在這里就在詳細說一下:

是這樣的,在我們恢復網絡的時候肯定需要知道saver這個對象的,在這里介紹兩種方法生成這個對象的方法。

一:

saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)

其中meta_graph_location就是保存模型時的.meta文件的路徑。保存后有四個文件(checkpoint、.index、.data-00000-of-00001和.meta文件)。.meta文件就是整個TensorFlow的結構圖。

二:

saver = tf.train.Saver()

本文采用的是第二種方法(上面已經有詳細的代碼),由于這種方法得到的saver對象,他不知道具體圖是什么樣的,因此在恢復前我有用如下代碼

predictor = yolo_predictor(config.obj_threshold, config.nms_threshold, config.classes_path, config.anchors_path)
boxes, scores, classes = predictor.predict(input_image, input_image_shape)

把整個結構又加載了一遍。如果采用第一種方法,是不需要在重寫這兩行代碼的。

我們要的就是 boxes, scores, classes這三個tensor的結果,并且想知道他們三個tensor的名字。你直接利用print(boxes, scores, classes)打印出來這三個tensor就會出來這三個tensor具體信息(包括名字,和shape,dtype等)。這個只是利用第二種方法得到saver對象,然后恢復ckpt文件,不涉及到固化pb文件問題。固化pb文件是需要知道這三個tensor的名字,所以需要打印看一下。

如果說,我只拿到了保存后的四個文件(checkpoint、.index、.data-00000-of-00001和.meta文件),其相應用代碼寫成的結構圖不清楚,比如說利用這兩行代碼:

predictor = yolo_predictor(config.obj_threshold, config.nms_threshold, config.classes_path, config.anchors_path)
boxes, scores, classes = predictor.predict(input_image, input_image_shape)

畫出的結構圖是什么樣的,我不知道。那么,想要知道具體的placehold和輸出tensor的名字,那只能通過代碼一中,打印出所有的OP操作節點,然后進行人工遍歷了。

讀取pb文件:

代碼:

def pb_detect(image_path, pb_model_path):
 
  os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_index
  image = Image.open(image_path)
  resize_image = letterbox_image(image, (416, 416))
  image_data = np.array(resize_image, dtype = np.float32)
  image_data /= 255.
  image_data = np.expand_dims(image_data, axis = 0)
  with tf.Graph().as_default():
    output_graph_def = tf.GraphDef()
    with open(pb_model_path, "rb") as f:
      output_graph_def.ParseFromString(f.read())
      tf.import_graph_def(output_graph_def, name="")
    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      input_image_tensor = sess.graph.get_tensor_by_name("Placeholder_1:0")
      input_image_tensor_shape = sess.graph.get_tensor_by_name("Placeholder:0")
      # 定義輸出的張量名稱
      #output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
      boxes = sess.graph.get_tensor_by_name("concat_11:0")
      scores = sess.graph.get_tensor_by_name("concat_12:0")
      classes = sess.graph.get_tensor_by_name("concat_13:0")
      # 讀取測試圖片
      # 測試讀出來的模型是否正確,注意這里傳入的是輸出和輸入節點的tensor的名字(需要在名字后面加:0),不是操作節點的名字
      out_boxes, out_scores, out_classes= sess.run([boxes,scores,classes],
              feed_dict={
                input_image_tensor: image_data,
                input_image_tensor_shape: [image.size[1], image.size[0]]
      })

可以看到讀取pb文件只需要比恢復ckpt文件容易的多,直接將placeholder的名字獲取到,將數據輸入恢復的網絡,以及讀取輸出即可。

小記:

有可能是TensorFlow版本更新或者其他原因,在后來工作中加載pb文件是報錯了:

ValueError: Fetch argument <tf.Tensor 'shuffle_batch:0' shape=(1, 300, 1024) dtype=float32> cannot be interpreted as a Tensor. (tf.Tensor 'shuffle_batch:0' shape=(1, 300, 1024), dtype=float32) is not an element of this graph.)

將上面讀取pb文件的代碼with tf.Graph().as_default():改成

global graph
graph = tf.get_default_graph()
with graph.as_default():

以上是“TensorFlow如何將ckpt文件固化成pb文件”這篇文章的所有內容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內容對大家有所幫助,如果還想學習更多知識,歡迎關注億速云行業資訊頻道!

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

吴旗县| 海安县| 罗源县| 大关县| 南部县| 塔河县| 瑞丽市| 濮阳县| 漳州市| 巴楚县| 余干县| 葵青区| 蒙阴县| 夹江县| 阿克| 天气| 咸宁市| 富宁县| 且末县| 岳西县| 青冈县| 绥棱县| 浙江省| 延寿县| 瑞安市| 交城县| 清河县| 定兴县| 云阳县| 昌邑市| 荆门市| 通河县| 嵩明县| 莱州市| 手游| 弥渡县| 阜宁县| 会宁县| 安岳县| 丰镇市| 大方县|