手把手教你使用TensorFlow2實(shí)現(xiàn)RNN
概述
RNN (Recurrent Netural Network) 是用于處理序列數(shù)據(jù)的神經(jīng)網(wǎng)絡(luò). 所謂序列數(shù)據(jù), 即前面的輸入和后面的輸入有一定的聯(lián)系.

權(quán)重共享
傳統(tǒng)神經(jīng)網(wǎng)絡(luò):

RNN:

RNN 的權(quán)重共享和 CNN 的權(quán)重共享類似, 不同時(shí)刻共享一個(gè)權(quán)重, 大大減少了參數(shù)數(shù)量.
計(jì)算過程:

計(jì)算狀態(tài) (State)

計(jì)算輸出:

案例
數(shù)據(jù)集
IBIM 數(shù)據(jù)集包含了來(lái)自互聯(lián)網(wǎng)的 50000 條關(guān)于電影的評(píng)論, 分為正面評(píng)價(jià)和負(fù)面評(píng)價(jià).
RNN 層
class RNN(tf.keras.Model):
def __init__(self, units):
super(RNN, self).__init__()
# 初始化 [b, 64] (b 表示 batch_size)
self.state0 = [tf.zeros([batch_size, units])]
self.state1 = [tf.zeros([batch_size, units])]
# [b, 80] => [b, 80, 100]
self.embedding = tf.keras.layers.Embedding(total_words, embedding_len, input_length=max_review_len)
self.rnn_cell0 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)
self.rnn_cell1 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)
# [b, 80, 100] => [b, 64] => [b, 1]
self.out_layer = tf.keras.layers.Dense(1)
def call(self, inputs, training=None):
"""
:param inputs: [b, 80]
:param training:
:return:
"""
state0 = self.state0
state1 = self.state1
x = self.embedding(inputs)
for word in tf.unstack(x, axis=1):
out0, state0 = self.rnn_cell0(word, state0, training=training)
out1, state1 = self.rnn_cell1(out0, state1, training=training)
# [b, 64] -> [b, 1]
x = self.out_layer(out1)
prob = tf.sigmoid(x)
return prob
獲取數(shù)據(jù)
def get_data():
# 獲取數(shù)據(jù)
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=total_words)
# 更改句子長(zhǎng)度
X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train, maxlen=max_review_len)
X_test = tf.keras.preprocessing.sequence.pad_sequences(X_test, maxlen=max_review_len)
# 調(diào)試輸出
print(X_train.shape, y_train.shape) # (25000, 80) (25000,)
print(X_test.shape, y_test.shape) # (25000, 80) (25000,)
# 分割訓(xùn)練集
train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_db = train_db.shuffle(10000).batch(batch_size, drop_remainder=True)
# 分割測(cè)試集
test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test))
test_db = test_db.batch(batch_size, drop_remainder=True)
return train_db, test_db
完整代碼
import tensorflow as tf
class RNN(tf.keras.Model):
def __init__(self, units):
super(RNN, self).__init__()
# 初始化 [b, 64]
self.state0 = [tf.zeros([batch_size, units])]
self.state1 = [tf.zeros([batch_size, units])]
# [b, 80] => [b, 80, 100]
self.embedding = tf.keras.layers.Embedding(total_words, embedding_len, input_length=max_review_len)
self.rnn_cell0 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)
self.rnn_cell1 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)
# [b, 80, 100] => [b, 64] => [b, 1]
self.out_layer = tf.keras.layers.Dense(1)
def call(self, inputs, training=None):
"""
:param inputs: [b, 80]
:param training:
:return:
"""
state0 = self.state0
state1 = self.state1
x = self.embedding(inputs)
for word in tf.unstack(x, axis=1):
out0, state0 = self.rnn_cell0(word, state0, training=training)
out1, state1 = self.rnn_cell1(out0, state1, training=training)
# [b, 64] -> [b, 1]
x = self.out_layer(out1)
prob = tf.sigmoid(x)
return prob
# 超參數(shù)
total_words = 10000 # 文字?jǐn)?shù)量
max_review_len = 80 # 句子長(zhǎng)度
embedding_len = 100 # 詞維度
batch_size = 1024 # 一次訓(xùn)練的樣本數(shù)目
learning_rate = 0.0001 # 學(xué)習(xí)率
iteration_num = 20 # 迭代次數(shù)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) # 優(yōu)化器
loss = tf.losses.BinaryCrossentropy(from_logits=True) # 損失
model = RNN(64)
# 調(diào)試輸出summary
model.build(input_shape=[None, 64])
print(model.summary())
# 組合
model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])
def get_data():
# 獲取數(shù)據(jù)
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=total_words)
# 更改句子長(zhǎng)度
X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train, maxlen=max_review_len)
X_test = tf.keras.preprocessing.sequence.pad_sequences(X_test, maxlen=max_review_len)
# 調(diào)試輸出
print(X_train.shape, y_train.shape) # (25000, 80) (25000,)
print(X_test.shape, y_test.shape) # (25000, 80) (25000,)
# 分割訓(xùn)練集
train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_db = train_db.shuffle(10000).batch(batch_size, drop_remainder=True)
# 分割測(cè)試集
test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test))
test_db = test_db.batch(batch_size, drop_remainder=True)
return train_db, test_db
if __name__ == "__main__":
# 獲取分割的數(shù)據(jù)集
train_db, test_db = get_data()
# 擬合
model.fit(train_db, epochs=iteration_num, validation_data=test_db, validation_freq=1)
輸出結(jié)果:
Model: "rnn"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) multiple 1000000
_________________________________________________________________
simple_rnn_cell (SimpleRNNCe multiple 10560
_________________________________________________________________
simple_rnn_cell_1 (SimpleRNN multiple 8256
_________________________________________________________________
dense (Dense) multiple 65
=================================================================
Total params: 1,018,881
Trainable params: 1,018,881
Non-trainable params: 0
_________________________________________________________________
None(25000, 80) (25000,)
(25000, 80) (25000,)
Epoch 1/20
2021-07-10 17:59:45.150639: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
24/24 [==============================] - 12s 294ms/step - loss: 0.7113 - accuracy: 0.5033 - val_loss: 0.6968 - val_accuracy: 0.4994
Epoch 2/20
24/24 [==============================] - 7s 292ms/step - loss: 0.6951 - accuracy: 0.5005 - val_loss: 0.6939 - val_accuracy: 0.4994
Epoch 3/20
24/24 [==============================] - 7s 297ms/step - loss: 0.6937 - accuracy: 0.5000 - val_loss: 0.6935 - val_accuracy: 0.4994
Epoch 4/20
24/24 [==============================] - 8s 316ms/step - loss: 0.6934 - accuracy: 0.5001 - val_loss: 0.6933 - val_accuracy: 0.4994
Epoch 5/20
24/24 [==============================] - 7s 301ms/step - loss: 0.6934 - accuracy: 0.4996 - val_loss: 0.6933 - val_accuracy: 0.4994
Epoch 6/20
24/24 [==============================] - 8s 334ms/step - loss: 0.6932 - accuracy: 0.5000 - val_loss: 0.6932 - val_accuracy: 0.4994
Epoch 7/20
24/24 [==============================] - 10s 398ms/step - loss: 0.6931 - accuracy: 0.5006 - val_loss: 0.6932 - val_accuracy: 0.4994
Epoch 8/20
24/24 [==============================] - 9s 382ms/step - loss: 0.6930 - accuracy: 0.5006 - val_loss: 0.6931 - val_accuracy: 0.4994
Epoch 9/20
24/24 [==============================] - 8s 322ms/step - loss: 0.6924 - accuracy: 0.4995 - val_loss: 0.6913 - val_accuracy: 0.5240
Epoch 10/20
24/24 [==============================] - 8s 321ms/step - loss: 0.6812 - accuracy: 0.5501 - val_loss: 0.6655 - val_accuracy: 0.5767
Epoch 11/20
24/24 [==============================] - 8s 318ms/step - loss: 0.6381 - accuracy: 0.6896 - val_loss: 0.6235 - val_accuracy: 0.7399
Epoch 12/20
24/24 [==============================] - 8s 323ms/step - loss: 0.6088 - accuracy: 0.7655 - val_loss: 0.6110 - val_accuracy: 0.7533
Epoch 13/20
24/24 [==============================] - 8s 321ms/step - loss: 0.5949 - accuracy: 0.7956 - val_loss: 0.6111 - val_accuracy: 0.7878
Epoch 14/20
24/24 [==============================] - 8s 324ms/step - loss: 0.5859 - accuracy: 0.8142 - val_loss: 0.5993 - val_accuracy: 0.7904
Epoch 15/20
24/24 [==============================] - 8s 330ms/step - loss: 0.5791 - accuracy: 0.8318 - val_loss: 0.5961 - val_accuracy: 0.7907
Epoch 16/20
24/24 [==============================] - 8s 340ms/step - loss: 0.5739 - accuracy: 0.8421 - val_loss: 0.5942 - val_accuracy: 0.7961
Epoch 17/20
24/24 [==============================] - 9s 378ms/step - loss: 0.5701 - accuracy: 0.8497 - val_loss: 0.5933 - val_accuracy: 0.8014
Epoch 18/20
24/24 [==============================] - 9s 361ms/step - loss: 0.5665 - accuracy: 0.8589 - val_loss: 0.5958 - val_accuracy: 0.8082
Epoch 19/20
24/24 [==============================] - 8s 353ms/step - loss: 0.5630 - accuracy: 0.8681 - val_loss: 0.5931 - val_accuracy: 0.7966
Epoch 20/20
24/24 [==============================] - 8s 314ms/step - loss: 0.5614 - accuracy: 0.8702 - val_loss: 0.5925 - val_accuracy: 0.7959Process finished with exit code 0
到此這篇關(guān)于手把手教你使用TensorFlow2實(shí)現(xiàn)RNN的文章就介紹到這了,更多相關(guān)TensorFlow2實(shí)現(xiàn)RNN內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python解析HTML并提取span標(biāo)簽中的文本
在網(wǎng)頁(yè)開發(fā)和數(shù)據(jù)抓取過程中,我們經(jīng)常需要從HTML頁(yè)面中提取信息,尤其是span元素中的文本,span標(biāo)簽是一個(gè)行內(nèi)元素,通常用于包裝一小段文本或其他元素,在Python中,我們可以通過使用BeautifulSoup或lxml等庫(kù)來(lái)解析HTML并提取span標(biāo)簽中的文本2024-12-12
Python實(shí)現(xiàn)解析Bit Torrent種子文件內(nèi)容的方法
這篇文章主要介紹了Python實(shí)現(xiàn)解析Bit Torrent種子文件內(nèi)容的方法,結(jié)合實(shí)例形式分析了Python針對(duì)Torrent文件的讀取與解析相關(guān)操作技巧與注意事項(xiàng),需要的朋友可以參考下2017-08-08
python web自制框架之接受url傳遞過來(lái)的參數(shù)實(shí)例
今天小編就為大家分享一篇python web自制框架之接受url傳遞過來(lái)的參數(shù)實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來(lái)看看吧2018-12-12
使用Python進(jìn)行物聯(lián)網(wǎng)設(shè)備的控制與數(shù)據(jù)收集
Python作為一種高效且易于學(xué)習(xí)的編程語(yǔ)言,已經(jīng)成為開發(fā)物聯(lián)網(wǎng)應(yīng)用的首選語(yǔ)言之一,本文將探討如何使用Python進(jìn)行物聯(lián)網(wǎng)設(shè)備的控制與數(shù)據(jù)收集,并提供相應(yīng)的代碼示例,需要的朋友可以參考下2024-05-05
Python代碼一鍵轉(zhuǎn)Jar包及Java調(diào)用Python新姿勢(shì)
這篇文章主要介紹了Python一鍵轉(zhuǎn)Jar包,Java調(diào)用Python新姿勢(shì),本文通過截圖實(shí)例給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-03-03
python多版本工具miniconda的配置優(yōu)化實(shí)現(xiàn)
通過Miniconda,您可以輕松地創(chuàng)建和管理多個(gè)Python環(huán)境,同時(shí)確保每個(gè)環(huán)境具有所需的依賴項(xiàng)和軟件包,本文主要介紹了python多版本工具miniconda的配置優(yōu)化實(shí)現(xiàn),感興趣的可以了解一下2024-01-01
使用Python第三方庫(kù)發(fā)送電子郵件的示例代碼
本文主要介紹了使用Python第三方庫(kù)發(fā)送電子郵件的示例代碼,文中通過示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2021-12-12
python實(shí)現(xiàn)TCP服務(wù)器端與客戶端的方法詳解
這篇文章主要介紹了python實(shí)現(xiàn)TCP服務(wù)器端與客戶端的方法,以實(shí)例形式詳解分析了Python實(shí)現(xiàn)服務(wù)器端與客戶端的技巧,具有一定參考借鑒價(jià)值,需要的朋友可以參考下2015-04-04
Python 實(shí)現(xiàn)12306登錄功能實(shí)例代碼
這篇文章主要介紹了Python 實(shí)現(xiàn)12306登錄功能的完整代碼,需要的朋友可以參考下2018-02-02
Python 數(shù)據(jù)結(jié)構(gòu)之隊(duì)列的實(shí)現(xiàn)
這篇文章主要介紹了Python 數(shù)據(jù)結(jié)構(gòu)之隊(duì)列的實(shí)現(xiàn)的相關(guān)資料,需要的朋友可以參考下2017-01-01

