基于Tensorflow的MNIST手寫(xiě)數(shù)字識(shí)別分類
本文實(shí)例為大家分享了基于Tensorflow的MNIST手寫(xiě)數(shù)字識(shí)別分類的具體實(shí)現(xiàn)代碼,供大家參考,具體內(nèi)容如下
代碼如下:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector
import time
IMAGE_PIXELS = 28
hidden_unit = 100
output_nums = 10
learning_rate = 0.001
train_steps = 50000
batch_size = 500
test_data_size = 10000
#日志目錄(這里根據(jù)自己的目錄修改)
logdir = 'D:/Develop_Software/Anaconda3/WorkDirectory/summary/mnist'
#導(dǎo)入mnist數(shù)據(jù)
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)
#全局訓(xùn)練步數(shù)
global_step = tf.Variable(0, name = 'global_step', trainable = False)
with tf.name_scope('input'):
#輸入數(shù)據(jù)
with tf.name_scope('x'):
x = tf.placeholder(
dtype = tf.float32, shape = (None, IMAGE_PIXELS * IMAGE_PIXELS))
#收集x圖像的會(huì)總數(shù)據(jù)
with tf.name_scope('x_summary'):
shaped_image_batch = tf.reshape(
tensor = x,
shape = (-1, IMAGE_PIXELS, IMAGE_PIXELS, 1),
name = 'shaped_image_batch')
tf.summary.image(name = 'image_summary',
tensor = shaped_image_batch,
max_outputs = 10)
with tf.name_scope('y_'):
y_ = tf.placeholder(dtype = tf.float32, shape = (None, 10))
with tf.name_scope('hidden_layer'):
with tf.name_scope('hidden_arg'):
#隱層模型參數(shù)
with tf.name_scope('hid_w'):
hid_w = tf.Variable(
tf.truncated_normal(shape = (IMAGE_PIXELS * IMAGE_PIXELS, hidden_unit)),
name = 'hidden_w')
#添加獲取隱層權(quán)重統(tǒng)計(jì)值匯總數(shù)據(jù)的匯總操作
tf.summary.histogram(name = 'weights', values = hid_w)
with tf.name_scope('hid_b'):
hid_b = tf.Variable(tf.zeros(shape = (1, hidden_unit), dtype = tf.float32),
name = 'hidden_b')
#隱層輸出
with tf.name_scope('relu'):
hid_out = tf.nn.relu(tf.matmul(x, hid_w) + hid_b)
with tf.name_scope('softmax_layer'):
with tf.name_scope('softmax_arg'):
#softmax層參數(shù)
with tf.name_scope('sm_w'):
sm_w = tf.Variable(
tf.truncated_normal(shape = (hidden_unit, output_nums)),
name = 'softmax_w')
#添加獲取softmax層權(quán)重統(tǒng)計(jì)值匯總數(shù)據(jù)的匯總操作
tf.summary.histogram(name = 'weights', values = sm_w)
with tf.name_scope('sm_b'):
sm_b = tf.Variable(tf.zeros(shape = (1, output_nums), dtype = tf.float32),
name = 'softmax_b')
#softmax層的輸出
with tf.name_scope('softmax'):
y = tf.nn.softmax(tf.matmul(hid_out, sm_w) + sm_b)
#梯度裁剪,因?yàn)楦怕嗜≈禐閇0, 1]為避免出現(xiàn)無(wú)意義的log(0),故將y值裁剪到[1e-10, 1]
y_clip = tf.clip_by_value(y, 1.0e-10, 1 - 1.0e-5)
with tf.name_scope('cross_entropy'):
#使用交叉熵代價(jià)函數(shù)
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_clip) + (1 - y_) * tf.log(1 - y_clip))
#添加獲取交叉熵的匯總操作
tf.summary.scalar(name = 'cross_entropy', tensor = cross_entropy)
with tf.name_scope('train'):
#若不使用同步訓(xùn)練機(jī)制,使用Adam優(yōu)化器
optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
#單步訓(xùn)練操作,
train_op = optimizer.minimize(cross_entropy, global_step = global_step)
#加載測(cè)試數(shù)據(jù)
test_image = mnist.test.images
test_label = mnist.test.labels
test_feed = {x:test_image, y_:test_label}
with tf.name_scope('accuracy'):
prediction = tf.equal(tf.argmax(input = y, axis = 1),
tf.argmax(input = y_, axis = 1))
accuracy = tf.reduce_mean(
input_tensor = tf.cast(x = prediction, dtype = tf.float32))
#創(chuàng)建嵌入變量
embedding_var = tf.Variable(test_image, trainable = False, name = 'embedding')
saver = tf.train.Saver({'embedding':embedding_var})
#創(chuàng)建元數(shù)據(jù)文件,將MNIST圖像測(cè)試集對(duì)應(yīng)的標(biāo)簽寫(xiě)入文件
def CreateMedaDataFile():
with open(logdir + '/metadata.tsv', 'w') as f:
label = np.nonzero(test_label)[1]
for i in range(test_data_size):
f.write('%d\n' % label[i])
#創(chuàng)建投影配置參數(shù)
def CreateProjectorConfig():
config = projector.ProjectorConfig()
embeddings = config.embeddings.add()
embeddings.tensor_name = 'embedding:0'
embeddings.metadata_path = logdir + '/metadata.tsv'
projector.visualize_embeddings(writer, config)
#聚集匯總操作
merged = tf.summary.merge_all()
#創(chuàng)建會(huì)話的配置參數(shù)
sess_config = tf.ConfigProto(
allow_soft_placement = True,
log_device_placement = False)
#創(chuàng)建會(huì)話
with tf.Session(config = sess_config) as sess:
#創(chuàng)建FileWriter實(shí)例
writer = tf.summary.FileWriter(logdir = logdir, graph = sess.graph)
#初始化全局變量
sess.run(tf.global_variables_initializer())
time_begin = time.time()
print('Training begin time: %f' % time_begin)
while True:
#加載訓(xùn)練批數(shù)據(jù)
batch_x, batch_y = mnist.train.next_batch(batch_size)
train_feed = {x:batch_x, y_:batch_y}
loss, _, summary= sess.run([cross_entropy, train_op, merged], feed_dict = train_feed)
step = global_step.eval()
#如果step為100的整數(shù)倍
if step % 100 == 0:
now = time.time()
print('%f: global_step = %d, loss = %f' % (
now, step, loss))
#向事件文件中添加匯總數(shù)據(jù)
writer.add_summary(summary = summary, global_step = step)
#若大于等于訓(xùn)練總步數(shù),退出訓(xùn)練
if step >= train_steps:
break
time_end = time.time()
print('Training end time: %f' % time_end)
print('Training time: %f' % (time_end - time_begin))
#測(cè)試模型精度
test_accuracy = sess.run(accuracy, feed_dict = test_feed)
print('accuracy: %f' % test_accuracy)
saver.save(sess = sess, save_path = logdir + '/embedding_var.ckpt')
CreateMedaDataFile()
CreateProjectorConfig()
#關(guān)閉FileWriter
writer.close()

