Skip to content

Commit

Permalink
[android][native_app] App example of linking to gradle deps native li…
Browse files Browse the repository at this point in the history
…bs and torchscript CustomOp
  • Loading branch information
IvanKobzarev committed Jun 26, 2020
1 parent 1e36f9e commit 0570874
Show file tree
Hide file tree
Showing 12 changed files with 385 additions and 0 deletions.
54 changes: 54 additions & 0 deletions NativeApp/app/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
cmake_minimum_required(VERSION 3.4.1)
set(TARGET pytorch_nativeapp)
project(${TARGET} CXX)
set(CMAKE_CXX_STANDARD 14)

set(build_DIR ${CMAKE_SOURCE_DIR}/build)

set(pytorch_testapp_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
file(GLOB pytorch_testapp_SOURCES
${pytorch_testapp_cpp_DIR}/pytorch_nativeapp.cpp
)

add_library(${TARGET} SHARED
${pytorch_testapp_SOURCES}
)

file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers")
file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}")

target_compile_options(${TARGET} PRIVATE
-fexceptions
)

set(BUILD_SUBDIR ${ANDROID_ABI})

find_library(PYTORCH_LIBRARY pytorch_jni
PATHS ${PYTORCH_LINK_DIRS}
NO_CMAKE_FIND_ROOT_PATH)
find_library(FBJNI_LIBRARY fbjni
PATHS ${PYTORCH_LINK_DIRS}
NO_CMAKE_FIND_ROOT_PATH)

# OpenCV
if(NOT DEFINED ENV{OPENCV_ANDROID_SDK})
message(FATAL_ERROR "Environment var OPENCV_ANDROID_SDK set")
endif()

set(OPENCV_INCLUDE_DIR "$ENV{OPENCV_ANDROID_SDK}/sdk/native/jni/include")

target_include_directories(${TARGET} PRIVATE
"${OPENCV_INCLUDE_DIR}"
${PYTORCH_INCLUDE_DIRS})

set(OPENCV_LIB_DIR "$ENV{OPENCV_ANDROID_SDK}/sdk/native/libs/${ANDROID_ABI}")

find_library(OPENCV_LIBRARY opencv_java4
PATHS ${OPENCV_LIB_DIR}
NO_CMAKE_FIND_ROOT_PATH)

target_link_libraries(${TARGET}
${PYTORCH_LIBRARY}
${FBJNI_LIBRARY}
${OPENCV_LIBRARY}
log)
70 changes: 70 additions & 0 deletions NativeApp/app/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
apply plugin: 'com.android.application'

repositories {
jcenter()
maven {
url "https://oss.sonatype.org/content/repositories/snapshots"
}
}

android {
configurations {
extractForNativeBuild
}
compileSdkVersion 28
buildToolsVersion "29.0.2"
defaultConfig {
applicationId "org.pytorch.nativeapp"
minSdkVersion 21
targetSdkVersion 28
versionCode 1
versionName "1.0"
externalNativeBuild {
cmake {
arguments "-DANDROID_STL=c++_shared"
}
}
}
buildTypes {
release {
minifyEnabled false
}
}
externalNativeBuild {
cmake {
path "CMakeLists.txt"
}
}
sourceSets {
main {
jniLibs.srcDirs = ['src/main/jniLibs']
}
}
}

dependencies {
implementation 'com.android.support:appcompat-v7:28.0.0'

implementation 'org.pytorch:pytorch_android:1.6.0-SNAPSHOT'
extractForNativeBuild 'org.pytorch:pytorch_android:1.6.0-SNAPSHOT'
}

task extractAARForNativeBuild {
doLast {
configurations.extractForNativeBuild.files.each {
def file = it.absoluteFile
copy {
from zipTree(file)
into "$buildDir/$file.name"
include "headers/**"
include "jni/**"
}
}
}
}

tasks.whenTaskAdded { task ->
if (task.name.contains('externalNativeBuild')) {
task.dependsOn(extractAARForNativeBuild)
}
}
19 changes: 19 additions & 0 deletions NativeApp/app/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.pytorch.nativeapp">

<application
android:allowBackup="true"
android:label="PyTorchNativeApp"
android:supportsRtl="true"
android:theme="@style/Theme.AppCompat.Light.DarkActionBar">

<activity android:name=".MainActivity">
<intent-filter>
<action android:name="android.intent.action.MAIN" />

<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
</manifest>
3 changes: 3 additions & 0 deletions NativeApp/app/src/main/assets/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*
*/
!.gitignore
98 changes: 98 additions & 0 deletions NativeApp/app/src/main/cpp/pytorch_nativeapp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#include <android/log.h>
#include <cassert>
#include <cmath>
#include <pthread.h>
#include <unistd.h>
#include <vector>
#define ALOGI(...) \
__android_log_print(ANDROID_LOG_INFO, "PyTorchNativeApp", __VA_ARGS__)
#define ALOGE(...) \
__android_log_print(ANDROID_LOG_ERROR, "PyTorchNativeApp", __VA_ARGS__)

#include "jni.h"

#include <opencv2/opencv.hpp>
#include <torch/script.h>

