Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to support Android V2 embedding #230

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ android {
}

dependencies {
compile 'org.tensorflow:tensorflow-lite:+'
compile 'org.tensorflow:tensorflow-lite-gpu:+'
implementation 'org.tensorflow:tensorflow-lite:+'
implementation 'org.tensorflow:tensorflow-lite-gpu:+'
}
}
62 changes: 49 additions & 13 deletions android/src/main/java/sq/flutter/tflite/TflitePlugin.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package sq.flutter.tflite;

import android.app.Activity;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.content.res.AssetManager;
Expand All @@ -19,11 +20,20 @@
import android.renderscript.Type;
import android.util.Log;

import androidx.annotation.NonNull;

import io.flutter.FlutterInjector;
import io.flutter.embedding.android.FlutterActivity;
import io.flutter.embedding.engine.loader.FlutterLoader;
import io.flutter.embedding.engine.plugins.activity.ActivityAware;
import io.flutter.embedding.engine.plugins.FlutterPlugin;
import io.flutter.embedding.engine.plugins.activity.ActivityPluginBinding;
import io.flutter.plugin.common.BinaryMessenger;
import io.flutter.plugin.common.MethodCall;
import io.flutter.plugin.common.MethodChannel;
import io.flutter.plugin.common.MethodChannel.MethodCallHandler;
import io.flutter.plugin.common.MethodChannel.Result;
import io.flutter.plugin.common.PluginRegistry.Registrar;
import io.flutter.plugin.common.PluginRegistry;

import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Interpreter;
Expand Down Expand Up @@ -52,8 +62,8 @@
import java.util.Vector;


