您好,登錄后才能下訂單哦!
這篇文章主要講解了如何實現Tensorflow中的圖和會話,內容清晰明了,對此有興趣的小伙伴可以學習一下,相信大家閱讀完之后會有幫助。
Tensorflow編程系統
Tensorflow工具或者說深度學習本身就是一個連貫緊密的系統。一般的系統是一個自治獨立的、能實現復雜功能的整體。系統的主要任務是對輸入進行處理,以得到想要的輸出結果。我們之前見過的很多系統都是線性的,就像汽車生產工廠的流水線一樣,輸入->系統處理->輸出。系統內部由很多單一的基本部件構成,這些單一部件具有特定的功能,且需要穩定的特性;系統設計者通過特殊的連接方式,讓這些簡單部件進行連接,以使它們之間可以進行數據交流和信息互換,來達到相互配合而完成具體工作的目的。
對于任何一個系統來說,都應該擁有穩定、獨立、能處理特殊任務的單一部件;且擁有一套良好的內部溝通機制,以讓系統可以健康安全的運行。
現實中的很多系統都是線性的,被設計好的、不能進行更改的,比如工廠的流水線,這樣的系統并不具備自我調整的能力,無法對外界的環境做出反應,因此也就不具備“智能”。
深度學習(神經網絡)之所以具備智能,就是因為它具有反饋機制。深度學習具有一套對輸出所做的評價函數(損失函數),損失函數在對神經網絡做出評價后,會通過某種方式(梯度下降法)更新網絡的組成參數,以期望系統得到更好的輸出數據。
由此可見,神經網絡的系統主要由以下幾個方面組成:
定義好以上的組成部分,我們就可以用流程化的方式將其組合起來,讓系統對輸入進行學習,調整參數。因為該系統的反饋機制,所以,組成的方式肯定需要循環。
而對于Tensorflow來說,其設計理念肯定離不開神經網絡本身。所以,學習Tensorflow之前,對神經網絡有一個整體、深刻的理解也是必須的。如下圖:Tensorflow的執行示意。
那么對于以上所列的幾點,什么才是最重要的呢?我想肯定是有關系統本身所涉及到的問題。即如何構建、執行一個神經網絡?在Tensorflow中,用計算圖來構建網絡,用會話來具體執行網絡。深入理解了這兩點,我想,對于Tensorflow的設計思路,以及運行機制,也就略知一二了。
圖(tf.Graph):計算圖,主要用于構建網絡,本身不進行任何實際的計算。計算圖的設計啟發是高等數學里面的鏈式求導法則的圖。我們可以將計算圖理解為是一個計算模板或者計劃書。
會話(tf.session):會話,主要用于執行網絡。所有關于神經網絡的計算都在這里進行,它執行的依據是計算圖或者計算圖的一部分,同時,會話也會負責分配計算資源和變量存放,以及維護執行過程中的變量。
接下來,我們主要從計算圖開始,看一看Tensorflow是如何構建、執行網絡的。
計算圖
在開始之前,我們先復習一下Tensorflow的幾種基本數據類型:
tf.constant(value, dtype=None, shape=None, name='Const', verify_shape=False) tf.Variable(initializer, name) tf.placeholder(dtype, shape=None, name=None)
復習完畢。
graph = tf.Graph() with graph.as_default(): img = tf.constant(1.0, shape=[1,5,5,3])
以上代碼中定義了一個計算圖,在該計算圖中定義了一個常量。Tensorflow默認會創建一張計算圖。所以上面代碼中的前兩行,可以省略。默認情況下,計算圖是空的。
在執行完img = tf.constant(1.0, shape=[1,5,5,3])以后,計算圖中生成了一個node,一個node結點由name, op, input, attrs組成,即結點名稱、操作、輸入以及一系列的屬性(類型、形狀、值)等組成,計算圖就是由這樣一個個的node組成的。對于tf.constant()函數,只會生成一個node,但對于有的函數,如tf.Variable(initializer, name)(注意其第一個參數是初始化器)就會生成多個node結點(后面會講到)。
那么執行完img = tf.constant(1.0, shape=[1,5,5,3])后,計算圖中就多一個node結點。(因為每個node的屬性很多,我只表示name,op,input屬性)
繼續添加代碼:
img = tf.constant(1.0, shape=[1,5,5,3]) k = tf.constant(1.0, shape=[3,3,3,1])
代碼執行后的計算圖如下:
需要注意的是,如果沒有對結點進行命名,Tensorflow自動會將其命名為:Const、Const_1、const_2......。其他類型的結點類同。
現在,我們添加一個變量:
img = tf.constant(1.0, shape=[1,5,5,3]) k = tf.constant(1.0, shape=[3,3,3,1]) kernel = tf.Variable(k)
該變量用一個常量作為初始化器。我們先看一下計算圖:
如圖所示:
執行完tf.Variable()函數后,一共產生了三個結點:
圖中只是完成了操作的定義,但并沒有執行操作(如Variable/Assign結點的Assign操作,所以,此時候變量依然不可以使用,這就是為什么要在會話中初始化的原因)。
我們繼續添加代碼:
img = tf.constant(1.0, shape=[1,5,5,3]) k = tf.constant(1.0, shape=[3,3,3,1]) kernel = tf.Variable(k) y = tf.nn.conv2d(img, kernel, strides=[1,2,2,1], padding="SAME")
得到的計算圖如下:
可以看出,變量讀取是通過Variable/read來進行的。
如果在這里我們直接開啟會話,并執行計算圖中的卷積操作,系統就會報錯。
img = tf.constant(1.0, shape=[1,5,5,3]) k = tf.constant(1.0, shape=[3,3,3,1]) kernel = tf.Variable(k) y2 = tf.nn.conv2d(img, kernel, strides=[1,2,2,1], padding="SAME") with tf.Session() as sess: sess.run(y2)
這段代碼錯誤的原因在于,變量并沒有初始化就被使用,而從圖中清晰的可以看到,直接執行卷積,是回溯不到變量的值(Const_1)的(箭頭方向)。
所以,在執行之前,要進行初始化,代碼如下:
img = tf.constant(1.0, shape=[1,5,5,3]) k = tf.constant(1.0, shape=[3,3,3,1]) kernel = tf.Variable(k) y2 = tf.nn.conv2d(img, kernel, strides=[1,2,2,1], padding="SAME") init = tf.global_variables_initializer()
執行完tf.global_variables_initializer()函數以后,計算圖如下:
tf.global_variables_initializer()產生了一個名為init的node,該結點將所有的Variable/Assign結點作為輸入,以達到對整張計算圖中的變量進行初始化。
所以,在開啟會話后,執行的第一步操作,就是變量初始化(當然變量初始化的方式有很多種,我們也可以顯示調用tf.assign()來完成對單個結點的初始化)。
完整代碼如下:
img = tf.constant(1.0, shape=[1,5,5,3]) k = tf.constant(1.0, shape=[3,3,3,1]) kernel = tf.Variable(k) y2 = tf.nn.conv2d(img, kernel, strides=[1,2,2,1], padding="SAME") init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) # do someting....
會話
在上述代碼中,我已經使用會話(tf.session())來執行計算圖了。在tf.session()中,我們重點掌握無所不能的sess.run()。
一個session()中包含了Operation被執行,以及Tensor被evaluated的環境。
tf.Session().run()函數的定義:
run( fetches, feed_dict=None, options=None, run_metadata=None )
tf.Session().run()函數的功能為:執行fetches參數所提供的operation操作或計算其所提供的Tensor。
run()函數每執行一步,都會執行與fetches有關的圖中的所有結點的計算,以完成fetches中的任務。其中,feed_dict提供了部分數據輸入的功能。(和tf.Placeholder()搭配使用,很舒服)
參數說明:
當我們把模型的計算圖構建好以后,就可以利用會話來進行執行訓練了。
在明白了計算圖是如何構建的,以及如何被會話正確的執行以后,我們就可以愉快的開始Tensorflow之旅啦。
看完上述內容,是不是對如何實現Tensorflow中的圖和會話有進一步的了解,如果還想學習更多內容,歡迎關注億速云行業資訊頻道。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。