python循環(huán)神經(jīng)網(wǎng)絡(luò)RNN函數(shù)tf.nn.dynamic_rnn使用
學(xué)習(xí)前言
已經(jīng)完成了RNN網(wǎng)絡(luò)的構(gòu)建,但是我們對于RNN網(wǎng)絡(luò)還有許多疑問,特別是tf.nn.dynamic_rnn函數(shù),其具體的應(yīng)用方式我們并不熟悉,查詢了一下資料,我心里的想法是這樣的。
tf.nn.dynamic_rnn的定義
tf.nn.dynamic_rnn(
cell,
inputs,
sequence_length=None,
initial_state=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None
)
- cell:上文所定義的lstm_cell。
- inputs:RNN輸入。如果time_major==false(默認(rèn)),則必須是如下shape的tensor:[batch_size,max_time,…]或此類元素的嵌套元組。如果time_major==true,則必須是如下形狀的tensor:[max_time,batch_size,…]或此類元素的嵌套元組。
- sequence_length:Int32/Int64矢量大小。用于在超過批處理元素的序列長度時復(fù)制通過狀態(tài)和零輸出。因此,它更多的是為了性能而不是正確性。
- initial_state:上文所定義的_init_state。
- dtype:數(shù)據(jù)類型。
- parallel_iterations:并行運(yùn)行的迭代次數(shù)。那些不具有任何時間依賴性并且可以并行運(yùn)行的操作將是。這個參數(shù)用時間來交換空間。值>>1使用更多的內(nèi)存,但花費(fèi)的時間更少,而較小的值使用更少的內(nèi)存,但計(jì)算需要更長的時間。
- time_major:輸入和輸出tensor的形狀格式。如果為True,這些張量的形狀必須是[max_time,batch_size,depth]。如果為False,這些張量的形狀必須是[batch_size,max_time,depth]。使用time_major=true會更有效率,因?yàn)樗梢员苊庠赗NN計(jì)算的開始和結(jié)束時進(jìn)行換位。但是,大多數(shù)TensorFlow數(shù)據(jù)都是批處理主數(shù)據(jù),因此默認(rèn)情況下,此函數(shù)為False。
- scope:創(chuàng)建的子圖的可變作用域;默認(rèn)為“RNN”。
其返回值為outputs,states。
outputs:RNN的最后一層的輸出,是一個tensor。如果為time_major== False,則它的shape為[batch_size,max_time,cell.output_size]。如果為time_major== True,則它的shape為[max_time,batch_size,cell.output_size]。
states:是每一層的最后一個step的輸出,是一個tensor。state是最終的狀態(tài),也就是序列中最后一個cell輸出的狀態(tài)。一般情況下states的形狀為 [batch_size, cell.output_size],但當(dāng)輸入的cell為BasicLSTMCell時,states的形狀為[2,batch_size, cell.output_size ],其中2也對應(yīng)著LSTM中的cell state和hidden state。
tf.nn.dynamic_rnn的使用舉例
單層實(shí)驗(yàn)
我們首先使用單層的RNN進(jìn)行實(shí)驗(yàn)。
使用的代碼為:
import tensorflow as tf
import numpy as np
n_steps = 2 #兩個step
n_inputs = 3 #每個input是三維
n_nerve = 4 #神經(jīng)元個數(shù)
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_nerve)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)
init = tf.global_variables_initializer()
X_batch = np.array([[[0, 1, 2], [1, 2, 3]],
[[3, 4, 5], [4, 5, 6]],
[[5, 6, 7], [6, 7, 8]],
[[7, 8, 9], [8, 9, 10]]])
with tf.Session() as sess:
sess.run(init)
outputs_val, states_val = sess.run([outputs, states], feed_dict={X: X_batch})
print("outputs:", outputs_val)
print("states:", states_val)
輸出的log為:
outputs: [[[0.92146313 0.6069534 0.24989243 0.9305415 ] [0.9234855 0.8470011 0.7865616 0.99935764]] [[0.9772771 0.9713368 0.99483156 0.9999987 ] [0.9753329 0.99538314 0.9988139 1. ]] [[0.9901842 0.99558043 0.9998626 1. ] [0.989398 0.9992842 0.9999691 1. ]] [[0.99577546 0.9993256 0.99999636 1. ] [0.9954579 0.9998903 0.99999917 1. ]]] states: [[0.9234855 0.8470011 0.7865616 0.99935764] [0.9753329 0.99538314 0.9988139 1. ] [0.989398 0.9992842 0.9999691 1. ] [0.9954579 0.9998903 0.99999917 1. ]]
- Xin的shape是[batch_size = 4, max_time = 2, depth = 3]。
- outputs的shape是[batch_size = 4, max_time = 2, cell.output_size = 4]。
- states的shape是[batch_size = 4, cell.output_size = 4]
在time_major = False的時候:
- Xin、outputs、states的第一維,都是batch_size,即用于訓(xùn)練的batch的大小。
- Xin、outputs的第二維,都是max_time,在本文中對應(yīng)著RNN的兩個step。
- outputs、states的最后一維指的是每一個RNN的Cell的輸出,本文的RNN的Cell的n_nerve為4,所以cell.output_size = 4。Xin的最后一維指的是每一個輸入樣本的維度。
- outputs對應(yīng)的是RNN的最后一層的輸出,states對應(yīng)的是每一層的最后一個step的輸出。在RNN的層數(shù)僅1層的時候,states的輸出對應(yīng)為outputs最后的step的輸出。
多層實(shí)驗(yàn)
接下來我們使用兩層的RNN進(jìn)行實(shí)驗(yàn)。
使用的代碼為:
import tensorflow as tf
import numpy as np
n_steps = 2 #兩個step
n_inputs = 3 #每個input是三維
n_nerve = 4 #神經(jīng)元個數(shù)
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
#定義多層
layers = [tf.nn.rnn_cell.BasicRNNCell(num_units=n_nerve) for i in range(2)]
multi_layer_cell = tf.contrib.rnn.MultiRNNCell(layers)
outputs, states = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float32)
init = tf.global_variables_initializer()
X_batch = np.array([[[0, 1, 2], [1, 2, 3]],
[[3, 4, 5], [4, 5, 6]],
[[5, 6, 7], [6, 7, 8]],
[[7, 8, 9], [8, 9, 10]]])
with tf.Session() as sess:
sess.run(init)
outputs_val, states_val = sess.run([outputs, states], feed_dict={X: X_batch})
print("outputs:", outputs_val)
print("states:", states_val)
輸出的log為:
outputs: [[[-0.577939 -0.3657474 -0.21074213 0.8188577 ]
[-0.67090076 -0.47001836 -0.40080917 0.6026697 ]]
[[-0.72777444 -0.36500326 -0.7526911 0.86113644]
[-0.7928404 -0.6413429 -0.61007065 0.787065 ]]
[[-0.7537433 -0.35850585 -0.83090436 0.8573037 ]
[-0.82016116 -0.6559162 -0.7360482 0.7915131 ]]
[[-0.7597004 -0.35760364 -0.8450942 0.8567379 ]
[-0.8276395 -0.6573326 -0.7727142 0.7895221 ]]]
states: (array([[-0.71645427, -0.0585744 , 0.95318353, 0.8424729 ],
[-0.99845 , -0.5044571 , 0.9955299 , 0.9750488 ],
[-0.99992913, -0.8408632 , 0.99885863, 0.9932366 ],
[-0.99999577, -0.9672 , 0.9996866 , 0.99814796]],
dtype=float32),
array([[-0.67090076, -0.47001836, -0.40080917, 0.6026697 ],
[-0.7928404 , -0.6413429 , -0.61007065, 0.787065 ],
[-0.82016116, -0.6559162 , -0.7360482 , 0.7915131 ],
[-0.8276395 , -0.6573326 , -0.7727142 , 0.7895221 ]],
dtype=float32))
可以看出來outputs對應(yīng)的是RNN的最后一層的輸出,states對應(yīng)的是每一層的最后一個step的輸出,在完成了兩層的定義后,outputs的shape并沒有變化,而states的內(nèi)容多了一層,分別對應(yīng)RNN的兩層輸出。
state中最后一層輸出對應(yīng)著outputs最后一步的輸出。
以上就是python循環(huán)神經(jīng)網(wǎng)絡(luò)RNN函數(shù)tf.nn.dynamic_rnn使用的詳細(xì)內(nèi)容,更多關(guān)于RNN函數(shù)tf.nn.dynamic_rnn的資料請關(guān)注腳本之家其它相關(guān)文章!
- 深度學(xué)習(xí)TextRNN的tensorflow1.14實(shí)現(xiàn)示例
- python人工智能tensorflow構(gòu)建循環(huán)神經(jīng)網(wǎng)絡(luò)RNN
- Python使用循環(huán)神經(jīng)網(wǎng)絡(luò)解決文本分類問題的方法詳解
- 基于循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)實(shí)現(xiàn)影評情感分類
- 基于循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)的古詩生成器
- TensorFlow實(shí)現(xiàn)RNN循環(huán)神經(jīng)網(wǎng)絡(luò)
- 循環(huán)神經(jīng)網(wǎng)絡(luò)TextRNN實(shí)現(xiàn)情感短文本分類任務(wù)
相關(guān)文章
如何解決Keras載入mnist數(shù)據(jù)集出錯的問題
這篇文章主要介紹了解決Keras載入mnist數(shù)據(jù)集出錯的問題,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2021-05-05
Python解決IndexError: list index out of&nb
IndexError是一種常見的異常類型,它通常發(fā)生在嘗試訪問列表(list)中不存在的索引時,錯誤信息“IndexError: list index out of range”意味著你試圖訪問的列表索引超出了列表的實(shí)際范圍,所以本文給大家介紹了Python成功解決IndexError: list index out of range2024-05-05
Pandas實(shí)現(xiàn)數(shù)據(jù)類型轉(zhuǎn)換的一些小技巧匯總
這篇文章主要給大家匯總介紹了關(guān)于Pandas實(shí)現(xiàn)數(shù)據(jù)類型轉(zhuǎn)換的一些小技巧,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2018-05-05
Pandas實(shí)現(xiàn)轉(zhuǎn)換產(chǎn)生新列的項(xiàng)目實(shí)踐
本文主要介紹了Pandas實(shí)現(xiàn)轉(zhuǎn)換產(chǎn)生新列,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2024-12-12
教女朋友學(xué)Python3(二)簡單的輸入輸出及內(nèi)置函數(shù)查看
這篇文章主要介紹了教女朋友學(xué)Python3(二)簡單的輸入輸出及內(nèi)置函數(shù)查看,涉及Python3簡單的輸入輸出功能實(shí)現(xiàn),以及參看內(nèi)置函數(shù)的功能和用法描述的語句,具有一定參考價值,需要的朋友可了解下。2017-11-11
在Python中處理日期和時間的基本知識點(diǎn)整理匯總
這篇文章主要介紹了在Python中處理日期和時間的基本知識點(diǎn)整理匯總,是Python入門學(xué)習(xí)中的基礎(chǔ)知識,需要的朋友可以參考下2015-05-05
深入理解Python中的*args和**kwargs參數(shù)(示例代碼)
*args和**kwargs是Python函數(shù)編程中極其有用的特性,它們?yōu)楹瘮?shù)參數(shù)的處理提供了極大的靈活性和強(qiáng)大的功能,這篇文章主要介紹了Python中的*args和**kwargs參數(shù),需要的朋友可以參考下2024-06-06

