您好,登錄后才能下訂單哦!
本篇文章給大家分享的是有關如何用RNN進行分類,小編覺得挺實用的,因此分享給大家學習,希望大家閱讀完這篇文章后可以有所收獲,話不多說,跟著小編一起來看看吧。
今天我們介紹的是RNN是如何玩分類的。
MNIST數據集,我們都已經很熟悉了,是一個手寫數字的數據集,之前我們用它來實戰CNN分類器和機器學習的方法(在公眾號中回復“MNIST”,即可免費下載)。今天我們就用RNN來對MNIST數據集進行一個預測。
這個時候,我們需要將每一張數據圖像當成一個28x28的序列信號(圖像的大小為28x28pixels)。對于整個網絡框架,我們使用一個150個循環神經元外加一個有10個神經元的全連接層(每個類對應一個),最后接一個softmax層。如下: 整個模型的構建階段,也很直接,跟我們前幾期學的dnn構建方法非常類似,這里只是用了沒有展開的RNN代替了之前的隱藏層,需要注意的是最后的全連接層連接的是RNN的狀態tensor,該狀態tensor僅僅包含了RNN的最后一個狀態,并且y是目標類別。
from tensorflow.contrib.layers import fully_connected
n_steps = 28
n_inputs = 28
n_neurons = 150
n_outputs = 10
learning_rate = 0.001
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.int32, [None])
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)
logits = fully_connected(states, n_outputs, activation_fn=None)
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=y, logits=logits)
loss = tf.reduce_mean(xentropy)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(loss)
correct = tf.nn.in_top_k(logits, y, 1)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
init = tf.global_variables_initializer()
接下來,我們加載數據集,并對數據集進行reshape,如下:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/")
X_test = mnist.test.images.reshape((-1, n_steps, n_inputs))
y_test = mnist.test.labels
現在,我們將對上面的RNN進行training,在執行階段跟之前的dnn也是非常類似的,如下:
n_epochs = 100
batch_size = 150
with tf.Session() as sess:
init.run()
for epoch in range(n_epochs):
for iteration in range(mnist.train.num_examples // batch_size):
X_batch, y_batch = mnist.train.next_batch(batch_size)
X_batch = X_batch.reshape((-1, n_steps, n_inputs))
sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
acc_test = accuracy.eval(feed_dict={X: X_test, y: y_test})
print(epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)
輸出的結果如下:
0 Train accuracy: 0.713333 Test accuracy: 0.7299
1 Train accuracy: 0.766667 Test accuracy: 0.7977
...
98 Train accuracy: 0.986667 Test accuracy: 0.9777
99 Train accuracy: 0.986667 Test accuracy: 0.9809
最終得到了98%的準確率,還挺不錯的,如果我們調整下超參數或者RNN權重初始化的方式,訓練的更久一些,或者加一些正則化的方法,結果應該還會更好。
以上就是如何用RNN進行分類,小編相信有部分知識點可能是我們日常工作會見到或用到的。希望你能通過這篇文章學到更多知識。更多詳情敬請關注億速云行業資訊頻道。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。