namespace pytorch_nativeapp {
namespace {
torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {
cv::Mat image_mat(/*rows=*/image.size(0),
/*cols=*/image.size(1),
/*type=*/CV_32FC1,
/*data=*/image.data_ptr<float>());
cv::Mat warp_mat(/*rows=*/warp.size(0),
/*cols=*/warp.size(1),
/*type=*/CV_32FC1,
/*data=*/warp.data_ptr<float>());

cv::Mat output_mat;
cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{8, 8});

torch::Tensor output =
torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{8, 8});
return output.clone();
}

static auto registry =
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);

template <typename T> void log(const char *m, T t) {
std::ostringstream os;
os << t << std::endl;
ALOGI("%s %s", m, os.str().c_str());
}

struct JITCallGuard {
torch::autograd::AutoGradMode no_autograd_guard{false};
torch::AutoNonVariableTypeMode non_var_guard{true};
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
};
} // namespace

static void loadAndForwardModel(JNIEnv *env, jclass, jstring jModelPath) {
const char *modelPath = env->GetStringUTFChars(jModelPath, 0);
assert(modelPath);

// To load torchscript model for mobile we need set these guards,
// because mobile build doesn't support features like autograd for smaller
// build size which is placed in `struct JITCallGuard` in this example. It may
// change in future, you can track the latest changes keeping an eye in
// android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp
JITCallGuard guard;
torch::jit::Module module = torch::jit::load(modelPath);
module.eval();
torch::Tensor x = torch::randn({4, 8});
torch::Tensor y = torch::randn({8, 5});
log("x:", x);
log("y:", y);
c10::IValue t_out = module.forward({x, y});
log("result:", t_out);
env->ReleaseStringUTFChars(jModelPath, modelPath);
}
} // namespace pytorch_nativeapp

JNIEXPORT jint JNI_OnLoad(JavaVM *vm, void *) {
JNIEnv *env;
if (vm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6) != JNI_OK) {
return JNI_ERR;
}

jclass c = env->FindClass("org/pytorch/nativeapp/NativeClient$NativePeer");
if (c == nullptr) {
return JNI_ERR;
}

static const JNINativeMethod methods[] = {
{"loadAndForwardModel", "(Ljava/lang/String;)V",
(void *)pytorch_nativeapp::loadAndForwardModel},
};
int rc = env->RegisterNatives(c, methods,
sizeof(methods) / sizeof(JNINativeMethod));

if (rc != JNI_OK) {
return rc;
}

return JNI_VERSION_1_6;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package org.pytorch.nativeapp;

import android.content.Context;
import android.os.Bundle;
import android.util.Log;
import androidx.appcompat.app.AppCompatActivity;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;

public class MainActivity extends AppCompatActivity {

private static final String TAG = "PyTorchNativeApp";

public static String assetFilePath(Context context, String assetName) {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}

try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
} catch (IOException e) {
Log.e(TAG, "Error process asset " + assetName + " to file path");
}
return null;
}

@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
final String modelFileAbsoluteFilePath =
new File(assetFilePath(this, "compute.pt")).getAbsolutePath();
NativeClient.loadAndForwardModel(modelFileAbsoluteFilePath);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.pytorch.nativeapp;

public final class NativeClient {

public static void loadAndForwardModel(final String modelPath) {
NativePeer.loadAndForwardModel(modelPath);
}

private static class NativePeer {
static {
System.loadLibrary("pytorch_nativeapp");
}

private static native void loadAndForwardModel(final String modelPath);
}
}
3 changes: 3 additions & 0 deletions NativeApp/app/src/main/jniLibs/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*
*/
!.gitignore
20 changes: 20 additions & 0 deletions NativeApp/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
buildscript {
repositories {
google()
jcenter()
}
dependencies {
classpath 'com.android.tools.build:gradle:3.5.0'
}
}

allprojects {
repositories {
google()
jcenter()
}
}

task clean(type: Delete) {
delete rootProject.buildDir
}
3 changes: 3 additions & 0 deletions NativeApp/gradle.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
android.useAndroidX=true
android.enableJetifier=true

52 changes: 52 additions & 0 deletions NativeApp/make_warp_perspective_pt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch
import torch.utils.cpp_extension

print(torch.version.__version__)
op_source = """
#include <opencv2/opencv.hpp>
#include <torch/script.h>
torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {
cv::Mat image_mat(/*rows=*/image.size(0),
/*cols=*/image.size(1),
/*type=*/CV_32FC1,
/*data=*/image.data_ptr<float>());
cv::Mat warp_mat(/*rows=*/warp.size(0),
/*cols=*/warp.size(1),
/*type=*/CV_32FC1,
/*data=*/warp.data_ptr<float>());
cv::Mat output_mat;
cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{64, 64});
torch::Tensor output =
torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{64, 64});
return output.clone();
}
static auto registry =
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);
"""

torch.utils.cpp_extension.load_inline(
name="warp_perspective",
cpp_sources=op_source,
extra_ldflags=["-lopencv_core", "-lopencv_imgproc"],
is_python_module=False,
verbose=True,
)

print(torch.ops.my_ops.warp_perspective)


@torch.jit.script
def compute(x, y):
if bool(x[0][0] == 42):
z = 5
else:
z = 10
x = torch.ops.my_ops.warp_perspective(x, torch.eye(3))
return x.matmul(y) + z


compute.save("app/src/main/assets/compute.pt")
1 change: 1 addition & 0 deletions NativeApp/settings.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include ':app'

0 comments on commit 0570874

Please sign in to comment.