Python實(shí)現(xiàn)雙向RNN與堆疊的雙向RNN的示例代碼
1、雙向RNN
雙向RNN(Bidirectional RNN)的結(jié)構(gòu)如下圖所示。



雙向的 RNN 是同時(shí)考慮“過去”和“未來”的信息。上圖是一個(gè)序列長(zhǎng)度為 4 的雙向RNN 結(jié)構(gòu)。

雙向RNN就像是我們做閱讀理解的時(shí)候從頭向后讀一遍文章,然后又從后往前讀一遍文章,然后再做題。有可能從后往前再讀一遍文章的時(shí)候會(huì)有新的不一樣的理解,最后模型可能會(huì)得到更好的結(jié)果。
2、堆疊的雙向RNN

堆疊的雙向RNN(Stacked Bidirectional RNN)的結(jié)構(gòu)如上圖所示。上圖是一個(gè)堆疊了3個(gè)隱藏層的RNN網(wǎng)絡(luò)。

注意,這里的堆疊的雙向RNN并不是只有雙向的RNN才可以堆疊,其實(shí)任意的RNN都可以堆疊,如SimpleRNN、LSTM和GRU這些循環(huán)神經(jīng)網(wǎng)絡(luò)也可以進(jìn)行堆疊。
堆疊指的是在RNN的結(jié)構(gòu)中疊加多層,類似于BP神經(jīng)網(wǎng)絡(luò)中可以疊加多層,增加網(wǎng)絡(luò)的非線性。
3、雙向LSTM實(shí)現(xiàn)MNIST數(shù)據(jù)集分類
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import LSTM,Dropout,Bidirectional
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
# 載入數(shù)據(jù)集
mnist = tf.keras.datasets.mnist
# 載入數(shù)據(jù),數(shù)據(jù)載入的時(shí)候就已經(jīng)劃分好訓(xùn)練集和測(cè)試集
# 訓(xùn)練集數(shù)據(jù)x_train的數(shù)據(jù)形狀為(60000,28,28)
# 訓(xùn)練集標(biāo)簽y_train的數(shù)據(jù)形狀為(60000)
# 測(cè)試集數(shù)據(jù)x_test的數(shù)據(jù)形狀為(10000,28,28)
# 測(cè)試集標(biāo)簽y_test的數(shù)據(jù)形狀為(10000)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 對(duì)訓(xùn)練集和測(cè)試集的數(shù)據(jù)進(jìn)行歸一化處理,有助于提升模型訓(xùn)練速度
x_train, x_test = x_train / 255.0, x_test / 255.0
# 把訓(xùn)練集和測(cè)試集的標(biāo)簽轉(zhuǎn)為獨(dú)熱編碼
y_train = tf.keras.utils.to_categorical(y_train,num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test,num_classes=10)
# 數(shù)據(jù)大小-一行有28個(gè)像素
input_size = 28
# 序列長(zhǎng)度-一共有28行
time_steps = 28
# 隱藏層memory block個(gè)數(shù)
cell_size = 50
# 創(chuàng)建模型
# 循環(huán)神經(jīng)網(wǎng)絡(luò)的數(shù)據(jù)輸入必須是3維數(shù)據(jù)
# 數(shù)據(jù)格式為(數(shù)據(jù)數(shù)量,序列長(zhǎng)度,數(shù)據(jù)大小)
# 載入的mnist數(shù)據(jù)的格式剛好符合要求
# 注意這里的input_shape設(shè)置模型數(shù)據(jù)輸入時(shí)不需要設(shè)置數(shù)據(jù)的數(shù)量
model = Sequential([
Bidirectional(LSTM(units=cell_size,input_shape=(time_steps,input_size),return_sequences=True)),
Dropout(0.2),
Bidirectional(LSTM(cell_size)),
Dropout(0.2),
# 50個(gè)memory block輸出的50個(gè)值跟輸出層10個(gè)神經(jīng)元全連接
Dense(10,activation=tf.keras.activations.softmax)
])
# 循環(huán)神經(jīng)網(wǎng)絡(luò)的數(shù)據(jù)輸入必須是3維數(shù)據(jù)
# 數(shù)據(jù)格式為(數(shù)據(jù)數(shù)量,序列長(zhǎng)度,數(shù)據(jù)大小)
# 載入的mnist數(shù)據(jù)的格式剛好符合要求
# 注意這里的input_shape設(shè)置模型數(shù)據(jù)輸入時(shí)不需要設(shè)置數(shù)據(jù)的數(shù)量
# model.add(LSTM(
# units = cell_size,
# input_shape = (time_steps,input_size),
# ))
# 50個(gè)memory block輸出的50個(gè)值跟輸出層10個(gè)神經(jīng)元全連接
# model.add(Dense(10,activation='softmax'))
# 定義優(yōu)化器
adam = Adam(lr=1e-3)
# 定義優(yōu)化器,loss function,訓(xùn)練過程中計(jì)算準(zhǔn)確率 使用交叉熵?fù)p失函數(shù)
model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])
# 訓(xùn)練模型
history=model.fit(x_train,y_train,batch_size=64,epochs=10,validation_data=(x_test,y_test))
#打印模型摘要
model.summary()
loss=history.history['loss']
val_loss=history.history['val_loss']
accuracy=history.history['accuracy']
val_accuracy=history.history['val_accuracy']
# 繪制loss曲線
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
# 繪制acc曲線
plt.plot(accuracy, label='Training accuracy')
plt.plot(val_accuracy, label='Validation accuracy')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
這個(gè)可能對(duì)文本數(shù)據(jù)比較容易處理,這里用這個(gè)模型有點(diǎn)勉強(qiáng),只是簡(jiǎn)單測(cè)試下。
模型摘要:

