tensorflow使用CNN分析mnist手寫體數(shù)字數(shù)據(jù)集
更新時間:2020年06月17日 10:16:40 作者:Dillon2015
這篇文章主要介紹了tensorflow使用CNN分析mnist手寫體數(shù)字數(shù)據(jù)集,文中示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下
本文實例為大家分享了tensorflow使用CNN分析mnist手寫體數(shù)字數(shù)據(jù)集,供大家參考,具體內(nèi)容如下
import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#把上述trX和teX的形狀變?yōu)閇-1,28,28,1],-1表示不考慮輸入圖片的數(shù)量,28×28是圖片的長和寬的像素數(shù),
# 1是通道(channel)數(shù)量,因為MNIST的圖片是黑白的,所以通道是1,如果是RGB彩色圖像,通道是3。
trX = trX.reshape(-1, 28, 28, 1) # 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1) # 28x28x1 input img
X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])
#初始化權重與定義網(wǎng)絡結構。
# 這里,我們將要構建一個擁有3個卷積層和3個池化層,隨后接1個全連接層和1個輸出層的卷積神經(jīng)網(wǎng)絡
def init_weights(shape):
return tf.Variable(tf.random_normal(shape, stddev=0.01))
w = init_weights([3, 3, 1, 32]) # patch大小為3×3,輸入維度為1,輸出維度為32
w2 = init_weights([3, 3, 32, 64]) # patch大小為3×3,輸入維度為32,輸出維度為64
w3 = init_weights([3, 3, 64, 128]) # patch大小為3×3,輸入維度為64,輸出維度為128
w4 = init_weights([128 * 4 * 4, 625]) # 全連接層,輸入維度為 128 × 4 × 4,是上一層的輸出數(shù)據(jù)又三維的轉變成一維, 輸出維度為625
w_o = init_weights([625, 10]) # 輸出層,輸入維度為 625, 輸出維度為10,代表10類(labels)
# 神經(jīng)網(wǎng)絡模型的構建函數(shù),傳入以下參數(shù)
# X:輸入數(shù)據(jù)
# w:每一層的權重
# p_keep_conv,p_keep_hidden:dropout要保留的神經(jīng)元比例
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
# 第一組卷積層及池化層,最后dropout一些神經(jīng)元
l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
# l1a shape=(?, 28, 28, 32)
l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
# l1 shape=(?, 14, 14, 32)
l1 = tf.nn.dropout(l1, p_keep_conv)
# 第二組卷積層及池化層,最后dropout一些神經(jīng)元
l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
# l2a shape=(?, 14, 14, 64)
l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
# l2 shape=(?, 7, 7, 64)
l2 = tf.nn.dropout(l2, p_keep_conv)
# 第三組卷積層及池化層,最后dropout一些神經(jīng)元
l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
# l3a shape=(?, 7, 7, 128)
l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
# l3 shape=(?, 4, 4, 128)
l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]]) # reshape to (?, 2048)
l3 = tf.nn.dropout(l3, p_keep_conv)
# 全連接層,最后dropout一些神經(jīng)元
l4 = tf.nn.relu(tf.matmul(l3, w4))
l4 = tf.nn.dropout(l4, p_keep_hidden)
# 輸出層
pyx = tf.matmul(l4, w_o)
return pyx #返回預測值
#我們定義dropout的占位符——keep_conv,它表示在一層中有多少比例的神經(jīng)元被保留下來。生成網(wǎng)絡模型,得到預測值
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden) #得到預測值
#定義損失函數(shù),這里我們?nèi)匀徊捎胻f.nn.softmax_cross_entropy_with_logits來比較預測值和真實值的差異,并做均值處理;
# 定義訓練的操作(train_op),采用實現(xiàn)RMSProp算法的優(yōu)化器tf.train.RMSPropOptimizer,學習率為0.001,衰減值為0.9,使損失最??;
# 定義預測的操作(predict_op)
cost = tf.reduce_mean(tf.nn. softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)
#定義訓練時的批次大小和評估時的批次大小
batch_size = 128
test_size = 256
#在一個會話中啟動圖,開始訓練和評估
# Launch the graph in a session
with tf.Session() as sess:
# you need to initialize all variables
tf. global_variables_initializer().run()
for i in range(100):
training_batch = zip(range(0, len(trX), batch_size),
range(batch_size, len(trX)+1, batch_size))
for start, end in training_batch:
sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
p_keep_conv: 0.8, p_keep_hidden: 0.5})
test_indices = np.arange(len(teX)) # Get A Test Batch
np.random.shuffle(test_indices)
test_indices = test_indices[0:test_size]
print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
sess.run(predict_op, feed_dict={X: teX[test_indices],
p_keep_conv: 1.0,
p_keep_hidden: 1.0})))
以上就是本文的全部內(nèi)容,希望對大家的學習有所幫助,也希望大家多多支持腳本之家。
相關文章
python中getattr函數(shù)使用方法 getattr實現(xiàn)工廠模式
這篇文章主要介紹了python中getattr()這個函數(shù)的一些用法,大家參考使用吧2014-01-01
Python過濾txt文件內(nèi)重復內(nèi)容的方法
今天小編就為大家分享一篇Python過濾txt文件內(nèi)重復內(nèi)容的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-10-10
Pytorch中的model.train()?和?model.eval()?原理與用法解析
pytorch可以給我們提供兩種方式來切換訓練和評估(推斷)的模式,分別是:model.train()?和?model.eval(),這篇文章主要介紹了Pytorch中的model.train()?和?model.eval()?原理與用法,需要的朋友可以參考下2023-04-04
使用Python實現(xiàn)ELT統(tǒng)計多個服務器下所有數(shù)據(jù)表信息
這篇文章主要介紹了使用Python實現(xiàn)ELT統(tǒng)計多個服務器下所有數(shù)據(jù)表信息,ETL,是英文Extract-Transform-Load的縮寫,用來描述將數(shù)據(jù)從來源端經(jīng)過抽取(extract)、轉換(transform)、加載(load)至目的端的過程,需要的朋友可以參考下2023-07-07

