Tensorflow訓(xùn)練MNIST手寫(xiě)數(shù)字識(shí)別模型
本文實(shí)例為大家分享了Tensorflow訓(xùn)練MNIST手寫(xiě)數(shù)字識(shí)別模型的具體代碼,供大家參考,具體內(nèi)容如下
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
INPUT_NODE = 784 # 輸入層節(jié)點(diǎn)=圖片像素=28x28=784
OUTPUT_NODE = 10 # 輸出層節(jié)點(diǎn)數(shù)=圖片類(lèi)別數(shù)目
LAYER1_NODE = 500 # 隱藏層節(jié)點(diǎn)數(shù),只有一個(gè)隱藏層
BATCH_SIZE = 100 # 一個(gè)訓(xùn)練包中的數(shù)據(jù)個(gè)數(shù),數(shù)字越小
# 越接近隨機(jī)梯度下降,越大越接近梯度下降
LEARNING_RATE_BASE = 0.8 # 基礎(chǔ)學(xué)習(xí)率
LEARNING_RATE_DECAY = 0.99 # 學(xué)習(xí)率衰減率
REGULARIZATION_RATE = 0.0001 # 正則化項(xiàng)系數(shù)
TRAINING_STEPS = 30000 # 訓(xùn)練輪數(shù)
MOVING_AVG_DECAY = 0.99 # 滑動(dòng)平均衰減率
# 定義一個(gè)輔助函數(shù),給定神經(jīng)網(wǎng)絡(luò)的輸入和所有參數(shù),計(jì)算神經(jīng)網(wǎng)絡(luò)的前向傳播結(jié)果
def inference(input_tensor, avg_class, weights1, biases1,
weights2, biases2):
# 當(dāng)沒(méi)有提供滑動(dòng)平均類(lèi)時(shí),直接使用參數(shù)當(dāng)前取值
if avg_class == None:
# 計(jì)算隱藏層前向傳播結(jié)果
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
# 計(jì)算輸出層前向傳播結(jié)果
return tf.matmul(layer1, weights2) + biases2
else:
# 首先計(jì)算變量的滑動(dòng)平均值,然后計(jì)算前向傳播結(jié)果
layer1 = tf.nn.relu(
tf.matmul(input_tensor, avg_class.average(weights1)) +
avg_class.average(biases1))
return tf.matmul(
layer1, avg_class.average(weights2)) + avg_class.average(biases2)
# 訓(xùn)練模型的過(guò)程
def train(mnist):
x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')
# 生成隱藏層參數(shù)
weights1 = tf.Variable(
tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))
biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))
# 生成輸出層參數(shù)
weights2 = tf.Variable(
tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))
biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE]))
# 計(jì)算前向傳播結(jié)果,不使用參數(shù)滑動(dòng)平均值 avg_class=None
y = inference(x, None, weights1, biases1, weights2, biases2)
# 定義訓(xùn)練輪數(shù)變量,指定為不可訓(xùn)練
global_step = tf.Variable(0, trainable=False)
# 給定滑動(dòng)平均衰減率和訓(xùn)練輪數(shù)的變量,初始化滑動(dòng)平均類(lèi)
variable_avgs = tf.train.ExponentialMovingAverage(
MOVING_AVG_DECAY, global_step)
# 在所有代表神經(jīng)網(wǎng)絡(luò)參數(shù)的可訓(xùn)練變量上使用滑動(dòng)平均
variables_avgs_op = variable_avgs.apply(tf.trainable_variables())
# 計(jì)算使用滑動(dòng)平均值后的前向傳播結(jié)果
avg_y = inference(x, variable_avgs, weights1, biases1, weights2, biases2)
# 計(jì)算交叉熵作為損失函數(shù)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=y, labels=tf.argmax(y_, 1))
cross_entropy_mean = tf.reduce_mean(cross_entropy)
# 計(jì)算L2正則化損失函數(shù)
regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
regularization = regularizer(weights1) + regularizer(weights2)
loss = cross_entropy_mean + regularization
# 設(shè)置指數(shù)衰減的學(xué)習(xí)率
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step, # 當(dāng)前迭代輪數(shù)
mnist.train.num_examples / BATCH_SIZE, # 過(guò)完所有訓(xùn)練數(shù)據(jù)的迭代次數(shù)
LEARNING_RATE_DECAY)
# 優(yōu)化損失函數(shù)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(
loss, global_step=global_step)
# 反向傳播同時(shí)更新神經(jīng)網(wǎng)絡(luò)參數(shù)及其滑動(dòng)平均值
with tf.control_dependencies([train_step, variables_avgs_op]):
train_op = tf.no_op(name='train')
# 檢驗(yàn)使用了滑動(dòng)平均模型的神經(jīng)網(wǎng)絡(luò)前向傳播結(jié)果是否正確
correct_prediction = tf.equal(tf.argmax(avg_y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 初始化會(huì)話(huà)并開(kāi)始訓(xùn)練
with tf.Session() as sess:
tf.global_variables_initializer().run()
# 準(zhǔn)備驗(yàn)證數(shù)據(jù),用于判斷停止條件和訓(xùn)練效果
validate_feed = {x: mnist.validation.images,
y_: mnist.validation.labels}
# 準(zhǔn)備測(cè)試數(shù)據(jù),用于模型優(yōu)劣的最后評(píng)價(jià)標(biāo)準(zhǔn)
test_feed = {x: mnist.test.images, y_: mnist.test.labels}
# 迭代訓(xùn)練神經(jīng)網(wǎng)絡(luò)
for i in range(TRAINING_STEPS):
if i%1000 == 0:
validate_acc = sess.run(accuracy, feed_dict=validate_feed)
print("After %d training step(s), validation accuracy using average "
"model is %g " % (i, validate_acc))
xs, ys = mnist.train.next_batch(BATCH_SIZE)
sess.run(train_op, feed_dict={x: xs, y_: ys})
# 訓(xùn)練結(jié)束后在測(cè)試集上檢測(cè)模型的最終正確率
test_acc = sess.run(accuracy, feed_dict=test_feed)
print("After %d training steps, test accuracy using average model "
"is %g " % (TRAINING_STEPS, test_acc))
# 主程序入口
def main(argv=None):
mnist = input_data.read_data_sets("/tmp/data", one_hot=True)
train(mnist)
# Tensorflow主程序入口
if __name__ == '__main__':
tf.app.run()
輸出結(jié)果如下:
Extracting /tmp/data/train-images-idx3-ubyte.gz Extracting /tmp/data/train-labels-idx1-ubyte.gz Extracting /tmp/data/t10k-images-idx3-ubyte.gz Extracting /tmp/data/t10k-labels-idx1-ubyte.gz After 0 training step(s), validation accuracy using average model is 0.0462 After 1000 training step(s), validation accuracy using average model is 0.9784 After 2000 training step(s), validation accuracy using average model is 0.9806 After 3000 training step(s), validation accuracy using average model is 0.9798 After 4000 training step(s), validation accuracy using average model is 0.9814 After 5000 training step(s), validation accuracy using average model is 0.9826 After 6000 training step(s), validation accuracy using average model is 0.9828 After 7000 training step(s), validation accuracy using average model is 0.9832 After 8000 training step(s), validation accuracy using average model is 0.9838 After 9000 training step(s), validation accuracy using average model is 0.983 After 10000 training step(s), validation accuracy using average model is 0.9836 After 11000 training step(s), validation accuracy using average model is 0.9822 After 12000 training step(s), validation accuracy using average model is 0.983 After 13000 training step(s), validation accuracy using average model is 0.983 After 14000 training step(s), validation accuracy using average model is 0.9844 After 15000 training step(s), validation accuracy using average model is 0.9832 After 16000 training step(s), validation accuracy using average model is 0.9844 After 17000 training step(s), validation accuracy using average model is 0.9842 After 18000 training step(s), validation accuracy using average model is 0.9842 After 19000 training step(s), validation accuracy using average model is 0.9838 After 20000 training step(s), validation accuracy using average model is 0.9834 After 21000 training step(s), validation accuracy using average model is 0.9828 After 22000 training step(s), validation accuracy using average model is 0.9834 After 23000 training step(s), validation accuracy using average model is 0.9844 After 24000 training step(s), validation accuracy using average model is 0.9838 After 25000 training step(s), validation accuracy using average model is 0.9834 After 26000 training step(s), validation accuracy using average model is 0.984 After 27000 training step(s), validation accuracy using average model is 0.984 After 28000 training step(s), validation accuracy using average model is 0.9836 After 29000 training step(s), validation accuracy using average model is 0.9842 After 30000 training steps, test accuracy using average model is 0.9839
以上就是本文的全部?jī)?nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
python面向?qū)ο骭詳談?lì)惖睦^承與方法的重載
下面小編就為大家?guī)?lái)一篇python面向?qū)ο骭詳談?lì)惖睦^承與方法的重載。小編覺(jué)得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2017-06-06
Python?ttkbootstrap?制作賬戶(hù)注冊(cè)信息界面的案例代碼
ttkbootstrap 是一個(gè)基于 tkinter 的界面美化庫(kù),使用這個(gè)工具可以開(kāi)發(fā)出類(lèi)似前端 bootstrap 風(fēng)格的 tkinter 桌面程序。本文重點(diǎn)給大家介紹Python?ttkbootstrap?制作賬戶(hù)注冊(cè)信息界面的案例代碼,感興趣的朋友一起看看吧2022-02-02
python利用json和pyecharts畫(huà)折線圖實(shí)例代碼
這篇文章主要介紹了python利用json和pyecharts畫(huà)折線圖實(shí)例,本文通過(guò)示例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2022-12-12
Python實(shí)現(xiàn)動(dòng)態(tài)條形圖繪制的示例代碼
這篇文章主要為大家詳細(xì)介紹了如何利用Python語(yǔ)言實(shí)現(xiàn)動(dòng)態(tài)條形圖的繪制,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2022-08-08
python+opencv+selenium自動(dòng)化登錄郵箱并解決滑動(dòng)驗(yàn)證的問(wèn)題
本文主要講解基于python+opencv+selenium自動(dòng)化登錄郵箱并解決滑動(dòng)驗(yàn)證的問(wèn)題,在這大家需要注意頁(yè)面元素定位及文本框和驗(yàn)證碼的frame嵌套問(wèn)題,感興趣的朋友一起看看吧2021-07-07
Flask使用SQLAlchemy實(shí)現(xiàn)持久化數(shù)據(jù)
本文主要介紹了Flask使用SQLAlchemy實(shí)現(xiàn)持久化數(shù)據(jù),文中通過(guò)示例代碼介紹的非常詳細(xì),需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2021-07-07
python中torch.load中的map_location參數(shù)使用
在PyTorch中,torch.load()函數(shù)是用于加載保存模型或張量數(shù)據(jù)的重要工具,map_location參數(shù)為我們提供了極大的靈活性,具有一定的參考價(jià)值,感興趣的可以了解一下2024-03-03
利用Python實(shí)現(xiàn)讀取Word表格計(jì)算匯總并寫(xiě)入Excel
這篇文章主要給大家介紹了關(guān)于如何利用Python實(shí)現(xiàn)讀取Word表格計(jì)算匯總并寫(xiě)入Excel的相關(guān)資料,文中通過(guò)實(shí)例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2022-01-01

