pytorch實現從本地加載 .pth 格式模型
可以從官網加載預訓練好的模型:
import torchvision.models as models model = models.vgg16(pretrained = True) print(model)
但是經常會出現因為下載速度太慢而出現requests.exceptions.ConnectionError: ('Connection aborted.', TimeoutError(10060, '由于連接方在一段時間后沒有正確答復或連接的主機沒有反應,連接嘗試失敗。', None, 10060, None))這種錯誤,因此需要我們手動去下載 .pth 文件(百度云也很慢,如果你是SVIP,當我沒說;迅雷的速度也還可以),然后從本地加載。
從本地加載只需要把上面的代碼換成如下:
import torchvision.models as models model = models.vgg16(pretrained=False) pre=torch.load(r'.\kaggle_dog_vs_cat\pretrain\vgg16-397923af.pth') model.load_state_dict(pre)
如果你模型不是用的vgg16,而是用的vgg11或者vgg13,只需要修改語句 model = models.vgg16(pretrained=False) 為對應模型的函數即可。
以上這篇pytorch實現從本地加載 .pth 格式模型就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
pyspark 讀取csv文件創(chuàng)建DataFrame的兩種方法
今天小編就為大家分享一篇pyspark 讀取csv文件創(chuàng)建DataFrame的兩種方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-06-06
pycharm?使用conda虛擬環(huán)境的詳細配置過程
這篇文章主要介紹了pycharm?使用conda虛擬環(huán)境,本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2023-03-03

