forked from pytorch/android-demo-app
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[android][native_app] App example of linking to gradle deps native li…
…bs and torchscript CustomOp
- Loading branch information
1 parent
1e36f9e
commit 0570874
Showing
12 changed files
with
385 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
cmake_minimum_required(VERSION 3.4.1) | ||
set(TARGET pytorch_nativeapp) | ||
project(${TARGET} CXX) | ||
set(CMAKE_CXX_STANDARD 14) | ||
|
||
set(build_DIR ${CMAKE_SOURCE_DIR}/build) | ||
|
||
set(pytorch_testapp_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp) | ||
file(GLOB pytorch_testapp_SOURCES | ||
${pytorch_testapp_cpp_DIR}/pytorch_nativeapp.cpp | ||
) | ||
|
||
add_library(${TARGET} SHARED | ||
${pytorch_testapp_SOURCES} | ||
) | ||
|
||
file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers") | ||
file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}") | ||
|
||
target_compile_options(${TARGET} PRIVATE | ||
-fexceptions | ||
) | ||
|
||
set(BUILD_SUBDIR ${ANDROID_ABI}) | ||
|
||
find_library(PYTORCH_LIBRARY pytorch_jni | ||
PATHS ${PYTORCH_LINK_DIRS} | ||
NO_CMAKE_FIND_ROOT_PATH) | ||
find_library(FBJNI_LIBRARY fbjni | ||
PATHS ${PYTORCH_LINK_DIRS} | ||
NO_CMAKE_FIND_ROOT_PATH) | ||
|
||
# OpenCV | ||
if(NOT DEFINED ENV{OPENCV_ANDROID_SDK}) | ||
message(FATAL_ERROR "Environment var OPENCV_ANDROID_SDK set") | ||
endif() | ||
|
||
set(OPENCV_INCLUDE_DIR "$ENV{OPENCV_ANDROID_SDK}/sdk/native/jni/include") | ||
|
||
target_include_directories(${TARGET} PRIVATE | ||
"${OPENCV_INCLUDE_DIR}" | ||
${PYTORCH_INCLUDE_DIRS}) | ||
|
||
set(OPENCV_LIB_DIR "$ENV{OPENCV_ANDROID_SDK}/sdk/native/libs/${ANDROID_ABI}") | ||
|
||
find_library(OPENCV_LIBRARY opencv_java4 | ||
PATHS ${OPENCV_LIB_DIR} | ||
NO_CMAKE_FIND_ROOT_PATH) | ||
|
||
target_link_libraries(${TARGET} | ||
${PYTORCH_LIBRARY} | ||
${FBJNI_LIBRARY} | ||
${OPENCV_LIBRARY} | ||
log) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
apply plugin: 'com.android.application' | ||
|
||
repositories { | ||
jcenter() | ||
maven { | ||
url "https://oss.sonatype.org/content/repositories/snapshots" | ||
} | ||
} | ||
|
||
android { | ||
configurations { | ||
extractForNativeBuild | ||
} | ||
compileSdkVersion 28 | ||
buildToolsVersion "29.0.2" | ||
defaultConfig { | ||
applicationId "org.pytorch.nativeapp" | ||
minSdkVersion 21 | ||
targetSdkVersion 28 | ||
versionCode 1 | ||
versionName "1.0" | ||
externalNativeBuild { | ||
cmake { | ||
arguments "-DANDROID_STL=c++_shared" | ||
} | ||
} | ||
} | ||
buildTypes { | ||
release { | ||
minifyEnabled false | ||
} | ||
} | ||
externalNativeBuild { | ||
cmake { | ||
path "CMakeLists.txt" | ||
} | ||
} | ||
sourceSets { | ||
main { | ||
jniLibs.srcDirs = ['src/main/jniLibs'] | ||
} | ||
} | ||
} | ||
|
||
dependencies { | ||
implementation 'com.android.support:appcompat-v7:28.0.0' | ||
|
||
implementation 'org.pytorch:pytorch_android:1.6.0-SNAPSHOT' | ||
extractForNativeBuild 'org.pytorch:pytorch_android:1.6.0-SNAPSHOT' | ||
} | ||
|
||
task extractAARForNativeBuild { | ||
doLast { | ||
configurations.extractForNativeBuild.files.each { | ||
def file = it.absoluteFile | ||
copy { | ||
from zipTree(file) | ||
into "$buildDir/$file.name" | ||
include "headers/**" | ||
include "jni/**" | ||
} | ||
} | ||
} | ||
} | ||
|
||
tasks.whenTaskAdded { task -> | ||
if (task.name.contains('externalNativeBuild')) { | ||
task.dependsOn(extractAARForNativeBuild) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
<?xml version="1.0" encoding="utf-8"?> | ||
<manifest xmlns:android="http://schemas.android.com/apk/res/android" | ||
package="org.pytorch.nativeapp"> | ||
|
||
<application | ||
android:allowBackup="true" | ||
android:label="PyTorchNativeApp" | ||
android:supportsRtl="true" | ||
android:theme="@style/Theme.AppCompat.Light.DarkActionBar"> | ||
|
||
<activity android:name=".MainActivity"> | ||
<intent-filter> | ||
<action android:name="android.intent.action.MAIN" /> | ||
|
||
<category android:name="android.intent.category.LAUNCHER" /> | ||
</intent-filter> | ||
</activity> | ||
</application> | ||
</manifest> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
* | ||
*/ | ||
!.gitignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
#include <android/log.h> | ||
#include <cassert> | ||
#include <cmath> | ||
#include <pthread.h> | ||
#include <unistd.h> | ||
#include <vector> | ||
#define ALOGI(...) \ | ||
__android_log_print(ANDROID_LOG_INFO, "PyTorchNativeApp", __VA_ARGS__) | ||
#define ALOGE(...) \ | ||
__android_log_print(ANDROID_LOG_ERROR, "PyTorchNativeApp", __VA_ARGS__) | ||
|
||
#include "jni.h" | ||
|
||
#include <opencv2/opencv.hpp> | ||
#include <torch/script.h> | ||
|
||
namespace pytorch_nativeapp { | ||
namespace { | ||
torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) { | ||
cv::Mat image_mat(/*rows=*/image.size(0), | ||
/*cols=*/image.size(1), | ||
/*type=*/CV_32FC1, | ||
/*data=*/image.data_ptr<float>()); | ||
cv::Mat warp_mat(/*rows=*/warp.size(0), | ||
/*cols=*/warp.size(1), | ||
/*type=*/CV_32FC1, | ||
/*data=*/warp.data_ptr<float>()); | ||
|
||
cv::Mat output_mat; | ||
cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{8, 8}); | ||
|
||
torch::Tensor output = | ||
torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{8, 8}); | ||
return output.clone(); | ||
} | ||
|
||
static auto registry = | ||
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective); | ||
|
||
template <typename T> void log(const char *m, T t) { | ||
std::ostringstream os; | ||
os << t << std::endl; | ||
ALOGI("%s %s", m, os.str().c_str()); | ||
} | ||
|
||
struct JITCallGuard { | ||
torch::autograd::AutoGradMode no_autograd_guard{false}; | ||
torch::AutoNonVariableTypeMode non_var_guard{true}; | ||
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false}; | ||
}; | ||
} // namespace | ||
|
||
static void loadAndForwardModel(JNIEnv *env, jclass, jstring jModelPath) { | ||
const char *modelPath = env->GetStringUTFChars(jModelPath, 0); | ||
assert(modelPath); | ||
|
||
// To load torchscript model for mobile we need set these guards, | ||
// because mobile build doesn't support features like autograd for smaller | ||
// build size which is placed in `struct JITCallGuard` in this example. It may | ||
// change in future, you can track the latest changes keeping an eye in | ||
// android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp | ||
JITCallGuard guard; | ||
torch::jit::Module module = torch::jit::load(modelPath); | ||
module.eval(); | ||
torch::Tensor x = torch::randn({4, 8}); | ||
torch::Tensor y = torch::randn({8, 5}); | ||
log("x:", x); | ||
log("y:", y); | ||
c10::IValue t_out = module.forward({x, y}); | ||
log("result:", t_out); | ||
env->ReleaseStringUTFChars(jModelPath, modelPath); | ||
} | ||
} // namespace pytorch_nativeapp | ||
|
||
JNIEXPORT jint JNI_OnLoad(JavaVM *vm, void *) { | ||
JNIEnv *env; | ||
if (vm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6) != JNI_OK) { | ||
return JNI_ERR; | ||
} | ||
|
||
jclass c = env->FindClass("org/pytorch/nativeapp/NativeClient$NativePeer"); | ||
if (c == nullptr) { | ||
return JNI_ERR; | ||
} | ||
|
||
static const JNINativeMethod methods[] = { | ||
{"loadAndForwardModel", "(Ljava/lang/String;)V", | ||
(void *)pytorch_nativeapp::loadAndForwardModel}, | ||
}; | ||
int rc = env->RegisterNatives(c, methods, | ||
sizeof(methods) / sizeof(JNINativeMethod)); | ||
|
||
if (rc != JNI_OK) { | ||
return rc; | ||
} | ||
|
||
return JNI_VERSION_1_6; | ||
} |
46 changes: 46 additions & 0 deletions
46
NativeApp/app/src/main/java/org/pytorch/nativeapp/MainActivity.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
package org.pytorch.nativeapp; | ||
|
||
import android.content.Context; | ||
import android.os.Bundle; | ||
import android.util.Log; | ||
import androidx.appcompat.app.AppCompatActivity; | ||
import java.io.File; | ||
import java.io.FileOutputStream; | ||
import java.io.IOException; | ||
import java.io.InputStream; | ||
import java.io.OutputStream; | ||
|
||
public class MainActivity extends AppCompatActivity { | ||
|
||
private static final String TAG = "PyTorchNativeApp"; | ||
|
||
public static String assetFilePath(Context context, String assetName) { | ||
File file = new File(context.getFilesDir(), assetName); | ||
if (file.exists() && file.length() > 0) { | ||
return file.getAbsolutePath(); | ||
} | ||
|
||
try (InputStream is = context.getAssets().open(assetName)) { | ||
try (OutputStream os = new FileOutputStream(file)) { | ||
byte[] buffer = new byte[4 * 1024]; | ||
int read; | ||
while ((read = is.read(buffer)) != -1) { | ||
os.write(buffer, 0, read); | ||
} | ||
os.flush(); | ||
} | ||
return file.getAbsolutePath(); | ||
} catch (IOException e) { | ||
Log.e(TAG, "Error process asset " + assetName + " to file path"); | ||
} | ||
return null; | ||
} | ||
|
||
@Override | ||
protected void onCreate(Bundle savedInstanceState) { | ||
super.onCreate(savedInstanceState); | ||
final String modelFileAbsoluteFilePath = | ||
new File(assetFilePath(this, "compute.pt")).getAbsolutePath(); | ||
NativeClient.loadAndForwardModel(modelFileAbsoluteFilePath); | ||
} | ||
} |
16 changes: 16 additions & 0 deletions
16
NativeApp/app/src/main/java/org/pytorch/nativeapp/NativeClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
package org.pytorch.nativeapp; | ||
|
||
public final class NativeClient { | ||
|
||
public static void loadAndForwardModel(final String modelPath) { | ||
NativePeer.loadAndForwardModel(modelPath); | ||
} | ||
|
||
private static class NativePeer { | ||
static { | ||
System.loadLibrary("pytorch_nativeapp"); | ||
} | ||
|
||
private static native void loadAndForwardModel(final String modelPath); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
* | ||
*/ | ||
!.gitignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
buildscript { | ||
repositories { | ||
google() | ||
jcenter() | ||
} | ||
dependencies { | ||
classpath 'com.android.tools.build:gradle:3.5.0' | ||
} | ||
} | ||
|
||
allprojects { | ||
repositories { | ||
google() | ||
jcenter() | ||
} | ||
} | ||
|
||
task clean(type: Delete) { | ||
delete rootProject.buildDir | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
android.useAndroidX=true | ||
android.enableJetifier=true | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import torch | ||
import torch.utils.cpp_extension | ||
|
||
print(torch.version.__version__) | ||
op_source = """ | ||
#include <opencv2/opencv.hpp> | ||
#include <torch/script.h> | ||
torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) { | ||
cv::Mat image_mat(/*rows=*/image.size(0), | ||
/*cols=*/image.size(1), | ||
/*type=*/CV_32FC1, | ||
/*data=*/image.data_ptr<float>()); | ||
cv::Mat warp_mat(/*rows=*/warp.size(0), | ||
/*cols=*/warp.size(1), | ||
/*type=*/CV_32FC1, | ||
/*data=*/warp.data_ptr<float>()); | ||
cv::Mat output_mat; | ||
cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{64, 64}); | ||
torch::Tensor output = | ||
torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{64, 64}); | ||
return output.clone(); | ||
} | ||
static auto registry = | ||
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective); | ||
""" | ||
|
||
torch.utils.cpp_extension.load_inline( | ||
name="warp_perspective", | ||
cpp_sources=op_source, | ||
extra_ldflags=["-lopencv_core", "-lopencv_imgproc"], | ||
is_python_module=False, | ||
verbose=True, | ||
) | ||
|
||
print(torch.ops.my_ops.warp_perspective) | ||
|
||
|
||
@torch.jit.script | ||
def compute(x, y): | ||
if bool(x[0][0] == 42): | ||
z = 5 | ||
else: | ||
z = 10 | ||
x = torch.ops.my_ops.warp_perspective(x, torch.eye(3)) | ||
return x.matmul(y) + z | ||
|
||
|
||
compute.save("app/src/main/assets/compute.pt") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
include ':app' |