以上就是本文的全部?jī)?nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
如何將python中的List轉(zhuǎn)化成dictionary
這篇文章主要介紹在python中如何將list轉(zhuǎn)化成dictionary,通過(guò)提出兩個(gè)問(wèn)題來(lái)告訴大家如何解決,有需要的可以參考借鑒。2016-08-08
python爬蟲(chóng)headers設(shè)置后無(wú)效的解決方法
這篇文章主要為大家詳細(xì)介紹了python爬蟲(chóng)headers設(shè)置后無(wú)效的解決方案,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2017-10-10
python日志通過(guò)不同的等級(jí)打印不同的顏色(示例代碼)
這篇文章主要介紹了python日志通過(guò)不同的等級(jí)打印不同的顏色,本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2021-01-01
Pycharm主題切換(禁用)導(dǎo)致UI界面顯示異常的解決方案
這篇文章主要介紹了Pycharm主題切換(禁用)導(dǎo)致UI界面顯示異常的原因分析和解決方案,文中通過(guò)圖文結(jié)合的方式給大家介紹的非常詳細(xì),需要的朋友可以參考下2024-06-06
Python序列的推導(dǎo)式實(shí)現(xiàn)代碼
推導(dǎo)式是可以從一個(gè)數(shù)據(jù)序列構(gòu)建另一個(gè)新的數(shù)據(jù)序列(的一種結(jié)構(gòu)體),是python的一種獨(dú)有特性,在python中共有三種推導(dǎo),列表推導(dǎo)式和字典推導(dǎo)式,集合推導(dǎo)式,本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),需要的朋友參考下吧2021-07-07
windows10系統(tǒng)中安裝python3.x+scrapy教程
本文給大家主要介紹了在windows10系統(tǒng)中安裝python3以及scrapy框架的教程以及有可能會(huì)遇到的問(wèn)題的解決辦法,希望大家能夠喜歡2016-11-11
python orm 框架中sqlalchemy用法實(shí)例詳解
這篇文章主要介紹了python orm 框架中sqlalchemy用法,結(jié)合實(shí)例形式詳細(xì)分析了Python orm 框架基本概念、原理及sqlalchemy相關(guān)使用技巧,需要的朋友可以參考下2020-02-02
Python的json.loads() 方法與json.dumps()方法及使用小結(jié)
json.loads() 是一個(gè)非常有用的方法,它允許你在處理 JSON 數(shù)據(jù)時(shí),將其轉(zhuǎn)換為 Python 數(shù)據(jù)類型,以便于在代碼中進(jìn)行操作和處理,這篇文章給大家介紹Python的json.loads() 方法與json.dumps()方法及使用小結(jié),感興趣的朋友一起看看吧2024-03-03