public class TflitePlugin implements MethodCallHandler {
private final Registrar mRegistrar;
public class TflitePlugin implements FlutterPlugin, MethodCallHandler, ActivityAware {
private Activity activity;
private Interpreter tfLite;
private boolean tfLiteBusy = false;
private int inputSize = 0;
Expand Down Expand Up @@ -82,17 +92,41 @@ public class TflitePlugin implements MethodCallHandler {
List<Integer> parentToChildEdges = new ArrayList<>();
List<Integer> childToParentEdges = new ArrayList<>();

public static void registerWith(Registrar registrar) {
final MethodChannel channel = new MethodChannel(registrar.messenger(), "tflite");
channel.setMethodCallHandler(new TflitePlugin(registrar));
private MethodChannel channel;

@Override
public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) {
channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), "tflite");
channel.setMethodCallHandler(this);
}

@Override
public void onAttachedToActivity(@NonNull ActivityPluginBinding binding) {
activity = binding.getActivity();
}

@Override
public void onDetachedFromActivity() {
activity = null;
}

@Override
public void onDetachedFromActivityForConfigChanges() {
this.onDetachedFromActivity();
}

private TflitePlugin(Registrar registrar) {
this.mRegistrar = registrar;
@Override
public void onReattachedToActivityForConfigChanges(@NonNull ActivityPluginBinding binding) {
this.onAttachedToActivity(binding);
}

@Override
public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) {
channel.setMethodCallHandler(null);
}

@Override
public void onMethodCall(MethodCall call, Result result) {
public void onMethodCall(@NonNull MethodCall call, @NonNull Result result) {
if (call.method.equals("loadModel")) {
try {
String res = loadModel((HashMap) call.arguments);
Expand Down Expand Up @@ -205,8 +239,9 @@ private String loadModel(HashMap args) throws IOException {
String key = null;
AssetManager assetManager = null;
if (isAsset) {
assetManager = mRegistrar.context().getAssets();
key = mRegistrar.lookupKeyForAsset(model);
assetManager = activity.getApplicationContext().getAssets();
FlutterLoader loader = FlutterInjector.instance().flutterLoader();
key = loader.getLookupKeyForAsset(model);
AssetFileDescriptor fileDescriptor = assetManager.openFd(key);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
Expand Down Expand Up @@ -238,7 +273,8 @@ private String loadModel(HashMap args) throws IOException {

if (labels.length() > 0) {
if (isAsset) {
key = mRegistrar.lookupKeyForAsset(labels);
FlutterLoader loader = FlutterInjector.instance().flutterLoader();
key = loader.getLookupKeyForAsset(labels);
loadLabels(assetManager, key);
} else {
loadLabels(null, labels);
Expand Down Expand Up @@ -411,7 +447,7 @@ ByteBuffer feedInputTensorFrame(List<byte[]> bytesList, int imageHeight, int ima

Bitmap bitmapRaw = Bitmap.createBitmap(imageWidth, imageHeight, Bitmap.Config.ARGB_8888);
Allocation bmData = renderScriptNV21ToRGBA888(
mRegistrar.context(),
activity.getApplicationContext(),
imageWidth,
imageHeight,
data);
Expand Down
104 changes: 52 additions & 52 deletions ios/Classes/TflitePlugin.mm
Original file line number Diff line number Diff line change
Expand Up @@ -1116,13 +1116,13 @@ void runSegmentationOnFrame(NSDictionary* args, FlutterResult result) {
}


NSArray* part_names = @[
NSArray* _tflite_part_names = @[
@"nose", @"leftEye", @"rightEye", @"leftEar", @"rightEar", @"leftShoulder",
@"rightShoulder", @"leftElbow", @"rightElbow", @"leftWrist", @"rightWrist",
@"leftHip", @"rightHip", @"leftKnee", @"rightKnee", @"leftAnkle", @"rightAnkle"
];

NSArray* pose_chain = @[
NSArray* _tflite_pose_chain = @[
@[@"nose", @"leftEye"], @[@"leftEye", @"leftEar"], @[@"nose", @"rightEye"],
@[@"rightEye", @"rightEar"], @[@"nose", @"leftShoulder"],
@[@"leftShoulder", @"leftElbow"], @[@"leftElbow", @"leftWrist"],
Expand All @@ -1133,23 +1133,23 @@ void runSegmentationOnFrame(NSDictionary* args, FlutterResult result) {
@[@"rightKnee", @"rightAnkle"]
];

NSMutableDictionary* parts_ids = [NSMutableDictionary dictionary];
NSMutableArray* parent_to_child_edges = [NSMutableArray array];
NSMutableArray* child_to_parent_edges = [NSMutableArray array];
int local_maximum_radius = 1;
int output_stride = 16;
int height;
int width;
int num_keypoints;
NSMutableDictionary* _tflite_parts_ids = [NSMutableDictionary dictionary];
NSMutableArray* _tflite_parent_to_child_edges = [NSMutableArray array];
NSMutableArray* _tflite_child_to_parent_edges = [NSMutableArray array];
int _tflite_local_maximum_radius = 1;
int _tflite_output_stride = 16;
int _tflite_height;
int _tflite_width;
int _tflite_num_keypoints;

void initPoseNet() {
if ([parts_ids count] == 0) {
for (int i = 0; i < [part_names count]; ++i)
[parts_ids setValue:[NSNumber numberWithInt:i] forKey:part_names[i]];
if ([_tflite_parts_ids count] == 0) {
for (int i = 0; i < [_tflite_part_names count]; ++i)
[_tflite_parts_ids setValue:[NSNumber numberWithInt:i] forKey:_tflite_part_names[i]];

for (int i = 0; i < [pose_chain count]; ++i) {
[parent_to_child_edges addObject:parts_ids[pose_chain[i][1]]];
[child_to_parent_edges addObject:parts_ids[pose_chain[i][0]]];
for (int i = 0; i < [_tflite_pose_chain count]; ++i) {
[_tflite_parent_to_child_edges addObject:_tflite_parts_ids[_tflite_pose_chain[i][1]]];
[_tflite_child_to_parent_edges addObject:_tflite_parts_ids[_tflite_pose_chain[i][0]]];
}
}
}
Expand All @@ -1163,12 +1163,12 @@ bool scoreIsMaximumInLocalWindow(int keypoint_id,
bool local_maxium = true;

int y_start = MAX(heatmap_y - local_maximum_radius, 0);
int y_end = MIN(heatmap_y + local_maximum_radius + 1, height);
int y_end = MIN(heatmap_y + local_maximum_radius + 1, _tflite_height);
for (int y_current = y_start; y_current < y_end; ++y_current) {
int x_start = MAX(heatmap_x - local_maximum_radius, 0);
int x_end = MIN(heatmap_x + local_maximum_radius + 1, width);
int x_end = MIN(heatmap_x + local_maximum_radius + 1, _tflite_width);
for (int x_current = x_start; x_current < x_end; ++x_current) {
if (sigmoid(scores[(y_current * width + x_current) * num_keypoints + keypoint_id]) > score) {
if (sigmoid(scores[(y_current * _tflite_width + x_current) * _tflite_num_keypoints + keypoint_id]) > score) {
local_maxium = false;
break;
}
Expand All @@ -1188,11 +1188,11 @@ PriorityQueue buildPartWithScoreQueue(float* scores,
float threshold,
int local_maximum_radius) {
PriorityQueue pq;
for (int heatmap_y = 0; heatmap_y < height; ++heatmap_y) {
for (int heatmap_x = 0; heatmap_x < width; ++heatmap_x) {
for (int keypoint_id = 0; keypoint_id < num_keypoints; ++keypoint_id) {
float score = sigmoid(scores[(heatmap_y * width + heatmap_x) *
num_keypoints + keypoint_id]);
for (int heatmap_y = 0; heatmap_y < _tflite_height; ++heatmap_y) {
for (int heatmap_x = 0; heatmap_x < _tflite_width; ++heatmap_x) {
for (int keypoint_id = 0; keypoint_id < _tflite_num_keypoints; ++keypoint_id) {
float score = sigmoid(scores[(heatmap_y * _tflite_width + heatmap_x) *
_tflite_num_keypoints + keypoint_id]);
if (score < threshold) continue;

if (scoreIsMaximumInLocalWindow(keypoint_id, score, heatmap_y, heatmap_x,
Expand All @@ -1217,11 +1217,11 @@ void getImageCoords(float* res,
int heatmap_x = [keypoint[@"x"] intValue];
int keypoint_id = [keypoint[@"partId"] intValue];

int offset = (heatmap_y * width + heatmap_x) * num_keypoints * 2 + keypoint_id;
int offset = (heatmap_y * _tflite_width + heatmap_x) * _tflite_num_keypoints * 2 + keypoint_id;
float offset_y = offsets[offset];
float offset_x = offsets[offset + num_keypoints];
res[0] = heatmap_y * output_stride + offset_y;
res[1] = heatmap_x * output_stride + offset_x;
float offset_x = offsets[offset + _tflite_num_keypoints];
res[0] = heatmap_y * _tflite_output_stride + offset_y;
res[1] = heatmap_x * _tflite_output_stride + offset_x;
}


Expand All @@ -1244,19 +1244,19 @@ bool withinNmsRadiusOfCorrespondingPoint(NSMutableArray* poses,
}

void getStridedIndexNearPoint(int* res, float _y, float _x) {
int y_ = round(_y / output_stride);
int x_ = round(_x / output_stride);
int y = y_ < 0 ? 0 : y_ > height - 1 ? height - 1 : y_;
int x = x_ < 0 ? 0 : x_ > width - 1 ? width - 1 : x_;
int y_ = round(_y / _tflite_output_stride);
int x_ = round(_x / _tflite_output_stride);
int y = y_ < 0 ? 0 : y_ > _tflite_height - 1 ? _tflite_height - 1 : y_;
int x = x_ < 0 ? 0 : x_ > _tflite_width - 1 ? _tflite_width - 1 : x_;
res[0] = y;
res[1] = x;
}

void getDisplacement(float* res, int edgeId, int* keypoint, float* displacements) {
int num_edges = (int)[parent_to_child_edges count];
int num_edges = (int)[_tflite_parent_to_child_edges count];
int y = keypoint[0];
int x = keypoint[1];
int offset = (y * width + x) * num_edges * 2 + edgeId;
int offset = (y * _tflite_width + x) * num_edges * 2 + edgeId;
res[0] = displacements[offset];
res[1] = displacements[offset + num_edges];
}
Expand All @@ -1265,7 +1265,7 @@ float getInstanceScore(NSMutableDictionary* keypoints) {
float scores = 0;
for (NSMutableDictionary* keypoint in keypoints.allValues)
scores += [keypoint[@"score"] floatValue];
return scores / num_keypoints;
return scores / _tflite_num_keypoints;
}

NSMutableDictionary* traverseToTargetKeypoint(int edge_id,
Expand Down Expand Up @@ -1298,25 +1298,25 @@ float getInstanceScore(NSMutableDictionary* keypoints) {
int target_keypoint_y = target_keypoint_indices[0];
int target_keypoint_x = target_keypoint_indices[1];

int offset = (target_keypoint_y * width + target_keypoint_x) * num_keypoints * 2 + target_keypoint_id;
int offset = (target_keypoint_y * _tflite_width + target_keypoint_x) * _tflite_num_keypoints * 2 + target_keypoint_id;
float offset_y = offsets[offset];
float offset_x = offsets[offset + num_keypoints];
float offset_x = offsets[offset + _tflite_num_keypoints];

target_keypoint[0] = target_keypoint_y * output_stride + offset_y;
target_keypoint[1] = target_keypoint_x * output_stride + offset_x;
target_keypoint[0] = target_keypoint_y * _tflite_output_stride + offset_y;
target_keypoint[1] = target_keypoint_x * _tflite_output_stride + offset_x;
}

int target_keypoint_indices[2];
getStridedIndexNearPoint(target_keypoint_indices, target_keypoint[0], target_keypoint[1]);

float score = sigmoid(scores[(target_keypoint_indices[0] * width +
target_keypoint_indices[1]) * num_keypoints + target_keypoint_id]);
float score = sigmoid(scores[(target_keypoint_indices[0] * _tflite_width +
target_keypoint_indices[1]) * _tflite_num_keypoints + target_keypoint_id]);

NSMutableDictionary* keypoint = [NSMutableDictionary dictionary];
[keypoint setValue:[NSNumber numberWithFloat:score] forKey:@"score"];
[keypoint setValue:[NSNumber numberWithFloat:target_keypoint[0] / input_size] forKey:@"y"];
[keypoint setValue:[NSNumber numberWithFloat:target_keypoint[1] / input_size] forKey:@"x"];
[keypoint setValue:part_names[target_keypoint_id] forKey:@"part"];
[keypoint setValue:_tflite_part_names[target_keypoint_id] forKey:@"part"];
return keypoint;
}

Expand All @@ -1330,9 +1330,9 @@ float getInstanceScore(NSMutableDictionary* keypoints) {
assert(interpreter->outputs().size() == 4);
TfLiteTensor* scores_tensor = interpreter->tensor(interpreter->outputs()[0]);
#endif
height = scores_tensor->dims->data[1];
width = scores_tensor->dims->data[2];
num_keypoints = scores_tensor->dims->data[3];
_tflite_height = scores_tensor->dims->data[1];
_tflite_width = scores_tensor->dims->data[2];
_tflite_num_keypoints = scores_tensor->dims->data[3];

#ifdef TFLITE2
float* scores = TfLiteInterpreterGetOutputTensor(interpreter, 0)->data.f;
Expand All @@ -1345,9 +1345,9 @@ float getInstanceScore(NSMutableDictionary* keypoints) {
float* displacements_fwd = interpreter->typed_output_tensor<float>(2);
float* displacements_bwd = interpreter->typed_output_tensor<float>(3);
#endif
PriorityQueue pq = buildPartWithScoreQueue(scores, threshold, local_maximum_radius);
PriorityQueue pq = buildPartWithScoreQueue(scores, threshold, _tflite_local_maximum_radius);

int num_edges = (int)[parent_to_child_edges count];
int num_edges = (int)[_tflite_parent_to_child_edges count];
int sqared_nms_radius = nms_radius * nms_radius;

NSMutableArray* results = [NSMutableArray array];
Expand All @@ -1367,14 +1367,14 @@ float getInstanceScore(NSMutableDictionary* keypoints) {
[keypoint setValue:[NSNumber numberWithFloat:[root[@"score"] floatValue]] forKey:@"score"];
[keypoint setValue:[NSNumber numberWithFloat:root_point[0] / input_size] forKey:@"y"];
[keypoint setValue:[NSNumber numberWithFloat:root_point[1] / input_size] forKey:@"x"];
[keypoint setValue:part_names[[root[@"partId"] intValue]] forKey:@"part"];
[keypoint setValue:_tflite_part_names[[root[@"partId"] intValue]] forKey:@"part"];

NSMutableDictionary* keypoints = [NSMutableDictionary dictionary];
[keypoints setObject:keypoint forKey:root[@"partId"]];

for (int edge = num_edges - 1; edge >= 0; --edge) {
int source_keypoint_id = [parent_to_child_edges[edge] intValue];
int target_keypoint_id = [child_to_parent_edges[edge] intValue];
int source_keypoint_id = [_tflite_parent_to_child_edges[edge] intValue];
int target_keypoint_id = [_tflite_child_to_parent_edges[edge] intValue];
if (keypoints[[NSNumber numberWithInt:source_keypoint_id]] &&
!(keypoints[[NSNumber numberWithInt:target_keypoint_id]])) {
keypoint = traverseToTargetKeypoint(edge, keypoints[[NSNumber numberWithInt:source_keypoint_id]],
Expand All @@ -1384,8 +1384,8 @@ float getInstanceScore(NSMutableDictionary* keypoints) {
}

for (int edge = 0; edge < num_edges; ++edge) {
int source_keypoint_id = [child_to_parent_edges[edge] intValue];
int target_keypoint_id = [parent_to_child_edges[edge] intValue];
int source_keypoint_id = [_tflite_child_to_parent_edges[edge] intValue];
int target_keypoint_id = [_tflite_parent_to_child_edges[edge] intValue];
if (keypoints[[NSNumber numberWithInt:source_keypoint_id]] &&
!(keypoints[[NSNumber numberWithInt:target_keypoint_id]])) {
keypoint = traverseToTargetKeypoint(edge, keypoints[[NSNumber numberWithInt:source_keypoint_id]],
Expand Down