pytorch中的model=model.to(device)使用說明
這代表將模型加載到指定設(shè)備上。
其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")則代表的使用GPU。
當(dāng)我們指定了設(shè)備之后,就需要將模型加載到相應(yīng)設(shè)備中,此時(shí)需要使用model=model.to(device),將模型加載到相應(yīng)的設(shè)備中。
將由GPU保存的模型加載到CPU上。
將torch.load()函數(shù)中的map_location參數(shù)設(shè)置為torch.device('cpu')
device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
將由GPU保存的模型加載到GPU上。確保對輸入的tensors調(diào)用input = input.to(device)方法。
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
將由CPU保存的模型加載到GPU上。
確保對輸入的tensors調(diào)用input = input.to(device)方法。map_location是將模型加載到GPU上,model.to(torch.device('cuda'))是將模型參數(shù)加載為CUDA的tensor。
最后保證使用.to(torch.device('cuda'))方法將需要使用的參數(shù)放入CUDA。
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)
補(bǔ)充:pytorch中model.to(device)和map_location=device的區(qū)別
一、簡介
在已訓(xùn)練并保存在CPU上的GPU上加載模型時(shí),加載模型時(shí)經(jīng)常由于訓(xùn)練和保存模型時(shí)設(shè)備不同出現(xiàn)讀取模型時(shí)出現(xiàn)錯(cuò)誤,在對跨設(shè)備的模型讀取時(shí)候涉及到兩個(gè)參數(shù)的使用,分別是model.to(device)和map_location=devicel兩個(gè)參數(shù),簡介一下兩者的不同。
將map_location函數(shù)中的參數(shù)設(shè)置 torch.load()為 cuda:device_id。這會將模型加載到給定的GPU設(shè)備。
調(diào)用model.to(torch.device('cuda'))將模型的參數(shù)張量轉(zhuǎn)換為CUDA張量,無論在cpu上訓(xùn)練還是gpu上訓(xùn)練,保存的模型參數(shù)都是參數(shù)張量不是cuda張量,因此,cpu設(shè)備上不需要使用torch.to(torch.device("cpu"))。
二、實(shí)例
了解了兩者代表的意義,以下介紹兩者的使用。
1、保存在GPU上,在CPU上加載
保存:
torch.save(model.state_dict(), PATH)
加載:
device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
解釋:
在使用GPU訓(xùn)練的CPU上加載模型時(shí),請傳遞 torch.device('cpu')給map_location函數(shù)中的 torch.load()參數(shù),使用map_location參數(shù)將張量下面的存儲器動態(tài)地重新映射到CPU設(shè)備 。
2、保存在GPU上,在GPU上加載
保存:
torch.save(model.state_dict(), PATH)
加載:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
解釋:
在GPU上訓(xùn)練并保存在GPU上的模型時(shí),只需將初始化model模型轉(zhuǎn)換為CUDA優(yōu)化模型即可model.to(torch.device('cuda'))。
此外,請務(wù)必.to(torch.device('cuda'))在所有模型輸入上使用該 功能來準(zhǔn)備模型的數(shù)據(jù)。
請注意,調(diào)用my_tensor.to(device) 返回my_tensorGPU上的新副本。
它不會覆蓋 my_tensor。
因此,請記住手動覆蓋張量: my_tensor = my_tensor.to(torch.device('cuda'))
3、保存在CPU,在GPU上加載
保存:
torch.save(model.state_dict(), PATH)
加載:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
解釋:
在已訓(xùn)練并保存在CPU上的GPU上加載模型時(shí),請將map_location函數(shù)中的參數(shù)設(shè)置 torch.load()為 cuda:device_id。
這會將模型加載到給定的GPU設(shè)備。
接下來,請務(wù)必調(diào)用model.to(torch.device('cuda'))將模型的參數(shù)張量轉(zhuǎn)換為CUDA張量。
最后,確保.to(torch.device('cuda'))在所有模型輸入上使用該 函數(shù)來為CUDA優(yōu)化模型準(zhǔn)備數(shù)據(jù)。
請注意,調(diào)用 my_tensor.to(device)返回my_tensorGPU上的新副本。
它不會覆蓋my_tensor。
因此,請記住手動覆蓋張量:my_tensor = my_tensor.to(torch.device('cuda'))
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Pycharm終端顯示PS而不顯示虛擬環(huán)境名的解決
這篇文章主要介紹了Pycharm終端顯示PS而不顯示虛擬環(huán)境名的解決方案,具有很好的參考價(jià)值,希望對大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-06-06
python命令行引導(dǎo)用戶填寫可用的ip地址和端口號實(shí)現(xiàn)
這篇文章主要為大家介紹了python命令行引導(dǎo)用戶填寫可用的ip地址和端口號實(shí)現(xiàn)詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-11-11
python實(shí)現(xiàn)股票歷史數(shù)據(jù)可視化分析案例
股票交易數(shù)據(jù)分析可直觀股市走向,對于如何把握股票行情,快速解讀股票交易數(shù)據(jù)有不可替代的作用,感興趣的可以了解一下2021-06-06
Python一個(gè)簡單的通信程序(客戶端 服務(wù)器)
今天小編就為大家分享一篇關(guān)于Python一個(gè)簡單的通信程序(客戶端 服務(wù)器),小編覺得內(nèi)容挺不錯(cuò)的,現(xiàn)在分享給大家,具有很好的參考價(jià)值,需要的朋友一起跟隨小編來看看吧2019-03-03
關(guān)于Python正則表達(dá)式 findall函數(shù)問題詳解
在寫正則表達(dá)式的時(shí)候總會遇到不少的問題,本文講述了Python正則表達(dá)式中 findall()函數(shù)和多個(gè)表達(dá)式元組相遇的時(shí)候會出現(xiàn)的問題2018-03-03
Pytest實(shí)現(xiàn)setup和teardown的詳細(xì)使用詳解
這篇文章主要介紹了Pytest實(shí)現(xiàn)setup和teardown的詳細(xì)使用詳解,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-04-04

