Skip to content

Commit

Permalink
Merge pull request #5028 from kinaryml:python-holistic-landmarker
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 594995636
  • Loading branch information
copybara-github committed Jan 2, 2024
2 parents 8609e5f + 569c16d commit e23fa53
Show file tree
Hide file tree
Showing 8 changed files with 1,199 additions and 22 deletions.
1 change: 1 addition & 0 deletions mediapipe/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
"//mediapipe/tasks/cc/vision/face_stylizer:face_stylizer_graph",
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
"//mediapipe/tasks/cc/vision/holistic_landmarker:holistic_landmarker_graph",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
Expand Down
1 change: 1 addition & 0 deletions mediapipe/tasks/python/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,6 @@ py_library(
"//mediapipe/calculators/core:flow_limiter_calculator_py_pb2",
"//mediapipe/framework:calculator_options_py_pb2",
"//mediapipe/framework:calculator_py_pb2",
"@com_google_protobuf//:protobuf_python",
],
)
49 changes: 28 additions & 21 deletions mediapipe/tasks/python/core/task_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
"""MediaPipe Tasks' task info data class."""

import dataclasses

from typing import Any, List

from google.protobuf import any_pb2
from mediapipe.calculators.core import flow_limiter_calculator_pb2
from mediapipe.framework import calculator_options_pb2
from mediapipe.framework import calculator_pb2
Expand Down Expand Up @@ -80,21 +79,34 @@ def add_stream_name_prefix(tag_index_name):
raise ValueError(
'`task_options` doesn`t provide `to_pb2()` method to convert itself to be a protobuf object.'
)
task_subgraph_options = calculator_options_pb2.CalculatorOptions()

task_options_proto = self.task_options.to_pb2()
task_subgraph_options.Extensions[task_options_proto.ext].CopyFrom(
task_options_proto)

node_config = calculator_pb2.CalculatorGraphConfig.Node(
calculator=self.task_graph,
input_stream=self.input_streams,
output_stream=self.output_streams,
)

if hasattr(task_options_proto, 'ext'):
# Use the extension mechanism for task_subgraph_options (proto2)
task_subgraph_options = calculator_options_pb2.CalculatorOptions()
task_subgraph_options.Extensions[task_options_proto.ext].CopyFrom(
task_options_proto
)
node_config.options.CopyFrom(task_subgraph_options)
else:
# Use the Any type for task_subgraph_options (proto3)
task_subgraph_options = any_pb2.Any()
task_subgraph_options.Pack(self.task_options.to_pb2())
node_config.node_options.append(task_subgraph_options)

if not enable_flow_limiting:
return calculator_pb2.CalculatorGraphConfig(
node=[
calculator_pb2.CalculatorGraphConfig.Node(
calculator=self.task_graph,
input_stream=self.input_streams,
output_stream=self.output_streams,
options=task_subgraph_options)
],
node=[node_config],
input_stream=self.input_streams,
output_stream=self.output_streams)
output_stream=self.output_streams,
)
# When a FlowLimiterCalculator is inserted to lower the overall graph
# latency, the task doesn't guarantee that each input must have the
# corresponding output.
Expand All @@ -120,13 +132,8 @@ def add_stream_name_prefix(tag_index_name):
],
options=flow_limiter_options)
config = calculator_pb2.CalculatorGraphConfig(
node=[
calculator_pb2.CalculatorGraphConfig.Node(
calculator=self.task_graph,
input_stream=task_subgraph_inputs,
output_stream=self.output_streams,
options=task_subgraph_options), flow_limiter
],
node=[node_config, flow_limiter],
input_stream=self.input_streams,
output_stream=self.output_streams)
output_stream=self.output_streams,
)
return config
21 changes: 21 additions & 0 deletions mediapipe/tasks/python/test/vision/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,27 @@ py_test(
],
)

py_test(
name = "holistic_landmarker_test",
srcs = ["holistic_landmarker_test.py"],
data = [
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models",
"//mediapipe/tasks/testdata/vision:test_protos",
],
tags = ["not_run:arm"],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_result_py_pb2",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:holistic_landmarker",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
"@com_google_protobuf//:protobuf_python",
],
)

py_test(
name = "face_aligner_test",
srcs = ["face_aligner_test.py"],
Expand Down
Loading

0 comments on commit e23fa53

Please sign in to comment.