pytorch 移動(dòng)端部署之helloworld的使用
開始
安裝Androidstudio 4.1
克隆此項(xiàng)目
git clone https://github.com/pytorch/android-demo-app.git
使用androidstudio 打開 android-demo-app 中的HelloWordApp
打開之后androidstudio 會(huì)自動(dòng)創(chuàng)建依賴 只需要等待即可
這個(gè)代碼已經(jīng)是官方寫好的故而
開一下官方教程中的代碼都在什么位置
這句
repositories {
jcenter()
}
dependencies {
implementation 'org.pytorch:pytorch_android:1.4.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}
位置
HelloWorldApp\app\build.gradle
里面的全部代碼
apply plugin: 'com.android.application'
repositories {
jcenter()
}
android {
compileSdkVersion 28
buildToolsVersion "29.0.2"
defaultConfig {
applicationId "org.pytorch.helloworld"
minSdkVersion 21
targetSdkVersion 28
versionCode 1
versionName "1.0"
}
buildTypes {
release {
minifyEnabled false
}
}
}
dependencies {
implementation 'androidx.appcompat:appcompat:1.1.0'
implementation 'org.pytorch:pytorch_android:1.4.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}
這句
Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
Module module = Module.load(assetFilePath(this, "model.pt"));
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
float[] scores = outputTensor.getDataAsFloatArray();
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
都在這里
HelloWorldApp\app\src\main\java\org\pytorch\helloworld\MainActivity.java
全部代碼
package org.pytorch.helloworld;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.util.Log;
import android.widget.ImageView;
import android.widget.TextView;
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import androidx.appcompat.app.AppCompatActivity;
public class MainActivity extends AppCompatActivity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Bitmap bitmap = null;
Module module = null;
try {
// creating bitmap from packaged into app android asset 'image.jpg',
// app/src/main/assets/image.jpg
bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
// loading serialized torchscript module from packaged into app android asset model.pt,
// app/src/model/assets/model.pt
module = Module.load(assetFilePath(this, "model.pt"));
} catch (IOException e) {
Log.e("PytorchHelloWorld", "Error reading assets", e);
finish();
}
// showing image on UI
ImageView imageView = findViewById(R.id.image);
imageView.setImageBitmap(bitmap);
// preparing input tensor
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
// running the model
final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
// getting tensor content as java array of floats
final float[] scores = outputTensor.getDataAsFloatArray();
// searching for the index with maximum score
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
// showing className on UI
TextView textView = findViewById(R.id.text);
textView.setText(className);
}
/**
* Copies specified asset to the file in /files app directory and returns this file absolute path.
*
* @return absolute file path
*/
public static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}
}
在Build 中選擇Build Bundile APK 的 Build APK 就可以了
生成的apk 在
HelloWorldApp\app\build\outputs\apk\debug
中 這個(gè)是可以直接安裝的
安裝后是一個(gè)固定的照片 就是檢測(cè)了一個(gè)固定的照片
這是一個(gè)例子如果你只是想測(cè)試自己的模型調(diào)用能不能成功這個(gè)項(xiàng)目改改模型和模型加載即可
這個(gè)項(xiàng)目模型是一個(gè)resnet18 接著我們將其替換為resnet50
模型轉(zhuǎn)換代碼如下
import torch
import torchvision.models as models
from PIL import Image
import numpy as np
image = Image.open("test.jpg") #圖片發(fā)在了build文件夾下
image = image.resize((224, 224),Image.ANTIALIAS)
image = np.asarray(image)
image = image / 255
image = torch.Tensor(image).unsqueeze_(dim=0)
image = image.permute((0, 3, 1, 2)).float()
model = models.resnet50(pretrained=True)
model = model.eval()
resnet = torch.jit.trace(model, torch.rand(1,3,224,224))
# output=resnet(torch.ones(1,3,224,224))
output = resnet(image)
max_index = torch.max(output, 1)[1].item()
print(max_index) # ImageNet1000類的類別序
resnet.save('model.pt')
if __name__ == '__main__':
pass
將這個(gè)保存的模型 覆蓋掉下面路徑中的模型
(在覆蓋之前最好備份一個(gè)原來(lái)的模型,這里我們選擇修改原來(lái)模型的名字為model_1.pt)
HelloWorldApp\app\src\main\assets\model.pt
成功覆蓋后再一次執(zhí)行打包操作(在Build 中選擇Build Bundile APK 的 Build APK 就可以了
生成的apk 在
HelloWorldApp\app\build\outputs\apk\debug)而后打開文件發(fā)現(xiàn)一個(gè)123M的apk 之前的apk是73M
安裝 并且測(cè)試
完美打開也就是說(shuō)一切resnet 系列的 都可以通過(guò)這個(gè) 項(xiàng)目進(jìn)行演化出來(lái)
到此這篇關(guān)于pytorch 移動(dòng)端部署之helloworld的使用的文章就介紹到這了,更多相關(guān)pytorch 移動(dòng)端部署helloworld內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python 實(shí)時(shí)調(diào)取攝像頭的示例代碼
這篇文章主要介紹了python 實(shí)時(shí)調(diào)取攝像頭的示例代碼,幫助大家更好的理解和使用python,感興趣的朋友可以了解下2020-11-11
Python實(shí)現(xiàn)迪杰斯特拉算法并生成最短路徑的示例代碼
這篇文章主要介紹了Python實(shí)現(xiàn)迪杰斯特拉算法并生成最短路徑的示例代碼,幫助大家更好的理解和使用python,感興趣的朋友可以了解下2020-12-12
python 實(shí)現(xiàn)selenium斷言和驗(yàn)證的方法
今天小編就為大家分享一篇python 實(shí)現(xiàn)selenium斷言和驗(yàn)證的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-02-02
使用python獲取csv文本的某行或某列數(shù)據(jù)的實(shí)例
下面小編就為大家分享一篇使用python獲取csv文本的某行或某列數(shù)據(jù)的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-04-04
詳解python uiautomator2 watcher的使用方法
這篇文章主要介紹了python uiautomator2 watcher的使用方法,該方是基于uiautomator2如下版本進(jìn)行驗(yàn)證,本文給大家介紹的非常詳細(xì),需要的朋友可以參考下2019-09-09
解決python調(diào)用自己文件函數(shù)/執(zhí)行函數(shù)找不到包問(wèn)題
這篇文章主要介紹了解決python調(diào)用自己文件函數(shù)/執(zhí)行函數(shù)找不到包問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-06-06
Python實(shí)現(xiàn)模擬時(shí)鐘代碼推薦
本文給大家匯總介紹了下使用Python實(shí)現(xiàn)模擬時(shí)鐘的代碼,一共3個(gè)例子,后兩個(gè)是基于QT實(shí)現(xiàn),有需要的小伙伴可以參考下2015-11-11