acc曲線:

loss曲線:


到此這篇關(guān)于Python實(shí)現(xiàn)雙向RNN與堆疊的雙向RNN的示例代碼的文章就介紹到這了,更多相關(guān)Python 雙向RNN內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
基于python實(shí)現(xiàn)對(duì)文件進(jìn)行切分行
這篇文章主要介紹了基于python實(shí)現(xiàn)對(duì)文件進(jìn)行切分行,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-04-04
Python Tricks 使用 pywinrm 遠(yuǎn)程控制 Windows 主機(jī)的方法
這篇文章主要介紹了Python Tricks 使用 pywinrm 遠(yuǎn)程控制 Windows 主機(jī)的方法,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-07-07
Pymysql實(shí)現(xiàn)往表中插入數(shù)據(jù)過程解析
這篇文章主要介紹了Pymysql實(shí)現(xiàn)往表中插入數(shù)據(jù)過程解析,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-06-06
Python3訪問并下載網(wǎng)頁內(nèi)容的方法
這篇文章主要介紹了Python3訪問并下載網(wǎng)頁內(nèi)容的方法,實(shí)例分析了Python頁面抓取及寫入文件的實(shí)現(xiàn)技巧,具有一定參考借鑒價(jià)值,需要的朋友可以參考下2015-07-07
python簡(jiǎn)單實(shí)現(xiàn)獲取當(dāng)前時(shí)間
最近項(xiàng)目中經(jīng)常需要python去取當(dāng)前的時(shí)間,雖然不是很難,但是老是忘記,用一次丟一次,為了能夠更好的記住,我今天特意寫下python 當(dāng)前時(shí)間這篇文章,如果你覺的對(duì)你有用的話,可以收藏下。2016-08-08
numpy數(shù)組合并和矩陣拼接的實(shí)現(xiàn)
這篇文章主要介紹了numpy數(shù)組合并和矩陣拼接的實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-03-03
python實(shí)現(xiàn)一個(gè)簡(jiǎn)單的web應(yīng)用框架
這篇文章主要為大家介紹了使用python寫一個(gè)簡(jiǎn)單的web應(yīng)用框架實(shí)現(xiàn)示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-04-04

