使用PyTorch實(shí)現(xiàn)隨機(jī)搜索策略
1. 隨機(jī)搜索策略
在本節(jié)中,我們將學(xué)習(xí)一種比隨機(jī)選擇動作更復(fù)雜的策略來解決 CartPole 問題——隨機(jī)搜索策略。
一種簡單但有效的方法是將智能體對環(huán)境的觀測值映射到代表兩個(gè)動作的二維向量,然后我們選擇值較高的動作執(zhí)行。映射函數(shù)使用權(quán)重矩陣描述,權(quán)重矩陣的形狀為 4 x 2,因?yàn)樵贑arPole環(huán)境中狀態(tài)是一個(gè) 4 維向量,而動作有 2 個(gè)可能值。在每個(gè)回合中,首先隨機(jī)生成權(quán)重矩陣,并用于計(jì)算此回合中每個(gè)步驟的動作,并在回合結(jié)束時(shí)計(jì)算總獎勵。重復(fù)此過程,最后將能夠得到最高總獎勵的權(quán)重矩陣作為最終的動作選擇策略。由于在每個(gè)回合中我們均會隨機(jī)選擇權(quán)重矩陣,因此稱這種方法為隨機(jī)搜索,期望通過在多個(gè)回合的測試中找到最佳權(quán)重。
2. 使用 PyTorch 實(shí)現(xiàn)隨機(jī)搜索算法
在本節(jié)中,我們使用 PyTorch 實(shí)現(xiàn)隨機(jī)搜索算法。
首先,導(dǎo)入 Gym 和 PyTorch 以及其他所需庫,并創(chuàng)建一個(gè) CartPole 環(huán)境實(shí)例:
import gym
import torch
from matplotlib import pyplot as plt
env = gym.make('CartPole-v0')獲取并打印狀態(tài)空間和行動空間的尺寸:
n_state = env.observation_space.shape[0] print(n_state) # 4 n_action = env.action_space.n print(n_action) # 2
當(dāng)我們在之后定義權(quán)重矩陣時(shí),將會使用這些尺寸,即權(quán)重矩陣尺寸為 (n_state, n_action) = (4 x 2)。
接下來,定義函數(shù)用于使用給定輸入權(quán)重模擬 CartPole 環(huán)境的一個(gè)游戲回合并返回此回合中的總獎勵:
def run_episode(env, weight):
state = env.reset()
total_reward = 0
is_done = False
while not is_done:
state = torch.from_numpy(state).float()
action = torch.argmax(torch.matmul(state, weight))
state, reward, is_done, _ = env.step(action.item())
total_reward += reward
return total_reward在以上代碼中,我們首先將狀態(tài)數(shù)組 state 轉(zhuǎn)換為浮點(diǎn)型張量,然后計(jì)算狀態(tài)數(shù)組和權(quán)重矩陣張量的乘積 torch.matmul(state, weight),以將狀態(tài)數(shù)組進(jìn)行映射映射為動作數(shù)組,使用 torch.argmax() 操作選擇值較高的動作,例如值為 [0.122, 0.333],則應(yīng)選擇動作 1。然后使用 item() 方法獲取操作結(jié)果值,因?yàn)榇颂幍?nbsp;step() 方法需要接受單元素張量,獲取新的狀態(tài)和獎勵。重復(fù)以上過程,直到回合結(jié)束。
指定回合數(shù),并初始化變量用于記錄最佳總獎勵和相應(yīng)權(quán)重矩陣,并初始化數(shù)組用于記錄每個(gè)回合的總獎勵:
n_episode = 1000 best_total_reward = 0 best_weight = None total_rewards = []
接下來,我們運(yùn)行 n_episode 個(gè)回合,在每個(gè)回合中,執(zhí)行以下操作:
- 構(gòu)建隨機(jī)權(quán)重矩陣
- 智能體根據(jù)權(quán)重矩陣將狀態(tài)映射到相應(yīng)的動作
- 回合終止并返回總獎勵
- 更新最佳總獎勵和最佳權(quán)重,并記錄總獎勵
for e in range(n_episode):
weight = torch.rand(n_state, n_action)
total_reward = run_episode(env, weight)
print('Episode {}: {}'.format(e+1, total_reward))
if total_reward > best_total_reward:
best_weight = weight
best_total_reward = total_reward
total_rewards.append(total_reward)運(yùn)行 1000 次隨機(jī)搜索獲得最佳策略,最佳策略由 best_weight 參數(shù)化。在測試最佳策略之前,我們可以計(jì)算通過隨機(jī)搜索獲得的平均總獎勵:
print('Average total reward over {} episode: {}'.format(n_episode, sum(total_rewards) / n_episode))
# Average total reward over 1000 episode: 46.722可以看到,對比使用隨機(jī)動作獲得的結(jié)果 (22.19),使用隨機(jī)搜索獲取的總獎勵是其兩倍以上。
接下來,我們使用隨機(jī)搜索得到的最佳權(quán)重矩陣,在 1000 個(gè)新的回合中測試其表現(xiàn)如何:
n_episode_eval = 1000
total_rewards_eval = []
for episode in range(n_episode_eval):
total_reward = run_episode(env, best_weight)
print('Episode {}: {}'.format(episode+1, total_reward))
total_rewards_eval.append(total_reward)
print('Average total reward over {} episode: {}'.format(n_episode_eval, sum(total_rewards_eval) / n_episode_eval))
# Average total reward over 1000 episode: 114.786隨機(jī)搜索算法的效果能夠獲取較好結(jié)果的主要原因是 CartPole 環(huán)境較為簡單。它的觀察狀態(tài)數(shù)組僅由四個(gè)變量組成。而在 Atari Space Invaders 游戲中的觀察值超過 100000 (即 210 \times 160 \times 3210×160×3)。同樣 CartPole 中動作狀態(tài)的維數(shù)也僅僅為 2。通常,使用簡單算法可以很好地解決簡單問題。
我們也可以注意到,隨機(jī)搜索策略的性能優(yōu)于隨機(jī)選擇動作。這是因?yàn)殡S機(jī)搜索策略將智能體對環(huán)境的當(dāng)前狀態(tài)考慮在內(nèi)。有了關(guān)于環(huán)境的相關(guān)信息,隨機(jī)搜索策略中的動作就可以比完全隨機(jī)的選擇動作更加智能。
我們還可以在訓(xùn)練和測試階段繪制每個(gè)回合的總獎勵:
plt.plot(total_rewards, label='search')
plt.plot(total_rewards_eval, label='eval')
plt.xlabel('episode')
plt.ylabel('total_reward')
plt.legend()
plt.show()
可以看到,每個(gè)回合的總獎勵是非常隨機(jī)的,并且并沒有因?yàn)榛睾蠑?shù)的增加顯示出改善的趨勢。在訓(xùn)練過程中,可以看到在實(shí)現(xiàn)前期有些回合的總獎勵已經(jīng)可以達(dá)到 200,由于智能體的策略并不會因?yàn)榛睾蠑?shù)的增加而改善,因此我們可以在回合總獎勵達(dá)到 200 時(shí)結(jié)束訓(xùn)練:
n_episode = 1000
best_total_reward = 0
best_weight = None
total_rewards = []
for episode in range(n_episode):
weight = torch.rand(n_state, n_action)
total_reward = run_episode(env, weight)
print('Episode {}: {}'.format(episode+1, total_reward))
if total_reward > best_total_reward:
best_weight = weight
best_total_reward = total_reward
total_rewards.append(total_reward)
if best_total_reward == 200:
break由于每回合的權(quán)重都是隨機(jī)生成的,因此獲取最大獎勵的策略出現(xiàn)的回合也并不確定。要計(jì)算所需訓(xùn)練回合的期望,可以重復(fù)以上訓(xùn)練過程 1000 次,并取訓(xùn)練次數(shù)的平均值作為期望:
n_training = 1000
n_episode_training = []
for _ in range(n_training):
for episode in range(n_episode):
weight = torch.rand(n_state, n_action)
total_reward = run_episode(env, weight)
if total_reward == 200:
n_episode_training.append(episode+1)
break
print('Expectation of training episodes needed: ', sum(n_episode_training) / n_training)
# Expectation of training episodes needed: 14.26可以看到,平均而言,我們預(yù)計(jì)大約需要 14 個(gè)回合才能找到最佳策略。
到此這篇關(guān)于使用PyTorch實(shí)現(xiàn)隨機(jī)搜索策略的文章就介紹到這了,更多相關(guān)PyTorch隨機(jī)搜索內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
- PyTorch 隨機(jī)數(shù)生成占用 CPU 過高的解決方法
- Pytorch在dataloader類中設(shè)置shuffle的隨機(jī)數(shù)種子方式
- pytorch實(shí)現(xiàn)保證每次運(yùn)行使用的隨機(jī)數(shù)都相同
- pytorch隨機(jī)采樣操作SubsetRandomSampler()
- Pytorch生成隨機(jī)數(shù)Tensor的方法匯總
- 簡述python&pytorch 隨機(jī)種子的實(shí)現(xiàn)
- PyTorch 如何設(shè)置隨機(jī)數(shù)種子使結(jié)果可復(fù)現(xiàn)
- pytorch通過訓(xùn)練結(jié)果的復(fù)現(xiàn)設(shè)置隨機(jī)種子
相關(guān)文章
Django中在xadmin中集成DjangoUeditor過程詳解
這篇文章主要介紹了Django中在xadmin中集成DjangoUeditor過程詳解,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-07-07
Python內(nèi)置函數(shù)round()的用法和注意事項(xiàng)詳解
這篇文章主要介紹了Python中round()函數(shù)的相關(guān)資料,包括其基本語法、使用示例和注意事項(xiàng),文中通過代碼介紹的非常詳細(xì),需要的朋友可以參考下2025-03-03
如何將Pycharm中Terminal使用Powershell作為終端
這篇文章主要介紹了如何將Pycharm中Terminal使用Powershell作為終端問題,具有很好的參考價(jià)值,希望對大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-05-05
Python Web框架Flask中使用百度云存儲BCS實(shí)例
這篇文章主要介紹了Python Web框架Flask中使用百度云存儲BCS實(shí)例,本文調(diào)用了百度云存儲Python SDK中的相關(guān)類,需要的朋友可以參考下2015-02-02
Pycharm安裝第三方庫、安裝位置以及鏡像設(shè)置方法詳解
對于Python開發(fā)用戶來講,安裝第三方庫是家常便飯,下面這篇文章主要給大家介紹了關(guān)于Pycharm安裝第三方庫、安裝位置以及鏡像設(shè)置方法的相關(guān)資料,文中通過實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下2023-05-05
用python實(shí)現(xiàn)五子棋實(shí)例
這篇文章主要為大家詳細(xì)介紹了用python實(shí)現(xiàn)五子棋實(shí)例,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2022-05-05
用Python爬取各大高校并可視化幫弟弟選大學(xué),弟弟直呼牛X
高考結(jié)束了,接下來最重要的就是玩玩玩,然后準(zhǔn)備報(bào)志愿吧.中國教育在線網(wǎng)顯示國內(nèi)目前共有2857所高等院校,報(bào)一個(gè)理想的學(xué)校簡直是千里挑一.正好表弟求著我讓我?guī)退x學(xué)校,我想著十年寒窗苦讀也不容易不如就用python幫幫他.分析一下目前國內(nèi)的大學(xué),需要的朋友可以參考下2021-06-06

