Skip to content

Commit

Permalink
Add general support for PathToResourceAsFile to TfLiteModelLoader
Browse files Browse the repository at this point in the history
Makes sure drishti::PathToResourceAsFile is called for all code paths, including gLinux. This guarantees that the --resource_root_dir flag is properly taken care of. Also simplifies code by deduplicating common code.

PiperOrigin-RevId: 599133782
  • Loading branch information
MediaPipe Team authored and copybara-github committed Jan 17, 2024
1 parent a471b64 commit 38ee0af
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 4 deletions.
30 changes: 29 additions & 1 deletion mediapipe/util/tflite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
# limitations under the License.
#

load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
load(
"@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl",
"cc_library_with_tflite",
"cc_test_with_tflite",
)

licenses(["notice"])

Expand Down Expand Up @@ -141,3 +145,27 @@ cc_library_with_tflite(
"//mediapipe/util:resource_util",
],
)

cc_test_with_tflite(
name = "tflite_model_loader_test",
srcs = ["tflite_model_loader_test.cc"],
data = [
":testdata/test_model.tflite",
],
tflite_deps = [
":tflite_model_loader",
"@org_tensorflow//tensorflow/lite:test_util",
],
deps = [
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_state",
"//mediapipe/framework:legacy_calculator_support",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/tool:tag_map_helper",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
],
)
Binary file added mediapipe/util/tflite/testdata/test_model.tflite
Binary file not shown.
9 changes: 6 additions & 3 deletions mediapipe/util/tflite/tflite_model_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

#include "mediapipe/util/tflite/tflite_model_loader.h"

#include <string>
#include <utility>

#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/resource_util.h"

Expand All @@ -26,11 +29,10 @@ absl::StatusOr<api2::Packet<TfLiteModelPtr>> TfLiteModelLoader::LoadFromPath(
std::string model_path = path;

std::string model_blob;
auto status_or_content =
mediapipe::GetResourceContents(model_path, &model_blob);
absl::Status status = mediapipe::GetResourceContents(model_path, &model_blob);
// TODO: get rid of manual resolving with PathToResourceAsFile
// as soon as it's incorporated into GetResourceContents.
if (!status_or_content.ok()) {
if (!status.ok()) {
MP_ASSIGN_OR_RETURN(auto resolved_path,
mediapipe::PathToResourceAsFile(model_path));
VLOG(2) << "Loading the model from " << resolved_path;
Expand All @@ -40,6 +42,7 @@ absl::StatusOr<api2::Packet<TfLiteModelPtr>> TfLiteModelLoader::LoadFromPath(

auto model = FlatBufferModel::VerifyAndBuildFromBuffer(model_blob.data(),
model_blob.size());

RET_CHECK(model) << "Failed to load model from path " << model_path;
return api2::MakePacket<TfLiteModelPtr>(
model.release(),
Expand Down
68 changes: 68 additions & 0 deletions mediapipe/util/tflite/tflite_model_loader_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include "mediapipe/util/tflite/tflite_model_loader.h"

#include <memory>
#include <string>

#include "absl/flags/declare.h"
#include "absl/flags/flag.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_state.h"
#include "mediapipe/framework/legacy_calculator_support.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/tag_map_helper.h"
#include "tensorflow/lite/test_util.h"

ABSL_DECLARE_FLAG(std::string, resource_root_dir);

namespace mediapipe {
namespace {

constexpr char kModelDir[] = "mediapipe/util/tflite/testdata";
constexpr char kModelFilename[] = "test_model.tflite";

class TfLiteModelLoaderTest : public tflite::testing::Test {
void SetUp() override {
// Create a stub calculator state.
CalculatorGraphConfig::Node config;
calculator_state_ = std::make_unique<CalculatorState>(
"fake_node", 0, "fake_type", config, nullptr);

// Create a stub calculator context.
calculator_context_ = std::make_unique<CalculatorContext>(
calculator_state_.get(), tool::CreateTagMap({}).value(),
tool::CreateTagMap({}).value());
}

protected:
std::unique_ptr<CalculatorState> calculator_state_;
std::unique_ptr<CalculatorContext> calculator_context_;
std::string model_path_ = absl::StrCat(kModelDir, "/", kModelFilename);
};

TEST_F(TfLiteModelLoaderTest, LoadFromPath) {
// TODO: remove LegacyCalculatorSupport usage.
LegacyCalculatorSupport::Scoped<CalculatorContext> scope(
calculator_context_.get());
MP_ASSERT_OK_AND_ASSIGN(api2::Packet<TfLiteModelPtr> model,
TfLiteModelLoader::LoadFromPath(model_path_));
EXPECT_NE(model.Get(), nullptr);
}

TEST_F(TfLiteModelLoaderTest, LoadFromPathRelativeToRootDir) {
absl::SetFlag(&FLAGS_resource_root_dir, kModelDir);

// TODO: remove LegacyCalculatorSupport usage.
LegacyCalculatorSupport::Scoped<CalculatorContext> scope(
calculator_context_.get());
MP_ASSERT_OK_AND_ASSIGN(api2::Packet<TfLiteModelPtr> model,
TfLiteModelLoader::LoadFromPath(kModelFilename));
EXPECT_NE(model.Get(), nullptr);
}

} // namespace
} // namespace mediapipe

0 comments on commit 38ee0af

Please sign in to comment.