tensorflow學(xué)習(xí)筆記之簡單的神經(jīng)網(wǎng)絡(luò)訓(xùn)練和測試
本文實例為大家分享了用簡單的神經(jīng)網(wǎng)絡(luò)來訓(xùn)練和測試的具體代碼,供大家參考,具體內(nèi)容如下
剛開始學(xué)習(xí)tf時,我們從簡單的地方開始。卷積神經(jīng)網(wǎng)絡(luò)(CNN)是由簡單的神經(jīng)網(wǎng)絡(luò)(NN)發(fā)展而來的,因此,我們的第一個例子,就從神經(jīng)網(wǎng)絡(luò)開始。
神經(jīng)網(wǎng)絡(luò)沒有卷積功能,只有簡單的三層:輸入層,隱藏層和輸出層。
數(shù)據(jù)從輸入層輸入,在隱藏層進行加權(quán)變換,最后在輸出層進行輸出。輸出的時候,我們可以使用softmax回歸,輸出屬于每個類別的概率值。借用極客學(xué)院的圖表示如下:

其中,x1,x2,x3為輸入數(shù)據(jù),經(jīng)過運算后,得到三個數(shù)據(jù)屬于某個類別的概率值y1,y2,y3. 用簡單的公式表示如下:

在訓(xùn)練過程中,我們將真實的結(jié)果和預(yù)測的結(jié)果相比(交叉熵比較法),會得到一個殘差。公式如下:

y是我們預(yù)測的概率值,y'是實際的值。這個殘差越小越好,我們可以使用梯度下降法,不停地改變W和b的值,使得殘差逐漸變小,最后收斂到最小值。這樣訓(xùn)練就完成了,我們就得到了一個模型(W和b的最優(yōu)化值)。
完整代碼如下:
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
y_actual = tf.placeholder(tf.float32, shape=[None, 10])
W = tf.Variable(tf.zeros([784,10])) #初始化權(quán)值W
b = tf.Variable(tf.zeros([10])) #初始化偏置項b
y_predict = tf.nn.softmax(tf.matmul(x,W) + b) #加權(quán)變換并進行softmax回歸,得到預(yù)測概率
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_actual*tf.log(y_predict),reduction_indies=1)) #求交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) #用梯度下降法使得殘差最小
correct_prediction = tf.equal(tf.argmax(y_predict,1), tf.argmax(y_actual,1)) #在測試階段,測試準(zhǔn)確度計算
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) #多個批次的準(zhǔn)確度均值
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
for i in range(1000): #訓(xùn)練階段,迭代1000次
batch_xs, batch_ys = mnist.train.next_batch(100) #按批次訓(xùn)練,每批100行數(shù)據(jù)
sess.run(train_step, feed_dict={x: batch_xs, y_actual: batch_ys}) #執(zhí)行訓(xùn)練
if(i%100==0): #每訓(xùn)練100次,測試一次
print "accuracy:",sess.run(accuracy, feed_dict={x: mnist.test.images, y_actual: mnist.test.labels})
每訓(xùn)練100次,測試一次,隨著訓(xùn)練次數(shù)的增加,測試精度也在增加。訓(xùn)練結(jié)束后,1W行數(shù)據(jù)測試的平均精度為91%左右,不是太高,肯定沒有CNN高。
以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
使用PyQt的QLabel組件實現(xiàn)選定目標(biāo)框功能的方法示例
這篇文章主要介紹了使用PyQt的QLabel組件實現(xiàn)選定目標(biāo)框功能的方法示例,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-05-05
Python使用mmap實現(xiàn)內(nèi)存映射文件操作
內(nèi)存映射通??梢蕴岣逫/O的性能,本文主要介紹了Python使用mmap實現(xiàn)內(nèi)存映射文件操作,分享給大家,感興趣的可以了解一下2021-06-06
Python中使用kitti數(shù)據(jù)集實現(xiàn)自動駕駛(繪制出所有物體的行駛軌跡)
這篇文章主要介紹了Python中使用kitti數(shù)據(jù)集實現(xiàn)自動駕駛——繪制出所有物體的行駛軌跡,本次內(nèi)容主要是畫出kitti車的行駛的軌跡,需要的朋友可以參考下2022-06-06
Python如何使用struct.unpack處理二進制文件
這篇文章主要介紹了Python如何使用struct.unpack處理二進制文件問題,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2024-02-02
Python利用QQ郵箱發(fā)送郵件的實現(xiàn)方法(分享)
下面小編就為大家?guī)硪黄狿ython利用QQ郵箱發(fā)送郵件的實現(xiàn)方法(分享)。小編覺得挺不錯的,現(xiàn)在就分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2017-06-06

