From a67f79a257a8f284a39505842bdb1bd9009c316c Mon Sep 17 00:00:00 2001 From: Jun Jiang Date: Fri, 13 Dec 2024 13:38:15 -0800 Subject: [PATCH] Sync litert codes to LiteRT repo. PiperOrigin-RevId: 705987824 --- tflite/experimental/litert/BUILD | 18 + tflite/experimental/litert/build_common/BUILD | 20 + .../build_common/export_litert_only.lds | 29 + .../litert/build_common/litert_build_defs.bzl | 219 +++ tflite/experimental/litert/c/BUILD | 305 +++++ tflite/experimental/litert/c/litert_any.h | 49 + .../litert/c/litert_c_api_common_test.c | 40 + tflite/experimental/litert/c/litert_common.h | 97 ++ .../litert/c/litert_compiled_model.cc | 104 ++ .../litert/c/litert_compiled_model.h | 106 ++ .../litert/c/litert_compiled_model_options.h | 35 + .../litert/c/litert_compiled_model_test.cc | 166 +++ .../litert/c/litert_dispatch_delegate.h | 84 ++ .../litert/c/litert_environment.cc | 31 + .../litert/c/litert_environment.h | 49 + tflite/experimental/litert/c/litert_event.cc | 43 + tflite/experimental/litert/c/litert_event.h | 45 + tflite/experimental/litert/c/litert_layout.h | 45 + .../experimental/litert/c/litert_logging.cc | 107 ++ tflite/experimental/litert/c/litert_logging.h | 87 ++ .../litert/c/litert_logging_test.cc | 34 + tflite/experimental/litert/c/litert_model.cc | 366 +++++ tflite/experimental/litert/c/litert_model.h | 347 +++++ .../litert/c/litert_model_test.cc | 359 +++++ tflite/experimental/litert/c/litert_op_code.h | 245 ++++ .../experimental/litert/c/litert_options.cc | 253 ++++ tflite/experimental/litert/c/litert_options.h | 174 +++ .../litert/c/litert_options_test.cc | 236 ++++ .../litert/c/litert_tensor_buffer.cc | 313 +++++ .../litert/c/litert_tensor_buffer.h | 190 +++ .../c/litert_tensor_buffer_requirements.cc | 108 ++ .../c/litert_tensor_buffer_requirements.h | 57 + .../litert_tensor_buffer_requirements_test.cc | 92 ++ .../litert/c/litert_tensor_buffer_test.cc | 296 ++++ tflite/experimental/litert/cc/BUILD | 351 +++++ tflite/experimental/litert/cc/litert_any.h | 109 ++ .../experimental/litert/cc/litert_any_test.cc | 110 ++ .../litert/cc/litert_buffer_ref.h | 356 +++++ .../litert/cc/litert_buffer_ref_test.cc | 332 +++++ .../litert/cc/litert_compiled_model.cc | 124 ++ .../litert/cc/litert_compiled_model.h | 133 ++ .../litert/cc/litert_compiled_model_test.cc | 88 ++ tflite/experimental/litert/cc/litert_detail.h | 112 ++ .../litert/cc/litert_element_type.h | 154 +++ .../litert/cc/litert_element_type_test.cc | 48 + .../litert/cc/litert_environment.h | 82 ++ .../experimental/litert/cc/litert_expected.h | 338 +++++ .../litert/cc/litert_expected_test.cc | 191 +++ tflite/experimental/litert/cc/litert_handle.h | 74 + tflite/experimental/litert/cc/litert_layout.h | 102 ++ .../litert/cc/litert_layout_test.cc | 62 + tflite/experimental/litert/cc/litert_macros.h | 67 + tflite/experimental/litert/cc/litert_model.cc | 45 + tflite/experimental/litert/cc/litert_model.h | 490 +++++++ .../litert/cc/litert_model_predicates.cc | 115 ++ .../litert/cc/litert_model_predicates.h | 78 ++ .../litert/cc/litert_model_predicates_test.cc | 207 +++ .../litert/cc/litert_model_test.cc | 338 +++++ .../litert/cc/litert_tensor_buffer.h | 227 ++++ .../cc/litert_tensor_buffer_requirements.h | 106 ++ .../litert_tensor_buffer_requirements_test.cc | 104 ++ .../litert/cc/litert_tensor_buffer_test.cc | 397 ++++++ tflite/experimental/litert/compiler/BUILD | 18 + .../experimental/litert/compiler/plugin/BUILD | 103 ++ .../litert/compiler/plugin/algo.cc | 259 ++++ .../litert/compiler/plugin/algo.h | 39 + .../litert/compiler/plugin/algo_test.cc | 246 ++++ .../litert/compiler/plugin/compiler_plugin.cc | 409 ++++++ .../litert/compiler/plugin/compiler_plugin.h | 145 ++ .../compiler/plugin/compiler_plugin_test.cc | 160 +++ tflite/experimental/litert/core/BUILD | 141 ++ .../litert/core/byte_code_util.cc | 170 +++ .../experimental/litert/core/byte_code_util.h | 118 ++ .../litert/core/byte_code_util_test.cc | 109 ++ .../litert/core/dynamic_loading.cc | 98 ++ .../litert/core/dynamic_loading.h | 71 + .../litert/core/dynamic_loading_test.cc | 72 + .../experimental/litert/core/environment.cc | 55 + tflite/experimental/litert/core/environment.h | 62 + .../litert/core/environment_test.cc | 70 + tflite/experimental/litert/core/filesystem.cc | 97 ++ tflite/experimental/litert/core/filesystem.h | 48 + .../litert/core/filesystem_test.cc | 39 + tflite/experimental/litert/core/model/BUILD | 303 +++++ .../litert/core/model/flatbuffer_to_litert.cc | 148 ++ .../litert/core/model/flatbuffer_to_litert.h | 43 + .../core/model/flatbuffer_to_litert_test.cc | 110 ++ .../litert/core/model/graph_validation.cc | 114 ++ .../litert/core/model/graph_validation.h | 47 + .../litert/core/model/ir_allocator.h | 109 ++ .../litert/core/model/ir_allocator_test.cc | 108 ++ .../litert/core/model/litert_to_flatbuffer.cc | 126 ++ .../litert/core/model/litert_to_flatbuffer.h | 32 + .../core/model/litert_to_flatbuffer_test.cc | 108 ++ .../experimental/litert/core/model/model.cc | 136 ++ tflite/experimental/litert/core/model/model.h | 827 ++++++++++++ .../litert/core/model/model_buffer.cc | 97 ++ .../litert/core/model/model_buffer.h | 37 + .../litert/core/model/model_buffer_test.cc | 72 + .../litert/core/model/model_file_test.cc | 541 ++++++++ .../litert/core/model/model_file_test_util.cc | 181 +++ .../litert/core/model/model_file_test_util.h | 50 + .../litert/core/model/model_graph.cc | 181 +++ .../litert/core/model/model_graph.h | 105 ++ .../litert/core/model/model_graph_test.cc | 344 +++++ .../litert/core/model/model_load.cc | 321 +++++ .../litert/core/model/model_load.h | 36 + .../litert/core/model/model_serialize.cc | 272 ++++ .../litert/core/model/model_serialize.h | 46 + .../litert/core/model/model_test.cc | 280 ++++ tflite/experimental/litert/core/util/BUILD | 85 ++ .../litert/core/util/flatbuffer_tools.cc | 321 +++++ .../litert/core/util/flatbuffer_tools.h | 283 ++++ .../litert/core/util/flatbuffer_tools_test.cc | 175 +++ .../litert/core/util/tensor_type_util.cc | 59 + .../litert/core/util/tensor_type_util.h | 111 ++ .../litert/core/util/tensor_type_util_test.cc | 69 + .../litert/integration_test/BUILD | 18 + tflite/experimental/litert/runtime/BUILD | 153 +++ .../litert/runtime/ahwb_buffer.cc | 112 ++ .../experimental/litert/runtime/ahwb_buffer.h | 53 + .../litert/runtime/compiled_model.cc | 310 +++++ .../litert/runtime/compiled_model.h | 142 ++ .../litert/runtime/compiled_model_test.cc | 192 +++ .../litert/runtime/compiler/BUILD | 50 + .../compiler/jit_compilation_qualcomm_test.cc | 137 ++ .../litert/runtime/dispatch/BUILD | 180 +++ .../litert/runtime/dispatch/README.md | 20 + .../runtime/dispatch/dispatch_delegate.cc | 161 +++ .../dispatch_delegate_google_tensor_test.cc | 285 ++++ .../dispatch/dispatch_delegate_kernel.cc | 642 +++++++++ .../dispatch/dispatch_delegate_kernel.h | 115 ++ .../dispatch_delegate_mediatek_test.cc | 285 ++++ .../dispatch/dispatch_delegate_options.h | 113 ++ .../dispatch_delegate_qualcomm_test.cc | 284 ++++ .../runtime/dispatch/litert_dispatch.cc | 513 +++++++ .../litert/runtime/dmabuf_buffer.cc | 180 +++ .../litert/runtime/dmabuf_buffer.h | 35 + tflite/experimental/litert/runtime/event.cc | 71 + tflite/experimental/litert/runtime/event.h | 31 + .../runtime/external_litert_buffer_context.cc | 125 ++ .../runtime/external_litert_buffer_context.h | 115 ++ .../litert/runtime/fastrpc_buffer.cc | 143 ++ .../litert/runtime/fastrpc_buffer.h | 35 + .../experimental/litert/runtime/ion_buffer.cc | 181 +++ .../experimental/litert/runtime/ion_buffer.h | 35 + .../litert/runtime/tensor_buffer.cc | 437 ++++++ .../litert/runtime/tensor_buffer.h | 166 +++ .../experimental/litert/runtime/tfl_utils.cc | 99 ++ .../experimental/litert/runtime/tfl_utils.h | 34 + tflite/experimental/litert/test/BUILD | 122 ++ tflite/experimental/litert/test/common.cc | 65 + tflite/experimental/litert/test/common.h | 85 ++ tflite/experimental/litert/test/test_macros.h | 46 + tflite/experimental/litert/test/test_models.h | 123 ++ .../litert/test/testdata/add_cst.mlir | 7 + .../litert/test/testdata/add_simple.mlir | 6 + .../litert/test/testdata/cos_mul.mlir | 7 + .../test/testdata/dynamic_shape_tensor.mlir | 6 + .../test/testdata/fully_connected_3d.mlir | 6 + .../litert/test/testdata/mul_simple.mlir | 7 + .../litert/test/testdata/multi_subgraph.mlir | 21 + .../test/testdata/multi_subgraph_mul.mlir | 13 + .../litert/test/testdata/one_mul.mlir | 6 + .../litert/test/testdata/rms_norm.mlir | 16 + .../litert/test/testdata/simple_add_op.mlir | 6 + .../test/testdata/simple_batch_matmul_op.mlir | 6 + .../litert/test/testdata/simple_cast_op.mlir | 6 + .../testdata/simple_concatenation_op.mlir | 6 + .../litert/test/testdata/simple_cos_op.mlir | 6 + .../litert/test/testdata/simple_div_op.mlir | 6 + .../testdata/simple_embedding_lookup_op.mlir | 7 + .../test/testdata/simple_floor_mod_op.mlir | 6 + .../testdata/simple_fully_connected_op.mlir | 6 + .../test/testdata/simple_greater_op.mlir | 6 + .../litert/test/testdata/simple_less_op.mlir | 6 + .../test/testdata/simple_logical_and_op.mlir | 6 + .../litert/test/testdata/simple_model.mlir | 6 + .../testdata/simple_model_google_tensor.bin | Bin 0 -> 12288 bytes .../litert/test/testdata/simple_model_mtk.bin | Bin 0 -> 6956 bytes .../test/testdata/simple_model_npu.mlir | 6 + .../test/testdata/simple_model_qualcomm.bin | Bin 0 -> 13800 bytes .../test/testdata/simple_model_test_vectors.h | 67 + .../litert/test/testdata/simple_mul_op.mlir | 6 + .../litert/test/testdata/simple_multi_op.mlir | 9 + .../test/testdata/simple_reshape_op.mlir | 6 + .../litert/test/testdata/simple_rsqrt_op.mlir | 6 + .../test/testdata/simple_select_op.mlir | 6 + .../test/testdata/simple_select_v2_op.mlir | 6 + .../litert/test/testdata/simple_sin_op.mlir | 6 + .../litert/test/testdata/simple_slice_op.mlir | 8 + .../test/testdata/simple_softmax_op.mlir | 6 + .../testdata/simple_stablehlo_scatter_op.mlir | 9 + .../testdata/simple_strided_slice_op.mlir | 6 + .../litert/test/testdata/simple_sub_op.mlir | 6 + .../litert/test/testdata/simple_sum_op.mlir | 7 + .../litert/test/testdata/simple_tanh_op.mlir | 6 + .../test/testdata/simple_transpose_op.mlir | 7 + .../litert/test/testdata/two_partition.mlir | 9 + .../litert/test/testdata/unranked_tensor.mlir | 6 + tflite/experimental/litert/tools/BUILD | 189 +++ .../experimental/litert/tools/apply_plugin.cc | 700 ++++++++++ .../experimental/litert/tools/apply_plugin.h | 177 +++ .../litert/tools/apply_plugin_main.cc | 140 ++ .../litert/tools/apply_plugin_test.cc | 227 ++++ tflite/experimental/litert/tools/dump.cc | 436 ++++++ tflite/experimental/litert/tools/dump.h | 71 + tflite/experimental/litert/tools/dump_test.cc | 131 ++ tflite/experimental/litert/tools/outstream.h | 83 ++ .../experimental/litert/tools/tool_display.cc | 86 ++ .../experimental/litert/tools/tool_display.h | 102 ++ .../litert/tools/tool_display_test.cc | 99 ++ tflite/experimental/litert/vendors/c/BUILD | 68 + .../litert/vendors/c/litert_compiler_plugin.h | 102 ++ .../vendors/c/litert_compiler_plugin_api.h | 130 ++ .../litert/vendors/c/litert_dispatch.h | 275 ++++ .../litert/vendors/c/litert_dispatch_api.h | 222 +++ .../c/litert_vendor_c_api_common_test.c | 28 + tflite/experimental/litert/vendors/cc/BUILD | 27 + .../vendors/cc/litert_compiler_plugin.h | 46 + .../litert/vendors/examples/BUILD | 59 + .../litert/vendors/examples/example_plugin.cc | 194 +++ .../vendors/examples/example_plugin_test.cc | 97 ++ .../vendors/google_tensor/dispatch/BUILD | 85 ++ .../google_tensor/dispatch/dispatch_api.cc | 1194 +++++++++++++++++ .../google_tensor/dispatch/dispatch_api.h | 67 + .../dispatch_api_google_tensor_test.cc | 282 ++++ .../litert_dispatch_device_context.cc | 61 + .../dispatch/litert_dispatch_device_context.h | 49 + .../dispatch/litert_dispatch_graph.h | 94 ++ .../litert_dispatch_invocation_context.cc | 84 ++ .../litert_dispatch_invocation_context.h | 66 + .../google_tensor/dispatch/southbound.cc | 146 ++ .../google_tensor/dispatch/southbound.h | 128 ++ .../litert/vendors/mediatek/BUILD | 34 + .../litert/vendors/mediatek/dispatch/BUILD | 85 ++ .../vendors/mediatek/dispatch/README.md | 4 + .../vendors/mediatek/dispatch/dispatch_api.cc | 327 +++++ .../dispatch/dispatch_api_mediatek_test.cc | 331 +++++ .../litert_dispatch_device_context.cc | 137 ++ .../dispatch/litert_dispatch_device_context.h | 87 ++ .../litert_dispatch_invocation_context.cc | 422 ++++++ .../litert_dispatch_invocation_context.h | 96 ++ .../litert/vendors/mediatek/neuron_adapter.cc | 130 ++ .../litert/vendors/mediatek/neuron_adapter.h | 219 +++ .../litert/vendors/qualcomm/BUILD | 136 ++ .../litert/vendors/qualcomm/common.h | 100 ++ .../litert/vendors/qualcomm/compiler/BUILD | 170 +++ .../litert/vendors/qualcomm/compiler/IR/BUILD | 123 ++ .../compiler/IR/op_compatibility_test.cc | 82 ++ .../vendors/qualcomm/compiler/IR/qnn_op.cc | 147 ++ .../vendors/qualcomm/compiler/IR/qnn_op.h | 53 + .../qualcomm/compiler/IR/qnn_op_test.cc | 65 + .../qualcomm/compiler/IR/qnn_tensor.cc | 246 ++++ .../vendors/qualcomm/compiler/IR/qnn_tensor.h | 73 + .../qualcomm/compiler/IR/qnn_tensor_test.cc | 203 +++ .../vendors/qualcomm/compiler/graph_mapper.cc | 163 +++ .../vendors/qualcomm/compiler/graph_mapper.h | 121 ++ .../qualcomm/compiler/legalizations/BUILD | 789 +++++++++++ .../legalizations/add_op_legalization.cc | 51 + .../legalizations/add_op_legalization.h | 49 + .../batch_matmul_op_legalization.cc | 52 + .../batch_matmul_op_legalization.h | 49 + .../legalizations/cast_op_legalization.cc | 49 + .../legalizations/cast_op_legalization.h | 49 + .../concatenation_op_legalization.cc | 101 ++ .../concatenation_op_legalization.h | 51 + .../legalizations/cos_op_legalization.cc | 49 + .../legalizations/cos_op_legalization.h | 49 + .../legalizations/div_op_legalization.cc | 51 + .../legalizations/div_op_legalization.h | 49 + .../embedding_lookup_op_legalization.cc | 104 ++ .../embedding_lookup_op_legalization.h | 51 + .../fully_connected_op_legalization.cc | 52 + .../fully_connected_op_legalization.h | 51 + .../legalizations/greater_op_legalization.cc | 51 + .../legalizations/greater_op_legalization.h | 49 + .../compiler/legalizations/legalization.h | 50 + .../legalizations/less_op_legalization.cc | 51 + .../legalizations/less_op_legalization.h | 49 + .../logical_and_op_legalization.cc | 51 + .../logical_and_op_legalization.h | 51 + .../legalizations/mul_op_legalization.cc | 51 + .../legalizations/mul_op_legalization.h | 49 + .../legalizations/reshape_op_legalization.cc | 82 ++ .../legalizations/reshape_op_legalization.h | 49 + .../legalizations/rsqrt_op_legalization.cc | 51 + .../legalizations/rsqrt_op_legalization.h | 49 + .../legalizations/select_op_legalization.cc | 55 + .../legalizations/select_op_legalization.h | 49 + .../legalizations/sin_op_legalization.cc | 49 + .../legalizations/sin_op_legalization.h | 49 + .../legalizations/slice_op_legalization.cc | 153 +++ .../legalizations/slice_op_legalization.h | 47 + .../legalizations/softmax_op_legalization.cc | 100 ++ .../legalizations/softmax_op_legalization.h | 49 + .../legalizations/sub_op_legalization.cc | 51 + .../legalizations/sub_op_legalization.h | 49 + .../legalizations/sum_op_legalization.cc | 139 ++ .../legalizations/sum_op_legalization.h | 49 + .../legalizations/tanh_op_legalization.cc | 51 + .../legalizations/tanh_op_legalization.h | 49 + .../transpose_op_legalization.cc | 121 ++ .../legalizations/transpose_op_legalization.h | 49 + .../qualcomm/compiler/legalizations/util.cc | 87 ++ .../qualcomm/compiler/legalizations/util.h | 39 + .../qualcomm/compiler/qnn_compiler_plugin.cc | 291 ++++ .../compiler/qnn_compiler_plugin_test.cc | 190 +++ .../qualcomm/compiler/qnn_compose_graph.cc | 173 +++ .../qualcomm/compiler/qnn_compose_graph.h | 33 + .../vendors/qualcomm/context_binary_info.cc | 216 +++ .../vendors/qualcomm/context_binary_info.h | 68 + .../litert/vendors/qualcomm/dispatch/BUILD | 93 ++ .../vendors/qualcomm/dispatch/dispatch_api.cc | 296 ++++ .../dispatch/dispatch_api_qualcomm_test.cc | 532 ++++++++ .../litert_dispatch_device_context.cc | 190 +++ .../dispatch/litert_dispatch_device_context.h | 79 ++ .../litert_dispatch_invocation_context.cc | 238 ++++ .../litert_dispatch_invocation_context.h | 82 ++ .../vendors/qualcomm/dispatch/registry.h | 73 + .../litert/vendors/qualcomm/qnn_log.cc | 64 + .../litert/vendors/qualcomm/qnn_log.h | 28 + .../litert/vendors/qualcomm/qnn_manager.cc | 387 ++++++ .../litert/vendors/qualcomm/qnn_manager.h | 226 ++++ .../vendors/qualcomm/qnn_manager_test.cc | 50 + .../litert/vendors/qualcomm/qnn_tensor.cc | 104 ++ .../litert/vendors/qualcomm/qnn_tensor.h | 60 + .../vendors/qualcomm/qualcomm_build_defs.bzl | 118 ++ .../litert/vendors/qualcomm/tools/BUILD | 31 + .../litert/vendors/qualcomm/tools/dump.cc | 88 ++ .../litert/vendors/qualcomm/tools/dump.h | 29 + 331 files changed, 41202 insertions(+) create mode 100644 tflite/experimental/litert/BUILD create mode 100644 tflite/experimental/litert/build_common/BUILD create mode 100644 tflite/experimental/litert/build_common/export_litert_only.lds create mode 100644 tflite/experimental/litert/build_common/litert_build_defs.bzl create mode 100644 tflite/experimental/litert/c/BUILD create mode 100644 tflite/experimental/litert/c/litert_any.h create mode 100644 tflite/experimental/litert/c/litert_c_api_common_test.c create mode 100644 tflite/experimental/litert/c/litert_common.h create mode 100644 tflite/experimental/litert/c/litert_compiled_model.cc create mode 100644 tflite/experimental/litert/c/litert_compiled_model.h create mode 100644 tflite/experimental/litert/c/litert_compiled_model_options.h create mode 100644 tflite/experimental/litert/c/litert_compiled_model_test.cc create mode 100644 tflite/experimental/litert/c/litert_dispatch_delegate.h create mode 100644 tflite/experimental/litert/c/litert_environment.cc create mode 100644 tflite/experimental/litert/c/litert_environment.h create mode 100644 tflite/experimental/litert/c/litert_event.cc create mode 100644 tflite/experimental/litert/c/litert_event.h create mode 100644 tflite/experimental/litert/c/litert_layout.h create mode 100644 tflite/experimental/litert/c/litert_logging.cc create mode 100644 tflite/experimental/litert/c/litert_logging.h create mode 100644 tflite/experimental/litert/c/litert_logging_test.cc create mode 100644 tflite/experimental/litert/c/litert_model.cc create mode 100644 tflite/experimental/litert/c/litert_model.h create mode 100644 tflite/experimental/litert/c/litert_model_test.cc create mode 100644 tflite/experimental/litert/c/litert_op_code.h create mode 100644 tflite/experimental/litert/c/litert_options.cc create mode 100644 tflite/experimental/litert/c/litert_options.h create mode 100644 tflite/experimental/litert/c/litert_options_test.cc create mode 100644 tflite/experimental/litert/c/litert_tensor_buffer.cc create mode 100644 tflite/experimental/litert/c/litert_tensor_buffer.h create mode 100644 tflite/experimental/litert/c/litert_tensor_buffer_requirements.cc create mode 100644 tflite/experimental/litert/c/litert_tensor_buffer_requirements.h create mode 100644 tflite/experimental/litert/c/litert_tensor_buffer_requirements_test.cc create mode 100644 tflite/experimental/litert/c/litert_tensor_buffer_test.cc create mode 100644 tflite/experimental/litert/cc/BUILD create mode 100644 tflite/experimental/litert/cc/litert_any.h create mode 100644 tflite/experimental/litert/cc/litert_any_test.cc create mode 100644 tflite/experimental/litert/cc/litert_buffer_ref.h create mode 100644 tflite/experimental/litert/cc/litert_buffer_ref_test.cc create mode 100644 tflite/experimental/litert/cc/litert_compiled_model.cc create mode 100644 tflite/experimental/litert/cc/litert_compiled_model.h create mode 100644 tflite/experimental/litert/cc/litert_compiled_model_test.cc create mode 100644 tflite/experimental/litert/cc/litert_detail.h create mode 100644 tflite/experimental/litert/cc/litert_element_type.h create mode 100644 tflite/experimental/litert/cc/litert_element_type_test.cc create mode 100644 tflite/experimental/litert/cc/litert_environment.h create mode 100644 tflite/experimental/litert/cc/litert_expected.h create mode 100644 tflite/experimental/litert/cc/litert_expected_test.cc create mode 100644 tflite/experimental/litert/cc/litert_handle.h create mode 100644 tflite/experimental/litert/cc/litert_layout.h create mode 100644 tflite/experimental/litert/cc/litert_layout_test.cc create mode 100644 tflite/experimental/litert/cc/litert_macros.h create mode 100644 tflite/experimental/litert/cc/litert_model.cc create mode 100644 tflite/experimental/litert/cc/litert_model.h create mode 100644 tflite/experimental/litert/cc/litert_model_predicates.cc create mode 100644 tflite/experimental/litert/cc/litert_model_predicates.h create mode 100644 tflite/experimental/litert/cc/litert_model_predicates_test.cc create mode 100644 tflite/experimental/litert/cc/litert_model_test.cc create mode 100644 tflite/experimental/litert/cc/litert_tensor_buffer.h create mode 100644 tflite/experimental/litert/cc/litert_tensor_buffer_requirements.h create mode 100644 tflite/experimental/litert/cc/litert_tensor_buffer_requirements_test.cc create mode 100644 tflite/experimental/litert/cc/litert_tensor_buffer_test.cc create mode 100644 tflite/experimental/litert/compiler/BUILD create mode 100644 tflite/experimental/litert/compiler/plugin/BUILD create mode 100644 tflite/experimental/litert/compiler/plugin/algo.cc create mode 100644 tflite/experimental/litert/compiler/plugin/algo.h create mode 100644 tflite/experimental/litert/compiler/plugin/algo_test.cc create mode 100644 tflite/experimental/litert/compiler/plugin/compiler_plugin.cc create mode 100644 tflite/experimental/litert/compiler/plugin/compiler_plugin.h create mode 100644 tflite/experimental/litert/compiler/plugin/compiler_plugin_test.cc create mode 100644 tflite/experimental/litert/core/BUILD create mode 100644 tflite/experimental/litert/core/byte_code_util.cc create mode 100644 tflite/experimental/litert/core/byte_code_util.h create mode 100644 tflite/experimental/litert/core/byte_code_util_test.cc create mode 100644 tflite/experimental/litert/core/dynamic_loading.cc create mode 100644 tflite/experimental/litert/core/dynamic_loading.h create mode 100644 tflite/experimental/litert/core/dynamic_loading_test.cc create mode 100644 tflite/experimental/litert/core/environment.cc create mode 100644 tflite/experimental/litert/core/environment.h create mode 100644 tflite/experimental/litert/core/environment_test.cc create mode 100644 tflite/experimental/litert/core/filesystem.cc create mode 100644 tflite/experimental/litert/core/filesystem.h create mode 100644 tflite/experimental/litert/core/filesystem_test.cc create mode 100644 tflite/experimental/litert/core/model/BUILD create mode 100644 tflite/experimental/litert/core/model/flatbuffer_to_litert.cc create mode 100644 tflite/experimental/litert/core/model/flatbuffer_to_litert.h create mode 100644 tflite/experimental/litert/core/model/flatbuffer_to_litert_test.cc create mode 100644 tflite/experimental/litert/core/model/graph_validation.cc create mode 100644 tflite/experimental/litert/core/model/graph_validation.h create mode 100644 tflite/experimental/litert/core/model/ir_allocator.h create mode 100644 tflite/experimental/litert/core/model/ir_allocator_test.cc create mode 100644 tflite/experimental/litert/core/model/litert_to_flatbuffer.cc create mode 100644 tflite/experimental/litert/core/model/litert_to_flatbuffer.h create mode 100644 tflite/experimental/litert/core/model/litert_to_flatbuffer_test.cc create mode 100644 tflite/experimental/litert/core/model/model.cc create mode 100644 tflite/experimental/litert/core/model/model.h create mode 100644 tflite/experimental/litert/core/model/model_buffer.cc create mode 100644 tflite/experimental/litert/core/model/model_buffer.h create mode 100644 tflite/experimental/litert/core/model/model_buffer_test.cc create mode 100644 tflite/experimental/litert/core/model/model_file_test.cc create mode 100644 tflite/experimental/litert/core/model/model_file_test_util.cc create mode 100644 tflite/experimental/litert/core/model/model_file_test_util.h create mode 100644 tflite/experimental/litert/core/model/model_graph.cc create mode 100644 tflite/experimental/litert/core/model/model_graph.h create mode 100644 tflite/experimental/litert/core/model/model_graph_test.cc create mode 100644 tflite/experimental/litert/core/model/model_load.cc create mode 100644 tflite/experimental/litert/core/model/model_load.h create mode 100644 tflite/experimental/litert/core/model/model_serialize.cc create mode 100644 tflite/experimental/litert/core/model/model_serialize.h create mode 100644 tflite/experimental/litert/core/model/model_test.cc create mode 100644 tflite/experimental/litert/core/util/BUILD create mode 100644 tflite/experimental/litert/core/util/flatbuffer_tools.cc create mode 100644 tflite/experimental/litert/core/util/flatbuffer_tools.h create mode 100644 tflite/experimental/litert/core/util/flatbuffer_tools_test.cc create mode 100644 tflite/experimental/litert/core/util/tensor_type_util.cc create mode 100644 tflite/experimental/litert/core/util/tensor_type_util.h create mode 100644 tflite/experimental/litert/core/util/tensor_type_util_test.cc create mode 100644 tflite/experimental/litert/integration_test/BUILD create mode 100644 tflite/experimental/litert/runtime/BUILD create mode 100644 tflite/experimental/litert/runtime/ahwb_buffer.cc create mode 100644 tflite/experimental/litert/runtime/ahwb_buffer.h create mode 100644 tflite/experimental/litert/runtime/compiled_model.cc create mode 100644 tflite/experimental/litert/runtime/compiled_model.h create mode 100644 tflite/experimental/litert/runtime/compiled_model_test.cc create mode 100644 tflite/experimental/litert/runtime/compiler/BUILD create mode 100644 tflite/experimental/litert/runtime/compiler/jit_compilation_qualcomm_test.cc create mode 100644 tflite/experimental/litert/runtime/dispatch/BUILD create mode 100644 tflite/experimental/litert/runtime/dispatch/README.md create mode 100644 tflite/experimental/litert/runtime/dispatch/dispatch_delegate.cc create mode 100644 tflite/experimental/litert/runtime/dispatch/dispatch_delegate_google_tensor_test.cc create mode 100644 tflite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.cc create mode 100644 tflite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.h create mode 100644 tflite/experimental/litert/runtime/dispatch/dispatch_delegate_mediatek_test.cc create mode 100644 tflite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h create mode 100644 tflite/experimental/litert/runtime/dispatch/dispatch_delegate_qualcomm_test.cc create mode 100644 tflite/experimental/litert/runtime/dispatch/litert_dispatch.cc create mode 100644 tflite/experimental/litert/runtime/dmabuf_buffer.cc create mode 100644 tflite/experimental/litert/runtime/dmabuf_buffer.h create mode 100644 tflite/experimental/litert/runtime/event.cc create mode 100644 tflite/experimental/litert/runtime/event.h create mode 100644 tflite/experimental/litert/runtime/external_litert_buffer_context.cc create mode 100644 tflite/experimental/litert/runtime/external_litert_buffer_context.h create mode 100644 tflite/experimental/litert/runtime/fastrpc_buffer.cc create mode 100644 tflite/experimental/litert/runtime/fastrpc_buffer.h create mode 100644 tflite/experimental/litert/runtime/ion_buffer.cc create mode 100644 tflite/experimental/litert/runtime/ion_buffer.h create mode 100644 tflite/experimental/litert/runtime/tensor_buffer.cc create mode 100644 tflite/experimental/litert/runtime/tensor_buffer.h create mode 100644 tflite/experimental/litert/runtime/tfl_utils.cc create mode 100644 tflite/experimental/litert/runtime/tfl_utils.h create mode 100644 tflite/experimental/litert/test/BUILD create mode 100644 tflite/experimental/litert/test/common.cc create mode 100644 tflite/experimental/litert/test/common.h create mode 100644 tflite/experimental/litert/test/test_macros.h create mode 100644 tflite/experimental/litert/test/test_models.h create mode 100644 tflite/experimental/litert/test/testdata/add_cst.mlir create mode 100644 tflite/experimental/litert/test/testdata/add_simple.mlir create mode 100644 tflite/experimental/litert/test/testdata/cos_mul.mlir create mode 100644 tflite/experimental/litert/test/testdata/dynamic_shape_tensor.mlir create mode 100644 tflite/experimental/litert/test/testdata/fully_connected_3d.mlir create mode 100644 tflite/experimental/litert/test/testdata/mul_simple.mlir create mode 100644 tflite/experimental/litert/test/testdata/multi_subgraph.mlir create mode 100644 tflite/experimental/litert/test/testdata/multi_subgraph_mul.mlir create mode 100644 tflite/experimental/litert/test/testdata/one_mul.mlir create mode 100644 tflite/experimental/litert/test/testdata/rms_norm.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_add_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_batch_matmul_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_cast_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_concatenation_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_cos_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_div_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_embedding_lookup_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_floor_mod_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_fully_connected_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_greater_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_less_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_logical_and_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_model.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_model_google_tensor.bin create mode 100644 tflite/experimental/litert/test/testdata/simple_model_mtk.bin create mode 100644 tflite/experimental/litert/test/testdata/simple_model_npu.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_model_qualcomm.bin create mode 100644 tflite/experimental/litert/test/testdata/simple_model_test_vectors.h create mode 100644 tflite/experimental/litert/test/testdata/simple_mul_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_multi_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_reshape_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_rsqrt_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_select_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_select_v2_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_sin_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_slice_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_softmax_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_stablehlo_scatter_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_strided_slice_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_sub_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_sum_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_tanh_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/simple_transpose_op.mlir create mode 100644 tflite/experimental/litert/test/testdata/two_partition.mlir create mode 100644 tflite/experimental/litert/test/testdata/unranked_tensor.mlir create mode 100644 tflite/experimental/litert/tools/BUILD create mode 100644 tflite/experimental/litert/tools/apply_plugin.cc create mode 100644 tflite/experimental/litert/tools/apply_plugin.h create mode 100644 tflite/experimental/litert/tools/apply_plugin_main.cc create mode 100644 tflite/experimental/litert/tools/apply_plugin_test.cc create mode 100644 tflite/experimental/litert/tools/dump.cc create mode 100644 tflite/experimental/litert/tools/dump.h create mode 100644 tflite/experimental/litert/tools/dump_test.cc create mode 100644 tflite/experimental/litert/tools/outstream.h create mode 100644 tflite/experimental/litert/tools/tool_display.cc create mode 100644 tflite/experimental/litert/tools/tool_display.h create mode 100644 tflite/experimental/litert/tools/tool_display_test.cc create mode 100644 tflite/experimental/litert/vendors/c/BUILD create mode 100644 tflite/experimental/litert/vendors/c/litert_compiler_plugin.h create mode 100644 tflite/experimental/litert/vendors/c/litert_compiler_plugin_api.h create mode 100644 tflite/experimental/litert/vendors/c/litert_dispatch.h create mode 100644 tflite/experimental/litert/vendors/c/litert_dispatch_api.h create mode 100644 tflite/experimental/litert/vendors/c/litert_vendor_c_api_common_test.c create mode 100644 tflite/experimental/litert/vendors/cc/BUILD create mode 100644 tflite/experimental/litert/vendors/cc/litert_compiler_plugin.h create mode 100644 tflite/experimental/litert/vendors/examples/BUILD create mode 100644 tflite/experimental/litert/vendors/examples/example_plugin.cc create mode 100644 tflite/experimental/litert/vendors/examples/example_plugin_test.cc create mode 100644 tflite/experimental/litert/vendors/google_tensor/dispatch/BUILD create mode 100644 tflite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.cc create mode 100644 tflite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h create mode 100644 tflite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc create mode 100644 tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.cc create mode 100644 tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h create mode 100644 tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h create mode 100644 tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.cc create mode 100644 tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h create mode 100644 tflite/experimental/litert/vendors/google_tensor/dispatch/southbound.cc create mode 100644 tflite/experimental/litert/vendors/google_tensor/dispatch/southbound.h create mode 100644 tflite/experimental/litert/vendors/mediatek/BUILD create mode 100644 tflite/experimental/litert/vendors/mediatek/dispatch/BUILD create mode 100644 tflite/experimental/litert/vendors/mediatek/dispatch/README.md create mode 100644 tflite/experimental/litert/vendors/mediatek/dispatch/dispatch_api.cc create mode 100644 tflite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc create mode 100644 tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.cc create mode 100644 tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h create mode 100644 tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.cc create mode 100644 tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.h create mode 100644 tflite/experimental/litert/vendors/mediatek/neuron_adapter.cc create mode 100644 tflite/experimental/litert/vendors/mediatek/neuron_adapter.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/BUILD create mode 100644 tflite/experimental/litert/vendors/qualcomm/common.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/BUILD create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/IR/BUILD create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/IR/op_compatibility_test.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op_test.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor_test.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/BUILD create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/context_binary_info.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/context_binary_info.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/dispatch/BUILD create mode 100644 tflite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/dispatch/registry.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/qnn_log.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/qnn_log.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/qnn_manager.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/qnn_manager.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/qnn_manager_test.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/qnn_tensor.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/qnn_tensor.h create mode 100644 tflite/experimental/litert/vendors/qualcomm/qualcomm_build_defs.bzl create mode 100644 tflite/experimental/litert/vendors/qualcomm/tools/BUILD create mode 100644 tflite/experimental/litert/vendors/qualcomm/tools/dump.cc create mode 100644 tflite/experimental/litert/vendors/qualcomm/tools/dump.h diff --git a/tflite/experimental/litert/BUILD b/tflite/experimental/litert/BUILD new file mode 100644 index 00000000..e3809c5e --- /dev/null +++ b/tflite/experimental/litert/BUILD @@ -0,0 +1,18 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) diff --git a/tflite/experimental/litert/build_common/BUILD b/tflite/experimental/litert/build_common/BUILD new file mode 100644 index 00000000..d8cace19 --- /dev/null +++ b/tflite/experimental/litert/build_common/BUILD @@ -0,0 +1,20 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +exports_files(srcs = ["export_litert_only.lds"]) diff --git a/tflite/experimental/litert/build_common/export_litert_only.lds b/tflite/experimental/litert/build_common/export_litert_only.lds new file mode 100644 index 00000000..97b05c1d --- /dev/null +++ b/tflite/experimental/litert/build_common/export_litert_only.lds @@ -0,0 +1,29 @@ +VERS_1.0 { + + /* + Export abi-stable "vendor" implemented symbols. + + TODO: Add all vendor symbols. Also export qnn libc++ symbols + (statically linked) as "protected" as needed. + */ + + global: + + /* Compiler Plugin */ + + LiteRt*CompilerPlugin*; + + /* Compiled Result */ + + LiteRt*CompiledResult*; + + /* Dispatch */ + + LiteRtDispatch*; + + local: + + /* Hide everything else */ + + *; +}; diff --git a/tflite/experimental/litert/build_common/litert_build_defs.bzl b/tflite/experimental/litert/build_common/litert_build_defs.bzl new file mode 100644 index 00000000..1a12fd62 --- /dev/null +++ b/tflite/experimental/litert/build_common/litert_build_defs.bzl @@ -0,0 +1,219 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common LiteRT Build Utilities.""" + +#################################################################################################### +# Util + +_LRT_SO_PREFIX = "libLiteRt" +_SO_EXT = ".so" +_SHARED_LIB_SUFFIX = "_so" + +# Public + +def make_linkopt(opt): + return "-Wl,{}".format(opt) + +def make_rpaths(rpaths): + return make_linkopt("-rpath={}".format(":".join(rpaths))) + +def append_rule_kwargs(rule_kwargs, **append): + for k, v in append.items(): + append_to = rule_kwargs.pop(k, []) + append_to += v + rule_kwargs[k] = append_to + +# Private + +def _valild_shared_lib_name(name): + return name.endswith(_SHARED_LIB_SUFFIX) + +def _valid_so_name(name): + return name.startswith(_LRT_SO_PREFIX) and name.endswith(_SO_EXT) + +def _make_target_ref(name): + return ":{}".format(name) + +def _make_script_linkopt(script): + return make_linkopt("--version-script=$(location {})".format(script)) + +#################################################################################################### +# Explicitly Link System Libraries ("ungrte") + +_SYS_RPATHS_X86_64 = [ + "/usr/lib/x86_64-linux-gnu", + "/lib/x86_64-linux-gnu", +] +_SYS_RPATHS_LINKOPT_X86_64 = make_rpaths(_SYS_RPATHS_X86_64) + +_SYS_ELF_INTERPRETER_X86_64 = "/lib64/ld-linux-x86-64.so.2" +_SYS_ELF_INTERPRETER_LINKOPT_X86_64 = make_linkopt("--dynamic-linker={}".format(_SYS_ELF_INTERPRETER_X86_64)) + +#################################################################################################### +# Symbol Hiding + +_EXPORT_LRT_ONLY_SCRIPT = "//tflite/experimental/litert/build_common:export_litert_only.lds" +_EXPORT_LRT_ONLY_LINKOPT = _make_script_linkopt(_EXPORT_LRT_ONLY_SCRIPT) + +#################################################################################################### +# Macros + +# Private + +def _litert_base( + rule, + ungrte = False, + **cc_rule_kwargs): + """ + Base rule for LiteRT targets. + + Args: + rule: The underlying rule to use (e.g., cc_test, cc_library). + ungrte: Whether to link against system libraries ("ungrte"). + **cc_rule_kwargs: Keyword arguments to pass to the underlying rule. + """ + if ungrte: + append_rule_kwargs( + cc_rule_kwargs, + linkopts = select({ + "@org_tensorflow//tensorflow:linux_x86_64": [_SYS_ELF_INTERPRETER_LINKOPT_X86_64, _SYS_RPATHS_LINKOPT_X86_64], + "//conditions:default": [], + }), + ) + rule(**cc_rule_kwargs) + +# Public + +def litert_test( + ungrte = False, + use_sys_malloc = False, + **cc_test_kwargs): + """ + LiteRT test rule. + + Args: + ungrte: Whether to link against system libraries ("ungrte"). + use_sys_malloc: Whether to use the system malloc. + **cc_test_kwargs: Keyword arguments to pass to the underlying rule. + """ + if use_sys_malloc: + # copybara:uncomment cc_test_kwargs["malloc"] = "//base:system_malloc" + pass + + append_rule_kwargs( + cc_test_kwargs, + deps = ["@com_google_googletest//:gtest_main"], + ) + + _litert_base( + native.cc_test, + ungrte, + **cc_test_kwargs + ) + +def litert_lib( + ungrte = False, + **cc_lib_kwargs): + """ + LiteRT library rule. + + Args: + ungrte: Whether to link against system libraries ("ungrte"). + **cc_lib_kwargs: Keyword arguments to pass to the underlying rule. + """ + _litert_base( + native.cc_library, + ungrte, + **cc_lib_kwargs + ) + +def litert_bin( + ungrte = False, + export_litert_only = False, + **cc_bin_kwargs): + """ + LiteRT binary rule. + + Args: + ungrte: Whether to link against system libraries ("ungrte"). + export_litert_only: Whether to export only LiteRT symbols. + **cc_bin_kwargs: Keyword arguments to pass to the underlying rule. + """ + if export_litert_only: + append_rule_kwargs( + cc_bin_kwargs, + linkopts = [_EXPORT_LRT_ONLY_LINKOPT], + deps = [_EXPORT_LRT_ONLY_SCRIPT], + ) + + _litert_base( + native.cc_binary, + ungrte, + **cc_bin_kwargs + ) + +def litert_dynamic_lib( + name, + shared_lib_name, + so_name, + export_litert_only = False, + ungrte = False, + **cc_lib_kwargs): + """ + LiteRT dynamic library rule. + + Args: + name: The name of the library. + shared_lib_name: The name of the shared library. + so_name: The name of the shared object file. + export_litert_only: Whether to export only LiteRT symbols. + ungrte: Whether to link against system libraries ("ungrte"). + **cc_lib_kwargs: Keyword arguments to pass to the underlying rule. + """ + if not _valild_shared_lib_name(shared_lib_name): + fail("\"shared_lib_name\" must end with \"_so\"") + if not _valid_so_name(so_name): + fail("\"so_name\" must be \"libLiteRt*.so\"") + + lib_name = name + cc_lib_kwargs["name"] = lib_name + + lib_target_ref = _make_target_ref(lib_name) + + vis = cc_lib_kwargs.get("visibility", None) + + # Share tags for all targets. + tags = cc_lib_kwargs.get("tags", []) + + litert_lib( + ungrte = ungrte, + **cc_lib_kwargs + ) + + user_link_flags = [] + additional_linker_inputs = [] + if export_litert_only: + user_link_flags.append(_EXPORT_LRT_ONLY_LINKOPT) + additional_linker_inputs.append(_EXPORT_LRT_ONLY_SCRIPT) + + native.cc_shared_library( + name = shared_lib_name, + shared_lib_name = so_name, + user_link_flags = user_link_flags, + additional_linker_inputs = additional_linker_inputs, + tags = tags, + visibility = vis, + deps = [lib_target_ref], + ) diff --git a/tflite/experimental/litert/c/BUILD b/tflite/experimental/litert/c/BUILD new file mode 100644 index 00000000..3594e6bb --- /dev/null +++ b/tflite/experimental/litert/c/BUILD @@ -0,0 +1,305 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "litert_common", + hdrs = ["litert_common.h"], +) + +cc_library( + name = "litert_any", + hdrs = ["litert_any.h"], +) + +cc_library( + name = "litert_environment", + srcs = ["litert_environment.cc"], + hdrs = ["litert_environment.h"], + deps = [ + ":litert_any", + ":litert_common", + "//tflite/experimental/litert/core:environment", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "litert_logging", + srcs = [ + "litert_logging.cc", + ], + hdrs = [ + "litert_logging.h", + ], + deps = [ + ":litert_common", + "//tflite:minimal_logging", + ], +) + +cc_test( + name = "litert_logging_test", + srcs = [ + "litert_logging_test.cc", + ], + deps = [ + ":litert_common", + ":litert_logging", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "litert_layout", + hdrs = ["litert_layout.h"], + deps = [ + ":litert_common", + ":litert_op_code", + "//tflite/core/c:c_api_types", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/cc:litert_macros", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "litert_model", + srcs = ["litert_model.cc"], + hdrs = ["litert_model.h"], + deps = [ + ":litert_common", + ":litert_layout", + ":litert_op_code", + "//tflite/core/c:c_api_types", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/core/model", + "//tflite/experimental/litert/core/model:model_load", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "litert_model_test", + srcs = ["litert_model_test.cc"], + deps = [ + ":litert_common", + ":litert_model", + ":litert_op_code", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/core/model", + "//tflite/experimental/litert/core/util:flatbuffer_tools", + "//tflite/experimental/litert/test:test_macros", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "litert_op_code", + hdrs = ["litert_op_code.h"], + deps = ["//tflite:builtin_ops"], +) + +cc_library( + name = "litert_options", + srcs = ["litert_options.cc"], + hdrs = [ + "litert_options.h", + ], + deps = [ + ":litert_common", + ":litert_op_code", + "//tflite/c:c_api_types", + "//tflite/experimental/litert/core/model", + "@org_tensorflow//tensorflow/compiler/mlir/lite/core:model_builder_base", + ], +) + +cc_test( + name = "litert_options_test", + srcs = ["litert_options_test.cc"], + data = [ + "//tflite/experimental/litert/test:mlir_test_data", + ], + tags = ["no_oss"], + deps = [ + ":litert_options", + "//tflite/experimental/litert/test:common", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "litert_tensor_buffer", + srcs = [ + "litert_event.cc", + "litert_tensor_buffer.cc", + "litert_tensor_buffer_requirements.cc", + ], + hdrs = [ + "litert_event.h", + "litert_tensor_buffer.h", + "litert_tensor_buffer_requirements.h", + ], + deps = [ + ":litert_common", + ":litert_logging", + ":litert_model", + "//tflite/experimental/litert/runtime:tensor_buffer", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "litert_tensor_buffer_test", + srcs = [ + "litert_tensor_buffer_test.cc", + ], + linkopts = select({ + "@org_tensorflow//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + deps = [ + ":litert_common", + ":litert_model", + ":litert_tensor_buffer", + "//tflite/experimental/litert/cc:litert_layout", + "//tflite/experimental/litert/runtime:tensor_buffer", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "litert_tensor_buffer_requirements_test", + srcs = [ + "litert_tensor_buffer_requirements_test.cc", + ], + linkopts = select({ + "@org_tensorflow//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + deps = [ + ":litert_common", + ":litert_tensor_buffer", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "litert_dispatch_delegate", + hdrs = [ + "litert_dispatch_delegate.h", + ], + deps = [ + "//tflite/c:c_api", + "//tflite/c:c_api_opaque", + "//tflite/c:c_api_types", + "//tflite/c:common", + "//tflite/delegates/utils:simple_opaque_delegate", + "//tflite/experimental/litert/runtime/dispatch:dispatch_delegate", + "//tflite/experimental/litert/vendors/c:litert_dispatch_c_api", + ], +) + +cc_library( + name = "litert_compiled_model_options", + hdrs = [ + "litert_compiled_model_options.h", + ], + deps = [ + ":litert_common", + ], +) + +cc_library( + name = "litert_compiled_model", + srcs = ["litert_compiled_model.cc"], + hdrs = [ + "litert_compiled_model.h", + ], + deps = [ + ":litert_common", + ":litert_compiled_model_options", + ":litert_logging", + ":litert_model", + ":litert_tensor_buffer", + "//tflite/c:c_api_types", + "//tflite/experimental/litert/runtime:compiled_model", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "litert_compiled_model_test", + srcs = [ + "litert_compiled_model_test.cc", + ], + data = [ + "//tflite/experimental/litert/test:testdata/simple_model.tflite", + ], + deps = [ + ":litert_common", + ":litert_compiled_model", + ":litert_compiled_model_options", + ":litert_model", + ":litert_tensor_buffer", + "//tflite/experimental/litert/test:common", + "//tflite/experimental/litert/test:simple_model", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +filegroup( + name = "litert_model_srcs", + srcs = ["litert_model.cc"], + visibility = ["//tflite/experimental/litert/core/model:__pkg__"], +) + +filegroup( + name = "litert_model_hdrs", + srcs = ["litert_model.h"], + visibility = ["//tflite/experimental/litert/core/model:__pkg__"], +) + +# This test verifies that the C API header files can build via C compiler. +cc_test( + name = "litert_c_api_common_test", + srcs = ["litert_c_api_common_test.c"], + copts = ["--std=c11"], + linkopts = ["-ldl"], + deps = [ + ":litert_any", + ":litert_common", + ":litert_compiled_model", + ":litert_compiled_model_options", + ":litert_dispatch_delegate", + ":litert_layout", + ":litert_logging", + ":litert_model", + ":litert_op_code", + ":litert_options", + ":litert_tensor_buffer", + ], +) + +exports_files(srcs = glob(["litert_*.h"])) diff --git a/tflite/experimental/litert/c/litert_any.h b/tflite/experimental/litert/c/litert_any.h new file mode 100644 index 00000000..69a2a8d7 --- /dev/null +++ b/tflite/experimental/litert/c/litert_any.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ANY_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ANY_H_ + +#include // NOLINT: To use bool type in C +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef enum { + kLiteRtAnyTypeNone = 0, + kLiteRtAnyTypeBool = 1, + kLiteRtAnyTypeInt = 2, + kLiteRtAnyTypeReal = 3, + kLiteRtAnyTypeString = 8, + kLiteRtAnyTypeVoidPtr = 9, +} LiteRtAnyType; + +typedef struct { + LiteRtAnyType type; + union { + bool bool_value; + int64_t int_value; + double real_value; + const char* str_value; + const void* ptr_value; + }; +} LiteRtAny; + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ANY_H_ diff --git a/tflite/experimental/litert/c/litert_c_api_common_test.c b/tflite/experimental/litert/c/litert_c_api_common_test.c new file mode 100644 index 00000000..09ef5e79 --- /dev/null +++ b/tflite/experimental/litert/c/litert_c_api_common_test.c @@ -0,0 +1,40 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file exists to verify that the below header files can build, link, +// and run as C code. +#ifdef __cplusplus +#error "This file should be compiled as C code, not as C++." +#endif + +// Include all the header files in the litert/c directory. +#include "tflite/experimental/litert/c/litert_common.h" // NOLINT +#include "tflite/experimental/litert/c/litert_any.h" // NOLINT +#include "tflite/experimental/litert/c/litert_compiled_model.h" // NOLINT +#include "tflite/experimental/litert/c/litert_compiled_model_options.h" // NOLINT +#include "tflite/experimental/litert/c/litert_dispatch_delegate.h" // NOLINT +#include "tflite/experimental/litert/c/litert_event.h" // NOLINT +#include "tflite/experimental/litert/c/litert_layout.h" // NOLINT +#include "tflite/experimental/litert/c/litert_logging.h" // NOLINT +#include "tflite/experimental/litert/c/litert_options.h" // NOLINT +#include "tflite/experimental/litert/c/litert_model.h" // NOLINT +#include "tflite/experimental/litert/c/litert_op_code.h" // NOLINT +#include "tflite/experimental/litert/c/litert_options.h" // NOLINT +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" // NOLINT +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" // NOLINT + +int main(void) { + return 0; +} + diff --git a/tflite/experimental/litert/c/litert_common.h b/tflite/experimental/litert/c/litert_common.h new file mode 100644 index 00000000..b68c0b77 --- /dev/null +++ b/tflite/experimental/litert/c/litert_common.h @@ -0,0 +1,97 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMMON_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMMON_H_ + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Declares canonical opaque type. +#define LITERT_DEFINE_HANDLE(name) typedef struct name##T* name +// Declares an array of references to opaque type. `name` must be +// previously declared opaque type. +#define LITERT_DEFINE_HANDLE_ARRAY(name) typedef name* name##Array + +#if __ANDROID_API__ >= 26 +#define LITERT_HAS_AHWB_SUPPORT 1 +#else +#define LITERT_HAS_AHWB_SUPPORT 0 +#endif // __ANDROID_API__ >= 26 + +#if defined(__linux__) || defined(__ANDROID__) +#define LITERT_HAS_SYNC_FENCE_SUPPORT 1 +#else +#define LITERT_HAS_SYNC_FENCE_SUPPORT 0 +#endif + +#if defined(__ANDROID__) +#define LITERT_HAS_ION_SUPPORT 1 +#define LITERT_HAS_DMABUF_SUPPORT 1 +#define LITERT_HAS_FASTRPC_SUPPORT 1 +#else +#define LITERT_HAS_ION_SUPPORT 0 +#define LITERT_HAS_DMABUF_SUPPORT 0 +#define LITERT_HAS_FASTRPC_SUPPORT 0 +#endif + +#define LITERT_API_VERSION_MAJOR 0 +#define LITERT_API_VERSION_MINOR 1 +#define LITERT_API_VERSION_PATCH 0 + +typedef struct LiteRtApiVersion { + int major; + int minor; + int patch; +} LiteRtApiVersion; + +typedef enum { + kLiteRtStatusOk = 0, + + // Generic errors. + kLiteRtStatusErrorInvalidArgument = 1, + kLiteRtStatusErrorMemoryAllocationFailure = 2, + kLiteRtStatusErrorRuntimeFailure = 3, + kLiteRtStatusErrorMissingInputTensor = 4, + kLiteRtStatusErrorUnsupported = 5, + kLiteRtStatusErrorNotFound = 6, + kLiteRtStatusErrorTimeoutExpired = 7, + + // File and loading related errors. + kLiteRtStatusErrorFileIO = 500, + kLiteRtStatusErrorInvalidFlatbuffer = 501, + kLiteRtStatusErrorDynamicLoading = 502, + kLiteRtStatusErrorSerialization = 503, + kLiteRtStatusErrorCompilation = 504, + + // IR related errors. + kLiteRtStatusErrorIndexOOB = 1000, + kLiteRtStatusErrorInvalidIrType = 1001, + kLiteRtStatusErrorInvalidGraphInvariant = 1002, + kLiteRtStatusErrorGraphModification = 1003, + + // Tool related errors. + kLiteRtStatusErrorInvalidToolConfig = 1500, + + // Lealization related errors. + kLiteRtStatusLegalizeNoMatch = 2000, + kLiteRtStatusErrorInvalidLegalization = 2001, +} LiteRtStatus; + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMMON_H_ diff --git a/tflite/experimental/litert/c/litert_compiled_model.cc b/tflite/experimental/litert/c/litert_compiled_model.cc new file mode 100644 index 00000000..0f9f6230 --- /dev/null +++ b/tflite/experimental/litert/c/litert_compiled_model.cc @@ -0,0 +1,104 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/c/litert_compiled_model.h" + +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_compiled_model_options.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/runtime/compiled_model.h" + +LiteRtStatus LiteRtCreateCompiledModel( + LiteRtModel model, LiteRtComplicationOptions complication_options, + LiteRtCompiledModel* compiled_model) { + if (!model || !compiled_model) { + return kLiteRtStatusErrorInvalidArgument; + } + + auto created_compiled_model = + LiteRtCompiledModelT::Create(model, complication_options); + if (!created_compiled_model) { + LITERT_LOG(LITERT_ERROR, "%s", + created_compiled_model.Error().Message().data()); + return created_compiled_model.Error().Status(); + } + *compiled_model = created_compiled_model->release(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetCompiledModelInputBufferRequirements( + LiteRtCompiledModel compiled_model, LiteRtParamIndex signature_index, + LiteRtParamIndex input_index, + LiteRtTensorBufferRequirements* buffer_requirements) { + if (!compiled_model || !buffer_requirements) { + return kLiteRtStatusErrorInvalidArgument; + } + + auto res = compiled_model->GetInputBufferRequirementsCApi(signature_index, + input_index); + if (!res) { + LITERT_LOG(LITERT_ERROR, "%s", res.Error().Message().data()); + return res.Error().Status(); + } + *buffer_requirements = res.Value(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetCompiledModelOutputBufferRequirements( + LiteRtCompiledModel compiled_model, LiteRtParamIndex signature_index, + LiteRtParamIndex output_index, + LiteRtTensorBufferRequirements* buffer_requirements) { + if (!compiled_model || !buffer_requirements) { + return kLiteRtStatusErrorInvalidArgument; + } + + auto res = compiled_model->GetOutputBufferRequirementsCApi(signature_index, + output_index); + if (!res) { + LITERT_LOG(LITERT_ERROR, "%s", res.Error().Message().data()); + return res.Error().Status(); + } + *buffer_requirements = res.Value(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtRunCompiledModel(LiteRtCompiledModel compiled_model, + LiteRtParamIndex signature_index, + size_t num_input_buffers, + LiteRtTensorBuffer* input_buffers, + size_t num_output_buffers, + LiteRtTensorBuffer* output_buffers) { + if (!compiled_model || (num_input_buffers > 0 && !input_buffers) || + (num_output_buffers > 0 && !output_buffers)) { + return kLiteRtStatusErrorInvalidArgument; + } + + auto res = + compiled_model->RunCApi(signature_index, num_input_buffers, input_buffers, + num_output_buffers, output_buffers); + if (!res) { + LITERT_LOG(LITERT_ERROR, "%s", res.Error().Message().data()); + return res.Error().Status(); + } + return kLiteRtStatusOk; +} + +void LiteRtDestroyCompiledModel(LiteRtCompiledModel compiled_model) { + delete compiled_model; +} diff --git a/tflite/experimental/litert/c/litert_compiled_model.h b/tflite/experimental/litert/c/litert_compiled_model.h new file mode 100644 index 00000000..87248263 --- /dev/null +++ b/tflite/experimental/litert/c/litert_compiled_model.h @@ -0,0 +1,106 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILED_MODEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILED_MODEL_H_ + +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_compiled_model_options.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// The LiteRtCompiledModel is a higher level inference API. It is created by +// provided model with compilation options. Internally, it instantiates runtime +// and applies Delegates mapped to the compilation options. +// It also supports getting LiteRtTensorBufferRequirements to create +// input/output TensorBuffers, and it allows to invoke the model with the +// input/output TensorBuffers. +// +// Example user flow: +// +// 1. Create LiteRtCompiledModel +// 2. Query the model input/output LiteRtTensorBufferRequirements +// 3. Create input/output LiteRtTensorBuffer +// 4. Fill the input LiteRtTensorBuffer with input data +// 5. Invoke the model with the input/output LiteRtTensorBuffer +// 6. Evaluate the output LiteRtTensorBuffer + +LITERT_DEFINE_HANDLE(LiteRtCompiledModel); + +// Creates a LiteRtCompiledModel from a LiteRtModel object. +// The model is loaded into memory and the caller takes ownership of the +// returned object. +LiteRtStatus LiteRtCreateCompiledModel( + LiteRtModel model, LiteRtComplicationOptions complication_options, + LiteRtCompiledModel* compiled_model); + +// Returns the buffer requirements for the given n-th input tensor. The returned +// LiteRtTensorBufferRequirements is used to create the input tensor +// buffer. +// +// Parameters: +// - compiled_model: the target `LiteRtCompiledModel` object. +// - signature_index: the index of the signature in `LiteRtModel`. +// - input_index: the index of the input tensor in the signature (subgraph). +// - buffer_requirements: the returned `LiteRtTensorBufferRequirements`. +LiteRtStatus LiteRtGetCompiledModelInputBufferRequirements( + LiteRtCompiledModel compiled_model, LiteRtParamIndex signature_index, + LiteRtParamIndex input_index, + LiteRtTensorBufferRequirements* buffer_requirements); + +// Returns the buffer requirements for the given n-th output tensor. The +// returned LiteRtTensorBufferRequirements is used to create the output tensor +// buffer. +// +// Parameters: +// - compiled_model: the target `LiteRtCompiledModel` object. +// - signature_index: the index of the signature in `LiteRtModel`. +// - input_index: the index of the input tensor in the signature (subgraph). +// - buffer_requirements: the returned `LiteRtTensorBufferRequirements`. +LiteRtStatus LiteRtGetCompiledModelOutputBufferRequirements( + LiteRtCompiledModel compiled_model, LiteRtParamIndex signature_index, + LiteRtParamIndex output_index, + LiteRtTensorBufferRequirements* buffer_requirements); + +// Runs the model of the given n-th signature with the provided input/output +// LiteRtTensorBuffer. +// +// Parameters: +// - compiled_model: the target `LiteRtCompiledModel` object. +// - signature_index: the index of the signature in `LiteRtModel`. +// - num_input_buffers: the number of input `LiteRtTensorBuffer`. +// - input_buffers: the array of input `LiteRtTensorBuffer`. +// - num_output_buffers: the number of output `LiteRtTensorBuffer`. +// - output_buffers: the array of output LiteRtTensorBuffer. +LiteRtStatus LiteRtRunCompiledModel(LiteRtCompiledModel compiled_model, + LiteRtParamIndex signature_index, + size_t num_input_buffers, + LiteRtTensorBuffer* input_buffers, + size_t num_output_buffers, + LiteRtTensorBuffer* output_buffers); + +void LiteRtDestroyCompiledModel(LiteRtCompiledModel compiled_model); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILED_MODEL_H_ diff --git a/tflite/experimental/litert/c/litert_compiled_model_options.h b/tflite/experimental/litert/c/litert_compiled_model_options.h new file mode 100644 index 00000000..f9583744 --- /dev/null +++ b/tflite/experimental/litert/c/litert_compiled_model_options.h @@ -0,0 +1,35 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILED_MODEL_OPTIONS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILED_MODEL_OPTIONS_H_ + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// The compilation options for the LiteRtCompiledModel. +// WARNING: This is an experimental and subject to change. +// TODO: b/379317134 - Add GPU support. +typedef enum LiteRtComplicationOptions : int { + kHwAccelDefault = 0, + kHwAccelCpu = 1 << 0, + kHwAccelNpu = 1 << 1, +} LiteRtComplicationOptions; + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILED_MODEL_OPTIONS_H_ diff --git a/tflite/experimental/litert/c/litert_compiled_model_test.cc b/tflite/experimental/litert/c/litert_compiled_model_test.cc new file mode 100644 index 00000000..f44d0779 --- /dev/null +++ b/tflite/experimental/litert/c/litert_compiled_model_test.cc @@ -0,0 +1,166 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/c/litert_compiled_model.h" + +#include +#include +#include + +#include +#include +#include "absl/log/absl_log.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_compiled_model_options.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/test/common.h" +#include "tflite/experimental/litert/test/testdata/simple_model_test_vectors.h" + +using testing::FloatNear; +using testing::Pointwise; + +namespace litert { +namespace { + +TEST(CompiledModelTest, Basic) { + auto path = testing::GetTestFilePath(kModelFileName); + + LiteRtModel model; + ASSERT_EQ(LiteRtCreateModelFromFile(path.c_str(), &model), kLiteRtStatusOk); + + LiteRtCompiledModel compiled_model; + ASSERT_EQ(LiteRtCreateCompiledModel(model, kHwAccelCpu, &compiled_model), + kLiteRtStatusOk); + + LiteRtSubgraph subgraph; + ASSERT_EQ(LiteRtGetModelSubgraph(model, 0, &subgraph), kLiteRtStatusOk); + + LiteRtParamIndex num_inputs; + LiteRtTensorArray input_tensors; + ASSERT_EQ(LiteRtGetSubgraphInputs(subgraph, &num_inputs, &input_tensors), + kLiteRtStatusOk); + + std::vector input_tensor_buffers; + input_tensor_buffers.reserve(num_inputs); + for (auto i = 0; i < num_inputs; ++i) { + LiteRtTensorBufferRequirements tensor_buffer_requirements; + ASSERT_EQ(LiteRtGetCompiledModelInputBufferRequirements( + compiled_model, /*signature_index=*/0, i, + &tensor_buffer_requirements), + kLiteRtStatusOk); + LiteRtTensorBufferType tensor_buffer_type; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + tensor_buffer_requirements, /*type_index=*/0, &tensor_buffer_type), + kLiteRtStatusOk); + size_t tensor_buffer_size; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( + tensor_buffer_requirements, &tensor_buffer_size), + kLiteRtStatusOk); + LiteRtTensorBuffer tensor_buffer; + EXPECT_EQ( + LiteRtCreateManagedTensorBuffer(tensor_buffer_type, &kInput0TensorType, + tensor_buffer_size, &tensor_buffer), + kLiteRtStatusOk); + input_tensor_buffers.push_back(tensor_buffer); + } + + LiteRtParamIndex num_outputs; + LiteRtTensorArray output_tensors; + ASSERT_EQ(LiteRtGetSubgraphOutputs(subgraph, &num_outputs, &output_tensors), + kLiteRtStatusOk); + + std::vector output_tensor_buffers; + output_tensor_buffers.reserve(num_outputs); + for (auto i = 0; i < num_outputs; ++i) { + LiteRtTensorBufferRequirements tensor_buffer_requirements; + ASSERT_EQ(LiteRtGetCompiledModelOutputBufferRequirements( + compiled_model, /*signature_index=*/0, i, + &tensor_buffer_requirements), + kLiteRtStatusOk); + LiteRtTensorBufferType tensor_buffer_type; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + tensor_buffer_requirements, /*type_index=*/0, &tensor_buffer_type), + kLiteRtStatusOk); + size_t tensor_buffer_size; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( + tensor_buffer_requirements, &tensor_buffer_size), + kLiteRtStatusOk); + LiteRtTensorBuffer tensor_buffer; + EXPECT_EQ( + LiteRtCreateManagedTensorBuffer(tensor_buffer_type, &kInput0TensorType, + tensor_buffer_size, &tensor_buffer), + kLiteRtStatusOk); + output_tensor_buffers.push_back(tensor_buffer); + } + + { + ABSL_LOG(INFO) << "Filling inputs with data"; + void* host_mem_addr; + + ASSERT_EQ(LiteRtLockTensorBuffer(input_tensor_buffers[0], &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_tensor_buffers[0]), + kLiteRtStatusOk); + + ASSERT_EQ(LiteRtLockTensorBuffer(input_tensor_buffers[1], &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_tensor_buffers[1]), + kLiteRtStatusOk); + } + + ASSERT_EQ(LiteRtRunCompiledModel( + compiled_model, /*signature_index=*/0, + input_tensor_buffers.size(), input_tensor_buffers.data(), + output_tensor_buffers.size(), output_tensor_buffers.data()), + kLiteRtStatusOk); + + { + ABSL_LOG(INFO) << "Checking output..."; + void* host_mem_addr; + ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffers[0], &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + auto output = absl::MakeSpan(static_cast(host_mem_addr), + kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(FloatNear(1e-3), kTestOutputTensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffers[0]), + kLiteRtStatusOk); + } + + LiteRtDestroyCompiledModel(compiled_model); + LiteRtDestroyModel(model); + + for (auto tensor_buffer : input_tensor_buffers) { + LiteRtDestroyTensorBuffer(tensor_buffer); + } + for (auto tensor_buffer : output_tensor_buffers) { + LiteRtDestroyTensorBuffer(tensor_buffer); + } +} + +} // namespace +} // namespace litert diff --git a/tflite/experimental/litert/c/litert_dispatch_delegate.h b/tflite/experimental/litert/c/litert_dispatch_delegate.h new file mode 100644 index 00000000..2a5a1750 --- /dev/null +++ b/tflite/experimental/litert/c/litert_dispatch_delegate.h @@ -0,0 +1,84 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_DISPATCH_DELEGATE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_DISPATCH_DELEGATE_H_ + +#include + +#include "tflite/c/c_api_opaque.h" +#include "tflite/c/c_api_types.h" +#include "tflite/c/common.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" + +#ifdef __cplusplus +#include + +#include "tflite/delegates/utils/simple_opaque_delegate.h" +#endif + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct LiteRtDispatchDelegateOptions LiteRtDispatchDelegateOptions; + +// Returns DispatchDelegateOptions populated with default values. +LiteRtDispatchDelegateOptions* LiteRtCreateDefaultDispatchDelegateOptions(); + +TfLiteStatus LiteRtAddDispatchDelegateOption( + LiteRtDispatchDelegateOptions* options, LiteRtDispatchOption option); + +void LiteRtDestroyDispatchDelegateOptions( + LiteRtDispatchDelegateOptions* options); + +// Create a delegate that uses the Dispatch API for execution. Takes ownership +// of the passed `options`. Must outlive the TFL interpreter. +TfLiteOpaqueDelegate* LiteRtCreateDispatchDelegate( + LiteRtDispatchDelegateOptions* options); + +// Do any needed cleanup and delete 'delegate'. +void LiteRtDestroyDispatchDelegate(TfLiteOpaqueDelegate* delegate); + +// +// Common option helpers +// + +// Alloc base is the address of the first byte of flatbuffer model in memory. It +// is used by ops to find the start of npu byte code appended to the file. +TfLiteStatus LiteRtDispatchDelegateAddAllocBaseOption( + LiteRtDispatchDelegateOptions* options, const void* alloc_base); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#ifdef __cplusplus +namespace litert { + +using DispatchDelegateOptionsPtr = + std::unique_ptr; + +using DispatchDelegatePtr = tflite::TfLiteOpaqueDelegateUniquePtr; + +DispatchDelegateOptionsPtr CreateDispatchDelegateOptionsPtr(); + +DispatchDelegatePtr CreateDispatchDelegatePtr( + DispatchDelegateOptionsPtr&& options); + +} // namespace litert +#endif + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_DISPATCH_DELEGATE_H_ diff --git a/tflite/experimental/litert/c/litert_environment.cc b/tflite/experimental/litert/c/litert_environment.cc new file mode 100644 index 00000000..23cbe147 --- /dev/null +++ b/tflite/experimental/litert/c/litert_environment.cc @@ -0,0 +1,31 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/c/litert_environment.h" + +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/core/environment.h" + +LiteRtStatus LiteRtEnvironmentCreate(int num_options, + const LiteRtEnvOption* options) { + if (auto status = litert::internal::Environment::CreateWithOptions( + absl::MakeSpan(options, num_options)); + !status) { + return status.Error().Status(); + } + return kLiteRtStatusOk; +} + +void LiteRtEnvironmentDestroy() { litert::internal::Environment::Destroy(); } diff --git a/tflite/experimental/litert/c/litert_environment.h b/tflite/experimental/litert/c/litert_environment.h new file mode 100644 index 00000000..a1b18205 --- /dev/null +++ b/tflite/experimental/litert/c/litert_environment.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ENVIRONMENT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ENVIRONMENT_H_ + +#include "tflite/experimental/litert/c/litert_any.h" +#include "tflite/experimental/litert/c/litert_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef enum { + kLiteRtEnvOptionTagCompilerPluginLibraryPath = 0, + kLiteRtEnvOptionTagDispatchLibraryPath = 1, +} LiteRtEnvOptionTag; + +typedef struct { + LiteRtEnvOptionTag tag; + LiteRtAny value; +} LiteRtEnvOption; + +// Create a singleton LiteRT environment with options. Returns an error if the +// instance already exists, in which case the specified options have no +// effect. If not created explicitly with options, the environment instance will +// be created (with no options) when needed. +LiteRtStatus LiteRtEnvironmentCreate(int num_options, + const LiteRtEnvOption* options); + +// Destroy the LiteRT environment instance. +void LiteRtEnvironmentDestroy(); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ENVIRONMENT_H_ diff --git a/tflite/experimental/litert/c/litert_event.cc b/tflite/experimental/litert/c/litert_event.cc new file mode 100644 index 00000000..ec76ab1a --- /dev/null +++ b/tflite/experimental/litert/c/litert_event.cc @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/c/litert_event.h" + +#include +#include +#include + +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/runtime/event.h" + +#if LITERT_HAS_SYNC_FENCE_SUPPORT +LiteRtStatus LiteRtCreateEventFromSyncFenceFd(int sync_fence_fd, bool owns_fd, + LiteRtEvent* event) { + *event = new LiteRtEventT{.fd = sync_fence_fd, .owns_fd = owns_fd}; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetEventSyncFenceFd(LiteRtEvent event, int* sync_fence_fd) { + *sync_fence_fd = event->fd; + return kLiteRtStatusOk; +} +#endif + +LiteRtStatus LiteRtEventWait(LiteRtEvent event, int64_t timeout_in_ms) { + return event->Wait(timeout_in_ms); +} + +void LiteRtDestroyEvent(LiteRtEvent event) { delete event; } diff --git a/tflite/experimental/litert/c/litert_event.h b/tflite/experimental/litert/c/litert_event.h new file mode 100644 index 00000000..45a1644d --- /dev/null +++ b/tflite/experimental/litert/c/litert_event.h @@ -0,0 +1,45 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_H_ + +#include // NOLINT: To use bool type in C +#include + +#include "tflite/experimental/litert/c/litert_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +LITERT_DEFINE_HANDLE(LiteRtEvent); + +#if LITERT_HAS_SYNC_FENCE_SUPPORT +LiteRtStatus LiteRtCreateEventFromSyncFenceFd(int sync_fence_fd, bool owns_fd, + LiteRtEvent* event); + +LiteRtStatus LiteRtGetEventSyncFenceFd(LiteRtEvent event, int* sync_fence_fd); +#endif // LITERT_HAS_SYNC_FENCE_SUPPORT + +// Pass -1 for timeout_in_ms for indefinite wait. +LiteRtStatus LiteRtEventWait(LiteRtEvent event, int64_t timeout_in_ms); + +void LiteRtDestroyEvent(LiteRtEvent event); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_H_ diff --git a/tflite/experimental/litert/c/litert_layout.h b/tflite/experimental/litert/c/litert_layout.h new file mode 100644 index 00000000..b641985b --- /dev/null +++ b/tflite/experimental/litert/c/litert_layout.h @@ -0,0 +1,45 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LAYOUT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LAYOUT_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Max number of dimensions in any ranked tensor type. +#define LITERT_TENSOR_MAX_RANK 8 + +// The shape information for tensor types of fixed rank. +typedef struct { + // The number of dimensions. + uint32_t rank; + + // Dimension sizes, array of length `rank`. Dynamic dimensions are anything + // less than 0. Everything from [rank, LITERT_MAX_RANK) is undefined. + int32_t dimensions[LITERT_TENSOR_MAX_RANK]; + + // Strides for a nomimal NWHC layout. NULL if unused. + const uint32_t* strides; +} LiteRtLayout; + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LAYOUT_H_ diff --git a/tflite/experimental/litert/c/litert_logging.cc b/tflite/experimental/litert/c/litert_logging.cc new file mode 100644 index 00000000..ff6daa9e --- /dev/null +++ b/tflite/experimental/litert/c/litert_logging.cc @@ -0,0 +1,107 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/c/litert_logging.h" + +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/logger.h" +#include "tflite/minimal_logging.h" + +class LiteRtLoggerT { + public: + LiteRtLogSeverity GetMinSeverity() { + return ConvertSeverity( + tflite::logging_internal::MinimalLogger::GetMinimumLogSeverity()); + } + + void SetMinSeverity(LiteRtLogSeverity severity) { + tflite::logging_internal::MinimalLogger::SetMinimumLogSeverity( + ConvertSeverity(severity)); + } + + void Log(LiteRtLogSeverity severity, const char* format, va_list args) { + tflite::logging_internal::MinimalLogger::LogFormatted( + ConvertSeverity(severity), format, args); + } + + private: + static tflite::LogSeverity ConvertSeverity(LiteRtLogSeverity severity) { + return static_cast(severity); + } + + static LiteRtLogSeverity ConvertSeverity(tflite::LogSeverity severity) { + return static_cast(severity); + } +}; + +LiteRtStatus LiteRtCreateLogger(LiteRtLogger* logger) { + if (!logger) { + return kLiteRtStatusErrorInvalidArgument; + } + *logger = new LiteRtLoggerT; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetMinLoggerSeverity(LiteRtLogger logger, + LiteRtLogSeverity* min_severity) { + if (!logger || !min_severity) { + return kLiteRtStatusErrorInvalidArgument; + } + *min_severity = logger->GetMinSeverity(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtSetMinLoggerSeverity(LiteRtLogger logger, + LiteRtLogSeverity min_severity) { + if (!logger) { + return kLiteRtStatusErrorInvalidArgument; + } + logger->SetMinSeverity(min_severity); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtLoggerLog(LiteRtLogger logger, LiteRtLogSeverity severity, + const char* format, ...) { + if (!logger || !format) { + return kLiteRtStatusErrorInvalidArgument; + } + va_list args; + va_start(args, format); + logger->Log(severity, format, args); + va_end(args); + return kLiteRtStatusOk; +} + +void LiteRtDestroyLogger(LiteRtLogger logger) { + if (logger != nullptr) { + delete logger; + } +} + +namespace { +LiteRtLoggerT StaticLogger; +LiteRtLogger DefaultLogger = &StaticLogger; +} // namespace + +LiteRtStatus LiteRtSetDefaultLogger(LiteRtLogger logger) { + if (!logger) { + return kLiteRtStatusErrorInvalidArgument; + } + DefaultLogger = logger; + return kLiteRtStatusOk; +} + +LiteRtLogger LiteRtGetDefaultLogger() { return DefaultLogger; } diff --git a/tflite/experimental/litert/c/litert_logging.h b/tflite/experimental/litert/c/litert_logging.h new file mode 100644 index 00000000..98afbff0 --- /dev/null +++ b/tflite/experimental/litert/c/litert_logging.h @@ -0,0 +1,87 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LOGGING_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LOGGING_H_ + +#include + +#include "tflite/experimental/litert/c/litert_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +LITERT_DEFINE_HANDLE(LiteRtLogger); + +// WARNING: The values of the following enum are to be kept in sync with +// tflite::LogSeverity. +typedef enum { + kLiteRtLogSeverityVerbose = 0, + kLiteRtLogSeverityInfo = 1, + kLiteRtLogSeverityWarning = 2, + kLiteRtLogSeverityError = 3, + kLiteRtLogSeveritySilent = 4, +} LiteRtLogSeverity; + +#define LITERT_VERBOSE kLiteRtLogSeverityVerbose +#define LITERT_INFO kLiteRtLogSeverityInfo +#define LITERT_WARNING kLiteRtLogSeverityWarning +#define LITERT_ERROR kLiteRtLogSeverityError +#define LITERT_SILENT kLiteRtLogSeveritySilent + +LiteRtStatus LiteRtCreateLogger(LiteRtLogger* logger); +LiteRtStatus LiteRtGetMinLoggerSeverity(LiteRtLogger logger, + LiteRtLogSeverity* min_severity); +LiteRtStatus LiteRtSetMinLoggerSeverity(LiteRtLogger logger, + LiteRtLogSeverity min_severity); +LiteRtStatus LiteRtLoggerLog(LiteRtLogger logger, LiteRtLogSeverity severity, + const char* format, ...); +void LiteRtDestroyLogger(LiteRtLogger logger); + +LiteRtLogger LiteRtGetDefaultLogger(); +LiteRtStatus LiteRtSetDefaultLogger(LiteRtLogger logger); +LiteRtStatus LiteRtDefaultLoggerLog(LiteRtLogSeverity severity, + const char* format, ...); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#define LITERT_LOGGER_LOG_PROD(logger, severity, format, ...) \ + { \ + LiteRtLogSeverity __min_severity__; \ + if (LiteRtGetMinLoggerSeverity(logger, &__min_severity__) != \ + kLiteRtStatusOk) { \ + __min_severity__ = kLiteRtLogSeverityVerbose; \ + } \ + if (severity >= __min_severity__) { \ + LiteRtLoggerLog(logger, severity, "[%s:%d] " format, __FILE__, __LINE__, \ + ##__VA_ARGS__); \ + } \ + } + +#ifndef NDEBUG +#define LITERT_LOGGER_LOG LITERT_LOGGER_LOG_PROD +#else +#define LITERT_LOGGER_LOG(logger, severity, format, ...) \ + do { \ + LITERT_LOGGER_LOG_PROD(logger, severity, format, ##__VA_ARGS__); \ + } while (false) +#endif + +#define LITERT_LOG(severity, format, ...) \ + LITERT_LOGGER_LOG(LiteRtGetDefaultLogger(), severity, format, ##__VA_ARGS__); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LOGGING_H_ diff --git a/tflite/experimental/litert/c/litert_logging_test.cc b/tflite/experimental/litert/c/litert_logging_test.cc new file mode 100644 index 00000000..afdf12a2 --- /dev/null +++ b/tflite/experimental/litert/c/litert_logging_test.cc @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/c/litert_logging.h" + +#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 +#include "tflite/experimental/litert/c/litert_common.h" + +TEST(Layout, Creation) { + LiteRtLogger logger; + ASSERT_EQ(LiteRtCreateLogger(&logger), kLiteRtStatusOk); + LiteRtDestroyLogger(logger); +} + +TEST(Layout, MinLogging) { + LiteRtLogger logger; + ASSERT_EQ(LiteRtCreateLogger(&logger), kLiteRtStatusOk); + ASSERT_EQ(LiteRtSetMinLoggerSeverity(logger, LITERT_SILENT), kLiteRtStatusOk); + LiteRtLogSeverity min_severity; + ASSERT_EQ(LiteRtGetMinLoggerSeverity(logger, &min_severity), kLiteRtStatusOk); + ASSERT_EQ(min_severity, LITERT_SILENT); + LiteRtDestroyLogger(logger); +} diff --git a/tflite/experimental/litert/c/litert_model.cc b/tflite/experimental/litert/c/litert_model.cc new file mode 100644 index 00000000..70279a39 --- /dev/null +++ b/tflite/experimental/litert/c/litert_model.cc @@ -0,0 +1,366 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/c/litert_model.h" + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/core/model/model_load.h" + +// +// Model +// + +LiteRtStatus LiteRtCreateModelFromFile(const char* filename, + LiteRtModel* model) { + if (!filename || !model) { + return kLiteRtStatusErrorInvalidArgument; + } + + auto new_model = litert::internal::LoadModelFromFile(filename); + if (!new_model) { + return new_model.Error().Status(); + } + *model = new_model->release(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtCreateModelFromBuffer(const void* buffer_addr, + size_t buffer_size, + LiteRtModel* model) { + if (!buffer_addr || !buffer_size || !model) { + return kLiteRtStatusErrorInvalidArgument; + } + + auto new_model = litert::internal::LoadModelFromBuffer( + litert::BufferRef(buffer_addr, buffer_size)); + if (!new_model) { + return new_model.Error().Status(); + } + *model = new_model->release(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetNumModelSubgraphs(LiteRtModel model, + LiteRtParamIndex* num_subgraphs) { + if (model == nullptr) { + return kLiteRtStatusErrorInvalidArgument; + } + *num_subgraphs = model->Subgraphs().size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetModelSubgraph(LiteRtModel model, + LiteRtParamIndex subgraph_index, + LiteRtSubgraph* subgraph) { + if (model == nullptr) { + return kLiteRtStatusErrorInvalidArgument; + } + if (subgraph_index >= model->Subgraphs().size()) { + return kLiteRtStatusErrorIndexOOB; + } + *subgraph = &model->Subgraph(subgraph_index); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetMainModelSubgraphIndex( + LiteRtModel model, LiteRtParamIndex* main_subgraph_index) { + if (!model || !main_subgraph_index) { + return kLiteRtStatusErrorInvalidArgument; + } + *main_subgraph_index = LiteRtModelT::kMainSubgraphIndex; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetModelMetadata(LiteRtModel model, const char* metadata_key, + const void** metadata_buffer, + size_t* metadata_buffer_size) { + if (!model || !metadata_key || !metadata_buffer || !metadata_buffer_size) { + return kLiteRtStatusErrorInvalidArgument; + } + auto m_buf = model->FindMetadata(metadata_key); + if (!m_buf) { + return m_buf.Error().Status(); + } + *metadata_buffer = m_buf->Data(); + *metadata_buffer_size = m_buf->Size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetNumModelSignatures(LiteRtModel model, + LiteRtParamIndex* num_signatures) { + if (!model || !num_signatures) { + return kLiteRtStatusErrorInvalidArgument; + } + *num_signatures = model->Signatures().size(); + return kLiteRtStatusOk; +} + +// Get the signature at the given index in the model +LiteRtStatus LiteRtGetModelSignature(LiteRtModel model, + LiteRtParamIndex signature_index, + LiteRtSignature* signature) { + if (!model || !signature) { + return kLiteRtStatusErrorInvalidArgument; + } + if (signature_index >= model->Signatures().size()) { + return kLiteRtStatusErrorIndexOOB; + } + *signature = model->Signatures().at(signature_index); + return kLiteRtStatusOk; +} + +void LiteRtDestroyModel(LiteRtModel model) { delete model; } + +LiteRtStatus LiteRtPushOp(LiteRtOpList op_list, LiteRtOp op) { + if (!op_list || !op) { + return kLiteRtStatusErrorInvalidArgument; + } + op_list->Push(op); + return kLiteRtStatusOk; +} + +// +// Signature +// + +LiteRtStatus LiteRtGetDefaultSignatureKey(const char** signature_key) { + if (!signature_key) { + return kLiteRtStatusErrorInvalidArgument; + } + *signature_key = LiteRtSignatureT::kDefaultSignatureKey.data(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetSignatureKey(LiteRtSignature signature, + const char** signature_key) { + if (!signature || !signature_key) { + return kLiteRtStatusErrorInvalidArgument; + } + *signature_key = signature->Key().data(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetSignatureSubgraph(LiteRtSignature signature, + LiteRtSubgraph* subgraph) { + if (signature == nullptr) { + return kLiteRtStatusErrorInvalidArgument; + } + *subgraph = &signature->GetSubgraph(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetNumSignatureInputs(LiteRtSignature signature, + LiteRtParamIndex* num_inputs) { + if (!signature || !num_inputs) { + return kLiteRtStatusErrorInvalidArgument; + } + *num_inputs = signature->InputNames().size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetSignatureInputName(LiteRtSignature signature, + LiteRtParamIndex input_idx, + const char** input_name) { + if (!signature || !input_name) { + return kLiteRtStatusErrorInvalidArgument; + } + if (input_idx >= signature->InputNames().size()) { + return kLiteRtStatusErrorIndexOOB; + } + *input_name = signature->InputNames().at(input_idx).data(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetNumSignatureOutputs(LiteRtSignature signature, + LiteRtParamIndex* num_outputs) { + if (!signature || !num_outputs) { + return kLiteRtStatusErrorInvalidArgument; + } + *num_outputs = signature->OutputNames().size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetSignatureOutputName(LiteRtSignature signature, + LiteRtParamIndex output_idx, + const char** output_name) { + if (!signature || !output_name) { + return kLiteRtStatusErrorInvalidArgument; + } + if (output_idx >= signature->OutputNames().size()) { + return kLiteRtStatusErrorIndexOOB; + } + *output_name = signature->OutputNames().at(output_idx).data(); + return kLiteRtStatusOk; +} + +// +// Subgraph +// + +LiteRtStatus LiteRtGetSubgraphInputs(LiteRtSubgraph subgraph, + LiteRtParamIndex* num_inputs, + LiteRtTensorArray* inputs) { + *num_inputs = subgraph->Inputs().size(); + *inputs = subgraph->Inputs().data(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetSubgraphOutputs(LiteRtSubgraph subgraph, + LiteRtParamIndex* num_outputs, + LiteRtTensorArray* outputs) { + *num_outputs = subgraph->Outputs().size(); + *outputs = subgraph->Outputs().data(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetSubgraphOps(LiteRtSubgraph subgraph, + LiteRtParamIndex* num_ops, + LiteRtOpArray* ops) { + *num_ops = subgraph->Ops().size(); + *ops = subgraph->Ops().data(); + return kLiteRtStatusOk; +} + +// +// Op +// + +LiteRtStatus LiteRtGetOpOutputs(LiteRtOp op, LiteRtParamIndex* num_outputs, + LiteRtTensorArray* outputs) { + *num_outputs = op->Outputs().size(); + *outputs = op->Outputs().data(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetOpInputs(LiteRtOp op, LiteRtParamIndex* num_inputs, + LiteRtTensorArray* inputs) { + *num_inputs = op->Inputs().size(); + *inputs = op->Inputs().data(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetOpCode(LiteRtOp op, LiteRtOpCode* code) { + *code = op->OpCode(); + return kLiteRtStatusOk; +} + +// +// Tensor +// + +LiteRtStatus LiteRtGetWeightsBytes(LiteRtWeights weights, const void** addr, + size_t* size) { + *addr = weights->Buf().Data(); + *size = weights->Buf().Size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorWeights(LiteRtTensor tensor, + LiteRtWeights* weights) { + *weights = &tensor->Weights(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorUses(LiteRtTensor tensor, + LiteRtParamIndex* num_uses, + LiteRtOpArray* use_users, + LiteRtParamIndex** use_user_arg_inds) { + *num_uses = tensor->Users().size(); + *use_users = tensor->Users().data(); + *use_user_arg_inds = tensor->UserArgInds().data(); + return kLiteRtStatusOk; +} + +// Null if subgraph input or constant. +LiteRtStatus LiteRtGetTensorDefiningOp(LiteRtTensor tensor, + bool* has_defining_op, + LiteRtTensorDefiningOp* defining_op) { + if (tensor->DefiningOp() != nullptr) { + *has_defining_op = true; + defining_op->op = tensor->DefiningOp(); + defining_op->op_output_index = tensor->DefiningOpOutInd(); + } else { + *has_defining_op = false; + } + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorTypeId(LiteRtTensor tensor, + LiteRtTensorTypeId* type_id) { + *type_id = tensor->Type().first; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetUnrankedTensorType( + LiteRtTensor tensor, LiteRtUnrankedTensorType* unranked_tensor_type) { + if (tensor->Type().first != kLiteRtUnrankedTensorType) { + return kLiteRtStatusErrorInvalidIrType; + } + *unranked_tensor_type = tensor->Type().second.unranked_tensor_type; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetRankedTensorType( + LiteRtTensor tensor, LiteRtRankedTensorType* ranked_tensor_type) { + if (tensor->Type().first != kLiteRtRankedTensorType) { + return kLiteRtStatusErrorInvalidIrType; + } + *ranked_tensor_type = tensor->Type().second.ranked_tensor_type; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorName(LiteRtTensor tensor, const char** name) { + *name = tensor->Name().data(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetQuantizationTypeId(LiteRtTensor tensor, + LiteRtQuantizationTypeId* q_type_id) { + *q_type_id = tensor->Qparams().first; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetPerTensorQuantization( + LiteRtTensor tensor, LiteRtQuantizationPerTensor* per_tensor_quantization) { + if (tensor->Qparams().first != kLiteRtQuantizationPerTensor) { + return kLiteRtStatusErrorInvalidIrType; + } + auto& per_tensor = tensor->Qparams().second.per_tensor; + per_tensor_quantization->scale = per_tensor.scale; + per_tensor_quantization->zero_point = per_tensor.zero_point; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetPerChannelQuantization( + LiteRtTensor tensor, + LiteRtQuantizationPerChannel* per_channel_quantization) { + if (tensor->Qparams().first != kLiteRtQuantizationPerChannel) { + return kLiteRtStatusErrorInvalidIrType; + } + auto& per_channel = tensor->Qparams().second.per_channel; + per_channel_quantization->scales = per_channel.scales; + per_channel_quantization->zero_points = per_channel.zero_points; + per_channel_quantization->num_channels = per_channel.num_channels; + per_channel_quantization->quantized_dimension = + per_channel.quantized_dimension; + return kLiteRtStatusOk; +} diff --git a/tflite/experimental/litert/c/litert_model.h b/tflite/experimental/litert/c/litert_model.h new file mode 100644 index 00000000..3d16a85c --- /dev/null +++ b/tflite/experimental/litert/c/litert_model.h @@ -0,0 +1,347 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_MODEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_MODEL_H_ + +#include // NOLINT: To use bool type in C +#include +#include + +#include "tflite/core/c/c_api_types.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_layout.h" +#include "tflite/experimental/litert/c/litert_op_code.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// +// Handles + Common +// + +// Constant data behind a tensor stored in the model. +LITERT_DEFINE_HANDLE(LiteRtWeights); + +// Values/edges of the models graph. +LITERT_DEFINE_HANDLE(LiteRtTensor); +LITERT_DEFINE_HANDLE_ARRAY(LiteRtTensor); + +// Operations/nodes of the models graph. +LITERT_DEFINE_HANDLE(LiteRtOp); +LITERT_DEFINE_HANDLE_ARRAY(LiteRtOp); + +// Fundamental block of program, i.e. a function body. +LITERT_DEFINE_HANDLE(LiteRtSubgraph); +LITERT_DEFINE_HANDLE_ARRAY(LiteRtSubgraph); + +// Signature of the model. +LITERT_DEFINE_HANDLE(LiteRtSignature); + +// A collection of subgraph + metadata + signature. +LITERT_DEFINE_HANDLE(LiteRtModel); + +// Append only list of ops. +LITERT_DEFINE_HANDLE(LiteRtOpList); + +// For indexing into litert collections or counting litert things. +typedef uint64_t LiteRtParamIndex; + +// +// LiteRtTensor + Types +// + +// Get the string name associated with this tensor. This is an optional +// attribute and if not set will return a zero-length string. +LiteRtStatus LiteRtGetTensorName(LiteRtTensor tensor, const char** name); + +// TENSOR TYPES + +// Primitive types for elements in a tensor. +typedef enum { + kLiteRtElementTypeNone = kTfLiteNoType, + kLiteRtElementTypeBool = kTfLiteBool, + kLiteRtElementTypeInt4 = kTfLiteInt4, + kLiteRtElementTypeInt8 = kTfLiteInt8, + kLiteRtElementTypeInt16 = kTfLiteInt16, + kLiteRtElementTypeInt32 = kTfLiteInt32, + kLiteRtElementTypeInt64 = kTfLiteInt64, + kLiteRtElementTypeUInt8 = kTfLiteUInt8, + kLiteRtElementTypeUInt16 = kTfLiteUInt16, + kLiteRtElementTypeUInt32 = kTfLiteUInt32, + kLiteRtElementTypeUInt64 = kTfLiteUInt64, + kLiteRtElementTypeFloat16 = kTfLiteFloat16, + kLiteRtElementTypeBFloat16 = kTfLiteBFloat16, + kLiteRtElementTypeFloat32 = kTfLiteFloat32, + kLiteRtElementTypeFloat64 = kTfLiteFloat64, + kLiteRtElementTypeComplex64 = kTfLiteComplex64, + kLiteRtElementTypeComplex128 = kTfLiteComplex128, + kLiteRtElementTypeTfResource = kTfLiteResource, + kLiteRtElementTypeTfString = kTfLiteString, + kLiteRtElementTypeTfVariant = kTfLiteVariant, +} LiteRtElementType; + +// Tensor whose rank is dynamic. +typedef struct { + // The primitive element type of the constituent data. + LiteRtElementType element_type; +} LiteRtUnrankedTensorType; + +// Tensor whose rank is static but dimenions may be dynamic. +typedef struct { + // The primitive element type of the constituent data. + LiteRtElementType element_type; + + // Shape information. + LiteRtLayout layout; +} LiteRtRankedTensorType; + +// The identifier for tensor type union. +typedef enum { + // Type with fix ranked and possibly dynamic dimensions. + kLiteRtRankedTensorType = 0, + + // Type with dynamic rank. + kLiteRtUnrankedTensorType = 1, +} LiteRtTensorTypeId; + +// Get type identifier from tensor. +LiteRtStatus LiteRtGetTensorTypeId(LiteRtTensor tensor, + LiteRtTensorTypeId* type_id); + +// Get unranked tensor type info, return bad status if not unranked. +LiteRtStatus LiteRtGetUnrankedTensorType( + LiteRtTensor tensor, LiteRtUnrankedTensorType* unranked_tensor_type); + +// Get ranked tensor type info, return bad status if not ranked. +LiteRtStatus LiteRtGetRankedTensorType( + LiteRtTensor tensor, LiteRtRankedTensorType* ranked_tensor_type); + +// QUANTIZATION + +// Schema for tensors quantized with one set of q-params. +typedef struct { + // Scaling factor. + float scale; + + // The value that float:0 maps to in q-space. + int64_t zero_point; +} LiteRtQuantizationPerTensor; + +// Schema for tensors quantized with one set of q-params per channel. +typedef struct { + int32_t quantized_dimension; + uint64_t num_channels; + float* scales; + int64_t* zero_points; +} LiteRtQuantizationPerChannel; + +// The identifier for quantization scheme type union. +typedef enum { + // Tag for tensors without quantization. + kLiteRtQuantizationNone = 0, + + // Basic quantization, one set of q-params per tensor. + kLiteRtQuantizationPerTensor = 1, + + // [NOT IMPLEMENTED YET] Q-params for each element accross a single dimension. + kLiteRtQuantizationPerChannel = 2, + + // [NOT IMPLEMENTED YET] Q-params accross blocks of fixed size (e.g. 2048). + kLiteRtQuantizationBlockWise = 3, +} LiteRtQuantizationTypeId; + +// Get the identifier for the type of quantization for a given tensor. +LiteRtStatus LiteRtGetQuantizationTypeId(LiteRtTensor tensor, + LiteRtQuantizationTypeId* q_type_id); + +// Get the per-tensor quantization information for a given tensor if it has it. +LiteRtStatus LiteRtGetPerTensorQuantization( + LiteRtTensor tensor, LiteRtQuantizationPerTensor* per_tensor_quantization); + +// Get the per-channel quantization information for a given tensor if it has it. +LiteRtStatus LiteRtGetPerChannelQuantization( + LiteRtTensor tensor, + LiteRtQuantizationPerChannel* per_channel_quantization); + +// EDGES + +// Information about the about that defines a tensor. +typedef struct LiteRtTensorDefiningOp { + // The defining op itself. + LiteRtOp op; + + // The op output index that defines the specific tensor. + LiteRtParamIndex op_output_index; +} LiteRtTensorDefiningOp; + +// Information about a reference to a tensor in the graph. +typedef struct LiteRtTensorUserOp { + // The referring op itself. + LiteRtOp op; + + // Index of which operand the op refers to a specific tensor on. + LiteRtParamIndex op_input_index; +} LiteRtTensorUserOp; + +// Get all the ops that reference given tensor, and at what operand index. +LiteRtStatus LiteRtGetTensorUses(LiteRtTensor tensor, + LiteRtParamIndex* num_uses, + LiteRtOpArray* users, + LiteRtParamIndex** user_arg_inds); + +// Get the op that defines this tensor and the corresponding output index. If +// tensor is a subgraph input, has_defining_op will be false. +LiteRtStatus LiteRtGetTensorDefiningOp(LiteRtTensor tensor, + bool* has_defining_op, + LiteRtTensorDefiningOp* defining_op); + +// WEIGHTS (constant data) + +// Get static weights associated with a given tensor. All tensors have weights, +// null weights have size = 0; +LiteRtStatus LiteRtGetTensorWeights(LiteRtTensor tensor, + LiteRtWeights* weights); + +// +// LiteRtWeights +// + +// Get opaque array from given tensor weights. +LiteRtStatus LiteRtGetWeightsBytes(LiteRtWeights weights, const void** addr, + size_t* size); + +// +// LiteRtOp +// + +// Get code corresponding to operation type for given op. +LiteRtStatus LiteRtGetOpCode(LiteRtOp op, LiteRtOpCode* code); + +// Get input tensors of given op. +LiteRtStatus LiteRtGetOpInputs(LiteRtOp op, LiteRtParamIndex* num_inputs, + LiteRtTensorArray* inputs); + +// Get output tensors of given op. +LiteRtStatus LiteRtGetOpOutputs(LiteRtOp op, LiteRtParamIndex* num_outputs, + LiteRtTensorArray* outputs); + +// +// LiteRtSubgraph +// + +// Get input tensors for given subgraph. +LiteRtStatus LiteRtGetSubgraphInputs(LiteRtSubgraph subgraph, + LiteRtParamIndex* num_inputs, + LiteRtTensorArray* inputs); + +// Get output tensors for given subgraph. +LiteRtStatus LiteRtGetSubgraphOutputs(LiteRtSubgraph subgraph, + LiteRtParamIndex* num_outputs, + LiteRtTensorArray* outputs); + +// Get all ops in given subgraph in a topological order. +LiteRtStatus LiteRtGetSubgraphOps(LiteRtSubgraph subgraph, + LiteRtParamIndex* num_ops, + LiteRtOpArray* ops); + +// +// LiteRtSignature +// + +// Default signature key. This is the key that is used if the model does not +// define any signatures. +LiteRtStatus LiteRtGetDefaultSignatureKey(const char** signature_key); + +// Get the signature key string defined in the model. +LiteRtStatus LiteRtGetSignatureKey(LiteRtSignature signature, + const char** signature_key); + +// Get the associated subgraph for the given signature. +LiteRtStatus LiteRtGetSignatureSubgraph(LiteRtSignature signature, + LiteRtSubgraph* subgraph); + +// Get the number of inputs for the given signature. +LiteRtStatus LiteRtGetNumSignatureInputs(LiteRtSignature signature, + LiteRtParamIndex* num_inputs); + +// Get the name of the i-th of input tensor name for the given signature. +LiteRtStatus LiteRtGetSignatureInputName(LiteRtSignature signature, + LiteRtParamIndex input_idx, + const char** input_name); + +// Get the number of outputs for the given signature. +LiteRtStatus LiteRtGetNumSignatureOutputs(LiteRtSignature signature, + LiteRtParamIndex* num_outputs); + +// Get the name of the i-th of output tensor name for the given signature. +LiteRtStatus LiteRtGetSignatureOutputName(LiteRtSignature signature, + LiteRtParamIndex output_idx, + const char** output_name); + +// +// LiteRtModel +// + +LiteRtStatus LiteRtCreateModelFromFile(const char* filename, + LiteRtModel* model); + +LiteRtStatus LiteRtCreateModelFromBuffer(const void* buffer_addr, + size_t buffer_size, + LiteRtModel* model); + +// Get the metadata buffer associated with given key if it exists. +LiteRtStatus LiteRtGetModelMetadata(LiteRtModel model, const char* metadata_key, + const void** metadata_buffer, + size_t* metadata_buffer_size); + +// Get the index of the entry subgraph. +// TODO: b/365299994 - Figure out signatures. +LiteRtStatus LiteRtGetMainModelSubgraphIndex( + LiteRtModel model, LiteRtParamIndex* main_subgraph_index); + +// Get number of subgraphs in model. +LiteRtStatus LiteRtGetNumModelSubgraphs(LiteRtModel model, + LiteRtParamIndex* num_subgraphs); + +// Get subgraph at given index in model. +LiteRtStatus LiteRtGetModelSubgraph(LiteRtModel model, + LiteRtParamIndex subgraph_index, + LiteRtSubgraph* subgraph); + +// Get the number of signatures defined in the model. +LiteRtStatus LiteRtGetNumModelSignatures(LiteRtModel model, + LiteRtParamIndex* num_signatures); + +// Get the signature at the given index in the model +LiteRtStatus LiteRtGetModelSignature(LiteRtModel model, + LiteRtParamIndex signature_index, + LiteRtSignature* signature); + +// Destroy the given model, freeing any memory it owns. +void LiteRtDestroyModel(LiteRtModel model); + +// +// Utility Types +// + +// An append only list of ops. +LiteRtStatus LiteRtPushOp(LiteRtOpList op_list, LiteRtOp op); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_MODEL_H_ diff --git a/tflite/experimental/litert/c/litert_model_test.cc b/tflite/experimental/litert/c/litert_model_test.cc new file mode 100644 index 00000000..c363cab6 --- /dev/null +++ b/tflite/experimental/litert/c/litert_model_test.cc @@ -0,0 +1,359 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/c/litert_model.h" + +#include +#include +#include +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tflite/experimental/litert/test/test_macros.h" + +namespace { + +using ::litert::BufferRef; +using ::litert::internal::MakeTflBuffer; +using ::testing::ElementsAreArray; + +TEST(LiteRtWeightsTest, GetNullWeights) { + LiteRtWeightsT weights = {}; + + const void* addr; + size_t size; + LITERT_ASSERT_STATUS_OK(LiteRtGetWeightsBytes(&weights, &addr, &size)); + + EXPECT_EQ(addr, nullptr); + EXPECT_EQ(size, 0); +} + +TEST(LiteRtWeightsTest, GetWeights) { + LiteRtWeightsT weights; + detail::SetTflBuffer(weights, MakeTflBuffer({1, 2, 3})); + + const void* addr; + size_t size; + LITERT_ASSERT_STATUS_OK(LiteRtGetWeightsBytes(&weights, &addr, &size)); + + EXPECT_NE(addr, nullptr); + EXPECT_EQ(size, 3 * sizeof(int32_t)); + + EXPECT_THAT(absl::MakeConstSpan(reinterpret_cast(addr), 3), + ElementsAreArray({1, 2, 3})); +} + +TEST(LiteRtTensorTest, GetUnrankedType) { + static constexpr auto kElementType = kLiteRtElementTypeFloat32; + static constexpr auto kId = kLiteRtUnrankedTensorType; + + TensorType type; + type.first = kId; + type.second.unranked_tensor_type.element_type = kElementType; + + LiteRtTensorT tensor; + tensor.SetType(std::move(type)); + + LiteRtTensorTypeId id; + LITERT_ASSERT_STATUS_OK(LiteRtGetTensorTypeId(&tensor, &id)); + ASSERT_EQ(id, kId); + + LiteRtUnrankedTensorType unranked; + LITERT_ASSERT_STATUS_OK(LiteRtGetUnrankedTensorType(&tensor, &unranked)); + EXPECT_EQ(unranked.element_type, kElementType); +} + +TEST(LiteRtTensorTest, GetRankedTensorType) { + static constexpr auto kElementType = kLiteRtElementTypeFloat32; + static constexpr auto kId = kLiteRtRankedTensorType; + + LiteRtTensorT tensor; + tensor.SetType(MakeRankedTensorType(kElementType, {3, 3})); + + LiteRtTensorTypeId id; + LITERT_ASSERT_STATUS_OK(LiteRtGetTensorTypeId(&tensor, &id)); + ASSERT_EQ(id, kId); + + LiteRtRankedTensorType ranked; + LITERT_ASSERT_STATUS_OK(LiteRtGetRankedTensorType(&tensor, &ranked)); + EXPECT_EQ(ranked.element_type, kElementType); + ASSERT_EQ(ranked.layout.rank, 2); + EXPECT_THAT(absl::MakeConstSpan(ranked.layout.dimensions, 2), + ElementsAreArray({3, 3})); +} + +TEST(LiteRtTensorTest, GetUses) { + LiteRtTensorT tensor; + + LiteRtOpT user; + tensor.Users().push_back(&user); + tensor.UserArgInds().push_back(0); + + LiteRtOpT other_user; + tensor.Users().push_back(&other_user); + tensor.UserArgInds().push_back(1); + + LiteRtParamIndex num_uses; + LiteRtOpArray actual_users; + LiteRtParamIndex* user_arg_inds; + LITERT_ASSERT_STATUS_OK( + LiteRtGetTensorUses(&tensor, &num_uses, &actual_users, &user_arg_inds)); + + ASSERT_EQ(num_uses, 2); + EXPECT_THAT(absl::MakeConstSpan(actual_users, 2), + ElementsAreArray({&user, &other_user})); + EXPECT_THAT(absl::MakeConstSpan(user_arg_inds, 2), ElementsAreArray({0, 1})); +} + +TEST(LiteRtTensorTest, GetDefiningOp) { + LiteRtTensorT tensor; + + LiteRtOpT def_op; + tensor.SetDefiningOp(def_op, 0); + + LiteRtTensorDefiningOp actual_def_op; + bool has_defining_op; + LITERT_ASSERT_STATUS_OK( + LiteRtGetTensorDefiningOp(&tensor, &has_defining_op, &actual_def_op)); + ASSERT_TRUE(has_defining_op); + EXPECT_EQ(actual_def_op.op, &def_op); + EXPECT_EQ(actual_def_op.op_output_index, 0); +} + +TEST(LiteRtTensorTest, NoDefiningOp) { + LiteRtTensorT tensor; + + LiteRtTensorDefiningOp actual_def_op; + bool has_defining_op; + LITERT_ASSERT_STATUS_OK( + LiteRtGetTensorDefiningOp(&tensor, &has_defining_op, &actual_def_op)); + ASSERT_FALSE(has_defining_op); +} + +TEST(LiteRtTensorTest, Name) { + static constexpr const char kName[] = "foo"; + + LiteRtTensorT tensor; + tensor.SetName(std::string(kName)); + + const char* name; + LITERT_ASSERT_STATUS_OK(LiteRtGetTensorName(&tensor, &name)); + EXPECT_STREQ(name, kName); +} + +TEST(LiteRtTensorTest, QuantizationNone) { + LiteRtTensorT tensor; + + LiteRtQuantizationTypeId q_type_id; + LITERT_ASSERT_STATUS_OK(LiteRtGetQuantizationTypeId(&tensor, &q_type_id)); + EXPECT_EQ(q_type_id, kLiteRtQuantizationNone); + + LiteRtQuantizationPerTensor per_tensor_quantization; + EXPECT_NE(LiteRtGetPerTensorQuantization(&tensor, &per_tensor_quantization), + kLiteRtStatusOk); +} + +TEST(LiteRtTensorTest, QuantizationPerTensor) { + static constexpr auto kScale = 1.0; + static constexpr auto kZeroPoint = 1; + + LiteRtTensorT tensor; + tensor.SetQarams(MakePerTensorQuantization(kScale, kZeroPoint)); + + LiteRtQuantizationTypeId q_type_id; + LITERT_ASSERT_STATUS_OK(LiteRtGetQuantizationTypeId(&tensor, &q_type_id)); + ASSERT_EQ(q_type_id, kLiteRtQuantizationPerTensor); + + LiteRtQuantizationPerTensor per_tensor_quantization; + LITERT_ASSERT_STATUS_OK( + LiteRtGetPerTensorQuantization(&tensor, &per_tensor_quantization)); + + EXPECT_EQ(per_tensor_quantization.scale, kScale); + EXPECT_EQ(per_tensor_quantization.zero_point, kZeroPoint); +} + +TEST(LiteRtTensorTest, QuantizationPerChannel) { + static constexpr size_t kNumChannels = 2; + static constexpr size_t kQuantizedDimension = 0; + static constexpr float kScales[kNumChannels] = {1.0, 2.0}; + static constexpr int64_t kZps[kNumChannels] = {2, 3}; + + LiteRtTensorT tensor; + + { + auto per_channel = + MakePerChannelQuantization(kScales, kZps, kQuantizedDimension, tensor); + tensor.SetQarams(per_channel); + } + + LiteRtQuantizationTypeId q_type_id; + LITERT_ASSERT_STATUS_OK(LiteRtGetQuantizationTypeId(&tensor, &q_type_id)); + ASSERT_EQ(q_type_id, kLiteRtQuantizationPerChannel); + + LiteRtQuantizationPerChannel per_channel_quantization; + LITERT_ASSERT_STATUS_OK( + LiteRtGetPerChannelQuantization(&tensor, &per_channel_quantization)); + + EXPECT_THAT( + absl::MakeConstSpan(per_channel_quantization.scales, kNumChannels), + testing::ElementsAreArray(kScales)); + EXPECT_THAT( + absl::MakeConstSpan(per_channel_quantization.zero_points, kNumChannels), + testing::ElementsAreArray(kZps)); + ASSERT_EQ(per_channel_quantization.num_channels, kNumChannels); + ASSERT_EQ(per_channel_quantization.quantized_dimension, kQuantizedDimension); +} + +TEST(LiteRtOpTest, GetOpCode) { + static constexpr auto kCode = kLiteRtOpCodeTflCustom; + + LiteRtOpT op; + op.SetOpCode(kCode); + + LiteRtOpCode code; + LITERT_ASSERT_STATUS_OK(LiteRtGetOpCode(&op, &code)); + EXPECT_EQ(code, kCode); +} + +TEST(LiteRtOpTest, GetInputs) { + LiteRtTensorT input1; + LiteRtTensorT input2; + + LiteRtOpT op; + op.Inputs().push_back(&input1); + op.Inputs().push_back(&input2); + + LiteRtTensorArray inputs; + LiteRtParamIndex num_inputs; + LITERT_ASSERT_STATUS_OK(LiteRtGetOpInputs(&op, &num_inputs, &inputs)); + ASSERT_EQ(num_inputs, 2); + EXPECT_THAT(absl::MakeConstSpan(inputs, num_inputs), + ElementsAreArray({&input1, &input2})); +} + +TEST(LiteRtOpTest, GetOutputs) { + LiteRtTensorT output1; + LiteRtTensorT output2; + + LiteRtOpT op; + op.Outputs().push_back(&output1); + op.Outputs().push_back(&output2); + + LiteRtTensorArray outputs; + LiteRtParamIndex num_outputs; + LITERT_ASSERT_STATUS_OK(LiteRtGetOpOutputs(&op, &num_outputs, &outputs)); + ASSERT_EQ(num_outputs, 2); + EXPECT_THAT(absl::MakeConstSpan(outputs, num_outputs), + ElementsAreArray({&output1, &output2})); +} + +TEST(LiteRtSubgraphTest, GetInputs) { + LiteRtTensorT input1; + LiteRtTensorT input2; + + LiteRtSubgraphT subgraph; + subgraph.Inputs().push_back(&input1); + subgraph.Inputs().push_back(&input2); + + LiteRtTensorArray inputs; + LiteRtParamIndex num_inputs; + LITERT_ASSERT_STATUS_OK( + LiteRtGetSubgraphInputs(&subgraph, &num_inputs, &inputs)); + ASSERT_EQ(num_inputs, 2); + EXPECT_THAT(absl::MakeConstSpan(inputs, num_inputs), + ElementsAreArray({&input1, &input2})); +} + +TEST(LiteRtSubgraphTest, GetOutputs) { + LiteRtTensorT output1; + LiteRtTensorT output2; + + LiteRtSubgraphT subgraph; + subgraph.Outputs().push_back(&output1); + subgraph.Outputs().push_back(&output2); + + LiteRtTensorArray outputs; + LiteRtParamIndex num_outputs; + LITERT_ASSERT_STATUS_OK( + LiteRtGetSubgraphOutputs(&subgraph, &num_outputs, &outputs)); + ASSERT_EQ(num_outputs, 2); + EXPECT_THAT(absl::MakeConstSpan(outputs, num_outputs), + ElementsAreArray({&output1, &output2})); +} + +TEST(LiteRtSubgraphTest, GetOps) { + LiteRtSubgraphT subgraph; + auto& op1 = subgraph.EmplaceOp(); + auto& op2 = subgraph.EmplaceOp(); + + LiteRtOpArray ops; + LiteRtParamIndex num_ops; + LITERT_ASSERT_STATUS_OK(LiteRtGetSubgraphOps(&subgraph, &num_ops, &ops)); + ASSERT_EQ(num_ops, 2); + EXPECT_THAT(absl::MakeConstSpan(ops, num_ops), + ElementsAreArray({&op1, &op2})); +} + +TEST(LiteRtModelTest, GetMetadata) { + static constexpr absl::string_view kKey = "KEY"; + static constexpr absl::string_view kData = "DATA"; + + LiteRtModelT model; + model.PushMetadata(kKey, kData); + + const void* metadata; + size_t metadata_size; + LITERT_ASSERT_STATUS_OK( + LiteRtGetModelMetadata(&model, kKey.data(), &metadata, &metadata_size)); + EXPECT_EQ(BufferRef(metadata, metadata_size).StrView(), kData); +} + +TEST(LiteRtModelTest, GetSubgraph) { + LiteRtModelT model; + auto& subgraph = model.EmplaceSubgraph(); + + LiteRtSubgraph actual_subgraph; + LITERT_ASSERT_STATUS_OK(LiteRtGetModelSubgraph(&model, 0, &actual_subgraph)); + EXPECT_EQ(actual_subgraph, &subgraph); +} + +TEST(LiteRtModelTest, GetSubgraphOOB) { + LiteRtModelT model; + + LiteRtSubgraph actual_subgraph; + LITERT_ASSERT_STATUS_HAS_CODE( + LiteRtGetModelSubgraph(&model, 0, &actual_subgraph), + kLiteRtStatusErrorIndexOOB); +} + +TEST(LiteRtOpListTest, PushOps) { + LiteRtOpListT op_list; + LiteRtOpT op; + + LITERT_ASSERT_STATUS_OK(LiteRtPushOp(&op_list, &op)); + auto vec = op_list.Vec(); + ASSERT_EQ(vec.size(), 1); + EXPECT_EQ(vec.front(), &op); +} + +} // namespace diff --git a/tflite/experimental/litert/c/litert_op_code.h b/tflite/experimental/litert/c/litert_op_code.h new file mode 100644 index 00000000..c0a13d82 --- /dev/null +++ b/tflite/experimental/litert/c/litert_op_code.h @@ -0,0 +1,245 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OP_CODE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OP_CODE_H_ + +#include "tflite/builtin_ops.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef enum { + kLiteRtOpCodeTflAdd = kTfLiteBuiltinAdd, + kLiteRtOpCodeTflAveragePool2d = kTfLiteBuiltinAveragePool2d, + kLiteRtOpCodeTflConcatenation = kTfLiteBuiltinConcatenation, + kLiteRtOpCodeTflConv2d = kTfLiteBuiltinConv2d, + kLiteRtOpCodeTflDepthwiseConv2d = kTfLiteBuiltinDepthwiseConv2d, + kLiteRtOpCodeTflDepthToSpace = kTfLiteBuiltinDepthToSpace, + kLiteRtOpCodeTflDequantize = kTfLiteBuiltinDequantize, + kLiteRtOpCodeTflEmbeddingLookup = kTfLiteBuiltinEmbeddingLookup, + kLiteRtOpCodeTflFloor = kTfLiteBuiltinFloor, + kLiteRtOpCodeTflFullyConnected = kTfLiteBuiltinFullyConnected, + kLiteRtOpCodeTflHashtableLookup = kTfLiteBuiltinHashtableLookup, + kLiteRtOpCodeTflL2Normalization = kTfLiteBuiltinL2Normalization, + kLiteRtOpCodeTflL2Pool2d = kTfLiteBuiltinL2Pool2d, + kLiteRtOpCodeTflLocalResponseNormalization = + kTfLiteBuiltinLocalResponseNormalization, + kLiteRtOpCodeTflLogistic = kTfLiteBuiltinLogistic, + kLiteRtOpCodeTflLshProjection = kTfLiteBuiltinLshProjection, + kLiteRtOpCodeTflLstm = kTfLiteBuiltinLstm, + kLiteRtOpCodeTflMaxPool2d = kTfLiteBuiltinMaxPool2d, + kLiteRtOpCodeTflMul = kTfLiteBuiltinMul, + kLiteRtOpCodeTflRelu = kTfLiteBuiltinRelu, + kLiteRtOpCodeTflReluN1To1 = kTfLiteBuiltinReluN1To1, + kLiteRtOpCodeTflRelu6 = kTfLiteBuiltinRelu6, + kLiteRtOpCodeTflReshape = kTfLiteBuiltinReshape, + kLiteRtOpCodeTflResizeBilinear = kTfLiteBuiltinResizeBilinear, + kLiteRtOpCodeTflRnn = kTfLiteBuiltinRnn, + kLiteRtOpCodeTflSoftmax = kTfLiteBuiltinSoftmax, + kLiteRtOpCodeTflSpaceToDepth = kTfLiteBuiltinSpaceToDepth, + kLiteRtOpCodeTflSvdf = kTfLiteBuiltinSvdf, + kLiteRtOpCodeTflTanh = kTfLiteBuiltinTanh, + kLiteRtOpCodeTflConcatEmbeddings = kTfLiteBuiltinConcatEmbeddings, + kLiteRtOpCodeTflSkipGram = kTfLiteBuiltinSkipGram, + kLiteRtOpCodeTflCall = kTfLiteBuiltinCall, + kLiteRtOpCodeTflCustom = kTfLiteBuiltinCustom, + kLiteRtOpCodeTflEmbeddingLookupSparse = kTfLiteBuiltinEmbeddingLookupSparse, + kLiteRtOpCodeTflPad = kTfLiteBuiltinPad, + kLiteRtOpCodeTflUnidirectionalSequenceRnn = + kTfLiteBuiltinUnidirectionalSequenceRnn, + kLiteRtOpCodeTflGather = kTfLiteBuiltinGather, + kLiteRtOpCodeTflBatchToSpaceNd = kTfLiteBuiltinBatchToSpaceNd, + kLiteRtOpCodeTflSpaceToBatchNd = kTfLiteBuiltinSpaceToBatchNd, + kLiteRtOpCodeTflTranspose = kTfLiteBuiltinTranspose, + kLiteRtOpCodeTflMean = kTfLiteBuiltinMean, + kLiteRtOpCodeTflSub = kTfLiteBuiltinSub, + kLiteRtOpCodeTflDiv = kTfLiteBuiltinDiv, + kLiteRtOpCodeTflSqueeze = kTfLiteBuiltinSqueeze, + kLiteRtOpCodeTflUnidirectionalSequenceLstm = + kTfLiteBuiltinUnidirectionalSequenceLstm, + kLiteRtOpCodeTflStridedSlice = kTfLiteBuiltinStridedSlice, + kLiteRtOpCodeTflBidirectionalSequenceRnn = + kTfLiteBuiltinBidirectionalSequenceRnn, + kLiteRtOpCodeTflExp = kTfLiteBuiltinExp, + kLiteRtOpCodeTflTopkV2 = kTfLiteBuiltinTopkV2, + kLiteRtOpCodeTflSplit = kTfLiteBuiltinSplit, + kLiteRtOpCodeTflLogSoftmax = kTfLiteBuiltinLogSoftmax, + kLiteRtOpCodeTflDelegate = kTfLiteBuiltinDelegate, + kLiteRtOpCodeTflBidirectionalSequenceLstm = + kTfLiteBuiltinBidirectionalSequenceLstm, + kLiteRtOpCodeTflCast = kTfLiteBuiltinCast, + kLiteRtOpCodeTflPrelu = kTfLiteBuiltinPrelu, + kLiteRtOpCodeTflMaximum = kTfLiteBuiltinMaximum, + kLiteRtOpCodeTflArgMax = kTfLiteBuiltinArgMax, + kLiteRtOpCodeTflMinimum = kTfLiteBuiltinMinimum, + kLiteRtOpCodeTflLess = kTfLiteBuiltinLess, + kLiteRtOpCodeTflNeg = kTfLiteBuiltinNeg, + kLiteRtOpCodeTflPadv2 = kTfLiteBuiltinPadv2, + kLiteRtOpCodeTflGreater = kTfLiteBuiltinGreater, + kLiteRtOpCodeTflGreaterEqual = kTfLiteBuiltinGreaterEqual, + kLiteRtOpCodeTflLessEqual = kTfLiteBuiltinLessEqual, + kLiteRtOpCodeTflSelect = kTfLiteBuiltinSelect, + kLiteRtOpCodeTflSlice = kTfLiteBuiltinSlice, + kLiteRtOpCodeTflSin = kTfLiteBuiltinSin, + kLiteRtOpCodeTflTransposeConv = kTfLiteBuiltinTransposeConv, + kLiteRtOpCodeTflSparseToDense = kTfLiteBuiltinSparseToDense, + kLiteRtOpCodeTflTile = kTfLiteBuiltinTile, + kLiteRtOpCodeTflExpandDims = kTfLiteBuiltinExpandDims, + kLiteRtOpCodeTflEqual = kTfLiteBuiltinEqual, + kLiteRtOpCodeTflNotEqual = kTfLiteBuiltinNotEqual, + kLiteRtOpCodeTflLog = kTfLiteBuiltinLog, + kLiteRtOpCodeTflSum = kTfLiteBuiltinSum, + kLiteRtOpCodeTflSqrt = kTfLiteBuiltinSqrt, + kLiteRtOpCodeTflRsqrt = kTfLiteBuiltinRsqrt, + kLiteRtOpCodeTflShape = kTfLiteBuiltinShape, + kLiteRtOpCodeTflPow = kTfLiteBuiltinPow, + kLiteRtOpCodeTflArgMin = kTfLiteBuiltinArgMin, + kLiteRtOpCodeTflFakeQuant = kTfLiteBuiltinFakeQuant, + kLiteRtOpCodeTflReduceProd = kTfLiteBuiltinReduceProd, + kLiteRtOpCodeTflReduceMax = kTfLiteBuiltinReduceMax, + kLiteRtOpCodeTflPack = kTfLiteBuiltinPack, + kLiteRtOpCodeTflLogicalOr = kTfLiteBuiltinLogicalOr, + kLiteRtOpCodeTflOneHot = kTfLiteBuiltinOneHot, + kLiteRtOpCodeTflLogicalAnd = kTfLiteBuiltinLogicalAnd, + kLiteRtOpCodeTflLogicalNot = kTfLiteBuiltinLogicalNot, + kLiteRtOpCodeTflUnpack = kTfLiteBuiltinUnpack, + kLiteRtOpCodeTflReduceMin = kTfLiteBuiltinReduceMin, + kLiteRtOpCodeTflFloorDiv = kTfLiteBuiltinFloorDiv, + kLiteRtOpCodeTflReduceAny = kTfLiteBuiltinReduceAny, + kLiteRtOpCodeTflSquare = kTfLiteBuiltinSquare, + kLiteRtOpCodeTflZerosLike = kTfLiteBuiltinZerosLike, + kLiteRtOpCodeTflFill = kTfLiteBuiltinFill, + kLiteRtOpCodeTflFloorMod = kTfLiteBuiltinFloorMod, + kLiteRtOpCodeTflRange = kTfLiteBuiltinRange, + kLiteRtOpCodeTflResizeNearestNeighbor = kTfLiteBuiltinResizeNearestNeighbor, + kLiteRtOpCodeTflLeakyRelu = kTfLiteBuiltinLeakyRelu, + kLiteRtOpCodeTflSquaredDifference = kTfLiteBuiltinSquaredDifference, + kLiteRtOpCodeTflMirrorPad = kTfLiteBuiltinMirrorPad, + kLiteRtOpCodeTflAbs = kTfLiteBuiltinAbs, + kLiteRtOpCodeTflSplitV = kTfLiteBuiltinSplitV, + kLiteRtOpCodeTflUnique = kTfLiteBuiltinUnique, + kLiteRtOpCodeTflCeil = kTfLiteBuiltinCeil, + kLiteRtOpCodeTflReverseV2 = kTfLiteBuiltinReverseV2, + kLiteRtOpCodeTflAddN = kTfLiteBuiltinAddN, + kLiteRtOpCodeTflGatherNd = kTfLiteBuiltinGatherNd, + kLiteRtOpCodeTflCos = kTfLiteBuiltinCos, + kLiteRtOpCodeTflWhere = kTfLiteBuiltinWhere, + kLiteRtOpCodeTflRank = kTfLiteBuiltinRank, + kLiteRtOpCodeTflElu = kTfLiteBuiltinElu, + kLiteRtOpCodeTflReverseSequence = kTfLiteBuiltinReverseSequence, + kLiteRtOpCodeTflMatrixDiag = kTfLiteBuiltinMatrixDiag, + kLiteRtOpCodeTflQuantize = kTfLiteBuiltinQuantize, + kLiteRtOpCodeTflMatrixSetDiag = kTfLiteBuiltinMatrixSetDiag, + kLiteRtOpCodeTflRound = kTfLiteBuiltinRound, + kLiteRtOpCodeTflHardSwish = kTfLiteBuiltinHardSwish, + kLiteRtOpCodeTflIf = kTfLiteBuiltinIf, + kLiteRtOpCodeTflWhile = kTfLiteBuiltinWhile, + kLiteRtOpCodeTflNonMaxSuppressionV4 = kTfLiteBuiltinNonMaxSuppressionV4, + kLiteRtOpCodeTflNonMaxSuppressionV5 = kTfLiteBuiltinNonMaxSuppressionV5, + kLiteRtOpCodeTflScatterNd = kTfLiteBuiltinScatterNd, + kLiteRtOpCodeTflSelectV2 = kTfLiteBuiltinSelectV2, + kLiteRtOpCodeTflDensify = kTfLiteBuiltinDensify, + kLiteRtOpCodeTflSegmentSum = kTfLiteBuiltinSegmentSum, + kLiteRtOpCodeTflBatchMatmul = kTfLiteBuiltinBatchMatmul, + kLiteRtOpCodeTflPlaceholderForGreaterOpCodeTfls = + kTfLiteBuiltinPlaceholderForGreaterOpCodes, + kLiteRtOpCodeTflCumsum = kTfLiteBuiltinCumsum, + kLiteRtOpCodeTflCallOnce = kTfLiteBuiltinCallOnce, + kLiteRtOpCodeTflBroadcastTo = kTfLiteBuiltinBroadcastTo, + kLiteRtOpCodeTflRfft2d = kTfLiteBuiltinRfft2d, + kLiteRtOpCodeTflConv3d = kTfLiteBuiltinConv3d, + kLiteRtOpCodeTflImag = kTfLiteBuiltinImag, + kLiteRtOpCodeTflReal = kTfLiteBuiltinReal, + kLiteRtOpCodeTflComplexAbs = kTfLiteBuiltinComplexAbs, + kLiteRtOpCodeTflHashtable = kTfLiteBuiltinHashtable, + kLiteRtOpCodeTflHashtableFind = kTfLiteBuiltinHashtableFind, + kLiteRtOpCodeTflHashtableImport = kTfLiteBuiltinHashtableImport, + kLiteRtOpCodeTflHashtableSize = kTfLiteBuiltinHashtableSize, + kLiteRtOpCodeTflReduceAll = kTfLiteBuiltinReduceAll, + kLiteRtOpCodeTflConv3dTranspose = kTfLiteBuiltinConv3dTranspose, + kLiteRtOpCodeTflVarHandle = kTfLiteBuiltinVarHandle, + kLiteRtOpCodeTflReadVariable = kTfLiteBuiltinReadVariable, + kLiteRtOpCodeTflAssignVariable = kTfLiteBuiltinAssignVariable, + kLiteRtOpCodeTflBroadcastArgs = kTfLiteBuiltinBroadcastArgs, + kLiteRtOpCodeTflRandomStandardNormal = kTfLiteBuiltinRandomStandardNormal, + kLiteRtOpCodeTflBucketize = kTfLiteBuiltinBucketize, + kLiteRtOpCodeTflRandomUniform = kTfLiteBuiltinRandomUniform, + kLiteRtOpCodeTflMultinomial = kTfLiteBuiltinMultinomial, + kLiteRtOpCodeTflGelu = kTfLiteBuiltinGelu, + kLiteRtOpCodeTflDynamicUpdateSlice = kTfLiteBuiltinDynamicUpdateSlice, + kLiteRtOpCodeTflRelu0To1 = kTfLiteBuiltinRelu0To1, + kLiteRtOpCodeTflUnsortedSegmentProd = kTfLiteBuiltinUnsortedSegmentProd, + kLiteRtOpCodeTflUnsortedSegmentMax = kTfLiteBuiltinUnsortedSegmentMax, + kLiteRtOpCodeTflUnsortedSegmentSum = kTfLiteBuiltinUnsortedSegmentSum, + kLiteRtOpCodeTflAtan2 = kTfLiteBuiltinAtan2, + kLiteRtOpCodeTflUnsortedSegmentMin = kTfLiteBuiltinUnsortedSegmentMin, + kLiteRtOpCodeTflSign = kTfLiteBuiltinSign, + kLiteRtOpCodeTflBitcast = kTfLiteBuiltinBitcast, + kLiteRtOpCodeTflBitwiseXor = kTfLiteBuiltinBitwiseXor, + kLiteRtOpCodeTflRightShift = kTfLiteBuiltinRightShift, + kLiteRtOpCodeShloLogistic = kTfLiteBuiltinStablehloLogistic, + kLiteRtOpCodeShloAdd = kTfLiteBuiltinStablehloAdd, + kLiteRtOpCodeShloDivide = kTfLiteBuiltinStablehloDivide, + kLiteRtOpCodeShloMultiply = kTfLiteBuiltinStablehloMultiply, + kLiteRtOpCodeShloMaximum = kTfLiteBuiltinStablehloMaximum, + kLiteRtOpCodeShloReshape = kTfLiteBuiltinStablehloReshape, + kLiteRtOpCodeShloClamp = kTfLiteBuiltinStablehloClamp, + kLiteRtOpCodeShloConcatenate = kTfLiteBuiltinStablehloConcatenate, + kLiteRtOpCodeShloBroadcastInDim = kTfLiteBuiltinStablehloBroadcastInDim, + kLiteRtOpCodeShloConvolution = kTfLiteBuiltinStablehloConvolution, + kLiteRtOpCodeShloSlice = kTfLiteBuiltinStablehloSlice, + kLiteRtOpCodeShloCustomCall = kTfLiteBuiltinStablehloCustomCall, + kLiteRtOpCodeShloReduce = kTfLiteBuiltinStablehloReduce, + kLiteRtOpCodeShloAbs = kTfLiteBuiltinStablehloAbs, + kLiteRtOpCodeShloAnd = kTfLiteBuiltinStablehloAnd, + kLiteRtOpCodeShloCosine = kTfLiteBuiltinStablehloCosine, + kLiteRtOpCodeShloExponential = kTfLiteBuiltinStablehloExponential, + kLiteRtOpCodeShloFloor = kTfLiteBuiltinStablehloFloor, + kLiteRtOpCodeShloLog = kTfLiteBuiltinStablehloLog, + kLiteRtOpCodeShloMinimum = kTfLiteBuiltinStablehloMinimum, + kLiteRtOpCodeShloNegate = kTfLiteBuiltinStablehloNegate, + kLiteRtOpCodeShloOr = kTfLiteBuiltinStablehloOr, + kLiteRtOpCodeShloPower = kTfLiteBuiltinStablehloPower, + kLiteRtOpCodeShloRemainder = kTfLiteBuiltinStablehloRemainder, + kLiteRtOpCodeShloRsqrt = kTfLiteBuiltinStablehloRsqrt, + kLiteRtOpCodeShloSelect = kTfLiteBuiltinStablehloSelect, + kLiteRtOpCodeShloSubtract = kTfLiteBuiltinStablehloSubtract, + kLiteRtOpCodeShloTanh = kTfLiteBuiltinStablehloTanh, + kLiteRtOpCodeShloScatter = kTfLiteBuiltinStablehloScatter, + kLiteRtOpCodeShloCompare = kTfLiteBuiltinStablehloCompare, + kLiteRtOpCodeShloConvert = kTfLiteBuiltinStablehloConvert, + kLiteRtOpCodeShloDynamicSlice = kTfLiteBuiltinStablehloDynamicSlice, + kLiteRtOpCodeShloDynamicUpdateSlice = + kTfLiteBuiltinStablehloDynamicUpdateSlice, + kLiteRtOpCodeShloPad = kTfLiteBuiltinStablehloPad, + kLiteRtOpCodeShloIota = kTfLiteBuiltinStablehloIota, + kLiteRtOpCodeShloGeneral = kTfLiteBuiltinStablehloDotGeneral, + kLiteRtOpCodeShloWindow = kTfLiteBuiltinStablehloReduceWindow, + kLiteRtOpCodeShloSort = kTfLiteBuiltinStablehloSort, + kLiteRtOpCodeShloWhile = kTfLiteBuiltinStablehloWhile, + kLiteRtOpCodeShloGather = kTfLiteBuiltinStablehloGather, + kLiteRtOpCodeShloTranspose = kTfLiteBuiltinStablehloTranspose, + kLiteRtOpCodeTflDilate = kTfLiteBuiltinDilate, + kLiteRtOpCodeShloRngBitGenerator = kTfLiteBuiltinStablehloRngBitGenerator, + kLiteRtOpCodeTflReduceWindow = kTfLiteBuiltinReduceWindow, + kLiteRtOpCodeShloComposite = kTfLiteBuiltinStablehloComposite, +} LiteRtOpCode; + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OP_CODE_H_ diff --git a/tflite/experimental/litert/c/litert_options.cc b/tflite/experimental/litert/c/litert_options.cc new file mode 100644 index 00000000..c6f776ed --- /dev/null +++ b/tflite/experimental/litert/c/litert_options.cc @@ -0,0 +1,253 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/c/litert_options.h" + +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/core/model/model.h" + +// +// Op Options +// + +LiteRtStatus LiteRtGetAddFusedActivationOption(LiteRtOp op, + uint32_t* fused_activation) { + if (op->OpCode() != kLiteRtOpCodeTflAdd) { + return kLiteRtStatusErrorInvalidArgument; + } + *fused_activation = + detail::GetTflOptions(*op).AsAddOptions()->fused_activation_function; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetBatchMatmulAdjXOption(LiteRtOp op, bool* adj_x) { + if (op->OpCode() != kLiteRtOpCodeTflBatchMatmul) { + return kLiteRtStatusErrorInvalidArgument; + } + *adj_x = detail::GetTflOptions(*op).AsBatchMatMulOptions()->adj_x; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetBatchMatmulAdjYOption(LiteRtOp op, bool* adj_y) { + if (op->OpCode() != kLiteRtOpCodeTflBatchMatmul) { + return kLiteRtStatusErrorInvalidArgument; + } + *adj_y = detail::GetTflOptions(*op).AsBatchMatMulOptions()->adj_y; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetBatchMatmulAsymmetricQuantizeInputOption( + LiteRtOp op, bool* asymmetric_quantize_input) { + if (op->OpCode() != kLiteRtOpCodeTflBatchMatmul) { + return kLiteRtStatusErrorInvalidArgument; + } + *asymmetric_quantize_input = detail::GetTflOptions(*op) + .AsBatchMatMulOptions() + ->asymmetric_quantize_inputs; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetConcatenationFusedActivationOption( + LiteRtOp op, uint32_t* fused_activation) { + if (op->OpCode() != kLiteRtOpCodeTflConcatenation) { + return kLiteRtStatusErrorInvalidArgument; + } + *fused_activation = detail::GetTflOptions(*op) + .AsConcatenationOptions() + ->fused_activation_function; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetConcatenationAxisOption(LiteRtOp op, int32_t* axis) { + if (op->OpCode() != kLiteRtOpCodeTflConcatenation) { + return kLiteRtStatusErrorInvalidArgument; + } + *axis = detail::GetTflOptions(*op).AsConcatenationOptions()->axis; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetDivFusedActivationOption(LiteRtOp op, + uint32_t* fused_activation) { + if (op->OpCode() != kLiteRtOpCodeTflDiv) { + return kLiteRtStatusErrorInvalidArgument; + } + *fused_activation = + detail::GetTflOptions(*op).AsDivOptions()->fused_activation_function; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetFullyConnectedFusedActivationOption( + LiteRtOp op, uint32_t* fused_activation) { + if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { + return kLiteRtStatusErrorInvalidArgument; + } + *fused_activation = detail::GetTflOptions(*op) + .AsFullyConnectedOptions() + ->fused_activation_function; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetFullyConnectedKeepNumDimsOption(LiteRtOp op, + bool* keep_num_dims) { + if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { + return kLiteRtStatusErrorInvalidArgument; + } + *keep_num_dims = + detail::GetTflOptions(*op).AsFullyConnectedOptions()->keep_num_dims; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtFullyConnectedGetQuantizedBiasTypeOption( + LiteRtOp op, uint32_t* quantized_bias_type) { + if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { + return kLiteRtStatusErrorInvalidArgument; + } + *quantized_bias_type = + detail::GetTflOptions(*op).AsFullyConnectedOptions()->quantized_bias_type; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetFullyConnectedAsymmetricQuantizeInputOption( + LiteRtOp op, bool* asymmetric_quantize_input) { + if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { + return kLiteRtStatusErrorInvalidArgument; + } + *asymmetric_quantize_input = detail::GetTflOptions(*op) + .AsFullyConnectedOptions() + ->asymmetric_quantize_inputs; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetFullyConnectedWeightsFormatOption( + LiteRtOp op, uint32_t* weights_format) { + if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { + return kLiteRtStatusErrorInvalidArgument; + } + *weights_format = + detail::GetTflOptions(*op).AsFullyConnectedOptions()->weights_format; + return kLiteRtStatusOk; +} +LiteRtStatus LiteRtGetMulFusedActivationOption(LiteRtOp op, + uint32_t* fused_activation) { + if (op->OpCode() != kLiteRtOpCodeTflMul) { + return kLiteRtStatusErrorInvalidArgument; + } + *fused_activation = + detail::GetTflOptions(*op).AsMulOptions()->fused_activation_function; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetSoftmaxBetaOption(LiteRtOp op, float* beta) { + if (op->OpCode() != kLiteRtOpCodeTflSoftmax) { + return kLiteRtStatusErrorInvalidArgument; + } + *beta = detail::GetTflOptions(*op).AsSoftmaxOptions()->beta; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetStridedSliceBeginMaskOption(LiteRtOp op, + int32_t* begin_mask) { + if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { + return kLiteRtStatusErrorInvalidArgument; + } + *begin_mask = detail::GetTflOptions(*op).AsStridedSliceOptions()->begin_mask; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetStridedSliceEndMaskOption(LiteRtOp op, + int32_t* end_mask) { + if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { + return kLiteRtStatusErrorInvalidArgument; + } + *end_mask = detail::GetTflOptions(*op).AsStridedSliceOptions()->end_mask; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetStridedSliceEllipsisMaskOption(LiteRtOp op, + int32_t* ellipsis_mask) { + if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { + return kLiteRtStatusErrorInvalidArgument; + } + *ellipsis_mask = + detail::GetTflOptions(*op).AsStridedSliceOptions()->ellipsis_mask; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetStridedSliceNewAxisMaskOption(LiteRtOp op, + int32_t* new_axis_mask) { + if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { + return kLiteRtStatusErrorInvalidArgument; + } + *new_axis_mask = + detail::GetTflOptions(*op).AsStridedSliceOptions()->new_axis_mask; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetStridedSliceShrinkAxisMaskOption( + LiteRtOp op, int32_t* shrink_axis_mask) { + if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { + return kLiteRtStatusErrorInvalidArgument; + } + *shrink_axis_mask = + detail::GetTflOptions(*op).AsStridedSliceOptions()->shrink_axis_mask; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetStridedSliceOffsetOption(LiteRtOp op, bool* offset) { + if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { + return kLiteRtStatusErrorInvalidArgument; + } + *offset = detail::GetTflOptions(*op).AsStridedSliceOptions()->offset; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetSubFusedActivationOption(LiteRtOp op, + uint32_t* fused_activation) { + if (op->OpCode() != kLiteRtOpCodeTflSub) { + return kLiteRtStatusErrorInvalidArgument; + } + *fused_activation = + detail::GetTflOptions(*op).AsSubOptions()->fused_activation_function; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetReshapeNewShapeOption(LiteRtOp op, + const int32_t** new_shape, + int32_t* new_shape_size) { + if (op->OpCode() != kLiteRtOpCodeTflReshape) { + return kLiteRtStatusErrorInvalidArgument; + } + if (detail::GetTflOptions(*op).AsReshapeOptions() == nullptr) { + *new_shape_size = -1; + return kLiteRtStatusOk; + } else { + *new_shape = + detail::GetTflOptions(*op).AsReshapeOptions()->new_shape.data(); + *new_shape_size = + detail::GetTflOptions(*op).AsReshapeOptions()->new_shape.size(); + } + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetSumKeepDimsOption(LiteRtOp op, bool* keepdims) { + if (op->OpCode() != kLiteRtOpCodeTflSum) { + return kLiteRtStatusErrorInvalidArgument; + } + // Sum OP options is stored as ReducerOptions. + *keepdims = detail::GetTflOptions(*op).AsReducerOptions()->keep_dims; + return kLiteRtStatusOk; +} diff --git a/tflite/experimental/litert/c/litert_options.h b/tflite/experimental/litert/c/litert_options.h new file mode 100644 index 00000000..63a07512 --- /dev/null +++ b/tflite/experimental/litert/c/litert_options.h @@ -0,0 +1,174 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OPTIONS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OPTIONS_H_ + +#include // NOLINT: To use bool type in C +#include + +#include "tflite/experimental/litert/c/litert_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +LITERT_DEFINE_HANDLE(LiteRtOp); + +//============================================================================== +// +// Get option APIs for LiteRt ADD op. +// Options: +// - FusedActivationOption : uint32_t +// +//============================================================================== +LiteRtStatus LiteRtGetAddFusedActivationOption(LiteRtOp op, + uint32_t* fused_activation); + +//============================================================================== +// +// Get option APIs for LiteRt BatchMatmul op. +// Options: +// - AdjXOption : bool +// - AdjYOption : bool +// - AsymmtericQuantizeInputOption : bool +// +//============================================================================== +LiteRtStatus LiteRtGetBatchMatmulAdjXOption(LiteRtOp op, bool* adj_x); +LiteRtStatus LiteRtGetBatchMatmulAdjYOption(LiteRtOp op, bool* adj_y); +LiteRtStatus LiteRtGetBatchMatmulAsymmetricQuantizeInputOption( + LiteRtOp op, bool* asymmetric_quantize_input); + +//============================================================================== +// +// Get option APIs for LiteRt Concatenation op. +// Options: +// - FusedActivationOption : uint32_t +// - AxisOption : int32_t +// +//============================================================================== +LiteRtStatus LiteRtGetConcatenationFusedActivationOption( + LiteRtOp op, uint32_t* fused_activation); +LiteRtStatus LiteRtGetConcatenationAxisOption(LiteRtOp op, int32_t* axis); + +//============================================================================== +// +// Get option APIs for LiteRt Div op. +// Options: +// - FusedActivationOption : uint32_t +// +//============================================================================== +LiteRtStatus LiteRtGetDivFusedActivationOption(LiteRtOp op, + uint32_t* fused_activation); + +//============================================================================== +// +// Get option APIs for LiteRt FullyConnected op. +// Options: +// - FusedActivationOption : uint32_t +// - WeightsFormatOption : uint32_t +// - KeepNumDimsOption : bool +// - QuantizedBiasTypeOption : uint32_t +// - AsymmtericQuantizeInputOption : bool +// +//============================================================================== +LiteRtStatus LiteRtGetFullyConnectedFusedActivationOption( + LiteRtOp op, uint32_t* fused_activation); +LiteRtStatus LiteRtGetFullyConnectedWeightsFormatOption( + LiteRtOp op, uint32_t* weights_format); +LiteRtStatus LiteRtGetFullyConnectedKeepNumDimsOption(LiteRtOp op, + bool* keep_num_dims); +LiteRtStatus LiteRtFullyConnectedGetQuantizedBiasTypeOption( + LiteRtOp op, uint32_t* quantized_bias_type); +LiteRtStatus LiteRtGetFullyConnectedAsymmetricQuantizeInputOption( + LiteRtOp op, bool* asymmetric_quantize_input); + +//============================================================================== +// +// Get option APIs for LiteRt Mul op. +// Options: +// - FusedActivationOption : uint32_t +// +//============================================================================== +LiteRtStatus LiteRtGetMulFusedActivationOption(LiteRtOp op, + uint32_t* fused_activation); + +//============================================================================== +// +// Get option APIs for LiteRt Softmax op. +// Options: +// - BetaOption : float +// +//============================================================================== +LiteRtStatus LiteRtGetSoftmaxBetaOption(LiteRtOp op, float* beta); + +//============================================================================== +// +// Get option APIs for LiteRt StridedSlice op. +// Options: +// - BeginMaskOption : int32_t +// - EndMaskOption : int32_t +// - EllipsisMaskOption : int32_t +// - NewAxisMaskOption : int32_t +// - ShrinkAxisMaskOption : int32_t +// - OffsetOption : bool + +//============================================================================== +LiteRtStatus LiteRtGetStridedSliceBeginMaskOption(LiteRtOp op, + int32_t* begin_mask); +LiteRtStatus LiteRtGetStridedSliceEndMaskOption(LiteRtOp op, int32_t* end_mask); +LiteRtStatus LiteRtGetStridedSliceEllipsisMaskOption(LiteRtOp op, + int32_t* ellipsis_mask); +LiteRtStatus LiteRtGetStridedSliceNewAxisMaskOption(LiteRtOp op, + int32_t* new_axis_mask); +LiteRtStatus LiteRtGetStridedSliceShrinkAxisMaskOption( + LiteRtOp op, int32_t* shrink_axis_mask); +LiteRtStatus LiteRtGetStridedSliceOffsetOption(LiteRtOp op, bool* offset); + +//============================================================================== +// +// Get option APIs for LiteRt Sub op. +// Options: +// - FusedActivationOption : uint32_t +// - (Not supported) PotScaleInt16Option : bool +// +//============================================================================== +LiteRtStatus LiteRtGetSubFusedActivationOption(LiteRtOp op, + uint32_t* fused_activation); + +//============================================================================== +// +// Get option APIs for LiteRt Reshape op. +// Options: +// - new_shape : int32_t[] +// +//============================================================================== +LiteRtStatus LiteRtGetReshapeNewShapeOption(LiteRtOp op, + const int32_t** new_shape, + int32_t* new_shape_size); + +//============================================================================== +// +// Get option APIs for LiteRt Sum op. +// Options: +// - KeepdimsOption : bool +// +//============================================================================== +LiteRtStatus LiteRtGetSumKeepDimsOption(LiteRtOp op, bool* keepdims); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OPTIONS_H_ diff --git a/tflite/experimental/litert/c/litert_options_test.cc b/tflite/experimental/litert/c/litert_options_test.cc new file mode 100644 index 00000000..ac7ebd24 --- /dev/null +++ b/tflite/experimental/litert/c/litert_options_test.cc @@ -0,0 +1,236 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +// NOLINTNEXTLINE + +#include // IWYU pragma: keep +#include +#include "tflite/experimental/litert/c/litert_options.h" +#include "tflite/experimental/litert/test/common.h" + +namespace { +TEST(GetOpOptionTest, TestGetAddOptions) { + auto model = litert::testing::LoadTestFileModel("simple_add_op.tflite"); + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + + auto ops = subgraph->Ops(); + auto op = ops.front().Get(); + + uint32_t fused_activation; + LITERT_ASSERT_STATUS_OK( + LiteRtGetAddFusedActivationOption(op, &fused_activation)); + ASSERT_EQ(fused_activation, 0); +} + +TEST(GetOpOptionTest, TestGetBatchMatmulOptions) { + auto model = + litert::testing::LoadTestFileModel("simple_batch_matmul_op.tflite"); + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + + auto ops = subgraph->Ops(); + auto op = ops.front().Get(); + + bool adj_x; + LITERT_ASSERT_STATUS_OK(LiteRtGetBatchMatmulAdjXOption(op, &adj_x)); + ASSERT_EQ(adj_x, false); + + bool adj_y; + LITERT_ASSERT_STATUS_OK(LiteRtGetBatchMatmulAdjYOption(op, &adj_y)); + ASSERT_EQ(adj_y, false); + + bool asymmetric_quantize_input; + LITERT_ASSERT_STATUS_OK(LiteRtGetBatchMatmulAsymmetricQuantizeInputOption( + op, &asymmetric_quantize_input)); + ASSERT_EQ(asymmetric_quantize_input, false); +} + +TEST(GetOpOptionTest, TestGetConcatenationOptions) { + auto model = + litert::testing::LoadTestFileModel("simple_concatenation_op.tflite"); + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + + auto ops = subgraph->Ops(); + auto op = ops.front().Get(); + + uint32_t fused_activation; + LITERT_ASSERT_STATUS_OK( + LiteRtGetConcatenationFusedActivationOption(op, &fused_activation)); + ASSERT_EQ(fused_activation, 0); + + int32_t axis; + LITERT_ASSERT_STATUS_OK(LiteRtGetConcatenationAxisOption(op, &axis)); + ASSERT_EQ(axis, 2); +} + +TEST(GetOpOptionTest, TestGetDivOptions) { + auto model = litert::testing::LoadTestFileModel("simple_div_op.tflite"); + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + + auto ops = subgraph->Ops(); + auto op = ops.front().Get(); + + uint32_t fused_activation; + LITERT_ASSERT_STATUS_OK( + LiteRtGetDivFusedActivationOption(op, &fused_activation)); + ASSERT_EQ(fused_activation, 0); +} + +TEST(GetOpOptionTest, TestGetFullyConnectedOptions) { + auto model = + litert::testing::LoadTestFileModel("simple_fully_connected_op.tflite"); + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + + auto ops = subgraph->Ops(); + auto op = ops.front().Get(); + + uint32_t fused_activation; + LITERT_ASSERT_STATUS_OK( + LiteRtGetFullyConnectedFusedActivationOption(op, &fused_activation)); + ASSERT_EQ(fused_activation, 0); + + uint32_t weights_format; + LITERT_ASSERT_STATUS_OK( + LiteRtGetFullyConnectedWeightsFormatOption(op, &weights_format)); + ASSERT_EQ(weights_format, 0); + + bool keep_num_dims; + LITERT_ASSERT_STATUS_OK( + LiteRtGetFullyConnectedKeepNumDimsOption(op, &keep_num_dims)); + ASSERT_EQ(keep_num_dims, true); + + uint32_t quantized_bias_type; + LITERT_ASSERT_STATUS_OK( + LiteRtFullyConnectedGetQuantizedBiasTypeOption(op, &quantized_bias_type)); + ASSERT_EQ(quantized_bias_type, 0); + + bool asymmetric_quantize_input; + LITERT_ASSERT_STATUS_OK(LiteRtGetFullyConnectedAsymmetricQuantizeInputOption( + op, &asymmetric_quantize_input)); + ASSERT_EQ(asymmetric_quantize_input, false); +} + +TEST(GetOpOptionTest, TestGetMulOptions) { + auto model = litert::testing::LoadTestFileModel("simple_mul_op.tflite"); + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + + auto ops = subgraph->Ops(); + auto op = ops.front().Get(); + + uint32_t fused_activation; + LITERT_ASSERT_STATUS_OK( + LiteRtGetMulFusedActivationOption(op, &fused_activation)); + ASSERT_EQ(fused_activation, 0); +} + +TEST(GetOpOptionTest, TestGetSoftmaxOptions) { + auto model = litert::testing::LoadTestFileModel("simple_softmax_op.tflite"); + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + + auto ops = subgraph->Ops(); + auto op = ops.front().Get(); + + float beta; + LITERT_ASSERT_STATUS_OK(LiteRtGetSoftmaxBetaOption(op, &beta)); + EXPECT_FLOAT_EQ(beta, 1.0); +} + +TEST(GetOpOptionTest, TestGetStridedSliceOptions) { + auto model = + litert::testing::LoadTestFileModel("simple_strided_slice_op.tflite"); + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + + auto ops = subgraph->Ops(); + auto op = ops.front().Get(); + + int32_t begin_mask; + LITERT_ASSERT_STATUS_OK( + LiteRtGetStridedSliceBeginMaskOption(op, &begin_mask)); + ASSERT_EQ(begin_mask, 0); + + int32_t end_mask; + LITERT_ASSERT_STATUS_OK(LiteRtGetStridedSliceEndMaskOption(op, &end_mask)); + ASSERT_EQ(end_mask, 0); + + int32_t ellipsis_mask; + LITERT_ASSERT_STATUS_OK( + LiteRtGetStridedSliceEllipsisMaskOption(op, &ellipsis_mask)); + ASSERT_EQ(ellipsis_mask, 0); + + int32_t new_axis_mask; + LITERT_ASSERT_STATUS_OK( + LiteRtGetStridedSliceNewAxisMaskOption(op, &new_axis_mask)); + ASSERT_EQ(new_axis_mask, 0); + + int32_t shrink_axis_mask; + LITERT_ASSERT_STATUS_OK( + LiteRtGetStridedSliceShrinkAxisMaskOption(op, &shrink_axis_mask)); + ASSERT_EQ(shrink_axis_mask, 0); + + bool offset; + LITERT_ASSERT_STATUS_OK(LiteRtGetStridedSliceOffsetOption(op, &offset)); + ASSERT_EQ(offset, false); +} + +TEST(GetOpOptionTest, TestGetSubOptions) { + auto model = litert::testing::LoadTestFileModel("simple_sub_op.tflite"); + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + + auto ops = subgraph->Ops(); + auto op = ops.front().Get(); + + uint32_t fused_activation; + LITERT_ASSERT_STATUS_OK( + LiteRtGetSubFusedActivationOption(op, &fused_activation)); + ASSERT_EQ(fused_activation, 0); +} + +TEST(GetOpOptionTest, TestGetReshapeOptions) { + auto model = litert::testing::LoadTestFileModel("simple_reshape_op.tflite"); + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + + auto ops = subgraph->Ops(); + auto op = ops.front().Get(); + + const int32_t* new_shape = nullptr; + int32_t new_shape_size; + LITERT_ASSERT_STATUS_OK( + LiteRtGetReshapeNewShapeOption(op, &new_shape, &new_shape_size)); + ASSERT_EQ(new_shape_size, -1); +} + +TEST(GetOpOptionTest, TestGetSumOptions) { + auto model = litert::testing::LoadTestFileModel("simple_sum_op.tflite"); + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + + auto ops = subgraph->Ops(); + auto op = ops.front().Get(); + + bool keepdims; + LITERT_ASSERT_STATUS_OK(LiteRtGetSumKeepDimsOption(op, &keepdims)); + ASSERT_EQ(keepdims, true); +} + +} // namespace diff --git a/tflite/experimental/litert/c/litert_tensor_buffer.cc b/tflite/experimental/litert/c/litert_tensor_buffer.cc new file mode 100644 index 00000000..f1997800 --- /dev/null +++ b/tflite/experimental/litert/c/litert_tensor_buffer.cc @@ -0,0 +1,313 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" + +#include +#include + +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_event.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/runtime/tensor_buffer.h" + +LiteRtStatus LiteRtCreateTensorBufferFromHostMemory( + const LiteRtRankedTensorType* tensor_type, void* host_buffer_addr, + size_t size, LiteRtHostMemoryDeallocator deallocator, + LiteRtTensorBuffer* tensor_buffer) { + if (!tensor_type || !host_buffer_addr || !tensor_buffer) { + return kLiteRtStatusErrorInvalidArgument; + } + auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromHostMemory( + *tensor_type, + absl::MakeSpan(static_cast(host_buffer_addr), size), + deallocator); + if (!created_tensor_buffer) { + LITERT_LOG(LITERT_ERROR, "%s", + created_tensor_buffer.Error().Message().data()); + return created_tensor_buffer.Error().Status(); + } + *tensor_buffer = created_tensor_buffer->release(); + return kLiteRtStatusOk; +} + +#if LITERT_HAS_AHWB_SUPPORT +LiteRtStatus LiteRtCreateTensorBufferFromAhwb( + const LiteRtRankedTensorType* tensor_type, AHardwareBuffer* ahwb, + size_t ahwb_offset, LiteRtAhwbDeallocator deallocator, + LiteRtTensorBuffer* tensor_buffer) { + if (!tensor_type || !ahwb || !tensor_buffer) { + return kLiteRtStatusErrorInvalidArgument; + } + auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromAhwb( + *tensor_type, ahwb, ahwb_offset, deallocator); + if (!created_tensor_buffer) { + LITERT_LOG(LITERT_ERROR, "%s", + created_tensor_buffer.Error().Message().data()); + return created_tensor_buffer.Error().Status(); + } + *tensor_buffer = created_tensor_buffer->release(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferAhwb(LiteRtTensorBuffer tensor_buffer, + AHardwareBuffer** ahwb) { + if (!tensor_buffer || !ahwb) { + return kLiteRtStatusErrorInvalidArgument; + } + + auto ahwb_buffer = tensor_buffer->GetAhwbBuffer(); + if (!ahwb_buffer) { + LITERT_LOG(LITERT_ERROR, "%s", ahwb_buffer.Error().Message().data()); + return ahwb_buffer.Error().Status(); + } + + *ahwb = *ahwb_buffer; + return kLiteRtStatusOk; +} +#endif // LITERT_HAS_AHWB_SUPPORT + +#if LITERT_HAS_ION_SUPPORT +LiteRtStatus LiteRtCreateTensorBufferFromIonBuffer( + const LiteRtRankedTensorType* tensor_type, void* ion_buffer_addr, + int ion_buffer_fd, size_t ion_buffer_size, size_t ion_buffer_offset, + LiteRtIonDeallocator deallocator, LiteRtTensorBuffer* tensor_buffer) { + if (!tensor_type || !tensor_buffer) { + return kLiteRtStatusErrorInvalidArgument; + } + auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromIonBuffer( + *tensor_type, ion_buffer_addr, ion_buffer_fd, ion_buffer_size, + ion_buffer_offset, deallocator); + if (!created_tensor_buffer) { + LITERT_LOG(LITERT_ERROR, "%s", + created_tensor_buffer.Error().Message().data()); + return created_tensor_buffer.Error().Status(); + } + *tensor_buffer = created_tensor_buffer->release(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferIonBuffer(LiteRtTensorBuffer tensor_buffer, + void** ion_buffer_addr, + int* ion_buffer_fd) { + if (!tensor_buffer || !ion_buffer_addr || !ion_buffer_fd) { + return kLiteRtStatusErrorInvalidArgument; + } + + auto ion_buffer = tensor_buffer->GetIonBuffer(); + if (!ion_buffer) { + LITERT_LOG(LITERT_ERROR, "%s", ion_buffer.Error().Message().data()); + return ion_buffer.Error().Status(); + } + + *ion_buffer_addr = ion_buffer->first; + *ion_buffer_fd = ion_buffer->second; + return kLiteRtStatusOk; +} +#endif // LITERT_HAS_ION_SUPPORT + +#if LITERT_HAS_DMABUF_SUPPORT +LiteRtStatus LiteRtCreateTensorBufferFromDmaBufBuffer( + const LiteRtRankedTensorType* tensor_type, void* dmabuf_buffer_addr, + int dmabuf_buffer_fd, size_t dmabuf_buffer_size, + size_t dmabuf_buffer_offset, LiteRtDmaBufDeallocator deallocator, + LiteRtTensorBuffer* tensor_buffer) { + if (!tensor_type || !tensor_buffer) { + return kLiteRtStatusErrorInvalidArgument; + } + auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromDmaBufBuffer( + *tensor_type, dmabuf_buffer_addr, dmabuf_buffer_fd, dmabuf_buffer_size, + dmabuf_buffer_offset, deallocator); + if (!created_tensor_buffer) { + LITERT_LOG(LITERT_ERROR, "%s", + created_tensor_buffer.Error().Message().data()); + return created_tensor_buffer.Error().Status(); + } + *tensor_buffer = created_tensor_buffer->release(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferDmaBufBuffer(LiteRtTensorBuffer tensor_buffer, + void** dmabuf_buffer_addr, + int* dmabuf_buffer_fd) { + if (!tensor_buffer || !dmabuf_buffer_addr || !dmabuf_buffer_fd) { + return kLiteRtStatusErrorInvalidArgument; + } + + auto dmabuf_buffer = tensor_buffer->GetDmaBufBuffer(); + if (!dmabuf_buffer) { + LITERT_LOG(LITERT_ERROR, "%s", dmabuf_buffer.Error().Message().data()); + return dmabuf_buffer.Error().Status(); + } + + *dmabuf_buffer_addr = dmabuf_buffer->first; + *dmabuf_buffer_fd = dmabuf_buffer->second; + return kLiteRtStatusOk; +} +#endif // LITERT_HAS_DMABUF_SUPPORT + +#if LITERT_HAS_FASTRPC_SUPPORT +LiteRtStatus LiteRtCreateTensorBufferFromFastRpcBuffer( + const LiteRtRankedTensorType* tensor_type, void* fastrpc_buffer_addr, + int fastrpc_buffer_fd, size_t fastrpc_buffer_size, + size_t fastrpc_buffer_offset, LiteRtFastRpcDeallocator deallocator, + LiteRtTensorBuffer* tensor_buffer) { + if (!tensor_type || !tensor_buffer) { + return kLiteRtStatusErrorInvalidArgument; + } + auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromFastRpcBuffer( + *tensor_type, fastrpc_buffer_addr, fastrpc_buffer_fd, fastrpc_buffer_size, + fastrpc_buffer_offset, deallocator); + if (!created_tensor_buffer) { + LITERT_LOG(LITERT_ERROR, "%s", + created_tensor_buffer.Error().Message().data()); + return created_tensor_buffer.Error().Status(); + } + *tensor_buffer = created_tensor_buffer->release(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferFastRpcBuffer( + LiteRtTensorBuffer tensor_buffer, void** fastrpc_buffer_addr, + int* fastrpc_buffer_fd) { + if (!tensor_buffer || !fastrpc_buffer_addr || !fastrpc_buffer_fd) { + return kLiteRtStatusErrorInvalidArgument; + } + + auto fastrpc_buffer = tensor_buffer->GetFastRpcBuffer(); + if (!fastrpc_buffer) { + LITERT_LOG(LITERT_ERROR, "%s", fastrpc_buffer.Error().Message().data()); + return fastrpc_buffer.Error().Status(); + } + + *fastrpc_buffer_addr = fastrpc_buffer->first; + *fastrpc_buffer_fd = fastrpc_buffer->second; + return kLiteRtStatusOk; +} +#endif // LITERT_HAS_FASTRPC_SUPPORT + +LiteRtStatus LiteRtCreateManagedTensorBuffer( + LiteRtTensorBufferType buffer_type, + const LiteRtRankedTensorType* tensor_type, size_t buffer_size, + LiteRtTensorBuffer* tensor_buffer) { + if (!tensor_type || !tensor_buffer) { + return kLiteRtStatusErrorInvalidArgument; + } + auto created_tensor_buffer = LiteRtTensorBufferT::CreateManaged( + buffer_type, *tensor_type, buffer_size); + if (!created_tensor_buffer) { + LITERT_LOG(LITERT_ERROR, "%s", + created_tensor_buffer.Error().Message().data()); + return created_tensor_buffer.Error().Status(); + } + *tensor_buffer = created_tensor_buffer->release(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtDuplicateTensorBuffer(LiteRtTensorBuffer tensor_buffer) { + if (!tensor_buffer) { + return kLiteRtStatusErrorInvalidArgument; + } + tensor_buffer->Duplicate(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferType(LiteRtTensorBuffer tensor_buffer, + LiteRtTensorBufferType* buffer_type) { + if (!tensor_buffer || !buffer_type) { + return kLiteRtStatusErrorInvalidArgument; + } + *buffer_type = tensor_buffer->buffer_type(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferTensorType( + LiteRtTensorBuffer tensor_buffer, LiteRtRankedTensorType* tensor_type) { + if (!tensor_buffer || !tensor_type) { + return kLiteRtStatusErrorInvalidArgument; + } + *tensor_type = tensor_buffer->tensor_type(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferSize(LiteRtTensorBuffer tensor_buffer, + size_t* buffer_size) { + if (!tensor_buffer || !buffer_size) { + return kLiteRtStatusErrorInvalidArgument; + } + *buffer_size = tensor_buffer->buffer_size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferOffset(LiteRtTensorBuffer tensor_buffer, + size_t* buffer_offset) { + if (!tensor_buffer || !buffer_offset) { + return kLiteRtStatusErrorInvalidArgument; + } + *buffer_offset = tensor_buffer->buffer_offset(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferHostMemory(LiteRtTensorBuffer tensor_buffer, + void** host_memory_addr) { + if (!tensor_buffer || !host_memory_addr) { + return kLiteRtStatusErrorInvalidArgument; + } + + auto host_buffer = tensor_buffer->GetHostBuffer(); + if (!host_buffer) { + LITERT_LOG(LITERT_ERROR, "%s", host_buffer.Error().Message().data()); + return host_buffer.Error().Status(); + } + + *host_memory_addr = *host_buffer; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtLockTensorBuffer(LiteRtTensorBuffer tensor_buffer, + void** host_mem_addr, LiteRtEvent event) { + if (!tensor_buffer || !host_mem_addr) { + return kLiteRtStatusErrorInvalidArgument; + } + + auto mapped_addr = tensor_buffer->Lock(event); + if (!mapped_addr) { + LITERT_LOG(LITERT_ERROR, "%s", mapped_addr.Error().Message().data()); + return mapped_addr.Error().Status(); + } + + *host_mem_addr = *mapped_addr; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtUnlockTensorBuffer(LiteRtTensorBuffer tensor_buffer) { + if (!tensor_buffer) { + return kLiteRtStatusErrorInvalidArgument; + } + + if (auto status = tensor_buffer->Unlock(); !status) { + LITERT_LOG(LITERT_ERROR, "%s", status.Error().Message().data()); + return status.Error().Status(); + } + + return kLiteRtStatusOk; +} + +void LiteRtDestroyTensorBuffer(LiteRtTensorBuffer tensor_buffer) { + if (tensor_buffer->Unref()) { + delete tensor_buffer; + } +} diff --git a/tflite/experimental/litert/c/litert_tensor_buffer.h b/tflite/experimental/litert/c/litert_tensor_buffer.h new file mode 100644 index 00000000..1f4c4674 --- /dev/null +++ b/tflite/experimental/litert/c/litert_tensor_buffer.h @@ -0,0 +1,190 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_H_ + +#include +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_event.h" +#include "tflite/experimental/litert/c/litert_model.h" + +#if LITERT_HAS_AHWB_SUPPORT +#include +#else +// Define a place holder AHardwareBuffer struct just to enable compilation. +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus +typedef struct AHardwareBuffer AHardwareBuffer; +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // LITERT_HAS_AHWB_SUPPORT + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +LITERT_DEFINE_HANDLE(LiteRtTensorBuffer); + +#define LITERT_HOST_MEMORY_BUFFER_ALIGNMENT 64 + +typedef enum { + kLiteRtTensorBufferTypeUnknown = 0, + kLiteRtTensorBufferTypeHostMemory = 1, + kLiteRtTensorBufferTypeAhwb = 2, + kLiteRtTensorBufferTypeIon = 3, + kLiteRtTensorBufferTypeDmaBuf = 4, + kLiteRtTensorBufferTypeFastRpc = 5, +} LiteRtTensorBufferType; + +typedef void (*LiteRtHostMemoryDeallocator)(void* addr); +typedef void (*LiteRtAhwbDeallocator)(AHardwareBuffer* ahwb); +typedef void (*LiteRtIonDeallocator)(void* ion_buffer_addr); +typedef void (*LiteRtDmaBufDeallocator)(void* dmabuf_buffer_addr); +typedef void (*LiteRtFastRpcDeallocator)(void* fastrpc_buffer_addr); + +// ///////////////////////////////////////////////////////////////////////////// +// TensorBuffers. +// ///////////////////////////////////////////////////////////////////////////// + +// Create a tensor buffer from an existing host memory buffer of a given size, +// with optional host memory buffer deallocator (it can be NULL). Return an +// error if the passed host memory buffer doesn't satisfy +// LITERT_HOST_MEMORY_BUFFER_ALIGNMENT alignment. +LiteRtStatus LiteRtCreateTensorBufferFromHostMemory( + const LiteRtRankedTensorType* tensor_type, void* host_buffer_addr, + size_t host_buffer_size, LiteRtHostMemoryDeallocator deallocator, + LiteRtTensorBuffer* buffer); + +// Return an error if the backing buffer is not allocated on the host memory. +LiteRtStatus LiteRtGetTensorBufferHostMemory(LiteRtTensorBuffer tensor_buffer, + void** host_memory_addr); + +#if LITERT_HAS_AHWB_SUPPORT +// Create a tensor buffer from an existing AHardwareBuffer, with optional +// AHardwareBuffer deallocator (it can be NULL). An non-zero `buffer_offset` can +// be used to specify multiple tensor buffers sharing the same underlying AHWB, +// in which case the provided AHWB must be sufficiently large to accomodate for +// the allocation needed for all tensor buffers sharing it. +LiteRtStatus LiteRtCreateTensorBufferFromAhwb( + const LiteRtRankedTensorType* tensor_type, AHardwareBuffer* ahwb, + size_t ahwb_offset, LiteRtAhwbDeallocator deallocator, + LiteRtTensorBuffer* buffer); + +// Return an error if the backing buffer is not an AhardwareBuffer. +LiteRtStatus LiteRtGetTensorBufferAhwb(LiteRtTensorBuffer tensor_buffer, + AHardwareBuffer** ahwb); +#endif // LITERT_HAS_AHWB_SUPPORT + +#if LITERT_HAS_ION_SUPPORT +// Create a tensor buffer from an existing ION buffer of a given size, with +// optional ION buffer deallocator (it can be NULL). An non-zero +// `ion_buffer_offset` can be used to specify multiple tensor buffers sharing +// the same underlying ION buffer, in which case parameter `ion_buffer_size` +// must be the entire size of the underlying ION memory buffer, including the +// allocation needed for all tensor buffers sharing it. +LiteRtStatus LiteRtCreateTensorBufferFromIonBuffer( + const LiteRtRankedTensorType* tensor_type, void* ion_buffer_addr, + int ion_buffer_fd, size_t ion_buffer_size, size_t ion_buffer_offset, + LiteRtIonDeallocator deallocator, LiteRtTensorBuffer* buffer); + +// Return an error if the backing buffer is not an ION buffer. +LiteRtStatus LiteRtGetTensorBufferIonBuffer(LiteRtTensorBuffer buffer, + void** ion_buffer_addr, + int* ion_buffer_fd); +#endif // LITERT_HAS_ION_SUPPORT + +#if LITERT_HAS_DMABUF_SUPPORT +// Create a tensor buffer from an existing DMA-BUF buffer of a given size, with +// optional DMA-BUF buffer deallocator (it can be NULL). An non-zero +// `dmabuf_buffer_offset` can be used to specify multiple tensor buffers sharing +// the same underlying ION buffer, in which case parameter `ion_buffer_size` +// must be the entire size of the underlying ION memory buffer, including the +// allocation needed for all tensor buffers sharing it. +LiteRtStatus LiteRtCreateTensorBufferFromDmaBufBuffer( + const LiteRtRankedTensorType* tensor_type, void* dmabuf_buffer_addr, + int dmabuf_buffer_fd, size_t dmabuf_buffer_size, + size_t dmabuf_buffer_offset, LiteRtDmaBufDeallocator deallocator, + LiteRtTensorBuffer* buffer); + +// Return an error if the backing buffer is not an DMA-BUF buffer. +LiteRtStatus LiteRtGetTensorBufferDmaBufBuffer(LiteRtTensorBuffer tensor_buffer, + void** dmabuf_buffer_addr, + int* dmabuf_buffer_fd); +#endif // LITERT_HAS_DMABUF_SUPPORT + +#if LITERT_HAS_FASTRPC_SUPPORT +// Create a tensor buffer from an existing FastRPC memory buffer of a given +// size, with optional FastRPC memory buffer deallocator (it can be NULL). An +// non-zero `fastrpc_buffer_offset` can be used to specify multiple tensor +// buffers sharing the same underlying FastRPC memory buffer, in which case +// parameter `fastrpc_buffer_size` must be the entire size of the underlying +// FastRPC memory buffer, including the allocation needed for all tensor buffers +// sharing it. +LiteRtStatus LiteRtCreateTensorBufferFromFastRpcBuffer( + const LiteRtRankedTensorType* tensor_type, void* fastrpc_buffer_addr, + int fastrpc_fd, size_t fastrpc_buffer_size, size_t fastrpc_buffer_offset, + LiteRtFastRpcDeallocator deallocator, LiteRtTensorBuffer* buffer); + +// Return an error if the backing buffer is not a FastRPC memory buffer. +LiteRtStatus LiteRtGetTensorBufferFastRpcBuffer( + LiteRtTensorBuffer tensor_buffer, void** fastrpc_buffer_addr, + int* fastrpc_buffer_fd); +#endif // LITERT_HAS_FASTRPC_SUPPORT + +// Create a buffer backed by managed memory for a given size. +LiteRtStatus LiteRtCreateManagedTensorBuffer( + LiteRtTensorBufferType buffer_type, + const LiteRtRankedTensorType* tensor_type, size_t buffer_size, + LiteRtTensorBuffer* buffer); + +// Create a duplicate of the current tensor buffer. It will increase the +// reference count of a managed tensor buffer. And the number decreases when +// LiteRtDestroyTensorBuffer() is called. +LiteRtStatus LiteRtDuplicateTensorBuffer(LiteRtTensorBuffer tensor_buffer); + +LiteRtStatus LiteRtGetTensorBufferType(LiteRtTensorBuffer tensor_buffer, + LiteRtTensorBufferType* buffer_type); + +LiteRtStatus LiteRtGetTensorBufferTensorType( + LiteRtTensorBuffer tensor_buffer, LiteRtRankedTensorType* tensor_type); + +LiteRtStatus LiteRtGetTensorBufferSize(LiteRtTensorBuffer tensor_buffer, + size_t* size); + +LiteRtStatus LiteRtGetTensorBufferOffset(LiteRtTensorBuffer tensor_buffer, + size_t* offset); + +// Lock a tensor buffer and map it to host memory, optionally synchronizing on a +// given input event (parameter `event` can be NULL). +LiteRtStatus LiteRtLockTensorBuffer(LiteRtTensorBuffer tensor_buffer, + void** host_mem_addr, LiteRtEvent event); + +// Unlock a tensor buffer and (potentially) unmap it from host memory. +LiteRtStatus LiteRtUnlockTensorBuffer(LiteRtTensorBuffer buffer); + +// Destroy a tensor buffer. If the tensor buffer is managed, the number of +// references to it is decreased and released the underlying TensorBufferT when +// the last reference is removed. +void LiteRtDestroyTensorBuffer(LiteRtTensorBuffer buffer); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_H_ diff --git a/tflite/experimental/litert/c/litert_tensor_buffer_requirements.cc b/tflite/experimental/litert/c/litert_tensor_buffer_requirements.cc new file mode 100644 index 00000000..80d1fb15 --- /dev/null +++ b/tflite/experimental/litert/c/litert_tensor_buffer_requirements.cc @@ -0,0 +1,108 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" + +#include +#include +#include +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" + +class LiteRtTensorBufferRequirementsT { + public: + LiteRtTensorBufferRequirementsT( + int num_supported_tensor_buffer_types, + const LiteRtTensorBufferType* supported_tensor_buffer_types, + size_t buffer_size, std::vector&& strides) + : supported_buffer_types_( + supported_tensor_buffer_types, + supported_tensor_buffer_types + num_supported_tensor_buffer_types), + buffer_size_(buffer_size), + strides_(std::move(strides)) {} + const std::vector& SupportedBufferTypes() const { + return supported_buffer_types_; + } + size_t BufferSize() const { return buffer_size_; } + const std::vector& Strides() const { return strides_; } + + private: + std::vector supported_buffer_types_; + size_t buffer_size_; + // Stride per each dimension. + std::vector strides_; +}; + +LiteRtStatus LiteRtCreateTensorBufferRequirements( + int num_supported_tensor_buffer_types, + const LiteRtTensorBufferType* supported_tensor_buffer_types, + size_t buffer_size, int num_strides, const uint32_t* strides, + LiteRtTensorBufferRequirements* requirements) { + if (num_supported_tensor_buffer_types < 1 || !supported_tensor_buffer_types || + !requirements) { + return kLiteRtStatusErrorInvalidArgument; + } + *requirements = new LiteRtTensorBufferRequirementsT( + num_supported_tensor_buffer_types, supported_tensor_buffer_types, + buffer_size, std::vector(strides, strides + num_strides)); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + LiteRtTensorBufferRequirements requirements, int* num_types) { + if (!requirements || !num_types) { + return kLiteRtStatusErrorInvalidArgument; + } + *num_types = requirements->SupportedBufferTypes().size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + LiteRtTensorBufferRequirements requirements, int type_index, + LiteRtTensorBufferType* type) { + if (!requirements || type_index < 0 || + type_index >= requirements->SupportedBufferTypes().size()) { + return kLiteRtStatusErrorInvalidArgument; + } + *type = requirements->SupportedBufferTypes()[type_index]; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferRequirementsBufferSize( + LiteRtTensorBufferRequirements requirements, size_t* buffer_size) { + if (!requirements || !buffer_size) { + return kLiteRtStatusErrorInvalidArgument; + } + *buffer_size = requirements->BufferSize(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferRequirementsStrides( + LiteRtTensorBufferRequirements requirements, int* num_strides, + const uint32_t** strides) { + if (!requirements || !num_strides || !strides) { + return kLiteRtStatusErrorInvalidArgument; + } + auto& s = requirements->Strides(); + *num_strides = s.size(); + *strides = s.data(); + return kLiteRtStatusOk; +} + +void LiteRtDestroyTensorBufferRequirements( + LiteRtTensorBufferRequirements requirements) { + delete requirements; +} diff --git a/tflite/experimental/litert/c/litert_tensor_buffer_requirements.h b/tflite/experimental/litert/c/litert_tensor_buffer_requirements.h new file mode 100644 index 00000000..a9bd0034 --- /dev/null +++ b/tflite/experimental/litert/c/litert_tensor_buffer_requirements.h @@ -0,0 +1,57 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ + +#include +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +LITERT_DEFINE_HANDLE(LiteRtTensorBufferRequirements); + +LiteRtStatus LiteRtCreateTensorBufferRequirements( + int num_supported_tensor_buffer_types, + const LiteRtTensorBufferType* supported_tensor_buffer_types, + size_t buffer_size, int num_strides, const uint32_t* strides, + LiteRtTensorBufferRequirements* requirements); + +LiteRtStatus LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + LiteRtTensorBufferRequirements requirements, int* num_types); + +LiteRtStatus LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + LiteRtTensorBufferRequirements requirements, int type_index, + LiteRtTensorBufferType* type); + +LiteRtStatus LiteRtGetTensorBufferRequirementsBufferSize( + LiteRtTensorBufferRequirements requirements, size_t* buffer_size); + +LiteRtStatus LiteRtGetTensorBufferRequirementsStrides( + LiteRtTensorBufferRequirements requirements, int* num_strides, + const uint32_t** strides); + +void LiteRtDestroyTensorBufferRequirements( + LiteRtTensorBufferRequirements requirements); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ diff --git a/tflite/experimental/litert/c/litert_tensor_buffer_requirements_test.cc b/tflite/experimental/litert/c/litert_tensor_buffer_requirements_test.cc new file mode 100644 index 00000000..a6420ee1 --- /dev/null +++ b/tflite/experimental/litert/c/litert_tensor_buffer_requirements_test.cc @@ -0,0 +1,92 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" + +#include +#include +#include + +#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" + +namespace { + +constexpr const LiteRtTensorBufferType kSupportedTensorBufferTypes[] = { + kLiteRtTensorBufferTypeHostMemory, + kLiteRtTensorBufferTypeAhwb, + kLiteRtTensorBufferTypeIon, + kLiteRtTensorBufferTypeFastRpc, +}; + +constexpr const size_t kNumSupportedTensorBufferTypes = + sizeof(kSupportedTensorBufferTypes) / + sizeof(kSupportedTensorBufferTypes[0]); + +constexpr const size_t kBufferSize = 1234; + +} // namespace + +TEST(TensorBufferRequirements, NoStrides) { + LiteRtTensorBufferRequirements requirements; + ASSERT_EQ(LiteRtCreateTensorBufferRequirements( + kNumSupportedTensorBufferTypes, kSupportedTensorBufferTypes, + kBufferSize, + /*num_strides=*/0, /*strides=*/nullptr, &requirements), + kLiteRtStatusOk); + + int num_types; + ASSERT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + requirements, &num_types), + kLiteRtStatusOk); + ASSERT_EQ(num_types, kNumSupportedTensorBufferTypes); + + for (auto i = 0; i < num_types; ++i) { + LiteRtTensorBufferType type; + ASSERT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + requirements, i, &type), + kLiteRtStatusOk); + ASSERT_EQ(type, kSupportedTensorBufferTypes[i]); + } + + size_t size; + ASSERT_EQ(LiteRtGetTensorBufferRequirementsBufferSize(requirements, &size), + kLiteRtStatusOk); + ASSERT_EQ(size, kBufferSize); + + LiteRtDestroyTensorBufferRequirements(requirements); +} + +TEST(TensorBufferRequirements, WithStrides) { + constexpr std::array kStrides = {1, 2, 3}; + + LiteRtTensorBufferRequirements requirements; + ASSERT_EQ(LiteRtCreateTensorBufferRequirements( + kNumSupportedTensorBufferTypes, kSupportedTensorBufferTypes, + kBufferSize, kStrides.size(), kStrides.data(), &requirements), + kLiteRtStatusOk); + + int num_strides; + const uint32_t* strides; + ASSERT_EQ(LiteRtGetTensorBufferRequirementsStrides(requirements, &num_strides, + &strides), + kLiteRtStatusOk); + ASSERT_EQ(num_strides, kStrides.size()); + for (auto i = 0; i < kStrides.size(); ++i) { + ASSERT_EQ(strides[i], kStrides[i]); + } + + LiteRtDestroyTensorBufferRequirements(requirements); +} diff --git a/tflite/experimental/litert/c/litert_tensor_buffer_test.cc b/tflite/experimental/litert/c/litert_tensor_buffer_test.cc new file mode 100644 index 00000000..8561f453 --- /dev/null +++ b/tflite/experimental/litert/c/litert_tensor_buffer_test.cc @@ -0,0 +1,296 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" + +#include +#include + +#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_layout.h" +#include "tflite/experimental/litert/runtime/ahwb_buffer.h" // IWYU pragma: keep +#include "tflite/experimental/litert/runtime/dmabuf_buffer.h" // IWYU pragma: keep +#include "tflite/experimental/litert/runtime/fastrpc_buffer.h" // IWYU pragma: keep +#include "tflite/experimental/litert/runtime/ion_buffer.h" // IWYU pragma: keep + +namespace { +constexpr const float kTensorData[] = {10, 20, 30, 40}; + +constexpr const int32_t kTensorDimensions[] = {sizeof(kTensorData) / + sizeof(kTensorData[0])}; + +constexpr const LiteRtRankedTensorType kTensorType = { + /*.element_type=*/kLiteRtElementTypeFloat32, + ::litert::BuildLayout(kTensorDimensions)}; + +} // namespace + +TEST(TensorBuffer, HostMemory) { + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeHostMemory; + + LiteRtTensorBuffer tensor_buffer; + ASSERT_EQ( + LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, + sizeof(kTensorData), &tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBufferType buffer_type; + ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), + kLiteRtStatusOk); + ASSERT_EQ(buffer_type, kTensorBufferType); + + LiteRtRankedTensorType tensor_type; + ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), + kLiteRtStatusOk); + ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); + ASSERT_EQ(tensor_type.layout.rank, 1); + ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); + ASSERT_EQ(tensor_type.layout.strides, nullptr); + + size_t size; + ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); + ASSERT_EQ(size, sizeof(kTensorData)); + + size_t offset; + ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), + kLiteRtStatusOk); + ASSERT_EQ(offset, 0); + + void* host_mem_addr; + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + LiteRtDestroyTensorBuffer(tensor_buffer); +} + +TEST(TensorBuffer, Ahwb) { + if (!litert::internal::AhwbBuffer::IsSupported()) { + GTEST_SKIP() << "AHardwareBuffers are not supported on this platform; " + "skipping the test"; + } + + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeAhwb; + + LiteRtTensorBuffer tensor_buffer; + ASSERT_EQ( + LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, + sizeof(kTensorData), &tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBufferType buffer_type; + ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), + kLiteRtStatusOk); + ASSERT_EQ(buffer_type, kTensorBufferType); + + LiteRtRankedTensorType tensor_type; + ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), + kLiteRtStatusOk); + ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); + ASSERT_EQ(tensor_type.layout.rank, 1); + ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); + ASSERT_EQ(tensor_type.layout.strides, nullptr); + + size_t size; + ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); + ASSERT_EQ(size, sizeof(kTensorData)); + + size_t offset; + ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), + kLiteRtStatusOk); + ASSERT_EQ(offset, 0); + + void* host_mem_addr; + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + LiteRtDestroyTensorBuffer(tensor_buffer); +} + +TEST(TensorBuffer, Ion) { + if (!litert::internal::IonBuffer::IsSupported()) { + GTEST_SKIP() + << "ION buffers are not supported on this platform; skipping the test"; + } + + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeIon; + + LiteRtTensorBuffer tensor_buffer; + ASSERT_EQ( + LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, + sizeof(kTensorData), &tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBufferType buffer_type; + ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), + kLiteRtStatusOk); + ASSERT_EQ(buffer_type, kTensorBufferType); + + LiteRtRankedTensorType tensor_type; + ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), + kLiteRtStatusOk); + ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); + ASSERT_EQ(tensor_type.layout.rank, 1); + ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); + ASSERT_EQ(tensor_type.layout.strides, nullptr); + + size_t size; + ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); + ASSERT_EQ(size, sizeof(kTensorData)); + + size_t offset; + ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), + kLiteRtStatusOk); + ASSERT_EQ(offset, 0); + + void* host_mem_addr; + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + LiteRtDestroyTensorBuffer(tensor_buffer); +} + +TEST(TensorBuffer, DmaBuf) { + if (!litert::internal::DmaBufBuffer::IsSupported()) { + GTEST_SKIP() + << "DMA-BUF buffers are not supported on this platform; skipping " + "the test"; + } + + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeDmaBuf; + + LiteRtTensorBuffer tensor_buffer; + ASSERT_EQ( + LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, + sizeof(kTensorData), &tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBufferType buffer_type; + ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), + kLiteRtStatusOk); + ASSERT_EQ(buffer_type, kTensorBufferType); + + LiteRtRankedTensorType tensor_type; + ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), + kLiteRtStatusOk); + ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); + ASSERT_EQ(tensor_type.layout.rank, 1); + ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); + ASSERT_EQ(tensor_type.layout.strides, nullptr); + + size_t size; + ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); + ASSERT_EQ(size, sizeof(kTensorData)); + + size_t offset; + ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), + kLiteRtStatusOk); + ASSERT_EQ(offset, 0); + + void* host_mem_addr; + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + LiteRtDestroyTensorBuffer(tensor_buffer); +} + +TEST(TensorBuffer, FastRpc) { + if (!litert::internal::FastRpcBuffer::IsSupported()) { + GTEST_SKIP() + << "FastRPC buffers are not supported on this platform; skipping " + "the test"; + } + + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeFastRpc; + + LiteRtTensorBuffer tensor_buffer; + ASSERT_EQ( + LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, + sizeof(kTensorData), &tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBufferType buffer_type; + ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), + kLiteRtStatusOk); + ASSERT_EQ(buffer_type, kTensorBufferType); + + LiteRtRankedTensorType tensor_type; + ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), + kLiteRtStatusOk); + ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); + ASSERT_EQ(tensor_type.layout.rank, 1); + ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); + ASSERT_EQ(tensor_type.layout.strides, nullptr); + + size_t size; + ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); + ASSERT_EQ(size, sizeof(kTensorData)); + + size_t offset; + ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), + kLiteRtStatusOk); + ASSERT_EQ(offset, 0); + + void* host_mem_addr; + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + LiteRtDestroyTensorBuffer(tensor_buffer); +} diff --git a/tflite/experimental/litert/cc/BUILD b/tflite/experimental/litert/cc/BUILD new file mode 100644 index 00000000..dcb46d1f --- /dev/null +++ b/tflite/experimental/litert/cc/BUILD @@ -0,0 +1,351 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "litert_environment", + hdrs = ["litert_environment.h"], + deps = [ + ":litert_any", + ":litert_expected", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_environment", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "litert_any", + hdrs = ["litert_any.h"], + deps = [ + ":litert_expected", + "//tflite/experimental/litert/c:litert_any", + "//tflite/experimental/litert/c:litert_common", + ], +) + +cc_test( + name = "litert_any_test", + srcs = [ + "litert_any_test.cc", + ], + deps = [ + ":litert_any", + "//tflite/experimental/litert/c:litert_common", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "litert_model", + srcs = ["litert_model.cc"], + hdrs = ["litert_model.h"], + deps = [ + ":litert_buffer_ref", + ":litert_detail", + ":litert_element_type", + ":litert_expected", + ":litert_handle", + ":litert_layout", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_model", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "litert_model_test", + srcs = [ + "litert_model_test.cc", + ], + data = [ + "//tflite/experimental/litert/test:mlir_test_data", + ], + deps = [ + ":litert_element_type", + ":litert_layout", + ":litert_model", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/core/model", + "//tflite/experimental/litert/test:common", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "litert_handle", + hdrs = ["litert_handle.h"], +) + +cc_library( + name = "litert_tensor_buffer", + hdrs = [ + "litert_tensor_buffer.h", + "litert_tensor_buffer_requirements.h", + ], + deps = [ + ":litert_detail", + ":litert_handle", + ":litert_model", + "//tflite/c:c_api_types", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "//tflite/experimental/litert/cc:litert_expected", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "litert_tensor_buffer_test", + srcs = [ + "litert_tensor_buffer_test.cc", + ], + linkopts = select({ + "@org_tensorflow//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + deps = [ + ":litert_layout", + ":litert_model", + ":litert_tensor_buffer", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "//tflite/experimental/litert/runtime:tensor_buffer", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "litert_tensor_buffer_requirements", + hdrs = [ + "litert_tensor_buffer_requirements.h", + ], + deps = [ + ":litert_detail", + ":litert_handle", + ":litert_macros", + "//tflite/c:c_api_types", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "//tflite/experimental/litert/cc:litert_expected", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "litert_tensor_buffer_requirements_test", + srcs = [ + "litert_tensor_buffer_requirements_test.cc", + ], + deps = [ + ":litert_tensor_buffer", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "litert_buffer_ref", + hdrs = [ + "litert_buffer_ref.h", + ], + deps = [ + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "litert_macros", + hdrs = ["litert_macros.h"], + deps = [ + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "@com_google_absl//absl/log:absl_check", + ], +) + +cc_library( + name = "litert_expected", + hdrs = ["litert_expected.h"], + deps = [ + ":litert_detail", + "//tflite/experimental/litert/c:litert_common", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "litert_expected_test", + srcs = ["litert_expected_test.cc"], + deps = [ + ":litert_buffer_ref", + ":litert_expected", + ":litert_macros", + "//tflite/experimental/litert/c:litert_common", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "litert_detail", + hdrs = ["litert_detail.h"], + deps = [ + "//tflite/experimental/litert/c:litert_common", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:absl_check", + ], +) + +cc_test( + name = "litert_buffer_ref_test", + srcs = ["litert_buffer_ref_test.cc"], + deps = [ + ":litert_buffer_ref", + "//tflite/experimental/litert/core/util:flatbuffer_tools", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "litert_element_type", + hdrs = ["litert_element_type.h"], + deps = ["//tflite/experimental/litert/c:litert_model"], +) + +cc_test( + name = "litert_element_type_test", + srcs = ["litert_element_type_test.cc"], + deps = [ + ":litert_element_type", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "litert_model_predicates", + srcs = ["litert_model_predicates.cc"], + hdrs = ["litert_model_predicates.h"], + deps = [ + ":litert_detail", + ":litert_model", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "litert_layout", + hdrs = ["litert_layout.h"], + deps = [ + "//tflite/experimental/litert/c:litert_layout", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "litert_model_predicates_test", + srcs = ["litert_model_predicates_test.cc"], + data = [ + "//tflite/experimental/litert/test:mlir_test_data", + ], + deps = [ + ":litert_element_type", + ":litert_model", + ":litert_model_predicates", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/test:common", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "litert_layout_test", + srcs = ["litert_layout_test.cc"], + deps = [ + ":litert_layout", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "litert_compiled_model", + srcs = ["litert_compiled_model.cc"], + hdrs = ["litert_compiled_model.h"], + deps = [ + ":litert_detail", + ":litert_expected", + ":litert_handle", + ":litert_model", + ":litert_tensor_buffer", + "//tflite:framework", + "//tflite/c:c_api_opaque", + "//tflite/c:c_api_types", + "//tflite/c:common", + "//tflite/core:cc_api_stable", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_compiled_model", + "//tflite/experimental/litert/c:litert_compiled_model_options", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "//tflite/kernels:builtin_ops", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "litert_compiled_model_test", + srcs = ["litert_compiled_model_test.cc"], + data = [ + "//tflite/experimental/litert/test:testdata/simple_model.tflite", + ], + deps = [ + ":litert_compiled_model", + ":litert_model", + ":litert_tensor_buffer", + "//tflite:framework", + "//tflite/c:c_api_opaque", + "//tflite/c:common", + "//tflite/experimental/litert/test:common", + "//tflite/experimental/litert/test:simple_model", + "//tflite/kernels:builtin_ops", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tflite/experimental/litert/cc/litert_any.h b/tflite/experimental/litert/cc/litert_any.h new file mode 100644 index 00000000..07be03f8 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_any.h @@ -0,0 +1,109 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ANY_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ANY_H_ + +#include +#include + +#include "tflite/experimental/litert/c/litert_any.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +namespace litert { + +inline std::any ToStdAny(LiteRtAny litert_any) { + std::any res; + switch (litert_any.type) { + case kLiteRtAnyTypeNone: + break; + case kLiteRtAnyTypeBool: + res = litert_any.bool_value; + break; + case kLiteRtAnyTypeInt: + res = litert_any.int_value; + break; + case kLiteRtAnyTypeReal: + res = litert_any.real_value; + break; + case kLiteRtAnyTypeString: + res = litert_any.str_value; + break; + case kLiteRtAnyTypeVoidPtr: + res = litert_any.ptr_value; + break; + } + return res; +} + +inline Expected ToLiteRtAny(const std::any& any) { + LiteRtAny result; + if (!any.has_value()) { + result.type = kLiteRtAnyTypeNone; + return result; + + } else if (any.type() == typeid(LiteRtAny::bool_value)) { + result.type = kLiteRtAnyTypeBool; + result.bool_value = std::any_cast(any); + return result; + + } else if (any.type() == typeid(int8_t)) { + result.type = kLiteRtAnyTypeInt; + result.int_value = std::any_cast(any); + return result; + + } else if (any.type() == typeid(int16_t)) { + result.type = kLiteRtAnyTypeInt; + result.int_value = std::any_cast(any); + return result; + + } else if (any.type() == typeid(int32_t)) { + result.type = kLiteRtAnyTypeInt; + result.int_value = std::any_cast(any); + return result; + + } else if (any.type() == typeid(int64_t)) { + result.type = kLiteRtAnyTypeInt; + result.int_value = std::any_cast(any); + return result; + + } else if (any.type() == typeid(float)) { + result.type = kLiteRtAnyTypeReal; + result.real_value = std::any_cast(any); + return result; + + } else if (any.type() == typeid(double)) { + result.type = kLiteRtAnyTypeReal; + result.real_value = std::any_cast(any); + return result; + + } else if (any.type() == typeid(LiteRtAny::str_value)) { + result.type = kLiteRtAnyTypeString; + result.str_value = std::any_cast(any); + return result; + + } else if (any.type() == typeid(LiteRtAny::ptr_value)) { + result.type = kLiteRtAnyTypeVoidPtr; + result.ptr_value = std::any_cast(any); + return result; + + } else { + return Error(kLiteRtStatusErrorInvalidArgument); + } +} + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ANY_H_ diff --git a/tflite/experimental/litert/cc/litert_any_test.cc b/tflite/experimental/litert/cc/litert_any_test.cc new file mode 100644 index 00000000..eaf8d594 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_any_test.cc @@ -0,0 +1,110 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_any.h" + +TEST(Any, ConversionNone) { + EXPECT_FALSE( + litert::ToStdAny(LiteRtAny{/*.type=*/kLiteRtAnyTypeNone}).has_value()); + + ASSERT_EQ(litert::ToLiteRtAny(std::any())->type, kLiteRtAnyTypeNone); +} + +TEST(Any, ConversionBool) { + ASSERT_EQ(std::any_cast(litert::ToStdAny(LiteRtAny{ + /*.type=*/kLiteRtAnyTypeBool, {/*.bool_value=*/true}})), + true); + ASSERT_EQ(std::any_cast(litert::ToStdAny(LiteRtAny{ + /*.type=*/kLiteRtAnyTypeBool, {/*.bool_value=*/false}})), + false); + + ASSERT_EQ(litert::ToLiteRtAny(std::any(true))->type, kLiteRtAnyTypeBool); + ASSERT_EQ(litert::ToLiteRtAny(std::any(true))->bool_value, true); + ASSERT_EQ(litert::ToLiteRtAny(std::any(false))->type, kLiteRtAnyTypeBool); + ASSERT_EQ(litert::ToLiteRtAny(std::any(false))->bool_value, false); +} + +TEST(Any, ConversionInt) { + LiteRtAny litert_any; + litert_any.type = kLiteRtAnyTypeInt; + litert_any.int_value = 1234; + ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), 1234); + + ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(12)))->type, + kLiteRtAnyTypeInt); + ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(12)))->int_value, + 12); + ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1234)))->type, + kLiteRtAnyTypeInt); + ASSERT_EQ( + litert::ToLiteRtAny(std::any(static_cast(1234)))->int_value, + 1234); + ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1234)))->type, + kLiteRtAnyTypeInt); + ASSERT_EQ( + litert::ToLiteRtAny(std::any(static_cast(1234)))->int_value, + 1234); + ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1234)))->type, + kLiteRtAnyTypeInt); + ASSERT_EQ( + litert::ToLiteRtAny(std::any(static_cast(1234)))->int_value, + 1234); +} + +TEST(Any, ConversionReal) { + LiteRtAny litert_any; + litert_any.type = kLiteRtAnyTypeReal; + litert_any.real_value = 123.4; + ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), 123.4); + + ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1.2)))->type, + kLiteRtAnyTypeReal); + EXPECT_NEAR( + litert::ToLiteRtAny(std::any(static_cast(1.2)))->real_value, 1.2, + 1e-7); + ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1.2)))->type, + kLiteRtAnyTypeReal); + EXPECT_NEAR( + litert::ToLiteRtAny(std::any(static_cast(1.2)))->real_value, 1.2, + 1e-7); +} + +TEST(Any, ConversionString) { + constexpr const char* kTestString = "test"; + LiteRtAny litert_any; + litert_any.type = kLiteRtAnyTypeString; + litert_any.str_value = kTestString; + ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), + kTestString); + + ASSERT_EQ(litert::ToLiteRtAny(std::any("test"))->type, kLiteRtAnyTypeString); + EXPECT_STREQ(litert::ToLiteRtAny(std::any("test"))->str_value, "test"); +} + +TEST(Any, ConversionPtr) { + const void* kTestPtr = reinterpret_cast(1234); + LiteRtAny litert_any; + litert_any.type = kLiteRtAnyTypeVoidPtr; + litert_any.ptr_value = kTestPtr; + ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), kTestPtr); + + ASSERT_EQ(litert::ToLiteRtAny(std::any(kTestPtr))->type, + kLiteRtAnyTypeVoidPtr); + EXPECT_EQ(litert::ToLiteRtAny(std::any(kTestPtr))->ptr_value, kTestPtr); +} diff --git a/tflite/experimental/litert/cc/litert_buffer_ref.h b/tflite/experimental/litert/cc/litert_buffer_ref.h new file mode 100644 index 00000000..c81b5d12 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_buffer_ref.h @@ -0,0 +1,356 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_BUFFER_REF_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_BUFFER_REF_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + +namespace litert { + +//===----------------------------------------------------------------------===// +// +// << BUFFER REF >> +// +// Read, read/write, and owning views of buffers of arbitrary byte width types. +// +// Serialized model artifacts and assets are frequently large strings that with +// (annoyingly) non-standard char type and left padded. The following classes +// simplify handling such buffers in an efficient copy free manner. They also +// provide read and write left-padded aware interpretebility through standard +// signed char strings types. This is used for making manual edits to flatbuffer +// metadata or dierctly to serialized flatbuffer. +// NOTE: std::basic_xxx not supported by our C++ toolchain. +// +// Pre-allocated buffers can be transferred to these classes or allocation can +// be internalized. XBufferRefs can be implictly upcasted to non-owning +// read/write or read-only to provide other routines with an appropriate view of +// the data. E.g.: +// +// ``` +// void ReadBuffer(BufferRef r_buf) { std::cerr << r_buf.StrView(); } +// void WriteToBuffer(MutableBufferRef rw_buf) { rw_buf.WriteTo("SomeData"); } +// ... +// OwningBuffer buf(size); +// WriteToBuffer(buf); // Implicitly convert to read/write with no ownership. +// ReadBuffer(buf); // Implicitly convert to read-only. +// ``` +// +//===----------------------------------------------------------------------===// + +// Allocation/Deallocation behavior for owning buffer refs. An allocator is a +// trivially constructible/destructible object that overrides () for allocating +// and freeing memory. + +// Malloc/free based memory. +template +struct Mallocator { + void operator()(ByteT* d) { + if (d != nullptr) { + free(d); + } + } + + ByteT* operator()(size_t bytes) { + return reinterpret_cast(malloc(bytes)); + } +}; + +// New/delete based memory. +template +struct Newlocator { + void operator()(ByteT* d) { + if (d != nullptr) { + delete[] d; + } + } + + ByteT* operator()(size_t bytes) { return new ByteT[bytes]; } +}; + +// +// Read-Only Bytes +// + +// Immutable and non-owning view of a buffer. +template +class BufferRef { + public: + using TupleT = std::tuple; + + // Null buffer. + BufferRef() : size_(0), offset_(0), data_(nullptr) {} + + // Construct from already allocated buffer. Methods will only expose + // data[offset, offset + size]. + BufferRef(const ByteT* data, size_t size, size_t offset = 0) + : size_(size), offset_(offset), data_(const_cast(data)) {} + BufferRef(const void* data, size_t size, size_t offset = 0) + : size_(size), + offset_(offset), + data_(const_cast(reinterpret_cast(data))) {} + explicit BufferRef(absl::Span data) + : size_(data.size()), + offset_(0), + data_(const_cast(data.data())) {} + + // Start of actual data. + const ByteT* Data() const { return data_ + offset_; } + + // Size of actual data. + size_t Size() const { return size_ - offset_; } + + // Get buffer details in tuple form. + TupleT Get() const { return TupleT(data_, size_, offset_); } + + // Start of actual data as signed char. Might not be null terminated. + const char* StrData() const { return reinterpret_cast(Data()); } + + // Convenience view of actual data as a string. Makes null terminated. + absl::string_view StrView() const { + return absl::string_view(StrData(), Size()); + } + + // Const view of actual data. + absl::Span Span() const { + return absl::MakeConstSpan(Data(), Size()); + } + + // Copy the buffer data to a vector. + std::vector ToVec() const { + return std::vector(StrData(), StrData() + Size()); + } + + // Write the string data to a stream. + void WriteStr(std::ostream& out) const { out.write(StrData(), Size()); } + + // Print info about this buffer. + void Dump(std::ostream& out) const { + out << absl::StreamFormat("%s[%lu:%lu]\n", TypeName(), offset_, size_); + } + + BufferRef(const BufferRef& other) = default; + BufferRef& operator=(const BufferRef& other) = default; + + virtual ~BufferRef() = default; + + protected: + size_t size_; + size_t offset_; + ByteT* data_ = nullptr; + + // Debug name. + virtual absl::string_view TypeName() const { return "BufferRef"; } +}; +template +BufferRef(const ByteT*, size_t, size_t) -> BufferRef; + +// +// Read-Write Non-Owning Bytes +// + +// Writeable (but still non-owning) version of BufferRef. +template +class MutableBufferRef : public BufferRef { + public: + using TupleT = std::tuple; + + // Null buffer. + MutableBufferRef() + : BufferRef((ByteT*)nullptr, /*size*/ 0, /*offset*/ 0) {} + + // Create a mutable view from pre-allocated non-const buffer. + MutableBufferRef(ByteT* data, size_t size, size_t offset = 0) + : BufferRef(data, size, offset) {} + MutableBufferRef(void* data, size_t size, size_t offset = 0) + : BufferRef(data, size, offset) {} + explicit MutableBufferRef(absl::Span data) : BufferRef(data) {} + explicit MutableBufferRef(absl::Span data) = delete; + MutableBufferRef(const ByteT*, size_t, size_t) = delete; + MutableBufferRef(const void*, size_t, size_t) = delete; + + // Mutable start of actual data. + ByteT* Data() { return this->data_ + this->offset_; } + + // Get the mutable start of actual data as a char pointer. + char* StrData() { return reinterpret_cast(Data()); } + + // Get buffer info in tuple form. + TupleT Get() { return TupleT(this->data_, this->size_, this->offset_); } + + // Mutable span of actual data. + absl::Span Span() { return absl::MakeSpan(Data(), this->Size()); } + + // Write string into the actual buffer at offset. Returns false if the entire + // string cannot fit into the actual buffer. + bool WriteInto(absl::string_view str, size_t offset = 0) { + if (str.size() > this->Size() - offset) { + return false; + } + std::memcpy(Data() + offset, str.data(), str.size()); + return true; + } + + MutableBufferRef(const MutableBufferRef& other) = default; + MutableBufferRef& operator=(const MutableBufferRef& other) = default; + + protected: + // Debug name. + absl::string_view TypeName() const override { return "MutableBufferRef"; } +}; +template +MutableBufferRef(ByteT*, size_t, size_t) -> MutableBufferRef; + +// +// Read-Write Owning Bytes +// + +// Writable and owning buffer reference. Can allocate new buffers internally and +// take ownership of existing buffers. Does not support resizing. +template > +class OwningBufferRef : public MutableBufferRef { + public: + using TupleT = std::tuple; + using WeakTupleT = std::tuple; + + // Null buffer. + OwningBufferRef() + : MutableBufferRef(/*data*/ (ByteT*)nullptr, /*size*/ 0, + /*offset*/ 0) {} + + // Initialize a new buffer reference and allocate internally. + explicit OwningBufferRef(size_t size) + : MutableBufferRef(/*data*/ (ByteT*)nullptr, size, /*offset*/ 0) { + this->data_ = (ByteT*)Allocator()(size); + } + + // Take ownership of given buffer. + OwningBufferRef(ByteT* data, size_t size, size_t offset = 0) + : MutableBufferRef(data, size, offset) {} + OwningBufferRef(void* data, size_t size, size_t offset = 0) + : MutableBufferRef(data, size, offset) {} + explicit OwningBufferRef(absl::Span data) + : MutableBufferRef(data) {} + + // Copy the given buffer. + OwningBufferRef(const ByteT* data, size_t size) + : MutableBufferRef(/*data*/ (ByteT*)nullptr, size, + /*offset*/ 0) { + this->data_ = (ByteT*)Allocator()(size); + std::memcpy(this->data_, data, size); + } + explicit OwningBufferRef(absl::Span data) + : OwningBufferRef(data.data(), data.size()) {} + + // Copy data from givens string. + explicit OwningBufferRef(absl::string_view data) + : OwningBufferRef( + reinterpret_cast(data.data()), data.size()) {} + + // Copy data from given c-style string. + explicit OwningBufferRef(const char* data) + : OwningBufferRef(absl::string_view(data)) {} + + // Drop reference to any owned memory. + void Drop() { + this->data_ = nullptr; + this->size_ = 0; + this->offset_ = 0; + } + + // Get the buffer details and drop references to them. + TupleT Release() { + auto res = std::make_tuple(this->data_, this->size_, this->offset_); + Drop(); + return res; + } + + // Get weak references to buffer data. Takes ownership of anything that + // is swapped in. + WeakTupleT GetWeak() { + return WeakTupleT(this->data_, this->size_, this->offset_); + } + + // Free any owned memory. + void Reset() { + Allocator()(this->data_); + Drop(); + } + + // Reset any existing data and copy in given ro buffer. + void Assign(const ByteT* buf, size_t size, size_t offset = 0) { + Reset(); + this->size_ = size; + this->data_ = (ByteT*)Allocator()(this->size_); + std::memcpy(this->data_, buf, this->size_); + this->offset_ = offset; + } + + OwningBufferRef(OwningBufferRef&& other) + : MutableBufferRef(other.data_, other.size_, other.offset_) { + other.Drop(); + } + + OwningBufferRef& operator=(OwningBufferRef&& other) { + if (this != &other) { + Reset(); + this->data_ = other.data_; + this->size_ = other.size_; + this->offset_ = other.offset_; + other.Drop(); + } + return *this; + } + + OwningBufferRef(const OwningBufferRef& other) + : MutableBufferRef(/*data*/ (ByteT*)nullptr, other.size_, + other.offset_) { + Assign(other.data_, other.size_, other.offset_); + } + + OwningBufferRef& operator=(const OwningBufferRef& other) { + Assign(other.data_, other.size_, other.offset_); + return *this; + } + + ~OwningBufferRef() override { Reset(); } + + protected: + // Debug string. + absl::string_view TypeName() const override { return "OwningBufferRef"; } +}; + +template > +OwningBufferRef(const ByteT*, size_t) -> OwningBufferRef; + +template > +OwningBufferRef(ByteT*, size_t) -> OwningBufferRef; + +template > +OwningBufferRef(const char*) -> OwningBufferRef; + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_BUFFER_REF_H_ diff --git a/tflite/experimental/litert/cc/litert_buffer_ref_test.cc b/tflite/experimental/litert/cc/litert_buffer_ref_test.cc new file mode 100644 index 00000000..c031e713 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_buffer_ref_test.cc @@ -0,0 +1,332 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" + +using litert::BufferRef; +using litert::Mallocator; +using litert::MutableBufferRef; +using litert::Newlocator; +using litert::OwningBufferRef; +using litert::internal::FbBufToStr; +using testing::ElementsAreArray; +using testing::Eq; +using testing::Pointwise; +using testing::StartsWith; + +namespace { + +static constexpr size_t kOffset = 4; + +static constexpr absl::string_view kData = "SomeRawBuffer"; +static constexpr absl::string_view kOtherData = "SOMERawBuffer"; + +absl::Span MakeConstFbData(absl::string_view data) { + const uint8_t* fb_data = reinterpret_cast(data.data()); + return absl::MakeConstSpan(fb_data, data.size()); +} + +absl::Span MakeFbData(absl::string_view data) { + const uint8_t* c_fb_data = reinterpret_cast(data.data()); + uint8_t* fb_data = const_cast(c_fb_data); + return absl::MakeSpan(fb_data, data.size()); +} + +std::vector MakeFbDataVec(absl::string_view data) { + const uint8_t* c_fb_data = reinterpret_cast(data.data()); + uint8_t* fb_data = const_cast(c_fb_data); + return std::vector(fb_data, fb_data + data.size()); +} + +template , typename ByteT = uint8_t> +absl::Span MakeInternalTestBuffer(absl::string_view data) { + ByteT* buffer = Allocator()(data.size()); + std::memcpy(buffer, data.data(), data.size()); + return absl::MakeSpan(reinterpret_cast(buffer), data.size()); +} + +// +// flatbuffer_tools.h +// + +TEST(FbBufToStringTest, ConstSpan) { + EXPECT_THAT(FbBufToStr(MakeConstFbData(kData)), Pointwise(Eq(), kData)); +} + +TEST(FbBufToStringTest, Span) { + EXPECT_THAT(FbBufToStr(MakeFbData(kData)), Pointwise(Eq(), kData)); +} + +TEST(FbBufToStringTest, ConstPointer) { + auto data = MakeConstFbData(kData); + EXPECT_THAT(FbBufToStr(data.data(), data.size()), Pointwise(Eq(), kData)); +} + +TEST(FbBufToStringTest, Pointer) { + auto data = MakeFbData(kData); + EXPECT_THAT(FbBufToStr(data.data(), data.size()), Pointwise(Eq(), kData)); +} + +// +// BufferRef (read-only) +// + +TEST(BufferRefTest, Dump) { + BufferRef buf(kData.data(), kData.size()); + std::stringstream out; + buf.Dump(out); + EXPECT_THAT(out.str(), StartsWith("BufferRef")); +} + +TEST(BufferRefTest, WithData) { + auto data = MakeConstFbData(kData); + BufferRef buf(data.data(), data.size()); + EXPECT_EQ(buf.Span(), data); + EXPECT_EQ(buf.StrView(), kData); +} + +TEST(BufferRefTest, WithDataAndOffset) { + auto data = MakeConstFbData(kData); + BufferRef buf(data.data(), data.size(), kOffset); + EXPECT_EQ(buf.Span(), data.subspan(kOffset, buf.Size())); + EXPECT_EQ(buf.StrView(), kData.substr(kOffset, buf.Size())); +} + +TEST(BufferRefTest, ToVec) { + auto data = MakeConstFbData(kData); + BufferRef buf(data.data(), data.size()); + EXPECT_THAT(buf.ToVec(), ElementsAreArray(data)); +} + +TEST(BufferRefTest, WriteStr) { + auto data = MakeConstFbData(kData); + BufferRef buf(data.data(), data.size()); + std::stringstream out; + buf.WriteStr(out); + EXPECT_EQ(out.str(), kData); +} + +TEST(BufferRefTest, WriteStrOffset) { + auto data = MakeConstFbData(kData); + BufferRef buf(data.data(), data.size(), kOffset); + std::stringstream out; + buf.WriteStr(out); + EXPECT_EQ(out.str(), kData.substr(kOffset, buf.Size())); +} + +TEST(BufferRefTest, TupleGet) { + auto input = MakeConstFbData(kData); + BufferRef buf(input); + auto [data, size, offset] = buf.Get(); + ASSERT_EQ(offset, 0); + EXPECT_EQ(input, buf.Span()); +} + +// +// MutableBufferRef (read/write) +// + +TEST(MutableBufferRefTest, Dump) { + MutableBufferRef buf; + std::stringstream out; + buf.Dump(out); + EXPECT_THAT(out.str(), StartsWith("MutableBufferRef")); +} + +TEST(MutableBufferRefTest, WriteInto) { + auto v_data = MakeFbDataVec(kOtherData); + MutableBufferRef buf(v_data.data(), v_data.size()); + ASSERT_TRUE(buf.WriteInto("Some")); + EXPECT_THAT(buf.Span(), ElementsAreArray(v_data)); + EXPECT_EQ(buf.StrView(), kData); +} + +TEST(MutableBufferRefTest, WriteIntoOffsetBuf) { + auto v_data = MakeFbDataVec(kOtherData); + static constexpr absl::string_view kExpData = "RAWBuffer"; + MutableBufferRef buf(v_data.data(), v_data.size(), kOffset); + ASSERT_TRUE(buf.WriteInto("RAW")); + EXPECT_THAT(buf.Span(), ElementsAreArray(MakeConstFbData(kExpData))); + EXPECT_EQ(buf.StrView(), kExpData); +} + +TEST(MutableBufferRefTest, WriteIntoOffsetData) { + auto v_data = MakeFbDataVec(kOtherData); + static constexpr absl::string_view kExpData = "SOMERAWBuffer"; + MutableBufferRef buf(v_data.data(), v_data.size()); + ASSERT_TRUE(buf.WriteInto("RAW", kOffset)); + EXPECT_THAT(buf.Span(), ElementsAreArray(MakeConstFbData(kExpData))); + EXPECT_EQ(buf.StrView(), kExpData); +} + +TEST(MutableBufferRefTest, TupleGet) { + auto input = MakeInternalTestBuffer("FOO"); + MutableBufferRef buf(input); + auto [data, size, offset] = buf.Get(); + *data = 'b'; + EXPECT_EQ(buf.StrView(), "bOO"); + delete[] input.data(); +} + +// +// OwningBufferRef (read/write with memory management) +// + +TEST(OwningBufferRefTest, Dump) { + OwningBufferRef buf; + std::stringstream out; + buf.Dump(out); + EXPECT_THAT(out.str(), StartsWith("OwningBufferRef")); +} + +TEST(OwningBufferRefTest, MoveCstor) { + auto raw = MakeInternalTestBuffer>(kData); + OwningBufferRef> buf(raw.data(), raw.size()); + OwningBufferRef> other(std::move(buf)); + EXPECT_EQ(other.StrView(), kData); +} + +TEST(OwningBufferRefTest, MoveAssign) { + auto raw = MakeInternalTestBuffer>(kData); + OwningBufferRef> buf(raw.data(), raw.size()); + OwningBufferRef> other = std::move(buf); + EXPECT_EQ(other.StrView(), kData); +} + +TEST(OwningBufferRefTest, CopyCstor) { + auto raw = MakeInternalTestBuffer>(kData); + OwningBufferRef> buf(raw.data(), raw.size()); + OwningBufferRef> other(buf); + other.WriteInto("SOME"); + EXPECT_EQ(buf.StrView(), kData); + EXPECT_EQ(other.StrView(), "SOMERawBuffer"); +} + +TEST(OwningBufferRefTest, CopyAssign) { + auto raw = MakeInternalTestBuffer>(kData); + OwningBufferRef> buf(raw.data(), raw.size()); + OwningBufferRef> other = buf; + other.WriteInto("SOME"); + EXPECT_EQ(buf.StrView(), kData); + EXPECT_EQ(other.StrView(), "SOMERawBuffer"); +} + +TEST(OwningBufferRefTest, InternalMalloc) { + OwningBufferRef> buf(kData.size()); + ASSERT_EQ(buf.Size(), kData.size()); + ASSERT_NE(buf.Data(), nullptr); + + buf.WriteInto(kData); + EXPECT_EQ(buf.StrView(), kData); +} + +TEST(OwningBufferRefTest, InternalNew) { + OwningBufferRef buf(kData.size()); + ASSERT_EQ(buf.Size(), kData.size()); + ASSERT_NE(buf.Data(), nullptr); + + buf.WriteInto(kData); + EXPECT_EQ(buf.StrView(), kData); +} + +TEST(OwningBufferRefTest, TakeOwnershipMalloc) { + auto malloc_buffer = MakeInternalTestBuffer>(kData); + OwningBufferRef> buf(malloc_buffer.data(), + malloc_buffer.size()); + EXPECT_EQ(buf.StrView(), kData); +} + +TEST(OwningBufferRefTest, TakeOwnershipNew) { + auto new_buffer = MakeInternalTestBuffer(kData); + OwningBufferRef buf(new_buffer.data(), new_buffer.size()); + EXPECT_EQ(buf.StrView(), kData); +} + +TEST(OwningBufferRefTest, TakeOwnershipOffset) { + auto malloc_buffer = MakeInternalTestBuffer>(kData); + OwningBufferRef> buf(malloc_buffer.data(), + malloc_buffer.size(), + /*offset=*/4); + EXPECT_EQ(buf.StrView(), "RawBuffer"); +} + +TEST(OwningBufferRefTest, CopyBuffer) { + auto const_buf = MakeConstFbData(kData); + OwningBufferRef buf(const_buf.data(), const_buf.size()); + buf.WriteInto("SOME"); + EXPECT_EQ(buf.StrView(), "SOMERawBuffer"); + EXPECT_EQ(FbBufToStr(const_buf), "SomeRawBuffer"); +} + +TEST(OwningBufferRefTest, ImplicitUpCasts) { + OwningBufferRef buf(kData.size()); + BufferRef c_buf = buf; + + buf.WriteInto(kData); + EXPECT_EQ(c_buf.StrView(), buf.StrView()); +} + +TEST(OwningBufferRefTest, TupleGetWeak) { + auto input = MakeInternalTestBuffer("FOO"); + + OwningBufferRef buf; + auto [data, size, offset] = buf.GetWeak(); + + data = input.data(); + size = input.size(); + offset = 0; + + ASSERT_EQ(buf.Size(), input.size()); + ASSERT_EQ(buf.Size(), input.size()); + + buf.WriteInto("BAR"); + + EXPECT_EQ(buf.StrView(), "BAR"); + EXPECT_EQ(buf.Span(), input); +} + +TEST(OwningBufferRefTest, TupleRelease) { + OwningBufferRef buf("BAZ"); + + auto [data, size, offset] = buf.Release(); + + EXPECT_EQ(buf.Size(), 0); + EXPECT_EQ(absl::string_view(data, size), "BAZ"); + + delete[] data; +} + +TEST(OwningBufferRefTest, Assign) { + auto const_buf = MakeConstFbData(kData); + OwningBufferRef buf; + buf.Assign(const_buf.data(), const_buf.size()); + buf.WriteInto("SOME"); + EXPECT_EQ(buf.StrView(), "SOMERawBuffer"); + EXPECT_EQ(FbBufToStr(const_buf), "SomeRawBuffer"); +} + +} // namespace diff --git a/tflite/experimental/litert/cc/litert_compiled_model.cc b/tflite/experimental/litert/cc/litert_compiled_model.cc new file mode 100644 index 00000000..786ef595 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_compiled_model.cc @@ -0,0 +1,124 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/cc/litert_compiled_model.h" + +#include +#include +#include +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_compiled_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer_requirements.h" + +namespace litert { + +Expected> CompiledModel::CreateInputBuffers( + size_t signature_index) { + auto signature = model_->GetSignature(signature_index); + if (!signature) { + return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find signature"); + } + auto subgraph = model_->Subgraph(signature->Key()); + if (!subgraph) { + return Unexpected(kLiteRtStatusErrorNotFound, "Failed to get subgraph"); + } + std::vector input_buffers; + auto input_tensors = subgraph->Inputs(); + input_buffers.reserve(input_tensors.size()); + + for (int i = 0; i < input_tensors.size(); ++i) { + auto input_buffer_requirements = + GetInputBufferRequirements(signature_index, i); + if (!input_buffer_requirements) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + input_buffer_requirements.Error().Message()); + } + auto tensor_type = input_tensors[i].RankedTensorType(); + LiteRtTensorBufferType tensor_buffer_type = + (*(*input_buffer_requirements).SupportedTypes())[0]; + auto input_buffer = TensorBuffer::CreateManaged( + tensor_buffer_type, tensor_type, + (*input_buffer_requirements).BufferSize().Value()); + if (!input_buffer) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + input_buffer.Error().Message()); + } + input_buffers.push_back(std::move(*input_buffer)); + } + return std::move(input_buffers); +} + +Expected> CompiledModel::CreateOutputBuffers( + size_t signature_index) { + auto signature = model_->GetSignature(signature_index); + if (!signature) { + return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find signature"); + } + auto subgraph = model_->Subgraph(signature->Key()); + if (!subgraph) { + return Unexpected(kLiteRtStatusErrorNotFound, "Failed to get subgraph"); + } + std::vector output_buffers; + auto output_tensors = subgraph->Outputs(); + output_buffers.reserve(output_tensors.size()); + for (int i = 0; i < output_tensors.size(); ++i) { + auto output_buffer_requirements = + GetOutputBufferRequirements(signature_index, i); + if (!output_buffer_requirements.HasValue()) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + output_buffer_requirements.Error().Message()); + } + auto tensor_type = output_tensors[i].RankedTensorType(); + LiteRtTensorBufferType tensor_buffer_type = + (*(*output_buffer_requirements).SupportedTypes())[0]; + auto output_buffer = TensorBuffer::CreateManaged( + tensor_buffer_type, tensor_type, + (*output_buffer_requirements).BufferSize().Value()); + if (!output_buffer.HasValue()) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + output_buffer.Error().Message()); + } + output_buffers.push_back(std::move(*output_buffer)); + } + return std::move(output_buffers); +} + +Expected CompiledModel::Run( + size_t signature_index, const std::vector& input_buffers, + const std::vector& output_buffers) { + auto input_buffers_ptr = + std::make_unique(input_buffers.size()); + for (int i = 0; i < input_buffers.size(); ++i) { + input_buffers_ptr[i] = input_buffers[i].Get(); + } + auto output_buffers_ptr = + std::make_unique(output_buffers.size()); + for (int i = 0; i < output_buffers.size(); ++i) { + output_buffers_ptr[i] = output_buffers[i].Get(); + } + if (auto status = LiteRtRunCompiledModel( + Get(), signature_index, input_buffers.size(), input_buffers_ptr.get(), + output_buffers.size(), output_buffers_ptr.get()); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to invoke the compiled model"); + } + return {}; +} + +} // namespace litert diff --git a/tflite/experimental/litert/cc/litert_compiled_model.h b/tflite/experimental/litert/cc/litert_compiled_model.h new file mode 100644 index 00000000..72e8125e --- /dev/null +++ b/tflite/experimental/litert/cc/litert_compiled_model.h @@ -0,0 +1,133 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_COMPILED_MODEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_COMPILED_MODEL_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_compiled_model.h" +#include "tflite/experimental/litert/c/litert_compiled_model_options.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/cc/litert_detail.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_handle.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer_requirements.h" + +namespace litert { + +// The CompiledModel is a higher level inference API. It is created by +// provided model with compilation options. Internally, it instantiates runtime +// and applies Delegates mapped to the compilation options. +// It also supports getting BufferRequirements to create input/output +// TensorBuffers, and it allows to invoke the model with the input/output +// TensorBuffers. +// +// Example user flow: +// +// 1. Create CompiledModel +// 2. Query the model input/output requirements +// 3. Create input/output TensorBuffers +// 4. Fill the input TensorBuffers with input data +// 5. Invoke the model with the input/output TensorBuffers +// 6. Evaluate the output TensorBuffers + +class CompiledModel + : public internal::Handle { + public: + CompiledModel() = default; + + // Parameter `owned` indicates if the created CompiledModel object should take + // ownership of the provided `compiled_model` handle. + explicit CompiledModel(Model* model, LiteRtCompiledModel compiled_model, + bool owned = true) + : internal::Handle( + compiled_model, owned), + model_(model) {} + + // Creates a CompiledModel from a TFLite file. + // The model is loaded into memory and the caller takes ownership of the + // returned object. + static Expected Create( + litert::Model& model, + LiteRtComplicationOptions complication_options = kHwAccelDefault) { + LiteRtCompiledModel compiled_model; + if (auto status = LiteRtCreateCompiledModel( + model.Get(), complication_options, &compiled_model); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to create compiled model"); + } + return CompiledModel(&model, compiled_model); + } + + // Returns the buffer requirements for the given n-th input tensor. The + // returned TensorBufferRequirements is used to create the input tensor + // buffer. + litert::Expected GetInputBufferRequirements( + size_t signature_index, size_t input_index) { + LiteRtTensorBufferRequirements buffer_requirements; + if (auto status = LiteRtGetCompiledModelInputBufferRequirements( + Get(), signature_index, input_index, &buffer_requirements); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to get input buffer requirements"); + } + return TensorBufferRequirements(buffer_requirements, /*owned=*/false); + } + + // Returns the buffer requirements for the given output tensor. The returned + // TensorBufferRequirements is used to create the output tensor + // buffer. + litert::Expected GetOutputBufferRequirements( + size_t signature_index, size_t output_index) { + LiteRtTensorBufferRequirements buffer_requirements; + if (auto status = LiteRtGetCompiledModelOutputBufferRequirements( + Get(), signature_index, output_index, &buffer_requirements); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to get output buffer requirements"); + } + return TensorBufferRequirements(buffer_requirements, /*owned=*/false); + } + + // A helper function to creates the input tensor buffers for the given + // signature. It uses BufferRequirements and RankedTensorType to create the + // input tensor buffers. + Expected> CreateInputBuffers( + size_t signature_index); + + // A helper function to creates the output tensor buffers for the given + // signature. It uses BufferRequirements and RankedTensorType to create the + // output tensor buffers. + Expected> CreateOutputBuffers( + size_t signature_index); + + // Runs the model of the given signature with the provided input/output + // TensorBuffers. + Expected Run(size_t signature_index, + const std::vector& input_buffers, + const std::vector& output_buffers); + + private: + Model* model_; +}; + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_COMPILED_MODEL_H_ diff --git a/tflite/experimental/litert/cc/litert_compiled_model_test.cc b/tflite/experimental/litert/cc/litert_compiled_model_test.cc new file mode 100644 index 00000000..2576a332 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_compiled_model_test.cc @@ -0,0 +1,88 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/cc/litert_compiled_model.h" + +#include + +#include +#include +#include "absl/log/absl_log.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tflite/experimental/litert/test/common.h" +#include "tflite/experimental/litert/test/testdata/simple_model_test_vectors.h" + +using testing::FloatNear; +using testing::Pointwise; + +namespace litert { +namespace { + +TEST(CompiledModelTest, Basic) { + auto model = testing::LoadTestFileModel(kModelFileName); + ASSERT_TRUE(model); + + auto res_compiled_model = CompiledModel::Create(model); + ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; + + auto& compiled_model = *res_compiled_model; + auto signatures = model.GetSignatures().Value(); + EXPECT_EQ(signatures.size(), 1); + + auto signature_key = signatures[0].Key(); + EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); + size_t signature_index = 0; + + auto input_buffers_res = compiled_model.CreateInputBuffers(signature_index); + EXPECT_TRUE(input_buffers_res); + auto& input_buffers = *input_buffers_res; + + auto output_buffers_res = compiled_model.CreateOutputBuffers(signature_index); + EXPECT_TRUE(output_buffers_res); + auto& output_buffers = *output_buffers_res; + + // Fill model inputs. + auto input_names = signatures[0].InputNames(); + EXPECT_EQ(input_names.size(), 2); + EXPECT_EQ(input_names.at(0), "arg0"); + EXPECT_EQ(input_names.at(1), "arg1"); + ASSERT_TRUE(input_buffers[0].Write( + absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); + ASSERT_TRUE(input_buffers[1].Write( + absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); + + // Execute model. + compiled_model.Run(signature_index, input_buffers, output_buffers); + + // Check model output. + auto output_names = signatures[0].OutputNames(); + EXPECT_EQ(output_names.size(), 1); + EXPECT_EQ(output_names.at(0), "tfl.add"); + { + auto lock_and_addr = + litert::TensorBufferScopedLock::Create(output_buffers[0]); + ASSERT_TRUE(lock_and_addr); + auto output = absl::MakeSpan(lock_and_addr->second, kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); + } +} + +} // namespace +} // namespace litert diff --git a/tflite/experimental/litert/cc/litert_detail.h b/tflite/experimental/litert/cc/litert_detail.h new file mode 100644 index 00000000..39d8be36 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_detail.h @@ -0,0 +1,112 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_DETAIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_DETAIL_H_ + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/log/absl_check.h" +#include "tflite/experimental/litert/c/litert_common.h" + +namespace litert { + +// Expected size for inlined vectors for things like the input/outputs of ops or +// subgraphs. +static constexpr size_t kTensorVecSize = 8; +template +using SmallVec = absl::InlinedVector; + +// See "std::construct_at" from C++20. +template +T* ConstructAt(T* p, Args&&... args) { + return ::new (static_cast(p)) T(std::forward(args)...); +} + +// Reduce all over zipped iters of same size. +template +bool AllZip(const LeftVals& lhs, const RightVals& rhs, + std::function + bin_pred) { + if (lhs.size() != rhs.size()) { + return false; + } + for (auto i = 0; i < lhs.size(); ++i) { + if (!bin_pred(lhs.at(i), rhs.at(i))) { + return false; + } + } + return true; +} + +// Reduce any over zipped iters of same size. +template +bool AnyZip(const LeftVals& lhs, const RightVals& rhs, + std::function + bin_pred) { + auto neg = [&](const auto& l, const auto& r) { return !bin_pred(l, r); }; + return !(AllZip(lhs, rhs, neg)); +} + +// Does element exist in range. +template +bool Contains(It begin, It end, const T& val) { + return std::find(begin, end, val) != end; +} + +// Does element exist in range satisfying pred. +template +bool ContainsIf(It begin, It end, UPred u_pred) { + return std::find_if(begin, end, u_pred) != end; +} + +// Get the ind of the given element if it is present. +template +std::optional FindInd(It begin, It end, T val) { + auto it = std::find(begin, end, val); + return (it == end) ? std::nullopt : std::make_optional(it - begin); +} + +namespace internal { + +// Call function "get" and assert it returns value equal to given expected +// value. +template +inline void AssertEq(F get, Expected expected, Args&&... args) { + auto status = get(std::forward(args)...); + ABSL_CHECK_EQ(status, expected); +} + +// Call function "get" and assert it returns true. +template +inline void AssertTrue(F get, Args&&... args) { + AssertEq(get, true, std::forward(args)...); +} + +// Call function "get" and assert it returns an OK LiteRtStatus. +template +inline void AssertOk(F get, Args&&... args) { + AssertEq(get, kLiteRtStatusOk, std::forward(args)...); +} + +} // namespace internal +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_DETAIL_H_ diff --git a/tflite/experimental/litert/cc/litert_element_type.h b/tflite/experimental/litert/cc/litert_element_type.h new file mode 100644 index 00000000..78db7be2 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_element_type.h @@ -0,0 +1,154 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ELEMENT_TYPE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ELEMENT_TYPE_H_ + +#include +#include +#include + +#include "tflite/experimental/litert/c/litert_model.h" + +namespace litert { + +// Data type of tensor elements. C++ equivalent to LiteRtElementType. +enum class ElementType { + None = kLiteRtElementTypeNone, + Bool = kLiteRtElementTypeBool, + Int4 = kLiteRtElementTypeInt4, + Int8 = kLiteRtElementTypeInt8, + Int16 = kLiteRtElementTypeInt16, + Int32 = kLiteRtElementTypeInt32, + Int64 = kLiteRtElementTypeInt64, + UInt8 = kLiteRtElementTypeUInt8, + UInt16 = kLiteRtElementTypeUInt16, + UInt32 = kLiteRtElementTypeUInt32, + UInt64 = kLiteRtElementTypeUInt64, + Float16 = kLiteRtElementTypeFloat16, + BFloat16 = kLiteRtElementTypeBFloat16, + Float32 = kLiteRtElementTypeFloat32, + Float64 = kLiteRtElementTypeFloat64, + Complex64 = kLiteRtElementTypeComplex64, + Complex128 = kLiteRtElementTypeComplex128, + TfResource = kLiteRtElementTypeTfResource, + TfString = kLiteRtElementTypeTfString, + TfVariant = kLiteRtElementTypeTfVariant, +}; + +// Get number of bytes of a single element of given type. +inline constexpr std::optional GetByteWidth(ElementType ty) { + if (ty == ElementType::Bool) + return 1; + else if (ty == ElementType::Int8) + return 1; + else if (ty == ElementType::Int16) + return 2; + else if (ty == ElementType::Int32) + return 4; + else if (ty == ElementType::Int64) + return 8; + else if (ty == ElementType::UInt8) + return 1; + else if (ty == ElementType::UInt16) + return 2; + else if (ty == ElementType::UInt32) + return 4; + else if (ty == ElementType::UInt64) + return 8; + else if (ty == ElementType::Float16) + return 2; + else if (ty == ElementType::BFloat16) + return 2; + else if (ty == ElementType::Float32) + return 4; + else if (ty == ElementType::Float64) + return 8; + else + return std::nullopt; +} + +// Get number of bytes of a single element of given type via template. +template +inline constexpr size_t GetByteWidth() { + constexpr auto byte_width = GetByteWidth(Ty); + static_assert(byte_width.has_value(), "Type does not have byte width"); + return byte_width.value(); +} + +// Get the litert::ElementType associated with given C++ type. +template +inline constexpr ElementType GetElementType() { + static_assert(false, "Uknown C++ type"); + return ElementType::None; +} + +template <> +inline constexpr ElementType GetElementType() { + return ElementType::Bool; +} + +template <> +inline constexpr ElementType GetElementType() { + return ElementType::Int8; +} + +template <> +inline constexpr ElementType GetElementType() { + return ElementType::UInt8; +} + +template <> +inline constexpr ElementType GetElementType() { + return ElementType::Int16; +} + +template <> +inline constexpr ElementType GetElementType() { + return ElementType::UInt16; +} + +template <> +inline constexpr ElementType GetElementType() { + return ElementType::Int32; +} + +template <> +inline constexpr ElementType GetElementType() { + return ElementType::UInt32; +} + +template <> +inline constexpr ElementType GetElementType() { + return ElementType::Int64; +} + +template <> +inline constexpr ElementType GetElementType() { + return ElementType::UInt64; +} + +template <> +inline constexpr ElementType GetElementType() { + return ElementType::Float32; +} + +template <> +inline constexpr ElementType GetElementType() { + return ElementType::Float64; +} + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ELEMENT_TYPE_H_ diff --git a/tflite/experimental/litert/cc/litert_element_type_test.cc b/tflite/experimental/litert/cc/litert_element_type_test.cc new file mode 100644 index 00000000..800cfe26 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_element_type_test.cc @@ -0,0 +1,48 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/cc/litert_element_type.h" + +#include +#include + +#include + +namespace litert { + +namespace { + +template +class ElementTypeTest : public ::testing::Test { + public: + size_t Size() const { return sizeof(T); } +}; + +TYPED_TEST_SUITE_P(ElementTypeTest); + +TYPED_TEST_P(ElementTypeTest, TypeAndSize) { + const size_t size = GetByteWidth()>(); + EXPECT_EQ(size, this->Size()); +} + +REGISTER_TYPED_TEST_SUITE_P(ElementTypeTest, TypeAndSize); + +using Types = + ::testing::Types; + +INSTANTIATE_TYPED_TEST_SUITE_P(ElementTypeTestSuite, ElementTypeTest, Types); + +} // namespace +} // namespace litert diff --git a/tflite/experimental/litert/cc/litert_environment.h b/tflite/experimental/litert/cc/litert_environment.h new file mode 100644 index 00000000..3ae116f3 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_environment.h @@ -0,0 +1,82 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ENVIRONMENT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ENVIRONMENT_H_ + +#include +#include + +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_environment.h" +#include "tflite/experimental/litert/cc/litert_any.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +namespace litert { + +class Environment { + public: + enum class OptionTag { + CompilerPluginLibraryPath = kLiteRtEnvOptionTagCompilerPluginLibraryPath, + DispatchLibraryPath = kLiteRtEnvOptionTagDispatchLibraryPath, + }; + + struct Option { + OptionTag tag; + std::any value; + }; + + static Expected Create(absl::Span options) { + auto c_options = ConvertOptions(options); + if (!c_options) { + return c_options.Error(); + } + if (auto status = + LiteRtEnvironmentCreate(c_options->size(), c_options->data()); + status != kLiteRtStatusOk) { + return Error(status); + } else { + return {}; + } + } + + static void Destroy() { LiteRtEnvironmentDestroy(); } + + private: + static Expected> ConvertOptions( + absl::Span options) { + std::vector c_options; + c_options.reserve(options.size()); + + for (auto& option : options) { + auto litert_any = ToLiteRtAny(option.value); + if (!litert_any) { + return litert_any.Error(); + } + + LiteRtEnvOption c_option = { + /*.tag=*/static_cast(option.tag), + /*.value=*/*litert_any, + }; + c_options.push_back(c_option); + } + + return c_options; + } +}; + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ENVIRONMENT_H_ diff --git a/tflite/experimental/litert/cc/litert_expected.h b/tflite/experimental/litert/cc/litert_expected.h new file mode 100644 index 00000000..e167a64b --- /dev/null +++ b/tflite/experimental/litert/cc/litert_expected.h @@ -0,0 +1,338 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_EXPECTED_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_EXPECTED_H_ + +#include +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_detail.h" + +namespace litert { + +// An "Expected" incapsulates the result of some routine which may have an +// unexpected result. Unexpected results in this context are a standard +// LiteRtStatus plus extra usability data such as error messages. This is +// similar to an absl::StatusOr or std::expected (C++23) but better integrated +// with LiteRtStatus as the canonical status code. + +// C++ wrapper around LiteRtStatus code. Provides a status as well +// as an error message. +class Error { + public: + // Construct Unexpected from status and optional error message. NOTE: + // kLiteRtStatusOk should not be passed to Unexpected. + explicit Error(LiteRtStatus status, absl::string_view message = "") + : status_(status), message_(message) { + ABSL_DCHECK(status != kLiteRtStatusOk); + } + + // Get the status. + constexpr LiteRtStatus Status() const { return status_; } + + // Get the error message, empty string if none was attached. + constexpr absl::string_view Message() const { return message_; } + + private: + LiteRtStatus status_; + absl::string_view message_; +}; + +class Unexpected { + public: + template + constexpr explicit Unexpected(Args&&... args) + : error_(std::forward(args)...) {} + + // Allow for implicit conversion from convertible Error value inplace. + // NOLINTNEXTLINE + Unexpected(class Error&& e) : error_(std::move(e)) {} + + Unexpected(Unexpected&& other) = default; + Unexpected(const Unexpected& other) = default; + Unexpected& operator=(Unexpected&& other) = default; + Unexpected& operator=(const Unexpected& other) = default; + + constexpr const class Error& Error() const& noexcept { return error_; } + constexpr class Error& Error() & noexcept { return error_; } + constexpr const class Error&& Error() const&& noexcept { + return std::move(error_); + } + constexpr class Error&& Error() && noexcept { return std::move(error_); } + + private: + class Error error_; +}; + +// Utility for generic return values that may be a statused failure. +// Expecteds store and own the lifetime of either an Unexpected, or a T. +// T may be any type, primitive or non-primitive. +// +// No dynamic allocations occur during initialization, +// so the underlying T is only movable (as opposed to something like "release"). +// Arguments should be constructed inplace at the time of initilizing +// the expcted if possible. +// +// Unexpected&& and T&& may be implicitly casted +// to an Expected. For example, +// +// Expected Bar() { +// bool success = ... +// if (!success) { return Unexpected(kLiteRtStatus, "Bad Baz"); } +// return Foo(); +// } +// +template +class Expected { + public: + // Construct Expected with T inplace. + + // Construct T from initializer list inplace. + template + Expected(std::initializer_list il) : has_value_(true), value_(il) {} + + // Construct T from forwarded args inplace. + template + explicit Expected(Args&&... args) + : has_value_(true), value_(std::forward(args)...) {} + + // Allow for implicit conversion from convertible T value inplace. + // NOLINTNEXTLINE + Expected(const T& t) : has_value_(true), value_(t) {} + // NOLINTNEXTLINE + Expected(T&& t) : has_value_(true), value_(std::move(t)) {} + + // Construct from Unexpected inplace. + + // Allow for implicit conversion from Error. + // NOLINTNEXTLINE + Expected(const Unexpected& err) : has_value_(false), unexpected_(err) {} + // NOLINTNEXTLINE + Expected(Unexpected&& err) : has_value_(false), unexpected_(std::move(err)) {} + // NOLINTNEXTLINE + Expected(const class Error& e) : has_value_(false), unexpected_(e) {} + + // Copy/move + + Expected(Expected&& other) : has_value_(other.HasValue()) { + if (HasValue()) { + ConstructAt(std::addressof(value_), std::move(other.value_)); + } else { + ConstructAt(std::addressof(unexpected_), std::move(other.unexpected_)); + } + } + + Expected(const Expected& other) : has_value_(other.has_value_) { + if (HasValue()) { + ConstructAt(std::addressof(value_), other.value_); + value_ = other.value_; + } else { + ConstructAt(std::addressof(unexpected_), other.unexpected_); + } + } + + Expected& operator=(Expected&& other) { + if (this != &other) { + Expected::~Expected(); + has_value_ = other.has_value_; + if (HasValue()) { + value_ = std::move(other.Value()); + } else { + unexpected_ = std::move(other.unexpected_); + } + } + return *this; + } + + Expected& operator=(const Expected& other) { + ~Expected(); + has_value_ = other.has_value_; + if (HasValue()) { + value_ = other.value_; + } else { + unexpected_ = other.unexpected_; + } + return *this; + } + + ~Expected() { + if (has_value_ && std::is_destructible()) { + value_.~T(); + } else { + unexpected_.~Unexpected(); + } + } + + // Observers for T value, program exits if it doesn't have one. + const T& Value() const& { + CheckVal(); + return value_; + } + + T& Value() & { + CheckVal(); + return value_; + } + + const T&& Value() const&& { + CheckVal(); + return std::move(value_); + } + + T&& Value() && { + CheckVal(); + return std::move(value_); + } + + const T* operator->() const { + CheckVal(); + return &value_; + } + + T* operator->() { + CheckVal(); + return &value_; + } + + const T& operator*() const& { return Value(); } + + T& operator*() & { return Value(); } + + const T&& operator*() const&& { return std::move(Value()); } + + T&& operator*() && { return std::move(Value()); } + + // Observer for Unexpected, program exits if it doesn't have one. + const class Error& Error() const& { + CheckNoVal(); + return unexpected_.Error(); + } + + class Error& Error() & { + CheckNoVal(); + return unexpected_.Error(); + } + + const class Error&& Error() const&& { + CheckNoVal(); + return std::move(unexpected_.Error()); + } + + class Error&& Error() && { + CheckNoVal(); + return std::move(unexpected_.Error()); + } + + // Does this expected contain a T Value. It contains an unexpected if not. + bool HasValue() const { return has_value_; } + + // Convert to bool for HasValue. + explicit operator bool() const { return HasValue(); } + + private: + bool has_value_; + union { + T value_; + Unexpected unexpected_; + }; + void CheckNoVal() const { ABSL_CHECK(!HasValue()); } + void CheckVal() const { ABSL_CHECK(HasValue()); } +}; + +template <> +class Expected { + public: + // Implicit construction is used to simplify returning a valid value, e.g., in + // "return {};" + Expected() : has_value_(true) {} + + // Construct from Unexpected inplace. + + // Allow for implicit conversion from Error. + // NOLINTNEXTLINE + Expected(const Unexpected& err) : has_value_(false), unexpected_(err) {} + // NOLINTNEXTLINE + Expected(Unexpected&& err) : has_value_(false), unexpected_(std::move(err)) {} + // NOLINTNEXTLINE + Expected(const Error& e) : has_value_(false), unexpected_(e) {} + + ~Expected() { + if (!has_value_) { + unexpected_.~Unexpected(); + } + } + + Expected& operator=(Expected&& other) { + if (this != &other) { + Expected::~Expected(); + has_value_ = other.has_value_; + unexpected_ = std::move(other.unexpected_); + } + return *this; + } + + Expected& operator=(const Expected& other) { + if (this != &other) { + Expected::~Expected(); + has_value_ = other.has_value_; + unexpected_ = other.unexpected_; + } + return *this; + } + + // Observer for Unexpected, program exits if it doesn't have one. + const class Error& Error() const& { + CheckNoVal(); + return unexpected_.Error(); + } + + class Error& Error() & { + CheckNoVal(); + return unexpected_.Error(); + } + + const class Error&& Error() const&& { + CheckNoVal(); + return std::move(unexpected_.Error()); + } + + class Error&& Error() && { + CheckNoVal(); + return std::move(unexpected_.Error()); + } + + // Does this expected contain a T Value. It contains an unexpected if not. + bool HasValue() const { return has_value_; } + + // Convert to bool for HasValue. + explicit operator bool() const { return HasValue(); } + + private: + bool has_value_; + union { + Unexpected unexpected_; + }; + void CheckNoVal() const { ABSL_CHECK(!HasValue()); } + void CheckVal() const { ABSL_CHECK(HasValue()); } +}; + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_EXPECTED_H_ diff --git a/tflite/experimental/litert/cc/litert_expected_test.cc b/tflite/experimental/litert/cc/litert_expected_test.cc new file mode 100644 index 00000000..ac1332d4 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_expected_test.cc @@ -0,0 +1,191 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/cc/litert_expected.h" + +#include +#include +#include +#include +#include + +#include +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" + +namespace litert { + +namespace { + +static constexpr LiteRtStatus kErrorStatus = kLiteRtStatusErrorInvalidArgument; + +struct TypeWithAllocation { + TypeWithAllocation(std::initializer_list il) : allocated(il) {} + std::vector allocated; +}; + +struct TypeWithFields { + TypeWithFields(int i_, int j_) : i(i_), j(j_) {} + int i; + int j; +}; + +TEST(ExpectedTest, PrimitiveExplicit) { + Expected exp(1.0); + ASSERT_TRUE(exp.HasValue()); +} + +TEST(ExpectedTest, PrimitiveImplicit) { + Expected exp = 1.0; + ASSERT_TRUE(exp.HasValue()); +} + +TEST(ExpectedTest, ClassWithAllocation) { + Expected exp(TypeWithAllocation({1, 2, 3})); + ASSERT_TRUE(exp.HasValue()); +} + +TEST(ExpectedTest, ClassWithFields) { + Expected exp(TypeWithFields(1, 2)); + ASSERT_TRUE(exp.HasValue()); +} + +TEST(ExpectedTest, FromErrorExplicit) { + Expected exp((Unexpected(kErrorStatus, "MESSAGE"))); + ASSERT_FALSE(exp.HasValue()); +} + +TEST(ExpectedTest, FromErrorImplicit) { + Expected exp = Unexpected(kErrorStatus); + ASSERT_FALSE(exp.HasValue()); +} + +TEST(ExpectedTest, CopyCstorError) { + const Expected exp = Unexpected(kErrorStatus); + Expected other(exp); + ASSERT_FALSE(other.HasValue()); + EXPECT_EQ(other.Error().Status(), kErrorStatus); +} + +TEST(ExpectedTest, CopyCstorVal) { + const Expected exp = 2; + Expected other(exp); + ASSERT_TRUE(other.HasValue()); + EXPECT_EQ(other.Value(), 2); +} + +TEST(ExpectedTest, CopyAssignError) { + const Expected exp = Unexpected(kErrorStatus); + ASSERT_FALSE(exp.HasValue()); + Expected other = exp; + ASSERT_FALSE(other.HasValue()); + EXPECT_EQ(other.Error().Status(), kErrorStatus); +} + +TEST(ExpectedTest, CopyAssignVal) { + const Expected exp = 2; + Expected other = exp; + ASSERT_TRUE(other.HasValue()); + EXPECT_EQ(other.Value(), 2); +} + +TEST(ExpectedTest, MoveCstorError) { + Expected exp = Unexpected(kErrorStatus); + Expected other(std::move(exp)); + ASSERT_FALSE(other.HasValue()); + EXPECT_EQ(other.Error().Status(), kErrorStatus); +} + +TEST(ExpectedTest, MoveCstorVal) { + Expected exp = 2; + Expected other(std::move(exp)); + ASSERT_TRUE(other.HasValue()); + EXPECT_EQ(other.Value(), 2); +} + +TEST(ExpectedTest, MoveAssignError) { + Expected exp = Unexpected(kErrorStatus); + Expected other = std::move(exp); + ASSERT_FALSE(other.HasValue()); + EXPECT_EQ(other.Error().Status(), kErrorStatus); +} + +TEST(ExpectedTest, MoveAssignVal) { + Expected exp = 2; + Expected other = std::move(exp); + ASSERT_TRUE(other.HasValue()); + EXPECT_EQ(other.Value(), 2); +} + +TEST(ExpectedTest, Indirection) { + Expected exp(TypeWithFields(1, 2)); + EXPECT_EQ(exp->i, 1); + EXPECT_EQ(exp->j, 2); +} + +TEST(ExpectedTest, Dereference) { + Expected exp(TypeWithFields(1, 2)); + const auto& val = *exp; + EXPECT_EQ(val.i, 1); + EXPECT_EQ(val.j, 2); +} + +TEST(UnexpectedTest, WithStatus) { + Unexpected err(kErrorStatus); + EXPECT_EQ(err.Error().Status(), kErrorStatus); + EXPECT_TRUE(err.Error().Message().empty()); +} + +TEST(UnexpectedTest, WithMessage) { + Unexpected err(kErrorStatus, "MESSAGE"); + EXPECT_EQ(err.Error().Status(), kErrorStatus); + EXPECT_EQ(err.Error().Message(), "MESSAGE"); +} + +Expected> Go() { + std::string data = "21234"; + OwningBufferRef buf(data.c_str()); + return buf; +} + +Expected> Forward() { + auto thing = Go(); + if (!thing.HasValue()) { + return thing.Error(); + } + // No copy ellision here. + return thing; +} + +TEST(ExpectedTest, ForwardBufThroughFuncs) { + auto res = Forward(); + EXPECT_TRUE(res.HasValue()); + EXPECT_EQ(res->StrView(), "21234"); +} + +TEST(ExpectedWithNoValue, WithoutError) { + Expected expected = {}; + EXPECT_TRUE(expected.HasValue()); +} + +TEST(ExpectedWithNoValue, WithError) { + Expected expected(Unexpected(kErrorStatus, "MESSAGE")); + EXPECT_FALSE(expected.HasValue()); + EXPECT_EQ(expected.Error().Status(), kErrorStatus); + EXPECT_EQ(expected.Error().Message(), "MESSAGE"); +} + +} // namespace + +} // namespace litert diff --git a/tflite/experimental/litert/cc/litert_handle.h b/tflite/experimental/litert/cc/litert_handle.h new file mode 100644 index 00000000..503eaad3 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_handle.h @@ -0,0 +1,74 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_HANDLE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_HANDLE_H_ + +#include +#include + +namespace litert { +namespace internal { + +template +inline void DummyDeleter(H) {} + +// This class is used to wrap and manage the lifetime of opaque handles from the +// C API into an equivalent C++ object. The class is a wrapper on +// std::unique_ptr<> that has a default constructor and doesn't crash if the +// deleter is null. +template +class Handle { + public: + Handle() = default; + explicit Handle(H handle, bool owned) noexcept + : ptr_(handle, owned ? deleter : DummyDeleter) {} + + Handle(Handle&& other) noexcept { *this = std::move(other); } + + Handle& operator=(Handle&& other) noexcept { + std::swap(ptr_, other.ptr_); + return *this; + } + + // Return true if the underlying LiteRtTensorBuffer handle is valid. + explicit operator bool() const noexcept { return static_cast(ptr_); } + + // Return the underlying LiteRtTensorBuffer handle. + H Get() const noexcept { return ptr_.get(); } + + H Release() noexcept { return ptr_.release(); } + + bool IsOwned() const noexcept { + return ptr_.get_deleter() != DummyDeleter; + } + + private: + std::unique_ptr, void (*)(H)> ptr_ = {nullptr, + DummyDeleter}; +}; + +// This class is similar to Handle, but the managed opaque handle is not owned +// (i.e., it will not be destroyed). +template +class NonOwnedHandle : public Handle> { + public: + explicit NonOwnedHandle(H handle) noexcept + : Handle>(handle, /*owned=*/false) {} +}; + +} // namespace internal +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_HANDLE_H_ diff --git a/tflite/experimental/litert/cc/litert_layout.h b/tflite/experimental/litert/cc/litert_layout.h new file mode 100644 index 00000000..27a24432 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_layout.h @@ -0,0 +1,102 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_LAYOUT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_LAYOUT_H_ + +#include +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_layout.h" + +namespace litert { + +// Small standalone helper functions for working with +// c layout api. + +static constexpr size_t kTensorMaxRank = LITERT_TENSOR_MAX_RANK; + +// Build layout from given iterator of dimensions. +template +inline constexpr LiteRtLayout BuildLayout(Begin begin, End end, + const uint32_t* strides = nullptr) { + LiteRtLayout res(end - begin, {}, strides); + auto i = 0; + + for (auto* it = begin; it < end && i < kTensorMaxRank; ++it) { + res.dimensions[i] = *it; + ++i; + } + + return res; +} + +// Build layout from given iterable of dimensions. +template +inline constexpr LiteRtLayout BuildLayout(const Dims& dims, + const uint32_t* strides = nullptr) { + return BuildLayout(std::cbegin(dims), std::cend(dims), strides); +} + +// Build layout from literal dimensions. +inline constexpr LiteRtLayout BuildLayout(std::initializer_list dims, + const uint32_t* strides = nullptr) { + return BuildLayout(dims.begin(), dims.end(), strides); +} + +// Compute the number of elements in dims iterator. Nullopt if there exists +// a dynamic dimension. +template +inline constexpr std::optional NumElements(Begin begin, End end) { + if (end - begin == 0) { + return {}; + } + size_t res = 1; + for (auto* it = begin; it < end; ++it) { + if (*it < 0) { + return {}; + } + res *= *it; + } + return res; +} + +// Override for layouts. +inline constexpr std::optional NumElements(const LiteRtLayout& layout) { + auto* b = std::cbegin(layout.dimensions); + return NumElements(b, b + layout.rank); +} + +// Get dims as span. +inline constexpr absl::Span DimsSpan( + const LiteRtLayout& layout) { + return absl::MakeConstSpan(layout.dimensions, layout.rank); +} + +// Get strides as span if they exist. +inline constexpr std::optional> StridesSpan( + const LiteRtLayout& layout) { + if (layout.strides) { + return absl::MakeConstSpan(layout.strides, layout.rank); + } + return {}; +} + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_LAYOUT_H_ diff --git a/tflite/experimental/litert/cc/litert_layout_test.cc b/tflite/experimental/litert/cc/litert_layout_test.cc new file mode 100644 index 00000000..4d0fed88 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_layout_test.cc @@ -0,0 +1,62 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/cc/litert_layout.h" + +#include + +#include +#include + +namespace litert { +namespace { + +using ::testing::ElementsAreArray; + +static constexpr int32_t kStaticDims[] = {2, 2}; +static constexpr int32_t kDynDims[] = {-1, 2}; +static constexpr uint32_t kStrides[] = {1, 1}; + +TEST(LayoutTest, BuildFromDims) { + auto layout = BuildLayout(kStaticDims); + EXPECT_EQ(layout.rank, 2); + EXPECT_THAT(DimsSpan(layout), ElementsAreArray(kStaticDims)); + EXPECT_EQ(layout.strides, nullptr); + EXPECT_FALSE(StridesSpan(layout).has_value()); +} + +TEST(LayoutTest, BuildFromDimsWithStrides) { + auto layout = BuildLayout(kStaticDims, kStrides); + EXPECT_EQ(layout.rank, 2); + EXPECT_THAT(DimsSpan(layout), ElementsAreArray(kStaticDims)); + auto strides = StridesSpan(layout); + ASSERT_TRUE(strides.has_value()); + EXPECT_THAT(*strides, ElementsAreArray(kStrides)); +} + +TEST(LayoutTest, NumElements) { + auto layout = BuildLayout(kStaticDims); + auto num_elements = NumElements(layout); + ASSERT_TRUE(num_elements.has_value()); + EXPECT_EQ(*num_elements, 4); +} + +TEST(LayoutTest, NumElementsDynamic) { + auto layout = BuildLayout(kDynDims); + auto num_elements = NumElements(layout); + ASSERT_FALSE(num_elements.has_value()); +} + +} // namespace +} // namespace litert diff --git a/tflite/experimental/litert/cc/litert_macros.h b/tflite/experimental/litert/cc/litert_macros.h new file mode 100644 index 00000000..b5f24c60 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_macros.h @@ -0,0 +1,67 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MACROS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MACROS_H_ + +#include "absl/log/absl_check.h" +#include "tflite/experimental/litert/c/litert_common.h" // IWYU pragma: keep +#include "tflite/experimental/litert/c/litert_logging.h" // IWYU pragma: keep + +#define _CONCAT_NAME_IMPL(x, y) x##y + +#define _CONCAT_NAME(x, y) _CONCAT_NAME_IMPL(x, y) + +#define _RETURN_VAL(val) return val + +#define LITERT_CHECK_STATUS_HAS_CODE(expr, code) ABSL_CHECK(expr == code); + +#define LITERT_CHECK_STATUS_OK(expr) \ + LITERT_CHECK_STATUS_HAS_CODE(expr, kLiteRtStatusOk); + +#define LITERT_ENSURE_SUPPORTED(cond, msg) \ + if (!(cond)) { \ + LITERT_LOG(LITERT_ERROR, "%s", msg); \ + return kLiteRtStatusErrorUnsupported; \ + } + +#define LITERT_ENSURE(expr, fail_stat, msg) \ + if (!(expr)) { \ + LITERT_LOG(LITERT_ERROR, "%s", msg); \ + return fail_stat; \ + } + +#define LITERT_RETURN_STATUS_IF_NOT_OK(expr) \ + if (LiteRtStatus status = expr; status != kLiteRtStatusOk) return status; + +#define LITERT_RETURN_STATUS_IF_NOT_OK_OR_NOT_MATCHED(expr) \ + if (LiteRtStatus status = expr; \ + (status != kLiteRtStatusOk && status != kLiteRtStatusLegalizeNoMatch)) \ + return status; + +#define LITERT_RETURN_VAL_IF_NOT_OK(expr, ret_val) \ + if (LiteRtStatus status = expr; status != kLiteRtStatusOk) return ret_val; + +#define LITERT_STACK_ARRAY(ty, var, size, init) \ + ty* var = (ty*)alloca(sizeof(ty) * size); \ + for (ty* e = var; e < var + size; ++e) { \ + *e = init; \ + } + +#define LITERT_EXPECT_OK(status) \ + if (auto stat = (status); stat != kLiteRtStatusOk) { \ + return ::litert::Unexpected(stat); \ + } + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MACROS_H_ diff --git a/tflite/experimental/litert/cc/litert_model.cc b/tflite/experimental/litert/cc/litert_model.cc new file mode 100644 index 00000000..1eecd36e --- /dev/null +++ b/tflite/experimental/litert/cc/litert_model.cc @@ -0,0 +1,45 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/cc/litert_model.h" + +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_detail.h" + +namespace litert { + +bool Tensor::IsSubgraphOutput() const { return Uses().empty(); } + +bool Tensor::IsSubgraphInput() const { + return !HasWeights() && !DefiningOp().has_value(); +} + +bool Tensor::IsConstant() const { + return HasWeights() && !DefiningOp().has_value(); +} + +SmallVec Tensor::Uses() const { + LiteRtParamIndex num_uses; + LiteRtOpArray users; + LiteRtParamIndex* user_arg_inds; + litert::internal::AssertOk(LiteRtGetTensorUses, Get(), &num_uses, &users, + &user_arg_inds); + SmallVec res; + for (int i = 0; i < num_uses; ++i) { + res.push_back(Tensor::TensorUse{Op(users[i]), user_arg_inds[i]}); // NOLINT + } + return res; +} + +} // namespace litert diff --git a/tflite/experimental/litert/cc/litert_model.h b/tflite/experimental/litert/cc/litert_model.h new file mode 100644 index 00000000..564fdf4d --- /dev/null +++ b/tflite/experimental/litert/cc/litert_model.h @@ -0,0 +1,490 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_detail.h" +#include "tflite/experimental/litert/cc/litert_element_type.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_handle.h" +#include "tflite/experimental/litert/cc/litert_layout.h" + +namespace litert { + +// Tensor layout. C++ equivalent to LiteRtLayout. +class Layout { + public: + explicit Layout(SmallVec&& dimensions, + SmallVec&& strides = SmallVec()) + : dimensions_(std::move(dimensions)), strides_(std::move(strides)) {} + + explicit Layout(const LiteRtLayout& layout) + : dimensions_(layout.dimensions, layout.dimensions + layout.rank) { + if (layout.strides) { + strides_.reserve(layout.rank); + std::copy(layout.strides, layout.strides + layout.rank, + std::back_inserter(strides_)); + } + } + + explicit operator LiteRtLayout() const { + auto res = BuildLayout(dimensions_); + res.strides = HasStrides() ? strides_.data() : nullptr; + return res; + } + + bool operator==(const Layout& other) const { + return dimensions_ == other.dimensions_ && strides_ == other.strides_; + } + + uint32_t Rank() const { return dimensions_.size(); } + + absl::Span Dimensions() const { + return absl::MakeSpan(dimensions_.data(), dimensions_.size()); + } + + bool HasStrides() const { return !strides_.empty(); } + + absl::Span Strides() const { + const uint32_t* data = HasStrides() ? strides_.data() : nullptr; + auto size = HasStrides() ? Rank() : 0; + return absl::MakeSpan(data, size); + } + + // Get the number of scalar elements in this tensor type. std::nullopt if + // not fully static. + std::optional NumElements() const { + return ::litert::NumElements(dimensions_.cbegin(), dimensions_.cend()); + } + + private: + SmallVec dimensions_; + SmallVec strides_; +}; + +// Type for tensors with known dimensions. C++ equivalent to +// LiteRtRankedTensorType. +class RankedTensorType { + public: + RankedTensorType(enum ElementType element_type, class Layout&& layout) + : element_type_(element_type), layout_(std::move(layout)) {} + explicit RankedTensorType(const LiteRtRankedTensorType& type) + : element_type_(static_cast(type.element_type)), + layout_(type.layout) {} + + explicit operator LiteRtRankedTensorType() const { + return LiteRtRankedTensorType{ + /*.element_type=*/static_cast(element_type_), + /*layout=*/static_cast(layout_), + }; + } + + bool operator==(const RankedTensorType& other) const { + return ElementType() == other.ElementType() && Layout() == other.Layout(); + } + + enum ElementType ElementType() const { return element_type_; } + + const class Layout& Layout() const { return layout_; } + + private: + enum ElementType element_type_; + class Layout layout_; +}; + +// Tensor weights. C++ equivalent of LiteRtWeights. +class Weights : public internal::NonOwnedHandle { + public: + Weights() = default; + explicit Weights(LiteRtWeights weights) + : internal::NonOwnedHandle(weights) {} + + absl::Span Bytes() const { + size_t size; + const void* addr; + internal::AssertOk(LiteRtGetWeightsBytes, Get(), &addr, &size); + return absl::MakeSpan(static_cast(addr), size); + } +}; + +// Tensor. C++ equivalent of LiteRtTensor. +class Tensor : public internal::NonOwnedHandle { + public: + Tensor() = default; + explicit Tensor(LiteRtTensor tensor) + : internal::NonOwnedHandle(tensor) {} + + LiteRtTensorTypeId TypeId() const { + LiteRtTensorTypeId type_id; + internal::AssertOk(LiteRtGetTensorTypeId, Get(), &type_id); + return type_id; + } + + LiteRtUnrankedTensorType UnrankedTensorType() const { + internal::AssertEq([&]() { return TypeId(); }, kLiteRtUnrankedTensorType); + LiteRtUnrankedTensorType unranked_tensor_type; + internal::AssertOk(LiteRtGetUnrankedTensorType, Get(), + &unranked_tensor_type); + return unranked_tensor_type; + } + + class RankedTensorType RankedTensorType() const { + internal::AssertEq([&]() { return TypeId(); }, kLiteRtRankedTensorType); + LiteRtRankedTensorType ranked_tensor_type; + internal::AssertOk(LiteRtGetRankedTensorType, Get(), &ranked_tensor_type); + return litert::RankedTensorType(ranked_tensor_type); + } + + LiteRtQuantizationTypeId QTypeId() const { + LiteRtQuantizationTypeId q_type_id; + internal::AssertOk(LiteRtGetQuantizationTypeId, Get(), &q_type_id); + return q_type_id; + } + + bool HasQuantization() const { return QTypeId() != kLiteRtQuantizationNone; } + + LiteRtQuantizationPerTensor PerTensorQuantization() const { + internal::AssertEq([&]() { return QTypeId(); }, + kLiteRtQuantizationPerTensor); + LiteRtQuantizationPerTensor per_tensor_quantization; + internal::AssertOk(LiteRtGetPerTensorQuantization, Get(), + &per_tensor_quantization); + return per_tensor_quantization; + } + + LiteRtQuantizationPerChannel PerChannelQuantization() const { + internal::AssertEq([&]() { return QTypeId(); }, + kLiteRtQuantizationPerChannel); + LiteRtQuantizationPerChannel per_channel_quantization; + internal::AssertOk(LiteRtGetPerChannelQuantization, Get(), + &per_channel_quantization); + return per_channel_quantization; + } + + bool HasWeights() const { + auto weights = Weights(); + return !weights.Bytes().empty(); + } + + class Weights Weights() const { + LiteRtWeights weights; + internal::AssertOk(LiteRtGetTensorWeights, Get(), &weights); + return litert::Weights(weights); + } + + absl::string_view Name() const { + const char* name; + internal::AssertOk(LiteRtGetTensorName, Get(), &name); + return absl::string_view(name); + } + + struct TensorUse; + SmallVec Uses() const; + + template + Expected> WeightsData() const { + const ElementType ty = RankedTensorType().ElementType(); + if (ty != GetElementType()) { + return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); + } + + if (!HasWeights()) { + return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); + } + const absl::Span weights = Weights().Bytes(); + + auto num_elements = RankedTensorType().Layout().NumElements(); + if (!num_elements.has_value()) { + return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); + } + auto byte_width = GetByteWidth(ty); + if (!byte_width.has_value()) { + return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); + } + + if (byte_width.value() * num_elements.value() != weights.size()) { + return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); + } + + return absl::MakeConstSpan(reinterpret_cast(weights.data()), + num_elements.value()); + } + + std::optional DefiningOp() const { + bool has_defining_op; + LiteRtTensorDefiningOp defining_op; + internal::AssertOk(LiteRtGetTensorDefiningOp, Get(), &has_defining_op, + &defining_op); + if (has_defining_op) { + return defining_op; + } else { + return std::nullopt; + } + } + + bool IsSubgraphOutput() const; + bool IsSubgraphInput() const; + bool IsConstant() const; +}; + +// Operator. C++ equivalent of LiteRtOp. +class Op : public internal::NonOwnedHandle { + public: + Op() = default; + explicit Op(LiteRtOp op) : internal::NonOwnedHandle(op) {} + + LiteRtOpCode Code() const { + LiteRtOpCode opcode; + internal::AssertOk(LiteRtGetOpCode, Get(), &opcode); + return opcode; + } + + SmallVec Inputs() const { + LiteRtParamIndex num_inputs; + LiteRtTensorArray inputs; + internal::AssertOk(LiteRtGetOpInputs, Get(), &num_inputs, &inputs); + return SmallVec(inputs, inputs + num_inputs); + } + + SmallVec Outputs() const { + LiteRtParamIndex num_outputs; + LiteRtTensorArray outputs; + internal::AssertOk(LiteRtGetOpOutputs, Get(), &num_outputs, &outputs); + return SmallVec(outputs, outputs + num_outputs); + } +}; + +struct Tensor::TensorUse { + Op user; + LiteRtParamIndex user_arg_ind; +}; + +// Model subgraph. C++ equivalent of LiteRtSubgraph. +class Subgraph : public internal::NonOwnedHandle { + public: + Subgraph() = default; + explicit Subgraph(LiteRtSubgraph subgraph) + : internal::NonOwnedHandle(subgraph) {} + + SmallVec Inputs() const { + LiteRtParamIndex num_inputs; + LiteRtTensorArray inputs; + internal::AssertOk(LiteRtGetSubgraphInputs, Get(), &num_inputs, &inputs); + return SmallVec(inputs, inputs + num_inputs); + } + + SmallVec Outputs() const { + LiteRtParamIndex num_outputs; + LiteRtTensorArray outputs; + internal::AssertOk(LiteRtGetSubgraphOutputs, Get(), &num_outputs, &outputs); + return SmallVec(outputs, outputs + num_outputs); + } + + std::vector Ops() const { + LiteRtParamIndex num_ops; + LiteRtOpArray ops; + internal::AssertOk(LiteRtGetSubgraphOps, Get(), &num_ops, &ops); + return std::vector(ops, ops + num_ops); + } +}; + +// Model signature. C++ equivalent of LiteRtSignature. +class Signature : public internal::NonOwnedHandle { + public: + Signature() = default; + explicit Signature(LiteRtSignature signature) + : internal::NonOwnedHandle(signature) {} + + absl::string_view Key() const { + const char* key; + internal::AssertOk(LiteRtGetSignatureKey, Get(), &key); + return key; + } + + LiteRtSubgraph Subgraph() const { + LiteRtSubgraph subgraph; + internal::AssertOk(LiteRtGetSignatureSubgraph, Get(), &subgraph); + return subgraph; + } + + std::vector InputNames() const { + LiteRtParamIndex num_inputs; + internal::AssertOk(LiteRtGetNumSignatureInputs, Get(), &num_inputs); + std::vector input_names; + input_names.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + const char* input_name; + internal::AssertOk(LiteRtGetSignatureInputName, Get(), i, &input_name); + input_names.push_back(input_name); + } + return input_names; + } + + std::vector OutputNames() const { + LiteRtParamIndex num_outputs; + internal::AssertOk(LiteRtGetNumSignatureOutputs, Get(), &num_outputs); + std::vector output_names; + output_names.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + const char* output_name; + internal::AssertOk(LiteRtGetSignatureOutputName, Get(), i, &output_name); + output_names.push_back(output_name); + } + return output_names; + } +}; + +// Model. C++ equivalent of LiteRtModel. +class Model : public internal::Handle { + public: + Model() = default; + + static Model CreateFromOwnedHandle(LiteRtModel model) { + return Model(model, /*owned=*/true); + } + + static Model CreateFromNonOwnedHandle(LiteRtModel model) { + return Model(model, /*owned=*/false); + } + + static Expected CreateFromFile(const std::string& filename) { + LiteRtModel model; + if (auto status = LiteRtCreateModelFromFile(filename.c_str(), &model); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to load model from file"); + } + return CreateFromOwnedHandle(model); + } + + static Expected CreateFromBuffer(BufferRef buffer) { + LiteRtModel model; + if (auto status = + LiteRtCreateModelFromBuffer(buffer.Data(), buffer.Size(), &model); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to load model from buffer"); + } + return CreateFromOwnedHandle(model); + } + + Expected> Metadata( + const std::string& metadata_key) const { + const void* buffer; + size_t buffer_size; + if (LiteRtGetModelMetadata(Get(), metadata_key.data(), &buffer, + &buffer_size) != kLiteRtStatusOk) { + return Unexpected(kLiteRtStatusErrorNotFound, "Metadata key not found"); + } + return absl::MakeSpan(static_cast(buffer), buffer_size); + } + + Expected MainSubgraph() { + LiteRtParamIndex main_subgraph_index; + internal::AssertOk(LiteRtGetMainModelSubgraphIndex, Get(), + &main_subgraph_index); + return this->Subgraph(main_subgraph_index); + } + + size_t NumSubgraphs() const { + LiteRtParamIndex num_subgraphs; + internal::AssertOk(LiteRtGetNumModelSubgraphs, Get(), &num_subgraphs); + return num_subgraphs; + } + + Expected Subgraph(size_t subgraph_index) { + LiteRtSubgraph subgraph; + if (LiteRtGetModelSubgraph(Get(), subgraph_index, &subgraph) != + kLiteRtStatusOk) { + return Unexpected(kLiteRtStatusErrorNotFound, "Subgraph not found"); + } + return litert::Subgraph(subgraph); + } + + Expected Subgraph(absl::string_view signature_key) { + auto signature = FindSignature(signature_key); + if (!signature) { + return Unexpected(kLiteRtStatusErrorNotFound, "Signature not found"); + } + return litert::Subgraph(signature->Subgraph()); + } + + // Returns the list of signatures defined in the model. + Expected> GetSignatures() const { + LiteRtParamIndex num_signatures; + internal::AssertOk(LiteRtGetNumModelSignatures, Get(), &num_signatures); + std::vector signatures; + signatures.reserve(num_signatures); + for (int i = 0; i < num_signatures; ++i) { + LiteRtSignature lite_rt_signature; + internal::AssertOk(LiteRtGetModelSignature, Get(), i, &lite_rt_signature); + Signature signature(lite_rt_signature); + signatures.push_back(std::move(signature)); + } + return std::move(signatures); + } + + // Returns the signature at the given index. + Expected GetSignature(size_t signature_index) const { + LiteRtSignature lite_rt_signature; + internal::AssertOk(LiteRtGetModelSignature, Get(), signature_index, + &lite_rt_signature); + return Signature(lite_rt_signature); + } + + Expected FindSignature( + absl::string_view signature_key) const { + LiteRtParamIndex num_signatures; + internal::AssertOk(LiteRtGetNumModelSignatures, Get(), &num_signatures); + for (int i = 0; i < num_signatures; ++i) { + LiteRtSignature lite_rt_signature; + internal::AssertOk(LiteRtGetModelSignature, Get(), i, &lite_rt_signature); + const char* key_cstr; + internal::AssertOk(LiteRtGetSignatureKey, lite_rt_signature, &key_cstr); + if (absl::string_view(key_cstr) == signature_key) { + return Signature(lite_rt_signature); + } + } + return Unexpected(kLiteRtStatusErrorNotFound, "Signature not found"); + } + + static absl::string_view DefaultSignatureKey() { + const char* key; + internal::AssertOk(LiteRtGetDefaultSignatureKey, &key); + return key; + } + + private: + // Parameter `owned` indicates if the created TensorBuffer object should take + // ownership of the provided `tensor_buffer` handle. + Model(LiteRtModel model, bool owned) + : internal::Handle(model, owned) {} +}; + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_H_ diff --git a/tflite/experimental/litert/cc/litert_model_predicates.cc b/tflite/experimental/litert/cc/litert_model_predicates.cc new file mode 100644 index 00000000..164424be --- /dev/null +++ b/tflite/experimental/litert/cc/litert_model_predicates.cc @@ -0,0 +1,115 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/cc/litert_model_predicates.h" + +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "tflite/experimental/litert/cc/litert_detail.h" +#include "tflite/experimental/litert/cc/litert_model.h" + +namespace litert { +namespace { + +template +bool Any(absl::Span vals, std::function unary_pred) { + for (const auto& val : vals) { + if (unary_pred(val)) { + return true; + } + } + return false; +} + +bool UseSoftEqual(const Tensor::TensorUse& actual_use, + const UseInfo& expected_use) { + if (expected_use.user_param_ind.has_value() && + actual_use.user_arg_ind != expected_use.user_param_ind.value()) { + return false; + } + if (expected_use.op_code.has_value() && + actual_use.user.Code() != expected_use.op_code.value()) { + return false; + } + return true; +} + +} // namespace + +// Does given tensor have given type and shape info. Optional values considered +// to be a vacous match. +bool MatchRankedTensorType(const RankedTensorType& tensor_type, + const TensorTypeInfo& expected) { + if (expected.element_type.has_value() && + (tensor_type.ElementType() != expected.element_type.value())) { + return false; + } + + if (expected.dims.has_value()) { + auto actual_dims = tensor_type.Layout().Dimensions(); + auto expected_dims = absl::MakeConstSpan(expected.dims.value()); + return AllZip(actual_dims, expected_dims, + [](auto l, auto r) -> bool { return l == r; }); + } + return true; +} + +// Does given op have signature matching given types. Optional values considered +// to be a vacous match. +bool MatchOpType( + const Op& op, + const std::vector>& expected_inputs, + const std::vector>& expected_outputs) { + auto match = [](const Tensor& actual, + const std::optional& expected) -> bool { + if (!expected.has_value()) { + return true; + } + return MatchRankedTensorType(actual.RankedTensorType(), expected.value()); + }; + + const bool inputs_match = AllZip(absl::MakeConstSpan(op.Inputs()), + absl::MakeConstSpan(expected_inputs), match); + const bool outputs_match = + AllZip(absl::MakeConstSpan(op.Outputs()), + absl::MakeConstSpan(expected_outputs), match); + return inputs_match && outputs_match; +} + +bool MatchUse(const Tensor& tensor, const UseInfo& expected_use) { + auto soft_equal = [&expected_use = std::as_const(expected_use)]( + const Tensor::TensorUse& actual_use) { + return UseSoftEqual(actual_use, expected_use); + }; + return Any(tensor.Uses(), soft_equal); +} + +bool MatchUses(const Tensor& tensor, const std::vector& expected_uses, + bool strict) { + const auto uses = tensor.Uses(); + if (strict && uses.size() != expected_uses.size()) { + return false; + } + auto not_use = [&tensor = + std::as_const(tensor)](const UseInfo& expected_use) { + return !MatchUse(tensor, expected_use); + }; + return !Any(expected_uses, not_use); +} + +} // namespace litert diff --git a/tflite/experimental/litert/cc/litert_model_predicates.h b/tflite/experimental/litert/cc/litert_model_predicates.h new file mode 100644 index 00000000..c5d802c0 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_model_predicates.h @@ -0,0 +1,78 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_PREDICATES_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_PREDICATES_H_ + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_model.h" + +// Predicates used for matching patterns in the graph. NOTE: All optionals in +// matcher arguments are considered to be a vacous match. + +namespace litert { + +struct TensorTypeInfo { + std::optional element_type = std::nullopt; + std::optional> dims = std::nullopt; + + explicit TensorTypeInfo(ElementType element_type) + : element_type(element_type) {} + explicit TensorTypeInfo(absl::InlinedVector dims) : dims(dims) {} + TensorTypeInfo(ElementType element_type, absl::InlinedVector dims) + : element_type(element_type), dims(dims) {} +}; + +struct UseInfo { + std::optional op_code = std::nullopt; + std::optional user_param_ind = std::nullopt; +}; + +// Does this tensor have given type and shape info. +bool MatchRankedTensorType(const RankedTensorType& tensor_type, + const TensorTypeInfo& expected); + +// Does this op have signature matching given types. +bool MatchOpType( + const Op& op, + const std::vector>& expected_inputs, + const std::vector>& expected_outputs); + +// Does this tensor contain weights whose values match expected_data. +template +inline bool MatchWeights(const Tensor& tensor, + absl::Span expected_data) { + auto weights = tensor.WeightsData(); + return weights.HasValue() && *weights == expected_data; +} + +// Does this tensor have a user with the given information. +bool MatchUse(const Tensor& tensor, const UseInfo& expected_use); + +// Does this tensor have matching users. If "strict" is true, then expected_uses +// size must equal the number of actual uses, otherwise just checks each +// expected_use match an actual use. +bool MatchUses(const Tensor& tensor, const std::vector& expected_uses, + bool strict = true); + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_PREDICATES_H_ diff --git a/tflite/experimental/litert/cc/litert_model_predicates_test.cc b/tflite/experimental/litert/cc/litert_model_predicates_test.cc new file mode 100644 index 00000000..1ed08951 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_model_predicates_test.cc @@ -0,0 +1,207 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/cc/litert_model_predicates.h" + +#include + +#include +#include "absl/container/inlined_vector.h" +#include "absl/log/absl_check.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_element_type.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/test/common.h" + +namespace litert { + +namespace { + +using ::litert::testing::LoadTestFileModel; + +TEST(MatchRankedTensorTypeTest, HasAll) { + auto litert_model = LoadTestFileModel("one_mul.tflite"); + auto subgraph = litert_model.MainSubgraph(); + ABSL_CHECK(subgraph); + auto ops = subgraph->Ops(); + const auto inputs = ops.front().Inputs(); + const auto& input = inputs.front(); + EXPECT_TRUE(MatchRankedTensorType( + input.RankedTensorType(), TensorTypeInfo(ElementType::Float32, {2, 2}))); +} + +TEST(MatchRankedTensorTypeTest, NoMatch) { + auto litert_model = LoadTestFileModel("one_mul.tflite"); + auto subgraph = litert_model.MainSubgraph(); + ABSL_CHECK(subgraph); + auto ops = subgraph->Ops(); + const auto inputs = ops.front().Inputs(); + const auto& input = inputs.front(); + EXPECT_FALSE(MatchRankedTensorType( + input.RankedTensorType(), TensorTypeInfo(ElementType::Float32, {3, 2}))); +} + +TEST(MatchRankedTensorTypeTest, AnyDims) { + auto litert_model = LoadTestFileModel("one_mul.tflite"); + auto subgraph = litert_model.MainSubgraph(); + ABSL_CHECK(subgraph); + auto ops = subgraph->Ops(); + const auto inputs = ops.front().Inputs(); + const auto& input = inputs.front(); + EXPECT_TRUE(MatchRankedTensorType(input.RankedTensorType(), + TensorTypeInfo(ElementType::Float32))); +} + +TEST(MatchRankedTensorTypeTest, AnyElementType) { + auto litert_model = LoadTestFileModel("one_mul.tflite"); + auto subgraph = litert_model.MainSubgraph(); + ABSL_CHECK(subgraph); + auto ops = subgraph->Ops(); + const auto inputs = ops.front().Inputs(); + const auto& input = inputs.front(); + EXPECT_TRUE( + MatchRankedTensorType(input.RankedTensorType(), TensorTypeInfo({2, 2}))); +} + +TEST(MatchOpTypeTest, HasAll) { + auto litert_model = LoadTestFileModel("one_mul.tflite"); + auto subgraph = litert_model.MainSubgraph(); + ABSL_CHECK(subgraph); + auto ops = subgraph->Ops(); + TensorTypeInfo expected_type(ElementType::Float32, {2, 2}); + EXPECT_TRUE(MatchOpType(ops.front(), {expected_type, expected_type}, + {expected_type})); +} + +TEST(MatchOpTypeTest, NoMatch) { + auto litert_model = LoadTestFileModel("one_mul.tflite"); + auto subgraph = litert_model.MainSubgraph(); + ABSL_CHECK(subgraph); + auto ops = subgraph->Ops(); + TensorTypeInfo expected_type(ElementType::Float32, {2, 2}); + TensorTypeInfo not_expected_type(ElementType::Int32, {2, 2}); + EXPECT_FALSE(MatchOpType(ops.front(), {not_expected_type, expected_type}, + {expected_type})); +} + +TEST(MatchOpTypeTest, AnyInput) { + auto litert_model = LoadTestFileModel("one_mul.tflite"); + auto subgraph = litert_model.MainSubgraph(); + ABSL_CHECK(subgraph); + auto ops = subgraph->Ops(); + TensorTypeInfo expected_type(ElementType::Float32, {2, 2}); + EXPECT_TRUE( + MatchOpType(ops.front(), {std::nullopt, expected_type}, {expected_type})); +} + +TEST(MatchOpTypeTest, AnyOutput) { + auto litert_model = LoadTestFileModel("one_mul.tflite"); + auto subgraph = litert_model.MainSubgraph(); + ABSL_CHECK(subgraph); + auto ops = subgraph->Ops(); + TensorTypeInfo expected_type(ElementType::Float32, {2, 2}); + EXPECT_TRUE( + MatchOpType(ops.front(), {std::nullopt, expected_type}, {std::nullopt})); +} + +TEST(MatchWeightsTest, Matches) { + auto litert_model = LoadTestFileModel("add_cst.tflite"); + auto subgraph = litert_model.MainSubgraph(); + ABSL_CHECK(subgraph); + auto ops = subgraph->Ops(); + const auto inputs = ops.front().Inputs(); + const auto& cst = inputs.back(); + EXPECT_TRUE(MatchWeights(cst, absl::Span({1.0, 2.0, 3.0, 4.0}))); +} + +TEST(MatchWeightsTest, NoMatchBadType) { + auto litert_model = LoadTestFileModel("add_cst.tflite"); + auto subgraph = litert_model.MainSubgraph(); + ABSL_CHECK(subgraph); + auto ops = subgraph->Ops(); + const auto inputs = ops.front().Inputs(); + const auto& cst = inputs.back(); + EXPECT_FALSE( + MatchWeights(cst, absl::Span({1.0, 2.0, 3.0, 4.0}))); +} +TEST(MatchWeightsTest, NoMatchBadVals) { + auto litert_model = LoadTestFileModel("add_cst.tflite"); + auto subgraph = litert_model.MainSubgraph(); + ABSL_CHECK(subgraph); + auto ops = subgraph->Ops(); + const auto inputs = ops.front().Inputs(); + const auto& cst = inputs.back(); + EXPECT_FALSE( + MatchWeights(cst, absl::Span({3.0, 2.0, 3.0, 5.0}))); +} + +TEST(MatchUseTest, Match) { + auto litert_model = LoadTestFileModel("add_cst.tflite"); + auto subgraph = litert_model.MainSubgraph(); + ABSL_CHECK(subgraph); + auto ops = subgraph->Ops(); + const auto inputs = ops.front().Inputs(); + EXPECT_TRUE(MatchUse(inputs.back(), UseInfo{kLiteRtOpCodeTflAdd, 1})); +} + +TEST(MatchUseTest, MatchAnyCode) { + auto litert_model = LoadTestFileModel("add_cst.tflite"); + auto subgraph = litert_model.MainSubgraph(); + ABSL_CHECK(subgraph); + auto ops = subgraph->Ops(); + const auto inputs = ops.front().Inputs(); + EXPECT_TRUE(MatchUse(inputs.back(), UseInfo{std::nullopt, 1})); +} + +TEST(MatchUseTest, NoMatch) { + auto litert_model = LoadTestFileModel("add_cst.tflite"); + auto subgraph = litert_model.MainSubgraph(); + ABSL_CHECK(subgraph); + auto ops = subgraph->Ops(); + const auto inputs = ops.front().Inputs(); + EXPECT_FALSE(MatchUse(inputs.back(), UseInfo{std::nullopt, 2})); +} + +TEST(MatchUsesTest, StrictMatch) { + auto litert_model = LoadTestFileModel("add_simple.tflite"); + auto subgraph = litert_model.MainSubgraph(); + ABSL_CHECK(subgraph); + auto subgraph_inputs = subgraph->Inputs(); + const auto& tensor = subgraph_inputs.front(); + EXPECT_TRUE( + MatchUses(tensor, {{kLiteRtOpCodeTflAdd, 0}, {kLiteRtOpCodeTflAdd, 1}})); +} + +TEST(MatchUsesTest, StrictNoMatch) { + auto litert_model = LoadTestFileModel("add_simple.tflite"); + auto subgraph = litert_model.MainSubgraph(); + ABSL_CHECK(subgraph); + auto subgraph_inputs = subgraph->Inputs(); + const auto& tensor = subgraph_inputs.front(); + EXPECT_FALSE(MatchUses(tensor, {{kLiteRtOpCodeTflAdd, 0}})); +} + +TEST(MatchUsesTest, NonStrict) { + auto litert_model = LoadTestFileModel("add_simple.tflite"); + auto subgraph = litert_model.MainSubgraph(); + ABSL_CHECK(subgraph); + auto subgraph_inputs = subgraph->Inputs(); + const auto& tensor = subgraph_inputs.front(); + EXPECT_TRUE(MatchUses(tensor, {{kLiteRtOpCodeTflAdd, 0}}, /*strict=*/false)); +} + +} // namespace + +} // namespace litert diff --git a/tflite/experimental/litert/cc/litert_model_test.cc b/tflite/experimental/litert/cc/litert_model_test.cc new file mode 100644 index 00000000..1ea35444 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_model_test.cc @@ -0,0 +1,338 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/cc/litert_model.h" + +#include +#include +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_element_type.h" +#include "tflite/experimental/litert/cc/litert_layout.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/test/common.h" + +// Tests for CC Wrapper classes around public C api. + +namespace litert { + +namespace { + +static constexpr const int32_t kTensorDimensions[] = {1, 2, 3}; + +static constexpr const auto kRank = + sizeof(kTensorDimensions) / sizeof(kTensorDimensions[0]); + +static constexpr const uint32_t kTensorStrides[] = {6, 3, 1}; + +static constexpr const LiteRtLayout kLayout = BuildLayout(kTensorDimensions); + +static constexpr const LiteRtLayout kLayoutWithStrides = + BuildLayout(kTensorDimensions, kTensorStrides); + +static constexpr const LiteRtRankedTensorType kTensorType = { + /*.element_type=*/kLiteRtElementTypeFloat32, + /*.layout=*/kLayout, +}; + +//===----------------------------------------------------------------------===// +// CC Model // +//===----------------------------------------------------------------------===// + +TEST(CcModelTest, SimpleModel) { + auto model = testing::LoadTestFileModel("one_mul.tflite"); + + LiteRtParamIndex num_subgraphs; + ASSERT_EQ(LiteRtGetNumModelSubgraphs(model.Get(), &num_subgraphs), + kLiteRtStatusOk); + EXPECT_EQ(model.NumSubgraphs(), num_subgraphs); + EXPECT_EQ(model.NumSubgraphs(), 1); + + LiteRtParamIndex main_subgraph_index; + ASSERT_EQ(LiteRtGetMainModelSubgraphIndex(model.Get(), &main_subgraph_index), + kLiteRtStatusOk); + EXPECT_EQ(main_subgraph_index, 0); + + LiteRtSubgraph litert_subgraph_0; + ASSERT_EQ(LiteRtGetModelSubgraph(model.Get(), /*subgraph_index=*/0, + &litert_subgraph_0), + kLiteRtStatusOk); + + auto subgraph_0 = model.Subgraph(0); + ASSERT_TRUE(subgraph_0); + EXPECT_EQ(subgraph_0->Get(), litert_subgraph_0); + + auto main_subgraph = model.MainSubgraph(); + EXPECT_EQ(main_subgraph->Get(), subgraph_0->Get()); +} + +//===----------------------------------------------------------------------===// +// CC Signature // +//===----------------------------------------------------------------------===// + +TEST(CcSignatureTest, Basic) { + auto model = testing::LoadTestFileModel("one_mul.tflite"); + + auto signatures = model.GetSignatures(); + ASSERT_TRUE(signatures); + ASSERT_EQ(signatures->size(), 1); + auto& signature = signatures->at(0); + EXPECT_THAT(signature.Key(), Model::DefaultSignatureKey()); + auto input_names = signature.InputNames(); + EXPECT_THAT(input_names[0], "arg0"); + EXPECT_THAT(input_names[1], "arg1"); + auto output_names = signature.OutputNames(); + EXPECT_THAT(output_names[0], "tfl.mul"); +} + +TEST(CcSignatureTest, Lookup) { + auto model = testing::LoadTestFileModel("one_mul.tflite"); + + { + auto signature = model.FindSignature("nonexistent"); + ASSERT_FALSE(signature); + } + auto signature = model.FindSignature(Model::DefaultSignatureKey()); + ASSERT_TRUE(signature); + EXPECT_THAT(signature->Key(), Model::DefaultSignatureKey()); + auto input_names = signature->InputNames(); + EXPECT_THAT(input_names[0], "arg0"); + EXPECT_THAT(input_names[1], "arg1"); + auto output_names = signature->OutputNames(); + EXPECT_THAT(output_names[0], "tfl.mul"); +} + +//===----------------------------------------------------------------------===// +// CC Layout // +//===----------------------------------------------------------------------===// + +TEST(CcLayoutTest, NoStrides) { + Layout layout(kLayout); + + ASSERT_EQ(layout.Rank(), kLayout.rank); + for (auto i = 0; i < layout.Rank(); ++i) { + ASSERT_EQ(layout.Dimensions()[i], kLayout.dimensions[i]); + } + ASSERT_FALSE(layout.HasStrides()); +} + +TEST(CcLayoutTest, WithStrides) { + Layout layout(kLayoutWithStrides); + + ASSERT_EQ(layout.Rank(), kLayoutWithStrides.rank); + for (auto i = 0; i < layout.Rank(); ++i) { + ASSERT_EQ(layout.Dimensions()[i], kLayoutWithStrides.dimensions[i]); + } + ASSERT_TRUE(layout.HasStrides()); + for (auto i = 0; i < layout.Rank(); ++i) { + ASSERT_EQ(layout.Strides()[i], kLayoutWithStrides.strides[i]); + } +} + +TEST(CcLayoutTest, Equal) { + auto&& dims = {2, 2}; + Layout layout1(BuildLayout(dims)); + Layout layout2(BuildLayout({2, 2})); + ASSERT_TRUE(layout1 == layout2); +} + +TEST(CcLayoutTest, NotEqual) { + Layout layout1(BuildLayout({2, 2}, nullptr)); + Layout layout2(BuildLayout({2, 2}, kTensorStrides)); + ASSERT_FALSE(layout1 == layout2); +} + +TEST(CcLayoutTest, NumElements) { + Layout layout(BuildLayout({2, 2, 3})); + auto num_elements = layout.NumElements(); + ASSERT_TRUE(num_elements.has_value()); + EXPECT_EQ(num_elements.value(), 12); +} + +//===----------------------------------------------------------------------===// +// CC Op // +//===----------------------------------------------------------------------===// + +TEST(CcOpTest, SimpleSupportedOp) { + auto litert_model = testing::LoadTestFileModel("one_mul.tflite"); + auto subgraph = litert_model.MainSubgraph(); + const auto ops = subgraph->Ops(); + const auto& op = ops.front(); + + EXPECT_EQ(op.Code(), kLiteRtOpCodeTflMul); + EXPECT_EQ(op.Inputs().size(), 2); + EXPECT_EQ(op.Outputs().size(), 1); +} + +//===----------------------------------------------------------------------===// +// CC RankedTensorType // +//===----------------------------------------------------------------------===// + +TEST(CcRankedTensorTypeTest, Accessors) { + Layout layout(kLayout); + RankedTensorType tensor_type(kTensorType); + ASSERT_EQ(tensor_type.ElementType(), + static_cast(kTensorType.element_type)); + ASSERT_TRUE(tensor_type.Layout() == layout); +} + +//===----------------------------------------------------------------------===// +// CC Tensor // +//===----------------------------------------------------------------------===// + +TEST(CcTensorTest, SimpleModel) { + auto litert_model = testing::LoadTestFileModel("one_mul.tflite"); + auto subgraph = litert_model.MainSubgraph(); + + auto inputs = subgraph->Inputs(); + ASSERT_EQ(inputs.size(), 2); + + { + const Tensor& input_tensor = inputs.front(); + ASSERT_EQ(input_tensor.TypeId(), kLiteRtRankedTensorType); + + auto input_ranked_tensor_type = input_tensor.RankedTensorType(); + ASSERT_EQ(input_ranked_tensor_type.ElementType(), ElementType::Float32); + + EXPECT_FALSE(input_tensor.HasWeights()); + + auto input_weights = input_tensor.Weights(); + ASSERT_EQ(input_weights.Bytes().size(), 0); + + ASSERT_EQ(input_tensor.DefiningOp(), std::nullopt); + + const auto uses = input_tensor.Uses(); + ASSERT_EQ(uses.size(), 1); + } + + auto outputs = subgraph->Outputs(); + ASSERT_EQ(outputs.size(), 1); + + { + const Tensor& output_tensor = outputs.front(); + ASSERT_EQ(output_tensor.TypeId(), kLiteRtRankedTensorType); + + auto output_defining_op = output_tensor.DefiningOp(); + EXPECT_TRUE(output_defining_op.has_value()); + + ASSERT_TRUE(output_tensor.Uses().empty()); + } +} + +TEST(CcTensorTest, WeightsData) { + auto litert_model = testing::LoadTestFileModel("add_cst.tflite"); + auto subgraph = litert_model.MainSubgraph(); + + auto data = subgraph->Ops().front().Inputs().back().WeightsData(); + ASSERT_TRUE(data.HasValue()); + EXPECT_THAT(data.Value(), ::testing::ElementsAreArray({1.0, 2.0, 3.0, 4.0})); +} + +TEST(CcTensorTest, Name) { + static constexpr absl::string_view kName = "foo"; + LiteRtTensorT tensor; + tensor.SetName(std::string(kName)); + + Tensor cc_tensor(&tensor); + EXPECT_EQ(cc_tensor.Name(), kName); +} + +TEST(CcTensorTest, QuantizationNone) { + LiteRtTensorT litert_tensor; + litert_tensor.Qparams().first = kLiteRtQuantizationNone; + + Tensor tensor(&litert_tensor); + EXPECT_EQ(tensor.QTypeId(), kLiteRtQuantizationNone); + EXPECT_FALSE(tensor.HasQuantization()); +} + +TEST(CcTensorTest, QuantizationPerTensor) { + static constexpr auto kScale = 1.0; + static constexpr auto kZeroPoint = 1; + + LiteRtTensorT litert_tensor; + litert_tensor.SetQarams(MakePerTensorQuantization(kScale, kZeroPoint)); + + Tensor tensor(&litert_tensor); + ASSERT_EQ(tensor.QTypeId(), kLiteRtQuantizationPerTensor); + ASSERT_TRUE(tensor.HasQuantization()); + + const auto per_tensor_quantization = tensor.PerTensorQuantization(); + EXPECT_EQ(per_tensor_quantization.scale, kScale); + EXPECT_EQ(per_tensor_quantization.zero_point, kZeroPoint); +} + +TEST(CcTensorTest, QuantizationPerChannel) { + static constexpr auto kNumChannels = 2; + static constexpr auto kQuantizedDimension = 0; + static constexpr float kScales[kNumChannels] = {1.0, 2.0}; + static constexpr int64_t kZeroPoints[kNumChannels] = {0, 0}; + + LiteRtTensorT litert_tensor; + auto per_channel = MakePerChannelQuantization( + kScales, kZeroPoints, kQuantizedDimension, litert_tensor); + litert_tensor.SetQarams(per_channel); + + Tensor tensor(&litert_tensor); + ASSERT_EQ(tensor.QTypeId(), kLiteRtQuantizationPerChannel); + ASSERT_TRUE(tensor.HasQuantization()); + + const auto per_channel_quantization = tensor.PerChannelQuantization(); + EXPECT_THAT( + absl::MakeConstSpan(per_channel_quantization.scales, kNumChannels), + ::testing::ElementsAreArray(kScales)); + EXPECT_THAT( + absl::MakeConstSpan(per_channel_quantization.zero_points, kNumChannels), + ::testing::ElementsAreArray(kZeroPoints)); + EXPECT_EQ(per_channel_quantization.num_channels, kNumChannels); + EXPECT_EQ(per_channel_quantization.quantized_dimension, kQuantizedDimension); +} + +//===----------------------------------------------------------------------===// +// CC Subgraph // +//===----------------------------------------------------------------------===// + +TEST(CcSubgraphTest, SimpleModel) { + auto model = testing::LoadTestFileModel("one_mul.tflite"); + auto subgraph = model.MainSubgraph(); + + ASSERT_EQ(subgraph->Inputs().size(), 2); + ASSERT_EQ(subgraph->Outputs().size(), 1); + ASSERT_EQ(subgraph->Ops().size(), 1); +} + +//===----------------------------------------------------------------------===// +// CC ElementType // +//===----------------------------------------------------------------------===// + +TEST(CcElementTypeTest, GetByteWidth) { + const size_t width = GetByteWidth(); + EXPECT_EQ(width, 1); +} + +TEST(CcElementTypeTest, GetElementType) { + ElementType ty = GetElementType(); + EXPECT_EQ(ty, ElementType::Float32); +} + +} // namespace +} // namespace litert diff --git a/tflite/experimental/litert/cc/litert_tensor_buffer.h b/tflite/experimental/litert/cc/litert_tensor_buffer.h new file mode 100644 index 00000000..11419297 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_tensor_buffer.h @@ -0,0 +1,227 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_H_ + +#include +#include +#include + +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_event.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_handle.h" +#include "tflite/experimental/litert/cc/litert_model.h" + +namespace litert { + +// Tensor and associated backing buffer. C++ equivalent of LiteRtTensorBuffer. +class TensorBuffer + : public internal::Handle { + public: + TensorBuffer() = default; + + // Parameter `owned` indicates if the created TensorBuffer object should take + // ownership of the provided `tensor_buffer` handle. + explicit TensorBuffer(LiteRtTensorBuffer tensor_buffer, bool owned = true) + : internal::Handle( + tensor_buffer, owned) {} + + // Creates a duplicate of the current TensorBuffer object. The returned + // object is reference counted so the underlying LiteRtTensorBuffer handle is + // not released with the destructor until the last reference is removed. + Expected Duplicate() const { + if (!IsOwned()) { + return Unexpected(kLiteRtStatusErrorInvalidArgument, + "Cannot duplicate a non-owned tensor buffer"); + } + if (auto status = LiteRtDuplicateTensorBuffer(Get()); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to duplicate managed tensor buffer"); + } + return TensorBuffer(Get()); + } + + static Expected CreateManaged( + LiteRtTensorBufferType buffer_type, const RankedTensorType& tensor_type, + size_t buffer_size) { + LiteRtTensorBuffer tensor_buffer; + auto litert_tensor_type = static_cast(tensor_type); + if (auto status = LiteRtCreateManagedTensorBuffer( + buffer_type, &litert_tensor_type, buffer_size, &tensor_buffer); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to create managed tensor buffer"); + } + return TensorBuffer(tensor_buffer); + } + + litert::Expected GetAhwb() const { +#if LITERT_HAS_AHWB_SUPPORT + AHardwareBuffer* ahwb; + if (LiteRtGetTensorBufferAhwb(Get(), &ahwb) == kLiteRtStatusOk) { + return ahwb; + } else { + return litert::Unexpected( + kLiteRtStatusErrorRuntimeFailure, + "Failed to get AHardwareBuffer from tensor buffer"); + } +#else + return litert::Unexpected( + kLiteRtStatusErrorRuntimeFailure, + "AHardwareBuffer is not supported on this platform"); +#endif + } + + Expected BufferType() const { + LiteRtTensorBufferType tensor_buffer_type; + if (auto status = LiteRtGetTensorBufferType(Get(), &tensor_buffer_type); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to get tensor buffer type"); + } + return tensor_buffer_type; + } + + Expected TensorType() const { + LiteRtRankedTensorType tensor_type; + if (auto status = LiteRtGetTensorBufferTensorType(Get(), &tensor_type); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to get tensor type"); + } + return RankedTensorType(tensor_type); + } + + Expected Size() const { + size_t size; + if (auto status = LiteRtGetTensorBufferSize(Get(), &size); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to get tensor size"); + } + return size; + } + + Expected Offset() const { + size_t offset; + if (auto status = LiteRtGetTensorBufferOffset(Get(), &offset); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to get tensor offset"); + } + return offset; + } + + Expected Lock(LiteRtEvent event = nullptr) { + void* host_mem_addr; + if (auto status = LiteRtLockTensorBuffer(Get(), &host_mem_addr, event); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to lock the tensor buffer"); + } + return host_mem_addr; + } + + Expected Unlock() { + if (auto status = LiteRtUnlockTensorBuffer(Get()); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to unlock the tensor buffer"); + } + return {}; + } + + // Writes data from the user provided Span to the tensor buffer. + // It returns an error if the provided buffer is bigger than the size of the + // tensor buffer. + template + Expected Write(absl::Span data) { + auto host_mem_addr = Lock(); + if (!host_mem_addr) { + return host_mem_addr.Error(); + } + auto size = Size(); + if (!size) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to get TensorBuffer size"); + } + if (*size < data.size() * sizeof(T)) { + return Unexpected( + kLiteRtStatusErrorRuntimeFailure, + "TensorBuffer size is smaller than the given data size"); + } + std::memcpy(*host_mem_addr, data.data(), data.size() * sizeof(T)); + Unlock(); + return {}; + } + + // Reads data into the user provided Span from the tensor buffer. + // If the provided buffer is smaller than the size of the tensor buffer, the + // data will be read up to the size of the provided buffer. + // It returns an error if the provided buffer is bigger than the size of the + // tensor buffer. + template + Expected Read(absl::Span data) { + auto host_mem_addr = Lock(); + if (!host_mem_addr) { + return host_mem_addr.Error(); + } + auto size = Size(); + if (!size) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to get TensorBuffer size"); + } + size_t total_read_size = data.size() * sizeof(T); + if (*size < total_read_size) { + return Unexpected( + kLiteRtStatusErrorRuntimeFailure, + "TensorBuffer size is smaller than the given data size"); + } + std::memcpy(data.data(), *host_mem_addr, total_read_size); + Unlock(); + return {}; + } +}; + +class TensorBufferScopedLock { + public: + ~TensorBufferScopedLock() { (void)LiteRtUnlockTensorBuffer(tensor_buffer_); } + + template + static Expected> Create( + TensorBuffer& tensor_buffer, LiteRtEvent event = nullptr) { + return Create(tensor_buffer.Get(), event); + } + + template + static Expected> Create( + LiteRtTensorBuffer tensor_buffer, LiteRtEvent event = nullptr) { + void* host_mem_addr; + if (auto status = + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, event); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to lock the tensor buffer"); + } + return std::make_pair(TensorBufferScopedLock(tensor_buffer), + static_cast(host_mem_addr)); + } + + private: + explicit TensorBufferScopedLock(LiteRtTensorBuffer& tensor_buffer) + : tensor_buffer_(tensor_buffer) {} + + LiteRtTensorBuffer tensor_buffer_; +}; + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_H_ diff --git a/tflite/experimental/litert/cc/litert_tensor_buffer_requirements.h b/tflite/experimental/litert/cc/litert_tensor_buffer_requirements.h new file mode 100644 index 00000000..ac794f09 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_tensor_buffer_requirements.h @@ -0,0 +1,106 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ + +#include +#include +#include + +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_handle.h" + +namespace litert { + +// Requirements for allocating a TensorBuffer, typically specified by a HW +// accelerator for a given I/O tensor. C++ equivalent to +// LiteRtTensorBufferRequirements. +class TensorBufferRequirements + : public internal::Handle { + public: + TensorBufferRequirements() = default; + + // Parameter `owned` indicates if the created TensorBufferRequirements object + // should take ownership of the provided `requirements` handle. + explicit TensorBufferRequirements(LiteRtTensorBufferRequirements requirements, + bool owned = true) + : internal::Handle(requirements, + owned) {} + + static Expected Create( + absl::Span buffer_types, size_t buffer_size, + absl::Span strides = + absl::MakeSpan(static_cast(nullptr), 0)) { + LiteRtTensorBufferRequirements tensor_buffer_requirements; + if (auto status = LiteRtCreateTensorBufferRequirements( + buffer_types.size(), buffer_types.data(), buffer_size, + strides.size(), strides.data(), &tensor_buffer_requirements); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to create tensor buffer requirements"); + } + return TensorBufferRequirements(tensor_buffer_requirements); + } + + Expected> SupportedTypes() const { + int num_types; + if (auto status = LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + Get(), &num_types); + status != kLiteRtStatusOk) { + return Unexpected(status, + "Failed to get the number of supported tensor types"); + } + std::vector types(num_types); + for (auto i = 0; i < num_types; ++i) { + if (auto status = + LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + Get(), i, &types[i]); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to get supported tensor type"); + } + } + return types; + } + + Expected BufferSize() const { + size_t buffer_size; + if (auto status = + LiteRtGetTensorBufferRequirementsBufferSize(Get(), &buffer_size); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to get tensor buffer size"); + } + return buffer_size; + } + + Expected> Strides() const { + int num_strides; + const uint32_t* strides; + if (auto status = LiteRtGetTensorBufferRequirementsStrides( + Get(), &num_strides, &strides); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to get strides"); + } + return absl::MakeSpan(strides, num_strides); + } +}; + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ diff --git a/tflite/experimental/litert/cc/litert_tensor_buffer_requirements_test.cc b/tflite/experimental/litert/cc/litert_tensor_buffer_requirements_test.cc new file mode 100644 index 00000000..901fd543 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_tensor_buffer_requirements_test.cc @@ -0,0 +1,104 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer_requirements.h" + +namespace { + +constexpr const LiteRtTensorBufferType kSupportedTensorBufferTypes[] = { + kLiteRtTensorBufferTypeHostMemory, + kLiteRtTensorBufferTypeAhwb, + kLiteRtTensorBufferTypeIon, + kLiteRtTensorBufferTypeFastRpc, +}; + +constexpr const size_t kNumSupportedTensorBufferTypes = + sizeof(kSupportedTensorBufferTypes) / + sizeof(kSupportedTensorBufferTypes[0]); + +constexpr const size_t kBufferSize = 1234; + +} // namespace + +TEST(TensorBufferRequirements, Owned) { + auto requirements = litert::TensorBufferRequirements::Create( + absl::MakeSpan(kSupportedTensorBufferTypes, + kNumSupportedTensorBufferTypes), + kBufferSize); + ASSERT_TRUE(requirements); + + auto supported_types = requirements->SupportedTypes(); + ASSERT_TRUE(supported_types); + ASSERT_EQ(supported_types->size(), kNumSupportedTensorBufferTypes); + for (auto i = 0; i < supported_types->size(); ++i) { + ASSERT_EQ((*supported_types)[i], kSupportedTensorBufferTypes[i]); + } + + auto size = requirements->BufferSize(); + ASSERT_TRUE(size); + ASSERT_EQ(*size, kBufferSize); +} + +TEST(TensorBufferRequirements, NotOwned) { + LiteRtTensorBufferRequirements litert_requirements; + ASSERT_EQ(LiteRtCreateTensorBufferRequirements( + kNumSupportedTensorBufferTypes, kSupportedTensorBufferTypes, + kBufferSize, /*num_strides=*/0, /*strides=*/nullptr, + &litert_requirements), + kLiteRtStatusOk); + + litert::TensorBufferRequirements requirements(litert_requirements, + /*owned=*/false); + + auto supported_types = requirements.SupportedTypes(); + ASSERT_TRUE(supported_types); + ASSERT_EQ(supported_types->size(), kNumSupportedTensorBufferTypes); + for (auto i = 0; i < supported_types->size(); ++i) { + ASSERT_EQ((*supported_types)[i], kSupportedTensorBufferTypes[i]); + } + + auto size = requirements.BufferSize(); + ASSERT_TRUE(size); + ASSERT_EQ(*size, kBufferSize); + + ASSERT_EQ(requirements.Get(), litert_requirements); + + LiteRtDestroyTensorBufferRequirements(litert_requirements); +} + +TEST(TensorBufferRequirements, WithStrides) { + constexpr std::array kStrides = {1, 2, 3}; + + auto requirements = litert::TensorBufferRequirements::Create( + absl::MakeSpan(kSupportedTensorBufferTypes, + kNumSupportedTensorBufferTypes), + kBufferSize, absl::MakeSpan(kStrides.data(), kStrides.size())); + ASSERT_TRUE(requirements); + + auto strides = requirements->Strides(); + ASSERT_TRUE(strides); + ASSERT_EQ(strides->size(), kStrides.size()); + for (auto i = 0; i < kStrides.size(); ++i) { + ASSERT_EQ((*strides)[i], kStrides[i]); + } +} diff --git a/tflite/experimental/litert/cc/litert_tensor_buffer_test.cc b/tflite/experimental/litert/cc/litert_tensor_buffer_test.cc new file mode 100644 index 00000000..76ec5818 --- /dev/null +++ b/tflite/experimental/litert/cc/litert_tensor_buffer_test.cc @@ -0,0 +1,397 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" + +#include +#include + +#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_layout.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tflite/experimental/litert/runtime/ahwb_buffer.h" // IWYU pragma: keep +#include "tflite/experimental/litert/runtime/dmabuf_buffer.h" // IWYU pragma: keep +#include "tflite/experimental/litert/runtime/fastrpc_buffer.h" // IWYU pragma: keep +#include "tflite/experimental/litert/runtime/ion_buffer.h" // IWYU pragma: keep +#include "tflite/experimental/litert/runtime/tensor_buffer.h" + +namespace { +constexpr const float kTensorData[] = {10, 20, 30, 40}; + +constexpr const int32_t kTensorDimensions[] = {sizeof(kTensorData) / + sizeof(kTensorData[0])}; + +constexpr const LiteRtRankedTensorType kTensorType = { + /*.element_type=*/kLiteRtElementTypeFloat32, + ::litert::BuildLayout(kTensorDimensions)}; +} // namespace + +int GetReferenceCount(const litert::TensorBuffer& tensor_buffer) { + LiteRtTensorBufferT* internal_tensor_buffer = + static_cast(tensor_buffer.Get()); + return internal_tensor_buffer->RefCount(); +} + +TEST(TensorBuffer, HostMemory) { + const litert::RankedTensorType kTensorType(::kTensorType); + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeHostMemory; + + auto tensor_buffer = litert::TensorBuffer::CreateManaged( + kTensorBufferType, kTensorType, sizeof(kTensorData)); + ASSERT_TRUE(tensor_buffer); + + auto tensor_buffer_type = tensor_buffer->BufferType(); + ASSERT_TRUE(tensor_buffer_type); + ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); + + auto tensor_type = tensor_buffer->TensorType(); + ASSERT_TRUE(tensor_type); + + ASSERT_EQ(tensor_type->ElementType(), litert::ElementType::Float32); + ASSERT_EQ(tensor_type->Layout().Rank(), 1); + ASSERT_EQ(tensor_type->Layout().Dimensions()[0], + kTensorType.Layout().Dimensions()[0]); + ASSERT_FALSE(tensor_type->Layout().HasStrides()); + + auto size = tensor_buffer->Size(); + ASSERT_TRUE(size); + ASSERT_EQ(*size, sizeof(kTensorData)); + + auto offset = tensor_buffer->Offset(); + ASSERT_TRUE(offset); + ASSERT_EQ(*offset, 0); + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr); + std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); + } + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr); + ASSERT_EQ( + std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), + 0); + } +} + +TEST(TensorBuffer, Ahwb) { + if (!litert::internal::AhwbBuffer::IsSupported()) { + GTEST_SKIP() << "AHardwareBuffers are not supported on this platform; " + "skipping the test"; + } + + const litert::RankedTensorType kTensorType(::kTensorType); + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeAhwb; + + auto tensor_buffer = litert::TensorBuffer::CreateManaged( + kTensorBufferType, kTensorType, sizeof(kTensorData)); + ASSERT_TRUE(tensor_buffer); + + auto tensor_buffer_type = tensor_buffer->BufferType(); + ASSERT_TRUE(tensor_buffer_type); + ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); + + auto tensor_type = tensor_buffer->TensorType(); + ASSERT_TRUE(tensor_type); + + ASSERT_EQ(tensor_type->ElementType(), litert::ElementType::Float32); + ASSERT_EQ(tensor_type->Layout().Rank(), 1); + ASSERT_EQ(tensor_type->Layout().Dimensions()[0], + kTensorType.Layout().Dimensions()[0]); + ASSERT_FALSE(tensor_type->Layout().HasStrides()); + + auto size = tensor_buffer->Size(); + ASSERT_TRUE(size); + ASSERT_EQ(*size, sizeof(kTensorData)); + + auto offset = tensor_buffer->Offset(); + ASSERT_TRUE(offset); + ASSERT_EQ(*offset, 0); + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr); + std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); + } + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr); + ASSERT_EQ( + std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), + 0); + } +} + +TEST(TensorBuffer, Ion) { + if (!litert::internal::IonBuffer::IsSupported()) { + GTEST_SKIP() + << "ION buffers are not supported on this platform; skipping the test"; + } + + const litert::RankedTensorType kTensorType(::kTensorType); + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeIon; + + auto tensor_buffer = litert::TensorBuffer::CreateManaged( + kTensorBufferType, kTensorType, sizeof(kTensorData)); + ASSERT_TRUE(tensor_buffer); + + auto tensor_buffer_type = tensor_buffer->BufferType(); + ASSERT_TRUE(tensor_buffer_type); + ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); + + auto tensor_type = tensor_buffer->TensorType(); + ASSERT_TRUE(tensor_type); + + ASSERT_EQ(tensor_type->ElementType(), litert::ElementType::Float32); + ASSERT_EQ(tensor_type->Layout().Rank(), 1); + ASSERT_EQ(tensor_type->Layout().Dimensions()[0], + kTensorType.Layout().Dimensions()[0]); + ASSERT_FALSE(tensor_type->Layout().HasStrides()); + + auto size = tensor_buffer->Size(); + ASSERT_TRUE(size); + ASSERT_EQ(*size, sizeof(kTensorData)); + + auto offset = tensor_buffer->Offset(); + ASSERT_TRUE(offset); + ASSERT_EQ(*offset, 0); + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr); + std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); + } + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr); + ASSERT_EQ( + std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), + 0); + } +} + +TEST(TensorBuffer, DmaBuf) { + if (!litert::internal::DmaBufBuffer::IsSupported()) { + GTEST_SKIP() + << "DMA-BUF buffers are not supported on this platform; skipping " + "the test"; + } + + const litert::RankedTensorType kTensorType(::kTensorType); + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeDmaBuf; + + auto tensor_buffer = litert::TensorBuffer::CreateManaged( + kTensorBufferType, kTensorType, sizeof(kTensorData)); + ASSERT_TRUE(tensor_buffer); + + auto tensor_buffer_type = tensor_buffer->BufferType(); + ASSERT_TRUE(tensor_buffer_type); + ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); + + auto tensor_type = tensor_buffer->TensorType(); + ASSERT_TRUE(tensor_type); + + ASSERT_EQ(tensor_type->ElementType(), litert::ElementType::Float32); + ASSERT_EQ(tensor_type->Layout().Rank(), 1); + ASSERT_EQ(tensor_type->Layout().Dimensions()[0], + kTensorType.Layout().Dimensions()[0]); + ASSERT_FALSE(tensor_type->Layout().HasStrides()); + + auto size = tensor_buffer->Size(); + ASSERT_TRUE(size); + ASSERT_EQ(*size, sizeof(kTensorData)); + + auto offset = tensor_buffer->Offset(); + ASSERT_TRUE(offset); + ASSERT_EQ(*offset, 0); + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr); + std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); + } + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr); + ASSERT_EQ( + std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), + 0); + } +} + +TEST(TensorBuffer, FastRpc) { + if (!litert::internal::FastRpcBuffer::IsSupported()) { + GTEST_SKIP() + << "FastRPC buffers are not supported on this platform; skipping " + "the test"; + } + + const litert::RankedTensorType kTensorType(::kTensorType); + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeFastRpc; + + auto tensor_buffer = litert::TensorBuffer::CreateManaged( + kTensorBufferType, kTensorType, sizeof(kTensorData)); + ASSERT_TRUE(tensor_buffer); + + auto tensor_buffer_type = tensor_buffer->BufferType(); + ASSERT_TRUE(tensor_buffer_type); + ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); + + auto tensor_type = tensor_buffer->TensorType(); + ASSERT_TRUE(tensor_type); + + ASSERT_EQ(tensor_type->ElementType(), litert::ElementType::Float32); + ASSERT_EQ(tensor_type->Layout().Rank(), 1); + ASSERT_EQ(tensor_type->Layout().Dimensions()[0], + kTensorType.Layout().Dimensions()[0]); + ASSERT_FALSE(tensor_type->Layout().HasStrides()); + + auto size = tensor_buffer->Size(); + ASSERT_TRUE(size); + ASSERT_EQ(*size, sizeof(kTensorData)); + + auto offset = tensor_buffer->Offset(); + ASSERT_TRUE(offset); + ASSERT_EQ(*offset, 0); + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr); + std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); + } + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr); + ASSERT_EQ( + std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), + 0); + } +} + +TEST(TensorBuffer, NotOwned) { + LiteRtTensorBuffer litert_tensor_buffer; + ASSERT_EQ(LiteRtCreateManagedTensorBuffer(kLiteRtTensorBufferTypeHostMemory, + &kTensorType, sizeof(kTensorData), + &litert_tensor_buffer), + kLiteRtStatusOk); + + litert::TensorBuffer tensor_buffer(litert_tensor_buffer, /*owned=*/false); + ASSERT_EQ(tensor_buffer.Get(), litert_tensor_buffer); + + LiteRtDestroyTensorBuffer(litert_tensor_buffer); +} + +TEST(TensorBuffer, Duplicate) { + LiteRtTensorBuffer litert_tensor_buffer; + ASSERT_EQ(LiteRtCreateManagedTensorBuffer(kLiteRtTensorBufferTypeHostMemory, + &kTensorType, sizeof(kTensorData), + &litert_tensor_buffer), + kLiteRtStatusOk); + + litert::TensorBuffer tensor_buffer(litert_tensor_buffer, /*owned=*/true); + ASSERT_EQ(GetReferenceCount(tensor_buffer), 1); + { + auto duplicated_tensor_buffer = tensor_buffer.Duplicate(); + ASSERT_TRUE(duplicated_tensor_buffer); + ASSERT_EQ(GetReferenceCount(*duplicated_tensor_buffer), 2); + // The duplicated tensor buffer should point to the same underlying + // LiteRtTensorBuffer object. + ASSERT_EQ(duplicated_tensor_buffer->Get(), tensor_buffer.Get()); + + // Update tensor buffer using the duplicated tensor buffer. + auto lock_and_addr = + litert::TensorBufferScopedLock::Create(*duplicated_tensor_buffer); + ASSERT_TRUE(lock_and_addr); + std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); + + // When the scope ends, the duplicated tensor buffer should be destroyed. + // This should not affect the original tensor buffer. + } + + ASSERT_EQ(GetReferenceCount(tensor_buffer), 1); + // Check that the original tensor buffer is not affected. + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(tensor_buffer); + ASSERT_TRUE(lock_and_addr); + ASSERT_EQ( + std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), + 0); + } +} + +TEST(TensorBuffer, ReadWriteBasic) { + LiteRtTensorBuffer litert_tensor_buffer; + ASSERT_EQ(LiteRtCreateManagedTensorBuffer(kLiteRtTensorBufferTypeHostMemory, + &kTensorType, sizeof(kTensorData), + &litert_tensor_buffer), + kLiteRtStatusOk); + + litert::TensorBuffer tensor_buffer(litert_tensor_buffer, /*owned=*/true); + auto write_success = tensor_buffer.Write(absl::MakeSpan( + kTensorData, sizeof(kTensorData) / sizeof(kTensorData[0]))); + ASSERT_TRUE(write_success); + float read_data[sizeof(kTensorData) / sizeof(kTensorData[0])]; + auto read_success = tensor_buffer.Read(absl::MakeSpan(read_data)); + ASSERT_TRUE(read_success); + ASSERT_EQ(std::memcmp(read_data, kTensorData, sizeof(kTensorData)), 0); +} + +TEST(TensorBuffer, ReadWriteBufferSizeMismatch) { + LiteRtTensorBuffer litert_tensor_buffer; + ASSERT_EQ(LiteRtCreateManagedTensorBuffer(kLiteRtTensorBufferTypeHostMemory, + &kTensorType, sizeof(kTensorData), + &litert_tensor_buffer), + kLiteRtStatusOk); + + litert::TensorBuffer tensor_buffer(litert_tensor_buffer, /*owned=*/true); + { + // Write with smaller size of data. + auto write_success = + tensor_buffer.Write(absl::MakeSpan(kTensorData, 1)); + ASSERT_TRUE(write_success); + } + { + constexpr const float big_data[] = {10, 20, 30, 40, 50}; + // Write with larger size of data. + auto write_success = + tensor_buffer.Write(absl::MakeSpan(big_data, 5)); + ASSERT_FALSE(write_success); + } + auto write_success = tensor_buffer.Write(absl::MakeSpan( + kTensorData, sizeof(kTensorData) / sizeof(kTensorData[0]))); + ASSERT_TRUE(write_success); + { + // Read with smaller size of buffer. + float read_data[1]; + auto read_success = tensor_buffer.Read(absl::MakeSpan(read_data, 1)); + ASSERT_TRUE(read_success); + ASSERT_EQ(read_data[0], kTensorData[0]); + } + { + // Read with larger size of buffer. + float read_data[5]; + auto read_success = tensor_buffer.Read(absl::MakeSpan(read_data, 5)); + ASSERT_FALSE(read_success); + } +} diff --git a/tflite/experimental/litert/compiler/BUILD b/tflite/experimental/litert/compiler/BUILD new file mode 100644 index 00000000..e3809c5e --- /dev/null +++ b/tflite/experimental/litert/compiler/BUILD @@ -0,0 +1,18 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) diff --git a/tflite/experimental/litert/compiler/plugin/BUILD b/tflite/experimental/litert/compiler/plugin/BUILD new file mode 100644 index 00000000..3c322ef5 --- /dev/null +++ b/tflite/experimental/litert/compiler/plugin/BUILD @@ -0,0 +1,103 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "compiler_plugin", + srcs = ["compiler_plugin.cc"], + hdrs = ["compiler_plugin.h"], + deps = [ + ":algo", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/cc:litert_detail", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/core:byte_code_util", + "//tflite/experimental/litert/core:dynamic_loading", + "//tflite/experimental/litert/core:filesystem", + "//tflite/experimental/litert/core/model", + "//tflite/experimental/litert/vendors/c:litert_compiler_plugin", + "//tflite/experimental/litert/vendors/c:litert_compiler_plugin_api", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +# copybara:uncomment_begin(no OSS for unique-test-directory) +# cc_test( +# name = "compiler_plugin_test", +# srcs = ["compiler_plugin_test.cc"], +# data = [ +# "//tflite/experimental/litert/test:mlir_test_data", +# "//tflite/experimental/litert/vendors/examples:example_plugin_so", +# ], +# tags = [ +# # Sanitizer runtimes are incompatible with RTLD_DEEPBIND. +# "noasan", +# "nomsan", +# "nosan", +# ], +# deps = [ +# ":compiler_plugin", +# "@com_google_googletest//:gtest_main", +# "//testing/base/public:unique-test-directory", +# "@com_google_absl//absl/strings:string_view", +# "//tflite/experimental/litert/c:litert_op_code", +# "//tflite/experimental/litert/core:filesystem", +# "//tflite/experimental/litert/test:common", +# "//tflite/experimental/litert/tools:dump", +# ], +# ) +# copybara:uncomment_end + +cc_library( + name = "algo", + srcs = ["algo.cc"], + hdrs = ["algo.h"], + deps = [ + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/core/model", + "//tflite/experimental/litert/core/model:model_graph", + "//tflite/schema:schema_fbs", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", + "@llvm-project//llvm:Support", + ], +) + +cc_test( + name = "algo_test", + srcs = ["algo_test.cc"], + data = [ + "//tflite/experimental/litert/test:mlir_test_data", + ], + deps = [ + ":algo", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/core/model", + "//tflite/experimental/litert/core/model:graph_validation", + "//tflite/experimental/litert/test:common", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tflite/experimental/litert/compiler/plugin/algo.cc b/tflite/experimental/litert/compiler/plugin/algo.cc new file mode 100644 index 00000000..d3bee3f2 --- /dev/null +++ b/tflite/experimental/litert/compiler/plugin/algo.cc @@ -0,0 +1,259 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/compiler/plugin/algo.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/MapVector.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/core/model/model_graph.h" +#include "tflite/schema/schema_generated.h" + +namespace litert::internal { +namespace { + +void MakeDispatchOp(LiteRtOpT& op) { + ABSL_DCHECK(op.Inputs().empty()); + ABSL_DCHECK(op.Outputs().empty()); + op.SetOpCode(kLiteRtOpCodeTflCustom); + detail::SetTflOpCodeInd(op, detail::kDispatchOpCodeTflInd); + op.ClearCustomOptions(); +} + +// +// flatlist to partition(s) +//===----------------------------------------------------------------------===// + +class DisjointSets { + public: + static std::vector> GetPartitionsFromFlatList( + const std::vector& flat_op_list); + + private: + void Insert(LiteRtOp op, LiteRtOp parent); + std::vector> GetBuckets(); + LiteRtOp GetBucket(LiteRtOp op); + // NOLINTBEGIN + llvm::MapVector map_; + // NOLINTEND +}; + +std::vector> DisjointSets::GetPartitionsFromFlatList( + const std::vector& flat_op_list) { + DisjointSets disjoint_sets; + for (auto* op : flat_op_list) { + disjoint_sets.map_[op] = op; + } + + for (auto* op : flat_op_list) { + for (auto* output : op->Outputs()) { + for (auto* user : output->Users()) { + if (disjoint_sets.map_.count(user) == 0) { + continue; + } + disjoint_sets.Insert(op, user); + } + } + } + + return disjoint_sets.GetBuckets(); +} + +void DisjointSets::Insert(LiteRtOp op, LiteRtOp parent) { + auto* parent_bucket = GetBucket(parent); + auto* op_bucket = GetBucket(op); + if (op_bucket == parent_bucket) { + return; + } + map_[op_bucket] = parent_bucket; +} + +// Get all disjoint sets. +std::vector> DisjointSets::GetBuckets() { + // NOLINTBEGIN + std::unordered_map> invert_map; + // NOLINTEND + for (const auto& entry : map_) { + auto* bucket = GetBucket(entry.first); + + if (invert_map.find(bucket) == invert_map.end()) { + invert_map.insert_or_assign(bucket, std::vector{}); + } + + invert_map[bucket].push_back(entry.first); + } + + std::vector> res; + res.reserve(invert_map.size()); + + for (auto& entry : invert_map) { + res.push_back(std::move(entry.second)); + } + + return res; +} + +// Gets the pointer which serves as the key for given ops bucket. Collapses +// paths to amortize. +LiteRtOp DisjointSets::GetBucket(LiteRtOp op) { + auto* parent = map_[op]; + if (op != parent) { + parent = GetBucket(parent); + map_[op] = parent; + } + return parent; +} + +// +// slice partitions out of a subgraph (into new subgraphs) +//===----------------------------------------------------------------------===// + +class GraphSlicer { + public: + // Slices "partitions" from "root" into the empty subgraph "slice". Assumes + // the partition is a valid sub-DAG, and replaces it witha single + // tfl.custom_op in "root". A reference to that op is returned. + static LiteRtOp SlicePartitionFromGraph(LiteRtSubgraphT& root, + LiteRtSubgraph slice, + std::vector& partition); + + private: + explicit GraphSlicer(LiteRtSubgraph slice) : slice_(slice) {} + + void CloneInto(const LiteRtOpT& op); + + void RerouteTensorsThroughCustomOp(const LiteRtSubgraphT& root); + + LiteRtSubgraph slice_; + // Maps tensor in old subgraph to tensor in new subgraph. + // NOLINTBEGIN + llvm::MapVector tensor_map_; + // NOLINTEND + LiteRtOp dispatch_op_ = nullptr; +}; + +LiteRtOp GraphSlicer::SlicePartitionFromGraph( + LiteRtSubgraphT& root, LiteRtSubgraph slice, + std::vector& partition) { + GraphSlicer slicer(slice); + + // Register input tensors of the sliced partition WRT to their original order + // in the root subgraph. This ensures the order of input tensors of the + // later outlined custom op is the same as the order of input tensors of the + // GraphInputs. + absl::flat_hash_set used_tensors; + + // Get all tensors used in the partition. + for (auto* op : partition) { + used_tensors.insert(op->Inputs().cbegin(), op->Inputs().cend()); + } + for (auto* old_input : root.Inputs()) { + if (used_tensors.contains(old_input)) { + auto* new_input = &MakeClone(*slicer.slice_, *old_input); + slicer.slice_->Inputs().push_back(new_input); + slicer.tensor_map_.insert({old_input, new_input}); + } + } + + for (auto* op : partition) { + slicer.CloneInto(*op); + } + + for (auto* op : partition) { + Drop(*op); + } + + // Reuse the storage from the last op in partition to maintain + // toplogical order. + slicer.dispatch_op_ = partition.back(); + MakeDispatchOp(*slicer.dispatch_op_); + slicer.RerouteTensorsThroughCustomOp(root); + + DCE(root); + + return slicer.dispatch_op_; +} + +void GraphSlicer::RerouteTensorsThroughCustomOp(const LiteRtSubgraphT& root) { + for (auto& [old_tensor, new_tensor] : tensor_map_) { + // Reroute tensors which need to be passed into the scope of the new + // subgraph to inputs of the custom op. + if (new_tensor->DefiningOp() == nullptr) { + AttachInput(old_tensor, *dispatch_op_); + continue; + } + + // Reroute custom op as the definer of tensors within the removed partition + // and referenced later in the root graph. + if (!old_tensor->Users().empty() || FindOutput(root, *old_tensor)) { + AttachOutput(old_tensor, *dispatch_op_); + slice_->Outputs().push_back(new_tensor); + } + } +} + +void GraphSlicer::CloneInto(const LiteRtOpT& old_op) { + auto& new_op = MakeClone(*slice_, old_op); + + for (auto i = 0; i < old_op.NumInputs(); ++i) { + auto* old_input = old_op.Inputs().at(i); + LiteRtTensor new_input; + if (tensor_map_.contains(old_input)) { + // If old_input is already in the map then map[input] is its cloned + // counterpart in the new graph. + new_input = tensor_map_[old_input]; + } else { + // Otherwise, it must be a new subgraph input. + new_input = &MakeClone(*slice_, *old_input); + slice_->Inputs().push_back(new_input); + tensor_map_.insert({old_input, new_input}); + } + + AttachInput(new_input, new_op); + } + + for (int i = 0; i < old_op.NumOutputs(); ++i) { + auto* old_output = old_op.Outputs().at(i); + auto* new_output = &MakeClone(*slice_, *old_output); + AttachOutput(new_output, new_op); + + // Update the values defined in scope of the new subgraph. + tensor_map_.insert({old_output, new_output}); + } +} + +} // namespace + +std::vector> GroupPartitions( + const std::vector& ops) { + return DisjointSets::GetPartitionsFromFlatList(ops); +} + +LiteRtOp OutlinePartition(LiteRtSubgraphT& root, LiteRtSubgraph slice, + std::vector& partition) { + return GraphSlicer::SlicePartitionFromGraph(root, slice, partition); +} + +} // namespace litert::internal diff --git a/tflite/experimental/litert/compiler/plugin/algo.h b/tflite/experimental/litert/compiler/plugin/algo.h new file mode 100644 index 00000000..5ba5df5b --- /dev/null +++ b/tflite/experimental/litert/compiler/plugin/algo.h @@ -0,0 +1,39 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_ALGO_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_ALGO_H_ + +#include + +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/core/model/model.h" + +namespace litert::internal { + +// Identifies sub-DAGs of ops connected w.r.t. the use-def chain. Expects +// all "ops" belong to the same Subgraph. The ops in the input +// and output will always be the same. +std::vector> GroupPartitions( + const std::vector& ops); + +// Outlines "partition" from "root" into the empty subgraph "slice". Assumes +// the partition is a valid sub-DAG, and replaces it with a single +// tfl.custom_op in "root". A reference to that op is returned. +LiteRtOp OutlinePartition(LiteRtSubgraphT& root, LiteRtSubgraph slice, + std::vector& partition); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_ALGO_H_ diff --git a/tflite/experimental/litert/compiler/plugin/algo_test.cc b/tflite/experimental/litert/compiler/plugin/algo_test.cc new file mode 100644 index 00000000..fae70201 --- /dev/null +++ b/tflite/experimental/litert/compiler/plugin/algo_test.cc @@ -0,0 +1,246 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/compiler/plugin/algo.h" + +#include + +#include +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/cc/litert_model_predicates.h" +#include "tflite/experimental/litert/core/model/graph_validation.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/test/common.h" + +namespace litert::internal { +namespace { + +TEST(TestPartitionsFromFlatList, SimpleMultiOp) { + auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + + auto ops = subgraph->Ops(); + + // func.func @main(arg0) + // 0 = tfl.add arg0, arg0 + // 1 = tfl.mul 0, 0 + // 2 = tfl.mul 1, 1 + // 3 = tfl.add 2, 2 + // return 3 + + { + std::vector partition; + partition.push_back(ops.at(1).Get()); + partition.push_back(ops.at(2).Get()); + + auto partitions = GroupPartitions(partition); + ASSERT_EQ(partitions.size(), 1); + ASSERT_EQ(partitions.front().size(), 2); + + EXPECT_EQ(partitions.front().at(0), partition.at(0)); + EXPECT_EQ(partitions.front().at(1), partition.at(1)); + } + + { + std::vector partition; + partition.push_back(ops.at(1).Get()); + partition.push_back(ops.at(3).Get()); + + auto partitions = GroupPartitions(partition); + ASSERT_EQ(partitions.size(), 2); + ASSERT_EQ(partitions.front().size(), 1); + ASSERT_EQ(partitions.back().size(), 1); + + auto p1_op_code = partitions.front().front()->OpCode(); + auto p2_op_code = partitions.back().front()->OpCode(); + + ASSERT_TRUE((p1_op_code == kLiteRtOpCodeTflMul && + p2_op_code == kLiteRtOpCodeTflAdd) || + (p1_op_code == kLiteRtOpCodeTflAdd && + p2_op_code == kLiteRtOpCodeTflMul)); + } + + { + std::vector partition; + + auto partitions = GroupPartitions(partition); + ASSERT_EQ(partitions.size(), 0); + } + + { + std::vector partition; + partition.push_back(ops.at(0).Get()); + partition.push_back(ops.at(1).Get()); + partition.push_back(ops.at(2).Get()); + partition.push_back(ops.at(3).Get()); + + auto partitions = GroupPartitions(partition); + ASSERT_EQ(partitions.size(), 1); + ASSERT_EQ(partitions.front().size(), 4); + + EXPECT_EQ(partitions.front().at(0), partition.at(0)); + EXPECT_EQ(partitions.front().at(1), partition.at(1)); + EXPECT_EQ(partitions.front().at(2), partition.at(2)); + EXPECT_EQ(partitions.front().at(3), partition.at(3)); + } +} + +TEST(TestSliceSubgraphSimpleMultiOp, OnePartition) { + auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + + auto ops = subgraph->Ops(); + + // func.func @main(arg0) + // 0 = tfl.add arg0, arg0 + // 1 = tfl.mul 0, 0 + // 2 = tfl.mul 1, 1 + // 3 = tfl.add 2, 2 + // return 3 + + std::vector partition; + partition.push_back(ops.at(1).Get()); + partition.push_back(ops.at(2).Get()); + + auto sliced_graph = litert::Subgraph(&model.Get()->EmplaceSubgraph()); + auto* dispatch_op = + OutlinePartition(*subgraph->Get(), sliced_graph.Get(), partition); + + const auto& internal_sliced = *sliced_graph.Get(); + ASSERT_TRUE(ValidateSubgraphIO(internal_sliced)); + ASSERT_TRUE(ValidateLocalTopology(internal_sliced.Ops().cbegin(), + internal_sliced.Ops().cend())); + + auto edited_subgraph_ops = subgraph->Ops(); + + ASSERT_EQ(edited_subgraph_ops.size(), 3); + ASSERT_EQ(edited_subgraph_ops.at(0).Code(), kLiteRtOpCodeTflAdd); + ASSERT_EQ(edited_subgraph_ops.at(1).Code(), kLiteRtOpCodeTflCustom); + ASSERT_EQ(edited_subgraph_ops.at(2).Code(), kLiteRtOpCodeTflAdd); + + auto sliced_subgraph_ops = sliced_graph.Ops(); + + ASSERT_EQ(sliced_subgraph_ops.size(), 2); + ASSERT_EQ(sliced_subgraph_ops[0].Code(), kLiteRtOpCodeTflMul); + ASSERT_EQ(sliced_subgraph_ops[1].Code(), kLiteRtOpCodeTflMul); + + ASSERT_EQ(dispatch_op, edited_subgraph_ops.at(1).Get()); + const Op hal_call(dispatch_op); + + { + const auto dispatch_op_ins = hal_call.Inputs(); + + ASSERT_EQ(dispatch_op_ins.size(), 1); + + auto hal_input_defining_op = dispatch_op_ins.front().DefiningOp(); + ASSERT_EQ(hal_input_defining_op->op, edited_subgraph_ops.at(0).Get()); + ASSERT_EQ(hal_input_defining_op->op_output_index, 0); + + const auto sliced_subgraph_inputs = sliced_graph.Inputs(); + + ASSERT_EQ(sliced_subgraph_inputs.size(), 1); + + ASSERT_TRUE(MatchUses(sliced_subgraph_inputs.front(), + {UseInfo{sliced_subgraph_ops.front().Code(), 0}, + UseInfo{sliced_subgraph_ops.front().Code(), 0}})); + ASSERT_TRUE(sliced_subgraph_inputs.front().IsSubgraphInput()); + } + + { + const auto hal_call_outs = hal_call.Outputs(); + ASSERT_EQ(hal_call_outs.size(), 1); + const auto& hal_call_out = hal_call_outs.front(); + + ASSERT_TRUE(MatchUses(hal_call_out, + {UseInfo{edited_subgraph_ops.back().Code(), 0}, + UseInfo{edited_subgraph_ops.back().Code(), 1}})); + + auto sliced_subgraph_outputs = sliced_graph.Outputs(); + + ASSERT_EQ(sliced_subgraph_outputs.size(), 1); + + const auto defining_op = sliced_subgraph_outputs.front().DefiningOp(); + ASSERT_EQ(defining_op->op, sliced_subgraph_ops.back().Get()); + ASSERT_EQ(defining_op->op_output_index, 0); + + ASSERT_TRUE(sliced_subgraph_outputs.front().Uses().empty()); + } +} + +TEST(TestSliceSubgraphSimpleMultiOp, TwoPartitions) { + auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + + auto ops = subgraph->Ops(); + + // func.func @main(arg0) + // 0 = tfl.add arg0, arg0 + // 1 = tfl.mul 0, 0 + // 2 = tfl.mul 1, 1 + // 3 = tfl.add 2, 2 + // return 3 + + std::vector partition_1; + partition_1.push_back(ops.at(0).Get()); + + auto sliced_graph_1 = litert::Subgraph(&model.Get()->EmplaceSubgraph()); + OutlinePartition(*(subgraph->Get()), sliced_graph_1.Get(), partition_1); + + const auto& internal_slice_1 = *sliced_graph_1.Get(); + ASSERT_TRUE(ValidateSubgraphIO(internal_slice_1)); + ASSERT_TRUE(ValidateLocalTopology(internal_slice_1.Ops().cbegin(), + internal_slice_1.Ops().cend())); + + std::vector partition_2; + partition_2.push_back(ops.at(2).Get()); + partition_2.push_back(ops.at(3).Get()); + + auto sliced_graph_2 = litert::Subgraph(&model.Get()->EmplaceSubgraph()); + OutlinePartition(*(subgraph->Get()), sliced_graph_2.Get(), partition_2); + + const auto& internal_slice_2 = *sliced_graph_2.Get(); + ASSERT_TRUE(ValidateSubgraphIO(internal_slice_2)); + ASSERT_TRUE(ValidateLocalTopology(internal_slice_2.Ops().cbegin(), + internal_slice_2.Ops().cend())); + + auto edited_subgraph_ops = subgraph->Ops(); + + ASSERT_EQ(edited_subgraph_ops.size(), 3); + ASSERT_EQ(edited_subgraph_ops.at(0).Code(), kLiteRtOpCodeTflCustom); + ASSERT_EQ(edited_subgraph_ops.at(1).Code(), kLiteRtOpCodeTflMul); + ASSERT_EQ(edited_subgraph_ops.at(2).Code(), kLiteRtOpCodeTflCustom); + + { + auto sliced_ops = sliced_graph_1.Ops(); + + ASSERT_EQ(sliced_ops.size(), 1); + ASSERT_EQ(sliced_ops.at(0).Code(), kLiteRtOpCodeTflAdd); + } + + { + auto sliced_ops = sliced_graph_2.Ops(); + + ASSERT_EQ(sliced_ops.size(), 2); + ASSERT_EQ(sliced_ops.at(0).Code(), kLiteRtOpCodeTflMul); + ASSERT_EQ(sliced_ops.at(1).Code(), kLiteRtOpCodeTflAdd); + } +} + +} // namespace +} // namespace litert::internal diff --git a/tflite/experimental/litert/compiler/plugin/compiler_plugin.cc b/tflite/experimental/litert/compiler/plugin/compiler_plugin.cc new file mode 100644 index 00000000..b360adf9 --- /dev/null +++ b/tflite/experimental/litert/compiler/plugin/compiler_plugin.cc @@ -0,0 +1,409 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/compiler/plugin/compiler_plugin.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_detail.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/compiler/plugin/algo.h" +#include "tflite/experimental/litert/core/byte_code_util.h" +#include "tflite/experimental/litert/core/dynamic_loading.h" +#include "tflite/experimental/litert/core/filesystem.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/vendors/c/litert_compiler_plugin.h" +#include "tflite/experimental/litert/vendors/c/litert_compiler_plugin_api.h" + +namespace litert::internal { + +// +// CompiledResult +// + +Expected> CompiledResult::ByteCode() const { + const void* data; + size_t size; + LITERT_EXPECT_OK(allocating_plugin_api_.get_compiled_result_byte_code( + compiled_result_handle_, &data, &size)); + return BufferRef(data, size); +} + +Expected CompiledResult::NumCalls() const { + LiteRtParamIndex call_idx; + LITERT_EXPECT_OK(allocating_plugin_api_.get_compiled_result_num_calls( + compiled_result_handle_, &call_idx)); + return call_idx; +} + +Expected CompiledResult::CallInfo( + LiteRtParamIndex call_idx) const { + const void* data; + size_t size; + LITERT_EXPECT_OK(allocating_plugin_api_.get_compiled_result_call_info( + compiled_result_handle_, call_idx, &data, &size)); + return std::string(reinterpret_cast(data), size); +} + +CompiledResult::~CompiledResult() { + allocating_plugin_api_.destroy_compiled_result(compiled_result_handle_); +} + +// +// CompilerPlugin +// + +namespace { + +#define RESOLVE_API_FUNC(name, dest) \ + LITERT_RETURN_STATUS_IF_NOT_OK( \ + ResolveLibSymbol(lib_handle, name, &dest)); + +LiteRtStatus ResolvePluginApi(void* lib_handle, + LiteRtCompilerPluginApi& result) { + RESOLVE_API_FUNC(kLiteRtGetCompilerPluginVersion, + result.get_compiler_plugin_version); + RESOLVE_API_FUNC(kLiteRtGetCompilerPluginSocManufacturer, + result.get_compiler_plugin_soc_manufacturer); + RESOLVE_API_FUNC(kLiteRtGetNumCompilerPluginSupportedSocModels, + result.get_num_compiler_plugin_supported_models); + RESOLVE_API_FUNC(kLiteRtGetCompilerPluginSupportedSocModel, + result.get_compiler_plugin_supported_soc_model); + + RESOLVE_API_FUNC(kLiteRtCreateCompilerPlugin, result.create_compiler_plugin); + RESOLVE_API_FUNC(kLiteRtDestroyCompilerPlugin, + result.destroy_compiler_plugin); + + RESOLVE_API_FUNC(kLiteRtCompilerPluginPartition, + result.compiler_plugin_partition); + RESOLVE_API_FUNC(kLiteRtCompilerPluginCompile, + result.compiler_plugin_compile); + + RESOLVE_API_FUNC(kLiteRtDestroyCompiledResult, + result.destroy_compiled_result); + RESOLVE_API_FUNC(kLiteRtGetCompiledResultByteCode, + result.get_compiled_result_byte_code); + RESOLVE_API_FUNC(kLiteRtGetCompiledResultCallInfo, + result.get_compiled_result_call_info); + RESOLVE_API_FUNC(kLiteRtGetNumCompiledResultCalls, + result.get_compiled_result_num_calls); + + return kLiteRtStatusOk; +} + +Expected> GetSocModels( + const LiteRtCompilerPluginApi& api, LiteRtCompilerPlugin plugin_handle) { + SmallVec soc_models; + + LiteRtParamIndex num_models; + LITERT_EXPECT_OK( + api.get_num_compiler_plugin_supported_models(plugin_handle, &num_models)); + + for (LiteRtParamIndex i = 0; i < num_models; ++i) { + const char* model; + if (api.get_compiler_plugin_supported_soc_model(plugin_handle, i, &model) != + kLiteRtStatusOk) { + continue; + } + soc_models.push_back(std::string(model)); + } + + return soc_models; +} + +} // namespace + +Expected CompilerPlugin::LoadPlugin( + const absl::string_view lib_path) { + CompilerPlugin plugin; + LITERT_LOG(LITERT_INFO, "Loading plugin at: %s", lib_path.data()); + + LITERT_EXPECT_OK(OpenLib(lib_path, &plugin.lib_handle_)); + LITERT_LOG(LITERT_INFO, "Loaded plugin at: %s", lib_path.data()); + + LITERT_EXPECT_OK(ResolvePluginApi(plugin.lib_handle_, plugin.plugin_api_)); + LITERT_LOG(LITERT_INFO, "Resolved plugin api at: %s", lib_path.data()); + + LITERT_EXPECT_OK( + plugin.plugin_api_.create_compiler_plugin(&plugin.plugin_handle_)); + LITERT_LOG(LITERT_INFO, "Initialize plugin at: %s", lib_path.data()); + + auto api_version = plugin.ApiVersion(); + if (!api_version) { + return api_version.Error(); + } + + if (api_version->major != LITERT_API_VERSION_MAJOR) { + LITERT_LOG( + LITERT_ERROR, + "Unsupported Compiler Plugin version, found version %d.%d.%d and " + "expected version %d.%d.%d", + api_version.Value().major, api_version.Value().minor, + api_version.Value().patch, LITERT_API_VERSION_MAJOR, + LITERT_API_VERSION_MINOR, LITERT_API_VERSION_PATCH); + return Unexpected(kLiteRtStatusErrorRuntimeFailure); + } + + // This should never change throughout the lifetime of the compiler + // plugin so save to avoid recalling. + auto soc_models = GetSocModels(plugin.plugin_api_, plugin.plugin_handle_); + if (!soc_models) { + return soc_models.Error(); + } + plugin.soc_models_ = *soc_models; + + return plugin; +} + +Expected> CompilerPlugin::LoadPlugins( + absl::Span lib_search_paths) { + std::vector plugin_lib_paths; + for (auto search_path : lib_search_paths) { + // Skip paths that are not valid. + if (Exists(search_path)) { + LITERT_EXPECT_OK(FindLiteRtSharedLibs(search_path, plugin_lib_paths)); + } + } + + SmallVec loaded_plugins; + loaded_plugins.reserve(lib_search_paths.size()); + + for (const auto& lib_path : plugin_lib_paths) { + LITERT_LOG(LITERT_INFO, "Loading plugin at: %s", lib_path.c_str()); + auto plugin = LoadPlugin(lib_path); + if (!plugin.HasValue()) { + continue; + } + loaded_plugins.push_back(std::move(plugin.Value())); + } + + return loaded_plugins; +} + +Expected CompilerPlugin::LoadPlugin( + absl::Span lib_search_paths, + absl::string_view soc_manufacturer) { + auto compiler_plugins = LoadPlugins(lib_search_paths); + if (!compiler_plugins) { + return compiler_plugins.Error(); + } + + for (auto& plugin : *compiler_plugins) { + if (plugin.SocManufacturer() == soc_manufacturer) { + return std::move(plugin); + } + } + + return Error(kLiteRtStatusErrorNotFound); +} + +CompilerPlugin::CompilerPlugin(CompilerPlugin&& other) + : soc_models_(std::move(other.soc_models_)), + lib_handle_(other.lib_handle_), + plugin_api_(std::move(other.plugin_api_)), + plugin_handle_(other.plugin_handle_) { + other.soc_models_ = {}; + other.plugin_api_ = {}; + other.lib_handle_ = nullptr; + other.plugin_handle_ = nullptr; +} + +CompilerPlugin& CompilerPlugin::operator=(CompilerPlugin&& other) { + if (this != &other) { + soc_models_ = std::move(other.soc_models_); + other.soc_models_ = {}; + + lib_handle_ = other.lib_handle_; + other.lib_handle_ = nullptr; + + plugin_api_ = std::move(other.plugin_api_); + other.plugin_api_ = {}; + + plugin_handle_ = other.plugin_handle_; + other.plugin_handle_ = nullptr; + } + return *this; +} + +CompilerPlugin::~CompilerPlugin() { + if (plugin_handle_ != nullptr) { + plugin_api_.destroy_compiler_plugin(plugin_handle_); + } + if (lib_handle_ != nullptr) { + if (kLiteRtStatusOk != CloseLib(lib_handle_)) { + LITERT_LOG(LITERT_WARNING, "%s", "Failed to close shared library\n"); + } + } +} + +Expected CompilerPlugin::ApiVersion() const { + LiteRtApiVersion api_version; + LITERT_EXPECT_OK(plugin_api_.get_compiler_plugin_version(&api_version)); + return api_version; +} + +Expected> CompilerPlugin::Partition( + const Subgraph& subgraph) { + LiteRtOpListT ops; + LITERT_EXPECT_OK(plugin_api_.compiler_plugin_partition(plugin_handle_, + subgraph.Get(), &ops)); + return ops.Vec(); +} + +LiteRtStatus CompilerPlugin::Compile( + std::optional soc_model, + const std::vector& partitions, std::ostream& byte_code_out, + std::vector& call_info_out) { + CompiledResult result = MakeResult(); + + const char* soc_model_str = soc_model ? soc_model->data() : nullptr; + + // Compile given partitions into result. + // TODO: Use const where appropriate in the C compiler plugin api. + LiteRtSubgraphArray partitions_arr = + const_cast(partitions.data()); + if (auto stat = plugin_api_.compiler_plugin_compile( + plugin_handle_, soc_model_str, partitions_arr, partitions.size(), + &result.compiled_result_handle_); + stat != kLiteRtStatusOk) { + return stat; + } + + // Parse call info from the result. + { + auto num_call = result.NumCalls(); + if (!num_call) { + return num_call.Error().Status(); + } + if (num_call.Value() != partitions.size()) { + LITERT_LOG( + LITERT_ERROR, "%s", + "Plugin didn't return call info for each partition compiled.\n"); + return kLiteRtStatusErrorRuntimeFailure; + } + for (int i = 0; i < num_call.Value(); ++i) { + auto call_info = result.CallInfo(i); + if (!call_info) { + return call_info.Error().Status(); + } + call_info_out.emplace_back() = *call_info; + } + } + + // Parse byte code from result. + { + auto byte_code = result.ByteCode(); + if (!byte_code) { + return byte_code.Error().Status(); + } + LITERT_LOG(LITERT_INFO, "Compiled %d partitions in %lu bytes", + partitions.size(), byte_code->Size()); + byte_code->WriteStr(byte_code_out); + } + + return kLiteRtStatusOk; +} + +Expected> ApplyPlugin( + CompilerPlugin& compiler_plugin, Model& model, + std::optional soc_model) { + if (model.NumSubgraphs() != 1) { + // TODO(@lukeboyer) Finish support for multi-subgraph. + LITERT_LOG(LITERT_ERROR, "Apply currently supported for 1 subgraph"); + return Error(kLiteRtStatusErrorUnsupported); + } + + // Get selected ops from plugin. + auto partition = compiler_plugin.Partition(*model.Subgraph(0)); + if (!partition) { + LITERT_LOG(LITERT_ERROR, "Failed to get partitions from plugin"); + return Error(kLiteRtStatusErrorRuntimeFailure); + } + + // Group selected ops into partitions. + auto grouped_partitions = GroupPartitions(*partition); + if (grouped_partitions.empty()) { + LITERT_LOG(LITERT_ERROR, "Failed to group partitions"); + return Error(kLiteRtStatusErrorRuntimeFailure); + } + + if (grouped_partitions.size() > 1) { + LITERT_LOG(LITERT_ERROR, "Apply on multiple partitions not supported yet."); + return Error(kLiteRtStatusErrorUnsupported); + } + + // Outline the partitions into new subgraphs. + std::vector custom_ops; + for (auto& partition : grouped_partitions) { + auto custom_op = + OutlinePartition(*model.Get()->Subgraphs().front(), + &model.Get()->EmplaceSubgraph(), partition); + custom_ops.push_back(custom_op); + } + + // Pass new subgraphs to the plugin for compilation. + std::vector compilation_input; + auto begin = model.Get()->Subgraphs().begin(); + auto end = model.Get()->Subgraphs().end(); + for (auto it = begin + 1; it < end; ++it) { + compilation_input.push_back(*it); + } + + // Compile partitions with plugin. + std::stringstream byte_code; + std::vector exec_info; + if (auto status = compiler_plugin.Compile(soc_model, compilation_input, + byte_code, exec_info); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to compile partitions."); + return Error(status); + } + + if (exec_info.size() != custom_ops.size()) { + LITERT_LOG(LITERT_ERROR, + "Compilation did not return exec_info for every partition"); + return Error(kLiteRtStatusErrorRuntimeFailure); + } + + // Attach entry point info to the custom ops. + auto custom_op_it = custom_ops.begin(); + auto exec_info_it = exec_info.begin(); + for (; custom_op_it < custom_ops.end(); custom_op_it++, exec_info_it++) { + LiteRtOp custom_op = *custom_op_it; + const auto& exec_info = *exec_info_it; + custom_op->SetCustomOptions(exec_info.data()); + } + + const auto byte_code_str = byte_code.str(); + return OwningBufferRef( + reinterpret_cast(byte_code_str.data()), + byte_code_str.size()); +} + +} // namespace litert::internal diff --git a/tflite/experimental/litert/compiler/plugin/compiler_plugin.h b/tflite/experimental/litert/compiler/plugin/compiler_plugin.h new file mode 100644 index 00000000..5043486c --- /dev/null +++ b/tflite/experimental/litert/compiler/plugin/compiler_plugin.h @@ -0,0 +1,145 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_PLUGIN_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_PLUGIN_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_detail.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/c/litert_compiler_plugin.h" +#include "tflite/experimental/litert/vendors/c/litert_compiler_plugin_api.h" + +namespace litert::internal { + +class CompiledResult { + friend class CompilerPlugin; + // Get the single module of compiled byte code. This contains the + // compilation result for all entry points. + Expected> ByteCode() const; + + // Get information regarding the "ith" entry points in the compiled module. + // There will be oe entry point for each subgraph compiled for. + Expected CallInfo(LiteRtParamIndex call_idx) const; + + // Get the number of entry points in the compiled module. This will be equal + // to the number of subgraphs passed to the compilation step. + Expected NumCalls() const; + + explicit CompiledResult(const LiteRtCompilerPluginApi& allocating_plugin_api) + : allocating_plugin_api_(allocating_plugin_api) {} + + CompiledResult(CompiledResult&& other) = default; + CompiledResult& operator=(CompiledResult&& other) = default; + CompiledResult(const CompiledResult& other) = delete; + CompiledResult& operator=(const CompiledResult& other) = delete; + + ~CompiledResult(); + + LiteRtCompilerPluginApi allocating_plugin_api_; + LiteRtCompiledResult compiled_result_handle_ = nullptr; +}; + +// Syntatic sugar around dynamically loaded LiteRtCompilerPlugin libraries. +// TODO turn this into a general C++ wraper for the whole compiler plugin api. +class CompilerPlugin { + public: + // Get the compiler plugin's API version. + Expected ApiVersion() const; + + // Get the manufacturer associated with this plugin. NOTE: SocManufacturer + // string returned by the underlying plugin are expected to have static + // lifetime. + absl::string_view SocManufacturer() const { + return plugin_api_.get_compiler_plugin_soc_manufacturer(); + } + + // Get list of unique soc models targetable by this plugin. + const SmallVec& SocModels() const { return soc_models_; } + + // Selects ops for the plugin to compile. + Expected> Partition(const Subgraph& subgraph); + + // Compile given LiteRtSubgraphs. Write compiled byte code to the given + // stream. For each given subgraph, write opaque data about the corresponding + // entry point to the given "call_info_out". Parameter "soc_model" is optional + // and can be set to specify the target SoC; for on-device compilation it + // should be left unspecified so as to let the underlying logic pick the + // architecture that matches the SoC on the user device. + LiteRtStatus Compile(std::optional soc_model, + const std::vector& partitions, + std::ostream& byte_code_out, + std::vector& call_info_out); + + // Search for shared library files with prefix "libLiteRtCompilerPlugin" in + // the directories passed through "lib_search_paths". Populates + // "loaded_plugins" with resolved plugin apis for each found library that can + // be succesfully loaded. Additionally initializes the compiler plugin + // instances and stores handle. + static Expected> LoadPlugins( + absl::Span lib_search_paths); + + // Search for shared library files with prefix "libLiteRtCompilerPlugin" in + // the directories passed through "lib_search_paths" and return a compiler + // plugin instance for a given manufactured, if one is found. + static Expected LoadPlugin( + absl::Span lib_search_paths, + absl::string_view soc_manufacturer); + + CompilerPlugin(CompilerPlugin&& other); + CompilerPlugin& operator=(CompilerPlugin&& other); + CompilerPlugin(const CompilerPlugin& other) = delete; + CompilerPlugin& operator=(const CompilerPlugin& other) = delete; + + // Destroys any living `LiteRtCompilerPlugin` and frees reference + // to dynamically loaded library. + ~CompilerPlugin(); + + private: + static Expected LoadPlugin(absl::string_view lib_path); + CompilerPlugin() = default; + + SmallVec soc_models_; + void* lib_handle_ = nullptr; + LiteRtCompilerPluginApi plugin_api_ = {}; + LiteRtCompilerPlugin plugin_handle_ = nullptr; + + // Internal LiteRtCompiledResult wrapper. + + CompiledResult MakeResult() const { return CompiledResult(plugin_api_); } +}; + +// Applies the plugin's "partition" and "compile" steps to the given model. +// Returns the serialized model with NPU code appended to the back. Parameter +// "soc_model" is optional and can be set to specify the target SoC; for +// on-device compilation it should be left unspecified so as to let the +// underlying logic pick the architecture that matches the SoC on the user +// device +Expected> ApplyPlugin( + CompilerPlugin& compiler_plugin, Model& model, + std::optional soc_model = std::nullopt); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_PLUGIN_H_ diff --git a/tflite/experimental/litert/compiler/plugin/compiler_plugin_test.cc b/tflite/experimental/litert/compiler/plugin/compiler_plugin_test.cc new file mode 100644 index 00000000..b83cd9fe --- /dev/null +++ b/tflite/experimental/litert/compiler/plugin/compiler_plugin_test.cc @@ -0,0 +1,160 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/compiler/plugin/compiler_plugin.h" + +#include +#include +#include +#include + +#include +#include +#include "testing/base/public/unique-test-directory.h" +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/core/filesystem.h" +#include "tflite/experimental/litert/test/common.h" +#include "tflite/experimental/litert/tools/dump.h" + +namespace litert::internal { +namespace { + +using ::testing::UniqueTestDirectory; + +constexpr absl::string_view kTestPluginSearchPath = + "tflite/experimental/litert/vendors/examples"; + +constexpr absl::string_view kTestManufacturer = "ExampleSocManufacturer"; +constexpr absl::string_view kTestModels = "ExampleSocModel"; + +TEST(CompilerPluginTest, LoadTestPlugin) { + auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); + + ASSERT_EQ(plugins->size(), 1); + EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); + ASSERT_EQ(plugins->front().SocModels().size(), 1); + EXPECT_EQ(plugins->front().SocModels().front(), kTestModels); +} + +TEST(CompilerPluginTest, LoadTestPluginWithMalformed) { + const auto dir = UniqueTestDirectory(); + Touch(Join({dir, "notLibLiteRt.so"})); + + auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); + + ASSERT_EQ(plugins->size(), 1); + EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); +} + +TEST(CompilerPluginTest, MultipleValidPlugins) { + auto plugins = CompilerPlugin::LoadPlugins( + {kTestPluginSearchPath, kTestPluginSearchPath}); + + ASSERT_EQ(plugins->size(), 2); + EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); + EXPECT_EQ(plugins->back().SocManufacturer(), kTestManufacturer); +} + +TEST(CompilerPluginTest, MoveAssign) { + auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); + + ASSERT_EQ(plugins->size(), 1); + EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); + + CompilerPlugin other = std::move(plugins->front()); + + EXPECT_EQ(other.SocManufacturer(), kTestManufacturer); +} + +TEST(CompilerPluginTest, MoveConstruct) { + auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); + + ASSERT_EQ(plugins->size(), 1); + EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); + + CompilerPlugin other(std::move(plugins->front())); + + EXPECT_EQ(other.SocManufacturer(), kTestManufacturer); +} + +TEST(CompilerPluginTest, SocModels) { + auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); + ASSERT_EQ(plugins->size(), 1); + EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); + + EXPECT_THAT(plugins->front().SocModels(), + ::testing::ElementsAreArray({kTestModels})); +} + +TEST(CompilerPluginTest, Partition) { + auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); + ASSERT_EQ(plugins->size(), 1); + EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); + + auto model = testing::LoadTestFileModel("mul_simple.tflite"); + auto subgraph = model.MainSubgraph(); + auto ops = plugins->front().Partition(*subgraph); + ASSERT_TRUE(ops); + + EXPECT_EQ(ops->size(), 2); +} + +TEST(CompilerPluginTest, CompileModel) { + auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); + ASSERT_EQ(plugins->size(), 1); + EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); + + auto model = testing::LoadTestFileModel("mul_simple.tflite"); + auto subgraph = model.MainSubgraph(); + + std::ostringstream byte_code_out; + std::vector call_info_out; + LITERT_ASSERT_STATUS_OK(plugins->front().Compile( + kTestModels, {subgraph->Get()}, byte_code_out, call_info_out)); + + EXPECT_GT(byte_code_out.str().size(), 0); + EXPECT_EQ(call_info_out.size(), 1); +} + +TEST(CompilerPluginTest, Dump) { + auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); + ASSERT_EQ(plugins->size(), 1); + + std::stringstream dump; + Dump(plugins->front(), dump); + + ASSERT_EQ(dump.view(), + "SocManufacturer: ExampleSocManufacturer\nSocModels: { " + "ExampleSocModel }\n"); +} + +TEST(ApplyPluginTest, ApplyPlugin) { + auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); + ASSERT_EQ(plugins->size(), 1); + auto model = testing::LoadTestFileModel("mul_simple.tflite"); + ASSERT_TRUE(model); + + auto npu_code = ApplyPlugin(plugins->front(), model); + ASSERT_TRUE(npu_code); + EXPECT_GT(npu_code->Size(), 0); + + auto ops = model.MainSubgraph()->Ops(); + ASSERT_EQ(ops.size(), 1); + EXPECT_EQ(ops.front().Code(), kLiteRtOpCodeTflCustom); + EXPECT_EQ(ops.front().Get()->CustomOptions().StrView(), "Partition_0"); +} + +} // namespace +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/BUILD b/tflite/experimental/litert/core/BUILD new file mode 100644 index 00000000..01b8e865 --- /dev/null +++ b/tflite/experimental/litert/core/BUILD @@ -0,0 +1,141 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "byte_code_util", + srcs = ["byte_code_util.cc"], + hdrs = ["byte_code_util.h"], + deps = [ + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "byte_code_util_test", + srcs = ["byte_code_util_test.cc"], + data = [ + "//tflite/experimental/litert/test:tflite_test_data", + ], + deps = [ + ":byte_code_util", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/test:common", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "dynamic_loading", + srcs = ["dynamic_loading.cc"], + hdrs = ["dynamic_loading.h"], + linkopts = ["-ldl"], + deps = [ + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", # buildcleaner: keep + "//tflite/experimental/litert/cc:litert_macros", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "environment", + srcs = ["environment.cc"], + hdrs = [ + "environment.h", + "//tflite/experimental/litert/c:litert_environment.h", + ], + deps = [ + "//tflite/experimental/litert/c:litert_any", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/cc:litert_any", + "//tflite/experimental/litert/cc:litert_expected", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "environment_test", + srcs = ["environment_test.cc"], + deps = [ + ":environment", + "//tflite/experimental/litert/c:litert_any", + "//tflite/experimental/litert/cc:litert_any", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "filesystem", + srcs = ["filesystem.cc"], + hdrs = ["filesystem.h"], + deps = [ + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/cc:litert_detail", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "filesystem_test", + srcs = ["filesystem_test.cc"], + deps = [ + ":filesystem", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + +# copybara:uncomment_begin(no OSS for unique-test-directory) +# cc_test( +# name = "dynamic_loading_test", +# srcs = ["dynamic_loading_test.cc"], +# tags = [ +# # Sanitizer runtimes are incompatible with RTLD_DEEPBIND. +# "noasan", +# "nomsan", +# "nosan", +# ], +# deps = [ +# ":dynamic_loading", +# ":filesystem", +# "@com_google_googletest//:gtest_main", +# "//testing/base/public:unique-test-directory", +# "@com_google_absl//absl/strings:string_view", +# "//tflite/experimental/litert/c:litert_logging", # buildcleaner: keep +# "//tflite/experimental/litert/test:common", +# ], +# ) +# copybara:uncomment_end diff --git a/tflite/experimental/litert/core/byte_code_util.cc b/tflite/experimental/litert/core/byte_code_util.cc new file mode 100644 index 00000000..71ad480c --- /dev/null +++ b/tflite/experimental/litert/core/byte_code_util.cc @@ -0,0 +1,170 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/byte_code_util.h" + +#include +#include +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" + +namespace litert::internal { + +namespace { +// Simple metadata added to the flatbuffer related to compiler plugin. +struct BuildStamp { + char soc_manufacturer[kSocManufacturerMaxLen + 1] = {}; + char soc_model[kSocModelMaxLen + 1] = {}; + Serialization serialization = kUnknown; +}; + +// Structure of serialized byte code placeholder. +struct ByteCodePlaceholder { + char offset_str[kByteCodeOffsetStrMaxLen + 1] = {}; + char size_str[kByteCodeSizeStrMaxLen + 1] = {}; +}; + +// Structure of serialized per-custom op data. +struct ExecInfo { + char entrypoint_name[kEntryPointNameMaxLen + 1] = {}; + char metadata_key[kMetadataKeyMaxLen + 1] = {}; +}; + +static constexpr size_t kByteCodePlaceholderBufSize = + sizeof(ByteCodePlaceholder) + kByteCodePrefix.size(); +} // namespace + +Expected> MakeBuildStamp( + absl::string_view soc_manufacturer, absl::string_view soc_model, + Serialization serialization) { + if (soc_manufacturer.size() >= kSocManufacturerMaxLen || + soc_model.size() >= kSocModelMaxLen) { + LITERT_LOG(LITERT_ERROR, "%s", "Soc Make/Model strings too large\n"); + return Unexpected(kLiteRtStatusErrorInvalidArgument); + } + BuildStamp stamp; + soc_manufacturer.copy(stamp.soc_manufacturer, soc_manufacturer.size()); + soc_model.copy(stamp.soc_model, soc_model.size()); + stamp.serialization = serialization; + return OwningBufferRef(reinterpret_cast(&stamp), + sizeof(stamp)); +} + +// Parse a serialized build stamp from the given buf. +Expected> +ParseBuildStamp(BufferRef buf) { + if (buf.Size() != sizeof(BuildStamp)) { + LITERT_LOG(LITERT_ERROR, "%s", "Build stamp size mismatch\n"); + return Unexpected(kLiteRtStatusErrorInvalidArgument); + } + const BuildStamp* stamp = reinterpret_cast(buf.Data()); + return std::make_tuple(absl::string_view(stamp->soc_manufacturer), + absl::string_view(stamp->soc_model), + stamp->serialization); +} + +OwningBufferRef MakeByteCodePlaceholder() { + OwningBufferRef buf(kByteCodePlaceholderBufSize); + buf.WriteInto(kByteCodePrefix); + ByteCodePlaceholder* placeholder = reinterpret_cast( + buf.Data() + kByteCodePrefix.size()); + *placeholder = ByteCodePlaceholder(); + return buf; +} + +Expected> ParseByteCodePlaceholder( + BufferRef buf) { + if (buf.Size() != kByteCodePlaceholderBufSize || + buf.StrView().compare(0, kByteCodePrefix.size(), kByteCodePrefix) != 0) { + LITERT_LOG(LITERT_ERROR, "%s", "Byte code placeholder size mismatch\n"); + return Unexpected(kLiteRtStatusErrorInvalidArgument); + } + + const ByteCodePlaceholder* placeholder = + reinterpret_cast(buf.Data() + + kByteCodePrefix.size()); + const absl::string_view offset_str(placeholder->offset_str); + const absl::string_view size_str(placeholder->size_str); + + size_t offset, size; + if (!absl::SimpleAtoi(offset_str, &offset) || + !absl::SimpleAtoi(size_str, &size)) { + LITERT_LOG(LITERT_ERROR, "%s", + "Byte code placeholder offset/size invalid\n"); + return Unexpected(kLiteRtStatusErrorInvalidArgument); + } + + return std::make_pair(offset, size); +} + +LiteRtStatus FinishByteCodePlaceholders( + MutableBufferRef seralized_model, size_t byte_code_size) { + const size_t placeholder_start = + seralized_model.StrView().rfind(kByteCodePrefix); + LITERT_ENSURE(placeholder_start != absl::string_view::npos, + kLiteRtStatusErrorInvalidArgument, + "Cannot find any bytecode placeholders in the model"); + + ByteCodePlaceholder* placeholder = reinterpret_cast( + seralized_model.Data() + kByteCodePrefix.size() + placeholder_start); + + const int offset_written = + absl::SNPrintF(placeholder->offset_str, kByteCodeOffsetStrMaxLen, "%lu", + seralized_model.Size()); + LITERT_ENSURE( + offset_written > -1 && offset_written <= kByteCodeOffsetStrMaxLen, + kLiteRtStatusErrorInvalidArgument, "Offset too large"); + + const int size_written = absl::SNPrintF( + placeholder->size_str, kByteCodeSizeStrMaxLen, "%lu", byte_code_size); + LITERT_ENSURE(size_written > -1 && size_written <= kByteCodeSizeStrMaxLen, + kLiteRtStatusErrorInvalidArgument, "Size too large"); + return kLiteRtStatusOk; +} + +Expected> ParseExecInfo( + BufferRef buf) { + if (buf.Size() != sizeof(ExecInfo)) { + LITERT_LOG(LITERT_ERROR, "%s", "Exec info size mismatch\n"); + return Unexpected(kLiteRtStatusErrorInvalidArgument); + } + const ExecInfo* exec_info = reinterpret_cast(buf.Data()); + return std::make_pair(absl::string_view(exec_info->entrypoint_name), + absl::string_view(exec_info->metadata_key)); +} + +Expected> MakeExecInfo( + absl::string_view entrypoint_name, absl::string_view metadata_key) { + if (entrypoint_name.size() >= kEntryPointNameMaxLen || + metadata_key.size() >= kMetadataKeyMaxLen) { + LITERT_LOG(LITERT_ERROR, "%s", "Exec info strings too large\n"); + return Unexpected(kLiteRtStatusErrorInvalidArgument); + } + ExecInfo exec_info; + entrypoint_name.copy(exec_info.entrypoint_name, entrypoint_name.size()); + metadata_key.copy(exec_info.metadata_key, metadata_key.size()); + return OwningBufferRef(reinterpret_cast(&exec_info), + sizeof(exec_info)); +} + +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/byte_code_util.h b/tflite/experimental/litert/core/byte_code_util.h new file mode 100644 index 00000000..db6568a3 --- /dev/null +++ b/tflite/experimental/litert/core/byte_code_util.h @@ -0,0 +1,118 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_BYTE_CODE_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_BYTE_CODE_UTIL_H_ + +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" + +namespace litert::internal { + +// Shared "custom_code" for all dispatch ops. +static constexpr absl::string_view kLiteRtDispatchOpCustomCode = "DISPATCH_OP"; + +// +// Build Stamp +// + +// Maximum size of string for soc_manufacturer. +static constexpr size_t kSocManufacturerMaxLen = 124; + +// Maximum size of string for soc_model. +static constexpr size_t kSocModelMaxLen = 124; + +// The method used for packing byte code with flatbuffer. +enum Serialization : uint8_t { + kUnknown = 0, + // Byte code is appended to back of .tflite. + kAppend = 1, + // Byte code is stored in a metadata buffer [FOR TESTING ONLY]. + kMetadata = 2 +}; + +// Metadata key to lookup the build stamp. +static constexpr absl::string_view kLiteRtBuildStampKey = "LiteRtStamp"; + +// Make a serialized build stamp that can go directly in the flatbuffer. +Expected> MakeBuildStamp( + absl::string_view soc_manufacturer, absl::string_view soc_model, + Serialization serialization); + +// Parse a serialized build stamp from the given buf. +Expected> +ParseBuildStamp(BufferRef buf); + +// +// METADATA +// + +// Metadata key for looking up byte code that is directly packed. +static constexpr absl::string_view kByteCodeMetadataKey = "NPU_BYTE_CODE"; + +// +// APPEND: Placeholder for bytecode offset and size. +// + +// Maximum number of digits the byte code size can be base 10. +static constexpr size_t kByteCodeSizeStrMaxLen = 10; + +// Maximum number of digits the byte code offset can be base 10. +static constexpr size_t kByteCodeOffsetStrMaxLen = 10; + +// Prefix before serialized [offset, size, function name]. +static constexpr absl::string_view kByteCodePrefix = ""; + +// Get a new serialized byte code placeholder buffer with prefix. +OwningBufferRef MakeByteCodePlaceholder(); + +// Parse byte code offset and size serialized as a ByteCodePlaceholder in buf. +Expected> ParseByteCodePlaceholder( + BufferRef buf); + +// Replace all byte code placeholders with actual values. This happens directly +// on a serialized model without changing its size. +LiteRtStatus FinishByteCodePlaceholders( + MutableBufferRef seralized_model, size_t byte_code_size); + +// +// APPEND: ExecInfo for per-custom op info. +// + +// Maximum length of string for the entry point name. +static constexpr size_t kEntryPointNameMaxLen = 124; + +// Maximum length of a metadata key stored per custom op. +static constexpr size_t kMetadataKeyMaxLen = 124; + +// Make a serialized exec info from the given values. +Expected> MakeExecInfo( + absl::string_view entrypoint_name, absl::string_view metadata_key); + +// Parse serialized exec info from buffer. +Expected> ParseExecInfo( + BufferRef buf); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_BYTE_CODE_UTIL_H_ diff --git a/tflite/experimental/litert/core/byte_code_util_test.cc b/tflite/experimental/litert/core/byte_code_util_test.cc new file mode 100644 index 00000000..c277d766 --- /dev/null +++ b/tflite/experimental/litert/core/byte_code_util_test.cc @@ -0,0 +1,109 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/byte_code_util.h" + +#include +#include +#include +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/test/common.h" + +namespace litert::internal { + +namespace { + +using ::testing::StartsWith; + +static constexpr absl::string_view kSocModel = "TestSocModel"; +static constexpr absl::string_view kSocMan = "TestSocMan"; +static constexpr Serialization kSerialization = Serialization::kAppend; + +TEST(TestBuildStamp, MakeBuildStampInputsTooLarge) { + // NOLINTNEXTLINE + std::string long_manufacturer(256, 'a'); + auto res = MakeBuildStamp(long_manufacturer, kSocModel, kSerialization); + EXPECT_EQ(res.Error().Status(), kLiteRtStatusErrorInvalidArgument); +} + +TEST(TestBuildStamp, MakeBuildStamp) { + auto stamp = MakeBuildStamp(kSocMan, kSocModel, kSerialization); + auto pstamp = ParseBuildStamp(*stamp); + auto [man, model, serial] = *pstamp; + EXPECT_EQ(man, kSocMan); + EXPECT_EQ(model, kSocModel); + EXPECT_EQ(serial, kSerialization); +} + +TEST(TestByteCodePlaceholder, ParseBadPlaceholder) { + OwningBufferRef placeholder; + auto res = ParseByteCodePlaceholder(placeholder); + EXPECT_EQ(res.Error().Status(), kLiteRtStatusErrorInvalidArgument); +} + +TEST(TestByteCodePlaceholder, BuildAndParseEmptyInvalid) { + auto placeholder = MakeByteCodePlaceholder(); + ASSERT_THAT(placeholder.StrView(), StartsWith(kByteCodePrefix)); + auto res = ParseByteCodePlaceholder(placeholder); + EXPECT_EQ(res.Error().Status(), kLiteRtStatusErrorInvalidArgument); +} + +TEST(TestByteCodePlaceholder, BuildAndFinishByteCodePlaceholder) { + auto placeholder = MakeByteCodePlaceholder(); + + static constexpr size_t kByteCodeSize = 200; + LITERT_ASSERT_STATUS_OK( + FinishByteCodePlaceholders(placeholder, kByteCodeSize)); + + auto p_placeholder = ParseByteCodePlaceholder(placeholder); + auto [offset, size] = *p_placeholder; + EXPECT_EQ(offset, placeholder.Size()); + EXPECT_EQ(size, kByteCodeSize); +} + +TEST(TestByteCodePlaceholder, BuildAndFinishByteCodePlaceholderTooLarge) { + auto placeholder = MakeByteCodePlaceholder(); + + static constexpr size_t kByteCodeSize = std::numeric_limits::max(); + LITERT_ASSERT_STATUS_HAS_CODE( + FinishByteCodePlaceholders(placeholder, kByteCodeSize), + kLiteRtStatusErrorInvalidArgument); +} + +TEST(TestExecInfo, ExecInfo) { + auto exec_info = MakeExecInfo("entry_point", "key"); + auto p_exec_info = ParseExecInfo(*exec_info); + auto [entry_point, key] = *p_exec_info; + EXPECT_EQ(entry_point, "entry_point"); + EXPECT_EQ(key, "key"); +} + +TEST(TestExecInfo, ExecInfoTooLarge) { + // NOLINTNEXTLINE + std::string long_entry_point(256, 'a'); + auto res = MakeExecInfo(long_entry_point, "key"); + EXPECT_EQ(res.Error().Status(), kLiteRtStatusErrorInvalidArgument); +} + +} // namespace + +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/dynamic_loading.cc b/tflite/experimental/litert/core/dynamic_loading.cc new file mode 100644 index 00000000..d2b4c05c --- /dev/null +++ b/tflite/experimental/litert/core/dynamic_loading.cc @@ -0,0 +1,98 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/dynamic_loading.h" + +#include + +#ifndef __ANDROID__ +#if __has_include() +#include +#endif +#endif + +#include +#include // NOLINT +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/cc/litert_macros.h" + +namespace litert::internal { + +LiteRtStatus OpenLib(absl::string_view so_path, void** lib_handle) { +#ifdef RTLD_DEEPBIND + void* res = ::dlopen(so_path.data(), RTLD_NOW | RTLD_LOCAL | RTLD_DEEPBIND); +#else + void* res = ::dlopen(so_path.data(), RTLD_NOW | RTLD_LOCAL); +#endif + + if (res == nullptr) { + LITERT_LOG(LITERT_ERROR, "Failed to load .so at path: %s\n", + so_path.data()); + LogDlError(); + + return kLiteRtStatusErrorDynamicLoading; + } + *lib_handle = res; + return kLiteRtStatusOk; +} + +LiteRtStatus CloseLib(void* lib_handle) { + if (0 != ::dlclose(lib_handle)) { + LITERT_LOG(LITERT_ERROR, "Failed to close .so with error: %s", ::dlerror()); + return kLiteRtStatusErrorDynamicLoading; + } + return kLiteRtStatusOk; +} + +namespace { + +LiteRtStatus FindLiteRtSharedLibsHelper(const std::string& search_path, + std::vector& results) { + if (!std::filesystem::exists(search_path)) { + return kLiteRtStatusErrorInvalidArgument; + } + + const std::string compiler_plugin_lib_pattern = + absl::StrFormat("%s%s", kLiteRtSharedLibPrefix, "CompilerPlugin"); + for (const auto& entry : std::filesystem::directory_iterator(search_path)) { + const auto& path = entry.path(); + if (entry.is_regular_file()) { + auto stem = path.stem().string(); + auto ext = path.extension().string(); + if (stem.find(compiler_plugin_lib_pattern) == 0 && ext == ".so") { + results.push_back(path); + } + } else if (entry.is_directory()) { + FindLiteRtSharedLibsHelper(path, results); + } + } + + return kLiteRtStatusOk; +} + +} // namespace + +LiteRtStatus FindLiteRtSharedLibs(absl::string_view search_path, + std::vector& results) { + std::string root(search_path.data()); + return FindLiteRtSharedLibsHelper(root, results); +} + +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/dynamic_loading.h b/tflite/experimental/litert/core/dynamic_loading.h new file mode 100644 index 00000000..6138ea47 --- /dev/null +++ b/tflite/experimental/litert/core/dynamic_loading.h @@ -0,0 +1,71 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DYNAMIC_LOADING_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DYNAMIC_LOADING_H_ + +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" + +namespace litert::internal { + +constexpr absl::string_view kLiteRtSharedLibPrefix = "libLiteRt"; + +// Check for null and print the last dlerror. +inline void LogDlError() { + char* err = ::dlerror(); + if (err == nullptr) { + return; + } + LITERT_LOG(LITERT_WARNING, "::dlerror() : %s", err); +} + +// Loads shared library at given path. +LiteRtStatus OpenLib(absl::string_view so_path, void** lib_handle); + +// Closes reference to loaded shared library held by lib_handle. +LiteRtStatus CloseLib(void* lib_handle); + +// Resolves a named symbol from given lib handle of type Sym. +template +inline static LiteRtStatus ResolveLibSymbol(void* lib_handle, + absl::string_view sym_name, + Sym* sym_handle) { + Sym ptr = (Sym)::dlsym(lib_handle, sym_name.data()); + if (ptr == nullptr) { + LITERT_LOG(LITERT_ERROR, "Faild to resolve symbol: %s\n", sym_name.data()); + LogDlError(); + return kLiteRtStatusErrorDynamicLoading; + } + *sym_handle = ptr; + return kLiteRtStatusOk; +} + +// Find all litert shared libraries in "search_path" and return +// kLiteRtStatusErrorInvalidArgument if the provided search_path doesn't +// exist. All internal dynamically linked dependencies for litert should be +// prefixed with "libLiteRtCompilerPlugin". +LiteRtStatus FindLiteRtSharedLibs(absl::string_view search_path, + std::vector& results); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DYNAMIC_LOADING_H_ diff --git a/tflite/experimental/litert/core/dynamic_loading_test.cc b/tflite/experimental/litert/core/dynamic_loading_test.cc new file mode 100644 index 00000000..bd888aaf --- /dev/null +++ b/tflite/experimental/litert/core/dynamic_loading_test.cc @@ -0,0 +1,72 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/dynamic_loading.h" + +#include +#include + +#include +#include +#include "testing/base/public/unique-test-directory.h" +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/core/filesystem.h" +#include "tflite/experimental/litert/test/common.h" + +namespace litert::internal { +namespace { + +using ::testing::Contains; +using ::testing::HasSubstr; +using ::testing::UniqueTestDirectory; + +constexpr absl::string_view kNotLiteRtSo = "notLibLiteRt.so"; +constexpr absl::string_view kLiteRtSo1 = "libLiteRtCompilerPlugin_1.so"; +constexpr absl::string_view kLiteRtSo2 = "libLiteRtCompilerPlugin_2.so"; + +TEST(TestDynamicLoading, GlobNoMatch) { + const auto dir = UniqueTestDirectory(); + Touch(Join({dir, kNotLiteRtSo})); + + std::vector results; + LITERT_ASSERT_STATUS_OK(litert::internal::FindLiteRtSharedLibs(dir, results)); + EXPECT_EQ(results.size(), 0); +} + +TEST(TestDynamicLoading, GlobOneMatch) { + const auto dir = UniqueTestDirectory(); + Touch(Join({dir, kLiteRtSo1})); + Touch(Join({dir, kNotLiteRtSo})); + + std::vector results; + LITERT_ASSERT_STATUS_OK(litert::internal::FindLiteRtSharedLibs(dir, results)); + ASSERT_EQ(results.size(), 1); + EXPECT_TRUE(absl::string_view(results.front()).ends_with(kLiteRtSo1)); +} + +TEST(TestDynamicLoading, GlobMultiMatch) { + const auto dir = UniqueTestDirectory(); + Touch(Join({dir, kLiteRtSo1})); + Touch(Join({dir, kLiteRtSo2})); + Touch(Join({dir, kNotLiteRtSo})); + + std::vector results; + LITERT_ASSERT_STATUS_OK(litert::internal::FindLiteRtSharedLibs(dir, results)); + ASSERT_EQ(results.size(), 2); + EXPECT_THAT(results, Contains(HasSubstr(kLiteRtSo1))); + EXPECT_THAT(results, Contains(HasSubstr(kLiteRtSo2))); +} + +} // namespace +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/environment.cc b/tflite/experimental/litert/core/environment.cc new file mode 100644 index 00000000..b796befa --- /dev/null +++ b/tflite/experimental/litert/core/environment.cc @@ -0,0 +1,55 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/environment.h" + +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_environment.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +namespace litert::internal { + +Environment* Environment::the_instance_ = nullptr; + +Expected Environment::CreateWithOptions( + absl::Span options) { + if (the_instance_) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "LiteRT environment cannot be created with options, it has " + "already been created"); + } + LITERT_LOG(LITERT_INFO, "Creating LiteRT environment with options"); + the_instance_ = new Environment(); + for (auto& option : options) { + the_instance_->options_[option.tag] = option.value; + } + return {}; +} + +void Environment::Destroy() { + delete the_instance_; + the_instance_ = nullptr; +} + +Expected Environment::Instance() { + if (!the_instance_) { + LITERT_LOG(LITERT_INFO, "Creating LiteRT environment with no options"); + the_instance_ = new Environment(); + } + return the_instance_; +} + +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/environment.h b/tflite/experimental/litert/core/environment.h new file mode 100644 index 00000000..32d5a889 --- /dev/null +++ b/tflite/experimental/litert/core/environment.h @@ -0,0 +1,62 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ENVIRONMENT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ENVIRONMENT_H_ + +#include +#include +#include + +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_environment.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +namespace litert::internal { + +// A singleton class that contains global LiteRT environment options. +class Environment { + public: + // Create the singleton environment instance with options. Returns an error if + // the instance already exists, in which case the specified options have no + // effect. + static Expected CreateWithOptions( + absl::Span options); + + // Return the envirnment instance and, if not yet created, creates one with no + // options. + static Expected Instance(); + + // Destroy the environment instance. + static void Destroy(); + + std::optional GetOption(LiteRtEnvOptionTag tag) const { + auto i = options_.find(tag); + if (i != options_.end()) { + return i->second; + } else { + return std::nullopt; + } + } + + private: + std::map options_; + + static Environment* the_instance_; +}; + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ENVIRONMENT_H_ diff --git a/tflite/experimental/litert/core/environment_test.cc b/tflite/experimental/litert/core/environment_test.cc new file mode 100644 index 00000000..9e8ac54c --- /dev/null +++ b/tflite/experimental/litert/core/environment_test.cc @@ -0,0 +1,70 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/environment.h" + +#include +#include + +#include +#include "tflite/experimental/litert/c/litert_any.h" +#include "tflite/experimental/litert/c/litert_environment.h" +#include "tflite/experimental/litert/cc/litert_any.h" + +namespace litert::internal { +namespace { + +TEST(Environment, CreateWithNoOption) { + ASSERT_TRUE(Environment::Instance()); + Environment::Destroy(); +} + +TEST(Environment, CreateWithOptions) { + const std::array environment_options = { + LiteRtEnvOption{ + kLiteRtEnvOptionTagCompilerPluginLibraryPath, + *ToLiteRtAny(std::any("sample path")), + }, + }; + ASSERT_TRUE(Environment::CreateWithOptions(environment_options)); + + auto env = Environment::Instance(); + ASSERT_TRUE(env); + + auto option = (*env)->GetOption(kLiteRtEnvOptionTagCompilerPluginLibraryPath); + ASSERT_TRUE(option.has_value()); + ASSERT_EQ(option->type, kLiteRtAnyTypeString); + ASSERT_STREQ(option->str_value, "sample path"); + + Environment::Destroy(); +} + +TEST(Environment, CreateWithOptionsFailure) { + // This will create an environment without options. + auto env = Environment::Instance(); + ASSERT_TRUE(env); + + const std::array environment_options = { + LiteRtEnvOption{ + kLiteRtEnvOptionTagCompilerPluginLibraryPath, + *ToLiteRtAny(std::any("sample path")), + }, + }; + ASSERT_FALSE(Environment::CreateWithOptions(environment_options)); + + Environment::Destroy(); +} + +} // namespace +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/filesystem.cc b/tflite/experimental/litert/core/filesystem.cc new file mode 100644 index 00000000..ab8ec475 --- /dev/null +++ b/tflite/experimental/litert/core/filesystem.cc @@ -0,0 +1,97 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/filesystem.h" + +#include +#include +#include // NOLINT +#include + +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_detail.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" + +namespace litert::internal { + +namespace { + +using StdPath = std::filesystem::path; + +StdPath MakeStdPath(absl::string_view path) { + return StdPath(std::string(path.begin(), path.end())); +} + +bool StdExists(const StdPath& std_path) { + return std::filesystem::exists(std_path); +} + +size_t StdSize(const StdPath& std_path) { + return std::filesystem::file_size(std_path); +} + +LiteRtStatus StdIFRead(const StdPath& std_path, char* data, size_t size) { + std::ifstream in_file_stream(std_path, std::ifstream::binary); + if (!in_file_stream) { + return kLiteRtStatusErrorFileIO; + } + + in_file_stream.read(data, size); + if (!in_file_stream) { + return kLiteRtStatusErrorFileIO; + } + + in_file_stream.close(); + return kLiteRtStatusOk; +} + +} // namespace + +void Touch(absl::string_view path) { std::ofstream(MakeStdPath(path)); } + +std::string Join(const SmallVec& paths) { + StdPath std_path; + for (auto subpath : paths) { + std_path /= MakeStdPath(subpath); + } + return std_path.generic_string(); +} + +bool Exists(absl::string_view path) { return StdExists(MakeStdPath(path)); } + +Expected Size(absl::string_view path) { + auto std_path = MakeStdPath(path); + if (!StdExists(std_path)) { + return Error(kLiteRtStatusErrorNotFound); + } + return StdSize(std_path); +} + +Expected> LoadBinaryFile(absl::string_view path) { + auto std_path = MakeStdPath(path); + + if (!StdExists(std_path)) { + return Error(kLiteRtStatusErrorFileIO); + } + + OwningBufferRef buf(StdSize(std_path)); + LITERT_EXPECT_OK(StdIFRead(std_path, buf.StrData(), buf.Size())); + + return buf; +} + +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/filesystem.h b/tflite/experimental/litert/core/filesystem.h new file mode 100644 index 00000000..abc6334e --- /dev/null +++ b/tflite/experimental/litert/core/filesystem.h @@ -0,0 +1,48 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_FILESYSTEM_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_FILESYSTEM_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_detail.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +// Generic file operations. Try to encapsulate the std filesystem header as much +// as possible because its technically unapproved. + +namespace litert::internal { + +// Append all given subpaths together (e.g. os.path.join). +std::string Join(const SmallVec& paths); + +// Make a new empty file at the given path. +void Touch(absl::string_view path); + +// Does this file exist. +bool Exists(absl::string_view path); + +// Get size of file. +Expected Size(absl::string_view path); + +// Load the bytes of the file at given path. +Expected> LoadBinaryFile(absl::string_view path); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_FILESYSTEM_H_ diff --git a/tflite/experimental/litert/core/filesystem_test.cc b/tflite/experimental/litert/core/filesystem_test.cc new file mode 100644 index 00000000..a19f7d23 --- /dev/null +++ b/tflite/experimental/litert/core/filesystem_test.cc @@ -0,0 +1,39 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/filesystem.h" + +#include +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" + +namespace litert::internal { +namespace { + +static constexpr absl::string_view kPrefix = "a/prefix"; +static constexpr absl::string_view kInfix = "an/infix"; +static constexpr absl::string_view kSuffix = "suffix.ext"; + +TEST(FilesystemTest, JoinTwo) { + const auto path = Join({kPrefix, kSuffix}); + EXPECT_EQ(path, absl::StrFormat("%s/%s", kPrefix, kSuffix)); +} + +TEST(FilesystemTest, JoinMany) { + const auto path = Join({kPrefix, kInfix, kSuffix}); + EXPECT_EQ(path, absl::StrFormat("%s/%s/%s", kPrefix, kInfix, kSuffix)); +} + +} // namespace +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/model/BUILD b/tflite/experimental/litert/core/model/BUILD new file mode 100644 index 00000000..feee35c9 --- /dev/null +++ b/tflite/experimental/litert/core/model/BUILD @@ -0,0 +1,303 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "model", + srcs = ["model.cc"], + hdrs = [ + "model.h", + "//tflite/experimental/litert/c:litert_model_hdrs", + ], + deps = [ + ":ir_allocator", + "//tflite/core/c:c_api_types", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_layout", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/core/util:flatbuffer_tools", + "//tflite/schema:schema_fbs", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@org_tensorflow//tensorflow/compiler/mlir/lite/core:model_builder_base", + ], +) + +cc_test( + name = "model_test", + srcs = ["model_test.cc"], + data = [ + "//tflite/experimental/litert/test:testdata/simple_model.tflite", + ], + deps = [ + ":model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/core/util:flatbuffer_tools", + "//tflite/experimental/litert/test:test_macros", + "//tflite/schema:schema_fbs", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "model_load", + srcs = ["model_load.cc"], + hdrs = ["model_load.h"], + deps = [ + ":flatbuffer_to_litert", + ":model", + ":model_graph", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/core/util:flatbuffer_tools", + "//tflite/schema:schema_fbs", + "@com_google_absl//absl/strings:string_view", + "@org_tensorflow//tensorflow/compiler/mlir/lite/core:model_builder_base", + ], +) + +cc_test( + name = "model_file_test", + srcs = ["model_file_test.cc"], + data = [ + "//tflite/experimental/litert/test:mlir_test_data", + "//tflite/experimental/litert/test:tflite_test_data", + ], + deps = [ + ":graph_validation", + ":model", + ":model_file_test_util", + ":model_load", + ":model_serialize", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/cc:litert_element_type", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/core/util:flatbuffer_tools", + "//tflite/experimental/litert/test:common", + "//tflite/experimental/litert/test:test_macros", + "//tflite/experimental/litert/test:test_models", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "model_serialize", + srcs = ["model_serialize.cc"], + hdrs = ["model_serialize.h"], + deps = [ + ":litert_to_flatbuffer", + ":model", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/core:byte_code_util", + "//tflite/experimental/litert/core/util:flatbuffer_tools", + "//tflite/schema:schema_fbs", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + +cc_library( + name = "flatbuffer_to_litert", + srcs = ["flatbuffer_to_litert.cc"], + hdrs = ["flatbuffer_to_litert.h"], + deps = [ + ":model", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_layout", + "//tflite/experimental/litert/core/util:flatbuffer_tools", + "//tflite/schema:schema_fbs", + ], +) + +cc_test( + name = "flatbuffer_to_litert_test", + srcs = ["flatbuffer_to_litert_test.cc"], + deps = [ + ":flatbuffer_to_litert", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/core/util:flatbuffer_tools", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "litert_to_flatbuffer", + srcs = ["litert_to_flatbuffer.cc"], + hdrs = ["litert_to_flatbuffer.h"], + deps = [ + ":model", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/core/util:flatbuffer_tools", + "//tflite/schema:schema_fbs", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "litert_to_flatbuffer_test", + srcs = ["litert_to_flatbuffer_test.cc"], + deps = [ + ":litert_to_flatbuffer", + ":model", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/cc:litert_layout", + "//tflite/experimental/litert/core/util:flatbuffer_tools", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "model_buffer", + srcs = ["model_buffer.cc"], + hdrs = ["model_buffer.h"], + deps = [ + ":model", + ":model_load", + ":model_serialize", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/core:byte_code_util", + "//tflite/experimental/litert/core:filesystem", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "model_file_test_util", + testonly = 1, + srcs = ["model_file_test_util.cc"], + hdrs = ["model_file_test_util.h"], + deps = [ + ":flatbuffer_to_litert", + ":model", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/cc:litert_detail", + "//tflite/experimental/litert/core/util:flatbuffer_tools", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "ir_allocator", + hdrs = ["ir_allocator.h"], + deps = ["@com_google_absl//absl/types:span"], +) + +cc_test( + name = "ir_allocator_test", + srcs = ["ir_allocator_test.cc"], + deps = [ + ":ir_allocator", + ":model", + "//tflite/experimental/litert/c:litert_op_code", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "model_graph", + srcs = ["model_graph.cc"], + hdrs = ["model_graph.h"], + deps = [ + ":model", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/cc:litert_detail", + "//tflite/experimental/litert/cc:litert_expected", + "@com_google_absl//absl/log:absl_check", + ], +) + +cc_library( + name = "graph_validation", + srcs = ["graph_validation.cc"], + hdrs = ["graph_validation.h"], + deps = [ + ":model", + ":model_graph", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/cc:litert_detail", + ], +) + +cc_test( + name = "model_graph_test", + srcs = ["model_graph_test.cc"], + deps = [ + ":graph_validation", + ":model", + ":model_graph", + "//tflite/experimental/litert/c:litert_op_code", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "model_buffer_test", + srcs = ["model_buffer_test.cc"], + deps = [ + ":model", + ":model_buffer", + ":model_load", + "//tflite:framework", + "//tflite:model_builder", + "//tflite/c:c_api_opaque", + "//tflite/c:common", + "//tflite/core:cc_api_stable", + "//tflite/experimental/litert/core:byte_code_util", + "//tflite/experimental/litert/test:common", + "//tflite/experimental/litert/test:simple_model_npu", + "//tflite/kernels:builtin_ops", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@org_tensorflow//tensorflow/compiler/mlir/lite:allocation", + ], +) diff --git a/tflite/experimental/litert/core/model/flatbuffer_to_litert.cc b/tflite/experimental/litert/core/model/flatbuffer_to_litert.cc new file mode 100644 index 00000000..f8c6b2ef --- /dev/null +++ b/tflite/experimental/litert/core/model/flatbuffer_to_litert.cc @@ -0,0 +1,148 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/model/flatbuffer_to_litert.h" + +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_layout.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tflite/schema/schema_generated.h" + +namespace litert::internal { + +LiteRtStatus IsOpSupported(const tflite::OperatorT& op) { + // TODO: b/365299994 - Check for supported options. + + if (!op.intermediates.empty()) { + // TODO: b/365299994 - Support intermediates. + LITERT_LOG(LITERT_ERROR, "Intermediate tensors not yet supported."); + return kLiteRtStatusErrorUnsupported; + } + + if (op.large_custom_options_size != 0) { + // TODO: b/365299994 - Support large custom options. + LITERT_LOG(LITERT_ERROR, "Large custom options not yet supported."); + return kLiteRtStatusErrorUnsupported; + } + + for (auto m_input : op.mutating_variable_inputs) { + if (m_input) { + // TODO: b/365299994 - Support mutating variable inputs. + LITERT_LOG(LITERT_ERROR, "Mutating variable inputs not yet supported."); + return kLiteRtStatusErrorUnsupported; + } + } + + return kLiteRtStatusOk; +} + +LiteRtStatus IsBufferSupported(const tflite::BufferT& buffer) { + if (buffer.offset != 0) { + // TODO: b/365299994 - Support buffer with offset. + LITERT_LOG(LITERT_ERROR, "Buffers with offset not yet supported."); + return kLiteRtStatusErrorUnsupported; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus IsTensorSupported(const TflTensor& tensor) { + if (tensor.is_variable) { + // TODO: b/365299994 - Support variable tensors. + LITERT_LOG(LITERT_ERROR, "Variable tensors not yet supported."); + return kLiteRtStatusErrorUnsupported; + } + + if (!tensor.variant_tensors.empty()) { + // TODO: b/365299994 - Support variant tensors. + LITERT_LOG(LITERT_ERROR, "Variant tensors not yet supported."); + return kLiteRtStatusErrorUnsupported; + } + + if (tensor.sparsity) { + // TODO: b/365299994 - Support sparsity tensors. + LITERT_LOG(LITERT_ERROR, "Sparsity tensors not yet supported."); + return kLiteRtStatusErrorUnsupported; + } + + return kLiteRtStatusOk; +} + +LiteRtElementType MapElementType(TflElementType type) { + switch (type) { + case tflite::TensorType_FLOAT32: + return kLiteRtElementTypeFloat32; + case tflite::TensorType_FLOAT16: + return kLiteRtElementTypeFloat16; + case tflite::TensorType_INT32: + return kLiteRtElementTypeInt32; + case tflite::TensorType_BOOL: + return kLiteRtElementTypeBool; + case tflite::TensorType_INT16: + return kLiteRtElementTypeInt16; + case tflite::TensorType_INT8: + return kLiteRtElementTypeInt8; + default: + return kLiteRtElementTypeNone; + } +} + +Expected MapTensorType(const TflTensorType& tfl_tensor_type) { + const auto& [element_type, shape] = tfl_tensor_type; + auto ranked_shape = AsDynamicShape(shape); + if (!ranked_shape) { + LITERT_LOG(LITERT_ERROR, "Only ranked tensors currently supported"); + return Error(kLiteRtStatusErrorUnsupported); + } + + auto litert_element_type = MapElementType(element_type); + if (litert_element_type == kLiteRtElementTypeNone) { + LITERT_LOG(LITERT_ERROR, "Element type not currently supported"); + return Error(kLiteRtStatusErrorUnsupported); + } + + TensorTypeDetail detail; + detail.ranked_tensor_type.element_type = litert_element_type; + detail.ranked_tensor_type.layout = BuildLayout(*ranked_shape); + + return std::make_pair(kLiteRtRankedTensorType, detail); +} + +Expected MapQuantization(const TflQuantization* tfl_quantization, + BufferProvider buffer_provider) { + if (!IsQuantized(tfl_quantization)) { + return MakeEmptyQuantization(); + } + + if (auto tfl_qparams = AsPerTensorQparams(tfl_quantization)) { + return MakePerTensorQuantization(tfl_qparams->second, tfl_qparams->first); + } + + if (auto tfl_qparams = AsPerChannelQparams(tfl_quantization)) { + [[maybe_unused]] const auto& [quantized_dimension, num_channels, + zero_points, scales] = *tfl_qparams; + return MakePerChannelQuantization(scales, zero_points, quantized_dimension, + buffer_provider); + } + + LITERT_LOG(LITERT_ERROR, "Uknown tfl quantization type"); + return Error(kLiteRtStatusErrorUnsupported); +} +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/model/flatbuffer_to_litert.h b/tflite/experimental/litert/core/model/flatbuffer_to_litert.h new file mode 100644 index 00000000..2950ca29 --- /dev/null +++ b/tflite/experimental/litert/core/model/flatbuffer_to_litert.h @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_FLATBUFFER_TO_LITERT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_FLATBUFFER_TO_LITERT_H_ + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" + +namespace litert::internal { + +LiteRtStatus IsOpSupported(const TflOp& op); + +LiteRtStatus IsBufferSupported(const TflBuffer& buffer); + +// Checks if the misc non-type non quantization parts of this tensor are +// supported in the litet model api. +LiteRtStatus IsTensorSupported(const TflTensor& tensor); + +LiteRtElementType MapElementType(TflElementType element_type); + +Expected MapTensorType(const TflTensorType& tfl_tensor_type); + +Expected MapQuantization(const TflQuantization* tfl_quantization, + BufferProvider buffer_provider); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_FLATBUFFER_TO_LITERT_H_ diff --git a/tflite/experimental/litert/core/model/flatbuffer_to_litert_test.cc b/tflite/experimental/litert/core/model/flatbuffer_to_litert_test.cc new file mode 100644 index 00000000..7daccc32 --- /dev/null +++ b/tflite/experimental/litert/core/model/flatbuffer_to_litert_test.cc @@ -0,0 +1,110 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/model/flatbuffer_to_litert.h" + +#include +#include +#include + +#include +#include +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" + +namespace litert::internal { +namespace { + +using ::testing::ElementsAreArray; + +TEST(FlatbufferToLiteRtTest, MapStaticTensorType) { + static constexpr int32_t kDims[] = {2, 2}; + static constexpr auto kDimsSpan = absl::MakeConstSpan(kDims); + + auto t = MapTensorType(std::make_pair(TflElementType::TensorType_INT32, + TflShapeInfo(kDimsSpan))); + ASSERT_TRUE(t); + + ASSERT_EQ(t->first, kLiteRtRankedTensorType); + auto& ranked = t->second.ranked_tensor_type; + EXPECT_EQ(ranked.element_type, kLiteRtElementTypeInt32); + EXPECT_EQ(absl::MakeSpan(ranked.layout.dimensions, ranked.layout.rank), + kDimsSpan); +} + +TEST(FlatbufferToLiteRtTest, MapDynamicTensorType) { + static constexpr int32_t kDims[] = {-1, 2}; + static constexpr auto kDimsSpan = absl::MakeConstSpan(kDims); + + auto t = MapTensorType(std::make_pair(TflElementType::TensorType_INT32, + TflShapeInfo(kDimsSpan))); + ASSERT_TRUE(t); + + ASSERT_EQ(t->first, kLiteRtRankedTensorType); + auto& ranked = t->second.ranked_tensor_type; + EXPECT_EQ(ranked.element_type, kLiteRtElementTypeInt32); + EXPECT_EQ(absl::MakeSpan(ranked.layout.dimensions, ranked.layout.rank), + kDimsSpan); +} + +TEST(FlatbufferToLiteRtTest, MapNoQuantization) { + LiteRtTensorT tensor; + auto q = MapQuantization(nullptr, tensor); + ASSERT_TRUE(q); + ASSERT_EQ(q->first, kLiteRtQuantizationNone); +} + +TEST(FlatbufferToLiteRtTest, MapPerTensorQuantization) { + static constexpr float kScale = 1.0; + static constexpr int64_t kZp = 2; + + TflQuantization tfl_q; + tfl_q.scale.assign({kScale}); + tfl_q.zero_point.assign({kZp}); + + LiteRtTensorT tensor; + auto q = MapQuantization(&tfl_q, tensor); + ASSERT_TRUE(q); + ASSERT_EQ(q->first, kLiteRtQuantizationPerTensor); + EXPECT_EQ(q->second.per_tensor.scale, kScale); + EXPECT_EQ(q->second.per_tensor.zero_point, kZp); +} + +TEST(FlatbufferToLiteRtTest, MapPerChannelQuantization) { + static constexpr size_t kRank = 2; + static constexpr float kScales[kRank] = {1.0, 2.0}; + static constexpr int64_t kZps[kRank] = {2, 3}; + static constexpr size_t kQDim = 1; + + TflQuantization tfl_q; + tfl_q.scale.assign(kScales, kScales + kRank); + tfl_q.zero_point.assign(kZps, kZps + kRank); + tfl_q.quantized_dimension = kQDim; + + LiteRtTensorT tensor; + auto q = MapQuantization(&tfl_q, tensor); + ASSERT_TRUE(q); + ASSERT_EQ(q->first, kLiteRtQuantizationPerChannel); + EXPECT_THAT(absl::MakeConstSpan(q->second.per_channel.scales, kRank), + ElementsAreArray(kScales)); + + EXPECT_THAT(absl::MakeConstSpan(q->second.per_channel.zero_points, kRank), + ElementsAreArray(kZps)); + EXPECT_EQ(q->second.per_channel.quantized_dimension, kQDim); + EXPECT_EQ(q->second.per_channel.num_channels, kRank); +} + +} // namespace +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/model/graph_validation.cc b/tflite/experimental/litert/core/model/graph_validation.cc new file mode 100644 index 00000000..e9c1f490 --- /dev/null +++ b/tflite/experimental/litert/core/model/graph_validation.cc @@ -0,0 +1,114 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/model/graph_validation.h" + +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_detail.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/core/model/model_graph.h" + +namespace litert::internal { + +bool ValidateLocalTopology(const LiteRtOpT& litert_op) { + // Check number of in edges equals number of inputs and each input index + // appears on an in edge. + for (auto i = 0; i < litert_op.Inputs().size(); ++i) { + const auto& litert_tensor = litert_op.Input(i); + + auto input_use = + GetTensorUses(litert_tensor, FindUseInds(litert_tensor, litert_op)); + + if (!ContainsIf(input_use.cbegin(), input_use.cend(), + [i](auto u) { return u.second == i; })) { + LITERT_LOG(LITERT_WARNING, + "Input tensor %d not connected to op on correct index.", i); + return false; + } + } + + // Similar to above for outputs. + for (auto i = 0; i < litert_op.Outputs().size(); ++i) { + const auto& litert_tensor = litert_op.Output(i); + + if (litert_tensor.DefiningOp() != &litert_op) { + LITERT_LOG(LITERT_WARNING, "Output back edge doesn't refer to this op."); + return false; + } + + if (litert_tensor.DefiningOpOutInd() != i) { + LITERT_LOG(LITERT_WARNING, "Output back edge ind is incorrect."); + return false; + } + } + + return true; +} + +bool ValidateSubgraphIO(const LiteRtSubgraphT& litert_subgraph) { + auto num_implied_inputs = 0; + auto num_implied_outputs = 0; + for (auto* tensor : litert_subgraph.Tensors()) { + const auto implied_out = tensor->NumUses() == 0; + const auto implied_in = + !IsConstant(*tensor) && tensor->DefiningOp() == nullptr; + + if (implied_out && implied_in) { + LITERT_LOG(LITERT_WARNING, "Graph contains a dead tensor"); + return false; + } + + const auto is_io = IsIO(litert_subgraph, *tensor); + + if (implied_in) { + if (!is_io) { + LITERT_LOG(LITERT_WARNING, + "Implied input not reflected in subgraph io %lu", + tensor - litert_subgraph.Tensors().at(0)); + return false; + } + ++num_implied_inputs; + } + + if (implied_out) { + if (!is_io) { + LITERT_LOG(LITERT_WARNING, + "Implied output not reflected in subgraph io"); + return false; + } + ++num_implied_outputs; + } + } + + if (num_implied_inputs != litert_subgraph.NumInputs()) { + LITERT_LOG( + LITERT_WARNING, + "Number of implied %lu inputs not equal to number of actual inputs %lu", + num_implied_inputs, litert_subgraph.NumInputs()); + return false; + } + + if (num_implied_outputs != litert_subgraph.NumOutputs()) { + LITERT_LOG(LITERT_WARNING, + "Number of implied %lu outputs not equal to number of actual " + "outputs %lu", + num_implied_outputs, litert_subgraph.NumOutputs()); + return false; + } + + return true; +} + +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/model/graph_validation.h b/tflite/experimental/litert/core/model/graph_validation.h new file mode 100644 index 00000000..c917df69 --- /dev/null +++ b/tflite/experimental/litert/core/model/graph_validation.h @@ -0,0 +1,47 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_GRAPH_VALIDATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_GRAPH_VALIDATION_H_ + +#include + +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/core/model/model_graph.h" + +// Helper functions for validating the structure of IR graphs. + +namespace litert::internal { + +// Checks the double-linked edges to immediate neighbors are valid. +bool ValidateLocalTopology(const LiteRtOpT& litert_op); + +// Runs ValidateLocalTopology across given LiteRtOp iterator. +template +bool ValidateLocalTopology(OpIt start, OpIt end) { + return std::all_of(start, end, + [](const auto* op) { return ValidateLocalTopology(*op); }); +} + +// Checks the following are bijections: +// * non-const tensor with no defining op <-> subgraph input +// * tensor with no users <-> subgraph output (assuming no side effect ops) +// These are used to figure out the i/o signatures when building a subgraph +// from scratch. +bool ValidateSubgraphIO(const LiteRtSubgraphT& litert_subgraph); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_GRAPH_VALIDATION_H_ diff --git a/tflite/experimental/litert/core/model/ir_allocator.h b/tflite/experimental/litert/core/model/ir_allocator.h new file mode 100644 index 00000000..4e0a575a --- /dev/null +++ b/tflite/experimental/litert/core/model/ir_allocator.h @@ -0,0 +1,109 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_IR_ALLOCATOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_IR_ALLOCATOR_H_ + +#include +#include +#include +#include +#include + +#include "absl/types/span.h" + +namespace litert::internal { + +// A list of IR objects scoped to the same block (subgraph) that provides +// pointer stability. Facilitates management of memory and c-like access +// to elements. +template +class IrAllocator { + private: + using Storage = std::list; + using Refs = std::vector; + + public: + // Emplace a new element onto the list. + template + Ir& EmplaceBack(Args&&... args) { + auto& emp = storage_.emplace_back(std::forward(args)...); + refs_->push_back(&emp); + return emp; + } + + // Get the array of (stable) pointers to underlying elements. Suitable + // for passing through c-like interface. Consituent pointers are always + // guarateed to be stable (unless explicitly erased). The array of pointers + // itself is guaranteed to be stable so long as no length-changing operations + // occur, moving this class does not invalidate pointers or array. + absl::Span Elements() const { + return absl::MakeSpan(refs_->data(), refs_->size()); + } + + // Remove elements from the allocator if they match the predicate. + // Returns the number of elements removed. + size_t RemoveIf(std::function pred) { + auto ref_it = refs_->begin(); + for (auto it = storage_.begin(); it != storage_.end();) { + if (!pred(*it)) { + *ref_it = &*it; + ++ref_it; + ++it; + continue; + } + it = storage_.erase(it); + } + const size_t removed = refs_->end() - ref_it; + refs_->resize(refs_->size() - removed); + return removed; + } + + // Cuts all but the first `size` elements from storage. Does nothing if `size` + // is greater or equal to current size. + void ResizeDown(size_t size) { + if (size >= Size()) { + return; + } + storage_.resize(size); + refs_->resize(size); + } + + // Transfers the ownership of given allocator to this one. + void Transfer(IrAllocator&& other) { + storage_.splice(storage_.cend(), other.storage_); + refs_->insert(refs_->end(), other.refs_->cbegin(), other.refs_->cend()); + } + + // Number of elements stored by this allocator. + size_t Size() const { return storage_.size(); } + + IrAllocator() { refs_ = std::make_unique(); } + + // IR is generally semantically movable (without reference invalidation) + // but not copyable. IrAllocators reflect that, note moving lists + // does not invalidate references. + IrAllocator(const IrAllocator& other) = delete; + IrAllocator& operator=(const IrAllocator& other) = delete; + IrAllocator(IrAllocator&& other) = default; + IrAllocator& operator=(IrAllocator&& other) = default; + + private: + Storage storage_; + std::unique_ptr refs_; +}; + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_IR_ALLOCATOR_H_ diff --git a/tflite/experimental/litert/core/model/ir_allocator_test.cc b/tflite/experimental/litert/core/model/ir_allocator_test.cc new file mode 100644 index 00000000..3c568714 --- /dev/null +++ b/tflite/experimental/litert/core/model/ir_allocator_test.cc @@ -0,0 +1,108 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/model/ir_allocator.h" + +#include + +#include +#include +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/core/model/model.h" + +namespace litert::internal { +namespace { + +using ::testing::ElementsAreArray; + +static constexpr auto kCustomOpCode = kLiteRtOpCodeTflCustom; +static constexpr auto kNonCustomOpCode = kLiteRtOpCodeTflSoftmax; + +TEST(IrAllocatorTest, EmplaceBack) { + IrAllocator ops; + + LiteRtOpT my_op; + my_op.SetOpCode(kCustomOpCode); + + ops.EmplaceBack(std::move(my_op)); + ASSERT_EQ(ops.Elements().size(), 1); + EXPECT_EQ(ops.Elements().at(0)->OpCode(), kCustomOpCode); +} + +TEST(IrAllocatorTest, RemoveIf) { + IrAllocator ops; + + LiteRtOpT my_op; + my_op.SetOpCode(kNonCustomOpCode); + ops.EmplaceBack(std::move(my_op)); + + LiteRtOpT my_op2; + my_op2.SetOpCode(kCustomOpCode); + ops.EmplaceBack(std::move(my_op2)); + + LiteRtOpT my_op3; + my_op3.SetOpCode(kCustomOpCode); + ops.EmplaceBack(std::move(my_op3)); + + LiteRtOpT my_op4; + my_op4.SetOpCode(kNonCustomOpCode); + ops.EmplaceBack(std::move(my_op4)); + + auto pred = [](const auto& op) { return op.OpCode() != kCustomOpCode; }; + ASSERT_EQ(ops.RemoveIf(pred), 2); + + ASSERT_EQ(ops.Elements().size(), 2); + ASSERT_EQ(ops.Elements().at(0)->OpCode(), kCustomOpCode); + ASSERT_EQ(ops.Elements().at(1)->OpCode(), kCustomOpCode); +} + +TEST(IrAllocatorTest, ResizeDown) { + IrAllocator ops; + + LiteRtOp op1 = nullptr; + { + LiteRtOpT my_op; + my_op.SetOpCode(kNonCustomOpCode); + op1 = &ops.EmplaceBack(std::move(my_op)); + } + + { + LiteRtOpT my_op2; + my_op2.SetOpCode(kCustomOpCode); + ops.EmplaceBack(std::move(my_op2)); + } + + ops.ResizeDown(1); + + ASSERT_EQ(ops.Size(), 1); + EXPECT_EQ(ops.Elements().at(0), op1); +} + +TEST(IrAllocatorTest, Transfer) { + IrAllocator ops; + auto& op1 = ops.EmplaceBack(); + auto& op2 = ops.EmplaceBack(); + + IrAllocator other_ops; + auto& other_op1 = other_ops.EmplaceBack(); + auto& other_op2 = other_ops.EmplaceBack(); + + ops.Transfer(std::move(other_ops)); + + EXPECT_THAT(ops.Elements(), + ElementsAreArray({&op1, &op2, &other_op1, &other_op2})); +} + +} // namespace +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/model/litert_to_flatbuffer.cc b/tflite/experimental/litert/core/model/litert_to_flatbuffer.cc new file mode 100644 index 00000000..c74d003c --- /dev/null +++ b/tflite/experimental/litert/core/model/litert_to_flatbuffer.cc @@ -0,0 +1,126 @@ + +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/model/litert_to_flatbuffer.h" + +#include +#include + +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tflite/schema/schema_generated.h" + +namespace litert::internal { + +namespace { + +Expected MapElementType(LiteRtElementType litert_element_type) { + switch (litert_element_type) { + case kLiteRtElementTypeFloat32: + return tflite::TensorType_FLOAT32; + case kLiteRtElementTypeFloat16: + return tflite::TensorType_FLOAT16; + case kLiteRtElementTypeInt32: + return tflite::TensorType_INT32; + case kLiteRtElementTypeBool: + return tflite::TensorType_BOOL; + case kLiteRtElementTypeInt16: + return tflite::TensorType_INT16; + case kLiteRtElementTypeInt8: + return tflite::TensorType_INT8; + default: + return Error(kLiteRtStatusErrorUnsupported); + } +} + +template +Expected MapTensorTypeDetail( + const LiteRtTenzorType& litert_tensor_type) { + return Error(kLiteRtStatusErrorUnsupported); +} + +template <> +Expected MapTensorTypeDetail( + const LiteRtRankedTensorType& litert_tensor_type) { + auto tfl_element_type = MapElementType(litert_tensor_type.element_type); + if (!tfl_element_type) { + return tfl_element_type.Error(); + } + + auto litert_shape = absl::MakeConstSpan(litert_tensor_type.layout.dimensions, + litert_tensor_type.layout.rank); + return std::make_pair(*tfl_element_type, TflShapeInfo(litert_shape)); +} + +template +Expected MapQuantizationDetail( + const LiteRtQuantDetail& litert_quantization) { + return Error(kLiteRtStatusErrorUnsupported); +} + +template <> +Expected MapQuantizationDetail( + const LiteRtQuantizationPerTensor& litert_quantization) { + auto tfl_quantization = std::make_unique(); + tfl_quantization->scale.assign({litert_quantization.scale}); + tfl_quantization->zero_point.assign({litert_quantization.zero_point}); + return tfl_quantization; +} + +template <> +Expected +MapQuantizationDetail( + const LiteRtQuantizationPerChannel& litert_quantization) { + auto tfl_quantization = std::make_unique(); + + for (int i = 0; i < litert_quantization.num_channels; ++i) { + tfl_quantization->scale.push_back(litert_quantization.scales[i]); + tfl_quantization->zero_point.push_back(litert_quantization.zero_points[i]); + } + tfl_quantization->quantized_dimension = + litert_quantization.quantized_dimension; + return tfl_quantization; +} + +} // namespace + +Expected MapTensorType(const TensorType& litert_tensor_type) { + switch (litert_tensor_type.first) { + case kLiteRtRankedTensorType: + return MapTensorTypeDetail(litert_tensor_type.second.ranked_tensor_type); + default: + return Error(kLiteRtStatusErrorUnsupported); + } +} + +Expected MapQuantization( + const Quantization& litert_quantization) { + switch (litert_quantization.first) { + case kLiteRtQuantizationNone: + return TflQuantizationPtr(nullptr); + case kLiteRtQuantizationPerTensor: + return MapQuantizationDetail(litert_quantization.second.per_tensor); + case kLiteRtQuantizationPerChannel: + return MapQuantizationDetail(litert_quantization.second.per_channel); + default: + return Error(kLiteRtStatusErrorUnsupported); + } +} + +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/model/litert_to_flatbuffer.h b/tflite/experimental/litert/core/model/litert_to_flatbuffer.h new file mode 100644 index 00000000..ae98955d --- /dev/null +++ b/tflite/experimental/litert/core/model/litert_to_flatbuffer.h @@ -0,0 +1,32 @@ + +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_LITERT_TO_FLATBUFFER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_LITERT_TO_FLATBUFFER_H_ + +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" + +namespace litert::internal { + +Expected MapTensorType(const TensorType& litert_tensor_type); + +Expected MapQuantization( + const Quantization& litert_quantization); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_LITERT_TO_FLATBUFFER_H_ diff --git a/tflite/experimental/litert/core/model/litert_to_flatbuffer_test.cc b/tflite/experimental/litert/core/model/litert_to_flatbuffer_test.cc new file mode 100644 index 00000000..2463acf8 --- /dev/null +++ b/tflite/experimental/litert/core/model/litert_to_flatbuffer_test.cc @@ -0,0 +1,108 @@ + +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/model/litert_to_flatbuffer.h" + +#include +#include +#include + +#include +#include +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_layout.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" + +namespace litert::internal { +namespace { + +using ::testing::ElementsAreArray; + +TEST(LiteRtToFlatbufferTest, MapNoQuantization) { + Quantization q; + auto tfl_q = MapQuantization(q); + ASSERT_TRUE(tfl_q); + EXPECT_EQ(tfl_q.Value(), nullptr); +} + +TEST(LiteRtToFlatbufferTest, MapPerTensorQuantization) { + static constexpr float kScale = 1.0; + static constexpr int64_t kZp = 2; + + Quantization q; + q.first = kLiteRtQuantizationPerTensor; + q.second.per_tensor.scale = kScale; + q.second.per_tensor.zero_point = kZp; + + auto tfl_q = MapQuantization(q); + ASSERT_TRUE(tfl_q); + EXPECT_THAT(tfl_q->get()->scale, ElementsAreArray({kScale})); + EXPECT_THAT(tfl_q->get()->zero_point, ElementsAreArray({kZp})); +} + +TEST(LiteRtToFlatbufferTest, MapPerChannelQuantization) { + static constexpr size_t kRank = 2; + static constexpr size_t kQuantizedDimension = 1; + static constexpr float kScales[kRank] = {1.0, 2.0}; + static constexpr int64_t kZps[kRank] = {2, 3}; + + Quantization q; + q.first = kLiteRtQuantizationPerChannel; + q.second.per_channel.scales = const_cast(kScales); + q.second.per_channel.zero_points = const_cast(kZps); + q.second.per_channel.num_channels = kRank; + q.second.per_channel.quantized_dimension = kQuantizedDimension; + + auto tfl_q = MapQuantization(q); + ASSERT_TRUE(tfl_q); + EXPECT_THAT(tfl_q->get()->scale, ElementsAreArray(kScales)); + EXPECT_THAT(tfl_q->get()->zero_point, ElementsAreArray(kZps)); +} + +TEST(LiteRtToFlatbufferTest, MapDynamicTensorType) { + static constexpr int32_t kDims[] = {-1, 2}; + + TensorType t; + t.first = kLiteRtRankedTensorType; + t.second.ranked_tensor_type.element_type = kLiteRtElementTypeFloat32; + t.second.ranked_tensor_type.layout = BuildLayout(kDims); + + auto tfl_t = MapTensorType(t); + ASSERT_TRUE(tfl_t); + EXPECT_EQ(tfl_t->first, TflElementType::TensorType_FLOAT32); + EXPECT_TRUE(tfl_t->second.has_rank); + EXPECT_THAT(tfl_t->second.shape, ElementsAreArray({1, 2})); + EXPECT_THAT(tfl_t->second.shape_signature, ElementsAreArray(kDims)); +} + +TEST(LiteRtToFlatbufferTest, MapStaticTensorType) { + static constexpr int32_t kDims[] = {2, 2}; + + TensorType t; + t.first = kLiteRtRankedTensorType; + t.second.ranked_tensor_type.element_type = kLiteRtElementTypeFloat32; + t.second.ranked_tensor_type.layout = BuildLayout(kDims); + + auto tfl_t = MapTensorType(t); + ASSERT_TRUE(tfl_t); + EXPECT_EQ(tfl_t->first, TflElementType::TensorType_FLOAT32); + EXPECT_TRUE(tfl_t->second.has_rank); + EXPECT_THAT(tfl_t->second.shape, ElementsAreArray({2, 2})); + EXPECT_TRUE(tfl_t->second.shape_signature.empty()); +} + +} // namespace +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/model/model.cc b/tflite/experimental/litert/core/model/model.cc new file mode 100644 index 00000000..2ad732b6 --- /dev/null +++ b/tflite/experimental/litert/core/model/model.cc @@ -0,0 +1,136 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/model/model.h" + +#include +#include +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_layout.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" + +using ::litert::BufferRef; +using ::litert::internal::TflBuffer; +using ::litert::internal::TflBufferPtr; +using ::litert::internal::TflOpCode; +using ::litert::internal::TflOpCodePtr; +using ::litert::internal::TflOptions; + +TensorType MakeRankedTensorType(LiteRtElementType element_type, + absl::Span dims) { + TensorType tensor_type; + tensor_type.first = kLiteRtRankedTensorType; + auto& ranked = tensor_type.second.ranked_tensor_type; + ranked.element_type = element_type; + ABSL_DCHECK_LE(dims.size(), LITERT_TENSOR_MAX_RANK); + ranked.layout.rank = dims.size(); + std::copy(dims.begin(), dims.end(), ranked.layout.dimensions); + // Strides not yet supported. + ranked.layout.strides = nullptr; + return tensor_type; +} + +Quantization MakePerTensorQuantization(float scale, int64_t zero_point) { + Quantization quantization; + quantization.first = kLiteRtQuantizationPerTensor; + quantization.second.per_tensor.scale = scale; + quantization.second.per_tensor.zero_point = zero_point; + return quantization; +} + +LiteRtSignatureT MakeDefaultSignature(LiteRtSubgraph subgraph) { + auto tensor_name = [](auto* tensor) { return std::string(tensor->Name()); }; + + auto in_start = subgraph->Inputs().cbegin(); + auto in_end = subgraph->Inputs().cend(); + std::vector input_names(subgraph->NumInputs()); + std::transform(in_start, in_end, input_names.begin(), tensor_name); + + auto out_start = subgraph->Outputs().cbegin(); + auto out_end = subgraph->Outputs().cend(); + std::vector output_names(subgraph->NumOutputs()); + std::transform(out_start, out_end, output_names.begin(), tensor_name); + + std::string name(LiteRtSignatureT::kDefaultSignatureKey); + return LiteRtSignatureT(subgraph, std::move(input_names), + std::move(output_names), std::move(name)); +} + +::litert::Expected LookupSubgraph( + const LiteRtModelT& model, absl::string_view signature_key) { + auto sig = model.FindSignature(signature_key); + if (!sig) { + return sig.Error(); + } + return &sig->get().GetSubgraph(); +} + +namespace detail { + +void SetTflOpCodeInd(LiteRtOpT& litert_op, int32_t tfl_op_code_ind) { + litert_op.tfl_op_code_ind_ = tfl_op_code_ind; +} + +int32_t GetTflOpCodeInd(const LiteRtOpT& litert_op) { + return litert_op.tfl_op_code_ind_; +} + +const TflOptions& GetTflOptions(const LiteRtOpT& litert_op) { + return litert_op.tfl_option_; +} + +TflOptions&& TakeTflOptions(LiteRtOpT& litert_op) { + return std::move(litert_op.tfl_option_); +} + +const TflBuffer& GetTflBuffer(const LiteRtWeightsT& litert_weights) { + return *litert_weights.tfl_buf_; +} + +TflBufferPtr TakeTflBuffer(LiteRtWeightsT& litert_weights) { + return std::move(litert_weights.tfl_buf_); +} + +void SetTflBuffer(LiteRtWeightsT& litert_weights, TflBufferPtr tfl_buffer) { + litert_weights.tfl_buf_ = std::move(tfl_buffer); +} + +const std::vector& GetTflOpCodes( + const LiteRtModelT& litert_model) { + return litert_model.tfl_operator_codes_; +} + +std::vector&& TakeTflOpCodes(LiteRtModelT& litert_model) { + return std::move(litert_model.tfl_operator_codes_); +} + +void SetTflInitFlatbuffer(LiteRtModelT& litert_model, + BufferRef init_flatbuffer) { + litert_model.tfl_init_flatbuffer_ = init_flatbuffer; +} + +BufferRef GetTflInitFlatbuffer(const LiteRtModelT& litert_model) { + return litert_model.tfl_init_flatbuffer_; +} + +} // namespace detail diff --git a/tflite/experimental/litert/core/model/model.h b/tflite/experimental/litert/core/model/model.h new file mode 100644 index 00000000..3f01b790 --- /dev/null +++ b/tflite/experimental/litert/core/model/model.h @@ -0,0 +1,827 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" // IWYU pragma: export +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/core/model/ir_allocator.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tflite/schema/schema_generated.h" + +//////////////////////////////////////////////////////////////////////////////// +// Internal LiteRtIR +// +// These are the backing definitions for the opaque types in the c api +// (c/litert_model.h). +// +// < STORAGE DETAIL > +// +// Unless deleted as a result of calls c api client, the lifetime of all "IR +// Objects" (definitions of opaque types) are designed to be transitively owned +// by the LiteRtModelT which is generally the longset living object. See various +// "Emplace" methods. +// +// Since c api clients interface with pointers to IR Ojbects, a form of pointer +// stability is desirable. Classes in this file enforce that pointers to IR +// Objects are valid for their entire life time. Thus a c api client may store +// pointers and depend on referential equality of IR Objects thoughout different +// calls. This also facilitates storing edge/parent-references as pointers +// within IR Objects. +// +// Direct copying is generally not allowed for IR Objects since copying +// instances of mutually recursive types is not entirely well-defined. +// +// IR Objects are generally default constructible to facilitate stable storage +// and iterative construction. +// +// < EXPOSING TFLITE SCHEMA > +// +// Direct access to tflite schema types is limited to the "detail" namespace. +// This indicates that encapsulating all the details of the flatbuffer is a WIP. +// Future implementations may use different data forms (new litert serialized +// format, tflite runtime types etc). +// +// < USAGE NOTE > +// +// The classes here contain only simple getters & setters. Care should be taken +// to leave the IR in a valid state when using setters since the graph is +// doubly-linked. Higher-level functionality for correct graph mutation can be +// found in "model_graph.h". +//////////////////////////////////////////////////////////////////////////////// + +// All tflite schema type usage. +namespace detail { + +// OP + +// Placeholder for the ind of the dispatch op code added during serialization. +static constexpr auto kDispatchOpCodeTflInd = -1; + +void SetTflOpCodeInd(LiteRtOpT& litert_op, int32_t tfl_op_code_ind); + +int32_t GetTflOpCodeInd(const LiteRtOpT& litert_op); + +template +void SetTflOptions(LiteRtOpT& litert_op, Arg&& arg); + +const ::litert::internal::TflOptions& GetTflOptions(const LiteRtOpT& litert_op); + +::litert::internal::TflOptions&& TakeTflOptions(LiteRtOpT& litert_op); + +// WEIGHT + +const ::litert::internal::TflBuffer& GetTflBuffer( + const LiteRtWeightsT& litert_weights); + +litert::internal::TflBufferPtr TakeTflBuffer(LiteRtWeightsT& litert_weights); + +void SetTflBuffer(LiteRtWeightsT& litert_weights, + litert::internal::TflBufferPtr tfl_buffer); + +// MODEL + +const std::vector<::litert::internal::TflOpCodePtr>& GetTflOpCodes( + const LiteRtModelT& litert_model); + +template +void SetTflOpCodes(LiteRtModelT& litert_model, Arg&& arg); + +std::vector<::litert::internal::TflOpCodePtr>&& TakeTflOpCodes( + LiteRtModelT& litert_model); + +void SetTflInitFlatbuffer(LiteRtModelT& litert_model, + ::litert::BufferRef init_flatbuffer); + +::litert::BufferRef GetTflInitFlatbuffer( + const LiteRtModelT& litert_model); + +} // namespace detail + +// +// Helpers for conceptual unions from C api. +// + +// // For requesting opaque data stored within IR. +using BufferProvider = std::function; + +// TENSOR TYPE + +// Detail convenience type for tensor type union. +typedef union { + LiteRtUnrankedTensorType unranked_tensor_type; + LiteRtRankedTensorType ranked_tensor_type; +} TensorTypeDetail; + +// Union and identifier for tensor types. +using TensorType = std::pair; + +// Construct tensor type union as ranked tensor. NOTE: Copies data in `dims`. +TensorType MakeRankedTensorType(LiteRtElementType element_type, + absl::Span dims); + +// QUANTIZATION TYPE + +// Detail convenience type for quantization type union. +typedef union { + LiteRtQuantizationPerTensor per_tensor; + LiteRtQuantizationPerChannel per_channel; +} QuantizationDetail; + +// Union and identifier for quantization types. +using Quantization = std::pair; + +// Make default type with quantization info. +inline Quantization MakeEmptyQuantization() { + return Quantization(kLiteRtQuantizationNone, QuantizationDetail()); +} + +// Construct quantization type as per tensor. +Quantization MakePerTensorQuantization(float scale, int64_t zero_point); + +// Construct quantization type as per channel, requires buffer callback to +// store data. +template +Quantization MakePerChannelQuantization(const Scales& scales, + const ZeroPoints& zero_points, + int32_t quantized_dim, + BufferProvider buffer_provider) { + const auto size = std::size(scales); + ABSL_DCHECK_EQ(size, std::size(zero_points)); + + Quantization res; + res.first = kLiteRtQuantizationPerChannel; + + res.second.per_channel.num_channels = size; + res.second.per_channel.quantized_dimension = quantized_dim; + + const size_t scales_buf_size = size * sizeof(float); + const size_t zeros_buf_size = size * sizeof(int64_t); + auto* scales_buf = reinterpret_cast(buffer_provider(scales_buf_size)); + auto* zeros_buf = reinterpret_cast(buffer_provider(zeros_buf_size)); + std::copy(std::cbegin(scales), std::cend(scales), scales_buf); + std::copy(std::cbegin(zero_points), std::cend(zero_points), zeros_buf); + + res.second.per_channel.scales = scales_buf; + res.second.per_channel.zero_points = zeros_buf; + + return res; +} + +// +// Tensor +// + +// Constant data associated with a tensor. +class LiteRtWeightsT { + private: + using OwnedBuffer = ::litert::OwningBufferRef; + + public: + // Underlying data. + ::litert::BufferRef Buf() const { + return ::litert::BufferRef(tfl_buf_->data.data(), + tfl_buf_->data.size()); + } + + // Set weights via copied data. + void SetFromBuf(::litert::BufferRef buf) { + tfl_buf_->data.assign(buf.Data(), buf.Data() + buf.Size()); + } + + // Set via copied vec. + void SetFromVec(const std::vector& vec) { tfl_buf_->data = vec; } + + // IR is generally, default constructible and movable but not copyable. + LiteRtWeightsT() + : tfl_buf_(std::make_unique<::litert::internal::TflBuffer>()) {} + LiteRtWeightsT(const LiteRtWeightsT&) = delete; + LiteRtWeightsT(LiteRtWeightsT&&) = default; + LiteRtWeightsT& operator=(const LiteRtWeightsT&) = delete; + LiteRtWeightsT& operator=(LiteRtWeightsT&&) = default; + + // Friendship for internal tflite details. + friend const ::litert::internal::TflBuffer& detail::GetTflBuffer( + const LiteRtWeightsT& litert_weights); + + friend litert::internal::TflBufferPtr detail::TakeTflBuffer( + LiteRtWeightsT& litert_weights); + + friend void detail::SetTflBuffer(LiteRtWeightsT& litert_weights, + litert::internal::TflBufferPtr tfl_buffer); + + private: + // TFLITE + ::litert::internal::TflBufferPtr tfl_buf_; +}; + +// Fundamental value in a litert program, "edges" in the graph. +class LiteRtTensorT { + private: + using UserData = std::unique_ptr; + + public: + using Ref = std::reference_wrapper; + using Use = std::pair; + using UseVec = std::vector; + using Alloc = ::litert::internal::IrAllocator; + + // The ops that take this tensor as input. + const std::vector& Users() const { return users_; } + std::vector& Users() { return users_; } + + // Which operand index users take this tensor on, respects the ordering of + // users.. + const std::vector& UserArgInds() const { + return user_arg_inds_; + } + std::vector& UserArgInds() { return user_arg_inds_; } + + // Number of uses, same as number of user arg inds. + size_t NumUses() const { return users_.size(); } + + // Get the ith use. + Use GetUse(size_t ind) const { + return {users_.at(ind), user_arg_inds_.at(ind)}; + } + + // Remove the use at the given index. + void RemoveUse(size_t ind) { + users_.erase(users_.begin() + ind); + user_arg_inds_.erase(user_arg_inds_.begin() + ind); + } + + // Get the op that outputs this tensor, null if constant or subgraph input. + LiteRtOp DefiningOp() const { return defining_op_; } + + // Get the output index of the op that defines this tensor, only meaningful + // if it has a defining op. + LiteRtParamIndex DefiningOpOutInd() const { return defining_op_out_ind_; } + + // Update the defining op of this tensor. The caller is required to update the + // given op's output if not already correct. + void SetDefiningOp(LiteRtOpT& defining_op, LiteRtParamIndex out_ind) { + defining_op_ = &defining_op; + defining_op_out_ind_ = out_ind; + } + + // Set the defining op to none. + void ClearDefiningOp() { + defining_op_ = nullptr; + defining_op_out_ind_ = 0; + } + + // Any constant data associated with this tensor. + const LiteRtWeightsT& Weights() const { return weights_; } + LiteRtWeightsT& Weights() { return weights_; } + + // Authored name associated with this tensor. May be empty. + absl::string_view Name() const { return name_; } + + // Update the name associated with this tensor. + void SetName(std::string name) { name_ = std::move(name); } + + // Get quantization information for this tensor. + const Quantization& Qparams() const { return quantization_; } + Quantization& Qparams() { return quantization_; } + + // Set quantization information. + template + void SetQarams(Arg&& arg) { + quantization_ = std::forward(arg); + } + + // Get the tensor type of this tensor. + const TensorType& Type() const { return tensor_type_; } + TensorType& Type() { return tensor_type_; } + + // Set the tensor type. + template + void SetType(Arg&& arg) { + tensor_type_ = std::forward(arg); + } + + // Get a new buffer that will live as long as this tensor. Used for storing + // various buffers passed through c-api (dims, quantization etc). + uint8_t* RequestBuffer(size_t size) { + user_data_.push_back(std::make_unique(size)); + return user_data_.back().get(); + } + + // Allow for implicit conversion to bufer provider. + // NOLINTNEXTLINE + operator BufferProvider() & { + return [this](auto s) { return this->RequestBuffer(s); }; + } + + // IR is generally, default constructible and movable but not copyable. + LiteRtTensorT() = default; + LiteRtTensorT(const LiteRtTensorT&) = delete; + LiteRtTensorT(LiteRtTensorT&&) = default; + LiteRtTensorT& operator=(const LiteRtTensorT&) = delete; + LiteRtTensorT& operator=(LiteRtTensorT&&) = default; + + private: + std::vector users_; + std::vector user_arg_inds_; + + LiteRtOp defining_op_ = nullptr; + LiteRtParamIndex defining_op_out_ind_; + + LiteRtWeightsT weights_; + Quantization quantization_; + TensorType tensor_type_; + + std::string name_; + + std::vector user_data_; +}; + +// Helper to get multiple uses at once. +template +LiteRtTensorT::UseVec GetTensorUses(const LiteRtTensorT& tensor, + const Inds& inds) { + auto start = std::cbegin(inds); + auto end = std::cend(inds); + LiteRtTensorT::UseVec uses(end - start); + auto get = [&tensor = std::as_const(tensor)](auto i) { + return tensor.GetUse(i); + }; + std::transform(start, end, uses.begin(), get); + return uses; +} + +// +// Op +// + +// Fundamental unit of compute of a litert program, or "nodes" in the graph. +class LiteRtOpT { + public: + using Ref = std::reference_wrapper; + using Alloc = ::litert::internal::IrAllocator; + + // Input tensors for this op. + const std::vector& Inputs() const { return inputs_; } + std::vector& Inputs() { return inputs_; } + + // Access input at given ind. + LiteRtTensorT& Input(size_t ind) { return *Inputs().at(ind); } + const LiteRtTensorT& Input(size_t ind) const { return *Inputs().at(ind); } + + // Number of input tensors. + size_t NumInputs() const { return inputs_.size(); } + + // Output tensors for this op. + const std::vector& Outputs() const { return outputs_; } + std::vector& Outputs() { return outputs_; } + + // Number of output tensors. + size_t NumOutputs() const { return outputs_.size(); } + + // Access output at given ind. + LiteRtTensorT& Output(size_t ind) { return *Outputs().at(ind); } + const LiteRtTensorT& Output(size_t ind) const { return *Outputs().at(ind); } + + // Remove the ith entry of input list. + void RemoveInput(size_t ind) { inputs_.erase(inputs_.begin() + ind); } + + // Remove the ith entry of output list. + void RemoveOutput(size_t ind) { outputs_.erase(outputs_.begin() + ind); } + + // Get any custom options attached to this op. Empty if there are none. + litert::BufferRef CustomOptions() const { return custom_options_; } + + // Attach custom opaque optins to this op. + template + void SetCustomOptions(Args&&... args) { + custom_options_ = + ::litert::OwningBufferRef(std::forward(args)...); + } + + // Sets the custom options to zero length buffer. + void ClearCustomOptions() { custom_options_.Reset(); } + + // Get the op code. + LiteRtOpCode OpCode() const { return litert_op_code_; } + + // Set the op code. + void SetOpCode(LiteRtOpCode litert_op_code) { + litert_op_code_ = litert_op_code; + } + + // IR is generally, default constructible and movable but not copyable. + LiteRtOpT() = default; + LiteRtOpT(const LiteRtOpT&) = delete; + LiteRtOpT(LiteRtOpT&&) = default; + LiteRtOpT& operator=(const LiteRtOpT&) = delete; + LiteRtOpT& operator=(LiteRtOpT&&) = default; + + // Friendship for internal tflite details. + friend void detail::SetTflOpCodeInd(LiteRtOpT& litert_op, + int32_t tfl_op_code_ind); + + friend int32_t detail::GetTflOpCodeInd(const LiteRtOpT& litert_op); + + template + friend void detail::SetTflOptions(LiteRtOpT& litert_op, Arg&& arg); + + friend const ::litert::internal::TflOptions& detail::GetTflOptions( + const LiteRtOpT& litert_op); + + friend ::litert::internal::TflOptions&& detail::TakeTflOptions( + LiteRtOpT& litert_op); + + private: + LiteRtOpCode litert_op_code_; + + ::litert::OwningBufferRef custom_options_; + + std::vector inputs_; + std::vector outputs_; + + // TFLITE + int32_t tfl_op_code_ind_ = detail::kDispatchOpCodeTflInd; + ::litert::internal::TflOptions tfl_option_; +}; + +// +// Subgraph +// + +// Fundamental block of a litert program. Manages the storage of all +// ops and tensor within. +class LiteRtSubgraphT { + public: + using Ref = std::reference_wrapper; + using Alloc = ::litert::internal::IrAllocator; + + // Get a stable pointer for all of the tensors in this subgraph. + absl::Span Tensors() { return tensors_.Elements(); } + absl::Span Tensors() const { return tensors_.Elements(); } + + // Access the tensor at given ind. + LiteRtTensorT& Tensor(size_t ind) { return *Tensors().at(ind); } + const LiteRtTensorT& Tensor(size_t ind) const { return *Tensors().at(ind); } + + // Get a stable pointer for all of the ops in this subgraph. Will + // be a valid toplological order. + absl::Span Ops() { return ops_.Elements(); } + absl::Span Ops() const { return ops_.Elements(); } + + // Access op at the given ind. + LiteRtOpT& Op(size_t ind) { return *Ops().at(ind); } + const LiteRtOpT& Op(size_t ind) const { return *Ops().at(ind); } + + // All the subgraph input tensors, these also exist in Tensors. + const std::vector& Inputs() const { return inputs_; } + std::vector& Inputs() { return inputs_; } + + // Number of inputs tensors. + size_t NumInputs() const { return inputs_.size(); } + + // Access the subgraph input at given ind. + LiteRtTensorT& Input(size_t ind) { return *Inputs().at(ind); } + const LiteRtTensorT& Input(size_t ind) const { return *Inputs().at(ind); } + + // All the subgraph output tensors, these also exist in Tensors. + const std::vector& Outputs() const { return outputs_; } + std::vector& Outputs() { return outputs_; } + + // Number of outputs tensors. + size_t NumOutputs() const { return outputs_.size(); } + + // Access the subgraph output at given ind. + LiteRtTensorT& Output(size_t ind) { return *Outputs().at(ind); } + const LiteRtTensorT& Output(size_t ind) const { return *Outputs().at(ind); } + + // Clear the entry for the ith input. + void ClearInput(size_t ind) { inputs_.erase(inputs_.begin() + ind); } + + // Clear the entry for the ith output. + void ClearOutput(size_t ind) { outputs_.erase(outputs_.begin() + ind); } + + // Construct a new tensor which will be owned by this subgraph and get a + // reference to it. + template + LiteRtTensorT& EmplaceTensor(Args&&... args) { + return tensors_.EmplaceBack(std::forward(args)...); + } + + // Construct a new op which will be owned by this subgraph and get a + // reference to it. + template + LiteRtOpT& EmplaceOp(Args&&... args) { + return ops_.EmplaceBack(std::forward(args)...); + } + + // De-allocates ops that pass given predicate. Returns number of ops removed. + size_t RemoveOpIf(std::function pred) { + return ops_.RemoveIf(pred); + } + + // De-allocates tensors that pass given predicate. Returns number of tensors + // removed. + size_t RemoveTensorIf(std::function pred) { + return tensors_.RemoveIf(pred); + } + + // IR is generally, default constructible and movable but not copyable. + LiteRtSubgraphT() = default; + LiteRtSubgraphT(const LiteRtSubgraphT&) = delete; + LiteRtSubgraphT(LiteRtSubgraphT&&) = default; + LiteRtSubgraphT& operator=(const LiteRtSubgraphT&) = delete; + LiteRtSubgraphT& operator=(LiteRtSubgraphT&&) = default; + + private: + LiteRtTensorT::Alloc tensors_; + + LiteRtOpT::Alloc ops_; + + std::vector inputs_; + std::vector outputs_; +}; + +// +// Signature +// + +class LiteRtSignatureT { + private: + using StrVec = std::vector; + + public: + using Ptr = std::unique_ptr; + using Ref = std::reference_wrapper; + using Alloc = ::litert::internal::IrAllocator; + + static constexpr absl::string_view kDefaultSignatureKey = + ""; + + LiteRtSignatureT(LiteRtSubgraph subgraph, StrVec input_names, + StrVec output_names, std::string key) + : key_(std::move(key)), + subgraph_(subgraph), + input_names_(std::move(input_names)), + output_names_(std::move(output_names)) {} + + // String named inputs for called subgraph. + const StrVec& InputNames() const { return input_names_; } + + // String named outputs for called subgraph. + const StrVec& OutputNames() const { return output_names_; } + + // Get the callable subgraph. + const LiteRtSubgraphT& GetSubgraph() const { return *subgraph_; } + LiteRtSubgraphT& GetSubgraph() { return *subgraph_; } + + // Name of the callable signature. + absl::string_view Key() const { return key_; } + + bool operator==(const LiteRtSignatureT& other) const { + const auto key_eq = key_ == other.key_; + const auto subgraph_eq = subgraph_ == other.subgraph_; + const auto input_names_eq = input_names_ == other.input_names_; + const auto output_names_eq = output_names_ == other.output_names_; + return key_eq && subgraph_eq && input_names_eq && output_names_eq; + } + + // IR is generally, default constructible and movable but not copyable. + LiteRtSignatureT() = default; + LiteRtSignatureT(const LiteRtSignatureT&) = delete; + LiteRtSignatureT(LiteRtSignatureT&&) = default; + LiteRtSignatureT& operator=(const LiteRtSignatureT&) = delete; + LiteRtSignatureT& operator=(LiteRtSignatureT&&) = default; + + private: + std::string key_; + + LiteRtSubgraph subgraph_; + + StrVec input_names_; + StrVec output_names_; +}; + +// Make a basic signature from information in the given subgraph. Used with the +// main subgraph when no explicit signatures have been authored. +LiteRtSignatureT MakeDefaultSignature(LiteRtSubgraph subgraph); + +// +// Model +// + +// Root-level graph object for litert programs. Manages the storage +// of all litert graph objects within. +class LiteRtModelT { + private: + using MetadataMap = + absl::flat_hash_map>; + + public: + using Ref = std::reference_wrapper; + using Ptr = std::unique_ptr; + using TflOpCodes = std::vector; + + // TODO replace this with the index of the default signature. + static constexpr const size_t kMainSubgraphIndex = 0; + + // OBSERVERS + + // Get a stable pointer for all of the subgraphs within this model. + absl::Span Subgraphs() { return subgraphs_.Elements(); } + absl::Span Subgraphs() const { + return subgraphs_.Elements(); + } + + // Access subgraph at given ind. + LiteRtSubgraphT& Subgraph(size_t ind) { return *Subgraphs().at(ind); } + const LiteRtSubgraphT& Subgraph(size_t ind) const { + return *Subgraphs().at(ind); + } + + // Number of subraphs. + size_t NumSubgraphs() const { return subgraphs_.Elements().size(); } + + // Default entry point of this model. + const LiteRtSubgraphT* MainSubgraph() const { + return &Subgraph(kMainSubgraphIndex); + } + LiteRtSubgraph MainSubgraph() { return &Subgraph(kMainSubgraphIndex); } + + // Look up signature by key. + litert::Expected FindSignature( + absl::string_view signature_key) const { + for (LiteRtSignature sig : signatures_.Elements()) { + if (sig->Key() == signature_key) { + return std::ref(*sig); + } + } + return ::litert::Error(kLiteRtStatusErrorNotFound, "Signature not found"); + } + + // All signatures registered with this model. + absl::Span Signatures() const { + return signatures_.Elements(); + } + + // Look up metadata by key, getting a view of its buffer as a string + // if it exists. + litert::Expected> FindMetadata( + absl::string_view key) const { + if (auto it = metadata_.find(key); it != metadata_.end()) { + return it->second; + } + return ::litert::Error(kLiteRtStatusErrorNotFound); + } + + // Metadata key-val pair iterator. + MetadataMap::iterator MetadataBegin() { return metadata_.begin(); } + MetadataMap::iterator MetadataEnd() { return metadata_.end(); } + + // Remvoe and take ownership of the metadata under given key if it exists. + litert::Expected> PopMetadata( + absl::string_view key) { + if (auto it = metadata_.find(key); it != metadata_.end()) { + return metadata_.extract(it).mapped(); + } + return ::litert::Error(kLiteRtStatusErrorNotFound); + } + + // BUILDERS + + // Build a new subgraph and get a stable reference to it. + template + LiteRtSubgraphT& EmplaceSubgraph(Args&&... args) { + return subgraphs_.EmplaceBack(std::forward(args)...); + } + + // Transfers given subgraphs into this model. + void TransferSubgraphs(LiteRtSubgraphT::Alloc&& subgraphs) { + subgraphs_.Transfer(std::move(subgraphs)); + } + + // Cut all by the first `size` subgraphs. Does nothing if given size is + // greater or equal to current. + void ResizeSubgraphsDown(size_t size) { subgraphs_.ResizeDown(size); } + + // Adds a new metadata buffer to the model. Fails if it already exists. + template + LiteRtStatus PushMetadata(absl::string_view key, Args&&... args) { + if (metadata_.contains(key)) { + return kLiteRtStatusErrorInvalidArgument; + } + metadata_.insert( + {std::string(key.begin(), key.end()), + ::litert::OwningBufferRef(std::forward(args)...)}); + return kLiteRtStatusOk; + } + + // Construct a new signature for this model. + template + LiteRtSignatureT& EmplaceSignature(Args&&... args) { + return signatures_.EmplaceBack(std::forward(args)...); + } + + // IR is generally, default constructible and movable but not copyable. + LiteRtModelT() = default; + LiteRtModelT(const LiteRtModelT&) = delete; + LiteRtModelT(LiteRtModelT&&) = default; + LiteRtModelT& operator=(const LiteRtModelT&) = delete; + LiteRtModelT& operator=(LiteRtModelT&&) = default; + + // Friendship for internal tflite details. + friend const TflOpCodes& detail::GetTflOpCodes( + const LiteRtModelT& litert_model); + + template + friend void detail::SetTflOpCodes(LiteRtModelT& litert_model, Arg&& arg); + + friend TflOpCodes&& detail::TakeTflOpCodes(LiteRtModelT& litert_model); + + friend void detail::SetTflInitFlatbuffer( + LiteRtModelT& litert_model, ::litert::BufferRef init_flatbuffer); + + friend ::litert::BufferRef detail::GetTflInitFlatbuffer( + const LiteRtModelT& litert_model); + + private: + LiteRtSubgraphT::Alloc subgraphs_; + LiteRtSignatureT::Alloc signatures_; + + MetadataMap metadata_; + + // TFLITE + TflOpCodes tfl_operator_codes_; + litert::BufferRef tfl_init_flatbuffer_; +}; + +// Lookup subgraph by signature name. +::litert::Expected LookupSubgraph( + const LiteRtModelT& model, absl::string_view signature_key); + +// +// Utils +// + +// Used for communicating selections of ops. +class LiteRtOpListT { + public: + void Push(LiteRtOp op) { ops_.push_back(op); } + + std::vector Vec() const { + std::vector res; + res.reserve(ops_.size()); + res.assign(ops_.begin(), ops_.end()); + return res; + } + + private: + // Investigate if this is possible with vector (hit some issues). + std::list ops_; +}; + +namespace detail { + +template +void SetTflOptions(LiteRtOpT& litert_op, Arg&& arg) { + litert_op.tfl_option_ = std::forward(arg); +} + +template +void SetTflOpCodes(LiteRtModelT& litert_model, Arg&& arg) { + litert_model.tfl_operator_codes_ = std::forward(arg); +} + +} // namespace detail + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_H_ diff --git a/tflite/experimental/litert/core/model/model_buffer.cc b/tflite/experimental/litert/core/model/model_buffer.cc new file mode 100644 index 00000000..9a46357b --- /dev/null +++ b/tflite/experimental/litert/core/model/model_buffer.cc @@ -0,0 +1,97 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/model/model_buffer.h" + +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/core/byte_code_util.h" +#include "tflite/experimental/litert/core/filesystem.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/core/model/model_load.h" +#include "tflite/experimental/litert/core/model/model_serialize.h" + +namespace litert { +namespace internal { + +Expected> GetModelBufWithByteCode( + LiteRtModelT&& model, BufferRef npu_byte_code) { + LITERT_EXPECT_OK( + model.PushMetadata(kByteCodeMetadataKey, MakeByteCodePlaceholder())); + + for (auto* subgraph : model.Subgraphs()) { + for (auto* op : subgraph->Ops()) { + if (op->OpCode() != kLiteRtOpCodeTflCustom) { + continue; + } + auto exec_info = + MakeExecInfo(op->CustomOptions().StrView(), kByteCodeMetadataKey); + if (!exec_info) { + return exec_info.Error(); + } + op->SetCustomOptions(std::move(*exec_info)); + } + } + + auto serialized = SerializeModel(std::move(model)); + if (!serialized) { + return serialized; + } + + LITERT_EXPECT_OK( + FinishByteCodePlaceholders(*serialized, npu_byte_code.Size())); + + OwningBufferRef with_append(serialized->Size() + + npu_byte_code.Size()); + + uint8_t* write = with_append.Data(); + std::memcpy(write, serialized->Data(), serialized->Size()); + write += serialized->Size(); + std::memcpy(write, npu_byte_code.Data(), npu_byte_code.Size()); + + return with_append; +} + +Expected> GetModelBufWithByteCode( + absl::string_view tfl_file, absl::string_view npu_file) { + auto model = LoadModelFromFile(tfl_file); + if (!model) { + return model.Error(); + } + + auto npu_file_buf = LoadBinaryFile(npu_file); + if (!npu_file_buf) { + return npu_file_buf.Error(); + } + + return GetModelBufWithByteCode(std::move(**model), std::move(*npu_file_buf)); +} + +} // namespace internal +} // namespace litert diff --git a/tflite/experimental/litert/core/model/model_buffer.h b/tflite/experimental/litert/core/model/model_buffer.h new file mode 100644 index 00000000..5c196908 --- /dev/null +++ b/tflite/experimental/litert/core/model/model_buffer.h @@ -0,0 +1,37 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_BUFFER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_BUFFER_H_ + +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/core/model/model.h" + +namespace litert::internal { + +// Get a buffer that is the concatenation of given tflite file and +// npu byte code file. Adds metadata containing the offset/size of npu byte +// code. +Expected> GetModelBufWithByteCode( + absl::string_view tfl_file, absl::string_view npu_file); + +// Same as above but takes in litert model and npu byte_code in memory. +Expected> GetModelBufWithByteCode( + LiteRtModelT&& model, BufferRef npu_byte_code); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_BUFFER_H_ diff --git a/tflite/experimental/litert/core/model/model_buffer_test.cc b/tflite/experimental/litert/core/model/model_buffer_test.cc new file mode 100644 index 00000000..be13eeaf --- /dev/null +++ b/tflite/experimental/litert/core/model/model_buffer_test.cc @@ -0,0 +1,72 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/model/model_buffer.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/lite/allocation.h" +#include "tflite/experimental/litert/core/byte_code_util.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/core/model/model_load.h" +#include "tflite/experimental/litert/test/common.h" +#include "tflite/experimental/litert/test/testdata/simple_model_test_vectors.h" +#include "tflite/interpreter.h" +#include "tflite/interpreter_builder.h" +#include "tflite/kernels/register.h" +#include "tflite/model_builder.h" +#include "tflite/stderr_reporter.h" + +namespace litert::internal { +namespace { + +static constexpr absl::string_view kNpuFile = kGoogleTensorModelFileName; +static constexpr absl::string_view kTfliteFile = "simple_model_npu.tflite"; + +TEST(GetModelBufWithByteCode, CreateInterpreter) { + auto model_with_byte_code = + GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), + testing::GetTestFilePath(kNpuFile)); + ASSERT_TRUE(model_with_byte_code); + + auto alloc = std::make_unique( + model_with_byte_code->Data(), model_with_byte_code->Size(), + tflite::DefaultErrorReporter()); + + auto fb_model = tflite::FlatBufferModel::BuildFromBuffer( + reinterpret_cast(alloc->base()), alloc->bytes()); + ASSERT_NE(fb_model, nullptr); + + tflite::ops::builtin::BuiltinOpResolver resolver; + std::unique_ptr interpreter; + tflite::InterpreterBuilder(*fb_model, resolver)(&interpreter); + EXPECT_NE(interpreter, nullptr); +} + +TEST(GetModelBufWithByteCode, CheckMetadata) { + auto model_with_byte_code = + GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), + testing::GetTestFilePath(kNpuFile)); + ASSERT_TRUE(model_with_byte_code); + + auto model = LoadModelFromBuffer(*model_with_byte_code); + + auto byte_code_buffer = model->get()->FindMetadata(kByteCodeMetadataKey); + ASSERT_TRUE(byte_code_buffer); +} + +} // namespace +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/model/model_file_test.cc b/tflite/experimental/litert/core/model/model_file_test.cc new file mode 100644 index 00000000..31d8873a --- /dev/null +++ b/tflite/experimental/litert/core/model/model_file_test.cc @@ -0,0 +1,541 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include // NOLINT +#include +#include +#include +#include +#include + +#include // IWYU pragma: keep +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_element_type.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/cc/litert_model_predicates.h" +#include "tflite/experimental/litert/core/model/graph_validation.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/core/model/model_file_test_util.h" +#include "tflite/experimental/litert/core/model/model_load.h" +#include "tflite/experimental/litert/core/model/model_serialize.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tflite/experimental/litert/test/common.h" +#include "tflite/experimental/litert/test/test_macros.h" +#include "tflite/experimental/litert/test/test_models.h" + +namespace litert::internal { +namespace { + +using ::litert::testing::GetTestFilePath; +using ::testing::Each; +using ::testing::ElementsAreArray; +using ::testing::FloatEq; +using ::testing::Values; + +using ModelFactory = std::function()>; + +static constexpr absl::string_view kAddSimple = "add_simple.tflite"; +static constexpr absl::string_view kAddCst = "add_cst.tflite"; +static constexpr absl::string_view kDynamicShapeModel = + "dynamic_shape_tensor.tflite"; +static constexpr absl::string_view kSimpleMultiOp = "simple_multi_op.tflite"; +static constexpr absl::string_view kOneMul = "one_mul.tflite"; +static constexpr absl::string_view kSimpleMultiSubgraph = + "multi_subgraph.tflite"; + +// Load a model, then serialize and re-load. Used to test serialization. +Expected LoadModelThroughRoundTrip(absl::string_view filename) { + auto model = Model::CreateFromFile(GetTestFilePath(filename)); + if (!model) { + return model.Error(); + } + + OwningBufferRef buf; + auto [data, size, offset] = buf.GetWeak(); + + LITERT_EXPECT_OK( + LiteRtSerializeModel(model->Release(), &data, &size, &offset)); + + // Reload model. + LiteRtModel result = nullptr; + LITERT_EXPECT_OK( + LiteRtCreateModelFromBuffer(buf.Data(), buf.Size(), &result)); + + return Model::CreateFromOwnedHandle(result); +} + +ModelFactory MakeRoundTripFactory(absl::string_view filename) { + return [=]() { return LoadModelThroughRoundTrip(filename); }; +} + +ModelFactory MakeLoadFactory(absl::string_view filename) { + return [=]() { return Model::CreateFromFile(GetTestFilePath(filename)); }; +} + +// Test fixture parameterized by a file path to test model. +class TestWithModelPath : public ::testing::TestWithParam { + protected: + std::string GetTestModelPath() const { + return testing::GetTestFilePath(GetParam()); + } +}; + +// Test fixture pareterized by a function that loads a model. +class TestWithModelFactory : public ::testing::TestWithParam { + protected: + Expected LoadModel() { return GetParam()(); } +}; + +// Simple tests +//===--------------------------------------------------------------------------- + +TEST(ModelLoadTest, BadFilepath) { + LiteRtModel model = nullptr; + LITERT_ASSERT_STATUS_HAS_CODE(LiteRtCreateModelFromFile("bad_path", &model), + kLiteRtStatusErrorFileIO); +} + +TEST(ModelLoadTest, BadFileData) { + // NOLINTBEGIN +#ifndef NDEBUG + // In debug mode, flatbuffers will `assert` while verifying. This will + // cause this test to crash (as expected). + GTEST_SKIP(); +#endif + std::filesystem::path test_file_path(::testing::TempDir()); + test_file_path.append("bad_file.txt"); + + std::ofstream bad_file; + bad_file.open(test_file_path.c_str()); + bad_file << "not_tflite"; + bad_file.close(); + + LiteRtModel model = nullptr; + LITERT_ASSERT_STATUS_HAS_CODE( + LiteRtCreateModelFromFile(test_file_path.c_str(), &model), + kLiteRtStatusErrorInvalidFlatbuffer); + // NOLINTEND +} + +TEST(ModelLoadTest, WithMetadata) { + constexpr static std::string_view kMetadataName = "an_soc_manufacturer"; + constexpr static std::string_view kMetadataData = "My_Meta_Data"; + + auto flatbuffer = + FlatbufferWrapper::CreateFromTflFile(GetTestFilePath(kAddSimple)); + auto tfl_model = flatbuffer->get()->Unpack(); + PushMetadata(kMetadataName, *tfl_model, + BufferRef(kMetadataData.data(), kMetadataData.size())); + auto serialialized = SerializeFlatbuffer(*tfl_model); + + auto litert_model = LoadModelFromBuffer(serialialized); + ASSERT_TRUE(litert_model); + + auto metadata = litert_model->get()->FindMetadata(kMetadataName); + ASSERT_TRUE(metadata); + EXPECT_EQ(metadata->StrView(), kMetadataData); +} + +TEST(ModelSerializeTest, WithMetadata) { + auto model = litert::testing::LoadTestFileModel(kAddSimple); + + constexpr static absl::string_view kMetadataName = "an_soc_manufacturer"; + constexpr static absl::string_view kMetadataData = "My_Meta_Data"; + + LITERT_ASSERT_STATUS_OK(model.Get()->PushMetadata( + kMetadataName, OwningBufferRef(kMetadataData))); + + auto serialized = SerializeModel(std::move(*model.Get())); + EXPECT_TRUE(VerifyFlatbuffer(serialized->Span())); + + auto re_loaded = LoadModelFromBuffer(*serialized); + auto metadata = re_loaded->get()->FindMetadata(kMetadataName); + EXPECT_EQ(metadata->StrView(), kMetadataData); +} + +TEST(ModelLoadTest, WithSignature) { + auto model = litert::testing::LoadTestFileModel(kAddSimple); + auto& litert_model = *model.Get(); + + auto signature = + litert_model.FindSignature(LiteRtSignatureT::kDefaultSignatureKey); + ASSERT_TRUE(signature); + + EXPECT_EQ(signature->get().InputNames().size(), 1); + EXPECT_EQ(signature->get().OutputNames().size(), 1); + EXPECT_EQ(&signature->get().GetSubgraph(), litert_model.MainSubgraph()); +} + +TEST(ModelSerializeTest, WithSignature) { + auto model = litert::testing::LoadTestFileModel(kAddSimple); + auto& litert_model = *model.Get(); + + static constexpr char kInput[] = "foo"; + static constexpr char kOutput[] = "bar"; + static constexpr char kKey[] = "newKey"; + + LiteRtSignatureT signature(litert_model.MainSubgraph(), {kInput}, {kOutput}, + kKey); + litert_model.EmplaceSignature(std::move(signature)); + + auto serialized = SerializeModel(std::move(*model.Get())); + EXPECT_TRUE(VerifyFlatbuffer(serialized->Span())); + + auto re_loaded = LoadModelFromBuffer(*serialized); + auto re_loaded_signature = re_loaded->get()->FindSignature(kKey); + ASSERT_TRUE(re_loaded_signature); + const auto& sig = re_loaded_signature->get(); + + const auto& inputs = sig.InputNames(); + const auto& outputs = sig.OutputNames(); + EXPECT_THAT(inputs, ElementsAreArray({kInput})); + EXPECT_THAT(outputs, ElementsAreArray({kOutput})); + EXPECT_EQ(&sig.GetSubgraph(), re_loaded->get()->MainSubgraph()); +} + +// Tests that explicitly check litert graph structure. +//===--------------------------------------------------------------------------- + +using AddSimpleTest = TestWithModelFactory; + +TEST_P(AddSimpleTest, CheckGraph) { + auto model = LoadModel(); + ASSERT_TRUE(model); + + // func(arg0) + // output = tfl.add(arg0, arg0) + // return(output) + // + + auto subgraph = model->MainSubgraph(); + const auto subgraph_inputs = subgraph->Inputs(); + const auto subgraph_outputs = subgraph->Outputs(); + const auto ops = subgraph->Ops(); + + ASSERT_EQ(subgraph_inputs.size(), 1); + ASSERT_EQ(subgraph_outputs.size(), 1); + + const auto& internal_ops = subgraph->Get()->Ops(); + ASSERT_TRUE( + ValidateLocalTopology(internal_ops.cbegin(), internal_ops.cend())); + ASSERT_TRUE(ValidateSubgraphIO(*subgraph->Get())); + + ASSERT_EQ(ops.size(), 1); + const auto& op = ops.front(); + + const TensorTypeInfo float_2by2_type(ElementType::Float32, {2, 2}); + ASSERT_TRUE( + MatchOpType(op, {float_2by2_type, float_2by2_type}, {float_2by2_type})); + EXPECT_EQ(op.Code(), kLiteRtOpCodeTflAdd); + + const auto op_inputs = op.Inputs(); + ASSERT_EQ(op_inputs.size(), 2); + ASSERT_EQ(op_inputs.front().Get(), subgraph_inputs.front().Get()); + ASSERT_EQ(op_inputs.front().Get(), op_inputs.back().Get()); + + const auto op_outputs = op.Outputs(); + ASSERT_EQ(op_outputs.size(), 1); + ASSERT_EQ(op_outputs.front().Get(), subgraph_outputs.front().Get()); + + ASSERT_FALSE(subgraph_outputs.front().IsConstant()); + ASSERT_FALSE(subgraph_inputs.front().IsConstant()); +} + +INSTANTIATE_TEST_SUITE_P(ModelLoadTests, AddSimpleTest, + Values(MakeLoadFactory(kAddSimple))); + +INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, AddSimpleTest, + Values(MakeRoundTripFactory(kAddSimple))); + +using AddCstTest = TestWithModelFactory; + +TEST_P(AddCstTest, CheckGraph) { + auto model = LoadModel(); + ASSERT_TRUE(model); + + // func(arg0) + // cst = ConstantTensor([1, 2, 3, 4]) + // output = tfl.add(arg0, cst) + // return(output) + // + + auto subgraph = model->MainSubgraph(); + const auto subgraph_inputs = subgraph->Inputs(); + const auto subgraph_outputs = subgraph->Outputs(); + const auto ops = subgraph->Ops(); + + ASSERT_EQ(subgraph_inputs.size(), 1); + ASSERT_EQ(subgraph_outputs.size(), 1); + + const auto& internal_ops = subgraph->Get()->Ops(); + ASSERT_TRUE( + ValidateLocalTopology(internal_ops.cbegin(), internal_ops.cend())); + ASSERT_TRUE(ValidateSubgraphIO(*subgraph->Get())); + + ASSERT_EQ(ops.size(), 1); + const auto& op = ops.front(); + + const TensorTypeInfo float_by4_type(ElementType::Float32, {4}); + ASSERT_TRUE( + MatchOpType(op, {float_by4_type, float_by4_type}, {float_by4_type})); + EXPECT_EQ(op.Code(), kLiteRtOpCodeTflAdd); + + const auto op_inputs = op.Inputs(); + ASSERT_EQ(op_inputs.size(), 2); + ASSERT_EQ(op_inputs.front().Get(), subgraph_inputs.front().Get()); + ASSERT_TRUE(MatchWeights(op_inputs.back(), + absl::Span({1.0, 2.0, 3.0, 4.0}))); + + const auto op_outputs = op.Outputs(); + ASSERT_EQ(op_outputs.size(), 1); + ASSERT_EQ(op_outputs.front().Get(), subgraph_outputs.front().Get()); + + ASSERT_FALSE(subgraph_outputs.front().IsConstant()); + ASSERT_FALSE(subgraph_inputs.front().IsConstant()); +} + +INSTANTIATE_TEST_SUITE_P(ModelLoadTests, AddCstTest, + Values(MakeLoadFactory(kAddCst))); + +INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, AddCstTest, + Values(MakeRoundTripFactory(kAddCst))); + +using SimpleMultiOpTest = TestWithModelFactory; + +TEST_P(SimpleMultiOpTest, CheckGraph) { + auto model = LoadModel(); + ASSERT_TRUE(model); + + // func.func @main(arg0) + // 0 = tfl.add arg0, arg0 + // 1 = tfl.mul 0, 0 + // 2 = tfl.mul 1, 1 + // 3 = tfl.add 2, 2 + // return 3 + + auto subgraph = model->MainSubgraph(); + const auto subgraph_inputs = subgraph->Inputs(); + const auto subgraph_outputs = subgraph->Outputs(); + const auto ops = subgraph->Ops(); + + ASSERT_EQ(subgraph_inputs.size(), 1); + ASSERT_EQ(subgraph_outputs.size(), 1); + + const auto& internal_ops = subgraph->Get()->Ops(); + ASSERT_TRUE( + ValidateLocalTopology(internal_ops.cbegin(), internal_ops.cend())); + ASSERT_TRUE(ValidateSubgraphIO(*subgraph->Get())); + + ASSERT_EQ(ops.size(), 4); + + for (const auto& op : ops) { + const auto inputs = op.Inputs(); + ASSERT_EQ(inputs.size(), 2); + ASSERT_EQ(inputs.front().Get(), inputs.back().Get()); + } + + const TensorTypeInfo float_2by2_type(ElementType::Float32, {2, 2}); + + ASSERT_TRUE(MatchOpType(ops.at(2), {float_2by2_type, float_2by2_type}, + {float_2by2_type})); + EXPECT_EQ(ops.at(2).Code(), kLiteRtOpCodeTflMul); +} + +INSTANTIATE_TEST_SUITE_P(ModelLoadTests, SimpleMultiOpTest, + Values(MakeLoadFactory(kSimpleMultiOp))); + +INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, SimpleMultiOpTest, + Values(MakeRoundTripFactory(kSimpleMultiOp))); + +using SimpleMultiSubgraphTest = TestWithModelFactory; + +TEST_P(SimpleMultiSubgraphTest, CheckGraph) { + auto model_wrap = LoadModel(); + ASSERT_TRUE(model_wrap); + auto& model = *model_wrap->Get(); + + ASSERT_EQ(model.NumSubgraphs(), 3); + + { + auto& main = *model.MainSubgraph(); + EXPECT_EQ(main.NumInputs(), 1); + EXPECT_EQ(main.NumOutputs(), 1); + EXPECT_EQ(main.Ops().size(), 1); + EXPECT_EQ(main.Tensors().size(), 3); + auto& op = main.Op(0); + auto* cst = op.Inputs().back(); + auto data = Tensor(cst).WeightsData(); + ASSERT_TRUE(data); + EXPECT_THAT(*data, Each(FloatEq(-1.0))); + EXPECT_TRUE(ValidateLocalTopology(main.Ops().cbegin(), main.Ops().cend())); + EXPECT_TRUE(ValidateSubgraphIO(main)); + } + + { + auto& func1 = model.Subgraph(1); + EXPECT_EQ(func1.NumInputs(), 1); + EXPECT_EQ(func1.NumOutputs(), 1); + EXPECT_EQ(func1.Ops().size(), 1); + EXPECT_EQ(func1.Tensors().size(), 3); + auto& op = func1.Op(0); + auto* cst = op.Inputs().back(); + auto data = Tensor(cst).WeightsData(); + ASSERT_TRUE(data); + EXPECT_THAT(*data, Each(FloatEq(1.0))); + EXPECT_TRUE( + ValidateLocalTopology(func1.Ops().cbegin(), func1.Ops().cend())); + EXPECT_TRUE(ValidateSubgraphIO(func1)); + } + + { + auto& func2 = model.Subgraph(2); + EXPECT_EQ(func2.NumInputs(), 1); + EXPECT_EQ(func2.NumOutputs(), 1); + EXPECT_EQ(func2.Ops().size(), 1); + EXPECT_EQ(func2.Tensors().size(), 3); + auto& op = func2.Op(0); + auto* cst = op.Inputs().back(); + auto data = Tensor(cst).WeightsData(); + ASSERT_TRUE(data); + EXPECT_THAT(*data, Each(FloatEq(2.0))); + EXPECT_TRUE( + ValidateLocalTopology(func2.Ops().cbegin(), func2.Ops().cend())); + EXPECT_TRUE(ValidateSubgraphIO(func2)); + } +} + +INSTANTIATE_TEST_SUITE_P(ModelLoadTests, SimpleMultiSubgraphTest, + Values(MakeLoadFactory(kSimpleMultiSubgraph))); + +INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, SimpleMultiSubgraphTest, + Values(MakeRoundTripFactory(kSimpleMultiSubgraph))); + +// Tests that programatically check litert against tflite models. +//===--------------------------------------------------------------------------- + +using ModelLoadOpCheckTest = TestWithModelPath; + +TEST_P(ModelLoadOpCheckTest, CheckOps) { + const auto model_path = GetTestModelPath(); + + auto flatbuffer = FlatbufferWrapper::CreateFromTflFile(model_path); + ASSERT_TRUE(flatbuffer); + auto expected_fb = flatbuffer->get()->Unpack(); + + auto model = LoadModelFromFile(model_path); + ASSERT_TRUE(model); + + const auto* subgraph = model->get()->MainSubgraph(); + const auto& ops = subgraph->Ops(); + + const auto& fb_subgraph = *expected_fb->subgraphs.front(); + const auto& fb_ops = fb_subgraph.operators; + const auto& fb_tensors = fb_subgraph.tensors; + + ASSERT_EQ(ops.size(), fb_ops.size()); + + auto get_tfl_tensor = [&](uint32_t ind) -> const TflTensor& { + return *fb_tensors.at(ind); + }; + + for (auto i = 0; i < ops.size(); ++i) { + ASSERT_TRUE(EqualsFbOp(*ops.at(i), *fb_ops.at(i), get_tfl_tensor)); + } +} + +INSTANTIATE_TEST_SUITE_P(ModelLoadQuantizedOpCheckTest, ModelLoadOpCheckTest, + ::testing::ValuesIn(kAllQModels)); + +INSTANTIATE_TEST_SUITE_P(ModelLoadDynamicOpCheckTest, ModelLoadOpCheckTest, + ::testing::ValuesIn({kDynamicShapeModel})); + +using ModelSerializeOpCheckTest = TestWithModelPath; + +TEST_P(ModelSerializeOpCheckTest, CheckOps) { + const auto model_path = GetTestModelPath(); + + // Save the initial fb for comparison. + auto expected_fb_data = FlatbufferWrapper::CreateFromTflFile(model_path); + ASSERT_TRUE(expected_fb_data); + auto expected_fb = expected_fb_data->get()->Unpack(); + + // Round trip the model. + auto model = LoadModelFromFile(model_path); + ASSERT_TRUE(model); + auto serialized = SerializeModel(std::move(**model)); + + auto actual_fb_data = FlatbufferWrapper::CreateFromBuffer(*serialized); + ASSERT_TRUE(actual_fb_data); + auto actual_fb = actual_fb_data->get()->Unpack(); + + const auto& expected_fb_subgraph = *expected_fb->subgraphs.front(); + const auto& expected_fb_ops = expected_fb_subgraph.operators; + const auto& expected_fb_tensors = expected_fb_subgraph.tensors; + + const auto& actual_fb_subgraph = *actual_fb->subgraphs.front(); + const auto& actual_fb_ops = actual_fb_subgraph.operators; + const auto& actual_fb_tensors = actual_fb_subgraph.tensors; + + ASSERT_EQ(expected_fb_ops.size(), actual_fb_ops.size()); + for (auto i = 0; i < actual_fb_ops.size(); ++i) { + const auto& expected = *expected_fb_ops.at(i); + const auto& actual = *actual_fb_ops.at(i); + EXPECT_EQ(expected.inputs.size(), actual.inputs.size()); + EXPECT_EQ(expected.outputs.size(), actual.outputs.size()); + } + + ASSERT_EQ(expected_fb_tensors.size(), actual_fb_tensors.size()); + for (auto i = 0; i < actual_fb_tensors.size(); ++i) { + const auto& expected = *expected_fb_tensors.at(i); + const auto& actual = *actual_fb_tensors.at(i); + + EXPECT_EQ(actual.type, expected.type); + EXPECT_EQ(actual.shape, expected.shape); + EXPECT_EQ(actual.shape_signature, expected.shape_signature); + + const auto expected_q_params = expected.quantization.get(); + const auto actual_q_params = actual.quantization.get(); + + const auto neither_quantized = + !IsQuantized(expected_q_params) && !IsQuantized(actual_q_params); + const auto both_per_tensor = IsPerTensorQuantized(expected_q_params) && + IsPerTensorQuantized(actual_q_params); + ASSERT_TRUE(neither_quantized || both_per_tensor); + + if (both_per_tensor) { + const auto expected_per_tensor = AsPerTensorQparams(expected_q_params); + const auto actual_per_tensor = AsPerTensorQparams(actual_q_params); + EXPECT_EQ(*expected_per_tensor, *actual_per_tensor); + } + } +} + +INSTANTIATE_TEST_SUITE_P(ModelSerializeOpCheckTest, ModelSerializeOpCheckTest, + ::testing::ValuesIn({kOneMul, kDynamicShapeModel})); + +INSTANTIATE_TEST_SUITE_P(ModelSerializeQuantizedOpCheckTest, + ModelSerializeOpCheckTest, + ::testing::ValuesIn(kAllQModels)); + +} // namespace +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/model/model_file_test_util.cc b/tflite/experimental/litert/core/model/model_file_test_util.cc new file mode 100644 index 00000000..f7ec6a1c --- /dev/null +++ b/tflite/experimental/litert/core/model/model_file_test_util.cc @@ -0,0 +1,181 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/model/model_file_test_util.h" + +#include + +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_detail.h" +#include "tflite/experimental/litert/core/model/flatbuffer_to_litert.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" + +namespace litert::internal { + +namespace { + +template +bool EqualsFbQuantizationDetail(LiteRtQType litert_quantization, + const TflQuantization* tfl_quantization) { + return false; +} + +template <> +bool EqualsFbQuantizationDetail( + LiteRtQuantizationPerTensor litert_quantization, + const TflQuantization* tfl_quantization) { + auto tfl_q_params = AsPerTensorQparams(tfl_quantization); + if (!tfl_q_params) return false; + return litert_quantization.zero_point == tfl_q_params->first && + litert_quantization.scale == tfl_q_params->second; +} + +template <> +bool EqualsFbQuantizationDetail( + LiteRtQuantizationPerChannel litert_quantization, + const TflQuantization* tfl_quantization) { + auto tfl_q_params = AsPerChannelQparams(tfl_quantization); + if (!tfl_q_params) return false; + const auto& [quantized_dimension, num_channels, zero_points, scales] = + *tfl_q_params; + const auto qd_eq = + litert_quantization.quantized_dimension == quantized_dimension; + const auto num_chan_eq = litert_quantization.num_channels == num_channels; + const auto zeros_eq = std::equal(zero_points.begin(), zero_points.end(), + litert_quantization.zero_points); + const auto scales_eq = + std::equal(scales.begin(), scales.end(), litert_quantization.scales); + return qd_eq && num_chan_eq && zeros_eq && scales_eq; +} +template +bool EqualsFbTensorTypeDetail(LiteRtTenzorType litert_tensor_type, + const TflTensorType& tfl_tensor) { + LITERT_LOG(LITERT_ERROR, "LiteRtTensorType not supported"); + return false; +} + +template <> +bool EqualsFbTensorTypeDetail( + LiteRtRankedTensorType litert_tensor_type, + const TflTensorType& tfl_tensor_type) { + auto tfl_shape = AsDynamicShape(tfl_tensor_type.second); + if (!tfl_shape) { + LITERT_LOG(LITERT_ERROR, "Not ranked shape"); + return false; + } + + if (MapElementType(tfl_tensor_type.first) != + static_cast(litert_tensor_type.element_type)) { + LITERT_LOG(LITERT_ERROR, "Element type not equal"); + return false; + } + + auto same_or_both_dyn = [](auto l, auto r) { + const auto same_static = l >= 0 && l == r; + const auto both_dyn = l < 0 && r < 0; + return same_static || both_dyn; + }; + + auto& layout = litert_tensor_type.layout; + const bool shape_eq = + AllZip(*tfl_shape, absl::MakeConstSpan(layout.dimensions, layout.rank), + same_or_both_dyn); + if (!shape_eq) { + LITERT_LOG(LITERT_ERROR, "Shapes are not equal"); + return false; + } + + return true; +} + +} // namespace + +bool EqualsFbQuantization(const Quantization& litert_quantization, + const TflQuantization* tfl_quantization) { + switch (litert_quantization.first) { + case kLiteRtQuantizationPerTensor: + return EqualsFbQuantizationDetail(litert_quantization.second.per_tensor, + tfl_quantization); + case kLiteRtQuantizationPerChannel: + return EqualsFbQuantizationDetail(litert_quantization.second.per_channel, + tfl_quantization); + case kLiteRtQuantizationNone: + return !IsQuantized(tfl_quantization); + default: + // Not implemented yet. + return false; + } +} + +// Compare tensor type within litert tensor to the type within flatbuffer +// tensor. +bool EqualsFbTensorType(const TensorType& litert_tensor_type, + const TflTensorType& tfl_tensor_type) { + switch (litert_tensor_type.first) { + case kLiteRtRankedTensorType: + return EqualsFbTensorTypeDetail( + litert_tensor_type.second.ranked_tensor_type, tfl_tensor_type); + default: + LITERT_LOG(LITERT_ERROR, "Tensor kind not supported"); + // Not implemented yet. + return false; + } +} + +bool EqualsFbTensor(const LiteRtTensorT& litert_tensor, + const TflTensor& tfl_tensor) { + if (!EqualsFbTensorType(litert_tensor.Type(), + {tfl_tensor.type, TflShapeInfo(tfl_tensor)})) { + LITERT_LOG(LITERT_ERROR, "Tensor not same type"); + return false; + } + + if (!EqualsFbQuantization(litert_tensor.Qparams(), + tfl_tensor.quantization.get())) { + LITERT_LOG(LITERT_ERROR, "Tensor not same quantization"); + return false; + } + + return true; +} + +bool EqualsFbOp(const LiteRtOpT& litert_op, const TflOp& tfl_op, + GetTflTensor get_tfl_tensor) { + auto check_tensors = [&](auto& litert_tensors, auto& tfl_tensors) { + if (litert_tensors.size() != tfl_tensors.size()) { + LITERT_LOG(LITERT_ERROR, "Tensors not same size"); + return false; + } + + for (auto i = 0; i < litert_tensors.size(); ++i) { + const auto& fb_tensor = get_tfl_tensor(tfl_tensors.at(i)).get(); + const auto& litert_tensor = *litert_tensors.at(i); + + if (!EqualsFbTensor(litert_tensor, fb_tensor)) { + LITERT_LOG(LITERT_ERROR, "Tensor %d not same", i); + return false; + } + } + + return true; + }; + + return check_tensors(litert_op.Inputs(), tfl_op.inputs) && + check_tensors(litert_op.Outputs(), tfl_op.outputs); +} + +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/model/model_file_test_util.h b/tflite/experimental/litert/core/model/model_file_test_util.h new file mode 100644 index 00000000..5590f039 --- /dev/null +++ b/tflite/experimental/litert/core/model/model_file_test_util.h @@ -0,0 +1,50 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_FILE_TEST_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_FILE_TEST_UTIL_H_ + +#include + +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" + +namespace litert::internal { + +// Callback to get a tfl tensor from it's index. +using GetTflTensor = + std::function(uint32_t ind)>; + +// Compare q-params for having the same type and values. +bool EqualsFbQuantization(const Quantization& litert_quantization, + const TflQuantization* tfl_quantization); + +// Compare tensor types for having the same shape and element type. +bool EqualsFbTensorType(const TensorType& litert_tensor_type, + const TflTensorType& tfl_tensor_type); + +// Compare litert op to flatbuffer op along with their input/output tensors +// types and quantization. Takes a callback to lookup tfl tensors the indices +// within the tfl op. +bool EqualsFbOp(const LiteRtOpT& litert_op, const TflOp& tfl_op, + GetTflTensor get_tfl_tensor); + +// Compare litert tensor to flatbuffer tensor for having same types and +// quantization. +bool EqualsFbTensor(const LiteRtTensorT& litert_tensor, + const TflTensor& tfl_tensor); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_FILE_TEST_UTIL_H_ diff --git a/tflite/experimental/litert/core/model/model_graph.cc b/tflite/experimental/litert/core/model/model_graph.cc new file mode 100644 index 00000000..479a0a4a --- /dev/null +++ b/tflite/experimental/litert/core/model/model_graph.cc @@ -0,0 +1,181 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/model/model_graph.h" + +#include +#include + +#include "absl/log/absl_check.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_detail.h" +#include "tflite/experimental/litert/core/model/model.h" + +namespace litert::internal { + +namespace { + +bool IsOpDead(const LiteRtOpT& op) { + return op.Inputs().empty() && op.Outputs().empty(); +} + +bool IsTensorDead(const LiteRtTensorT& tensor) { + return tensor.DefiningOp() == nullptr && tensor.NumUses() == 0; +} + +} // namespace + +void CloneTo(const LiteRtTensorT& src, LiteRtTensorT& dest) { + dest.SetName({src.Name().cbegin(), src.Name().cend()}); + dest.SetQarams(src.Qparams()); + dest.SetType(src.Type()); + // TODO: b/383906683 Avoid copying for better performance. + dest.Weights().SetFromBuf(src.Weights().Buf()); +} + +void CloneTo(const LiteRtOpT& src, LiteRtOpT& dest) { + dest.SetCustomOptions(src.CustomOptions().Data(), src.CustomOptions().Size()); + detail::SetTflOptions(dest, detail::GetTflOptions(src)); + detail::SetTflOpCodeInd(dest, detail::GetTflOpCodeInd(src)); + dest.SetOpCode(src.OpCode()); +} + +LiteRtTensorT& MakeClone(LiteRtSubgraphT& parent, const LiteRtTensorT& src) { + auto& new_tensor = parent.EmplaceTensor(); + CloneTo(src, new_tensor); + return new_tensor; +} + +LiteRtOpT& MakeClone(LiteRtSubgraphT& parent, const LiteRtOpT& src) { + auto& new_op = parent.EmplaceOp(); + CloneTo(src, new_op); + return new_op; +} + +std::optional FindInput(const LiteRtOpT& op, + const LiteRtTensorT& tensor) { + return FindInd(op.Inputs().cbegin(), op.Inputs().cend(), &tensor); +} + +std::optional FindOutput(const LiteRtOpT& op, + const LiteRtTensorT& tensor) { + return FindInd(op.Outputs().cbegin(), op.Outputs().cend(), &tensor); +} + +std::optional FindInput(const LiteRtSubgraphT& subgraph, + const LiteRtTensorT& tensor) { + return FindInd(subgraph.Inputs().cbegin(), subgraph.Inputs().cend(), &tensor); +} + +std::optional FindOutput(const LiteRtSubgraphT& subgraph, + const LiteRtTensorT& tensor) { + return FindInd(subgraph.Outputs().cbegin(), subgraph.Outputs().cend(), + &tensor); +} + +SmallVec FindUseInds(const LiteRtTensorT& tensor, + const LiteRtOpT& op) { + SmallVec res; + for (auto i = 0; i < tensor.NumUses(); ++i) { + if (tensor.Users().at(i) == &op) { + res.push_back(i); + } + } + return res; +} + +bool IsConstant(const LiteRtTensorT& tensor) { + const auto is_const = tensor.Weights().Buf().Size() > 0; + ABSL_DCHECK(!is_const || tensor.DefiningOp() == nullptr) + << "Constant tensors should not be defined by an op"; + return is_const; +} + +void AttachInput(LiteRtTensor tensor, LiteRtOpT& op) { + op.Inputs().push_back(tensor); + tensor->Users().push_back(&op); + tensor->UserArgInds().push_back(op.Inputs().size() - 1); +} + +void AttachOutput(LiteRtTensor tensor, LiteRtOpT& op) { + ABSL_DCHECK(tensor->DefiningOp() == nullptr) + << "Cannot add an already defined tensor as op output"; + op.Outputs().push_back(tensor); + tensor->SetDefiningOp(op, op.Outputs().size() - 1); +} + +LiteRtTensor DisconnectInput(LiteRtOpT& op, LiteRtParamIndex input_ind) { + ABSL_DCHECK(input_ind < op.Inputs().size()) << "Removing tensor index oob"; + auto& input = op.Input(input_ind); + + // Find the index of the use for the given in edge. + auto target_use_ind = -1; + for (auto i = 0; i < input.NumUses(); ++i) { + if (input.Users().at(i) == &op && input.UserArgInds().at(i) == input_ind) { + target_use_ind = i; + } + } + ABSL_DCHECK_GE(target_use_ind, 0) << "Malformed graph"; + + // Slide latter input use arg inds to the left. + for (auto i = input_ind + 1; i < op.Inputs().size(); ++i) { + auto& r_in = op.Input(i); + for (auto u = 0; u < r_in.NumUses(); ++u) { + auto& r_arg_ind = r_in.UserArgInds().at(u); + if (r_in.Users().at(u) == &op && r_arg_ind > input_ind) { + r_arg_ind -= 1; + } + } + } + + // Update the edges. + input.RemoveUse(target_use_ind); + op.RemoveInput(input_ind); + + return &input; +} + +bool IsIO(const LiteRtSubgraphT& subgraph, const LiteRtTensorT& tensor) { + return FindInput(subgraph, tensor) || FindOutput(subgraph, tensor); +} + +LiteRtTensor DisconnectOutput(LiteRtOpT& op, LiteRtParamIndex output_ind) { + ABSL_DCHECK(output_ind < op.Outputs().size()) << "Removing tensor index oob"; + auto& output = op.Output(output_ind); + output.ClearDefiningOp(); + op.RemoveOutput(output_ind); + return &output; +} + +void Drop(LiteRtOpT& litert_op) { + while (!litert_op.Inputs().empty()) { + DisconnectInput(litert_op, 0); + } + while (!litert_op.Outputs().empty()) { + DisconnectOutput(litert_op, 0); + } +} + +bool DCE(LiteRtSubgraphT& subgraph) { + const auto ops_removed = subgraph.RemoveOpIf(IsOpDead); + + auto rm_tensor = [&subgraph = std::as_const(subgraph)](const auto& t) { + return IsTensorDead(t) && !IsIO(subgraph, t); + }; + const auto tensors_removed = subgraph.RemoveTensorIf(rm_tensor); + + return (ops_removed + tensors_removed) > 0; +} + +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/model/model_graph.h b/tflite/experimental/litert/core/model/model_graph.h new file mode 100644 index 00000000..35506669 --- /dev/null +++ b/tflite/experimental/litert/core/model/model_graph.h @@ -0,0 +1,105 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_GRAPH_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_GRAPH_H_ + +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/core/model/model.h" + +namespace litert::internal { + +// using IrMapping = absl::flat_hash_map; + +// CLONING + +// Clones the basic data between tensors (like name and data) but not +// things related to incoming/outgoing edges (users, defining op) or weights. +void CloneTo(const LiteRtTensorT& src, LiteRtTensorT& dest); + +// Clones the basic data between ops (like op code and options) but +// things related to incoming/outgoing edges (input/output tensors). +void CloneTo(const LiteRtOpT& src, LiteRtOpT& dest); + +// Same as clone to, but allocates a the dest tensor into given subgraph. +LiteRtTensorT& MakeClone(LiteRtSubgraphT& parent, const LiteRtTensorT& src); + +// Same as clone to, but allocates a the dest op into given subgraph. +LiteRtOpT& MakeClone(LiteRtSubgraphT& parent, const LiteRtOpT& src); + +// OBSERVERS + +// Checks if tensor is input to given op, return its index if so. +std::optional FindInput(const LiteRtOpT& op, + const LiteRtTensorT& tensor); + +// Checks if tensor is output to given op, return its index if so. +std::optional FindOutput(const LiteRtOpT& op, + const LiteRtTensorT& tensor); + +// Checks if tensor is input to given subgraph, return its index if so. +std::optional FindInput(const LiteRtSubgraphT& subgraph, + const LiteRtTensorT& tensor); + +// Checks if tensor is output to given subgraph, return its index if so. +std::optional FindOutput(const LiteRtSubgraphT& subgraph, + const LiteRtTensorT& tensor); + +// Check if tensor is part of subgraph IO. +bool IsIO(const LiteRtSubgraphT& subgraph, const LiteRtTensorT& tensor); + +// Checks if tensor is used by op, return the use inds for each use of tensor by +// op (there may be multiple). These are the indexes to call +// LiteRtTensorT::GetUse with. +SmallVec FindUseInds(const LiteRtTensorT& tensor, + const LiteRtOpT& op); + +// Is this tensor a constant tensor? +bool IsConstant(const LiteRtTensorT& tensor); + +// MUTATORS + +// Attaches the pre-allocated tensor to be an input of given op. +void AttachInput(LiteRtTensor tensor, LiteRtOpT& op); + +// Attaches the pre-allocated tensor to be an output of given op. +void AttachOutput(LiteRtTensor tensor, LiteRtOpT& op); + +// Remove the input edge from an op. Return the disconnected tensor. +LiteRtTensor DisconnectInput(LiteRtOpT& op, LiteRtParamIndex input_ind); + +// Remove an output edge from an op. Return the disconnected tensor. +LiteRtTensor DisconnectOutput(LiteRtOpT& op, LiteRtParamIndex output_ind); + +// Remove all incoming and outgoing edges from this op. This can prep nodes +// for removal in DCE. +void Drop(LiteRtOpT& litert_op); + +// Run very naive dead code elimination. Removes only ops/tensors that have no +// in/out edges. Ops are handled first. Ignores subgraph IO. Not recursive and +// does only one pass. Returns if the graph was modified. +// NOTE: This de-allocates removed objects, only use when references to these +// objects will not be used. +// TODO: Update this with complete work-list based approach. +bool DCE(LiteRtSubgraphT& subgraph); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_GRAPH_H_ diff --git a/tflite/experimental/litert/core/model/model_graph_test.cc b/tflite/experimental/litert/core/model/model_graph_test.cc new file mode 100644 index 00000000..788d85cc --- /dev/null +++ b/tflite/experimental/litert/core/model/model_graph_test.cc @@ -0,0 +1,344 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/model/model_graph.h" + +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/core/model/graph_validation.h" +#include "tflite/experimental/litert/core/model/model.h" + +namespace litert::internal { +namespace { + +using ::testing::UnorderedElementsAreArray; + +// Custom matcher; example: +// ``` +// LiteRtTensor tensor ... +// EXPECT_THAT(tensor, HasRankedType(kLiteRtInt, absl::MakeSpan({2, 2}))); +// ``` +// TODO: Update to use dumping API directly and move to shared header. +MATCHER_P2(HasRankedType, element_type, shape, "") { + if (arg.Type().first != kLiteRtRankedTensorType) { + *result_listener << "Not ranked tensor type"; + return false; + } + const auto& ranked_tensor_type = arg.Type().second.ranked_tensor_type; + const auto& layout = ranked_tensor_type.layout; + + const auto element_type_eq = ranked_tensor_type.element_type == element_type; + const auto rank_eq = layout.rank == std::size(shape); + + auto actual_shape = absl::MakeConstSpan(layout.dimensions, layout.rank); + auto expected_shape = + absl::MakeConstSpan(std::cbegin(shape), std::cend(shape)); + const auto shape_eq = actual_shape == expected_shape; + + if (shape_eq && element_type_eq && rank_eq) { + return true; + } + + *result_listener << "\n"; + if (!shape_eq) { + *result_listener << "Not correct shape\n"; + } + if (!element_type_eq) { + *result_listener << "Not correct element type\n"; + } + if (!rank_eq) { + *result_listener << "Not correct rank\n"; + } + + *result_listener << absl::StreamFormat("Actual ElementType is: %d\n", + ranked_tensor_type.element_type); + *result_listener << absl::StreamFormat("Actual Rank is: %lu\n", layout.rank); + *result_listener << "Actual shape is: { "; + for (const auto d : actual_shape) { + *result_listener << absl::StreamFormat("%d, ", d); + } + *result_listener << "}\n"; + + return false; +} + +using ::testing::ElementsAreArray; + +static constexpr size_t kRank = 1; +static constexpr int32_t kDims[] = {2}; +static constexpr absl::Span kDimsSpan(kDims); +static constexpr auto kType = kLiteRtElementTypeInt32; +static constexpr absl::string_view kCustomOptions = "OPTIONS"; +static constexpr auto kOpCode = kLiteRtOpCodeTflMul; + +LiteRtTensorT TestTensor() { + LiteRtTensorT tensor; + tensor.Type().first = kLiteRtRankedTensorType; + tensor.Type().second.ranked_tensor_type.element_type = kType; + tensor.Type().second.ranked_tensor_type.layout.dimensions[0] = kDims[0]; + tensor.Type().second.ranked_tensor_type.layout.rank = kRank; + return tensor; +} + +LiteRtOpT TestOp() { + LiteRtOpT op; + op.SetOpCode(kOpCode); + op.SetCustomOptions(kCustomOptions); + return op; +} + +TEST(ModelGraphTest, CloneTensor) { + LiteRtTensorT dest; + CloneTo(TestTensor(), dest); + EXPECT_THAT(dest, HasRankedType(kType, kDimsSpan)); +} + +TEST(ModelGraphTest, MakeCloneTensor) { + LiteRtSubgraphT subgraph; + auto& dest = MakeClone(subgraph, TestTensor()); + EXPECT_THAT(dest, HasRankedType(kType, kDimsSpan)); +} + +TEST(ModelGraphTest, CloneOp) { + LiteRtOpT dest; + CloneTo(TestOp(), dest); + EXPECT_EQ(dest.OpCode(), kOpCode); + EXPECT_EQ(dest.CustomOptions().StrView(), kCustomOptions); +} + +TEST(ModelGraphTest, MakeCloneOp) { + LiteRtSubgraphT subgraph; + auto& dest = MakeClone(subgraph, TestOp()); + EXPECT_EQ(dest.OpCode(), kOpCode); + EXPECT_EQ(dest.CustomOptions().StrView(), kCustomOptions); +} + +TEST(ModelGraphTest, OpFindInput) { + auto op = TestOp(); + auto tensor = TestTensor(); + AttachInput(&tensor, op); + auto input = FindInput(op, tensor); + ASSERT_TRUE(input); + EXPECT_EQ(*input, 0); +} + +TEST(ModelGraphTest, OpFindOutput) { + auto op = TestOp(); + auto tensor = TestTensor(); + AttachOutput(&tensor, op); + auto output = FindOutput(op, tensor); + ASSERT_TRUE(output); + EXPECT_EQ(*output, 0); +} + +TEST(ModelGraphTest, SubgraphFindInput) { + LiteRtSubgraphT subgraph; + auto tensor = TestTensor(); + subgraph.Inputs().push_back(&tensor); + auto input = FindInput(subgraph, tensor); + ASSERT_TRUE(input); + EXPECT_EQ(*input, 0); +} + +TEST(ModelGraphTest, SubgraphFindOutput) { + LiteRtSubgraphT subgraph; + auto tensor = TestTensor(); + subgraph.Outputs().push_back(&tensor); + auto output = FindOutput(subgraph, tensor); + ASSERT_TRUE(output); + EXPECT_EQ(*output, 0); +} + +TEST(ModelGraphTest, TensorFindUseInds) { + auto op1 = TestOp(); + auto op2 = TestOp(); + auto tensor = TestTensor(); + + AttachInput(&tensor, op1); + AttachInput(&tensor, op2); + AttachInput(&tensor, op1); + + auto use_inds = FindUseInds(tensor, op1); + auto uses = GetTensorUses(tensor, use_inds); + ASSERT_EQ(uses.size(), 2); + + LiteRtTensorT::UseVec expected = {{&op1, 0}, {&op1, 1}}; + EXPECT_THAT(uses, UnorderedElementsAreArray(expected)); +} + +TEST(ModelGraphTest, OpAttachInput) { + auto op = TestOp(); + auto tensor = TestTensor(); + AttachInput(&tensor, op); + EXPECT_THAT(op.Inputs(), ElementsAreArray({&tensor})); + EXPECT_THAT(tensor.Users(), ElementsAreArray({&op})); + EXPECT_THAT(tensor.UserArgInds(), ElementsAreArray({0})); +} + +TEST(ModelGraphTest, OpAttachOutput) { + auto op = TestOp(); + auto tensor = TestTensor(); + AttachOutput(&tensor, op); + EXPECT_THAT(op.Outputs(), ElementsAreArray({&tensor})); + EXPECT_EQ(tensor.DefiningOp(), &op); + EXPECT_EQ(tensor.DefiningOpOutInd(), 0); +} + +TEST(ModelGraphTest, DisconnectInputOp) { + auto op = TestOp(); + auto tensor = TestTensor(); + AttachInput(&tensor, op); + auto disconnected = DisconnectInput(op, 0); + EXPECT_EQ(disconnected, &tensor); + EXPECT_TRUE(op.Inputs().empty()); + EXPECT_TRUE(tensor.Users().empty()); + EXPECT_TRUE(tensor.UserArgInds().empty()); +} + +TEST(ModelGraphTest, DisconnectMiddleInputOp) { + auto op = TestOp(); + + auto tensor1 = TestTensor(); + auto tensor2 = TestTensor(); + auto tensor3 = TestTensor(); + + AttachInput(&tensor1, op); + AttachInput(&tensor2, op); + AttachInput(&tensor3, op); + + auto disconnected = DisconnectInput(op, 1); + + EXPECT_EQ(disconnected, &tensor2); + ASSERT_EQ(op.Inputs().size(), 2); + EXPECT_EQ(op.Inputs().front(), &tensor1); + EXPECT_EQ(op.Inputs().back(), &tensor3); + ASSERT_TRUE(tensor2.Users().empty()); + ASSERT_TRUE(tensor2.UserArgInds().empty()); + + ASSERT_TRUE(ValidateLocalTopology(op)); +} + +TEST(ModelGraphTest, DisconnectOutputOp) { + auto op = TestOp(); + auto tensor = TestTensor(); + AttachOutput(&tensor, op); + auto disconnected = DisconnectOutput(op, 0); + EXPECT_EQ(disconnected, &tensor); + EXPECT_EQ(tensor.DefiningOp(), nullptr); + EXPECT_TRUE(op.Outputs().empty()); +} + +TEST(ModelGraphTest, DropOp) { + LiteRtOpT op; + + LiteRtTensorT input1; + LiteRtTensorT input2; + LiteRtTensorT output; + + AttachInput(&input1, op); + AttachInput(&input2, op); + AttachOutput(&output, op); + + Drop(op); + + EXPECT_TRUE(op.Inputs().empty()); + EXPECT_TRUE(op.Outputs().empty()); + EXPECT_TRUE(input1.Users().empty()); + EXPECT_TRUE(input2.Users().empty()); + EXPECT_EQ(output.DefiningOp(), nullptr); +} + +TEST(ModelGraphTestDCE, NoDeadCode) { + LiteRtSubgraphT subgraph; + + auto& input = subgraph.EmplaceTensor(); + auto& output = subgraph.EmplaceTensor(); + + auto& op = subgraph.EmplaceOp(); + + AttachInput(&input, op); + AttachOutput(&output, op); + + subgraph.Inputs().push_back(&input); + subgraph.Outputs().push_back(&output); + + ASSERT_FALSE(DCE(subgraph)); + EXPECT_EQ(subgraph.Ops().size(), 1); + EXPECT_EQ(subgraph.Tensors().size(), 2); + + ASSERT_TRUE( + ValidateLocalTopology(subgraph.Ops().cbegin(), subgraph.Ops().cend())); + ASSERT_TRUE(ValidateSubgraphIO(subgraph)); +} + +TEST(ModelGraphTestDCE, DeadTensor) { + LiteRtSubgraphT subgraph; + subgraph.EmplaceTensor(); + + ASSERT_TRUE(DCE(subgraph)); + EXPECT_TRUE(subgraph.Tensors().empty()); + + ASSERT_TRUE( + ValidateLocalTopology(subgraph.Ops().cbegin(), subgraph.Ops().cend())); + ASSERT_TRUE(ValidateSubgraphIO(subgraph)); +} + +TEST(ModelGraphTestDCE, DeadOp) { + LiteRtSubgraphT subgraph; + subgraph.EmplaceOp(); + + ASSERT_TRUE(DCE(subgraph)); + EXPECT_TRUE(subgraph.Ops().empty()); + + ASSERT_TRUE( + ValidateLocalTopology(subgraph.Ops().cbegin(), subgraph.Ops().cend())); + ASSERT_TRUE(ValidateSubgraphIO(subgraph)); +} + +TEST(ModelGraphTestDCE, SomeDead) { + LiteRtSubgraphT subgraph; + + auto& input = subgraph.EmplaceTensor(); + auto& output = subgraph.EmplaceTensor(); + + auto& op = subgraph.EmplaceOp(); + + AttachInput(&input, op); + AttachOutput(&output, op); + + // Dead + subgraph.EmplaceTensor(); + subgraph.EmplaceOp(); + + subgraph.Inputs().push_back(&input); + subgraph.Outputs().push_back(&output); + + ASSERT_TRUE(DCE(subgraph)); + EXPECT_EQ(subgraph.Ops().size(), 1); + EXPECT_EQ(subgraph.Tensors().size(), 2); + + ASSERT_TRUE( + ValidateLocalTopology(subgraph.Ops().cbegin(), subgraph.Ops().cend())); + ASSERT_TRUE(ValidateSubgraphIO(subgraph)); +} + +} // namespace +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/model/model_load.cc b/tflite/experimental/litert/core/model/model_load.cc new file mode 100644 index 00000000..49ea8fe4 --- /dev/null +++ b/tflite/experimental/litert/core/model/model_load.cc @@ -0,0 +1,321 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/model/model_load.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/core/model/flatbuffer_to_litert.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/core/model/model_graph.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tflite/schema/schema_generated.h" + +namespace litert::internal { +namespace { + +// Provides a view of model-level resources when constructing litert graph. +class FlatbufferContext { + public: + explicit FlatbufferContext(TflModel& tfl_model) : tfl_model_(tfl_model) {} + + void SetOpCode(LiteRtOpT& litert_op, uint32_t ind) { + auto tfl_op_code = GetTflOpCode(tfl_model_, ind); + litert_op.SetOpCode(static_cast(*tfl_op_code)); + detail::SetTflOpCodeInd(litert_op, ind); + } + + // Take ownership of the tfl buffer under the given index if it exists. + Expected TakeTflBuffer(uint32_t ind) { + return TakeBuffer(tfl_model_, ind); + } + + private: + TflModel& tfl_model_; +}; + +LiteRtStatus UnpackOp(FlatbufferContext& context, LiteRtSubgraphT& parent, + TflOpPtr tfl_op, LiteRtOpT& litert_op) { + // I/O TENSORS + + if (!tfl_op->intermediates.empty()) { + // TODO: b/365299994 - Support intermediates. + LITERT_LOG(LITERT_ERROR, "Intermediate tensors not yet supported."); + return kLiteRtStatusErrorUnsupported; + } + + for (auto m_input : tfl_op->mutating_variable_inputs) { + if (m_input) { + // TODO: b/365299994 - Support mutating variable inputs. + LITERT_LOG(LITERT_ERROR, "Mutating variable inputs not yet supported."); + return kLiteRtStatusErrorUnsupported; + } + } + + for (auto input_ind : tfl_op->inputs) { + // Skipping optional input tensor. + if (input_ind == -1) { + continue; + } + AttachInput(&parent.Tensor(input_ind), litert_op); + } + + for (auto output_ind : tfl_op->outputs) { + AttachOutput(&parent.Tensor(output_ind), litert_op); + } + + // OPTIONS + + if (tfl_op->large_custom_options_size != 0) { + // TODO: b/365299994 - Support large custom options. + LITERT_LOG(LITERT_ERROR, "Large custom options not yet supported."); + return kLiteRtStatusErrorUnsupported; + } + + const auto& tfl_custom_opts = tfl_op->custom_options; + litert_op.SetCustomOptions(tfl_custom_opts.data(), tfl_custom_opts.size()); + detail::SetTflOptions(litert_op, std::move(tfl_op->builtin_options)); + + // OP CODE + + context.SetOpCode(litert_op, tfl_op->opcode_index); + + return kLiteRtStatusOk; +} + +LiteRtStatus UnpackTensor(FlatbufferContext& context, TflTensorPtr tfl_tensor, + LiteRtTensorT& litert_tensor) { + // WEIGHTS + + const auto buffer_ind = tfl_tensor->buffer; + if (buffer_ind != 0) { + auto buffer = context.TakeTflBuffer(buffer_ind); + if (!buffer) { + return buffer.Error().Status(); + } + + if (buffer->get()->offset != 0) { + // TODO: b/365299994 - Support buffer with offset. + LITERT_LOG(LITERT_ERROR, "Buffers with offset not yet supported."); + return kLiteRtStatusErrorUnsupported; + } + detail::SetTflBuffer(litert_tensor.Weights(), std::move(*buffer)); + } + + // TENSOR TYPE + + TflTensorType tfl_tensor_type(tfl_tensor->type, TflShapeInfo(*tfl_tensor)); + auto tensor_type = MapTensorType(tfl_tensor_type); + if (!tensor_type) { + return tensor_type.Error().Status(); + } + + litert_tensor.SetType(std::move(*tensor_type)); + + // QUANTIZATION + + auto quantization = + MapQuantization(tfl_tensor->quantization.get(), litert_tensor); + if (!quantization) { + return quantization.Error().Status(); + } + + litert_tensor.SetQarams(std::move(*quantization)); + + // MISC + + litert_tensor.SetName(tfl_tensor->name); + + if (tfl_tensor->is_variable) { + // TODO: b/365299994 - Support variable tensors. + LITERT_LOG(LITERT_ERROR, "Variable tensors not yet supported."); + return kLiteRtStatusErrorUnsupported; + } + + if (!tfl_tensor->variant_tensors.empty()) { + // TODO: b/365299994 - Support variant tensors. + LITERT_LOG(LITERT_ERROR, "Variant tensors not yet supported."); + return kLiteRtStatusErrorUnsupported; + } + + if (tfl_tensor->sparsity) { + // TODO: b/365299994 - Support sparsity tensors. + LITERT_LOG(LITERT_ERROR, "Sparsity tensors not yet supported."); + return kLiteRtStatusErrorUnsupported; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus UnpackSubgraph(FlatbufferContext& context, + TflSubgraphPtr tfl_subgraph, + LiteRtSubgraphT& litert_subgraph) { + // Unpack tensors. + for (auto& tfl_tensor : tfl_subgraph->tensors) { + LITERT_RETURN_STATUS_IF_NOT_OK(UnpackTensor( + context, std::move(tfl_tensor), litert_subgraph.EmplaceTensor())); + } + + // Unpack ops, pass litert_subgraph so they can look up the new litert + // tensors. + for (auto& tfl_op : tfl_subgraph->operators) { + LITERT_RETURN_STATUS_IF_NOT_OK(UnpackOp(context, litert_subgraph, + std::move(tfl_op), + litert_subgraph.EmplaceOp())); + } + + // Update subgraph I/O. + for (auto tfl_input_ind : tfl_subgraph->inputs) { + litert_subgraph.Inputs().push_back(&litert_subgraph.Tensor(tfl_input_ind)); + } + for (auto tfl_output_ind : tfl_subgraph->outputs) { + litert_subgraph.Outputs().push_back( + &litert_subgraph.Tensor(tfl_output_ind)); + } + + return kLiteRtStatusOk; +} + +LiteRtStatus UnpackSignatures(std::vector& tfl_signatures, + LiteRtModelT& parent) { + for (auto& tfl_signature : tfl_signatures) { + auto* litert_subgraph = + parent.Subgraphs().at(tfl_signature->subgraph_index); + + auto& tfl_inputs = tfl_signature->inputs; + auto& tfl_outputs = tfl_signature->outputs; + +#ifndef NDEBUG + // Tflite signatures map a tensor index to a name. We just assume + // that the indexes are exactly those of the subgraph inputs. Check + // this in debug mode. + if (tfl_inputs.size() != litert_subgraph->Inputs().size() || + tfl_outputs.size() != litert_subgraph->Outputs().size()) { + LITERT_LOG(LITERT_ERROR, + "Signature has incorrect number of input/outputs"); + } + + for (auto i = 0; i < tfl_inputs.size(); ++i) { + const auto& tfl_input = tfl_inputs.at(i); + const auto* litert_input = litert_subgraph->Inputs().at(i); + const auto* index_litert_input = + litert_subgraph->Tensors().at(tfl_input->tensor_index); + if (litert_input != index_litert_input) { + LITERT_LOG(LITERT_ERROR, + "Signature inputs reference tensors not in subgraph i/o"); + } + } + + for (auto i = 0; i < tfl_outputs.size(); ++i) { + const auto& tfl_output = tfl_outputs.at(i); + const auto* litert_output = litert_subgraph->Outputs().at(i); + const auto* index_litert_output = + litert_subgraph->Tensors().at(tfl_output->tensor_index); + if (litert_output != index_litert_output) { + LITERT_LOG(LITERT_ERROR, + "Signature outputs reference tensors not in subgraph i/o"); + } + } +#endif + + auto get_name = [](const auto& tfl_tensor) { return tfl_tensor->name; }; + + std::vector input_names(tfl_inputs.size()); + std::transform(tfl_inputs.cbegin(), tfl_inputs.cend(), input_names.begin(), + get_name); + + std::vector output_names(tfl_outputs.size()); + std::transform(tfl_outputs.cbegin(), tfl_outputs.cend(), + output_names.begin(), get_name); + + parent.EmplaceSignature(litert_subgraph, std::move(input_names), + std::move(output_names), + tfl_signature->signature_key); + } + + if (tfl_signatures.empty()) { + parent.EmplaceSignature(MakeDefaultSignature(parent.MainSubgraph())); + } + + return kLiteRtStatusOk; +} + +LiteRtStatus UnpackMetadata(FlatbufferContext& context, + std::vector& tfl_metadata, + LiteRtModelT& parent) { + for (auto& tfl_m_data : tfl_metadata) { + auto tfl_buffer = context.TakeTflBuffer(tfl_m_data->buffer); + if (!tfl_buffer) { + return tfl_buffer.Error().Status(); + } + + const auto& tfl_vec = tfl_buffer->get()->data; + parent.PushMetadata(tfl_m_data->name, tfl_vec.data(), tfl_vec.size()); + } + + return kLiteRtStatusOk; +} + +Expected UnpackModel(TflModelPtr tfl_model) { + auto litert_model = std::make_unique(); + FlatbufferContext context(*tfl_model); + + for (auto& tfl_subgraph : tfl_model->subgraphs) { + LITERT_EXPECT_OK(UnpackSubgraph(context, std::move(tfl_subgraph), + litert_model->EmplaceSubgraph())); + } + + LITERT_EXPECT_OK(UnpackSignatures(tfl_model->signature_defs, *litert_model)); + LITERT_EXPECT_OK(UnpackMetadata(context, tfl_model->metadata, *litert_model)); + detail::SetTflOpCodes(*litert_model, std::move(tfl_model->operator_codes)); + + return litert_model; +} + +} // namespace + +Expected LoadModelFromBuffer(BufferRef buffer) { + auto flatbuffer = FlatbufferWrapper::CreateFromBuffer(buffer); + if (!flatbuffer) { + return flatbuffer.Error(); + } + auto litert_model = UnpackModel(flatbuffer->get()->Unpack()); + if (litert_model) { + // Save the original FB pointer to use it later on CompiledModel. + detail::SetTflInitFlatbuffer(**litert_model, buffer); + } + return litert_model; +} + +Expected LoadModelFromFile(absl::string_view filename) { + auto flatbuffer = FlatbufferWrapper::CreateFromTflFile(filename); + if (!flatbuffer) { + return flatbuffer.Error(); + } + return UnpackModel(flatbuffer->get()->Unpack()); +} + +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/model/model_load.h b/tflite/experimental/litert/core/model/model_load.h new file mode 100644 index 00000000..6f72408e --- /dev/null +++ b/tflite/experimental/litert/core/model/model_load.h @@ -0,0 +1,36 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_LOAD_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_LOAD_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/core/model/model.h" + +namespace litert::internal { + +Expected> LoadModelFromFile( + absl::string_view filename); + +Expected> LoadModelFromBuffer( + BufferRef buffer); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_LOAD_H_ diff --git a/tflite/experimental/litert/core/model/model_serialize.cc b/tflite/experimental/litert/core/model/model_serialize.cc new file mode 100644 index 00000000..e9e5c934 --- /dev/null +++ b/tflite/experimental/litert/core/model/model_serialize.cc @@ -0,0 +1,272 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/model/model_serialize.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/core/byte_code_util.h" +#include "tflite/experimental/litert/core/model/litert_to_flatbuffer.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tflite/schema/schema_generated.h" + +namespace litert::internal { +namespace { + +using TensorMap = absl::flat_hash_map; + +// This is expected to be used to serialize the dispatch op custom code. +TflOpCodePtr MakeCustomOpCode(std::string custom_code_name) { + auto custom_code = std::make_unique(); + custom_code->builtin_code = ::tflite::BuiltinOperator_CUSTOM; + custom_code->custom_code = std::move(custom_code_name); + custom_code->version = 1; + return custom_code; +} + +// Utility for accessing flatbuffer state. +class FlatbufferBuilder { + public: + explicit FlatbufferBuilder(uint32_t dispatch_op_code_ind) + : tfl_model_(std::make_unique()), + dispatch_op_code_ind_(dispatch_op_code_ind) { + // Tfl expects empty buffer 0. + tfl_model_->buffers.push_back(std::make_unique()); + } + + TflModel& Model() { return *tfl_model_.get(); } + + TflModelPtr Release() && { return std::move(tfl_model_); } + + // Move given buffer into tfl model and get its index. + uint32_t SubmitBuffer(TflBufferPtr tfl_buffer) { + tfl_model_->buffers.push_back(std::move(tfl_buffer)); + return tfl_model_->buffers.size() - 1; + } + + // Add to tfl model metadata. + void PushMetadata(std::string key, BufferRef data) { + auto tfl_buffer = std::make_unique(); + tfl_buffer->data.assign(data.Data(), data.Data() + data.Size()); + auto tfl_buffer_ind = SubmitBuffer(std::move(tfl_buffer)); + tfl_model_->metadata_buffer.push_back(tfl_buffer_ind); + auto tfl_metadata = std::make_unique(); + tfl_metadata->name = key; + tfl_metadata->buffer = tfl_buffer_ind; + tfl_model_->metadata.push_back(std::move(tfl_metadata)); + } + + // Get the index in the tfl op codes for the dispatch custom code. + // This should be the only new custom code added after loading the initial + // tfl. + uint32_t DispatchOpCodeInd() const { return dispatch_op_code_ind_; } + + private: + TflModelPtr tfl_model_; + uint32_t dispatch_op_code_ind_; +}; + +void SetOptions(const LiteRtOpT& litert_op, TflOp& tfl_op) { + tfl_op.builtin_options = detail::GetTflOptions(litert_op); + if (litert_op.CustomOptions().Size() != 0) { + tfl_op.custom_options = litert_op.CustomOptions().ToVec(); + tfl_op.custom_options_format = tflite::CustomOptionsFormat_FLEXBUFFERS; + } +} + +LiteRtStatus PackOp(FlatbufferBuilder& builder, LiteRtOpT& litert_op, + TflOp& tfl_op, const TensorMap& tensor_map) { + auto tfl_op_code_ind = detail::GetTflOpCodeInd(litert_op); + if (tfl_op_code_ind < 0) { + tfl_op_code_ind = builder.DispatchOpCodeInd(); + } + tfl_op.opcode_index = tfl_op_code_ind; + + for (auto* in : litert_op.Inputs()) { + tfl_op.inputs.push_back(tensor_map.at(in)); + } + + for (auto* out : litert_op.Outputs()) { + tfl_op.outputs.push_back(tensor_map.at(out)); + } + + SetOptions(litert_op, tfl_op); + + return kLiteRtStatusOk; +} + +LiteRtStatus PackTensor(FlatbufferBuilder& builder, + LiteRtTensorT& litert_tensor, TflTensor& tfl_tensor) { + auto tfl_tensor_type = MapTensorType(litert_tensor.Type()); + if (!tfl_tensor_type) { + return tfl_tensor_type.Error().Status(); + } + auto [tfl_elem_type, tfl_shape] = *tfl_tensor_type; + + tfl_tensor.type = tfl_elem_type; + tfl_tensor.shape.assign(tfl_shape.shape.begin(), tfl_shape.shape.end()); + tfl_tensor.has_rank = tfl_shape.has_rank; + tfl_tensor.shape_signature.assign(tfl_shape.shape_signature.begin(), + tfl_shape.shape_signature.end()); + + auto tfl_quantization = MapQuantization(litert_tensor.Qparams()); + if (!tfl_quantization) { + return tfl_quantization.Error().Status(); + } + tfl_tensor.quantization = std::move(*tfl_quantization); + + tfl_tensor.buffer = + builder.SubmitBuffer(detail::TakeTflBuffer(litert_tensor.Weights())); + tfl_tensor.name = std::string(litert_tensor.Name()); + + return kLiteRtStatusOk; +} + +LiteRtStatus PackSubgraph(FlatbufferBuilder& builder, + LiteRtSubgraphT& litert_subgraph, + TflSubgraph& tfl_subgraph, TensorMap& tensor_map) { + for (auto* tensor : litert_subgraph.Tensors()) { + tfl_subgraph.tensors.push_back(std::make_unique()); + tensor_map.insert({tensor, tfl_subgraph.tensors.size() - 1}); + LITERT_RETURN_STATUS_IF_NOT_OK( + PackTensor(builder, *tensor, *tfl_subgraph.tensors.back())); + } + + for (auto* op : litert_subgraph.Ops()) { + tfl_subgraph.operators.push_back(std::make_unique()); + LITERT_RETURN_STATUS_IF_NOT_OK( + PackOp(builder, *op, *tfl_subgraph.operators.back(), tensor_map)); + } + + for (auto* in : litert_subgraph.Inputs()) { + tfl_subgraph.inputs.push_back(tensor_map.at(in)); + } + + for (auto* out : litert_subgraph.Outputs()) { + tfl_subgraph.outputs.push_back(tensor_map.at(out)); + } + + return kLiteRtStatusOk; +} + +Expected PackAsTflite(LiteRtModelT& litert_model) { + // Pass the op code list through that was saved during loading. Add one more + // op code for the dispatch ops. + auto tfl_op_codes = detail::TakeTflOpCodes(litert_model); + tfl_op_codes.push_back( + MakeCustomOpCode(std::string(kLiteRtDispatchOpCustomCode))); + + FlatbufferBuilder builder(tfl_op_codes.size() - 1); + builder.Model().operator_codes = std::move(tfl_op_codes); + + // Pack litert subgraphs into tfl subgraphs and save the mapping of tensors. + TensorMap tensor_map; + for (auto* litert_subgraph : litert_model.Subgraphs()) { + auto& tfl_subgraph = *builder.Model().subgraphs.emplace_back( + std::make_unique()); + LITERT_EXPECT_OK( + PackSubgraph(builder, *litert_subgraph, tfl_subgraph, tensor_map)); + } + + // Serialize the signatures using saved tensor mapping. + for (auto* litert_signature : litert_model.Signatures()) { + auto* litert_subgraph = &litert_signature->GetSubgraph(); + + auto& tfl_signature = *builder.Model().signature_defs.emplace_back( + std::make_unique()); + tfl_signature.signature_key = std::string(litert_signature->Key()); + + auto begin = litert_model.Subgraphs().cbegin(); + auto end = litert_model.Subgraphs().cend(); + const auto litert_subgraph_ind = + std::find(begin, end, litert_subgraph) - begin; + tfl_signature.subgraph_index = litert_subgraph_ind; + + auto input_ind = 0; + for (const auto& litert_name : litert_signature->InputNames()) { + auto& tfl_input = *tfl_signature.inputs.emplace_back( + std::make_unique<::tflite::TensorMapT>()); + tfl_input.name = litert_name; + tfl_input.tensor_index = + tensor_map.find(litert_subgraph->Inputs().at(input_ind))->second; + ++input_ind; + } + + auto output_ind = 0; + for (const auto& litert_name : litert_signature->OutputNames()) { + auto& tfl_output = *tfl_signature.outputs.emplace_back( + std::make_unique<::tflite::TensorMapT>()); + tfl_output.name = litert_name; + tfl_output.tensor_index = + tensor_map.find(litert_subgraph->Outputs().at(output_ind))->second; + ++output_ind; + } + } + + // Serialize metadata. + for (auto it = litert_model.MetadataBegin(); it != litert_model.MetadataEnd(); + ++it) { + builder.PushMetadata(it->first, it->second); + } + + return std::move(builder).Release(); +} + +} // namespace + +Expected> SerializeModel(LiteRtModelT&& model) { + auto tfl_model = PackAsTflite(model); + if (!tfl_model) { + return tfl_model.Error(); + } + + // TODO(@lukeboyer) Figure out what to do with fb versions. + tfl_model->get()->version = 3; + + auto serialized_tfl = SerializeFlatbuffer(**tfl_model); + if (!VerifyFlatbuffer(serialized_tfl.Span())) { + return Error(kLiteRtStatusErrorInvalidFlatbuffer); + } + + return serialized_tfl; +} + +} // namespace litert::internal + +LiteRtStatus LiteRtSerializeModel(LiteRtModel model, uint8_t** buf, + size_t* size, size_t* offset, + bool destroy_model) { + auto serialized = litert::internal::SerializeModel(std::move(*model)); + if (destroy_model) { + delete model; + } + if (!serialized) { + return serialized.Error().Status(); + } + std::tie(*buf, *size, *offset) = serialized->Release(); + return kLiteRtStatusOk; +} diff --git a/tflite/experimental/litert/core/model/model_serialize.h b/tflite/experimental/litert/core/model/model_serialize.h new file mode 100644 index 00000000..52b466e3 --- /dev/null +++ b/tflite/experimental/litert/core/model/model_serialize.h @@ -0,0 +1,46 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_SERIALIZE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_SERIALIZE_H_ + +#include "tflite/experimental/litert/c/litert_model.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Serializes model to bytes. +// NOTE this destroys the model before it returns unless destroy_model is false. +// NOTE: Caller takes ownership of `buf`. Flatbuffers are packed into their +// arrays back to front, so the valid flatbuffer is buf[offset, size]. +LiteRtStatus LiteRtSerializeModel(LiteRtModel model, uint8_t** buf, + size_t* size, size_t* offset, + bool destroy_model = true); + +#ifdef __cplusplus +} + +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +namespace litert::internal { + +Expected> SerializeModel(LiteRtModelT&& model); + +} // namespace litert::internal + +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_SERIALIZE_H_ diff --git a/tflite/experimental/litert/core/model/model_test.cc b/tflite/experimental/litert/core/model/model_test.cc new file mode 100644 index 00000000..0f4011d4 --- /dev/null +++ b/tflite/experimental/litert/core/model/model_test.cc @@ -0,0 +1,280 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/model/model.h" + +#include +#include +#include +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tflite/experimental/litert/test/test_macros.h" +#include "tflite/schema/schema_generated.h" + +namespace litert::internal { +namespace { + +using ::testing::ElementsAreArray; + +// +// Model +// + +TEST(ModelTest, GetMetadata) { + static constexpr absl::string_view kMetadata = "VALUE"; + static constexpr absl::string_view kKey = "KEY"; + + LiteRtModelT model; + LITERT_ASSERT_STATUS_OK(model.PushMetadata(kKey, kMetadata)); + auto found_metadata = model.FindMetadata(kKey); + ASSERT_TRUE(found_metadata); + EXPECT_EQ(found_metadata->StrView(), kMetadata); +} + +TEST(ModelTest, MetadataDNE) { + LiteRtModelT model; + auto res = model.FindMetadata("FOO"); + ASSERT_FALSE(res.HasValue()); +} + +TEST(ModelTest, PopMetadata) { + static constexpr absl::string_view kMetadata = "VALUE"; + static constexpr absl::string_view kKey = "KEY"; + + LiteRtModelT model; + LITERT_ASSERT_STATUS_OK(model.PushMetadata(kKey, kMetadata)); + + auto popped_metadata = model.PopMetadata(kKey); + ASSERT_TRUE(popped_metadata); + EXPECT_EQ(popped_metadata->StrView(), kMetadata); + + EXPECT_FALSE(model.FindMetadata(kKey)); +} + +TEST(ModelTest, EmplaceSubgraph) { + LiteRtModelT model; + model.EmplaceSubgraph(); + EXPECT_EQ(model.Subgraphs().size(), 1); +} + +TEST(ModelTest, Signature) { + static constexpr absl::string_view kSignatureName = "MY_SIGNATURE"; + + const std::vector inputs = {"input_1", "input_2"}; + const std::vector outputs = {"output_1"}; + + LiteRtModelT model; + auto& subgraph = model.EmplaceSubgraph(); + + auto& signature = model.EmplaceSignature(&subgraph, inputs, outputs, + std::string(kSignatureName)); + + auto found_signature = model.FindSignature(kSignatureName); + ASSERT_TRUE(found_signature); + EXPECT_EQ(found_signature->get(), signature); +} + +TEST(ModelTest, SignatureDNE) { + static constexpr absl::string_view kSignatureName = "MY_SIGNATURE"; + LiteRtModelT model; + auto found_signature = model.FindSignature(kSignatureName); + EXPECT_FALSE(found_signature); +} + +// +// Subgraph +// + +TEST(ModelSubgraphTest, Input) { + LiteRtTensorT tensor; + LiteRtSubgraphT subgraph; + subgraph.Inputs().push_back(&tensor); + EXPECT_EQ(&subgraph.Input(0), subgraph.Inputs().front()); +} + +TEST(ModelSubgraphTest, Output) { + LiteRtTensorT tensor; + LiteRtSubgraphT subgraph; + subgraph.Outputs().push_back(&tensor); + EXPECT_EQ(&subgraph.Output(0), subgraph.Outputs().front()); +} + +TEST(ModelSubgraphTest, EmplaceTensor) { + LiteRtSubgraphT subgraph; + auto& tensor = subgraph.EmplaceTensor(); + ASSERT_EQ(subgraph.Tensors().size(), 1); + EXPECT_THAT(subgraph.Tensors(), ElementsAreArray({&tensor})); +} + +TEST(ModelSubgraphTest, EmplaceOp) { + LiteRtSubgraphT subgraph; + auto& op = subgraph.EmplaceOp(); + ASSERT_EQ(subgraph.Ops().size(), 1); + EXPECT_THAT(subgraph.Ops(), ElementsAreArray({&op})); +} + +// +// Op +// + +TEST(ModelOpTest, Input) { + LiteRtOpT op; + LiteRtTensorT tensor; + op.Inputs().push_back(&tensor); + EXPECT_EQ(&op.Input(0), op.Inputs().front()); +} + +TEST(ModelOpTest, Output) { + LiteRtOpT op; + LiteRtTensorT tensor; + op.Outputs().push_back(&tensor); + EXPECT_EQ(&op.Output(0), op.Outputs().front()); +} + +TEST(ModelOpTest, CustomOptions) { + static constexpr absl::string_view kOpts = "OPTIONS"; + + LiteRtOpT op; + op.SetCustomOptions(kOpts); + EXPECT_EQ(op.CustomOptions().StrView(), kOpts); +} + +TEST(ModelOpTest, Options) { + static constexpr auto kOptsType = ::tflite::BuiltinOptions_AddOptions; + + TflOptions options; + options.type = kOptsType; + options.Set(::tflite::AddOptionsT()); + + LiteRtOpT op; + detail::SetTflOptions(op, std::move(options)); + + ASSERT_EQ(detail::GetTflOptions(op).type, kOptsType); +} + +TEST(ModelOpTest, OpCode) { + constexpr static auto kOpCode = kLiteRtOpCodeTflMul; + + LiteRtOpT op; + op.SetOpCode(kOpCode); + EXPECT_EQ(op.OpCode(), kOpCode); +} + +// +// Tensor +// + +TEST(ModelTensorTypeTest, MakeRankedTensorType) { + static constexpr const int32_t kDims[] = {2, 2}; + static constexpr auto kDimsSpan = absl::MakeConstSpan(kDims); + static constexpr auto kElementType = kLiteRtElementTypeFloat32; + const auto tensor_type = MakeRankedTensorType(kElementType, kDimsSpan); + ASSERT_EQ(tensor_type.first, kLiteRtRankedTensorType); + EXPECT_EQ(tensor_type.second.ranked_tensor_type.element_type, kElementType); + const auto& layout = tensor_type.second.ranked_tensor_type.layout; + ASSERT_EQ(layout.rank, kDimsSpan.size()); + EXPECT_THAT(absl::MakeConstSpan(layout.dimensions, kDimsSpan.size()), + ElementsAreArray(kDimsSpan)); +} + +TEST(ModelQuantizationTypeTest, MakePerTensor) { + static constexpr auto kScale = 1.0f; + static constexpr auto kZero = 1L; + const auto quant = MakePerTensorQuantization(kScale, kZero); + ASSERT_EQ(quant.first, kLiteRtQuantizationPerTensor); + const auto& per_tensor = quant.second.per_tensor; + EXPECT_EQ(per_tensor.scale, kScale); + EXPECT_EQ(per_tensor.zero_point, kZero); +} + +TEST(ModelQuantizationTypeTest, MakePerChannel) { + static constexpr std::array kScale = {1.0f, 2.0f}; + static constexpr std::array kZero = {1L, 2L}; + static constexpr int32_t kQdim = 0; + + LiteRtTensorT tensor; + const auto quant = MakePerChannelQuantization( + kScale, kZero, kQdim, + [&tensor](auto s) { return tensor.RequestBuffer(s); }); + + ASSERT_EQ(quant.first, kLiteRtQuantizationPerChannel); + const auto& per_channel = quant.second.per_channel; + + const auto size = per_channel.num_channels; + ASSERT_EQ(size, 2); + EXPECT_EQ(per_channel.quantized_dimension, 0); + + auto scales = absl::MakeConstSpan(per_channel.scales, size); + auto zeros = absl::MakeConstSpan(per_channel.zero_points, size); + + EXPECT_THAT(scales, ElementsAreArray(kScale)); + EXPECT_THAT(zeros, ElementsAreArray(kZero)); +} + +TEST(ModelWeightsTest, WeightsFromBuf) { + static constexpr absl::string_view kData = "some_data"; + + LiteRtWeightsT weights; + weights.SetFromBuf(BufferRef(kData.data(), kData.size())); + EXPECT_EQ(weights.Buf().StrView(), kData); +} + +TEST(ModelTensorTest, Name) { + static constexpr absl::string_view kName = "TENSOR_NAME"; + + LiteRtTensorT tensor; + tensor.SetName(std::string(kName.begin(), kName.end())); + EXPECT_EQ(tensor.Name(), kName); +} + +TEST(ModelTensorTest, Use) { + LiteRtTensorT tensor; + tensor.Users().emplace_back(); + tensor.UserArgInds().push_back(0); + auto [user, ind] = tensor.GetUse(0); + EXPECT_EQ(user, tensor.Users().front()); + EXPECT_EQ(ind, 0); +} + +TEST(ModelTensorTest, DefiningOp) { + LiteRtTensorT tensor; + LiteRtOpT op; + tensor.SetDefiningOp(op, 0); + EXPECT_EQ(tensor.DefiningOp(), &op); + EXPECT_EQ(tensor.DefiningOpOutInd(), 0); +} + +// +// Util +// + +TEST(ModelOpListTest, Push) { + LiteRtOpListT op_list; + LiteRtOpT op; + op_list.Push(&op); + auto vec = op_list.Vec(); + EXPECT_EQ(vec.front(), &op); +} + +} // namespace +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/util/BUILD b/tflite/experimental/litert/core/util/BUILD new file mode 100644 index 00000000..41f0dc9d --- /dev/null +++ b/tflite/experimental/litert/core/util/BUILD @@ -0,0 +1,85 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "flatbuffer_tools", + srcs = ["flatbuffer_tools.cc"], + hdrs = [ + "flatbuffer_tools.h", + ], + deps = [ + "//tflite:model_builder", + "//tflite:stderr_reporter", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/cc:litert_detail", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/core:filesystem", + "//tflite/schema:schema_fbs", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@flatbuffers//:runtime_cc", + "@org_tensorflow//tensorflow/compiler/mlir/lite:allocation", + ], +) + +cc_test( + name = "flatbuffer_tools_test", + srcs = ["flatbuffer_tools_test.cc"], + data = [ + "//tflite/experimental/litert/test:mlir_test_data", + "//tflite/experimental/litert/test:tflite_test_data", + ], + deps = [ + ":flatbuffer_tools", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/test:common", + "//tflite/experimental/litert/test:test_macros", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "tensor_type_util", + srcs = [ + "tensor_type_util.cc", + ], + hdrs = [ + "tensor_type_util.h", + ], + deps = [ + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/cc:litert_expected", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "tensor_type_util_test", + srcs = ["tensor_type_util_test.cc"], + deps = [ + ":tensor_type_util", + "//tflite/experimental/litert/c:litert_model", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tflite/experimental/litert/core/util/flatbuffer_tools.cc b/tflite/experimental/litert/core/util/flatbuffer_tools.cc new file mode 100644 index 00000000..6225eca2 --- /dev/null +++ b/tflite/experimental/litert/core/util/flatbuffer_tools.cc @@ -0,0 +1,321 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" + +#include +#include +#include +#include +#include + +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/allocation.h" +#include "tflite/experimental/litert/core/filesystem.h" + +#ifndef NDEBUG +// Make flatbuffers verifier `assert` in debug mode. +#define FLATBUFFERS_DEBUG_VERIFICATION_FAILURE + +#include "flatbuffers/flatbuffers.h" // from @flatbuffers // IWYU pragma: keep +#endif + +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "flatbuffers/verifier.h" // from @flatbuffers +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/model_builder.h" +#include "tflite/schema/schema_generated.h" +#include "tflite/stderr_reporter.h" + +namespace litert::internal { + +using ::flatbuffers::Verifier; +using ::tflite::VerifyModelBuffer; + +namespace { + +Expected FindMetadataInd(const TflModel& model, + absl::string_view key) { + tflite::MetadataT* fb_metadata = nullptr; + for (auto& m : model.metadata) { + if (m->name == key) { + fb_metadata = m.get(); + break; + } + } + if (fb_metadata == nullptr) { + return Error(kLiteRtStatusErrorNotFound); + } + return fb_metadata->buffer; +} + +} // namespace + +absl::string_view FbBufToStr(const uint8_t* fb_data, size_t size) { + auto fb_buf_raw = reinterpret_cast(fb_data); + return absl::string_view(fb_buf_raw, size); +} + +absl::string_view FbBufToStr(absl::Span fb_buf) { + auto fb_buf_raw = reinterpret_cast(fb_buf.data()); + const size_t fb_buf_size = fb_buf.size(); + return absl::string_view(fb_buf_raw, fb_buf_size); +} + +absl::Span FbBufToStr(absl::Span fb_buf) { + return absl::MakeSpan(reinterpret_cast(fb_buf.data()), fb_buf.size()); +} + +absl::Span FbBufToStr(uint8_t* fb_data, size_t size) { + return absl::MakeSpan(reinterpret_cast(fb_data), size); +} + +bool VerifyFlatbuffer(absl::Span buf) { + return VerifyFlatbuffer(buf.data(), buf.size()); +} + +bool VerifyFlatbuffer(const uint8_t* buf, size_t buf_size) { + flatbuffers::Verifier::Options options; +#ifndef NDEBUG + options.assert = true; +#endif + flatbuffers::Verifier verifier(buf, buf_size, options); + return VerifyModelBuffer(verifier); +} + +Expected> GetMetadata(absl::string_view key, + TflModel& model) { + auto buffer_ind = FindMetadataInd(model, key); + if (!buffer_ind) { + // Metadata key already has value. + return buffer_ind.Error(); + } + auto& fb_vec = model.buffers.at(*buffer_ind)->data; + return MutableBufferRef(fb_vec.data(), fb_vec.size()); +} + +Expected> GetMetadata(absl::string_view key, + const TflModel& model) { + auto metadata = GetMetadata(key, const_cast(model)); + if (!metadata) { + return metadata.Error(); + } + return *metadata; +} + +LiteRtStatus PushMetadata(absl::string_view key, TflModel& model, + BufferRef metadata) { + auto buffer_ind = FindMetadataInd(model, key); + if (buffer_ind) { + // Metadata key already has value. + return kLiteRtStatusErrorInvalidArgument; + } + + auto& new_metadata = + model.metadata.emplace_back(std::make_unique()); + new_metadata->name.assign(key.data(), key.size()); + + const auto new_m_buffer_ind = model.buffers.size(); + new_metadata->buffer = new_m_buffer_ind; + + auto& new_buffer = model.buffers.emplace_back(std::make_unique()); + new_buffer->data.assign(metadata.Data(), metadata.Data() + metadata.Size()); + + return kLiteRtStatusOk; +} + +Expected> GetTflBuffer(TflModel& tfl_model, + uint32_t buffer_ind) { + if (buffer_ind >= tfl_model.buffers.size()) { + return Error(kLiteRtStatusErrorIndexOOB); + } + auto& tfl_data = tfl_model.buffers.at(buffer_ind)->data; + return MutableBufferRef(tfl_data.data(), tfl_data.size()); +} + +Expected> GetTflBuffer(const TflModel& tfl_model, + uint32_t buffer_ind) { + auto buffer = GetTflBuffer(const_cast(tfl_model), buffer_ind); + if (!buffer) { + return buffer.Error(); + } + return *buffer; +} + +Expected TakeBuffer(TflModel& tfl_model, uint32_t buffer_ind) { + if (buffer_ind >= tfl_model.buffers.size()) { + return Error(kLiteRtStatusErrorIndexOOB); + } + return std::move(tfl_model.buffers.at(buffer_ind)); +} + +Expected PushTflBuffer(TflModel& tfl_model, + BufferRef buffer) { + tfl_model.buffers.emplace_back(std::make_unique<::tflite::BufferT>()) + ->data.assign(buffer.Data(), buffer.Data() + buffer.Size()); + return tfl_model.buffers.size() - 1; +} + +Expected GetTflOpCode(const TflModel& tfl_model, + uint32_t op_code_ind) { + if (op_code_ind >= tfl_model.operator_codes.size()) { + return Error(kLiteRtStatusErrorIndexOOB); + } + return std::move(tfl_model.operator_codes.at(op_code_ind)->builtin_code); +} + +bool IsRankedTensorType(const TflShapeInfo& tfl_shape) { + return tfl_shape.has_rank; +} + +bool IsStaticTensorType(const TflShapeInfo& tfl_shape) { + return !IsRankedTensorType(tfl_shape) || + std::none_of(tfl_shape.shape_signature.begin(), + tfl_shape.shape_signature.end(), + [](auto d) { return d < 0; }); +} + +Expected> AsStaticShape( + const TflShapeInfo& tfl_shape) { + if (!IsStaticTensorType(tfl_shape)) { + return Error(kLiteRtStatusErrorInvalidArgument); + } + return absl::MakeConstSpan(tfl_shape.shape.data(), tfl_shape.shape.size()); +} + +Expected> AsDynamicShape( + const TflShapeInfo& tfl_shape) { + auto static_shape = AsStaticShape(tfl_shape); + if (static_shape) { + return static_shape; + } + if (!IsRankedTensorType(tfl_shape)) { + return Error(kLiteRtStatusErrorInvalidArgument); + } + return absl::MakeConstSpan(tfl_shape.shape_signature.data(), + tfl_shape.shape_signature.size()); +} + +bool IsQuantized(const TflQuantization* tfl_quantization) { + return tfl_quantization && + (!tfl_quantization->scale.empty() || + tfl_quantization->details.type != tflite::QuantizationDetails_NONE); +} + +bool IsPerChannelQuantized(const TflQuantization* tfl_quantization) { + return tfl_quantization && tfl_quantization->scale.size() > 1; +} + +bool IsPerTensorQuantized(const TflQuantization* tfl_quantization) { + return tfl_quantization && tfl_quantization->scale.size() == 1; +} + +bool IsBlockwiseQuantized(const TflQuantization* tfl_quantization) { + return tfl_quantization && + tfl_quantization->details.type == + tflite::QuantizationDetails_BlockwiseQuantization; +} + +bool IsCustomQuantized(const TflQuantization* tfl_quantization) { + return tfl_quantization && tfl_quantization->details.type == + tflite::QuantizationDetails_CustomQuantization; +} + +Expected AsPerTensorQparams( + const TflQuantization* tfl_quantization) { + if (!IsPerTensorQuantized(tfl_quantization)) { + return Error(kLiteRtStatusErrorInvalidArgument); + } + return std::make_pair(tfl_quantization->zero_point.front(), + tfl_quantization->scale.front()); +} + +Expected AsPerChannelQparams( + const TflQuantization* tfl_quantization) { + if (!IsPerChannelQuantized(tfl_quantization)) { + return Error(kLiteRtStatusErrorInvalidArgument); + } + return TflPerChannelQParams(tfl_quantization->quantized_dimension, + tfl_quantization->zero_point.size(), + tfl_quantization->zero_point, + tfl_quantization->scale); +} + +::tflite::Allocation::Ptr MakeAllocation(BufferRef buf) { + return std::make_unique<::tflite::MemoryAllocation>( + buf.Data(), buf.Size(), ::tflite::DefaultErrorReporter()); +} + +Expected FlatbufferWrapper::CreateFromBuffer( + OwningBufferRef&& buffer) { + if (!VerifyFlatbuffer(buffer.Data(), buffer.Size())) { + return Error(kLiteRtStatusErrorInvalidFlatbuffer); + } + + auto alloc = MakeAllocation(buffer); + + if (alloc == nullptr) { + return Error(kLiteRtStatusErrorFileIO); + } + + auto fb_model = ::tflite::FlatBufferModel::BuildFromBuffer( + reinterpret_cast(alloc->base()), alloc->bytes()); + if (fb_model == nullptr) { + return Error(kLiteRtStatusErrorFileIO); + } + + return FlatbufferWrapper::Ptr(new FlatbufferWrapper( + std::move(fb_model), std::move(alloc), std::move(buffer))); +} + +Expected FlatbufferWrapper::CreateFromBuffer( + BufferRef buffer) { + return FlatbufferWrapper::CreateFromBuffer( + OwningBufferRef(buffer.Data(), buffer.Size())); +} + +Expected FlatbufferWrapper::CreateFromTflFile( + absl::string_view path) { + auto buf = LoadBinaryFile(path); + if (!buf) { + return buf.Error(); + } + return FlatbufferWrapper::CreateFromBuffer(std::move(*buf)); +} + +OwningBufferRef SerializeFlatbuffer(const TflModel& tfl_model) { + flatbuffers::FlatBufferBuilder b; + auto model_offset = tflite::Model::Pack(b, &tfl_model); + tflite::FinishModelBuffer(b, model_offset); + + OwningBufferRef buffer; + auto [new_buf, new_size, new_offset] = buffer.GetWeak(); + new_buf = b.ReleaseRaw(new_size, new_offset); + + return buffer; +} + +OwningBufferRef SerializeFlatbuffer( + const FlatbufferWrapper& flatbuffer) { + auto tfl_model = flatbuffer.Unpack(); + return SerializeFlatbuffer(*tfl_model); +} + +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/util/flatbuffer_tools.h b/tflite/experimental/litert/core/util/flatbuffer_tools.h new file mode 100644 index 00000000..f84f504f --- /dev/null +++ b/tflite/experimental/litert/core/util/flatbuffer_tools.h @@ -0,0 +1,283 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_FLATBUFFER_TOOLS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_FLATBUFFER_TOOLS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/mlir/lite/allocation.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_detail.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/model_builder.h" +#include "tflite/schema/schema_generated.h" + +namespace litert::internal { + +// Flatbuffer IR + +using TflTensor = ::tflite::TensorT; +using TflOp = ::tflite::OperatorT; +using TflBuffer = ::tflite::BufferT; +using TflSubgraph = ::tflite::SubGraphT; +using TflModel = ::tflite::ModelT; +using TflOpCodeEnum = ::tflite::BuiltinOperator; +using TflOpCode = ::tflite::OperatorCodeT; +using TflQuantization = ::tflite::QuantizationParametersT; +using TflElementType = ::tflite::TensorType; +using TflOptions = ::tflite::BuiltinOptionsUnion; +using TflSignature = ::tflite::SignatureDefT; +using TflMetadata = ::tflite::MetadataT; + +using TflBufferPtr = std::unique_ptr; +using TflModelPtr = std::unique_ptr; +using TflQuantizationPtr = std::unique_ptr; +using TflOpCodePtr = std::unique_ptr; +using TflSubgraphPtr = std::unique_ptr; +using TflTensorPtr = std::unique_ptr; +using TflOpPtr = std::unique_ptr; +using TflSignaturePtr = std::unique_ptr; +using TflMetadataPtr = std::unique_ptr; + +// Code and verion. +using TflOpCodeDetail = std::pair; + +// Zero-point, scale. +using TflPerTensorQParams = std::pair; + +// Quantized dim, num channels, zero-points, scales. +using TflPerChannelQParams = + std::tuple, std::vector>; + +// Mirror of all the tensor type related fields in flatbuffer tensor definition. +struct TflShapeInfo { + // Fixed or dynamic rank. + bool has_rank; + + // Basic shape, all elements are non-negative (even if this is a dynamic + // shape). + SmallVec shape; + + // Dynamic dyn info. If this is not empty, then its length is equal to shape. + // If i is a dyn dim, then shape[i] == 1 and shape_signature[i] < 0. Otherwise + // shape_signature[i] == shape[i]. + SmallVec shape_signature; + + // Convert from a single dims array. Will detect if array is static/dynamic + // and populate fields accordingly. + explicit TflShapeInfo(absl::Span shape_data) : has_rank(true) { + bool is_dyn = false; + shape.reserve(shape_data.size()); + shape_signature.reserve(shape_data.size()); + for (auto d : shape_data) { + if (d >= 0) { + shape.push_back(d); + shape_signature.push_back(d); + } else { + is_dyn = true; + shape.push_back(1); + shape_signature.push_back(-1); + } + } + if (!is_dyn) { + shape_signature.clear(); + } + } + + // Convert from tensor. + explicit TflShapeInfo(const TflTensor& tfl_tensor) + : has_rank(tfl_tensor.has_rank), + shape(SmallVec(tfl_tensor.shape.begin(), + tfl_tensor.shape.end())), + shape_signature(SmallVec(tfl_tensor.shape_signature.begin(), + tfl_tensor.shape_signature.end())) {} +}; + +using TflTensorType = std::pair; + +// Flatbuffer bytes util. + +// Convenience method to get string view from native flatbuffer chars. +absl::string_view FbBufToStr(const uint8_t* fb_data, size_t size); + +// Span version. +absl::string_view FbBufToStr(absl::Span fb_buf); + +// Convenience method to get mutable signed char span from native flatbuffer +// chars. +absl::Span FbBufToStr(uint8_t* fb_data, size_t size); + +// Span to span version. +absl::Span FbBufToStr(absl::Span fb_buf); + +// Flatbuffer verifiers. + +// Verifies given serialized flatbuffer +bool VerifyFlatbuffer(const uint8_t* buf, size_t buf_size); + +// Override of above with view input. +bool VerifyFlatbuffer(absl::Span buf); + +// TFL flatbuffer IR helpers. + +// Get the metadata buffer under given key if it exists. +Expected> GetMetadata(absl::string_view key, + const TflModel& model); + +// Get the metadata buffer under given key if it exists that can be written to. +Expected> GetMutableMetadata(absl::string_view key, + TflModel& model); + +// Push the given metadata to the given key if the key does not already exist. +LiteRtStatus PushMetadata(absl::string_view key, TflModel& model, + BufferRef metadata); + +// Get the buffer object at the given index if it exists. +Expected> GetTflBuffer(const TflModel& tfl_model, + uint32_t buffer_ind); + +// Get the buffer object at the given index if it exists that can be written to. +Expected> GetMutableTflBuffer(TflModel& tfl_model, + uint32_t buffer_ind); + +// Move and take ownership of the buffer object at given index if it exists. +Expected TakeBuffer(TflModel& tfl_model, uint32_t buffer_ind); + +// Add a new buffer to the tflite model, returning its index. +Expected PushTflBuffer(TflModel& tfl_model, + BufferRef buffer); + +// Make a tflite buffer from data. +template +TflBufferPtr MakeTflBuffer(std::initializer_list data) { + auto res = std::make_unique(); + const auto byte_size = data.size() * sizeof(T); + res->data.resize(byte_size); + for (auto it = data.begin(); it != data.end(); ++it) { + auto* write_to = + reinterpret_cast(res->data.data()) + (it - data.begin()); + *write_to = *it; + } + res->size = res->data.size(); + res->offset = 0; + return res; +} + +// Get the op code from the model at the given index if it exists. +Expected GetTflOpCode(const TflModel& tfl_model, + uint32_t op_code_ind); + +// Is tensor fixed rank, with possible dynamic dims. +bool IsRankedTensorType(const TflShapeInfo& tfl_shape); + +// Is ranked tensor type with static shape. +bool IsStaticTensorType(const TflShapeInfo& tfl_shape); + +// Get static shape info if given is indeed a static shape. +Expected> AsStaticShape( + const TflShapeInfo& tfl_shape); + +// Get ranked dynamic shape info if given is indeed a ranked. Still works with +// static shapes. +Expected> AsDynamicShape( + const TflShapeInfo& tfl_shape); + +// Is the tensor quantized. +bool IsQuantized(const TflQuantization* tfl_quantization); + +// Is the tensor per-tensor quantized. +bool IsPerTensorQuantized(const TflQuantization* tfl_quantization); + +// Is the tensor per-channel quantized. +bool IsPerChannelQuantized(const TflQuantization* tfl_quantization); + +// Is the tensor block-wise quantized. +bool IsBlockWiseQuantized(const TflQuantization* tfl_quantization); + +// Does tensor have custom quantization. +bool IsCustomQuantized(const TflQuantization* tfl_quantization); + +// Get the per-tensor tensor q-params if given tensor has them. +Expected AsPerTensorQparams( + const TflQuantization* tfl_quantization); + +// Get the per-channel tensor q-params if given tensor has them. +Expected AsPerChannelQparams( + const TflQuantization* tfl_quantization); + +// Flatbuffer management helpers. + +// Make a tfl allocation from buffer. +::tflite::Allocation::Ptr MakeAllocation(BufferRef buf); + +// Wrapper around a tflite model buffer. +class FlatbufferWrapper { + public: + using Ptr = std::unique_ptr; + + // Load flatbuffer from file. + static Expected CreateFromTflFile(absl::string_view path); + + // Load flatbuffer from allocated buffer that will be copied. + static Expected CreateFromBuffer(BufferRef buffer); + + // Load flatbuffer from allocated buffer and take ownership. + static Expected CreateFromBuffer(OwningBufferRef&& buffer); + + // Underlying buffer. + BufferRef Buf() const { + return BufferRef(alloc_->base(), alloc_->bytes()); + } + + // Underlying model object. + const ::tflite::FlatBufferModel& FlatbufferModel() const { + return *fb_model_; + } + + // Unpack the contained flatbuffer. + TflModelPtr Unpack() const { + return TflModelPtr(fb_model_->GetModel()->UnPack()); + } + + private: + FlatbufferWrapper(::tflite::FlatBufferModel::Ptr fb_model, + ::tflite::Allocation::Ptr alloc, + OwningBufferRef&& model_buf) + : fb_model_(std::move(fb_model)), + alloc_(std::move(alloc)), + model_buf_(std::forward>(model_buf)) {} + + ::tflite::FlatBufferModel::Ptr fb_model_; + ::tflite::Allocation::Ptr alloc_; + OwningBufferRef model_buf_; +}; + +// Re-serialize the unpacked model from flatbuffer wrapper. +OwningBufferRef SerializeFlatbuffer( + const FlatbufferWrapper& flatbuffer); +OwningBufferRef SerializeFlatbuffer(const TflModel& tfl_model); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_FLATBUFFER_TOOLS_H_ diff --git a/tflite/experimental/litert/core/util/flatbuffer_tools_test.cc b/tflite/experimental/litert/core/util/flatbuffer_tools_test.cc new file mode 100644 index 00000000..04419f4a --- /dev/null +++ b/tflite/experimental/litert/core/util/flatbuffer_tools_test.cc @@ -0,0 +1,175 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" + +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/test/common.h" +#include "tflite/experimental/litert/test/test_macros.h" + +namespace litert::internal { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; +using ::testing::Lt; + +FlatbufferWrapper::Ptr TestFlatbuffer( + absl::string_view filename = "one_mul.tflite") { + const auto tfl_path = testing::GetTestFilePath(filename); + return *FlatbufferWrapper::CreateFromTflFile(tfl_path); +} + +static const absl::string_view kKey = "MyKey"; +static const absl::string_view kData = "MyData"; + +TEST(FlatbufferToolsTest, Metadata) { + auto flatbuffer = TestFlatbuffer(); + ASSERT_NE(flatbuffer, nullptr); + auto tfl_model = flatbuffer->Unpack(); + + LITERT_ASSERT_STATUS_OK(PushMetadata( + kKey, *tfl_model, BufferRef(kData.data(), kData.size()))); + + auto metadata = GetMetadata(kKey, *tfl_model); + ASSERT_TRUE(metadata); + EXPECT_EQ(metadata->StrView(), kData); +} + +TEST(FlatbufferToolsTest, GetMetadataNotFound) { + auto flatbuffer = TestFlatbuffer(); + auto tfl_model = flatbuffer->Unpack(); + ASSERT_NE(flatbuffer, nullptr); + EXPECT_FALSE(GetMetadata(kKey, *tfl_model)); +} + +TEST(FlatbufferToolsTest, TflBuffer) { + auto flatbuffer = TestFlatbuffer(); + ASSERT_NE(flatbuffer, nullptr); + auto tfl_model = flatbuffer->Unpack(); + + auto ind = PushTflBuffer((*tfl_model), + BufferRef(kData.data(), kData.size())); + ASSERT_TRUE(ind); + + auto buf = GetTflBuffer((*tfl_model), *ind); + ASSERT_TRUE(buf); + ASSERT_EQ(buf->StrView(), kData); +} + +TEST(FlatbufferToolsTest, GetTflBufferNotFound) { + auto flatbuffer = TestFlatbuffer(); + ASSERT_NE(flatbuffer, nullptr); + auto tfl_model = flatbuffer->Unpack(); + + auto buf = GetTflBuffer((*tfl_model), 100); + ASSERT_FALSE(buf); +} + +TEST(FlatbufferToolsTest, GetTflOpCode) { + auto flatbuffer = TestFlatbuffer(); + ASSERT_NE(flatbuffer, nullptr); + auto tfl_model = flatbuffer->Unpack(); + + auto op_code = GetTflOpCode((*tfl_model), 0); + ASSERT_TRUE(op_code); +} + +TEST(FlatbufferToolsTest, GetTflOpCodeNotFound) { + auto flatbuffer = TestFlatbuffer(); + ASSERT_NE(flatbuffer, nullptr); + auto tfl_model = flatbuffer->Unpack(); + + auto op_code = GetTflOpCode((*tfl_model), 100); + ASSERT_FALSE(op_code); +} + +TEST(FlatbufferToolsTest, StaticTensorTypeTest) { + auto flatbuffer = TestFlatbuffer(); + auto tfl_model = flatbuffer->Unpack(); + auto& tensor = tfl_model->subgraphs.front()->tensors.front(); + + TflShapeInfo shape(*tensor); + + ASSERT_TRUE(IsRankedTensorType(shape)); + ASSERT_TRUE(IsStaticTensorType(shape)); + + auto static_shape = AsStaticShape(shape); + + ASSERT_TRUE(static_shape); + ASSERT_THAT(*static_shape, ElementsAreArray({2, 2})); +} + +TEST(FlatbufferToolsTest, UnrankedTensorTypeTest) { + auto flatbuffer = TestFlatbuffer("unranked_tensor.tflite"); + auto tfl_model = flatbuffer->Unpack(); + auto& tensor = tfl_model->subgraphs.front()->tensors.front(); + + TflShapeInfo shape(*tensor); + + ASSERT_FALSE(IsRankedTensorType(shape)); +} + +TEST(FlatbufferToolsTest, RankedDynamicTensorTypeTest) { + auto flatbuffer = TestFlatbuffer("dynamic_shape_tensor.tflite"); + auto tfl_model = flatbuffer->Unpack(); + auto& tensor = tfl_model->subgraphs.front()->tensors.front(); + + TflShapeInfo shape(*tensor); + + ASSERT_TRUE(IsRankedTensorType(shape)); + ASSERT_FALSE(IsStaticTensorType(shape)); + + auto dyn_shape = AsDynamicShape(shape); + + ASSERT_TRUE(dyn_shape); + ASSERT_THAT(*dyn_shape, ElementsAre(Lt(0), 2)); +} + +TEST(FlatbufferToolsTest, PerTensorQuantizedTest) { + auto flatbuffer = + TestFlatbuffer("single_add_default_a16w8_recipe_quantized.tflite"); + auto tfl_model = flatbuffer->Unpack(); + auto& tensor = tfl_model->subgraphs.front()->tensors.front(); + + const auto* const q_parms = tensor->quantization.get(); + + ASSERT_TRUE(IsQuantized(q_parms)); + EXPECT_TRUE(IsPerTensorQuantized(q_parms)); + + auto per_tensor = AsPerTensorQparams(q_parms); + ASSERT_TRUE(per_tensor); +} + +TEST(FlatbufferToolsTest, PerChannelQuantizedTest) { + auto flatbuffer = TestFlatbuffer("static_w8_a16_quantized_k_einsum.tflite"); + auto tfl_model = flatbuffer->Unpack(); + auto& tensor = tfl_model->subgraphs.front()->tensors[1]; + + const auto* const q_parms = tensor->quantization.get(); + + ASSERT_TRUE(IsQuantized(q_parms)); + EXPECT_TRUE(IsPerChannelQuantized(q_parms)); + + auto per_channel = AsPerChannelQparams(q_parms); + ASSERT_TRUE(per_channel); +} + +} // namespace +} // namespace litert::internal diff --git a/tflite/experimental/litert/core/util/tensor_type_util.cc b/tflite/experimental/litert/core/util/tensor_type_util.cc new file mode 100644 index 00000000..2d9811cd --- /dev/null +++ b/tflite/experimental/litert/core/util/tensor_type_util.cc @@ -0,0 +1,59 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/util/tensor_type_util.h" + +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +namespace litert { +namespace internal { + +Expected GetElementSize(LiteRtElementType element_type) { + switch (element_type) { + case kLiteRtElementTypeInt4: + return Ratio{1, 2}; + case kLiteRtElementTypeBool: + return Ratio{1, 1}; + case kLiteRtElementTypeInt8: + case kLiteRtElementTypeUInt8: + return Ratio{1, 1}; + case kLiteRtElementTypeInt16: + case kLiteRtElementTypeUInt16: + case kLiteRtElementTypeFloat16: + case kLiteRtElementTypeBFloat16: + return Ratio{2, 1}; + case kLiteRtElementTypeInt32: + case kLiteRtElementTypeUInt32: + case kLiteRtElementTypeFloat32: + return Ratio{4, 1}; + case kLiteRtElementTypeInt64: + case kLiteRtElementTypeUInt64: + case kLiteRtElementTypeFloat64: + return Ratio{8, 1}; + case kLiteRtElementTypeComplex64: + return Ratio{16, 1}; + case kLiteRtElementTypeComplex128: + return Ratio{32, 1}; + default: + return Unexpected(kLiteRtStatusErrorInvalidArgument, + "Unexpected element type"); + } +} + +} // namespace internal +} // namespace litert diff --git a/tflite/experimental/litert/core/util/tensor_type_util.h b/tflite/experimental/litert/core/util/tensor_type_util.h new file mode 100644 index 00000000..aaa15d5b --- /dev/null +++ b/tflite/experimental/litert/core/util/tensor_type_util.h @@ -0,0 +1,111 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_TENSOR_TYPE_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_TENSOR_TYPE_UTIL_H_ + +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +namespace litert::internal { + +struct Ratio { + using Type = int; + Type num; + Type denom; + std::string ToString() const { return absl::StrCat(num, "/", denom); } +}; + +Expected GetElementSize(LiteRtElementType element_type); + +// Get the number of elements in a tensor with given dimensions. +template +Expected GetNumElements(absl::Span dimensions) { + size_t num_elements = 1; + for (auto i = 0; i < dimensions.size(); ++i) { + auto dim = dimensions[i]; + if (dim < 0) { + return Unexpected(kLiteRtStatusErrorInvalidArgument, + "Unexpected negative dimension"); + } else if (dim == 0) { + return Unexpected(kLiteRtStatusErrorInvalidArgument, + "Unexpected 0 dimension"); + } + num_elements *= dim; + } + return num_elements; +} + +inline Expected GetNumElements( + const LiteRtRankedTensorType& tensor_type) { + return GetNumElements( + absl::MakeSpan(tensor_type.layout.dimensions, tensor_type.layout.rank)); +} + +// Get the minimum number of bytes necessary to represent a packed tensor with a +// given element type and dimensions. +template +Expected GetNumPackedBytes(LiteRtElementType element_type, + absl::Span dimensions) { + auto element_size = GetElementSize(element_type); + if (!element_size) { + return element_size.Error(); + } + auto num_elements = GetNumElements(dimensions); + if (!num_elements) { + return num_elements.Error(); + } + return ((*num_elements * element_size->num) + (element_size->denom - 1)) / + element_size->denom; +} + +// Get the number of bytes necessary to represent a packed tensor type, ignoring +// any stride information. +inline Expected GetNumPackedBytes( + const LiteRtRankedTensorType& tensor_type) { + return GetNumPackedBytes( + tensor_type.element_type, + absl::MakeSpan(tensor_type.layout.dimensions, tensor_type.layout.rank)); +} + +// Get the minimum number of bytes necessary to represent a possibly unpacked +// tensor with a given element type, dimensions, and strides. +template +Expected GetNumBytes(LiteRtElementType element_type, + absl::Span dimensions, absl::Span strides) { + if (dimensions.size() != strides.size()) { + return Unexpected( + kLiteRtStatusErrorInvalidArgument, + "Dimensions and strides have different number of elements"); + } + auto element_size = GetElementSize(element_type); + if (!element_size) { + return element_size.Error(); + } + auto rank = dimensions.size(); + size_t num_elements = 1; + for (auto i = 0; i < rank; ++i) { + num_elements += (dimensions[i] - 1) * strides[i]; + } + return ((num_elements * element_size->num) + (element_size->denom - 1)) / + element_size->denom; +} + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_TENSOR_TYPE_UTIL_H_ diff --git a/tflite/experimental/litert/core/util/tensor_type_util_test.cc b/tflite/experimental/litert/core/util/tensor_type_util_test.cc new file mode 100644 index 00000000..2ccab3b4 --- /dev/null +++ b/tflite/experimental/litert/core/util/tensor_type_util_test.cc @@ -0,0 +1,69 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/core/util/tensor_type_util.h" + +#include +#include + +#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_model.h" + +using litert::internal::GetNumBytes; +using litert::internal::GetNumElements; +using litert::internal::GetNumPackedBytes; + +TEST(TensorTypeUtil, GetNumElements) { + constexpr std::array dimensions = {3, 2, 1}; + auto num_elements = GetNumElements(absl::MakeSpan(dimensions)); + EXPECT_TRUE(num_elements); + EXPECT_EQ(*num_elements, 6); +} + +TEST(TensorTypeUtil, GetNumElementsWithUnknownDimension) { + constexpr std::array dimensions = {3, -1, 1}; + auto num_elements = GetNumElements(absl::MakeSpan(dimensions)); + EXPECT_FALSE(num_elements); +} + +TEST(TensorTypeUtil, GetNumElementsWithZeroDimension) { + constexpr std::array dimensions = {3, 0, 1}; + auto num_elements = GetNumElements(absl::MakeSpan(dimensions)); + EXPECT_FALSE(num_elements); +} + +TEST(TensorTypeUtil, GetNumPackedBytes) { + LiteRtElementType element_type = kLiteRtElementTypeInt32; + constexpr std::array dimensions = {3, 2, 1}; + auto num_bytes = GetNumPackedBytes(element_type, absl::MakeSpan(dimensions)); + EXPECT_TRUE(num_bytes); + EXPECT_EQ(*num_bytes, sizeof(int32_t) * 6); +} + +TEST(TensorTypeUtil, GetNumBytes) { + LiteRtElementType element_type = kLiteRtElementTypeInt32; + constexpr std::array dimensions = {3, 2, 1}; + constexpr std::array strides = {1, 4, 8}; + // The data should be allocated as follows (where 'X' is a used cell and 'o' + // is an unused/padding cell): + // + // XXXo XXX + // + // The total is 4 + 3 = 7 cells + auto num_bytes = GetNumBytes(element_type, absl::MakeSpan(dimensions), + absl::MakeSpan(strides)); + EXPECT_TRUE(num_bytes); + EXPECT_EQ(*num_bytes, sizeof(int32_t) * 7); +} diff --git a/tflite/experimental/litert/integration_test/BUILD b/tflite/experimental/litert/integration_test/BUILD new file mode 100644 index 00000000..e3809c5e --- /dev/null +++ b/tflite/experimental/litert/integration_test/BUILD @@ -0,0 +1,18 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) diff --git a/tflite/experimental/litert/runtime/BUILD b/tflite/experimental/litert/runtime/BUILD new file mode 100644 index 00000000..118045a6 --- /dev/null +++ b/tflite/experimental/litert/runtime/BUILD @@ -0,0 +1,153 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "tensor_buffer", + srcs = [ + "ahwb_buffer.cc", + "dmabuf_buffer.cc", + "event.cc", + "fastrpc_buffer.cc", + "ion_buffer.cc", + "tensor_buffer.cc", + ], + hdrs = [ + "ahwb_buffer.h", + "dmabuf_buffer.h", + "event.h", + "fastrpc_buffer.h", + "ion_buffer.h", + "tensor_buffer.h", + "//tflite/experimental/litert/c:litert_event.h", + "//tflite/experimental/litert/c:litert_tensor_buffer.h", + "//tflite/experimental/litert/c:litert_tensor_buffer_requirements.h", + ], + deps = [ + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/core/util:tensor_type_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "tfl_utils", + srcs = [ + "tfl_utils.cc", + ], + hdrs = [ + "tfl_utils.h", + ], + deps = [ + "//tflite/c:c_api", + "//tflite/c:c_api_opaque", + "//tflite/c:c_api_types", + "//tflite/c:common", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/cc:litert_detail", + "//tflite/experimental/litert/cc:litert_element_type", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_model", + ], +) + +cc_library( + name = "external_litert_buffer_context", + srcs = ["external_litert_buffer_context.cc"], + hdrs = ["external_litert_buffer_context.h"], + deps = [ + ":tfl_utils", + "//tflite/c:c_api", + "//tflite/c:c_api_opaque", + "//tflite/c:c_api_types", + "//tflite/c:common", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_tensor_buffer", + "//tflite/experimental/litert/cc:litert_tensor_buffer_requirements", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + +cc_library( + name = "compiled_model", + srcs = ["compiled_model.cc"], + hdrs = ["compiled_model.h"], + deps = [ + ":external_litert_buffer_context", + ":tensor_buffer", + "//tflite:framework", + "//tflite:model_builder", + "//tflite/c:c_api_opaque", + "//tflite/c:common", + "//tflite/core:cc_api_stable", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_compiled_model_options", + "//tflite/experimental/litert/c:litert_dispatch_delegate", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/cc:litert_detail", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_tensor_buffer", + "//tflite/experimental/litert/cc:litert_tensor_buffer_requirements", + "//tflite/experimental/litert/core/model", + "//tflite/experimental/litert/core/model:model_serialize", + "//tflite/kernels:builtin_ops", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:string_view", + "@org_tensorflow//tensorflow/compiler/mlir/lite:allocation", + ], +) + +cc_test( + name = "compiled_model_test", + srcs = ["compiled_model_test.cc"], + data = [ + "//tflite/experimental/litert/test:testdata/simple_model.tflite", + ], + deps = [ + ":compiled_model", + "//tflite:framework", + "//tflite/c:c_api_opaque", + "//tflite/c:common", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_compiled_model_options", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_tensor_buffer", + "//tflite/experimental/litert/core/model", + "//tflite/experimental/litert/test:common", + "//tflite/experimental/litert/test:simple_model", + "//tflite/kernels:builtin_ops", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tflite/experimental/litert/runtime/ahwb_buffer.cc b/tflite/experimental/litert/runtime/ahwb_buffer.cc new file mode 100644 index 00000000..a7c9e60e --- /dev/null +++ b/tflite/experimental/litert/runtime/ahwb_buffer.cc @@ -0,0 +1,112 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/runtime/ahwb_buffer.h" + +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_event.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +namespace litert { +namespace internal { + +bool AhwbBuffer::IsSupported() { +#if LITERT_HAS_AHWB_SUPPORT + return true; +#else + return false; +#endif +} + +Expected AhwbBuffer::Alloc(size_t size) { +#if LITERT_HAS_AHWB_SUPPORT + AHardwareBuffer* ahwb; + AHardwareBuffer_Desc ahwb_desc = { + .width = static_cast(size), + .height = 1, + .layers = 1, + .format = AHARDWAREBUFFER_FORMAT_BLOB, + .usage = AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY | + AHARDWAREBUFFER_USAGE_CPU_READ_RARELY | + AHARDWAREBUFFER_USAGE_GPU_DATA_BUFFER}; + if (AHardwareBuffer_allocate(&ahwb_desc, &ahwb) != 0) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to allocate AHWB"); + } + return AhwbBuffer{/*.ahwb=*/ahwb}; +#else + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "AHardwareBuffers are not supported on this platform"); +#endif // LITERT_HAS_AHWB_SUPPORT +} + +void AhwbBuffer::Free(AHardwareBuffer* ahwb) { +#if LITERT_HAS_AHWB_SUPPORT + AHardwareBuffer_release(ahwb); +#endif +} + +Expected AhwbBuffer::GetSize(AHardwareBuffer* ahwb) { +#if LITERT_HAS_AHWB_SUPPORT + AHardwareBuffer_Desc ahwb_desc; + AHardwareBuffer_describe(ahwb, &ahwb_desc); + return static_cast(ahwb_desc.width) * ahwb_desc.height * + ahwb_desc.layers; +#else + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "AHardwareBuffers are not supported on this platform"); +#endif // LITERT_HAS_AHWB_SUPPORT +} + +Expected AhwbBuffer::Lock(AHardwareBuffer* ahwb, LiteRtEvent event) { +#if LITERT_HAS_AHWB_SUPPORT + int fence = -1; + if (event) { + if (auto status = LiteRtGetEventSyncFenceFd(event, &fence); + status != kLiteRtStatusOk) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to get sync fence fd from event"); + } + } + void* host_addr; + if (AHardwareBuffer_lock(ahwb, + AHARDWAREBUFFER_USAGE_CPU_READ_RARELY | + AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, + fence, /*rect=*/nullptr, &host_addr) != 0) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, "Failed to lock AHWB"); + } + return host_addr; +#else + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "AHardwareBuffers are not supported on this platform"); +#endif +} + +Expected AhwbBuffer::Unlock(AHardwareBuffer* ahwb) { +#if LITERT_HAS_AHWB_SUPPORT + if (AHardwareBuffer_unlock(ahwb, /*fence=*/nullptr) != 0) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to unlock AHWB"); + } + return {}; +#else + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "AHardwareBuffers are not supported on this platform"); +#endif +} + +} // namespace internal +} // namespace litert diff --git a/tflite/experimental/litert/runtime/ahwb_buffer.h b/tflite/experimental/litert/runtime/ahwb_buffer.h new file mode 100644 index 00000000..987102e6 --- /dev/null +++ b/tflite/experimental/litert/runtime/ahwb_buffer.h @@ -0,0 +1,53 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_AHWB_BUFFER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_AHWB_BUFFER_H_ + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_event.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +#if LITERT_HAS_AHWB_SUPPORT +#include +#else +// Define a place holder AHardwareBuffer struct just to enable compilation. +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus +typedef struct AHardwareBuffer AHardwareBuffer; +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // LITERT_HAS_AHWB_SUPPORT + +namespace litert { +namespace internal { + +struct AhwbBuffer { + AHardwareBuffer* ahwb; + + static bool IsSupported(); + static Expected Alloc(size_t size); + static void Free(AHardwareBuffer* ahwb); + static Expected GetSize(AHardwareBuffer* ahwb); + static Expected Lock(AHardwareBuffer* ahwb, + LiteRtEvent event = nullptr); + static Expected Unlock(AHardwareBuffer* ahwb); +}; + +} // namespace internal +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_AHWB_BUFFER_H_ diff --git a/tflite/experimental/litert/runtime/compiled_model.cc b/tflite/experimental/litert/runtime/compiled_model.cc new file mode 100644 index 00000000..3cd26a90 --- /dev/null +++ b/tflite/experimental/litert/runtime/compiled_model.cc @@ -0,0 +1,310 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/runtime/compiled_model.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/lite/allocation.h" +#include "tflite/c/common.h" +#include "tflite/core/interpreter_builder.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_compiled_model_options.h" +#include "tflite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_detail.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/core/model/model_serialize.h" +#include "tflite/experimental/litert/runtime/external_litert_buffer_context.h" +#include "tflite/experimental/litert/runtime/tensor_buffer.h" +#include "tflite/interpreter.h" +#include "tflite/kernels/register.h" +#include "tflite/model_builder.h" +#include "tflite/stderr_reporter.h" + +using litert::Expected; +using litert::SmallVec; +using litert::TensorBuffer; +using litert::Unexpected; +using litert::internal::ExternalLiteRtBufferContext; + +Expected LiteRtCompiledModelT::Initialize() { + // Use BuiltinOpResolverWithoutDefaultDelegates to avoid auto applying of + // Xnnpack delegate with GetSignatureRunner() API. + tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver; + tflite::InterpreterBuilder(*fb_model_, resolver)(&interp_); + if (interp_ == nullptr) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure); + } + + signature_keys_ = interp_->signature_keys(); + if (signature_keys_.empty()) { + static auto* default_signature_key = + new std::string(LiteRtSignatureT::kDefaultSignatureKey); + signature_keys_.push_back(default_signature_key); + } + // Register the ExternalLiteRtBufferContext for TensorBuffer handshaking. + buffer_context_ = + std::make_unique(); + interp_->SetExternalContext(kTfLiteLiteRtBufferContext, + buffer_context_.get()); + + return {}; +} + +Expected LiteRtCompiledModelT::Create( + LiteRtModel model, LiteRtComplicationOptions complication_options) { + auto runtime = std::make_unique(); + + const char* model_buffer = nullptr; + size_t model_buffer_size = 0; + // The following code gets the original FB pointer from LiteRtModel. + // TODO b/383120429 - Use a better way of getting the FB pointer. + auto init_model_buffer = detail::GetTflInitFlatbuffer(*model); + if (init_model_buffer.Size() != 0) { + // Use the saved the original FB pointer when the LiteRtModel was created + // from a buffer. + model_buffer = init_model_buffer.StrData(); + model_buffer_size = init_model_buffer.Size(); + } else { + // TODO b/383120429 - Once LiteRtModel provide tflite::Model object, switch + // to use it to initialize Interpreter instead of serializing LiteRtModel. + auto [data, size, offset] = runtime->model_buf_.GetWeak(); + if (LiteRtSerializeModel(model, &data, &size, &offset, + /*destroy_model=*/false) != kLiteRtStatusOk) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure); + } + runtime->alloc_ = std::make_unique( + runtime->model_buf_.Data(), runtime->model_buf_.Size(), + tflite::DefaultErrorReporter()); + model_buffer = reinterpret_cast(runtime->alloc_->base()); + model_buffer_size = runtime->alloc_->bytes(); + } + runtime->fb_model_ = + tflite::FlatBufferModel::BuildFromBuffer(model_buffer, model_buffer_size); + if (runtime->fb_model_ == nullptr) { + return Unexpected(kLiteRtStatusErrorFileIO); + } + + if (auto res = runtime->Initialize(); !res.HasValue()) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure); + } + + // TODO: b/379317134 - Support other delegates with compilation options. + if (complication_options & kHwAccelNpu) { + auto dispatch_delegate_options = litert::CreateDispatchDelegateOptionsPtr(); + LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), + model_buffer); + auto dispatch_delegate = + litert::CreateDispatchDelegatePtr(std::move(dispatch_delegate_options)); + if (auto status = + runtime->interp_->ModifyGraphWithDelegate(dispatch_delegate.get()); + status != kTfLiteOk) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to modify graph with delegate"); + } + } + + return runtime; +} + +litert::Expected +LiteRtCompiledModelT::GetTensorBufferRequirements(const TfLiteTensor* tensor) { + auto requirements = buffer_context_->GetBufferRequirement(tensor); + if (requirements) { + return (*requirements)->Get(); + } + LiteRtTensorBufferRequirements litert_cpu_buffer_requirements; + LiteRtTensorBufferType cpu_buffer_type[] = { + kLiteRtTensorBufferTypeHostMemory}; + uint32_t cpu_buffer_strides[] = {0}; + auto res = LiteRtCreateTensorBufferRequirements( + /*num_supported_tensor_buffer_types=*/1, cpu_buffer_type, tensor->bytes, + /*num_strides=*/1, cpu_buffer_strides, &litert_cpu_buffer_requirements); + if (res != kLiteRtStatusOk) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to create CPU buffer requirements"); + } + cpu_buffer_requirements_[tensor] = + litert::TensorBufferRequirements(litert_cpu_buffer_requirements); + return litert_cpu_buffer_requirements; +} + +Expected +LiteRtCompiledModelT::GetInputBufferRequirements( + absl::string_view signature_key, size_t input_index) { + auto runner = GetSignatureRunner(signature_key); + if (runner == nullptr) { + return Unexpected(kLiteRtStatusErrorNotFound, + "Failed to get signature runner"); + } + auto input_names = runner->input_names(); + if (input_index >= input_names.size()) { + return Unexpected(kLiteRtStatusErrorIndexOOB, "Input index out of range"); + } + auto input_name = input_names[input_index]; + auto* input_tensor = runner->input_tensor(input_name); + if (input_tensor == nullptr) { + return Unexpected(kLiteRtStatusErrorNotFound, "Failed to get input tensor"); + } + + return GetTensorBufferRequirements(input_tensor); +} + +Expected +LiteRtCompiledModelT::GetOutputBufferRequirements( + absl::string_view signature_key, size_t output_index) { + auto runner = GetSignatureRunner(signature_key); + if (runner == nullptr) { + return Unexpected(kLiteRtStatusErrorNotFound, + "Failed to get signature runner"); + } + auto output_names = runner->output_names(); + if (output_index >= output_names.size()) { + return Unexpected(kLiteRtStatusErrorIndexOOB, "Output index out of range"); + } + auto output_name = output_names[output_index]; + auto* output_tensor = runner->output_tensor(output_name); + if (output_tensor == nullptr) { + return Unexpected(kLiteRtStatusErrorNotFound, + "Failed to get output tensor"); + } + + return GetTensorBufferRequirements(output_tensor); +} + +tflite::SignatureRunner* LiteRtCompiledModelT::GetSignatureRunner( + absl::string_view signature_key) { + if (signature_runners_.contains(signature_key)) { + return signature_runners_[signature_key]; + } + auto runner = interp_->GetSignatureRunner( + signature_key == LiteRtSignatureT::kDefaultSignatureKey + ? nullptr + : std::string(signature_key).c_str()); + signature_runners_[signature_key] = runner; + return runner; +} + +Expected LiteRtCompiledModelT::Run( + absl::string_view signature_key, + std::vector& input_buffers, + std::vector& output_buffers) { + auto runner = GetSignatureRunner(signature_key); + if (runner == nullptr) { + return Unexpected(kLiteRtStatusErrorNotFound, + "Failed to get signature runner"); + } + if (input_buffers.size() != runner->input_names().size()) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Input buffer size mismatch"); + } + if (output_buffers.size() != runner->output_names().size()) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Output buffer size mismatch"); + } + + for (int i = 0; i < runner->input_names().size(); ++i) { + const auto& input_name = runner->input_names()[i]; + auto* input_tensor = runner->input_tensor(input_name); + if (input_buffers[i]->buffer_type() == kLiteRtTensorBufferTypeHostMemory) { + // Assign CPU buffer via CustomAllocation. + TensorBuffer cpu_buffer(input_buffers[i], /*owned=*/false); + auto lock_and_addr = litert::TensorBufferScopedLock::Create(cpu_buffer); + TfLiteCustomAllocation custom_allocation{lock_and_addr->second, + input_tensor->bytes}; + runner->SetCustomAllocationForInputTensor(input_name, custom_allocation, + /*flags=*/0); + } else { + // Register tensor buffer for non CPU buffers. + input_buffers[i]->Duplicate(); + TensorBuffer duplicated_buffer(input_buffers[i]); + if (auto status = buffer_context_->RegisterTensorBuffer( + input_tensor, std::move(duplicated_buffer)); + status != kLiteRtStatusOk) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to register input tensor buffer"); + } + } + } + + for (int i = 0; i < runner->output_names().size(); ++i) { + const auto& output_name = runner->output_names()[i]; + auto* output_tensor = runner->output_tensor(output_name); + if (output_buffers[i]->buffer_type() == kLiteRtTensorBufferTypeHostMemory) { + // Assign CPU buffer via CustomAllocation. + TensorBuffer cpu_buffer(output_buffers[i], /*owned=*/false); + auto lock_and_addr = litert::TensorBufferScopedLock::Create(cpu_buffer); + TfLiteCustomAllocation custom_allocation{lock_and_addr->second, + output_tensor->bytes}; + runner->SetCustomAllocationForOutputTensor(output_name, custom_allocation, + /*flags=*/0); + } else { + output_buffers[i]->Duplicate(); + TensorBuffer duplicated_buffer(output_buffers[i]); + if (auto status = buffer_context_->RegisterTensorBuffer( + output_tensor, std::move(duplicated_buffer)); + status != kLiteRtStatusOk) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to register output tensor buffer"); + } + } + } + + if (auto res = runner->AllocateTensors(); res != kTfLiteOk) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to allocate tensors"); + } + + if (auto res = runner->Invoke(); res != kTfLiteOk) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, "Failed to invoke"); + } + + return {}; +} + +litert::Expected LiteRtCompiledModelT::RunCApi( + size_t signature_index, size_t num_input_buffers, + LiteRtTensorBuffer* input_buffers, size_t num_output_buffers, + LiteRtTensorBuffer* output_buffers) { + if (signature_index >= signature_keys_.size()) { + return litert::Unexpected( + kLiteRtStatusErrorIndexOOB, + "Signature index is out of range of signature keys"); + } + std::vector input_buffers_vec; + input_buffers_vec.reserve(num_input_buffers); + for (int i = 0; i < num_input_buffers; ++i) { + input_buffers_vec.push_back(std::move(input_buffers[i])); + } + std::vector output_buffers_vec; + output_buffers_vec.reserve(num_output_buffers); + for (int i = 0; i < num_output_buffers; ++i) { + output_buffers_vec.push_back(std::move(output_buffers[i])); + } + return Run(*signature_keys_[signature_index], input_buffers_vec, + output_buffers_vec); +} diff --git a/tflite/experimental/litert/runtime/compiled_model.h b/tflite/experimental/litert/runtime/compiled_model.h new file mode 100644 index 00000000..6cdea9f2 --- /dev/null +++ b/tflite/experimental/litert/runtime/compiled_model.h @@ -0,0 +1,142 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_COMPILED_MODEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_COMPILED_MODEL_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/lite/allocation.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_compiled_model_options.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/runtime/external_litert_buffer_context.h" +#include "tflite/interpreter.h" +#include "tflite/model_builder.h" + +// The LiteRtCompiledModelT is internal implementation of CompiledModel C++ API. +class LiteRtCompiledModelT { + public: + using Ptr = std::unique_ptr; + + LiteRtCompiledModelT() = default; + ~LiteRtCompiledModelT() = default; + + // Creates a LiteRtCompiledModelT from a LiteRtModel object. + // The model is loaded into memory and the caller takes ownership of the + // returned object. + static litert::Expected Create( + LiteRtModel model, LiteRtComplicationOptions complication_options); + + // Returns the buffer requirements for the n-th input tensor. The returned + // LiteRtTensorBufferRequirements is used to create the input tensor + // buffer. + litert::Expected GetInputBufferRequirements( + absl::string_view signature_key, size_t input_index); + + // The same as GetInputBufferRequirements() for C API. + litert::Expected + GetInputBufferRequirementsCApi(size_t signature_index, size_t input_index) { + if (signature_index >= signature_keys_.size()) { + return litert::Unexpected( + kLiteRtStatusErrorIndexOOB, + "Signature index is out of range of signature keys"); + } + return GetInputBufferRequirements(*signature_keys_[signature_index], + input_index); + } + + // Returns the buffer requirements for the n-th output tensor. The returned + // LiteRtTensorBufferRequirements is used to create the output tensor + // buffer. + litert::Expected GetOutputBufferRequirements( + absl::string_view signature_key, size_t output_index); + + // The same as GetOutputBufferRequirements() for C API. + litert::Expected + GetOutputBufferRequirementsCApi(size_t signature_index, size_t output_index) { + if (signature_index >= signature_keys_.size()) { + return litert::Unexpected( + kLiteRtStatusErrorIndexOOB, + "Signature index is out of range of signature keys"); + } + return GetOutputBufferRequirements(*signature_keys_[signature_index], + output_index); + } + + // Runs the model of the given signature with the provided input/output + // litert::TensorBuffers. + litert::Expected Run(absl::string_view signature_key, + std::vector& input_buffers, + std::vector& output_buffers); + + // The same as Run() for C API. + litert::Expected RunCApi(size_t signature_index, + size_t num_input_buffers, + LiteRtTensorBuffer* input_buffers, + size_t num_output_buffers, + LiteRtTensorBuffer* output_buffers); + + private: + // Processes the model and initializes the internal states. + // This is called in the public Create*() methods. + litert::Expected Initialize(); + + // Returns the buffer requirements for the given tensor. + litert::Expected GetTensorBufferRequirements( + const TfLiteTensor* tensor); + + // Returns the SignatureRunner for the given signature key. + // If the signature key is not found, returns nullptr. + tflite::SignatureRunner* GetSignatureRunner(absl::string_view signature_key); + + // Map from signature key to SignatureRunner. This is used to lazy calling + // GetSignatureRunner() which is expensive. + absl::flat_hash_map + signature_runners_; + + // The buffer requirement maps for CPU buffers. For delegates with CPU + // buffers, they don't register TensorBufferRequirements. Instead, the + // CompiledModel creates the TensorBufferRequirements and stores them + // in this map. + absl::flat_hash_map + cpu_buffer_requirements_; + + // The Interpreter and related objects used to run the model. + std::unique_ptr<::tflite::Interpreter> interp_; + std::unique_ptr<::tflite::FlatBufferModel> fb_model_; + std::unique_ptr<::tflite::Allocation> alloc_; + litert::OwningBufferRef model_buf_; + std::vector signature_keys_; + + // The ExternalLiteRtBufferContext used to register tensor buffers with + // Delegates. + // Note: The ExternalLiteRtBufferContext must be destroyed after the + // Interpreter. + std::unique_ptr + buffer_context_; +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_COMPILED_MODEL_H_ diff --git a/tflite/experimental/litert/runtime/compiled_model_test.cc b/tflite/experimental/litert/runtime/compiled_model_test.cc new file mode 100644 index 00000000..49914279 --- /dev/null +++ b/tflite/experimental/litert/runtime/compiled_model_test.cc @@ -0,0 +1,192 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/runtime/compiled_model.h" + +#include +#include +#include +#include +#include + +#include +#include +#include "absl/log/absl_log.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_compiled_model_options.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/test/common.h" +#include "tflite/experimental/litert/test/testdata/simple_model_test_vectors.h" + +namespace litert { +namespace { + +using ::testing::FloatNear; +using ::testing::Pointwise; + +Expected> CreateInputBuffers( + LiteRtModel& model, LiteRtCompiledModelT& compiled_model, + absl::string_view signature_key) { + std::vector input_buffers; + auto* subgraph = *LookupSubgraph(*model, signature_key); + auto& input_tensors = subgraph->Inputs(); + const size_t num_inputs = subgraph->NumInputs(); + input_buffers.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + auto litert_input_buffer_requirements = + compiled_model.GetInputBufferRequirements(signature_key, i); + if (!litert_input_buffer_requirements.HasValue()) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + litert_input_buffer_requirements.Error().Message()); + } + TensorBufferRequirements input_buffer_requirements = + TensorBufferRequirements(*litert_input_buffer_requirements, + /*owned=*/false); + const auto& ranked_tensor_type = + input_tensors[i]->Type().second.ranked_tensor_type; + LiteRtTensorBufferType tensor_buffer_type = + input_buffer_requirements.SupportedTypes()->at(0); + LiteRtTensorBuffer input_buffer; + if (auto status = LiteRtCreateManagedTensorBuffer( + tensor_buffer_type, &ranked_tensor_type, + input_buffer_requirements.BufferSize().Value(), &input_buffer); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to create input tensor buffer"); + } + input_buffers.push_back(input_buffer); + } + + return std::move(input_buffers); +} + +Expected> CreateOutputBuffers( + LiteRtModel& model, LiteRtCompiledModelT& compiled_model, + absl::string_view signature_key) { + std::vector output_buffers; + auto* subgraph = *LookupSubgraph(*model, signature_key); + auto& output_tensors = subgraph->Outputs(); + size_t num_outputs = subgraph->NumOutputs(); + output_buffers.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + auto litert_output_buffer_requirements = + compiled_model.GetOutputBufferRequirements(signature_key, i); + if (!litert_output_buffer_requirements.HasValue()) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + litert_output_buffer_requirements.Error().Message()); + } + TensorBufferRequirements output_buffer_requirements = + TensorBufferRequirements(*litert_output_buffer_requirements, + /*owned=*/false); + auto ranked_tensor_type = + output_tensors[i]->Type().second.ranked_tensor_type; + LiteRtTensorBufferType tensor_buffer_type = + output_buffer_requirements.SupportedTypes()->at(0); + LiteRtTensorBuffer output_buffer; + if (auto status = LiteRtCreateManagedTensorBuffer( + tensor_buffer_type, &ranked_tensor_type, + output_buffer_requirements.BufferSize().Value(), &output_buffer); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to create output tensor buffer"); + } + output_buffers.push_back(output_buffer); + } + + return std::move(output_buffers); +} + +TEST(CompiledModelTest, Basic) { + auto path = testing::GetTestFilePath(kModelFileName); + + LiteRtModel model; + ASSERT_EQ(LiteRtCreateModelFromFile(path.c_str(), &model), kLiteRtStatusOk); + + auto res_compiled_model = LiteRtCompiledModelT::Create(model, kHwAccelCpu); + ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel: " + << res_compiled_model.Error().Message(); + auto& compiled_model = **res_compiled_model; + + auto signatures = model->Signatures(); + ASSERT_EQ(signatures.size(), 1); + auto signature_key = signatures[0]->Key(); + EXPECT_EQ(signature_key, LiteRtSignatureT::kDefaultSignatureKey); + + auto input_buffers_res = + CreateInputBuffers(model, compiled_model, signature_key); + EXPECT_TRUE(input_buffers_res); + auto input_buffers = std::move(*input_buffers_res); + + auto output_buffers_res = + CreateOutputBuffers(model, compiled_model, signature_key); + EXPECT_TRUE(output_buffers_res); + auto output_buffers = std::move(*output_buffers_res); + + // Fill model inputs. + auto& input_names = signatures[0]->InputNames(); + EXPECT_EQ(input_names.size(), 2); + EXPECT_EQ(input_names.at(0), "arg0"); + EXPECT_EQ(input_names.at(1), "arg1"); + auto& input_0_buffer = input_buffers[0]; + { + TensorBuffer cpu_buffer(input_0_buffer, /*owned=*/false); + cpu_buffer.Write( + absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size)); + } + auto& input_1_buffer = input_buffers[1]; + { + TensorBuffer cpu_buffer(input_1_buffer, /*owned=*/false); + cpu_buffer.Write( + absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size)); + } + + // Execute model. + compiled_model.Run(signature_key, input_buffers, output_buffers); + + // Check model output. + auto output_names = signatures[0]->OutputNames(); + EXPECT_EQ(output_names.size(), 1); + EXPECT_EQ(output_names.at(0), "tfl.add"); + { + void* host_mem_addr; + ASSERT_EQ(LiteRtLockTensorBuffer(output_buffers[0], &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + auto output = absl::MakeSpan(static_cast(host_mem_addr), + kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(output_buffers[0]), kLiteRtStatusOk); + } + + // Since Buffers in LiteRtTensorBuffer, we need to destroy them explicitly. + for (auto& input_buffer : input_buffers) { + LiteRtDestroyTensorBuffer(input_buffer); + } + for (auto& output_buffer : output_buffers) { + LiteRtDestroyTensorBuffer(output_buffer); + } + + LiteRtDestroyModel(model); +} + +} // namespace +} // namespace litert diff --git a/tflite/experimental/litert/runtime/compiler/BUILD b/tflite/experimental/litert/runtime/compiler/BUILD new file mode 100644 index 00000000..384f665e --- /dev/null +++ b/tflite/experimental/litert/runtime/compiler/BUILD @@ -0,0 +1,50 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +cc_test( + name = "jit_compilation_qualcomm_test", + srcs = ["jit_compilation_qualcomm_test.cc"], + data = [ + "//tflite/experimental/litert/test:simple_model", + "//tflite/experimental/litert/vendors/qualcomm/compiler:qnn_compiler_plugin_so", + "//tflite/experimental/litert/vendors/qualcomm/dispatch:dispatch_api_so", + ], + linkopts = select({ + "@org_tensorflow//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + deps = [ + "//tflite:framework", + "//tflite/c:c_api_opaque", + "//tflite/c:common", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_dispatch_delegate", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/compiler/plugin:compiler_plugin", + "//tflite/experimental/litert/core/model:model_buffer", + "//tflite/experimental/litert/runtime:external_litert_buffer_context", + "//tflite/experimental/litert/test:common", + "//tflite/experimental/litert/test:simple_model_npu", + "//tflite/kernels:builtin_ops", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tflite/experimental/litert/runtime/compiler/jit_compilation_qualcomm_test.cc b/tflite/experimental/litert/runtime/compiler/jit_compilation_qualcomm_test.cc new file mode 100644 index 00000000..5291b52a --- /dev/null +++ b/tflite/experimental/litert/runtime/compiler/jit_compilation_qualcomm_test.cc @@ -0,0 +1,137 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include +#include "absl/log/absl_log.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "tflite/c/c_api_opaque.h" +#include "tflite/c/common.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/compiler/plugin/compiler_plugin.h" +#include "tflite/experimental/litert/core/model/model_buffer.h" +#include "tflite/experimental/litert/runtime/external_litert_buffer_context.h" +#include "tflite/experimental/litert/test/common.h" +#include "tflite/experimental/litert/test/testdata/simple_model_test_vectors.h" +#include "tflite/interpreter.h" +#include "tflite/kernels/register.h" +#include "tflite/model_builder.h" +#include "tflite/signature_runner.h" + +constexpr const char* kCompilerPluginLibSearchPath = "/data/local/tmp"; + +TEST(JitCompilation, Qualcomm) { + auto model_path = litert::testing::GetTestFilePath(kModelFileName); + auto model = litert::Model::CreateFromFile(model_path); + ASSERT_TRUE(model); + +#if !defined(__ANDROID__) + GTEST_SKIP() << "The rest of this test is specific to Android devices with a " + "Qualcomm HTP"; +#endif + + constexpr const std::array + compiler_plugin_lib_search_paths = {kCompilerPluginLibSearchPath}; + auto compiler_plugin = litert::internal::CompilerPlugin::LoadPlugin( + compiler_plugin_lib_search_paths, "Qualcomm"); + ASSERT_TRUE(compiler_plugin); + + auto api_version = compiler_plugin->ApiVersion(); + ASSERT_TRUE(api_version); + + ABSL_LOG(INFO) << "Found compiler plugin with version " << api_version->major + << "." << api_version->minor << "." << api_version->patch; + + auto npu_bytecode = ApplyPlugin(*compiler_plugin, *model); + EXPECT_TRUE(npu_bytecode); + EXPECT_GT(npu_bytecode->Size(), 0); + + auto serialized_model = litert::internal::GetModelBufWithByteCode( + std::move(*model->Get()), *npu_bytecode); + EXPECT_TRUE(serialized_model); + + model = litert::Model::CreateFromBuffer(*serialized_model); + + auto flatbuffer_model = tflite::FlatBufferModel::BuildFromBuffer( + reinterpret_cast(serialized_model->Data()), + serialized_model->Size()); + + EXPECT_TRUE(flatbuffer_model != nullptr); + + tflite::Interpreter::Ptr interpreter = nullptr; + tflite::ops::builtin::BuiltinOpResolver resolver; + tflite::InterpreterBuilder(*flatbuffer_model, resolver)(&interpreter); + EXPECT_TRUE(interpreter != nullptr); + + EXPECT_EQ(interpreter->nodes_size(), 1); + EXPECT_EQ(interpreter->inputs().size(), 2); + EXPECT_EQ(interpreter->outputs().size(), 1); + ASSERT_EQ(interpreter->execution_plan().size(), 1); + + litert::internal::ExternalLiteRtBufferContext buffer_context; + interpreter->SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); + + auto dispatch_delegate_options = litert::CreateDispatchDelegateOptionsPtr(); + LiteRtDispatchDelegateAddAllocBaseOption( + dispatch_delegate_options.get(), flatbuffer_model->allocation()->base()); + auto dispatch_delegate = + litert::CreateDispatchDelegatePtr(std::move(dispatch_delegate_options)); + + ASSERT_EQ(interpreter->ModifyGraphWithDelegate(dispatch_delegate.get()), + kTfLiteOk); + + // Get the list of signatures and check it. + auto signature_defs = interpreter->signature_keys(); + ASSERT_EQ(signature_defs.size(), 0); + + tflite::impl::SignatureRunner* runner = + interpreter->GetSignatureRunner(/*signature_key=*/nullptr); + ASSERT_NE(runner, nullptr); + + EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); + + // Fill model inputs. + ASSERT_STREQ(runner->input_names()[0], "arg0"); + auto input_0_tensor = runner->input_tensor("arg0"); + ASSERT_NE(input_0_tensor, nullptr); + auto* input_0 = input_0_tensor->data.f; + std::memcpy(input_0, kTestInput0Tensor, sizeof(kTestInput0Tensor)); + + ASSERT_STREQ(runner->input_names()[1], "arg1"); + auto input_1_tensor = runner->input_tensor("arg1"); + ASSERT_NE(input_1_tensor, nullptr); + auto* input_1 = input_1_tensor->data.f; + std::memcpy(input_1, kTestInput1Tensor, sizeof(kTestInput1Tensor)); + + EXPECT_EQ(runner->Invoke(), kTfLiteOk); + + // Check model output. + auto output_tensor = runner->output_tensor(runner->output_names()[0]); + ASSERT_NE(output_tensor, nullptr); + auto* output = output_tensor->data.f; + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; + } + for (auto i = 0; i < kTestOutputSize; ++i) { + EXPECT_NEAR(output[i], kTestOutputTensor[i], 1e-5); + } +} diff --git a/tflite/experimental/litert/runtime/dispatch/BUILD b/tflite/experimental/litert/runtime/dispatch/BUILD new file mode 100644 index 00000000..4910de49 --- /dev/null +++ b/tflite/experimental/litert/runtime/dispatch/BUILD @@ -0,0 +1,180 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "dispatch", + srcs = [ + "litert_dispatch.cc", + ], + hdrs = [ + "//tflite/experimental/litert/vendors/c:litert_dispatch.h", + "//tflite/experimental/litert/vendors/c:litert_dispatch_api.h", + ], + deps = [ + "//tflite/experimental/litert/c:litert_any", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "dispatch_delegate", + srcs = [ + "dispatch_delegate.cc", + "dispatch_delegate_kernel.cc", + ], + hdrs = [ + "dispatch_delegate_kernel.h", + "dispatch_delegate_options.h", + "//tflite/experimental/litert/c:litert_dispatch_delegate.h", + ], + deps = [ + "//tflite/c:c_api", + "//tflite/c:c_api_opaque", + "//tflite/c:c_api_types", + "//tflite/c:common", + "//tflite/core/c:c_api_opaque_without_op_resolver", + "//tflite/delegates/utils:simple_opaque_delegate", + "//tflite/experimental/litert/c:litert_any", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "//tflite/experimental/litert/cc:litert_any", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_tensor_buffer", + "//tflite/experimental/litert/core:byte_code_util", + "//tflite/experimental/litert/core:environment", + "//tflite/experimental/litert/runtime:external_litert_buffer_context", + "//tflite/experimental/litert/runtime:tfl_utils", + "//tflite/experimental/litert/vendors/c:litert_dispatch_c_api", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "dispatch_delegate_google_tensor_test", + srcs = ["dispatch_delegate_google_tensor_test.cc"], + data = [ + "//tflite/experimental/litert/vendors/google_tensor/dispatch:dispatch_api_so", + ], + linkopts = select({ + "@org_tensorflow//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + deps = [ + "//tflite:framework", + "//tflite/c:c_api_opaque", + "//tflite/c:common", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_compiled_model_options", + "//tflite/experimental/litert/c:litert_dispatch_delegate", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "//tflite/experimental/litert/cc:litert_compiled_model", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_tensor_buffer", + "//tflite/experimental/litert/core/model:model_buffer", + "//tflite/experimental/litert/core/util:flatbuffer_tools", + "//tflite/experimental/litert/runtime:external_litert_buffer_context", + "//tflite/experimental/litert/test:common", + "//tflite/experimental/litert/test:simple_model_npu", + "//tflite/kernels:builtin_ops", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "dispatch_delegate_qualcomm_test", + srcs = ["dispatch_delegate_qualcomm_test.cc"], + data = [ + "//tflite/experimental/litert/vendors/qualcomm/dispatch:dispatch_api_so", + ], + linkopts = select({ + "@org_tensorflow//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + deps = [ + "//tflite:framework", + "//tflite/c:c_api_opaque", + "//tflite/c:common", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_compiled_model_options", + "//tflite/experimental/litert/c:litert_dispatch_delegate", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "//tflite/experimental/litert/cc:litert_compiled_model", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_tensor_buffer", + "//tflite/experimental/litert/core/model:model_buffer", + "//tflite/experimental/litert/runtime:external_litert_buffer_context", + "//tflite/experimental/litert/test:common", + "//tflite/experimental/litert/test:simple_model_npu", + "//tflite/kernels:builtin_ops", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "dispatch_delegate_mediatek_test", + srcs = ["dispatch_delegate_mediatek_test.cc"], + data = [ + "//tflite/experimental/litert/vendors/mediatek/dispatch:dispatch_api_so", + ], + linkopts = select({ + "@org_tensorflow//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + deps = [ + "//tflite:framework", + "//tflite/c:c_api_opaque", + "//tflite/c:common", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_compiled_model_options", + "//tflite/experimental/litert/c:litert_dispatch_delegate", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "//tflite/experimental/litert/cc:litert_compiled_model", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_tensor_buffer", + "//tflite/experimental/litert/core/model:model_buffer", + "//tflite/experimental/litert/core/util:flatbuffer_tools", + "//tflite/experimental/litert/runtime:external_litert_buffer_context", + "//tflite/experimental/litert/test:common", + "//tflite/experimental/litert/test:simple_model_npu", + "//tflite/kernels:builtin_ops", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tflite/experimental/litert/runtime/dispatch/README.md b/tflite/experimental/litert/runtime/dispatch/README.md new file mode 100644 index 00000000..7cfb2195 --- /dev/null +++ b/tflite/experimental/litert/runtime/dispatch/README.md @@ -0,0 +1,20 @@ +## Google Tensor + +Test case can dispatch_delegate_google_tensor_test can be run on a device with a +Pixel 9 device with the following comands + +$ ../../../google/run_test_on_android.sh dispatch_delegate_google_tensor_test + +## Qualcomm + +Test case can dispatch_delegate_qualcomm_test can be run on a Samsung S24 device +with the following comands + +$ ../../../google/run_test_on_android.sh dispatch_delegate_qualcomm_test + +## MediaTek + +Test case can dispatch_delegate_mediatek_test can be run on a device with a +MetiaTek mt6989 SoC with the following comands + +$ ../../../google/run_test_on_android.sh dispatch_delegate_mediatek_test diff --git a/tflite/experimental/litert/runtime/dispatch/dispatch_delegate.cc b/tflite/experimental/litert/runtime/dispatch/dispatch_delegate.cc new file mode 100644 index 00000000..ad856aba --- /dev/null +++ b/tflite/experimental/litert/runtime/dispatch/dispatch_delegate.cc @@ -0,0 +1,161 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/c/c_api_opaque.h" +#include "tflite/c/c_api_types.h" +#include "tflite/c/common.h" +#include "tflite/delegates/utils/simple_opaque_delegate.h" +#include "tflite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/core/byte_code_util.h" +#include "tflite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.h" +#include "tflite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" + +namespace { + +using ::litert::internal::kLiteRtDispatchOpCustomCode; + +// A TFL Delegate that can recognize subgraphs that run on Dispatch API capable +// accelerators, e.g. TPU, DSP, ... It replaces such subgraphs and offloads +// their work through the Dispatch API. +class DispatchDelegate : public tflite::SimpleOpaqueDelegateInterface { + public: + static TfLiteOpaqueDelegate* Create(LiteRtDispatchDelegateOptions* options_) { + litert::DispatchDelegateOptionsPtr options( + options_, LiteRtDestroyDispatchDelegateOptions); + if (!options) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return nullptr; + } + + std::unique_ptr managed_sb_delegate( + new DispatchDelegate(std::move(options))); + return tflite::TfLiteOpaqueDelegateFactory::CreateSimpleDelegate( + std::move(managed_sb_delegate), + kTfLiteDelegateFlagsAllowDynamicTensors); + } + + bool IsNodeSupportedByDelegate(const TfLiteOperator* op, + const TfLiteOpaqueNode* node, + TfLiteOpaqueContext* context) const override; + + TfLiteStatus Initialize(TfLiteOpaqueContext* context) override; + + const char* Name() const override; + + std::unique_ptr + CreateDelegateKernelInterface() override; + + private: + static constexpr absl::string_view kDelegateName = "DispatchDelegate"; + + explicit DispatchDelegate(litert::DispatchDelegateOptionsPtr&& options) + : options_(std::move(options)) {} + + litert::DispatchDelegateOptionsPtr options_; + int dispatch_graph_name_id_ = 0; +}; + +bool DispatchDelegate::IsNodeSupportedByDelegate( + const TfLiteOperator* op, const TfLiteOpaqueNode* node, + TfLiteOpaqueContext* context) const { + auto custom_code = absl::string_view(TfLiteOperatorGetCustomName(op)); + return custom_code == kLiteRtDispatchOpCustomCode; +} + +TfLiteStatus DispatchDelegate::Initialize(TfLiteOpaqueContext* context) { + return kTfLiteOk; +} + +const char* DispatchDelegate::Name() const { return kDelegateName.data(); } + +std::unique_ptr +DispatchDelegate::CreateDelegateKernelInterface() { + std::string dispatch_graph_name = + absl::StrFormat("DispatchGraph_%d", dispatch_graph_name_id_++); + + auto kernel = litert::internal::DispatchDelegateKernel::Create( + std::move(dispatch_graph_name), *options_); + if (kernel) { + return std::move(*kernel); + } else { + LITERT_LOG(LITERT_ERROR, "Failed to create a dispatch delegate kernel: %s", + kernel.Error().Message().data()); + return nullptr; + } +} + +} // namespace + +LiteRtDispatchDelegateOptions* LiteRtCreateDefaultDispatchDelegateOptions() { + return new LiteRtDispatchDelegateOptions; +} + +TfLiteStatus LiteRtAddDispatchDelegateOption( + LiteRtDispatchDelegateOptions* options, LiteRtDispatchOption option) { + if (!options) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kTfLiteError; + } + + options->AddOption(option); + return kTfLiteOk; +} + +TfLiteStatus LiteRtDispatchDelegateAddAllocBaseOption( + LiteRtDispatchDelegateOptions* options, const void* alloc_base) { + AddAllocBaseOption(alloc_base, *options); + return kTfLiteOk; +} + +void LiteRtDestroyDispatchDelegateOptions( + LiteRtDispatchDelegateOptions* options) { + delete options; +} + +TfLiteDelegate* LiteRtCreateDispatchDelegate( + LiteRtDispatchDelegateOptions* options) { + if (!options) { + options = LiteRtCreateDefaultDispatchDelegateOptions(); + } + return DispatchDelegate::Create(options); +} + +void LiteRtDestroyDispatchDelegate(TfLiteOpaqueDelegate* delegate) { + tflite::TfLiteOpaqueDelegateFactory::DeleteSimpleDelegate(delegate); +} + +namespace litert { + +DispatchDelegateOptionsPtr CreateDispatchDelegateOptionsPtr() { + return {LiteRtCreateDefaultDispatchDelegateOptions(), + LiteRtDestroyDispatchDelegateOptions}; +} + +DispatchDelegatePtr CreateDispatchDelegatePtr( + DispatchDelegateOptionsPtr&& options) { + return DispatchDelegatePtr(LiteRtCreateDispatchDelegate(options.release()), + LiteRtDestroyDispatchDelegate); +} +} // namespace litert diff --git a/tflite/experimental/litert/runtime/dispatch/dispatch_delegate_google_tensor_test.cc b/tflite/experimental/litert/runtime/dispatch/dispatch_delegate_google_tensor_test.cc new file mode 100644 index 00000000..6635ec26 --- /dev/null +++ b/tflite/experimental/litert/runtime/dispatch/dispatch_delegate_google_tensor_test.cc @@ -0,0 +1,285 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include +#include +#include "absl/log/absl_log.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/c/c_api_opaque.h" +#include "tflite/c/common.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_compiled_model_options.h" +#include "tflite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_compiled_model.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tflite/experimental/litert/core/model/model_buffer.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tflite/experimental/litert/runtime/external_litert_buffer_context.h" +#include "tflite/experimental/litert/test/common.h" +#include "tflite/experimental/litert/test/testdata/simple_model_test_vectors.h" +#include "tflite/interpreter.h" +#include "tflite/signature_runner.h" + +namespace litert { +namespace { + +using ::litert::testing::MakeRuntimeFromTestFileWithNpuModel; +using ::testing::FloatNear; +using ::testing::Pointwise; + +static constexpr absl::string_view kNpuFile = kGoogleTensorModelFileName; +static constexpr absl::string_view kTfliteFile = "simple_model_npu.tflite"; + +TEST(DispatchDelegate, GoogleTensorCpuBuffer) { + auto runtime = MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile); + ASSERT_TRUE(runtime) << "Failed to initialize tflite interpreter"; + auto& rt = **runtime; + auto& interpreter = rt.Interpreter(); + + internal::ExternalLiteRtBufferContext buffer_context; + interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); + + EXPECT_EQ(interpreter.nodes_size(), 1); + EXPECT_EQ(interpreter.inputs().size(), 2); + EXPECT_EQ(interpreter.outputs().size(), 1); + ASSERT_EQ(interpreter.execution_plan().size(), 1); + + auto dispatch_delegate_options = CreateDispatchDelegateOptionsPtr(); + LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), + rt.Flatbuffer().Buf().Data()); + auto dispatch_delegate = + CreateDispatchDelegatePtr(std::move(dispatch_delegate_options)); + +#if !defined(__ANDROID__) + GTEST_SKIP() << "The rest of this test is specific to Android devices with a " + "GoogleTensor eTPU"; +#endif + + ASSERT_EQ(interpreter.ModifyGraphWithDelegate(dispatch_delegate.get()), + kTfLiteOk); + + // Get the list of signatures and check it. + auto signature_defs = interpreter.signature_keys(); + ASSERT_EQ(signature_defs.size(), 1); + + tflite::impl::SignatureRunner* runner = + interpreter.GetSignatureRunner(/*signature_key=*/nullptr); + ASSERT_NE(runner, nullptr); + + EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); + + // Fill model inputs. + ASSERT_STREQ(runner->input_names()[0], "arg0"); + auto input_0_tensor = runner->input_tensor("arg0"); + ASSERT_NE(input_0_tensor, nullptr); + auto* input_0 = input_0_tensor->data.f; + std::memcpy(input_0, kTestInput0Tensor, sizeof(kTestInput0Tensor)); + + ASSERT_STREQ(runner->input_names()[1], "arg1"); + auto input_1_tensor = runner->input_tensor("arg1"); + ASSERT_NE(input_1_tensor, nullptr); + auto* input_1 = input_1_tensor->data.f; + std::memcpy(input_1, kTestInput1Tensor, sizeof(kTestInput1Tensor)); + + EXPECT_EQ(runner->Invoke(), kTfLiteOk); + + // Check model output. + ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); + auto output_tensor = runner->output_tensor("tfl.custom"); + ASSERT_NE(output_tensor, nullptr); + auto output = absl::MakeSpan(output_tensor->data.f, kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(::testing::FloatNear(1e-5), kTestOutputTensor)); +} + +TEST(DispatchDelegate, GoogleTensorHwBuffer) { + auto runtime = MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile); + ASSERT_TRUE(runtime) << "Failed to initialize tflite interpreter"; + auto& rt = **runtime; + auto& interpreter = rt.Interpreter(); + + internal::ExternalLiteRtBufferContext buffer_context; + interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); + + EXPECT_EQ(interpreter.nodes_size(), 1); + EXPECT_EQ(interpreter.inputs().size(), 2); + EXPECT_EQ(interpreter.outputs().size(), 1); + ASSERT_EQ(interpreter.execution_plan().size(), 1); + + auto dispatch_delegate_options = CreateDispatchDelegateOptionsPtr(); + LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), + rt.Flatbuffer().Buf().Data()); + auto dispatch_delegate = + CreateDispatchDelegatePtr(std::move(dispatch_delegate_options)); + +#if !defined(__ANDROID__) + GTEST_SKIP() << "The rest of this test is specific to Android devices with a " + "GoogleTensor eTPU"; +#endif + + ASSERT_EQ(interpreter.ModifyGraphWithDelegate(dispatch_delegate.get()), + kTfLiteOk); + + // Create and register tensor buffers for all inputs and outputs. + + std::vector input_buffers; + for (int i = 0; i < interpreter.inputs().size(); ++i) { + auto input_buffer_requirements = + buffer_context.GetBufferRequirement(interpreter.input_tensor(i)); + ASSERT_TRUE(input_buffer_requirements); + ASSERT_EQ((*input_buffer_requirements)->SupportedTypes().Value()[0], + kLiteRtTensorBufferTypeAhwb); + auto input_buffer = + buffer_context.CreateBufferForTensor(interpreter.input_tensor(i)); + ASSERT_TRUE(input_buffer); + ASSERT_TRUE(input_buffer->IsOwned()); + ASSERT_EQ(*input_buffer->BufferType(), kLiteRtTensorBufferTypeAhwb); + auto duplicate_buffer = (*input_buffer).Duplicate(); + ASSERT_TRUE(duplicate_buffer); + auto status = buffer_context.RegisterTensorBuffer( + interpreter.input_tensor(i), std::move(*duplicate_buffer)); + ASSERT_EQ(status, kLiteRtStatusOk); + input_buffers.push_back(std::move(*input_buffer)); + } + + std::vector output_buffers; + for (int i = 0; i < interpreter.outputs().size(); ++i) { + auto output_buffer_requirements = + buffer_context.GetBufferRequirement(interpreter.output_tensor(i)); + ASSERT_TRUE(output_buffer_requirements); + ASSERT_EQ((*output_buffer_requirements)->SupportedTypes().Value()[0], + kLiteRtTensorBufferTypeAhwb); + auto output_buffer = + buffer_context.CreateBufferForTensor(interpreter.output_tensor(i)); + ASSERT_TRUE(output_buffer.HasValue()); + ASSERT_TRUE(output_buffer->IsOwned()); + ASSERT_EQ(*output_buffer->BufferType(), kLiteRtTensorBufferTypeAhwb); + auto duplicate_buffer = (*output_buffer).Duplicate(); + ASSERT_TRUE(duplicate_buffer); + auto status = buffer_context.RegisterTensorBuffer( + interpreter.output_tensor(i), std::move(*duplicate_buffer)); + ASSERT_EQ(status, kLiteRtStatusOk); + output_buffers.push_back(std::move(*output_buffer)); + } + + // Get the list of signatures and check it. + auto signature_defs = interpreter.signature_keys(); + ASSERT_EQ(signature_defs.size(), 1); + + tflite::impl::SignatureRunner* runner = + interpreter.GetSignatureRunner(/*signature_key=*/nullptr); + ASSERT_NE(runner, nullptr); + + EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); + + // Fill model inputs. + ASSERT_STREQ(runner->input_names()[0], "arg0"); + auto& input_0_buffer = input_buffers[0]; + input_0_buffer.Write( + absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size)); + + ASSERT_STREQ(runner->input_names()[1], "arg1"); + auto& input_1_buffer = input_buffers[1]; + input_1_buffer.Write( + absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size)); + + EXPECT_EQ(runner->Invoke(), kTfLiteOk); + + // Check model output. + ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); + auto& output_buffer = output_buffers[0]; + float output_buffer_data[kTestOutputSize]; + auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); + auto read_success = output_buffer.Read(output_span); + ASSERT_TRUE(read_success); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" + << kTestOutputTensor[i]; + } + EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); +} + +TEST(DispatchDelegate, CompiledModel) { + auto model_with_byte_code = + internal::GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), + testing::GetTestFilePath(kNpuFile)); + ASSERT_TRUE(model_with_byte_code); + auto model = Model::CreateFromBuffer(*model_with_byte_code); + ASSERT_TRUE(model); + +#if !defined(__ANDROID__) + GTEST_SKIP() << "The rest of this test is specific to Android devices with a " + "GoogleTensor eTPU"; +#endif + + auto res_compiled_model = CompiledModel::Create(*model, kHwAccelNpu); + ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; + auto& compiled_model = *res_compiled_model; + + auto signatures = model->GetSignatures(); + ASSERT_TRUE(signatures); + EXPECT_EQ(signatures->size(), 1); + auto& signature = signatures->at(0); + auto signature_key = signature.Key(); + EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); + size_t signature_index = 0; + + auto input_buffers_res = compiled_model.CreateInputBuffers(signature_index); + EXPECT_TRUE(input_buffers_res); + auto& input_buffers = *input_buffers_res; + + auto output_buffers_res = compiled_model.CreateOutputBuffers(signature_index); + EXPECT_TRUE(output_buffers_res); + auto& output_buffers = *output_buffers_res; + + // Fill model inputs. + auto input_names = signature.InputNames(); + EXPECT_EQ(input_names.size(), 2); + EXPECT_EQ(input_names.at(0), "arg0"); + EXPECT_EQ(input_names.at(1), "arg1"); + ASSERT_TRUE(input_buffers[0].Write( + absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); + ASSERT_TRUE(input_buffers[1].Write( + absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); + + // Execute model. + compiled_model.Run(signature_index, input_buffers, output_buffers); + + // Check model output. + auto output_names = signature.OutputNames(); + EXPECT_EQ(output_names.size(), 1); + EXPECT_EQ(output_names.at(0), "tfl.custom"); + float output_buffer_data[kTestOutputSize]; + auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); + ASSERT_TRUE(output_buffers[0].Read(output_span)); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" + << kTestOutputTensor[i]; + } + EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); +} + +} // namespace +} // namespace litert diff --git a/tflite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.cc b/tflite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.cc new file mode 100644 index 00000000..88c8ba8d --- /dev/null +++ b/tflite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.cc @@ -0,0 +1,642 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tflite/c/c_api_opaque.h" +#include "tflite/c/c_api_types.h" +#include "tflite/c/common.h" +#include "tflite/core/c/c_api_opaque.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/core/byte_code_util.h" +#include "tflite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h" +#include "tflite/experimental/litert/runtime/external_litert_buffer_context.h" +#include "tflite/experimental/litert/runtime/tfl_utils.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" + +namespace litert { +namespace internal { + +namespace { + +// Get the bytecode and function name from given custom op options data. +Expected>> ResolveExecInfo( + BufferRef custom_opts, TfLiteOpaqueContext* context, + const LiteRtDispatchDelegateOptions& options) { + auto exec_info = ParseExecInfo(custom_opts); + if (!exec_info) { + LITERT_LOG(LITERT_ERROR, "Failed to parse custom initial data", ""); + return Unexpected(kLiteRtStatusErrorRuntimeFailure); + } + + auto [function_name, metadata_key] = *exec_info; + + const char* metadata; + size_t bytes; + if (auto stat = TfLiteOpaqueContextGetMetadata(context, metadata_key.data(), + &metadata, &bytes); + stat != kTfLiteOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get metadata for dispatch op: %d", + stat); + return Unexpected(kLiteRtStatusErrorRuntimeFailure); + } + + BufferRef metadata_buf(metadata, bytes); + + auto bytecode_loc = ParseByteCodePlaceholder(metadata_buf); + if (!bytecode_loc) { + LITERT_LOG(LITERT_ERROR, "Failed to parse metadata for dispatch op", ""); + return Unexpected(kLiteRtStatusErrorRuntimeFailure); + } + + auto [bytecode_offset, bytecode_size] = *bytecode_loc; + + LITERT_LOG( + LITERT_INFO, + "Initializing invocation context for dispatch op\n\tfunction_name: " + "%s\n\tbyte_code_offset: %lu \n\tbyte_code_size: %lu", + function_name.data(), bytecode_offset, bytecode_size); + if (bytecode_size == 0) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Found zero-size bytecode"); + } + + auto alloc_base = FindAllocBase(options); + if (!alloc_base) { + LITERT_LOG(LITERT_ERROR, + "Could not find requried delegate options \"alloc_base\"", ""); + return Unexpected(kLiteRtStatusErrorRuntimeFailure); + } + + const void* alloc = std::any_cast(*alloc_base); + const void* bytecode = + reinterpret_cast(alloc) + bytecode_offset; + return std::make_pair(function_name, + BufferRef(bytecode, bytecode_size)); +} +} // namespace + +DispatchDelegateKernel::~DispatchDelegateKernel() { + for (size_t i = 0; i < input_tensor_buffer_handles_.size(); ++i) { + (void)LiteRtDispatchDetachInput(invocation_context_, i, + input_tensor_buffer_handles_[i]); + } + + for (size_t i = 0; i < output_tensor_buffer_handles_.size(); ++i) { + (void)LiteRtDispatchDetachOutput(invocation_context_, i, + output_tensor_buffer_handles_[i]); + } + + if (invocation_context_) { + (void)LiteRtDispatchInvocationContextDestroy(invocation_context_); + } + + for (auto& buffer_handle : input_tensor_buffer_handles_) { + (void)LiteRtDispatchUnregisterTensorBuffer(device_context_, buffer_handle); + } + + for (auto& buffer_handle : output_tensor_buffer_handles_) { + (void)LiteRtDispatchUnregisterTensorBuffer(device_context_, buffer_handle); + } + + if (device_context_) { + (void)LiteRtDispatchDeviceContextDestroy(device_context_); + } + + input_tensor_buffers_.clear(); + output_tensor_buffers_.clear(); +} + +Expected DispatchDelegateKernel::Create( + std::string&& graph_name, const LiteRtDispatchDelegateOptions& options) { + auto dispatch_options = options.GetDispatchOptions(); + if (auto status = LiteRtDispatchInitialize(dispatch_options.data(), + dispatch_options.size()); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to initialize Dispatch API: %d", status); + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to initialize Dispatch API"); + } + + const char* vendor_id; + if (auto status = LiteRtDispatchGetVendorId(&vendor_id); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get Dispatch API vendor ID: %d", + status); + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to get Dispatch API vendor ID"); + } + LITERT_LOG(LITERT_INFO, "Dispatch API vendor ID: %s", vendor_id); + + const char* build_id; + if (auto status = LiteRtDispatchGetBuildId(&build_id); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get Dispatch API build ID: %d", status); + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to get Dispatch API build ID"); + } + LITERT_LOG(LITERT_INFO, "Dispatch API build ID: %s", build_id); + + LiteRtApiVersion api_version; + if (auto status = LiteRtDispatchGetApiVersion(&api_version); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get LiteRT Dispatch API version: %d", + status); + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to get LiteRT Dispatch API version"); + } + LITERT_LOG(LITERT_INFO, "Dispatch API version: %d.%d.%d", api_version.major, + api_version.minor, api_version.patch); + // Check if the versions mach. + if (api_version.major != LITERT_API_VERSION_MAJOR || + api_version.minor < LITERT_API_VERSION_MINOR) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Found Dispatch API with an unsupported version"); + } + + int capabilities; + if (auto status = LiteRtDispatchGetCapabilities(&capabilities); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get Dispatch API capabilities: %d", + status); + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to get Dispatch API capabilities"); + } + LITERT_LOG(LITERT_INFO, "Dispatch API capabilities: %d", capabilities); + + if (!(capabilities & kLiteRtDispatchCapabilitiesBasic)) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Dispatch API has insufficient capabilities"); + } + + LiteRtDispatchDeviceContext device_context; + if (auto status = LiteRtDispatchDeviceContextCreate(&device_context); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get Dispatch API device context: %d", + status); + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to create Dispatch API device context"); + } + + return Ptr(new DispatchDelegateKernel(options, std::move(graph_name), + device_context)); +} + +TfLiteStatus DispatchDelegateKernel::Init( + TfLiteOpaqueContext* context, const TfLiteOpaqueDelegateParams* params) { + if (params->nodes_to_replace->size != 1) { + LITERT_LOG(LITERT_ERROR, + "Models with more than one dispatch node are not yet supported"); + return kTfLiteError; + } + + auto node_id = params->nodes_to_replace->data[0]; + TfLiteOpaqueNode* node; + TfLiteOperator* op; + if (auto status = TfLiteOpaqueContextGetNodeAndRegistration(context, node_id, + &node, &op); + status != kTfLiteOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get node and registration: %d", status); + return status; + } + + const void* init_data; + int init_data_size; + if (auto status = TfLiteOpaqueNodeGetCustomInitialData(node, &init_data, + &init_data_size); + status != kTfLiteOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get custom initial data: %d", status); + return status; + } + if (!init_data || !init_data_size) { + LITERT_LOG(LITERT_ERROR, "Found custom op with missing initial data"); + return kTfLiteError; + } + + BufferRef custom_opts(init_data, init_data_size); + auto exec_info = ResolveExecInfo(custom_opts, context, options_); + if (!exec_info) { + LITERT_LOG(LITERT_ERROR, "Failed to parse custom options"); + return kTfLiteError; + } + auto [function_name, bytecode] = *exec_info; + + const int num_inputs = params->input_tensors->size; + const int num_outputs = params->output_tensors->size; + + if (auto status = LiteRtDispatchInvocationContextCreate( + device_context_, kLiteRtDispatchExecutableTypeMlModel, + bytecode.Data(), bytecode.Size(), function_name.data(), num_inputs, + num_outputs, &invocation_context_); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to create invocation context: %d", status); + return kTfLiteError; + } + + input_tensor_buffers_require_cpu_sync_.resize(num_inputs); + input_tensor_buffers_.resize(num_inputs); + input_tensor_buffer_handles_.resize(num_inputs); + input_tensor_buffer_used_size_.resize(num_inputs); + + output_tensor_buffers_require_cpu_sync_.resize(num_outputs); + output_tensor_buffers_.resize(num_outputs); + output_tensor_buffer_handles_.resize(num_outputs); + output_tensor_buffer_used_size_.resize(num_outputs); + + void* external_context; + TfLiteOpaqueContextGetExternalContext(context, &external_context, + kTfLiteLiteRtBufferContext); + if (!external_context) { + LITERT_LOG(LITERT_ERROR, "External context not found"); + return kTfLiteError; + } + + auto* buffer_context = + reinterpret_cast( + external_context); + + // Register input and output buffer requirements. + size_t num_node_inputs = TfLiteOpaqueNodeNumberOfInputs(node); + for (size_t i = 0; i < num_node_inputs; ++i) { + auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetInput(context, node, i); + if (!tfl_opaque_tensor) { + LITERT_LOG(LITERT_ERROR, "Failed to get TFL node input %d", i); + return kTfLiteError; + } + auto tensor_type = ConvertTensorType(tfl_opaque_tensor); + if (!tensor_type) { + LITERT_LOG(LITERT_ERROR, "%s", tensor_type.Error().Message().data()); + return kTfLiteError; + } + auto input_buffer_requirements = + GetBufferRequirements(*tensor_type, i, /*is_input=*/true); + if (auto res = buffer_context->RegisterBufferRequirement( + tfl_opaque_tensor, std::move(*input_buffer_requirements)); + res != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to register buffer requirement"); + return kTfLiteError; + } + } + + size_t num_node_outputs = TfLiteOpaqueNodeNumberOfOutputs(node); + for (size_t i = 0; i < num_node_outputs; ++i) { + auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetOutput(context, node, i); + if (!tfl_opaque_tensor) { + LITERT_LOG(LITERT_ERROR, "Failed to get TFL node output %d", i); + return kTfLiteError; + } + auto tensor_type = ConvertTensorType(tfl_opaque_tensor); + if (!tensor_type) { + LITERT_LOG(LITERT_ERROR, "%s", tensor_type.Error().Message().data()); + return kTfLiteError; + } + auto output_buffer_requirements = + GetBufferRequirements(*tensor_type, i, /*is_input=*/false); + if (auto res = buffer_context->RegisterBufferRequirement( + tfl_opaque_tensor, std::move(*output_buffer_requirements)); + res != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to register buffer requirement"); + return kTfLiteError; + } + } + + return kTfLiteOk; +} + +Expected +DispatchDelegateKernel::GetBufferRequirements( + const RankedTensorType& tensor_type, int io_tensor_index, + bool is_input) const { + auto litert_tensor_type = static_cast(tensor_type); + LiteRtTensorBufferRequirements tensor_buffer_requirements; + if (is_input) { + if (auto status = LiteRtDispatchGetInputRequirements( + invocation_context_, /*input_index=*/io_tensor_index, + &litert_tensor_type, &tensor_buffer_requirements); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, + "Failed to get tensor buffer requirements for input %d: %d", + io_tensor_index, status); + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to get tensor buffer requirements for input"); + } + + } else { + if (auto status = LiteRtDispatchGetOutputRequirements( + invocation_context_, /*output_index=*/io_tensor_index, + &litert_tensor_type, &tensor_buffer_requirements); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, + "Failed to get tensor buffer requirements for output %d: %d", + io_tensor_index, status); + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to get tensor buffer requirements for output"); + } + } + + return TensorBufferRequirements(tensor_buffer_requirements, + /*owned=*/true); +} + +TfLiteStatus DispatchDelegateKernel::CreateAndSetBuffer( + const TfLiteOpaqueTensor* tfl_opaque_tensor, int buffer_index, + bool is_input) { + auto& cached_tensor_buffer = is_input ? input_tensor_buffers_[buffer_index] + : output_tensor_buffers_[buffer_index]; + + auto tensor_type = ConvertTensorType(tfl_opaque_tensor); + if (!tensor_type) { + LITERT_LOG(LITERT_ERROR, "%s", tensor_type.Error().Message().data()); + return kTfLiteError; + } + + // Check if we can reuse a cached tensor buffer or we need to create a new + // one. + if (static_cast(cached_tensor_buffer)) { + if (auto cached_tensor_type = cached_tensor_buffer.TensorType(); + !cached_tensor_type) { + LITERT_LOG(LITERT_ERROR, "%s", + cached_tensor_type.Error().Message().data()); + return kTfLiteError; + } + + if (tensor_type->Layout() == cached_tensor_buffer.TensorType()->Layout()) { + // We can reuse the cached tensor buffer. + return kTfLiteOk; + } + + // We cannot reuse the cached tensor buffer; proceed below. + } + + auto tensor_buffer_requirements = + GetBufferRequirements(*tensor_type, buffer_index, is_input); + if (!tensor_buffer_requirements) { + LITERT_LOG(LITERT_ERROR, "%s", + tensor_buffer_requirements.Error().Message().data()); + return kTfLiteError; + } + + auto supported_tensor_buffer_types = + tensor_buffer_requirements->SupportedTypes(); + if (!supported_tensor_buffer_types) { + LITERT_LOG(LITERT_ERROR, "%s", + supported_tensor_buffer_types.Error().Message().data()); + return kTfLiteError; + } + + if (supported_tensor_buffer_types->empty()) { + LITERT_LOG(LITERT_ERROR, + "Insufficient number of supported tensor buffer types"); + return kTfLiteError; + } + + // For now we simply pick the first buffer type that's supported. + LiteRtTensorBufferType tensor_buffer_type = + (*supported_tensor_buffer_types)[0]; + + auto tensor_buffer_size = tensor_buffer_requirements->BufferSize(); + if (!tensor_buffer_size) { + LITERT_LOG(LITERT_ERROR, "%s", tensor_buffer_size.Error().Message().data()); + return kTfLiteError; + } + + auto litert_tensor_type = static_cast(*tensor_type); + LiteRtTensorBuffer litert_tensor_buffer; + if (auto status = LiteRtCreateManagedTensorBuffer( + tensor_buffer_type, &litert_tensor_type, *tensor_buffer_size, + &litert_tensor_buffer); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to create managed tensor buffer: %d", + status); + return kTfLiteError; + } + + return RegisterLiteRtTensorBuffer(TensorBuffer(litert_tensor_buffer), + *tensor_buffer_size, buffer_index, + is_input); +} + +TfLiteStatus DispatchDelegateKernel::RegisterLiteRtTensorBuffer( + TensorBuffer&& tensor_buffer, size_t buffer_used_size, int buffer_index, + bool is_input) { + LiteRtTensorBufferHandle buffer_handle; + if (auto status = LiteRtDispatchRegisterTensorBuffer( + device_context_, tensor_buffer.Get(), &buffer_handle); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to register tensor buffer: %d", status); + return kTfLiteError; + } + + if (is_input) { + if (auto status = LiteRtDispatchAttachInput(invocation_context_, + buffer_index, buffer_handle); + status != kLiteRtStatusOk) { + (void)LiteRtDispatchUnregisterTensorBuffer(device_context_, + buffer_handle); + LITERT_LOG(LITERT_ERROR, "Failed to attach tensor buffer to input %d: %d", + buffer_index, status); + return kTfLiteError; + } + } else { + if (auto status = LiteRtDispatchAttachOutput(invocation_context_, + buffer_index, buffer_handle); + status != kLiteRtStatusOk) { + (void)LiteRtDispatchUnregisterTensorBuffer(device_context_, + buffer_handle); + LITERT_LOG(LITERT_ERROR, + "Failed to attach tensor buffer to output %d: %d", + buffer_index, status); + return kTfLiteError; + } + } + + if (is_input) { + input_tensor_buffers_[buffer_index] = std::move(tensor_buffer); + input_tensor_buffer_handles_[buffer_index] = buffer_handle; + input_tensor_buffer_used_size_[buffer_index] = buffer_used_size; + } else { + output_tensor_buffers_[buffer_index] = std::move(tensor_buffer); + output_tensor_buffer_handles_[buffer_index] = buffer_handle; + output_tensor_buffer_used_size_[buffer_index] = buffer_used_size; + } + return kTfLiteOk; +} + +TfLiteStatus DispatchDelegateKernel::Prepare(TfLiteOpaqueContext* context, + TfLiteOpaqueNode* node) { + return kTfLiteOk; +} + +TfLiteStatus DispatchDelegateKernel::RegisterLiteRtTensorBuffers( + TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) { + void* external_context; + TfLiteOpaqueContextGetExternalContext(context, &external_context, + kTfLiteLiteRtBufferContext); + auto* buffer_context = + reinterpret_cast( + external_context); + + size_t num_node_inputs = TfLiteOpaqueNodeNumberOfInputs(node); + for (size_t i = 0; i < num_node_inputs; ++i) { + auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetInput(context, node, i); + auto tensor_buffer = buffer_context->GetTensorBuffer(tfl_opaque_tensor); + if (tensor_buffer.HasValue()) { + // TODO - b/379176766: If the provided TensorBuffer is not supported + // types, we need to create a new one and convert the data from the + // provided TensorBuffer. + auto buffer_size = tensor_buffer->Size(); + if (!buffer_size) { + LITERT_LOG(LITERT_ERROR, "%s", buffer_size.Error().Message().data()); + return kTfLiteError; + } + if (auto status = RegisterLiteRtTensorBuffer(std::move(*tensor_buffer), + *buffer_size, i, + /*is_input=*/true); + status != kTfLiteOk) { + return status; + } + input_tensor_buffers_require_cpu_sync_[i] = false; + } else { + LITERT_LOG(LITERT_INFO, + "Input#%d TensorBuffer is not registered. Create a new one", + i); + if (auto status = + CreateAndSetBuffer(tfl_opaque_tensor, i, /*is_input=*/true); + status != kTfLiteOk) { + return status; + } + input_tensor_buffers_require_cpu_sync_[i] = true; + } + } + + size_t num_node_outputs = TfLiteOpaqueNodeNumberOfOutputs(node); + for (size_t i = 0; i < num_node_outputs; ++i) { + auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetOutput(context, node, i); + auto tensor_buffer = buffer_context->GetTensorBuffer(tfl_opaque_tensor); + if (tensor_buffer.HasValue()) { + // TODO - b/379176766: If the provided TensorBuffer is not supported + // types, we need to create a new one and convert the data back to the + // provided TensorBuffer. + auto buffer_size = tensor_buffer->Size(); + if (!buffer_size) { + LITERT_LOG(LITERT_ERROR, "%s", buffer_size.Error().Message().data()); + return kTfLiteError; + } + if (auto status = RegisterLiteRtTensorBuffer(std::move(*tensor_buffer), + *buffer_size, i, + /*is_input=*/false); + status != kTfLiteOk) { + return status; + } + output_tensor_buffers_require_cpu_sync_[i] = false; + } else { + LITERT_LOG(LITERT_INFO, + "Output#%d TensorBuffer is not registered. Create a new one", + i); + if (auto status = + CreateAndSetBuffer(tfl_opaque_tensor, i, /*is_input=*/false); + status != kTfLiteOk) { + return status; + } + output_tensor_buffers_require_cpu_sync_[i] = true; + } + } + + return kTfLiteOk; +} + +TfLiteStatus DispatchDelegateKernel::Eval(TfLiteOpaqueContext* context, + TfLiteOpaqueNode* node) { + if (auto status = RegisterLiteRtTensorBuffers(context, node); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to register tensor buffers: %d", status); + return kTfLiteError; + } + + size_t num_node_inputs = TfLiteOpaqueNodeNumberOfInputs(node); + if (num_node_inputs != input_tensor_buffers_.size()) { + LITERT_LOG(LITERT_ERROR, "Invalid number of inputs"); + return kTfLiteError; + } + + for (size_t i = 0; i < num_node_inputs; ++i) { + if (!input_tensor_buffers_require_cpu_sync_[i]) { + continue; + } + auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetInput(context, node, i); + void* tensor_data = TfLiteOpaqueTensorData(tfl_opaque_tensor); + auto& tensor_buffer = input_tensor_buffers_[i]; + + auto lock_and_addr = TensorBufferScopedLock::Create(tensor_buffer); + if (!lock_and_addr) { + LITERT_LOG(LITERT_ERROR, "%s", lock_and_addr.Error().Message().data()); + return kTfLiteError; + } + + size_t buffer_size = input_tensor_buffer_used_size_[i]; + std::memcpy(lock_and_addr->second, tensor_data, buffer_size); + } + + if (auto status = LiteRtDispatchInvoke(invocation_context_); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to invoke context: %d", status); + return kTfLiteError; + } + + size_t num_node_outputs = TfLiteOpaqueNodeNumberOfOutputs(node); + if (num_node_outputs != output_tensor_buffers_.size()) { + LITERT_LOG(LITERT_ERROR, "Invalid number of outputs"); + return kTfLiteError; + } + + for (size_t i = 0; i < num_node_outputs; ++i) { + if (!output_tensor_buffers_require_cpu_sync_[i]) { + continue; + } + auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetOutput(context, node, i); + void* tensor_data = TfLiteOpaqueTensorData(tfl_opaque_tensor); + auto& tensor_buffer = output_tensor_buffers_[i]; + + auto lock_and_addr = TensorBufferScopedLock::Create(tensor_buffer); + if (!lock_and_addr) { + LITERT_LOG(LITERT_ERROR, "%s", lock_and_addr.Error().Message().data()); + return kTfLiteError; + } + + size_t buffer_size = output_tensor_buffer_used_size_[i]; + std::memcpy(tensor_data, lock_and_addr->second, buffer_size); + } + + return kTfLiteOk; +} + +} // namespace internal +} // namespace litert diff --git a/tflite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.h b/tflite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.h new file mode 100644 index 00000000..8d864de8 --- /dev/null +++ b/tflite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.h @@ -0,0 +1,115 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DISPATCH_DISPATCH_DELEGATE_KERNEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DISPATCH_DISPATCH_DELEGATE_KERNEL_H_ + +#include +#include +#include +#include +#include + +#include "tflite/c/c_api_types.h" +#include "tflite/c/common.h" +#include "tflite/delegates/utils/simple_opaque_delegate.h" +#include "tflite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" + +namespace litert { +namespace internal { + +// A TFL kernel that the interpreter calls to dispatch execution through the +// Dispatch API. +class DispatchDelegateKernel + : public tflite::SimpleOpaqueDelegateKernelInterface { + public: + using Ptr = std::unique_ptr; + + ~DispatchDelegateKernel() override; + + static Expected Create(std::string&& graph_name, + const LiteRtDispatchDelegateOptions& options); + + TfLiteStatus Init(TfLiteOpaqueContext* context, + const TfLiteOpaqueDelegateParams* params) override; + + TfLiteStatus Prepare(TfLiteOpaqueContext* context, + TfLiteOpaqueNode* node) override; + + TfLiteStatus Eval(TfLiteOpaqueContext* context, + TfLiteOpaqueNode* node) override; + + private: + DispatchDelegateKernel(const LiteRtDispatchDelegateOptions& options, + std::string&& graph_name, + LiteRtDispatchDeviceContext device_context) + : options_(options), + graph_name_(std::move(graph_name)), + device_context_(device_context) {} + + Expected GetBufferRequirements( + const RankedTensorType& tensor_type, int io_tensor_index, + bool is_input) const; + + // Creates a new tensor buffer for the given tensor. After that the created + // tensor buffer is registered with RegisterLiteRtTensorBuffer(). + TfLiteStatus CreateAndSetBuffer(const TfLiteOpaqueTensor* tfl_opaque_tensor, + int buffer_index, bool is_input); + + // Registers the given LiteRtTensorBuffer (and its size) with the Dispatch + // API. + // Also update the internal state (input_tensor_buffers_, etc.) to keep track + // of the registered tensor buffers. + TfLiteStatus RegisterLiteRtTensorBuffer(TensorBuffer&& tensor_buffer, + size_t used_size, int buffer_index, + bool is_input); + + // Registers LiteRtTensorBuffers for all inputs and outputs of the given + // node. + // Also update the internal state (input_tensor_buffers_, etc.) to keep track + // of the registered tensor buffers. + TfLiteStatus RegisterLiteRtTensorBuffers(TfLiteOpaqueContext* context, + TfLiteOpaqueNode* node); + + const LiteRtDispatchDelegateOptions& options_; + std::string graph_name_; + LiteRtDispatchDeviceContext device_context_; + LiteRtDispatchInvocationContext invocation_context_ = nullptr; + + // Indicates whether the input tensor buffer requires a CPU sync before + // invoking the Dispatch API. + std::vector input_tensor_buffers_require_cpu_sync_; + + std::vector input_tensor_buffers_; + std::vector input_tensor_buffer_handles_; + std::vector input_tensor_buffer_used_size_; + + // Indicates whether the output tensor buffer requires a CPU sync after + // invoking the Dispatch API. + std::vector output_tensor_buffers_require_cpu_sync_; + + std::vector output_tensor_buffers_; + std::vector output_tensor_buffer_handles_; + std::vector output_tensor_buffer_used_size_; +}; + +} // namespace internal +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DISPATCH_DISPATCH_DELEGATE_KERNEL_H_ diff --git a/tflite/experimental/litert/runtime/dispatch/dispatch_delegate_mediatek_test.cc b/tflite/experimental/litert/runtime/dispatch/dispatch_delegate_mediatek_test.cc new file mode 100644 index 00000000..45240f90 --- /dev/null +++ b/tflite/experimental/litert/runtime/dispatch/dispatch_delegate_mediatek_test.cc @@ -0,0 +1,285 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include +#include +#include "absl/log/absl_log.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/c/c_api_opaque.h" +#include "tflite/c/common.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_compiled_model_options.h" +#include "tflite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_compiled_model.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tflite/experimental/litert/core/model/model_buffer.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tflite/experimental/litert/runtime/external_litert_buffer_context.h" +#include "tflite/experimental/litert/test/common.h" +#include "tflite/experimental/litert/test/testdata/simple_model_test_vectors.h" +#include "tflite/interpreter.h" +#include "tflite/signature_runner.h" + +namespace litert { +namespace { + +using ::litert::testing::MakeRuntimeFromTestFileWithNpuModel; +using ::testing::FloatNear; +using ::testing::Pointwise; + +static constexpr absl::string_view kNpuFile = kMediaTekModelFileName; +static constexpr absl::string_view kTfliteFile = "simple_model_npu.tflite"; + +TEST(DispatchDelegate, MediaTekCpuBuffer) { + auto runtime = MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile); + ASSERT_TRUE(runtime) << "Failed to initialize tflite interpreter"; + auto& rt = **runtime; + auto& interpreter = rt.Interpreter(); + + litert::internal::ExternalLiteRtBufferContext buffer_context; + interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); + + EXPECT_EQ(interpreter.nodes_size(), 1); + EXPECT_EQ(interpreter.inputs().size(), 2); + EXPECT_EQ(interpreter.outputs().size(), 1); + ASSERT_EQ(interpreter.execution_plan().size(), 1); + + auto dispatch_delegate_options = CreateDispatchDelegateOptionsPtr(); + LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), + rt.Flatbuffer().Buf().Data()); + auto dispatch_delegate = + CreateDispatchDelegatePtr(std::move(dispatch_delegate_options)); + +#if !defined(__ANDROID__) + GTEST_SKIP() << "The rest of this test is specific to Android devices with a " + "MediaTek NPU"; +#endif + + ASSERT_EQ(interpreter.ModifyGraphWithDelegate(dispatch_delegate.get()), + kTfLiteOk); + + // Get the list of signatures and check it. + auto signature_defs = interpreter.signature_keys(); + ASSERT_EQ(signature_defs.size(), 1); + + tflite::impl::SignatureRunner* runner = + interpreter.GetSignatureRunner(/*signature_key=*/nullptr); + ASSERT_NE(runner, nullptr); + + EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); + + // Fill model inputs. + ASSERT_STREQ(runner->input_names()[0], "arg0"); + auto input_0_tensor = runner->input_tensor("arg0"); + ASSERT_NE(input_0_tensor, nullptr); + auto* input_0 = input_0_tensor->data.f; + std::memcpy(input_0, kTestInput0Tensor, sizeof(kTestInput0Tensor)); + + ASSERT_STREQ(runner->input_names()[1], "arg1"); + auto input_1_tensor = runner->input_tensor("arg1"); + ASSERT_NE(input_1_tensor, nullptr); + auto* input_1 = input_1_tensor->data.f; + std::memcpy(input_1, kTestInput1Tensor, sizeof(kTestInput1Tensor)); + + EXPECT_EQ(runner->Invoke(), kTfLiteOk); + + // Check model output. + ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); + auto output_tensor = runner->output_tensor("tfl.custom"); + ASSERT_NE(output_tensor, nullptr); + auto output = absl::MakeSpan(output_tensor->data.f, kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(::testing::FloatNear(1e-5), kTestOutputTensor)); +} + +TEST(DispatchDelegate, MediaTekHwBuffer) { + auto runtime = MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile); + ASSERT_TRUE(runtime) << "Failed to initialize tflite interpreter"; + auto& rt = **runtime; + auto& interpreter = rt.Interpreter(); + + litert::internal::ExternalLiteRtBufferContext buffer_context; + interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); + + EXPECT_EQ(interpreter.nodes_size(), 1); + EXPECT_EQ(interpreter.inputs().size(), 2); + EXPECT_EQ(interpreter.outputs().size(), 1); + ASSERT_EQ(interpreter.execution_plan().size(), 1); + + auto dispatch_delegate_options = CreateDispatchDelegateOptionsPtr(); + LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), + rt.Flatbuffer().Buf().Data()); + auto dispatch_delegate = + CreateDispatchDelegatePtr(std::move(dispatch_delegate_options)); + +#if !defined(__ANDROID__) + GTEST_SKIP() << "The rest of this test is specific to Android devices with a " + "MediaTek NPU"; +#endif + + ASSERT_EQ(interpreter.ModifyGraphWithDelegate(dispatch_delegate.get()), + kTfLiteOk); + + // Create and register tensor buffers for all inputs and outputs. + + std::vector input_buffers; + for (int i = 0; i < interpreter.inputs().size(); ++i) { + auto input_buffer_requirements = + buffer_context.GetBufferRequirement(interpreter.input_tensor(i)); + ASSERT_TRUE(input_buffer_requirements); + ASSERT_EQ((*input_buffer_requirements)->SupportedTypes().Value()[0], + kLiteRtTensorBufferTypeAhwb); + auto input_buffer = + buffer_context.CreateBufferForTensor(interpreter.input_tensor(i)); + ASSERT_TRUE(input_buffer); + ASSERT_TRUE(input_buffer->IsOwned()); + ASSERT_EQ(*input_buffer->BufferType(), kLiteRtTensorBufferTypeAhwb); + auto duplicate_buffer = (*input_buffer).Duplicate(); + ASSERT_TRUE(duplicate_buffer); + auto status = buffer_context.RegisterTensorBuffer( + interpreter.input_tensor(i), std::move(*duplicate_buffer)); + ASSERT_EQ(status, kLiteRtStatusOk); + input_buffers.push_back(std::move(*input_buffer)); + } + + std::vector output_buffers; + for (int i = 0; i < interpreter.outputs().size(); ++i) { + auto output_buffer_requirements = + buffer_context.GetBufferRequirement(interpreter.output_tensor(i)); + ASSERT_TRUE(output_buffer_requirements); + ASSERT_EQ((*output_buffer_requirements)->SupportedTypes().Value()[0], + kLiteRtTensorBufferTypeAhwb); + auto output_buffer = + buffer_context.CreateBufferForTensor(interpreter.output_tensor(i)); + ASSERT_TRUE(output_buffer.HasValue()); + ASSERT_TRUE(output_buffer->IsOwned()); + ASSERT_EQ(*output_buffer->BufferType(), kLiteRtTensorBufferTypeAhwb); + auto duplicate_buffer = (*output_buffer).Duplicate(); + ASSERT_TRUE(duplicate_buffer); + auto status = buffer_context.RegisterTensorBuffer( + interpreter.output_tensor(i), std::move(*duplicate_buffer)); + ASSERT_EQ(status, kLiteRtStatusOk); + output_buffers.push_back(std::move(*output_buffer)); + } + + // Get the list of signatures and check it. + auto signature_defs = interpreter.signature_keys(); + ASSERT_EQ(signature_defs.size(), 1); + + tflite::impl::SignatureRunner* runner = + interpreter.GetSignatureRunner(/*signature_key=*/nullptr); + ASSERT_NE(runner, nullptr); + + EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); + + // Fill model inputs. + ASSERT_STREQ(runner->input_names()[0], "arg0"); + auto& input_0_buffer = input_buffers[0]; + input_0_buffer.Write( + absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size)); + + ASSERT_STREQ(runner->input_names()[1], "arg1"); + auto& input_1_buffer = input_buffers[1]; + input_1_buffer.Write( + absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size)); + + EXPECT_EQ(runner->Invoke(), kTfLiteOk); + + // Check model output. + ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); + auto& output_buffer = output_buffers[0]; + float output_buffer_data[kTestOutputSize]; + auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); + auto read_success = output_buffer.Read(output_span); + ASSERT_TRUE(read_success); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" + << kTestOutputTensor[i]; + } + EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); +} + +TEST(DispatchDelegate, CompiledModel) { + auto model_with_byte_code = + internal::GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), + testing::GetTestFilePath(kNpuFile)); + ASSERT_TRUE(model_with_byte_code); + auto model = Model::CreateFromBuffer(*model_with_byte_code); + ASSERT_TRUE(model); + +#if !defined(__ANDROID__) + GTEST_SKIP() << "The rest of this test is specific to Android devices with a " + "MediaTek NPU"; +#endif + + auto res_compiled_model = CompiledModel::Create(*model, kHwAccelNpu); + ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; + auto& compiled_model = *res_compiled_model; + + auto signatures = model->GetSignatures(); + ASSERT_TRUE(signatures); + EXPECT_EQ(signatures->size(), 1); + auto& signature = signatures->at(0); + auto signature_key = signature.Key(); + EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); + size_t signature_index = 0; + + auto input_buffers_res = compiled_model.CreateInputBuffers(signature_index); + EXPECT_TRUE(input_buffers_res); + auto& input_buffers = *input_buffers_res; + + auto output_buffers_res = compiled_model.CreateOutputBuffers(signature_index); + EXPECT_TRUE(output_buffers_res); + auto& output_buffers = *output_buffers_res; + + // Fill model inputs. + auto input_names = signature.InputNames(); + EXPECT_EQ(input_names.size(), 2); + EXPECT_EQ(input_names.at(0), "arg0"); + EXPECT_EQ(input_names.at(1), "arg1"); + ASSERT_TRUE(input_buffers[0].Write( + absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); + ASSERT_TRUE(input_buffers[1].Write( + absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); + + // Execute model. + compiled_model.Run(signature_index, input_buffers, output_buffers); + + // Check model output. + auto output_names = signature.OutputNames(); + EXPECT_EQ(output_names.size(), 1); + EXPECT_EQ(output_names.at(0), "tfl.custom"); + float output_buffer_data[kTestOutputSize]; + auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); + ASSERT_TRUE(output_buffers[0].Read(output_span)); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" + << kTestOutputTensor[i]; + } + EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); +} + +} // namespace +} // namespace litert diff --git a/tflite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h b/tflite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h new file mode 100644 index 00000000..1df2e1e6 --- /dev/null +++ b/tflite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h @@ -0,0 +1,113 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DISPATCH_DISPATCH_DELEGATE_OPTIONS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DISPATCH_DISPATCH_DELEGATE_OPTIONS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_any.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tflite/experimental/litert/c/litert_environment.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/cc/litert_any.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/core/environment.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" + +class LiteRtDispatchDelegateOptions { + public: + LiteRtDispatchDelegateOptions() { + auto environment = litert::internal::Environment::Instance(); + if (!environment) { + LITERT_LOG(LITERT_WARNING, "LiteRT environment not found"); + return; + } + + auto option = + (*environment)->GetOption(kLiteRtEnvOptionTagDispatchLibraryPath); + if (!option.has_value()) { + return; + } + + if (option->type != kLiteRtAnyTypeString) { + LITERT_LOG(LITERT_WARNING, + "Ingoring option kLiteRtEnvOptionTagDispatchLibraryPath due " + "to invalid value"); + return; + } + + LiteRtDispatchOption dispatch_option = { + /*.name=*/kDispatchOptionSharedLibraryDir, + /*.value=*/*option, + }; + AddOption(dispatch_option); + } + + // Push a new dispatch option. + void AddOption(LiteRtDispatchOption option) { options_.push_back(option); } + + // Get all dispatch options. + const std::vector& GetDispatchOptions() const { + return options_; + } + + // Find a dispatch option under the given name if it exists. + litert::Expected FindDispatchOption(absl::string_view name) const { + for (const auto& option : options_) { + if (option.name != name) { + continue; + } + return litert::ToStdAny(option.value); + } + return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); + } + + private: + std::vector options_; +}; + +// +// Common options +// + +static constexpr absl::string_view kAllocBase = "alloc_base"; + +inline void AddAllocBaseOption(const void* alloc_base, + LiteRtDispatchDelegateOptions& opts) { + LiteRtAny opt; + opt.type = kLiteRtAnyTypeVoidPtr; + opt.ptr_value = alloc_base; + opts.AddOption(LiteRtDispatchOption{kAllocBase.data(), opt}); +} + +inline litert::Expected FindAllocBase( + const LiteRtDispatchDelegateOptions& opts) { + auto alloc_base = opts.FindDispatchOption(kAllocBase); + if (!alloc_base) { + return alloc_base.Error(); + } + return std::any_cast(*alloc_base); +} + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DISPATCH_DISPATCH_DELEGATE_OPTIONS_H_ diff --git a/tflite/experimental/litert/runtime/dispatch/dispatch_delegate_qualcomm_test.cc b/tflite/experimental/litert/runtime/dispatch/dispatch_delegate_qualcomm_test.cc new file mode 100644 index 00000000..67e30fb3 --- /dev/null +++ b/tflite/experimental/litert/runtime/dispatch/dispatch_delegate_qualcomm_test.cc @@ -0,0 +1,284 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include +#include +#include "absl/log/absl_log.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/c/c_api_opaque.h" +#include "tflite/c/common.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_compiled_model_options.h" +#include "tflite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_compiled_model.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tflite/experimental/litert/core/model/model_buffer.h" +#include "tflite/experimental/litert/runtime/external_litert_buffer_context.h" +#include "tflite/experimental/litert/test/common.h" +#include "tflite/experimental/litert/test/testdata/simple_model_test_vectors.h" +#include "tflite/interpreter.h" +#include "tflite/signature_runner.h" + +namespace litert { +namespace { + +using ::litert::testing::MakeRuntimeFromTestFileWithNpuModel; +using ::testing::FloatNear; +using ::testing::Pointwise; + +static constexpr absl::string_view kNpuFile = kQualcommModelFileName; +static constexpr absl::string_view kTfliteFile = "simple_model_npu.tflite"; + +TEST(DispatchDelegate, QualcommCpuBuffer) { + auto runtime = MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile); + ASSERT_TRUE(runtime) << "Failed to initialize tflite interpreter"; + auto& rt = **runtime; + auto& interpreter = rt.Interpreter(); + + litert::internal::ExternalLiteRtBufferContext buffer_context; + interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); + + EXPECT_EQ(interpreter.nodes_size(), 1); + EXPECT_EQ(interpreter.inputs().size(), 2); + EXPECT_EQ(interpreter.outputs().size(), 1); + ASSERT_EQ(interpreter.execution_plan().size(), 1); + + auto dispatch_delegate_options = CreateDispatchDelegateOptionsPtr(); + LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), + rt.Flatbuffer().Buf().Data()); + auto dispatch_delegate = + CreateDispatchDelegatePtr(std::move(dispatch_delegate_options)); + +#if !defined(__ANDROID__) + GTEST_SKIP() << "The rest of this test is specific to Android devices with a " + "Qualcomm HTP"; +#endif + + ASSERT_EQ(interpreter.ModifyGraphWithDelegate(dispatch_delegate.get()), + kTfLiteOk); + + // Get the list of signatures and check it. + auto signature_defs = interpreter.signature_keys(); + ASSERT_EQ(signature_defs.size(), 1); + + tflite::impl::SignatureRunner* runner = + interpreter.GetSignatureRunner(/*signature_key=*/nullptr); + ASSERT_NE(runner, nullptr); + + EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); + + // Fill model inputs. + ASSERT_STREQ(runner->input_names()[0], "arg0"); + auto input_0_tensor = runner->input_tensor("arg0"); + ASSERT_NE(input_0_tensor, nullptr); + auto* input_0 = input_0_tensor->data.f; + std::memcpy(input_0, kTestInput0Tensor, sizeof(kTestInput0Tensor)); + + ASSERT_STREQ(runner->input_names()[1], "arg1"); + auto input_1_tensor = runner->input_tensor("arg1"); + ASSERT_NE(input_1_tensor, nullptr); + auto* input_1 = input_1_tensor->data.f; + std::memcpy(input_1, kTestInput1Tensor, sizeof(kTestInput1Tensor)); + + EXPECT_EQ(runner->Invoke(), kTfLiteOk); + + // Check model output. + ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); + auto output_tensor = runner->output_tensor("tfl.custom"); + ASSERT_NE(output_tensor, nullptr); + auto output = absl::MakeSpan(output_tensor->data.f, kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(::testing::FloatNear(1e-5), kTestOutputTensor)); +} + +TEST(DispatchDelegate, QualcommHwBuffer) { + auto runtime = MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile); + ASSERT_TRUE(runtime) << "Failed to initialize tflite interpreter"; + auto& rt = **runtime; + auto& interpreter = rt.Interpreter(); + + litert::internal::ExternalLiteRtBufferContext buffer_context; + interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); + + EXPECT_EQ(interpreter.nodes_size(), 1); + EXPECT_EQ(interpreter.inputs().size(), 2); + EXPECT_EQ(interpreter.outputs().size(), 1); + ASSERT_EQ(interpreter.execution_plan().size(), 1); + + auto dispatch_delegate_options = CreateDispatchDelegateOptionsPtr(); + LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), + rt.Flatbuffer().Buf().Data()); + auto dispatch_delegate = + CreateDispatchDelegatePtr(std::move(dispatch_delegate_options)); + +#if !defined(__ANDROID__) + GTEST_SKIP() << "The rest of this test is specific to Android devices with a " + "Qualcomm HTP"; +#endif + + ASSERT_EQ(interpreter.ModifyGraphWithDelegate(dispatch_delegate.get()), + kTfLiteOk); + + // Create and register tensor buffers for all inputs and outputs. + + std::vector input_buffers; + for (int i = 0; i < interpreter.inputs().size(); ++i) { + auto input_buffer_requirements = + buffer_context.GetBufferRequirement(interpreter.input_tensor(i)); + ASSERT_TRUE(input_buffer_requirements); + ASSERT_EQ((*input_buffer_requirements)->SupportedTypes().Value()[0], + kLiteRtTensorBufferTypeFastRpc); + auto input_buffer = + buffer_context.CreateBufferForTensor(interpreter.input_tensor(i)); + ASSERT_TRUE(input_buffer); + ASSERT_TRUE(input_buffer->IsOwned()); + ASSERT_EQ(*input_buffer->BufferType(), kLiteRtTensorBufferTypeFastRpc); + auto duplicate_buffer = (*input_buffer).Duplicate(); + ASSERT_TRUE(duplicate_buffer); + auto status = buffer_context.RegisterTensorBuffer( + interpreter.input_tensor(i), std::move(*duplicate_buffer)); + ASSERT_EQ(status, kLiteRtStatusOk); + input_buffers.push_back(std::move(*input_buffer)); + } + + std::vector output_buffers; + for (int i = 0; i < interpreter.outputs().size(); ++i) { + auto output_buffer_requirements = + buffer_context.GetBufferRequirement(interpreter.output_tensor(i)); + ASSERT_TRUE(output_buffer_requirements); + ASSERT_EQ((*output_buffer_requirements)->SupportedTypes().Value()[0], + kLiteRtTensorBufferTypeFastRpc); + auto output_buffer = + buffer_context.CreateBufferForTensor(interpreter.output_tensor(i)); + ASSERT_TRUE(output_buffer.HasValue()); + ASSERT_TRUE(output_buffer->IsOwned()); + ASSERT_EQ(*output_buffer->BufferType(), kLiteRtTensorBufferTypeFastRpc); + auto duplicate_buffer = (*output_buffer).Duplicate(); + ASSERT_TRUE(duplicate_buffer); + auto status = buffer_context.RegisterTensorBuffer( + interpreter.output_tensor(i), std::move(*duplicate_buffer)); + ASSERT_EQ(status, kLiteRtStatusOk); + output_buffers.push_back(std::move(*output_buffer)); + } + + // Get the list of signatures and check it. + auto signature_defs = interpreter.signature_keys(); + ASSERT_EQ(signature_defs.size(), 1); + + tflite::impl::SignatureRunner* runner = + interpreter.GetSignatureRunner(/*signature_key=*/nullptr); + ASSERT_NE(runner, nullptr); + + EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); + + // Fill model inputs. + ASSERT_STREQ(runner->input_names()[0], "arg0"); + auto& input_0_buffer = input_buffers[0]; + input_0_buffer.Write( + absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size)); + + ASSERT_STREQ(runner->input_names()[1], "arg1"); + auto& input_1_buffer = input_buffers[1]; + input_1_buffer.Write( + absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size)); + + EXPECT_EQ(runner->Invoke(), kTfLiteOk); + + // Check model output. + ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); + auto& output_buffer = output_buffers[0]; + float output_buffer_data[kTestOutputSize]; + auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); + auto read_success = output_buffer.Read(output_span); + ASSERT_TRUE(read_success); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" + << kTestOutputTensor[i]; + } + EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); +} + +TEST(DispatchDelegate, CompiledModel) { + auto model_with_byte_code = + internal::GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), + testing::GetTestFilePath(kNpuFile)); + ASSERT_TRUE(model_with_byte_code); + auto model = Model::CreateFromBuffer(*model_with_byte_code); + ASSERT_TRUE(model); + +#if !defined(__ANDROID__) + GTEST_SKIP() << "The rest of this test is specific to Android devices with a " + "Qualcomm HTP"; +#endif + + auto res_compiled_model = CompiledModel::Create(*model, kHwAccelNpu); + ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; + auto& compiled_model = *res_compiled_model; + + auto signatures = model->GetSignatures(); + ASSERT_TRUE(signatures); + EXPECT_EQ(signatures->size(), 1); + auto& signature = signatures->at(0); + auto signature_key = signature.Key(); + EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); + size_t signature_index = 0; + + auto input_buffers_res = compiled_model.CreateInputBuffers(signature_index); + EXPECT_TRUE(input_buffers_res); + auto& input_buffers = *input_buffers_res; + + auto output_buffers_res = compiled_model.CreateOutputBuffers(signature_index); + EXPECT_TRUE(output_buffers_res); + auto& output_buffers = *output_buffers_res; + + // Fill model inputs. + auto input_names = signature.InputNames(); + EXPECT_EQ(input_names.size(), 2); + EXPECT_EQ(input_names.at(0), "arg0"); + EXPECT_EQ(input_names.at(1), "arg1"); + ASSERT_TRUE(input_buffers[0].Write( + absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); + ASSERT_TRUE(input_buffers[1].Write( + absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); + + // Execute model. + compiled_model.Run(signature_index, input_buffers, output_buffers); + + // Check model output. + auto output_names = signature.OutputNames(); + EXPECT_EQ(output_names.size(), 1); + EXPECT_EQ(output_names.at(0), "tfl.custom"); + float output_buffer_data[kTestOutputSize]; + auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); + ASSERT_TRUE(output_buffers[0].Read(output_span)); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" + << kTestOutputTensor[i]; + } + EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); +} + +} // namespace +} // namespace litert diff --git a/tflite/experimental/litert/runtime/dispatch/litert_dispatch.cc b/tflite/experimental/litert/runtime/dispatch/litert_dispatch.cc new file mode 100644 index 00000000..fa729f95 --- /dev/null +++ b/tflite/experimental/litert/runtime/dispatch/litert_dispatch.cc @@ -0,0 +1,513 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" + +#include + +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_event.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch_api.h" + +#define INVOKE_FUNC(function, ...) \ + if (!TheApi.interface) { \ + LITERT_LOG(LITERT_ERROR, "Dispatch API interface not found"); \ + return kLiteRtStatusErrorRuntimeFailure; \ + } \ + if (!TheApi.interface->function) { \ + LITERT_LOG(LITERT_ERROR, #function " not found"); \ + return kLiteRtStatusErrorRuntimeFailure; \ + } \ + return TheApi.interface->function(__VA_ARGS__); + +#define INVOKE_ASYNC_FUNC(function, ...) \ + if (!TheApi.async_interface) { \ + LITERT_LOG(LITERT_ERROR, "Dispatch API async interface not found"); \ + return kLiteRtStatusErrorRuntimeFailure; \ + } \ + if (!TheApi.async_interface->function) { \ + LITERT_LOG(LITERT_ERROR, #function " not found"); \ + return kLiteRtStatusErrorRuntimeFailure; \ + } \ + return TheApi.async_interface->function(__VA_ARGS__); + +#define INVOKE_GRAPH_FUNC(function, ...) \ + if (!TheApi.graph_interface) { \ + LITERT_LOG(LITERT_ERROR, "Dispatch API graoh interface not found"); \ + return kLiteRtStatusErrorRuntimeFailure; \ + } \ + if (!TheApi.graph_interface->function) { \ + LITERT_LOG(LITERT_ERROR, #function " not found"); \ + return kLiteRtStatusErrorRuntimeFailure; \ + } \ + return TheApi.graph_interface->function(__VA_ARGS__); + +namespace { + +constexpr const char* kSharedLibName = "libLiteRtDispatch.so"; + +bool IsTheApiInitialized = false; +LiteRtDispatchApi TheApi = { + /*.version=*/{/*.major=*/0, /*.minor=*/0, /*.patch=*/0}, + /*.interface=*/nullptr, + /*.async_interface=*/nullptr, + /*.graph_interface=*/nullptr, +}; + +LiteRtStatus Initialize(const LiteRtDispatchOption* options, int num_options) { + INVOKE_FUNC(initialize, options, num_options); +} + +std::string GetSharedLibraryPath(const LiteRtDispatchOption* options, + int num_options) { + for (auto i = 0; i < num_options; ++i) { + auto& option = options[i]; + if (!strcmp(option.name, kDispatchOptionSharedLibraryDir)) { + return absl::StrFormat("%s/%s", option.value.str_value, kSharedLibName); + } + } + return kSharedLibName; +} + +} // namespace + +// ///////////////////////////////////////////////////////////////////////////// +// Basic Execution API +// ///////////////////////////////////////////////////////////////////////////// + +LiteRtStatus LiteRtDispatchInitialize(const LiteRtDispatchOption* options, + int num_options) { + if (IsTheApiInitialized) { + return kLiteRtStatusOk; + } + + auto shared_lib_path = GetSharedLibraryPath(options, num_options); + void* lib_handle = ::dlopen(shared_lib_path.data(), RTLD_NOW | RTLD_LOCAL); + if (!lib_handle) { + LITERT_LOG(LITERT_ERROR, "Failed to load dispatch library: %s", + ::dlerror()); + return kLiteRtStatusErrorRuntimeFailure; + } + + using LiteRtDispatchGetApi_t = LiteRtStatus (*)(LiteRtDispatchApi*); + auto LiteRtDispatchGetApi = reinterpret_cast( + ::dlsym(lib_handle, "LiteRtDispatchGetApi")); + if (!LiteRtDispatchGetApi) { + ::dlclose(lib_handle); + LITERT_LOG(LITERT_ERROR, "LiteRtDispatchGetApi not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + if (auto status = LiteRtDispatchGetApi(&TheApi); status != kLiteRtStatusOk) { + ::dlclose(lib_handle); + return status; + } + + if (TheApi.version.major != LITERT_API_VERSION_MAJOR) { + ::dlclose(lib_handle); + LITERT_LOG( + LITERT_ERROR, + "Unsupported Dispatch API runtime version, found version %d.%d.%d and " + "expected version %d.%d.%d", + TheApi.version.major, TheApi.version.minor, TheApi.version.patch, + LITERT_API_VERSION_MAJOR, LITERT_API_VERSION_MINOR, + LITERT_API_VERSION_PATCH); + return kLiteRtStatusErrorRuntimeFailure; + } + + auto status = Initialize(options, num_options); + if (status == kLiteRtStatusOk) { + IsTheApiInitialized = true; + } + return status; +} + +LiteRtStatus LiteRtDispatchGetApiVersion(LiteRtApiVersion* api_version) { + if (!api_version) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + *api_version = TheApi.version; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtDispatchGetVendorId(const char** vendor_id) { + if (!vendor_id) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(get_vendor_id, vendor_id); +} + +LiteRtStatus LiteRtDispatchGetBuildId(const char** build_id) { + if (!build_id) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(get_build_id, build_id); +} + +LiteRtStatus LiteRtDispatchGetCapabilities(int* capabilities) { + if (!capabilities) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(get_capabilities, capabilities); +} + +LiteRtStatus LiteRtDispatchDeviceContextCreate( + LiteRtDispatchDeviceContext* device_context) { + if (!device_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(device_context_create, device_context); +} + +LiteRtStatus LiteRtDispatchDeviceContextDestroy( + LiteRtDispatchDeviceContext device_context) { + if (!device_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(device_context_destroy, device_context); +} + +LiteRtStatus LiteRtDispatchGetInputRequirements( + LiteRtDispatchInvocationContext invocation_context, int input_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements) { + if (!invocation_context || !tensor_type || !tensor_buffer_requirements) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(get_input_requirements, invocation_context, input_index, + tensor_type, tensor_buffer_requirements); +} + +LiteRtStatus LiteRtDispatchGetOutputRequirements( + LiteRtDispatchInvocationContext invocation_context, int output_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements) { + if (!invocation_context || !tensor_type || !tensor_buffer_requirements) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(get_output_requirements, invocation_context, output_index, + tensor_type, tensor_buffer_requirements); +} + +LiteRtStatus LiteRtDispatchRegisterTensorBuffer( + LiteRtDispatchDeviceContext device_context, + LiteRtTensorBuffer tensor_buffer, + LiteRtTensorBufferHandle* tensor_buffer_handle) { + if (!device_context || !tensor_buffer || !tensor_buffer_handle) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(register_tensor_buffer, device_context, tensor_buffer, + tensor_buffer_handle); +} + +LiteRtStatus LiteRtDispatchUnregisterTensorBuffer( + LiteRtDispatchDeviceContext device_context, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (!device_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(unregister_tensor_buffer, device_context, tensor_buffer_handle); +} + +LiteRtStatus LiteRtDispatchInvocationContextCreate( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType exec_type, const void* exec_bytecode_ptr, + size_t exec_bytecode_size, const char* function_name, int num_inputs, + int num_outputs, LiteRtDispatchInvocationContext* invocation_context) { + if (!device_context || !exec_bytecode_ptr || !invocation_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(invocation_context_create, device_context, exec_type, + exec_bytecode_ptr, exec_bytecode_size, function_name, num_inputs, + num_outputs, invocation_context); +} + +LiteRtStatus LiteRtDispatchInvocationContextDestroy( + LiteRtDispatchInvocationContext invocation_context) { + if (!invocation_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(invocation_context_destroy, invocation_context); +} + +LiteRtStatus LiteRtDispatchAttachInput( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (!invocation_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(attach_input, invocation_context, graph_input_index, + tensor_buffer_handle); +} + +LiteRtStatus LiteRtDispatchAttachOutput( + LiteRtDispatchInvocationContext invocation_context, int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (!invocation_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + if (!TheApi.interface) { + LITERT_LOG(LITERT_ERROR, "Dispatch API interface not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + if (!TheApi.interface->attach_output) { + LITERT_LOG(LITERT_ERROR, "attach_output_tensor_buffer not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + INVOKE_FUNC(attach_output, invocation_context, graph_output_index, + tensor_buffer_handle); +} + +LiteRtStatus LiteRtDispatchDetachInput( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (!invocation_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(detach_input, invocation_context, graph_input_index, + tensor_buffer_handle); +} + +LiteRtStatus LiteRtDispatchDetachOutput( + LiteRtDispatchInvocationContext invocation_context, int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (!invocation_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(detach_output, invocation_context, graph_output_index, + tensor_buffer_handle); +} + +LiteRtStatus LiteRtDispatchInvoke( + LiteRtDispatchInvocationContext invocation_context) { + if (!invocation_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(invoke, invocation_context); +} + +// ///////////////////////////////////////////////////////////////////////////// +// Async Execution API +// ///////////////////////////////////////////////////////////////////////////// + +LiteRtStatus LiteRtDispatchAttachInputEvent( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtEvent input_event) { + if (!invocation_context || !input_event) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_ASYNC_FUNC(attach_input_event, invocation_context, graph_input_index, + input_event); +} + +LiteRtStatus LiteRtDispatchInvokeAsync( + LiteRtDispatchInvocationContext invocation_context, int num_output_events, + LiteRtEvent* output_events) { + if (!invocation_context || !output_events) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_ASYNC_FUNC(invoke_async, invocation_context, num_output_events, + output_events); +} + +// ///////////////////////////////////////////////////////////////////////////// +// Graph Execution API +// ///////////////////////////////////////////////////////////////////////////// + +LiteRtStatus LiteRtDispatchGraphCreate( + LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph* graph) { + if (!device_context || !graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(graph_create, device_context, graph); +} + +LiteRtStatus LiteRtDispatchGraphDestroy(LiteRtDispatchGraph graph) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(graph_destroy, graph); +} + +LiteRtStatus LiteRtDispatchAddNode(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, + LiteRtDispatchNodeType node_type) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(add_node, graph, node_id, node_type); +} + +LiteRtStatus LiteRtDispatchAddEdge(LiteRtDispatchGraph graph, + LiteRtDispatchEdgeId edge_id) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(add_edge, graph, edge_id); +} + +LiteRtStatus LiteRtDispatchConnectNodeInput(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, + int input_index, + LiteRtDispatchEdgeId edge_id) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(connect_node_input, graph, node_id, input_index, edge_id); +} + +LiteRtStatus LiteRtDispatchConnectNodeOutput(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, + int output_index, + LiteRtDispatchEdgeId edge_id) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(connect_node_output, graph, node_id, output_index, edge_id); +} + +LiteRtStatus LiteRtDispatchConnectGraphInput(LiteRtDispatchGraph graph, + int input_index, + LiteRtDispatchEdgeId edge_id) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(connect_graph_input, graph, input_index, edge_id); +} + +LiteRtStatus LiteRtDispatchConnectGraphOutput(LiteRtDispatchGraph graph, + int output_index, + LiteRtDispatchEdgeId edge_id) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(connect_graph_output, graph, output_index, edge_id); +} + +LiteRtStatus LiteRtDispatchLoadExecutable( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType type, const void* bytecode, + size_t bytecode_size, LiteRtDispatchExecutableHandle* exec_handle) { + if (!device_context || !bytecode || !exec_handle) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + if (!TheApi.graph_interface) { + LITERT_LOG(LITERT_ERROR, "Dispatch API graph interface not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + if (!TheApi.graph_interface->load_executable) { + LITERT_LOG(LITERT_ERROR, "load_executable not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + INVOKE_GRAPH_FUNC(load_executable, device_context, type, bytecode, + bytecode_size, exec_handle); +} + +LiteRtStatus LiteRtDispatchUnloadExecutable( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableHandle exec_handle) { + if (!device_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(unload_executable, device_context, exec_handle); +} + +LiteRtStatus LiteRtDispatchAssignNodeFunction( + LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, + LiteRtDispatchExecutableHandle exec_handle, const char* function_name) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(assign_node_function, graph, node_id, exec_handle, + function_name); +} + +LiteRtStatus LiteRtDispatchAnnotateGraph(LiteRtDispatchGraph graph, + const char* key, const char* value) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(annotate_graph, graph, key, value); +} + +LiteRtStatus LiteRtDispatchAnnotateNode(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, + const char* key, const char* value) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(annotate_node, graph, node_id, key, value); +} + +LiteRtStatus LiteRtDispatchAnnotateEdge(LiteRtDispatchGraph graph, + LiteRtDispatchEdgeId edge_id, + const char* key, const char* value) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(annotate_edge, graph, edge_id, key, value); +} + +LiteRtStatus LiteRtDispatchInvocationContextCreateFromGraph( + LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph, + LiteRtDispatchInvocationContext* invocation_context) { + if (!device_context || !graph || !invocation_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(invocation_context_create_from_graph, device_context, graph, + invocation_context); +} diff --git a/tflite/experimental/litert/runtime/dmabuf_buffer.cc b/tflite/experimental/litert/runtime/dmabuf_buffer.cc new file mode 100644 index 00000000..0e0b4eed --- /dev/null +++ b/tflite/experimental/litert/runtime/dmabuf_buffer.cc @@ -0,0 +1,180 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/runtime/dmabuf_buffer.h" + +#include +#include + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/const_init.h" +#include "absl/container/node_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +namespace litert { +namespace internal { + +namespace { + +class DmaBufLibrary { + public: + using Ptr = std::unique_ptr; + + ~DmaBufLibrary() { + if (allocator_) { + free_allocator_(allocator_); + } + } + + static Expected Create() { + DlHandle dlhandle(::dlopen("libdmabufheap.so", RTLD_LAZY | RTLD_LOCAL), + ::dlclose); + if (!dlhandle) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "libdmabufheap.so not found"); + } + + auto create_allocator = reinterpret_cast( + ::dlsym(dlhandle.get(), "CreateDmabufHeapBufferAllocator")); + if (!create_allocator) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "CreateDmabufHeapBufferAllocator not found"); + } + + auto free_allocator = reinterpret_cast( + ::dlsym(dlhandle.get(), "FreeDmabufHeapBufferAllocator")); + if (!free_allocator) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "FreeDmabufHeapBufferAllocator not found"); + } + + auto alloc_buffer = reinterpret_cast( + ::dlsym(dlhandle.get(), "DmabufHeapAlloc")); + if (!alloc_buffer) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "DmabufHeapAlloc not found"); + } + + void* allocator = create_allocator(); + if (!allocator) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "CreateDmabufHeapBufferAllocator failed"); + } + + return Ptr(new DmaBufLibrary(std::move(dlhandle), allocator, free_allocator, + alloc_buffer)); + } + + Expected Alloc(size_t size) { + int fd = alloc_buffer_(allocator_, kDmaBufHeap, size, /*flags=*/0, + /*legacy_align=*/0); + if (fd < 0) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to allocate DMA-BUF buffer"); + } + void* addr = + ::mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (addr == MAP_FAILED) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to mem-map DMA-BUF buffer"); + } + records_[addr] = Record{.fd = fd, .addr = addr, .size = size}; + return DmaBufBuffer{.fd = fd, .addr = addr}; + } + + void Free(void* addr) { + auto iter = records_.find(addr); + if (iter == records_.end()) { + return; + } + auto& record = iter->second; + ::munmap(record.addr, record.size); + ::close(record.fd); + records_.erase(iter); + } + + private: + static constexpr const char* kDmaBufHeap = "system"; + + struct Record { + int fd; + void* addr; + size_t size; + }; + + using DlHandle = std::unique_ptr; + using CreateAllocator = void* (*)(); + using FreeAllocator = void (*)(void*); + using AllocBuffer = int (*)(void*, const char*, size_t, unsigned int, size_t); + + DmaBufLibrary(DlHandle&& dlhandle, void* allocator, + FreeAllocator free_allocator, AllocBuffer alloc_buffer) + : dlhandle_(std::move(dlhandle)) { + allocator_ = allocator; + free_allocator_ = free_allocator; + alloc_buffer_ = alloc_buffer; + } + + DlHandle dlhandle_; + void* allocator_; + FreeAllocator free_allocator_; + AllocBuffer alloc_buffer_; + absl::node_hash_map records_; +}; + +DmaBufLibrary* TheDmaBufLibrary; +ABSL_CONST_INIT absl::Mutex TheMutex(absl::kConstInit); + +Expected InitLibraryIfNeededUnlocked() { + if (!TheDmaBufLibrary) { + if (auto library = DmaBufLibrary::Create(); library) { + TheDmaBufLibrary = library->release(); + } else { + return Unexpected(library.Error()); + } + } + return {}; +} + +} // namespace + +bool DmaBufBuffer::IsSupported() { + absl::MutexLock lock(&TheMutex); + auto status = InitLibraryIfNeededUnlocked(); + return static_cast(status); +} + +Expected DmaBufBuffer::Alloc(size_t size) { + absl::MutexLock lock(&TheMutex); + if (auto status = InitLibraryIfNeededUnlocked(); !status) { + return Unexpected(status.Error()); + } + return TheDmaBufLibrary->Alloc(size); +} + +void DmaBufBuffer::Free(void* addr) { + absl::MutexLock lock(&TheMutex); + if (TheDmaBufLibrary) { + TheDmaBufLibrary->Free(addr); + } +} + +} // namespace internal +} // namespace litert diff --git a/tflite/experimental/litert/runtime/dmabuf_buffer.h b/tflite/experimental/litert/runtime/dmabuf_buffer.h new file mode 100644 index 00000000..0d3387c5 --- /dev/null +++ b/tflite/experimental/litert/runtime/dmabuf_buffer.h @@ -0,0 +1,35 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DMABUF_BUFFER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DMABUF_BUFFER_H_ + +#include "tflite/experimental/litert/cc/litert_expected.h" + +namespace litert { +namespace internal { + +struct DmaBufBuffer { + int fd; + void* addr; + + static bool IsSupported(); + static Expected Alloc(size_t size); + static void Free(void* addr); +}; + +} // namespace internal +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DMABUF_BUFFER_H_ diff --git a/tflite/experimental/litert/runtime/event.cc b/tflite/experimental/litert/runtime/event.cc new file mode 100644 index 00000000..b5b72ff9 --- /dev/null +++ b/tflite/experimental/litert/runtime/event.cc @@ -0,0 +1,71 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/runtime/event.h" + +#include +#include +#include + +#include +#include +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" + +LiteRtStatus LiteRtEventT::Wait(int64_t timeout_in_ms) { +#if LITERT_HAS_SYNC_FENCE_SUPPORT + struct pollfd fds = { + .fd = fd, + .events = POLLIN, + }; + + int ret; + do { + ret = ::poll(&fds, 1, timeout_in_ms); + if (ret == 1) { + break; + } else if (ret == 0) { + LITERT_LOG(LITERT_WARNING, "Timeout expired: %d", timeout_in_ms); + return kLiteRtStatusErrorTimeoutExpired; + } + } while (ret == -1 && (errno == EINTR || errno == EAGAIN)); + + if (ret < 0) { + LITERT_LOG(LITERT_ERROR, "Error waiting for fence: %s", ::strerror(errno)); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; + +#else + LITERT_LOG(LITERT_ERROR, "LiteRtEventWait not implemented for this platform"); + return kLiteRtStatusErrorUnsupported; +#endif +} + +namespace { +inline bool IsFdValid(int fd) { + return ::fcntl(fd, F_GETFD) != -1 || errno != EBADF; +} +} // namespace + +LiteRtEventT::~LiteRtEventT() { +#if LITERT_HAS_SYNC_FENCE_SUPPORT + if (owns_fd && IsFdValid(fd)) { + ::close(fd); + } +#endif +} diff --git a/tflite/experimental/litert/runtime/event.h b/tflite/experimental/litert/runtime/event.h new file mode 100644 index 00000000..bb4e8ce2 --- /dev/null +++ b/tflite/experimental/litert/runtime/event.h @@ -0,0 +1,31 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EVENT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EVENT_H_ + +#include + +#include "tflite/experimental/litert/c/litert_common.h" + +struct LiteRtEventT { +#if LITERT_HAS_SYNC_FENCE_SUPPORT + int fd; + bool owns_fd; +#endif + ~LiteRtEventT(); + LiteRtStatus Wait(int64_t timeout_in_ms); +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EVENT_H_ diff --git a/tflite/experimental/litert/runtime/external_litert_buffer_context.cc b/tflite/experimental/litert/runtime/external_litert_buffer_context.cc new file mode 100644 index 00000000..a021c86a --- /dev/null +++ b/tflite/experimental/litert/runtime/external_litert_buffer_context.cc @@ -0,0 +1,125 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/runtime/external_litert_buffer_context.h" + +#include + +#include "tflite/c/c_api_opaque.h" +#include "tflite/c/c_api_types.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/runtime/tfl_utils.h" + +namespace litert { +namespace internal { + +LiteRtStatus ExternalLiteRtBufferContext::RegisterBufferRequirement( + const TfLiteOpaqueTensor* tensor, + TensorBufferRequirements&& buffer_requirements) { + if (buffer_requirements_.find(tensor) != buffer_requirements_.end()) { + LITERT_LOG(LITERT_ERROR, + "RegisterBufferRequirement already exists for tensor: %p", + tensor); + return kLiteRtStatusErrorRuntimeFailure; + } + buffer_requirements_[tensor] = std::move(buffer_requirements); + return kLiteRtStatusOk; +} + +litert::Expected +ExternalLiteRtBufferContext::GetBufferRequirement( + const TfLiteOpaqueTensor* tensor) { + auto it = buffer_requirements_.find(tensor); + if (it == buffer_requirements_.end()) { + return litert::Unexpected(kLiteRtStatusErrorNotFound, + "Buffer requirement not found"); + } + return &(it->second); +} + +LiteRtStatus ExternalLiteRtBufferContext::RegisterTensorBuffer( + const TfLiteOpaqueTensor* tensor, TensorBuffer&& tensor_buffer) { + tensor_buffers_[tensor] = std::move(tensor_buffer); + return kLiteRtStatusOk; +} + +litert::Expected ExternalLiteRtBufferContext::GetTensorBuffer( + const TfLiteOpaqueTensor* tensor) { + auto it = tensor_buffers_.find(tensor); + if (it == tensor_buffers_.end()) { + return litert::Unexpected(kLiteRtStatusErrorNotFound, + "Tensor buffer not found"); + } + + auto duplicate_tensor_buffer = it->second.Duplicate(); + if (!duplicate_tensor_buffer) { + return litert::Unexpected(duplicate_tensor_buffer.Error()); + } + return std::move(duplicate_tensor_buffer.Value()); +} + +litert::Expected +ExternalLiteRtBufferContext::CreateBufferForTensor( + const TfLiteOpaqueTensor* tensor) { + auto tensor_buffer_requirements = GetBufferRequirement(tensor); + if (!tensor_buffer_requirements) { + return litert::Unexpected(tensor_buffer_requirements.Error()); + } + + auto tensor_type = litert::internal::ConvertTensorType(tensor); + if (!tensor_type) { + return litert::Unexpected(tensor_type.Error()); + } + + auto supported_tensor_buffer_types = + (*tensor_buffer_requirements)->SupportedTypes(); + if (!supported_tensor_buffer_types) { + return litert::Unexpected(supported_tensor_buffer_types.Error()); + } + if (supported_tensor_buffer_types->empty()) { + return litert::Unexpected( + kLiteRtStatusErrorRuntimeFailure, + "Insufficient number of supported tensor buffer types"); + } + + // For now we simply pick the first buffer type that's supported. + LiteRtTensorBufferType tensor_buffer_type = + (*supported_tensor_buffer_types)[0]; + + auto tensor_buffer_size = (*tensor_buffer_requirements)->BufferSize(); + if (!tensor_buffer_size) { + return litert::Unexpected(tensor_buffer_size.Error()); + } + auto litert_tensor_type = static_cast(*tensor_type); + + LiteRtTensorBuffer litert_tensor_buffer; + if (auto status = LiteRtCreateManagedTensorBuffer( + tensor_buffer_type, &litert_tensor_type, *tensor_buffer_size, + &litert_tensor_buffer); + status != kLiteRtStatusOk) { + return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to create managed tensor buffer"); + } + + return TensorBuffer(litert_tensor_buffer, /*owned=*/true); +} + +} // namespace internal +} // namespace litert diff --git a/tflite/experimental/litert/runtime/external_litert_buffer_context.h b/tflite/experimental/litert/runtime/external_litert_buffer_context.h new file mode 100644 index 00000000..b3e58824 --- /dev/null +++ b/tflite/experimental/litert/runtime/external_litert_buffer_context.h @@ -0,0 +1,115 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EXTERNAL_LITERT_BUFFER_CONTEXT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EXTERNAL_LITERT_BUFFER_CONTEXT_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tflite/c/c_api_opaque.h" +#include "tflite/c/c_api_types.h" +#include "tflite/c/common.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer_requirements.h" + +namespace litert { +namespace internal { + +class ExternalLiteRtBufferContext : public TfLiteExternalContext { + public: + ExternalLiteRtBufferContext() = default; + ~ExternalLiteRtBufferContext() = default; + + // Registers a tensor buffer requirements for the given tensor. + // The registered TensorBufferRequirements object is owned by + // ExternalLiteRtBufferContext. + // Note: Currently, the system pre-registers tensor buffer requirements before + // they're actually used. A more efficient approach would be to query + // DelegateKernel only when these requirements are needed. + LiteRtStatus RegisterBufferRequirement( + const TfLiteOpaqueTensor* tensor, + TensorBufferRequirements&& buffer_requirements); + + inline LiteRtStatus RegisterBufferRequirement( + const TfLiteTensor* tensor, + TensorBufferRequirements&& buffer_requirements) { + return RegisterBufferRequirement( + reinterpret_cast(tensor), + std::move(buffer_requirements)); + } + + // Gets a registered tensor buffer requirements for the given tensor. + // The returned TensorBufferRequirements object is still owned by + // ExternalLiteRtBufferContext. + litert::Expected GetBufferRequirement( + const TfLiteOpaqueTensor* tensor); + + inline litert::Expected GetBufferRequirement( + const TfLiteTensor* tensor) { + return GetBufferRequirement( + reinterpret_cast(tensor)); + } + + // Registers a tensor buffer for the given tensor. + // The registered TensorBuffer object is owned by ExternalLiteRtBufferContext. + LiteRtStatus RegisterTensorBuffer(const TfLiteOpaqueTensor* tensor, + TensorBuffer&& tensor_buffer); + + inline LiteRtStatus RegisterTensorBuffer(const TfLiteTensor* tensor, + TensorBuffer&& tensor_buffer) { + return RegisterTensorBuffer( + reinterpret_cast(tensor), + std::move(tensor_buffer)); + } + + // Gets a registered tensor buffer for the given tensor. + // The returned TensorBuffer object is duplication (reference counted) + // of registered TensorBuffer. + litert::Expected GetTensorBuffer( + const TfLiteOpaqueTensor* tensor); + + inline litert::Expected GetTensorBuffer( + const TfLiteTensor* tensor) { + return GetTensorBuffer(reinterpret_cast(tensor)); + } + + // Creates a tensor buffer for the given tensor. + // The callers takes ownership of the returned TensorBuffer object. + litert::Expected CreateBufferForTensor( + const TfLiteOpaqueTensor* tensor); + + inline litert::Expected CreateBufferForTensor( + const TfLiteTensor* tensor) { + return CreateBufferForTensor( + reinterpret_cast(tensor)); + } + + private: + absl::flat_hash_map + buffer_requirements_; + absl::flat_hash_map tensor_buffers_; + + ExternalLiteRtBufferContext(const ExternalLiteRtBufferContext&) = delete; + ExternalLiteRtBufferContext& operator=(const ExternalLiteRtBufferContext&) = + delete; +}; + +} // namespace internal +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EXTERNAL_LITERT_BUFFER_CONTEXT_H_ diff --git a/tflite/experimental/litert/runtime/fastrpc_buffer.cc b/tflite/experimental/litert/runtime/fastrpc_buffer.cc new file mode 100644 index 00000000..df52f5d6 --- /dev/null +++ b/tflite/experimental/litert/runtime/fastrpc_buffer.cc @@ -0,0 +1,143 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/runtime/fastrpc_buffer.h" + +#include + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/const_init.h" +#include "absl/synchronization/mutex.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +namespace litert { +namespace internal { + +namespace { + +class FastRpcMemLibrary { + public: + using Ptr = std::unique_ptr; + + static Expected Create() { + DlHandle dlhandle(::dlopen("libcdsprpc.so", RTLD_NOW | RTLD_LOCAL), + ::dlclose); + if (!dlhandle) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "libcdsprpc.so not found"); + } + + auto rpcmem_alloc = + reinterpret_cast(::dlsym(dlhandle.get(), "rpcmem_alloc")); + if (!rpcmem_alloc) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "rpcmem_alloc not found"); + } + + auto rpcmem_free = + reinterpret_cast(::dlsym(dlhandle.get(), "rpcmem_free")); + if (!rpcmem_free) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "rpcmem_free not found"); + } + + auto rpcmem_to_fd = + reinterpret_cast(::dlsym(dlhandle.get(), "rpcmem_to_fd")); + if (!rpcmem_to_fd) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "rpcmem_to_fd not found"); + } + + return Ptr(new FastRpcMemLibrary(std::move(dlhandle), rpcmem_alloc, + rpcmem_free, rpcmem_to_fd)); + } + + void* Alloc(size_t size) const { + return rpcmem_alloc_(kRpcmemHeapIdSystem, kRpcmemDefaultFlags, size); + } + + void Free(void* buffer) const { return rpcmem_free_(buffer); } + + int ToFd(void* buffer) const { return rpcmem_to_fd_(buffer); } + + private: + static constexpr int kRpcmemHeapIdSystem = 25; + static constexpr uint32_t kRpcmemDefaultFlags = 1; + + using DlHandle = std::unique_ptr; + using RpcMemAlloc = void* (*)(int, uint32_t, int); + using RpcMemFree = void (*)(void*); + using RpcMemToFd = int (*)(void*); + + FastRpcMemLibrary(DlHandle&& dlhandle, RpcMemAlloc rpcmem_alloc, + RpcMemFree rpcmem_free, RpcMemToFd rpcmem_to_fd) + : dlhandle_(std::move(dlhandle)) { + rpcmem_alloc_ = rpcmem_alloc; + rpcmem_free_ = rpcmem_free; + rpcmem_to_fd_ = rpcmem_to_fd; + } + + DlHandle dlhandle_; + RpcMemAlloc rpcmem_alloc_; + RpcMemFree rpcmem_free_; + RpcMemToFd rpcmem_to_fd_; +}; + +FastRpcMemLibrary* TheFastRpcMemLibrary; +ABSL_CONST_INIT absl::Mutex TheMutex(absl::kConstInit); + +Expected InitLibraryIfNeededUnlocked() { + if (!TheFastRpcMemLibrary) { + if (auto library = FastRpcMemLibrary::Create(); library) { + TheFastRpcMemLibrary = library->release(); + } else { + return Unexpected(library.Error()); + } + } + return {}; +} + +} // namespace + +bool FastRpcBuffer::IsSupported() { + absl::MutexLock lock(&TheMutex); + auto status = InitLibraryIfNeededUnlocked(); + return static_cast(status); +} + +Expected FastRpcBuffer::Alloc(size_t size) { + absl::MutexLock lock(&TheMutex); + if (auto status = InitLibraryIfNeededUnlocked(); !status) { + return status.Error(); + } + void* addr = TheFastRpcMemLibrary->Alloc(size); + int fd = TheFastRpcMemLibrary->ToFd(addr); + return FastRpcBuffer{.fd = fd, .addr = addr}; +} + +void FastRpcBuffer::Free(void* addr) { + absl::MutexLock lock(&TheMutex); + if (TheFastRpcMemLibrary) { + TheFastRpcMemLibrary->Free(addr); + } +} + +} // namespace internal +} // namespace litert diff --git a/tflite/experimental/litert/runtime/fastrpc_buffer.h b/tflite/experimental/litert/runtime/fastrpc_buffer.h new file mode 100644 index 00000000..7de4e780 --- /dev/null +++ b/tflite/experimental/litert/runtime/fastrpc_buffer.h @@ -0,0 +1,35 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_FASTRPC_BUFFER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_FASTRPC_BUFFER_H_ + +#include "tflite/experimental/litert/cc/litert_expected.h" + +namespace litert { +namespace internal { + +struct FastRpcBuffer { + int fd; + void* addr; + + static bool IsSupported(); + static Expected Alloc(size_t size); + static void Free(void* addr); +}; + +} // namespace internal +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_FASTRPC_BUFFER_H_ diff --git a/tflite/experimental/litert/runtime/ion_buffer.cc b/tflite/experimental/litert/runtime/ion_buffer.cc new file mode 100644 index 00000000..16fae0f7 --- /dev/null +++ b/tflite/experimental/litert/runtime/ion_buffer.cc @@ -0,0 +1,181 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/runtime/ion_buffer.h" + +#include +#include + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/const_init.h" +#include "absl/container/node_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +namespace litert { +namespace internal { + +namespace { + +class IonLibrary { + public: + using Ptr = std::unique_ptr; + + ~IonLibrary() { + if (client_fd_ > 0) { + ion_close_(client_fd_); + } + } + + static Expected Create() { + DlHandle dlhandle(::dlopen("libion.so", RTLD_NOW | RTLD_LOCAL), ::dlclose); + if (!dlhandle) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "libion.so not found"); + } + + auto ion_open = + reinterpret_cast(::dlsym(dlhandle.get(), "ion_open")); + if (!ion_open) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, "ion_open not found"); + } + + auto ion_close = + reinterpret_cast(::dlsym(dlhandle.get(), "ion_close")); + if (!ion_close) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "ion_close not found"); + } + + auto ion_alloc_fd = + reinterpret_cast(::dlsym(dlhandle.get(), "ion_alloc_fd")); + if (!ion_alloc_fd) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "ion_alloc_fd not found"); + } + + int client_fd = ion_open(); + if (client_fd < 0) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to open ion device"); + } + + return Ptr(new IonLibrary(std::move(dlhandle), client_fd, ion_close, + ion_alloc_fd)); + } + + Expected Alloc(size_t size, size_t alignment) { + int heap_id_mask = 1 << kIonHeapId; + int fd; + if (auto status = ion_alloc_fd_(client_fd_, size, alignment, heap_id_mask, + kIonFlags, &fd); + status != 0) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to allocate DMA-BUF buffer"); + } + void* addr = + ::mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (addr == MAP_FAILED) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to mem-map DMA-BUF buffer"); + } + records_[addr] = Record{.fd = fd, .addr = addr, .size = size}; + return IonBuffer{.fd = fd, .addr = addr}; + } + + void Free(void* addr) { + auto iter = records_.find(addr); + if (iter == records_.end()) { + return; + } + auto& record = iter->second; + ::munmap(record.addr, record.size); + ::close(record.fd); + records_.erase(iter); + } + + private: + static constexpr const int kIonHeapId = 25; + static constexpr const int kIonFlags = 1; + + struct Record { + int fd; + void* addr; + size_t size; + }; + + using DlHandle = std::unique_ptr; + using IonOpen = int (*)(); + using IonClose = int (*)(int); + using IonAllocFd = int (*)(int, size_t, size_t, unsigned int, unsigned int, + int*); + + IonLibrary(DlHandle&& dlhandle, int client_fd, IonClose ion_close, + IonAllocFd ion_alloc_fd) + : dlhandle_(std::move(dlhandle)), + client_fd_(client_fd), + ion_close_(ion_close), + ion_alloc_fd_(ion_alloc_fd) {} + + DlHandle dlhandle_; + int client_fd_; + IonClose ion_close_; + IonAllocFd ion_alloc_fd_; + absl::node_hash_map records_; +}; + +IonLibrary* TheIonLibrary; +ABSL_CONST_INIT absl::Mutex TheMutex(absl::kConstInit); + +Expected InitLibraryIfNeededUnlocked() { + if (!TheIonLibrary) { + if (auto library = IonLibrary::Create(); library) { + TheIonLibrary = library->release(); + } else { + return Unexpected(library.Error()); + } + } + return {}; +} + +} // namespace + +bool IonBuffer::IsSupported() { + absl::MutexLock lock(&TheMutex); + auto status = InitLibraryIfNeededUnlocked(); + return static_cast(status); +} + +Expected IonBuffer::Alloc(size_t size, size_t alignment) { + absl::MutexLock lock(&TheMutex); + if (auto status = InitLibraryIfNeededUnlocked(); !status) { + return status.Error(); + } + return TheIonLibrary->Alloc(size, alignment); +} + +void IonBuffer::Free(void* addr) { + absl::MutexLock lock(&TheMutex); + if (TheIonLibrary) { + TheIonLibrary->Free(addr); + } +} + +} // namespace internal +} // namespace litert diff --git a/tflite/experimental/litert/runtime/ion_buffer.h b/tflite/experimental/litert/runtime/ion_buffer.h new file mode 100644 index 00000000..e981a992 --- /dev/null +++ b/tflite/experimental/litert/runtime/ion_buffer.h @@ -0,0 +1,35 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ION_BUFFER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ION_BUFFER_H_ + +#include "tflite/experimental/litert/cc/litert_expected.h" + +namespace litert { +namespace internal { + +struct IonBuffer { + int fd; + void* addr; + + static bool IsSupported(); + static Expected Alloc(size_t size, size_t alignment); + static void Free(void* addr); +}; + +} // namespace internal +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ION_BUFFER_H_ diff --git a/tflite/experimental/litert/runtime/tensor_buffer.cc b/tflite/experimental/litert/runtime/tensor_buffer.cc new file mode 100644 index 00000000..b94101e7 --- /dev/null +++ b/tflite/experimental/litert/runtime/tensor_buffer.cc @@ -0,0 +1,437 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/runtime/tensor_buffer.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_event.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/core/util/tensor_type_util.h" +#include "tflite/experimental/litert/runtime/ahwb_buffer.h" +#include "tflite/experimental/litert/runtime/dmabuf_buffer.h" +#include "tflite/experimental/litert/runtime/event.h" +#include "tflite/experimental/litert/runtime/fastrpc_buffer.h" +#include "tflite/experimental/litert/runtime/ion_buffer.h" + +using litert::Expected; +using litert::Unexpected; + +namespace { + +template +void Copy(size_t array_size, const T* array, std::vector& vec) { + vec.clear(); + vec.reserve(array_size); + std::copy(array, array + array_size, std::back_inserter(vec)); + array = vec.data(); +} + +} // namespace + +LiteRtTensorBufferT::LiteRtTensorBufferT( + const LiteRtRankedTensorType& tensor_type, + LiteRtTensorBufferType buffer_type, size_t buffer_size, + size_t buffer_offset) + : tensor_type_(tensor_type), + buffer_type_(buffer_type), + buffer_size_(buffer_size), + buffer_offset_(buffer_offset), + ref_(1) { + // Copy local memory passed by the caller. + Copy(tensor_type_.layout.rank, tensor_type_.layout.dimensions, dimensions_); + if (tensor_type_.layout.strides) { + Copy(tensor_type_.layout.rank, tensor_type_.layout.strides, strides_); + } +} + +LiteRtTensorBufferT::~LiteRtTensorBufferT() { + switch (buffer_type()) { + case kLiteRtTensorBufferTypeUnknown: + // Nothing to do. + break; + case kLiteRtTensorBufferTypeHostMemory: + if (auto& buffer = std::get(buffer_); buffer.deallocator) { + buffer.deallocator(buffer.addr); + } + break; + case kLiteRtTensorBufferTypeAhwb: + if (auto& buffer = std::get(buffer_); buffer.deallocator) { + buffer.deallocator(buffer.ahwb); + } + break; + case kLiteRtTensorBufferTypeIon: + if (auto& buffer = std::get(buffer_); buffer.deallocator) { + buffer.deallocator(buffer.addr); + } + break; + case kLiteRtTensorBufferTypeDmaBuf: + if (auto& buffer = std::get(buffer_); buffer.deallocator) { + buffer.deallocator(buffer.addr); + } + break; + case kLiteRtTensorBufferTypeFastRpc: + if (auto& buffer = std::get(buffer_); buffer.deallocator) { + buffer.deallocator(buffer.addr); + } + break; + } +} + +Expected LiteRtTensorBufferT::CreateFromHostMemory( + const LiteRtRankedTensorType& tensor_type, absl::Span host_memory, + LiteRtHostMemoryDeallocator deallocator) { + Ptr tensor_buffer(new LiteRtTensorBufferT( + tensor_type, kLiteRtTensorBufferTypeHostMemory, host_memory.size())); + tensor_buffer->buffer_ = HostBuffer{ + .addr = host_memory.data(), + .deallocator = deallocator, + }; + + if (auto status = tensor_buffer->IsValid(); !status) { + return Unexpected(status.Error()); + } + + return tensor_buffer; +} + +Expected +LiteRtTensorBufferT::CreateManagedOnHostMemory( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { + void* host_memory_ptr; + if (auto rc = ::posix_memalign( + &host_memory_ptr, LITERT_HOST_MEMORY_BUFFER_ALIGNMENT, buffer_size); + rc) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to allocate aligned memory"); + } + + LiteRtHostMemoryDeallocator deallocator = ::free; + auto tensor_buffer = CreateFromHostMemory( + tensor_type, + absl::MakeSpan(static_cast(host_memory_ptr), buffer_size), + deallocator); + if (!tensor_buffer) { + free(host_memory_ptr); + return Unexpected(tensor_buffer.Error()); + } + + return std::move(*tensor_buffer); +} + +Expected LiteRtTensorBufferT::CreateFromAhwb( + const LiteRtRankedTensorType& tensor_type, AHardwareBuffer* ahwb, + size_t ahwb_offset, LiteRtAhwbDeallocator deallocator) { + auto buffer_size = litert::internal::AhwbBuffer::GetSize(ahwb); + if (!buffer_size) { + return Unexpected(buffer_size.Error()); + } + + Ptr tensor_buffer(new LiteRtTensorBufferT( + tensor_type, kLiteRtTensorBufferTypeAhwb, *buffer_size, ahwb_offset)); + tensor_buffer->buffer_ = AhwbBuffer{ + .ahwb = ahwb, + .deallocator = deallocator, + }; + + if (auto status = tensor_buffer->IsValid(); !status) { + return Unexpected(status.Error()); + } + + return tensor_buffer; +} + +Expected LiteRtTensorBufferT::CreateManagedAhwbBuffer( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { + auto buffer = litert::internal::AhwbBuffer::Alloc(buffer_size); + if (!buffer) { + return Unexpected(buffer.Error()); + } + return CreateFromAhwb(tensor_type, buffer->ahwb, /*ahwb_offset=*/0, + /*deallocator=*/litert::internal::AhwbBuffer::Free); +} + +Expected LiteRtTensorBufferT::CreateFromIonBuffer( + const LiteRtRankedTensorType& tensor_type, void* ion_buffer_addr, + int ion_buffer_fd, size_t ion_buffer_size, size_t ion_buffer_offset, + LiteRtIonDeallocator deallocator) { + if (!ion_buffer_addr) { + return Unexpected(kLiteRtStatusErrorInvalidArgument, + "Invalid ION buffer address"); + } + if (ion_buffer_fd < 0) { + return Unexpected(kLiteRtStatusErrorInvalidArgument, + "Invalid ION buffer fd"); + } + + Ptr tensor_buffer( + new LiteRtTensorBufferT(tensor_type, kLiteRtTensorBufferTypeIon, + ion_buffer_size, ion_buffer_offset)); + tensor_buffer->buffer_ = IonBuffer{ + .addr = ion_buffer_addr, + .fd = ion_buffer_fd, + .deallocator = deallocator, + }; + + if (auto status = tensor_buffer->IsValid(); !status) { + return Unexpected(status.Error()); + } + + return tensor_buffer; +} + +Expected LiteRtTensorBufferT::CreateManagedIonBuffer( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { + auto buffer = litert::internal::IonBuffer::Alloc( + buffer_size, /*alignment=*/LITERT_HOST_MEMORY_BUFFER_ALIGNMENT); + if (!buffer) { + return Unexpected(buffer.Error()); + } + return CreateFromIonBuffer(tensor_type, buffer->addr, buffer->fd, buffer_size, + /*ion_buffer_offset=*/0, + litert::internal::IonBuffer::Free); +} + +Expected LiteRtTensorBufferT::CreateFromDmaBufBuffer( + const LiteRtRankedTensorType& tensor_type, void* dmabuf_buffer_addr, + int dmabuf_buffer_fd, size_t dmabuf_buffer_size, + size_t dmabuf_buffer_offset, LiteRtDmaBufDeallocator deallocator) { + if (!dmabuf_buffer_addr) { + return Unexpected(kLiteRtStatusErrorInvalidArgument, + "Invalid DMA-BUF buffer address"); + } + if (dmabuf_buffer_fd < 0) { + return Unexpected(kLiteRtStatusErrorInvalidArgument, + "Invalid DMA-BUF buffer fd"); + } + + Ptr tensor_buffer( + new LiteRtTensorBufferT(tensor_type, kLiteRtTensorBufferTypeDmaBuf, + dmabuf_buffer_size, dmabuf_buffer_offset)); + tensor_buffer->buffer_ = DmaBufBuffer{ + .addr = dmabuf_buffer_addr, + .fd = dmabuf_buffer_fd, + .deallocator = deallocator, + }; + + if (auto status = tensor_buffer->IsValid(); !status) { + return Unexpected(status.Error()); + } + + return tensor_buffer; +} + +Expected +LiteRtTensorBufferT::CreateManagedDmaBufBuffer( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { + auto buffer = litert::internal::DmaBufBuffer::Alloc(buffer_size); + if (!buffer) { + return Unexpected(buffer.Error()); + } + return CreateFromDmaBufBuffer(tensor_type, buffer->addr, buffer->fd, + buffer_size, /*dmabuf_buffer_offset=*/0, + litert::internal::DmaBufBuffer::Free); +} + +Expected LiteRtTensorBufferT::CreateFromFastRpcBuffer( + const LiteRtRankedTensorType& tensor_type, void* fastrpc_buffer_addr, + int fastrpc_buffer_fd, size_t fastrpc_buffer_size, + size_t fastrpc_buffer_offset, LiteRtFastRpcDeallocator deallocator) { + if (!fastrpc_buffer_addr) { + return Unexpected(kLiteRtStatusErrorInvalidArgument, + "Invalid FastRPC buffer address"); + } + if (fastrpc_buffer_fd < 0) { + return Unexpected(kLiteRtStatusErrorInvalidArgument, + "Invalid FastRPC buffer fd"); + } + + Ptr tensor_buffer( + new LiteRtTensorBufferT(tensor_type, kLiteRtTensorBufferTypeFastRpc, + fastrpc_buffer_size, fastrpc_buffer_offset)); + tensor_buffer->buffer_ = FastRpcBuffer{ + .addr = fastrpc_buffer_addr, + .fd = fastrpc_buffer_fd, + .deallocator = deallocator, + }; + + if (auto status = tensor_buffer->IsValid(); !status) { + return Unexpected(status.Error()); + } + + return tensor_buffer; +} + +Expected +LiteRtTensorBufferT::CreateManagedFastRpcBuffer( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { + auto buffer = litert::internal::FastRpcBuffer::Alloc(buffer_size); + if (!buffer) { + return Unexpected(buffer.Error()); + } + return CreateFromFastRpcBuffer(tensor_type, buffer->addr, buffer->fd, + buffer_size, /*fastrpc_buffer_offset=*/0, + litert::internal::FastRpcBuffer::Free); +} + +Expected LiteRtTensorBufferT::CreateManaged( + LiteRtTensorBufferType buffer_type, + const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { + switch (buffer_type) { + case kLiteRtTensorBufferTypeHostMemory: + return CreateManagedOnHostMemory(tensor_type, buffer_size); + case kLiteRtTensorBufferTypeAhwb: + return CreateManagedAhwbBuffer(tensor_type, buffer_size); + case kLiteRtTensorBufferTypeIon: + return CreateManagedIonBuffer(tensor_type, buffer_size); + case kLiteRtTensorBufferTypeDmaBuf: + return CreateManagedDmaBufBuffer(tensor_type, buffer_size); + case kLiteRtTensorBufferTypeFastRpc: + return CreateManagedFastRpcBuffer(tensor_type, buffer_size); + default: + return Unexpected(kLiteRtStatusErrorInvalidArgument, + "Unexpected tensor type"); + } +} + +Expected LiteRtTensorBufferT::IsValid() { + // Check for static dimensions. + for (auto i = 0; i < tensor_type_.layout.rank; ++i) { + if (tensor_type_.layout.dimensions[i] <= 0) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "TensorBuffer must have all static dimensions"); + } + } + + // Check for valid offset. + if (buffer_offset() >= buffer_size()) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Invalid buffer offset"); + } + + // Check for sufficient size. + if (auto num_bytes = litert::internal::GetNumPackedBytes(tensor_type_); + !num_bytes) { + return Unexpected(num_bytes.Error()); + } else if (*num_bytes > buffer_size() - buffer_offset()) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Insufficient buffer size"); + } + + // Check for proper alignment. + if (buffer_type() == kLiteRtTensorBufferTypeHostMemory) { + auto host_buffer = GetHostBuffer(); + if (!host_buffer) { + return Unexpected(host_buffer.Error()); + } + if (reinterpret_cast(*host_buffer) % + LITERT_HOST_MEMORY_BUFFER_ALIGNMENT) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Unaligned host memory pointer"); + } + } + + return {}; +} + +Expected LiteRtTensorBufferT::GetHostBuffer() { + if (buffer_type_ != kLiteRtTensorBufferTypeHostMemory) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Unexpected tensor buffer type"); + } + return std::get(buffer_).addr; +} + +Expected LiteRtTensorBufferT::GetAhwbBuffer() { + if (buffer_type_ != kLiteRtTensorBufferTypeAhwb) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Unexpected tensor buffer type"); + } + return std::get(buffer_).ahwb; +} + +Expected> LiteRtTensorBufferT::GetIonBuffer() { + if (buffer_type_ != kLiteRtTensorBufferTypeIon) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Unexpected tensor buffer type"); + } + auto buffer = std::get(buffer_); + return std::make_pair(buffer.addr, buffer.fd); +} + +Expected> LiteRtTensorBufferT::GetDmaBufBuffer() { + if (buffer_type_ != kLiteRtTensorBufferTypeDmaBuf) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Unexpected tensor buffer type"); + } + auto buffer = std::get(buffer_); + return std::make_pair(buffer.addr, buffer.fd); +} + +Expected> LiteRtTensorBufferT::GetFastRpcBuffer() { + if (buffer_type_ != kLiteRtTensorBufferTypeFastRpc) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Unexpected tensor buffer type"); + } + auto buffer = std::get(buffer_); + return std::make_pair(buffer.addr, buffer.fd); +} + +Expected LiteRtTensorBufferT::Lock(LiteRtEvent event) { + if (event) { + // Only AHWB supports waiting on an input sync fence when locking the + // buffer. For all other buffer types we wait here. + if (buffer_type() != kLiteRtTensorBufferTypeAhwb) { + if (auto status = event->Wait(/*timeout_in_ms*/ -1); + status != kLiteRtStatusOk) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to wait on input event"); + } + } + } + + switch (buffer_type()) { + case kLiteRtTensorBufferTypeHostMemory: + return *GetHostBuffer(); + case kLiteRtTensorBufferTypeAhwb: + return litert::internal::AhwbBuffer::Lock(*GetAhwbBuffer(), event); + case kLiteRtTensorBufferTypeIon: + return GetIonBuffer()->first; + case kLiteRtTensorBufferTypeDmaBuf: + return GetDmaBufBuffer()->first; + case kLiteRtTensorBufferTypeFastRpc: + return GetFastRpcBuffer()->first; + default: + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Unexpected tensor buffer type"); + } +} + +Expected LiteRtTensorBufferT::Unlock() { + if (buffer_type() == kLiteRtTensorBufferTypeAhwb) { + auto ahwb = std::get(buffer_).ahwb; + return litert::internal::AhwbBuffer::Unlock(ahwb); + } + + return {}; +} diff --git a/tflite/experimental/litert/runtime/tensor_buffer.h b/tflite/experimental/litert/runtime/tensor_buffer.h new file mode 100644 index 00000000..d2b69752 --- /dev/null +++ b/tflite/experimental/litert/runtime/tensor_buffer.h @@ -0,0 +1,166 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_H_ + +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +class LiteRtTensorBufferT { + public: + using Ptr = std::unique_ptr; + + ~LiteRtTensorBufferT(); + + // Make this class non-copiable because it includes raw pointers and resource + // handles. + LiteRtTensorBufferT(const LiteRtTensorBufferT&) = delete; + LiteRtTensorBufferT(LiteRtTensorBufferT&&) = delete; + LiteRtTensorBufferT& operator=(const LiteRtTensorBufferT&) = delete; + LiteRtTensorBufferT& operator=(LiteRtTensorBufferT&&) = delete; + + static litert::Expected CreateFromHostMemory( + const LiteRtRankedTensorType& tensor_type, + absl::Span host_memory, + LiteRtHostMemoryDeallocator deallocator = nullptr); + + static litert::Expected CreateFromAhwb( + const LiteRtRankedTensorType& tensor_type, AHardwareBuffer* ahwb, + size_t ahwb_offset, LiteRtAhwbDeallocator deallocator = nullptr); + + static litert::Expected CreateFromIonBuffer( + const LiteRtRankedTensorType& tensor_type, void* ion_buffer_addr, + int ion_buffer_fd, size_t ion_buffer_size, size_t ion_buffer_offset, + LiteRtIonDeallocator deallocator = nullptr); + + static litert::Expected CreateFromDmaBufBuffer( + const LiteRtRankedTensorType& tensor_type, void* dmabuf_buffer_addr, + int dmabuf_buffer_fd, size_t dmabuf_buffer_size, + size_t dmabuf_buffer_offset, + LiteRtDmaBufDeallocator deallocator = nullptr); + + static litert::Expected CreateFromFastRpcBuffer( + const LiteRtRankedTensorType& tensor_type, void* fastrpc_buffer_addr, + int fastrpc_buffer_fd, size_t fastrpc_buffer_size, + size_t fastrpc_buffer_offset, + LiteRtFastRpcDeallocator deallocator = nullptr); + + static litert::Expected CreateManaged( + LiteRtTensorBufferType buffer_type, + const LiteRtRankedTensorType& tensor_type, size_t buffer_size); + + LiteRtRankedTensorType tensor_type() const { return tensor_type_; } + LiteRtTensorBufferType buffer_type() const { return buffer_type_; } + size_t buffer_size() const { return buffer_size_; } + size_t buffer_offset() const { return buffer_offset_; } + + litert::Expected GetHostBuffer(); + litert::Expected GetAhwbBuffer(); + litert::Expected> GetIonBuffer(); + litert::Expected> GetDmaBufBuffer(); + litert::Expected> GetFastRpcBuffer(); + + litert::Expected Lock(LiteRtEvent event = nullptr); + litert::Expected Unlock(); + + // Used to duplicate the current tensor buffer. Internally it increases + // reference count to the underlying buffer. + void Duplicate() const { Ref(); } + + // Increments reference count by one. + void Ref() const { ref_.fetch_add(1, std::memory_order_relaxed); } + + // Decrements reference count by one. If the count remains + // positive, returns false. When the count reaches zero, returns + // true. + bool Unref() const { + if (ref_.fetch_sub(1, std::memory_order_acq_rel) == 1) { + return true; + } + return false; + } + + // Gets the current reference count. + int RefCount() const { return ref_.load(std::memory_order_relaxed); } + + private: + struct HostBuffer { + void* addr; + LiteRtHostMemoryDeallocator deallocator; + }; + + struct AhwbBuffer { + AHardwareBuffer* ahwb; + LiteRtAhwbDeallocator deallocator; + }; + + struct IonBuffer { + void* addr; + int fd; + LiteRtIonDeallocator deallocator; + }; + + struct DmaBufBuffer { + void* addr; + int fd; + LiteRtDmaBufDeallocator deallocator; + }; + + struct FastRpcBuffer { + void* addr; + int fd; + LiteRtFastRpcDeallocator deallocator; + }; + + LiteRtTensorBufferT(const LiteRtRankedTensorType& tensor_type, + LiteRtTensorBufferType buffer_type, size_t buffer_size, + size_t buffer_offset = 0); + + static litert::Expected CreateManagedOnHostMemory( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size); + + static litert::Expected CreateManagedAhwbBuffer( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size); + + static litert::Expected CreateManagedIonBuffer( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size); + + static litert::Expected CreateManagedDmaBufBuffer( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size); + + static litert::Expected CreateManagedFastRpcBuffer( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size); + + litert::Expected IsValid(); + + LiteRtRankedTensorType tensor_type_; + std::vector> dimensions_; + std::vector> strides_; + LiteRtTensorBufferType buffer_type_; + size_t buffer_size_; + size_t buffer_offset_; + std::variant + buffer_; + mutable std::atomic_int_fast32_t ref_; +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_H_ diff --git a/tflite/experimental/litert/runtime/tfl_utils.cc b/tflite/experimental/litert/runtime/tfl_utils.cc new file mode 100644 index 00000000..dcee870d --- /dev/null +++ b/tflite/experimental/litert/runtime/tfl_utils.cc @@ -0,0 +1,99 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/runtime/tfl_utils.h" + +#include +#include +#include +#include + +#include "tflite/c/c_api_opaque.h" +#include "tflite/c/c_api_types.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_detail.h" +#include "tflite/experimental/litert/cc/litert_element_type.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_model.h" + +namespace litert { +namespace internal { + +Expected ConvertElementType(TfLiteType tfl_type) { + switch (tfl_type) { + case kTfLiteNoType: + return ElementType::None; + case kTfLiteBool: + return ElementType::Bool; + case kTfLiteInt4: + return ElementType::Int4; + case kTfLiteInt8: + return ElementType::Int8; + case kTfLiteInt16: + return ElementType::Int16; + case kTfLiteInt32: + return ElementType::Int32; + case kTfLiteInt64: + return ElementType::Int64; + case kTfLiteUInt8: + return ElementType::UInt8; + case kTfLiteUInt16: + return ElementType::UInt16; + case kTfLiteUInt32: + return ElementType::UInt32; + case kTfLiteUInt64: + return ElementType::UInt64; + case kTfLiteFloat16: + return ElementType::Float16; + case kTfLiteBFloat16: + return ElementType::BFloat16; + case kTfLiteFloat32: + return ElementType::Float32; + case kTfLiteFloat64: + return ElementType::Float64; + case kTfLiteComplex64: + return ElementType::Complex64; + case kTfLiteComplex128: + return ElementType::Complex128; + case kTfLiteResource: + return ElementType::TfResource; + case kTfLiteString: + return ElementType::TfString; + case kTfLiteVariant: + return ElementType::TfVariant; + default: + return Unexpected(kLiteRtStatusErrorInvalidArgument, + "Unsupported TfLiteType"); + } +} + +Expected ConvertTensorType( + const TfLiteOpaqueTensor* tfl_opaque_tensor) { + auto tfl_type = TfLiteOpaqueTensorType(tfl_opaque_tensor); + auto element_type = ConvertElementType(tfl_type); + if (!element_type) { + return Unexpected(element_type.Error()); + } + + size_t rank = TfLiteOpaqueTensorNumDims(tfl_opaque_tensor); + SmallVec dimensions(rank); + for (size_t i = 0; i < rank; ++i) { + dimensions[i] = TfLiteOpaqueTensorDim(tfl_opaque_tensor, i); + } + + return RankedTensorType(*element_type, Layout(std::move(dimensions))); +} + +} // namespace internal +} // namespace litert diff --git a/tflite/experimental/litert/runtime/tfl_utils.h b/tflite/experimental/litert/runtime/tfl_utils.h new file mode 100644 index 00000000..723c6390 --- /dev/null +++ b/tflite/experimental/litert/runtime/tfl_utils.h @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TFL_UTILS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TFL_UTILS_H_ + +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_model.h" + +struct TfLiteOpaqueTensor; + +namespace litert { +namespace internal { + +Expected ConvertElementType(TfLiteType tfl_type); + +Expected ConvertTensorType( + const TfLiteOpaqueTensor* tfl_opaque_tensor); + +} // namespace internal +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TFL_UTILS_H_ diff --git a/tflite/experimental/litert/test/BUILD b/tflite/experimental/litert/test/BUILD new file mode 100644 index 00000000..a5c2d132 --- /dev/null +++ b/tflite/experimental/litert/test/BUILD @@ -0,0 +1,122 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +# TODO: b/365295276 - Make custom rule and move to `.sh`. +OUT_DIR = "$(RULEDIR)" + +CONVERTER = "@org_tensorflow//tensorflow/compiler/mlir/lite:tf_tfl_translate" + +CMD = """ +for mlir_file in $(SRCS); do + $(location {converter}) --input-mlir $$mlir_file --o={out_dir}/testdata/$$(basename $$mlir_file .mlir).tflite +done +""".format( + converter = CONVERTER, + out_dir = OUT_DIR, +) + +genrule( + name = "mlir_test_data", + srcs = glob(["testdata/*.mlir"]), + outs = [s.removesuffix(".mlir") + ".tflite" for s in glob(["testdata/*.mlir"])], + cmd = CMD, + tools = [CONVERTER], +) + +filegroup( + name = "tflite_test_data", + srcs = glob(["testdata/*.tflite"]), +) + +cc_library( + name = "common", + testonly = 1, + srcs = [ + "common.cc", + ], + hdrs = [ + "common.h", + ], + deps = [ + ":test_macros", + "//tflite:framework", + "//tflite/c:c_api_opaque", + "//tflite/c:common", + "//tflite/core:cc_api_stable", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/core:filesystem", + "//tflite/experimental/litert/core/model:model_buffer", + "//tflite/experimental/litert/core/util:flatbuffer_tools", + "//tflite/kernels:builtin_ops", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform", + ], +) + +cc_library( + name = "simple_model", + testonly = 1, + hdrs = [ + "testdata/simple_model_test_vectors.h", + ], + data = [ + "testdata/simple_model.tflite", + ], + deps = [ + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/cc:litert_layout", + ], +) + +cc_library( + name = "simple_model_npu", + testonly = 1, + srcs = [], + hdrs = [ + "testdata/simple_model_test_vectors.h", + ], + data = [ + "testdata/simple_model_google_tensor.bin", + "testdata/simple_model_mtk.bin", + "testdata/simple_model_npu.tflite", + "testdata/simple_model_qualcomm.bin", + ], + deps = [ + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/cc:litert_layout", + ], +) + +cc_library( + name = "test_models", + hdrs = ["test_models.h"], + deps = [ + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "test_macros", + hdrs = ["test_macros.h"], + deps = ["//tflite/experimental/litert/c:litert_common"], +) diff --git a/tflite/experimental/litert/test/common.cc b/tflite/experimental/litert/test/common.cc new file mode 100644 index 00000000..b50b9ebc --- /dev/null +++ b/tflite/experimental/litert/test/common.cc @@ -0,0 +1,65 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/test/common.h" + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/cc/litert_model_predicates.h" +#include "tflite/experimental/litert/core/filesystem.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tflite/interpreter.h" +#include "tflite/kernels/register.h" +#include "tsl/platform/platform.h" + +namespace litert { +namespace testing { + +std::string GetTestFilePath(absl::string_view filename) { + static constexpr absl::string_view kTestDataDir = + "tensorflow/lite/experimental/litert/" + "test/testdata/"; + + if constexpr (!tsl::kIsOpenSource) { + return internal::Join({"third_party", kTestDataDir, filename}); + } else { + return internal::Join({kTestDataDir, filename}); + } +} + +Model LoadTestFileModel(absl::string_view filename) { + return *Model::CreateFromFile(GetTestFilePath(filename)); +} + +Expected TflRuntime::CreateFromFlatBuffer( + internal::FlatbufferWrapper::Ptr flatbuffer) { + ::tflite::Interpreter::Ptr interp; + tflite::ops::builtin::BuiltinOpResolver resolver; + tflite::InterpreterBuilder(flatbuffer->FlatbufferModel(), resolver)(&interp); + if (interp == nullptr) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure); + } + return TflRuntime::Ptr( + new TflRuntime(std::move(flatbuffer), std::move(interp))); +} + +} // namespace testing +} // namespace litert diff --git a/tflite/experimental/litert/test/common.h b/tflite/experimental/litert/test/common.h new file mode 100644 index 00000000..02122918 --- /dev/null +++ b/tflite/experimental/litert/test/common.h @@ -0,0 +1,85 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_COMMON_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_COMMON_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/core/model/model_buffer.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tflite/experimental/litert/test/test_macros.h" // IWYU pragma: keep +#include "tflite/interpreter.h" + +namespace litert { +namespace testing { + +std::string GetTestFilePath(absl::string_view filename); + +Model LoadTestFileModel(absl::string_view filename); + +class TflRuntime { + public: + using Ptr = std::unique_ptr; + + static Expected CreateFromFlatBuffer( + internal::FlatbufferWrapper::Ptr flatbuffer); + + ::tflite::Interpreter& Interpreter() { return *interpreter_; } + + const internal::FlatbufferWrapper& Flatbuffer() const { return *flatbuffer_; } + + private: + TflRuntime(internal::FlatbufferWrapper::Ptr flatbuffer, + ::tflite::Interpreter::Ptr interpreter) + : flatbuffer_(std::move(flatbuffer)), + interpreter_(std::move(interpreter)) {} + + internal::FlatbufferWrapper::Ptr flatbuffer_; + ::tflite::Interpreter::Ptr interpreter_; +}; + +inline Expected MakeRuntimeFromTestFile( + absl::string_view filename) { + auto flatbuffer = + internal::FlatbufferWrapper::CreateFromTflFile(GetTestFilePath(filename)); + if (!flatbuffer) { + return flatbuffer.Error(); + } + return TflRuntime::CreateFromFlatBuffer(std::move(*flatbuffer)); +} + +inline Expected MakeRuntimeFromTestFileWithNpuModel( + absl::string_view filename, absl::string_view npu_filename) { + auto buf = internal::GetModelBufWithByteCode(GetTestFilePath(filename), + GetTestFilePath(npu_filename)); + if (!buf) { + return buf.Error(); + } + auto flatbuffer = + internal::FlatbufferWrapper::CreateFromBuffer(std::move(*buf)); + if (!flatbuffer) { + return flatbuffer.Error(); + } + return TflRuntime::CreateFromFlatBuffer(std::move(*flatbuffer)); +} + +} // namespace testing +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_COMMON_H_ diff --git a/tflite/experimental/litert/test/test_macros.h b/tflite/experimental/litert/test/test_macros.h new file mode 100644 index 00000000..c3420707 --- /dev/null +++ b/tflite/experimental/litert/test/test_macros.h @@ -0,0 +1,46 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TEST_MACROS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TEST_MACROS_H_ + +#include "tflite/experimental/litert/c/litert_common.h" // IWYU pragma: keep + +#define _LITERT_ASSERT_RESULT_OK_ASSIGN(decl, expr, result) \ + auto result = (expr); \ + ASSERT_TRUE(result.HasValue()); \ + decl = result.Value(); + +#define LITERT_ASSERT_RESULT_OK_ASSIGN(decl, expr) \ + _LITERT_ASSERT_RESULT_OK_ASSIGN(decl, expr, \ + _CONCAT_NAME(_result, __COUNTER__)) + +#define _LITERT_ASSERT_RESULT_OK_MOVE(decl, expr, result) \ + auto result = (expr); \ + ASSERT_TRUE(result.HasValue()); \ + decl = std::move(result.Value()); + +#define LITERT_ASSERT_RESULT_OK_MOVE(decl, expr) \ + _LITERT_ASSERT_RESULT_OK_MOVE(decl, expr, _CONCAT_NAME(_result, __COUNTER__)) + +#define LITERT_ASSERT_STATUS_HAS_CODE(expr, code) \ + { \ + LiteRtStatus status = (expr); \ + ASSERT_EQ(status, code); \ + } + +#define LITERT_ASSERT_STATUS_OK(expr) \ + LITERT_ASSERT_STATUS_HAS_CODE(expr, kLiteRtStatusOk); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TEST_MACROS_H_ diff --git a/tflite/experimental/litert/test/test_models.h b/tflite/experimental/litert/test/test_models.h new file mode 100644 index 00000000..4dd8f0a1 --- /dev/null +++ b/tflite/experimental/litert/test/test_models.h @@ -0,0 +1,123 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TEST_MODELS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TEST_MODELS_H_ + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + +// /////////////////////////////////////////////////////////////////////////// +// FP32 models. +// /////////////////////////////////////////////////////////////////////////// + +// Attention sub-module of a toy model. +static constexpr absl::string_view kAttentionModel = "attention.tflite"; + +// Attention vector einsum sub-module of a toy LLM. +static constexpr absl::string_view kAttnVecEinsumModel = + "attn_vec_einsum.tflite"; + +// Feed forward sub-module of a toy LLM. +static constexpr absl::string_view kFeedForwardModel = "ff.tflite"; + +// Key einsume sub-module of a toy LLM. +static constexpr absl::string_view kKeyEinsumModel = "k_einsum.tflite"; + +// Value einsum sub-module of a toy LLM. +static constexpr absl::string_view kValueEinsumModel = "v_einsum.tflite"; + +// Query einsum sub-module of a toy LLM. +static constexpr absl::string_view kQueryEinsumModel = "q_einsum.tflite"; + +// RMS Normalization sub-module of a toy LLM. +static constexpr absl::string_view kRMSNormModel = "norm.tflite"; + +// ROPE sub-module of a toy LLM. +static constexpr absl::string_view kROPEModel = "rope.tflite"; + +// ROPE sub-module of a toy LLM, uses embedding_lookup op for sin/cos. +static constexpr absl::string_view kLookUpROPEModel = "lookup_rope.tflite"; + +// Scale dot product attentionsub-module of a toy LLM. +static constexpr absl::string_view kSDPAModel = "sdpa.tflite"; + +// Transformer block sub-module of a toy LLM. +static constexpr absl::string_view kTransformerBlockModel = + "transformer.tflite"; + +// /////////////////////////////////////////////////////////////////////////// +// Quantized models. +// /////////////////////////////////////////////////////////////////////////// + +// Quantized model with a single mul op. +// Mul: <8x100x32x4xint16>, <8x100x32x4xint16> -> <8x100x32x4xint16> +static constexpr absl::string_view kQSimpleMul16x16Model = "mul_quant.tflite"; + +// Quantized model with a mul op and a add op. +// Mul: <8x100x32x4xint16>, <8x100x32x4xint16> -> <8x100x32x4xint16> +// Add: <8x100x32x4xint16>, <8x100x32x4xint16> -> <8x100x32x4xint16> +static constexpr absl::string_view kQMulAdd16x16Model = + "simple_quantized_ops.tflite"; + +// Single add op i16 activations and i8 weights and dynamic shape. +// Add: , -> +static constexpr absl::string_view kQSingleDynAdd16x8Model = + "single_add_default_a16w8_recipe_quantized.tflite"; + +// Single add op i8 activations and i8 weights and dynamic shape. +// Add: , -> +static constexpr absl::string_view kQSingleDynAdd8x8Model = + "single_add_default_a8w8_recipe_quantized.tflite"; + +// Single mul op i16 activations and i8 weights and dynamic shape. +// Mul: , -> +static constexpr absl::string_view kQSingleDynMul16x8Model = + "single_mul_default_a16w8_recipe_quantized.tflite"; + +// Single mul op i8 activations and i8 weights and dynamic shape. +// Mul: , -> +static constexpr absl::string_view kQSingleDynMul8x8Model = + "single_mul_default_a8w8_recipe_quantized.tflite"; + +// Single rsqrt op i16 activations and i8 weights and dynamic shape. +// RSQRT: -> +static constexpr absl::string_view kQSingleDynRsqrt16x8Model = + "single_rsqrt_default_a16w8_recipe_quantized.tflite"; + +// Single rsqrt op i8 activations and i8 weights and dynamic shape. +// RSQRT: -> +static constexpr absl::string_view kQSingleDynRsqrt8x8Model = + "single_rsqrt_default_a8w8_recipe_quantized.tflite"; + +// Quantized einsum model with i16 activations and i8 weights. +static constexpr absl::string_view kQQueryEinsum16x8Model = + "static_w8_a16_quantized_q_einsum.tflite"; + +static constexpr absl::string_view kQKeyEinsum16x8Model = + "static_w8_a16_quantized_k_einsum.tflite"; + +static constexpr absl::string_view kQVauleEinsum16x8Model = + "static_w8_a16_quantized_v_einsum.tflite"; + +static constexpr absl::string_view kQAttnVecEinsum16x8Model = + "static_w8_a16_quantized_attn_vec_einsum.tflite"; + +// All the quantized test models. +static constexpr auto kAllQModels = absl::MakeConstSpan((absl::string_view[]){ + kQSimpleMul16x16Model, kQMulAdd16x16Model, kQSingleDynAdd16x8Model, + kQSingleDynAdd8x8Model, kQSingleDynMul16x8Model, kQSingleDynMul8x8Model, + kQSingleDynRsqrt16x8Model, kQSingleDynRsqrt8x8Model}); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TEST_MODELS_H_ diff --git a/tflite/experimental/litert/test/testdata/add_cst.mlir b/tflite/experimental/litert/test/testdata/add_cst.mlir new file mode 100644 index 00000000..502a32a7 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/add_cst.mlir @@ -0,0 +1,7 @@ +module { +func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %cst = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> + %0 = tfl.add %arg0, %cst {fused_activation_function = "NONE"} : tensor<4xf32> + return %0 : tensor<4xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/add_simple.mlir b/tflite/experimental/litert/test/testdata/add_simple.mlir new file mode 100644 index 00000000..32945b4c --- /dev/null +++ b/tflite/experimental/litert/test/testdata/add_simple.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/cos_mul.mlir b/tflite/experimental/litert/test/testdata/cos_mul.mlir new file mode 100644 index 00000000..e6f996a7 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/cos_mul.mlir @@ -0,0 +1,7 @@ +module { +func.func @main(%arg0: tensor<8x100x32x2xf32>, %arg1: tensor<8x100x1x2xf32>) -> tensor<8x100x32x2xf32> { + %0 = "tfl.cos"(%arg1) : (tensor<8x100x1x2xf32>) -> tensor<8x100x1x2xf32> + %1 = tfl.mul(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<8x100x32x2xf32>, tensor<8x100x1x2xf32>) -> tensor<8x100x32x2xf32> + return %1 : tensor<8x100x32x2xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/dynamic_shape_tensor.mlir b/tflite/experimental/litert/test/testdata/dynamic_shape_tensor.mlir new file mode 100644 index 00000000..7024ce18 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/dynamic_shape_tensor.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor) -> tensor { + %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor + return %0 : tensor +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/fully_connected_3d.mlir b/tflite/experimental/litert/test/testdata/fully_connected_3d.mlir new file mode 100644 index 00000000..a3db1d9a --- /dev/null +++ b/tflite/experimental/litert/test/testdata/fully_connected_3d.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<8x100x128xf32>, %arg1: tensor<128x128xf32>, %arg2: none) -> tensor<8x100x128xf32> { + %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) <{asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<8x100x128xf32>, tensor<128x128xf32>, none) -> tensor<8x100x128xf32> + return %0 : tensor<8x100x128xf32> +} +} diff --git a/tflite/experimental/litert/test/testdata/mul_simple.mlir b/tflite/experimental/litert/test/testdata/mul_simple.mlir new file mode 100644 index 00000000..dd02656c --- /dev/null +++ b/tflite/experimental/litert/test/testdata/mul_simple.mlir @@ -0,0 +1,7 @@ +module { +func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = tfl.mul %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2x2xf32> + %1 = tfl.mul %0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> + return %1 : tensor<2x2xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/multi_subgraph.mlir b/tflite/experimental/litert/test/testdata/multi_subgraph.mlir new file mode 100644 index 00000000..7c1f0fe4 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/multi_subgraph.mlir @@ -0,0 +1,21 @@ +module { + +func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %cst = arith.constant dense<[-1.0, -1.0, -1.0, -1.0]> : tensor<4xf32> + %0 = tfl.add %arg0, %cst {fused_activation_function = "NONE"} : tensor<4xf32> + return %0 : tensor<4xf32> +} + +func.func @func1(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %cst = arith.constant dense<[1.0, 1.0, 1.0, 1.0]> : tensor<4xf32> + %0 = tfl.add %arg0, %cst {fused_activation_function = "NONE"} : tensor<4xf32> + return %0 : tensor<4xf32> +} + +func.func @func2(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %cst = arith.constant dense<[2.0, 2.0, 2.0, 2.0]> : tensor<4xf32> + %0 = tfl.add %arg0, %cst {fused_activation_function = "NONE"} : tensor<4xf32> + return %0 : tensor<4xf32> +} + +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/multi_subgraph_mul.mlir b/tflite/experimental/litert/test/testdata/multi_subgraph_mul.mlir new file mode 100644 index 00000000..607100db --- /dev/null +++ b/tflite/experimental/litert/test/testdata/multi_subgraph_mul.mlir @@ -0,0 +1,13 @@ +module { + +func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +func.func @func1(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<4x4xf32> + return %0 : tensor<4x4xf32> +} + +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/one_mul.mlir b/tflite/experimental/litert/test/testdata/one_mul.mlir new file mode 100644 index 00000000..afabf190 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/one_mul.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/rms_norm.mlir b/tflite/experimental/litert/test/testdata/rms_norm.mlir new file mode 100644 index 00000000..476c9829 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/rms_norm.mlir @@ -0,0 +1,16 @@ +module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00zs\F5|\1F\CE)\0D\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.10.0\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { + func.func @main(%arg0: tensor<8x128x1024xf32> {tf_saved_model.index_path = ["args_0"]}) -> (tensor<8x128x1024xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_args_0:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = tfl.mul %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<8x128x1024xf32> + %1 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> + %2 = "tfl.sum"(%0, %1) <{keep_dims = false}> : (tensor<8x128x1024xf32>, tensor<1xi32>) -> tensor<8x128xf32> + %3 = "tfl.pseudo_const"() <{value = dense<1.024000e+03> : tensor}> : () -> tensor + %4 = tfl.div(%2, %3) <{fused_activation_function = "NONE"}> : (tensor<8x128xf32>, tensor) -> tensor<8x128xf32> + %5 = "tfl.pseudo_const"() <{value = dense<9.99999997E-7> : tensor}> : () -> tensor + %6 = tfl.add(%4, %5) <{fused_activation_function = "NONE"}> : (tensor<8x128xf32>, tensor) -> tensor<8x128xf32> + %7 = "tfl.pseudo_const"() <{value = dense<[8, 128, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> + %8 = "tfl.reshape"(%6, %7) : (tensor<8x128xf32>, tensor<3xi32>) -> tensor<8x128x1xf32> + %9 = "tfl.rsqrt"(%8) : (tensor<8x128x1xf32>) -> tensor<8x128x1xf32> + %10 = tfl.mul(%arg0, %9) <{fused_activation_function = "NONE"}> : (tensor<8x128x1024xf32>, tensor<8x128x1xf32>) -> tensor<8x128x1024xf32> + return %10 : tensor<8x128x1024xf32> + } +} diff --git a/tflite/experimental/litert/test/testdata/simple_add_op.mlir b/tflite/experimental/litert/test/testdata/simple_add_op.mlir new file mode 100644 index 00000000..0902f596 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_add_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x1xf32>, %arg1: tensor<1x128x1xf32>) -> tensor<1x128x1xf32> { + %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1x128x1xf32> + return %0 : tensor<1x128x1xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_batch_matmul_op.mlir b/tflite/experimental/litert/test/testdata/simple_batch_matmul_op.mlir new file mode 100644 index 00000000..e756a0da --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_batch_matmul_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x4x256x128xf32>, %arg1: tensor<1x4x128x128xf32>) -> tensor<1x4x256x128xf32> { + %0 = "tfl.batch_matmul"(%arg0, %arg1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<1x4x256x128xf32>, tensor<1x4x128x128xf32>) -> tensor<1x4x256x128xf32> + return %0 : tensor<1x4x256x128xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_cast_op.mlir b/tflite/experimental/litert/test/testdata/simple_cast_op.mlir new file mode 100644 index 00000000..6066c665 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_cast_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<8x100x1xi32>) -> tensor<8x100x1xf32> { + %0 = "tfl.cast"(%arg0) : (tensor<8x100x1xi32>) -> tensor<8x100x1xf32> + return %0 : tensor<8x100x1xf32> +} +} diff --git a/tflite/experimental/litert/test/testdata/simple_concatenation_op.mlir b/tflite/experimental/litert/test/testdata/simple_concatenation_op.mlir new file mode 100644 index 00000000..e1e9bd36 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_concatenation_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<128x4x1x256xf32>, %arg1: tensor<128x4x1x256xf32>) -> tensor<128x4x2x256xf32> { + %0 = "tfl.concatenation"(%arg0, %arg1) <{axis = 2 : i32, fused_activation_function = "NONE"}> : (tensor<128x4x1x256xf32>, tensor<128x4x1x256xf32>) -> tensor<128x4x2x256xf32> + return %0 : tensor<128x4x2x256xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_cos_op.mlir b/tflite/experimental/litert/test/testdata/simple_cos_op.mlir new file mode 100644 index 00000000..70ea46c1 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_cos_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<8x100x1x2xf32>) -> tensor<8x100x1x2xf32> { + %0 = "tfl.cos"(%arg0) : (tensor<8x100x1x2xf32>) -> tensor<8x100x1x2xf32> + return %0 : tensor<8x100x1x2xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_div_op.mlir b/tflite/experimental/litert/test/testdata/simple_div_op.mlir new file mode 100644 index 00000000..3748d45b --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_div_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x8x128xf32>, %arg1: tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> { + %0 = tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1x128x8x128xf32> + return %0 : tensor<1x128x8x128xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_embedding_lookup_op.mlir b/tflite/experimental/litert/test/testdata/simple_embedding_lookup_op.mlir new file mode 100644 index 00000000..75b8000b --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_embedding_lookup_op.mlir @@ -0,0 +1,7 @@ +module { +func.func @main(%arg0: tensor<5xi32>) -> tensor<5x1x2xf32> { + %table = "tfl.pseudo_const"() <{value = dense<"0x00010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001"> : tensor<20x1x2xf32>}> : () -> tensor<20x1x2xf32> + %0 = "tfl.embedding_lookup"(%arg0, %table) : (tensor<5xi32>, tensor<20x1x2xf32>) -> tensor<5x1x2xf32> + return %0 : tensor<5x1x2xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_floor_mod_op.mlir b/tflite/experimental/litert/test/testdata/simple_floor_mod_op.mlir new file mode 100644 index 00000000..6bd3f1fa --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_floor_mod_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> tensor<5xf32> { + %0 = "tfl.floor_mod"(%arg0, %arg1) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> + return %0 : tensor<5xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_fully_connected_op.mlir b/tflite/experimental/litert/test/testdata/simple_fully_connected_op.mlir new file mode 100644 index 00000000..5cad1206 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_fully_connected_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<128x2048xf32>, %arg1: tensor<2304x2048xf32>, %arg2: none) -> tensor<128x2304xf32> { + %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) <{asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<128x2048xf32>, tensor<2304x2048xf32>, none) -> tensor<128x2304xf32> + return %0 : tensor<128x2304xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_greater_op.mlir b/tflite/experimental/litert/test/testdata/simple_greater_op.mlir new file mode 100644 index 00000000..b368def1 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_greater_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x1x64xi32>, %arg1: tensor<1x1x64xi32>) -> tensor<1x1x64xi1> { + %0 = "tfl.greater"(%arg0, %arg1) : (tensor<1x1x64xi32>, tensor<1x1x64xi32>) -> tensor<1x1x64xi1> + return %0 : tensor<1x1x64xi1> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_less_op.mlir b/tflite/experimental/litert/test/testdata/simple_less_op.mlir new file mode 100644 index 00000000..06370a18 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_less_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x1x64xi32>, %arg1: tensor<1x1x64xi32>) -> tensor<1x1x64xi1> { + %0 = "tfl.less"(%arg0, %arg1) : (tensor<1x1x64xi32>, tensor<1x1x64xi32>) -> tensor<1x1x64xi1> + return %0 : tensor<1x1x64xi1> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_logical_and_op.mlir b/tflite/experimental/litert/test/testdata/simple_logical_and_op.mlir new file mode 100644 index 00000000..e58307ca --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_logical_and_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x64x64xi1>, %arg1: tensor<1x64x64xi1>) -> tensor<1x64x64xi1> { + %0 = "tfl.logical_and"(%arg0, %arg1) : (tensor<1x64x64xi1>, tensor<1x64x64xi1>) -> tensor<1x64x64xi1> + return %0 : tensor<1x64x64xi1> +} +} diff --git a/tflite/experimental/litert/test/testdata/simple_model.mlir b/tflite/experimental/litert/test/testdata/simple_model.mlir new file mode 100644 index 00000000..d88a5d59 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_model.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2xf32> + return %0 : tensor<2xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_model_google_tensor.bin b/tflite/experimental/litert/test/testdata/simple_model_google_tensor.bin new file mode 100644 index 0000000000000000000000000000000000000000..208cb983671510eedbb7b31f3491b9a0d6d14b9a GIT binary patch literal 12288 zcmeHMU2I%O75?V#ubut5PMop{P*!TvxDXnX29zKR-;$K%AqtUNl^`K)oY-z%9NTgn znpQ}?g%>}81urXJh*E*9_QgtGtn$FB50#0cYFjF_Nc@;64{o295-JCoEZ>>?v-f&; zZBuNk5_?v&GiT13GjqroCQvtJl=~ov@OuKK-&Us3$!iJwm{ngZ40z5(6&I^ z0&NSlEzq{Wr_uu5=s^@ojxB{8@-WfOAF#xH+&Bs+KZY;*J6h}3VB4Lm z=DJYo>RjkTr3)Xpi(|2WIhwR%n2F&RFEqKf%ya~2`nbClt49s zs>Lj-S=c7^P}E@7#GL5SO+6}FiEE;Vq8{^IC~G~I zx;%dbq}U*zb-tszWl{ai-C9lZ#fKmxEiv=K{+cdY{O^^R4%}>ZqbD60pu+H{gn`Z<$f5K zfQr&<5<4_F_>Z~JeLRMM4y5nK!Y<|YKzu*SPrT}~OMw0=c6t~DmpJ8t{c(Q`Tr(8E zT;tuY?Ge2Eet)&bxM}?_`hUX{Mj;>fOZSXePJeZZQ?mAOh1OxGEoCbjTagh`M@v+h zkje(Iq{jCD99J|1Uj;w(@nkgEnP1(k;0rR@d!iY~2I}mVJY0wx?AMsl&;0IMwr~kn zN1s!&k-V34iriorJY;gSzq<3WW3x%xfWBa zcqw0&b$VmG{A`Hc7%$}$w%)!mUVc7A-}-p@aL~T&I}Dp2FL|2X)_9p;OTW@T`+2F- zEXsw=5Wp-sARYxilJojqIh-%lC}gJgihFv{yb<~j&Ig)pLD9wo{z6ca^JSj;DE>YH zEJh$-3#|?E%bz~R4KMGw@Wwi@40&--bj(N%wfy**=)mn|8$TfasmgIBHuGM&iugM$-2PB7?X5f84SZ|QZ z1{)ri)zRbd+R^0u?!kg{*y{U%@p&Id)$;sJ`u=t3kFG^;72ka$tUAxS{5O4@w4!?k zX+>wfhL-e}`@?~tKDH~y&o+O5=%H+jsn!1QScu-ZKYTGnZ`>cg6rwlo4+lf^t=}KM z9JDX{4#Vc}50A5h?L1E(Vt@X5+74?6Y>NK`gcjdrtO-Enk6~P|gHLe(6j(p{{>ckk ztY7Q?snnr;IWVAAMca#PZoTX;&;i}+VO4`Bs1tD?lQY7-1pC`Qe|G%!_r5(G{2xI)rmrB} z8q#;;on5~35Ep+{p3#Y~<*!!KZ>a^AvZ3!?zuW(&Aol^z3y`zdYZpg`(C<&@<9V|1ys6*0M8<9=kuf!Kr7%7`JUTuwSQxnc(%41Ti*AI36sLWgyQd-d%}?I@A_R^Vq94GkV27#;n(j7w}S{Rz(>*rpn2vx{|?hbDh8GWJ|yaA{t#IP5q(p^fOPMIZ=4>^taDERk%2Mv~ z*jn&xck4l;AO3}P4n3-ECHAT?%`lx~GMV@)%RI?>cI?6mL*`p}_cO_w3#m6Uc6q9B za&r92M4{LFkx%-Ocdw#ej|sJpxOgqmTB~`KHZqsa8V3x|4I2UC*fZqD;}&;RD+t%-*eBo@A~@l(R?|P!USLj^fd6pvaIY1=*vL% z8MUm^v`tb7g!F$JG}}p#Yf1h|H_=C6Jhq_5L9?Iq<#-AYZj~wEq9u~{%shxaN6EK$ z-6KBIOy{x&1b|J|0`?JCb}y%3BlBk;i?;ol=9mml8?F~hOKX8r4>AKjP+UV^+5*HY zB$HL53?Z|Tp1i7DnP((!fs`@xpxJPY{A^ezlkjHN*xiZhrEWujv__L7}=M0?L z*a>80%s;+tc<)LL#v}-5fn-s-*#O880dlI&)WKRHx4?{(mJgepow`D?ojIf<=f-pT z%b0Q=fP>7{X;)~SLUv9cCoSX4S;K=+(4QYT`=PUPZz6it2Il90?*cfs<&R+kG@J0g zC&0&S-l{Q#lQz$5PlDHMo=^8R2fqyb9C+Q<^BNk0Xg1Feas<58gMF-ktyp`@*)L=) zPhvAj)p}MvRAiol(AqCH!F&u1@7G+xMr0-}C5SJj?Jh&e`@ag{I6h!e9?$h~>cQp! z`^A8qABX37%Yht=b=#ipZE0?9YP13uP&5#)kfYdTd9*XI%(buK_vVLV8~J zUV4tJ1KFR~l>2m^e~2?+olkH}t;grp887AYSiOmUQhMXu^O-Oif@^$wRXQH$YWx~6 z7uRN1&0UL;Nu!>Ll<`h`U~c*=^zYOtO{I+U6RVE=!F#xC_C=qT@|!QNjq$iTf^ngg zm@YM$rVAKHmC2N0UZxW#7nCc^yOJqmjP`7rU2uH_O&txTM%~vyU8I53g8L?H>Xf}0 zFDl<++SSCgtA=U!ozzuYHWzP;mim=H)5$;j724G`8`t$xqf+$@iW8ogsH2kk@!iRC zqw?d8B#yZuQNBFCW{+NOjCJcI7OubeH8V%z~czjCXZ$3S#9 z#_Py1-l(rI)E%^1?4*DgSKZ`Sx~Ur4-?O$&FEZ2{p!%sntU2K(Lmi-JM1nd+{AX?t zcIGkEjr6jpf_}V6g?9+v@vp0}EYkF8Wn4XNsHzs4HqR)h{if-fI$|g#+TirjQVJRB zCpEg*U)6^7rZ+^8b${3=?o?^u6Ty%eKmO0fyYH$0GSpMys8|`|{l5{6iSeU>9 zjAN^^Bpe zt(p+ZkC&gZqJQ5|epGWo99{3JHq^mzMBK$`GSq!RT_nOGP<)#Dg5LC5Ig{lu zh=sV$w}L*wC(BShAx*p_EBK>UFj&d|Tz=5$1ECGvcNx`2XkO{Jou#*U5xG&gJF2_4)r( z9v)%d3GhEgr{N!ol_x#vGb{B;Tnl-4g!vrkj2t%?50ukfN5weZ5@mhH3n4Zz*P_k;DEW!)@F`th6(P94)S_Wkzs zAo@kbm23YHG~aJC>T4MB-yHa=MG?J%8GqX|-z6Grn`-r1eJ@1v69V5XoT3jXuSxlN zhIRa=lcTdg#K;>dN0hdtsI0cxWimdvJIlG1PnT@Bv^;ds};NeQRSqprN_3t)aE8sihV^{{_IKV)_68 literal 0 HcmV?d00001 diff --git a/tflite/experimental/litert/test/testdata/simple_model_npu.mlir b/tflite/experimental/litert/test/testdata/simple_model_npu.mlir new file mode 100644 index 00000000..f4959fb6 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_model_npu.mlir @@ -0,0 +1,6 @@ +module { + func.func @main(%x: tensor<2xf32>, %y: tensor<2xf32>) -> tensor<2xf32> { + %out = "tfl.custom"(%x, %y) {custom_code = "DISPATCH_OP", custom_option = #tfl} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + func.return %out : tensor<2xf32> + } +} diff --git a/tflite/experimental/litert/test/testdata/simple_model_qualcomm.bin b/tflite/experimental/litert/test/testdata/simple_model_qualcomm.bin new file mode 100644 index 0000000000000000000000000000000000000000..a66f76296d7698031d0d8df30aaefe093aa5d04d GIT binary patch literal 13800 zcmeHM%a0sK8Lv&8*m=r?9GnECCGi0{W>r1YGdru!M7y#R=Yj1A1Ch9Ndb(?7THDjz z?#Jwo5yFDpxa7bI{(&4o0>pts@QDLL;u81+xNs#LA|WBb{C-u{^Vml)$4Jw5Z+CxH z_0?D3$Pv>%J@`zRzO)7xBck!8MRy3(TOl@AKWo zec3+oz-R$}LcD_eU;jZb^rIZV$2%VS;yB9l1>ElhMHN?Zma6DVkR(3m{1kmZz(Bli zu-}8um+0xtd_6NBqm9hu^F`dHDyDHd@*_12>ZI~#m@z`%i&zQe7SJ{A@$8Bws0TUM z!%-b@t<(dbyr?hocmp^6h*lzROrv%2=LeMS5VD=o|1cOA{f2XY*7ZCZr>(74XG?Ci zT)ExqInv#6q`$Mh?Kq1-p+1(cUswX==0&;BTDkGLK{_Miyk5&`mM&DA2 zny9ooh)d;{^+fyDrqTVzCfbE~9=4)&3?Irw|EDz6BmJHHI1hc0EzO*qgBQOk(`mlk z`0*zCEas!sCgPUoBOl%mGT&*wd=5HZw%aE7abBZ${Q;esm2yQ~XS0U?wZ@a?q>nxc zlp?AUr_9PWHhSRO*!&r*V``Z=MLPg7BSZQ8 z|2->i-O;%D%bOvgGzCU~9W`48Z^5PUo8YF(HZr3Ue zRnd=zUNnU74;G|6EBq+;q|9=^%!YN_1{FF^RZ#+)`-QT& z>4n*35Cg1z@~y(U@`p)w(09Bt9D|yxV(1q^I#PZTPvWY-)Ahn2^@kYamw5t;^i&!Q z5`4}_V7PR{tgQS&P^t)BoK6i^Rh%e%NH?gftQr?8U|XaH^~g`uR3-hdIo|C$NUQjc z@<&CGkNd+QDV3)VD^;W`y3p$J!vM4TVWNT*tfo~s@e8U8b64?1Wi)b=z}=I^abY0Qrjdu z6*J$QXjU+ZQ$HxealhN6((?e;N#YP{8wa_vQ^lDdO#*)~%Y(9n7O60Qtm4tQ>i4`s zJn}JLHMW{>b-G>>6tK-8s6tSowZ7Bpd1(e?1Q>dthnALm$H4;1Bg8bQ;wdJ9Is&JA z;F{$P;4%A^USXl4I#m5GY@pyDa34P};;e|PS>M@ddGtG#z@QDkQVGY9-sw!et_S+x?Ls)%K7b${I zyNzta)QSQc=W2J&SjQdEm!Opn7CP=}E%G($ zU8_3*$NnpF%jsNR{qrg$R6&unyvCpBj99p4GnprnmJ`vHa`sOzl_5r+>ayg9Voly1O;SQ(B8zi zAqf*)Nl`T+()+F_-UhvG@v7&z)ayLF z3}2mEYzR25QJN)j&_smPPE8NYlK=px=d&L^LL7iiA=5ceE9)r9224Z<_c^P)P{ec~ zgQ06)0|XFI89;^@^|K<<3EU^f6C#k%HG27JFi~FDfz`9T2FMxbXp)KbVs(gAtJMx& z#S;iX>-mE?DE$dGg>lIEvteLZ20+%SG4CloG(w%yrVF1jjA-f` z-&9-HVt`pVk3w3u`eYL4D*b6()bm7^jWFpV5w5leSl(E|rETc<8zd zNR-|r8NIV8j`!0A37%>Pg;Ts#f9VBjn-bQglwr>HM zO~_me>#{=TBE^`y=bYT%w1xvj?TP$My9J7rYGn~*RF8TxIAn>8*pFFNn;Oy#86vku z{=&=3mBx&+o3gXk^|%?-cz-eot=!uGaIQn`YP;!avSJd<^J|3OwxnC`g-9)-@e)7Y(38)(w8+6SVWVoJVdLYNe!eU@gmmL)IsDP1qA? zo3Ubn-71_y&=Oc002>>tc{ig?K>mKCFS%`oVpxgLf!dC_nvnor$UOpb_9*6dir(94 zBMCalnT@i51YC5fW8DRKus$-TGSN?K#VE$7!$zhCR9V)}>|!OH$rR8AH0HtVJN>o` zb?>$MJJfs2q2+X`&7SQYgdFQR&)I9<+71erIRjnV4%$nXPUJ2&jZqv?^+RMhmH?|m zbTusBoM(d$am#EybXr+Nx70gDAQE>|%gDzf0^OJt7rRO6GMO}rB75daiG3ida+M;B zqX6KSu_v&3A(vyl@jAm0dtKyj2StZ*NNEM&QgHTKWcnC#0Xb%SvXn@T|IveN1rZK1 z>}9HP)_Jou$dFG5TU3yAjh^FyKaVXjm&QIOf$94?JL!V z+|xdaeI_EQP$(C5wFhIPTJPH!_7-b8fy0%>Fau{A>vT4?izk7gD5-jGFMrwD!G@M&j?>-lahhyV zz!T`sAk0dV0M zS-J-}+CzneEC_u-ogEzr7b@;#CPue>iqRj_&Wvf zk3P7Py@cmk-2DCG7EeUQr8iTc(=DDV?!G?UD&0%nEwOub5?mRpa6ixBJhxPJgz|0> zS7I-UFyc~MT)L?dZkb0LJZ{a>@MeyQ#Kzr6ApzvQ%ljd`+uwoie)OX&6@N$KZ$`X_ zLs7co^&wux`)Bk|9GmZ-**O2pCvJSocKdrSo5j7$dw(Kr`TBY58-Hxp`|Epe@3HNl z_;2Rh)q2ps7J$Hh4~c)@=Wj>6@O>9E?l#|+*0ilWt^9Kuc|C=jzvJ-z_TL`n&pzi5 ze)#bZ9>$_cJ9_1|&tIHB{G{o_e$7oeO55k&Uq61o>C^CGJNp*sgXbo{pSN>SKHmIo zoxcb2$-5b|i;=bE{o7M`qr0ef zo*C$ncb8o+SUPR{7;gG&(~n#Z(U)_w{RQ;plf3C0q&I8-u9bo9FOe}W_Pb+cAbk^b z(s+Llbm(dmui`6~&iAMO{>eQ)*VkGDYYnV5@OjYyaqFkA{O +#include + +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_layout.h" + +constexpr const char* kModelFileName = "simple_model.tflite"; +constexpr const char* kQualcommModelFileName = "simple_model_qualcomm.bin"; +constexpr const char* kGoogleTensorModelFileName = + "simple_model_google_tensor.bin"; +constexpr const char* kMediaTekModelFileName = "simple_model_mtk.bin"; + +constexpr const int32_t kTestInput0Dimensions[] = {2}; +constexpr const int32_t kNumTestInput0Dimensions = + sizeof(kTestInput0Dimensions) / sizeof(kTestInput0Dimensions[0]); +constexpr const int32_t kTestInput1Dimensions[] = {2}; +constexpr const int32_t kNumTestInput1Dimensions = + sizeof(kTestInput1Dimensions) / sizeof(kTestInput1Dimensions[0]); +constexpr const int32_t kTestOutputDimensions[] = {2}; +constexpr const int32_t kNumTestOutputDimensions = + sizeof(kTestOutputDimensions) / sizeof(kTestOutputDimensions[0]); + +constexpr const float kTestInput0Tensor[] = {1, 2}; +constexpr const float kTestInput1Tensor[] = {10, 20}; +constexpr const float kTestOutputTensor[] = {11, 22}; + +constexpr const float kTestInput0Tensor_2[] = {10, 20}; +constexpr const float kTestInput1Tensor_2[] = {100, 200}; +constexpr const float kTestOutputTensor_2[] = {110, 220}; + +constexpr const size_t kTestInput0Size = + sizeof(kTestInput0Tensor) / sizeof(kTestInput0Tensor[0]); +constexpr const size_t kTestInput1Size = + sizeof(kTestInput1Tensor) / sizeof(kTestInput1Tensor[0]); +constexpr const size_t kTestOutputSize = + sizeof(kTestOutputTensor) / sizeof(kTestOutputTensor[0]); + +constexpr const LiteRtRankedTensorType kInput0TensorType = { + /*.element_type=*/kLiteRtElementTypeFloat32, + ::litert::BuildLayout(kTestInput0Dimensions)}; + +constexpr const LiteRtRankedTensorType kInput1TensorType = { + /*.element_type=*/kLiteRtElementTypeFloat32, + ::litert::BuildLayout(kTestInput1Dimensions)}; + +constexpr const LiteRtRankedTensorType kOutputTensorType = { + /*.element_type=*/kLiteRtElementTypeFloat32, + ::litert::BuildLayout(kTestOutputDimensions)}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TESTDATA_SIMPLE_MODEL_TEST_VECTORS_H_ diff --git a/tflite/experimental/litert/test/testdata/simple_mul_op.mlir b/tflite/experimental/litert/test/testdata/simple_mul_op.mlir new file mode 100644 index 00000000..7fb5ac2d --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_mul_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x2304xf32>, %arg1: tensor<1x128x2304xf32>) -> tensor<1x128x2304xf32> { + %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1x128x2304xf32> + return %0 : tensor<1x128x2304xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_multi_op.mlir b/tflite/experimental/litert/test/testdata/simple_multi_op.mlir new file mode 100644 index 00000000..07757fdd --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_multi_op.mlir @@ -0,0 +1,9 @@ +module { +func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2x2xf32> + %1 = tfl.mul %0, %0 {fused_activation_function = "NONE"} : tensor<2x2xf32> + %2 = tfl.mul %1, %1 {fused_activation_function = "NONE"} : tensor<2x2xf32> + %3 = tfl.add %2, %2 {fused_activation_function = "NONE"} : tensor<2x2xf32> + return %3 : tensor<2x2xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_reshape_op.mlir b/tflite/experimental/litert/test/testdata/simple_reshape_op.mlir new file mode 100644 index 00000000..515db6e4 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_reshape_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x4x256xf32>, %arg1: tensor<4xi32>) -> tensor<128x4x1x256xf32> { + %0 = "tfl.reshape"(%arg0, %arg1) : (tensor<1x128x4x256xf32>, tensor<4xi32>) -> tensor<128x4x1x256xf32> + return %0 : tensor<128x4x1x256xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_rsqrt_op.mlir b/tflite/experimental/litert/test/testdata/simple_rsqrt_op.mlir new file mode 100644 index 00000000..5083f3f3 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_rsqrt_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x1xf32>) -> tensor<1x128x1xf32> { + %0 = "tfl.rsqrt"(%arg0) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32> + return %0 : tensor<1x128x1xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_select_op.mlir b/tflite/experimental/litert/test/testdata/simple_select_op.mlir new file mode 100644 index 00000000..2405e5d3 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_select_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x8x128xi1>, %arg1: tensor<1x128x8x128xf32>, %arg2: tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> { + %0 = "tfl.select"(%arg0, %arg1, %arg2) : (tensor<1x128x8x128xi1>, tensor<1x128x8x128xf32>, tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> + return %0 : tensor<1x128x8x128xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_select_v2_op.mlir b/tflite/experimental/litert/test/testdata/simple_select_v2_op.mlir new file mode 100644 index 00000000..a8d80ecc --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_select_v2_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<8x1x1x100xi1>, %arg1: tensor<8x100x32x100xf32>, %arg2: tensor<8x100x32x100xf32>) -> tensor<8x100x32x100xf32> { + %0 = "tfl.select_v2"(%arg0, %arg1, %arg2) : (tensor<8x1x1x100xi1>, tensor<8x100x32x100xf32>, tensor<8x100x32x100xf32>) -> tensor<8x100x32x100xf32> + return %0 : tensor<8x100x32x100xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_sin_op.mlir b/tflite/experimental/litert/test/testdata/simple_sin_op.mlir new file mode 100644 index 00000000..431d3b93 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_sin_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<8x100x1x2xf32>) -> tensor<8x100x1x2xf32> { + %0 = "tfl.sin"(%arg0) : (tensor<8x100x1x2xf32>) -> tensor<8x100x1x2xf32> + return %0 : tensor<8x100x1x2xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_slice_op.mlir b/tflite/experimental/litert/test/testdata/simple_slice_op.mlir new file mode 100644 index 00000000..117b9feb --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_slice_op.mlir @@ -0,0 +1,8 @@ +module { +func.func @main(%arg0: tensor<1x128x8x256xf32>) -> tensor<1x128x8x128xf32> { + %cst_0 = "tfl.pseudo_const"() <{value = dense<0> : tensor<4xi32>}> : () -> tensor<4xi32> + %cst_1 = "tfl.pseudo_const"() <{value = dense<[1, 128, 4, 128]> : tensor<4xi32>}> : () -> tensor<4xi32> + %0 = "tfl.slice"(%arg0, %cst_0, %cst_1) : (tensor<1x128x8x256xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x128x8x128xf32> + return %0 : tensor<1x128x8x128xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_softmax_op.mlir b/tflite/experimental/litert/test/testdata/simple_softmax_op.mlir new file mode 100644 index 00000000..bb3a83a3 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_softmax_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = "tfl.softmax"(%arg0) <{beta = 1.000000e+00 : f32}> : (tensor<8x128xf32>) -> tensor<8x128xf32> + return %0 : tensor<8x128xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_stablehlo_scatter_op.mlir b/tflite/experimental/litert/test/testdata/simple_stablehlo_scatter_op.mlir new file mode 100644 index 00000000..9d098eb0 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_stablehlo_scatter_op.mlir @@ -0,0 +1,9 @@ +module { +func.func @main(%arg0: tensor<1x128x4x256xf32>, %arg1: tensor<131072x4xi32>, %arg2: tensor<131072xf32>) -> tensor<1x128x4x256xf32> { + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{scatter_dimension_numbers = #stablehlo.scatter}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + stablehlo.return %arg4 : tensor + }) : (tensor<1x128x4x256xf32>, tensor<131072x4xi32>, tensor<131072xf32>) -> tensor<1x128x4x256xf32> + return %0 : tensor<1x128x4x256xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_strided_slice_op.mlir b/tflite/experimental/litert/test/testdata/simple_strided_slice_op.mlir new file mode 100644 index 00000000..373eff80 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_strided_slice_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x4x256xf32>, %arg1: tensor<4xi32>, %arg2: tensor<4xi32>, %arg3: tensor<4xi32>) -> tensor<1x128x4x128xf32> { + %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<1x128x4x256xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x128x4x128xf32> + return %0 : tensor<1x128x4x128xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_sub_op.mlir b/tflite/experimental/litert/test/testdata/simple_sub_op.mlir new file mode 100644 index 00000000..e1483fed --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_sub_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x4x128xf32>, %arg1: tensor<1x128x4x128xf32>) -> tensor<1x128x4x128xf32> { + %0 = tfl.sub %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1x128x4x128xf32> + return %0 : tensor<1x128x4x128xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_sum_op.mlir b/tflite/experimental/litert/test/testdata/simple_sum_op.mlir new file mode 100644 index 00000000..bb4613d5 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_sum_op.mlir @@ -0,0 +1,7 @@ +module { +func.func @main(%arg0: tensor<1x128x2304xf32>) -> tensor<1x128x1xf32> { + %cst = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> + %0 = "tfl.sum"(%arg0, %cst) <{keep_dims = true}> : (tensor<1x128x2304xf32>, tensor<1xi32>) -> tensor<1x128x1xf32> + return %0 : tensor<1x128x1xf32> +} +} diff --git a/tflite/experimental/litert/test/testdata/simple_tanh_op.mlir b/tflite/experimental/litert/test/testdata/simple_tanh_op.mlir new file mode 100644 index 00000000..ce1d0302 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_tanh_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> { + %0 = "tfl.tanh"(%arg0) : (tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> + return %0 : tensor<1x128x8x128xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/simple_transpose_op.mlir b/tflite/experimental/litert/test/testdata/simple_transpose_op.mlir new file mode 100644 index 00000000..f24d7221 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/simple_transpose_op.mlir @@ -0,0 +1,7 @@ +module { +func.func @main(%arg0: tensor<128x4x2x128xf32>) -> tensor<128x2x4x128xf32> { + %cst = "tfl.pseudo_const"() <{value = dense<[0, 2, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> + %0 = "tfl.transpose"(%arg0, %cst) : (tensor<128x4x2x128xf32>, tensor<4xi32>) -> tensor<128x2x4x128xf32> + return %0 : tensor<128x2x4x128xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/two_partition.mlir b/tflite/experimental/litert/test/testdata/two_partition.mlir new file mode 100644 index 00000000..738c8309 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/two_partition.mlir @@ -0,0 +1,9 @@ +module { +func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2x2xf32> + %1 = tfl.mul %0, %0 {fused_activation_function = "NONE"} : tensor<2x2xf32> + %2 = tfl.add %1, %1 {fused_activation_function = "NONE"} : tensor<2x2xf32> + %3 = tfl.mul %2, %2 {fused_activation_function = "NONE"} : tensor<2x2xf32> + return %3 : tensor<2x2xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/test/testdata/unranked_tensor.mlir b/tflite/experimental/litert/test/testdata/unranked_tensor.mlir new file mode 100644 index 00000000..4e2403a7 --- /dev/null +++ b/tflite/experimental/litert/test/testdata/unranked_tensor.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<*xf32> + return %0 : tensor<*xf32> +} +} \ No newline at end of file diff --git a/tflite/experimental/litert/tools/BUILD b/tflite/experimental/litert/tools/BUILD new file mode 100644 index 00000000..a24397f4 --- /dev/null +++ b/tflite/experimental/litert/tools/BUILD @@ -0,0 +1,189 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//tflite/experimental/litert/vendors/qualcomm:qualcomm_build_defs.bzl", "litert_cc_bin_with_qnn") + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "apply_plugin", + testonly = 1, + srcs = ["apply_plugin.cc"], + hdrs = ["apply_plugin.h"], + deps = [ + ":dump", + ":outstream", + ":tool_display", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/cc:litert_detail", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/compiler/plugin:algo", + "//tflite/experimental/litert/compiler/plugin:compiler_plugin", + "//tflite/experimental/litert/core:byte_code_util", + "//tflite/experimental/litert/core/model:model_graph", + "//tflite/experimental/litert/core/model:model_serialize", + "//tflite/experimental/litert/core/util:flatbuffer_tools", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "apply_plugin_test", + srcs = ["apply_plugin_test.cc"], + data = [ + "//tflite/experimental/litert/test:mlir_test_data", + "//tflite/experimental/litert/vendors/examples:example_plugin_so", + ], + tags = [ + "noasan", + "nomsan", + "nosan", + ], + deps = [ + ":apply_plugin", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/cc:litert_buffer_ref", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/core:byte_code_util", + "//tflite/experimental/litert/core/model", + "//tflite/experimental/litert/test:common", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + +litert_cc_bin_with_qnn( + name = "apply_plugin_main", + testonly = 1, + srcs = ["apply_plugin_main.cc"], + data = [ + "//tflite/experimental/litert/vendors/examples:example_plugin_so", + "//tflite/experimental/litert/vendors/qualcomm/compiler:qnn_compiler_plugin_so", + ], + export_litert_only = 1, + include_system = 1, + linkstatic = 1, + # copybara:uncomment malloc = "//base:system_malloc", + tags = [ + "noasan", + "nobuilder", + "nomsan", + "nosan", + ], + ungrte = True, + deps = [ + ":apply_plugin", + ":outstream", + "//tflite/experimental/litert/core:byte_code_util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + ], +) + +# Fork of "apply_plugin_main" without the "ungrte" so this tool can be used as part of larger +# integration test pipelines with example_plugin. +cc_binary( + name = "apply_plugin_main_for_test", + testonly = 1, + srcs = ["apply_plugin_main.cc"], + data = [ + "//tflite/experimental/litert/vendors/examples:example_plugin_so", + ], + linkstatic = 1, + tags = [ + "noasan", + "nomsan", + "nosan", + ], + deps = [ + ":apply_plugin", + ":outstream", + "//tflite/experimental/litert/core:byte_code_util", + "@com_google_absl//absl/strings:str_format", + "@llvm-project//llvm:Support", + ], +) + +cc_library( + name = "tool_display", + srcs = ["tool_display.cc"], + hdrs = ["tool_display.h"], + deps = [ + ":outstream", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "tool_display_test", + srcs = ["tool_display_test.cc"], + data = ["//tflite/experimental/litert/test:mlir_test_data"], + deps = [ + ":tool_display", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "dump", + srcs = ["dump.cc"], + hdrs = ["dump.h"], + deps = [ + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/compiler/plugin:compiler_plugin", + "//tflite/experimental/litert/core/model", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "dump_test", + srcs = ["dump_test.cc"], + data = ["//tflite/experimental/litert/test:mlir_test_data"], + deps = [ + ":dump", + "//tflite/experimental/litert/core/model", + "//tflite/experimental/litert/test:common", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "outstream", + hdrs = ["outstream.h"], + deps = [ + "//tflite/experimental/litert/c:litert_logging", + "@com_google_absl//absl/strings:string_view", + ], +) diff --git a/tflite/experimental/litert/tools/apply_plugin.cc b/tflite/experimental/litert/tools/apply_plugin.cc new file mode 100644 index 00000000..0ca6ce75 --- /dev/null +++ b/tflite/experimental/litert/tools/apply_plugin.cc @@ -0,0 +1,700 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/tools/apply_plugin.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/compiler/plugin/algo.h" +#include "tflite/experimental/litert/compiler/plugin/compiler_plugin.h" +#include "tflite/experimental/litert/core/byte_code_util.h" +#include "tflite/experimental/litert/core/model/model_graph.h" +#include "tflite/experimental/litert/core/model/model_serialize.h" +#include "tflite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tflite/experimental/litert/tools/dump.h" +#include "tflite/experimental/litert/tools/tool_display.h" + +namespace litert::tools { + +using ::litert::BufferRef; +using ::litert::OwningBufferRef; +using ::litert::internal::CompilerPlugin; +using ::litert::internal::Dump; +using ::litert::internal::FinishByteCodePlaceholders; +using ::litert::internal::GroupPartitions; +using ::litert::internal::IsConstant; +using ::litert::internal::kByteCodeMetadataKey; +using ::litert::internal::kLiteRtBuildStampKey; +using ::litert::internal::kLiteRtDispatchOpCustomCode; +using ::litert::internal::MakeBuildStamp; +using ::litert::internal::MakeByteCodePlaceholder; +using ::litert::internal::MakeExecInfo; +using ::litert::internal::OutlinePartition; +using ::litert::internal::Serialization; +using ::litert::internal::SerializeModel; +using ::litert::internal::VerifyFlatbuffer; +using ::litert::tools::ApplyPluginRun; + +#define LITERT_ENSURE_CONFIG(expr) \ + if (!(expr)) { \ + return kLiteRtStatusErrorInvalidToolConfig; \ + } + +namespace { + +class Context { + public: + using Ptr = std::unique_ptr; + + explicit Context(ApplyPluginRun::Ptr run) + : run_(std::move(run)), + display_(ToolDisplay(std::move(run_->dump_out), + Context::CmdStr(run_->cmd))) {} + + ApplyPluginRun::Cmd Cmd() const { return run_->cmd; } + + absl::Span LibSearchPaths() const { + return absl::MakeConstSpan(run_->lib_search_paths.data(), + run_->lib_search_paths.size()); + } + + absl::string_view SocModelTarget() const { + ABSL_CHECK_EQ(run_->soc_models.size(), 1); + return run_->soc_models.front(); + } + + absl::string_view SocManufacturer() const { + return run_->soc_manufacturer.value(); + } + + std::ostream& Out() { + ABSL_CHECK_EQ(run_->outs.size(), 1); + return run_->outs.front(); + } + + OutStream SwapOut(OutStream out) { + ABSL_CHECK_EQ(run_->outs.size(), 1); + auto res = run_->outs.front(); + run_->outs.at(0) = out; + return res; + } + + Serialization Serialization() const { return run_->serialization; } + + const ApplyPluginRun& Run() const { return *run_; } + ApplyPluginRun& Run() { return *run_; } + + ToolDisplay& Dump() { return display_; } + + static absl::string_view CmdStr(ApplyPluginRun::Cmd cmd); + + private: + ApplyPluginRun::Ptr run_; + ToolDisplay display_; +}; + +absl::string_view Context::CmdStr(ApplyPluginRun::Cmd cmd) { + switch (cmd) { + case ApplyPluginRun::Cmd::INFO: + return "INFO"; + case ApplyPluginRun::Cmd::NOOP: + return "NOOP"; + case ApplyPluginRun::Cmd::PARTITION: + return "PARTITION"; + case ApplyPluginRun::Cmd::COMPILE: + return "COMPILE"; + case ApplyPluginRun::Cmd::APPLY: + return "APPLY"; + } +} + +void DumpModelStats(Context& ctx, BufferRef buf) { + ctx.Dump().Labeled() << absl::StreamFormat( + "Serialized a model of size %lu bytes\n", buf.Size()); +} + +Expected> LoadAllPlugins(Context& ctx) { + ctx.Dump().Start("Load Plugins"); + ctx.Dump().Labeled() << "Loading plugins from: "; + const auto paths = ctx.LibSearchPaths(); + for (auto it = paths.begin(); it < paths.end(); ++it) { + ctx.Dump().Display() << *it; + if (it < paths.end() - 1) { + ctx.Dump().Display() << ", "; + } + } + ctx.Dump().Display() << "\n"; + + auto plugins = CompilerPlugin::LoadPlugins(ctx.LibSearchPaths()); + if (!plugins.HasValue()) { + ctx.Dump().Fail(); + return plugins; + } + ctx.Dump().Labeled() << "Found plugins\n"; + ctx.Dump().Labeled() << absl::StreamFormat("Loaded %lu plugins\n", + plugins.Value().size()); + + ctx.Dump().Done(); + return plugins; +} + +Expected LoadPlugin(Context& ctx) { + auto plugins = LoadAllPlugins(ctx); + if (!plugins) { + return plugins.Error(); + } + + ctx.Dump().Start("Select Plugin"); + + for (auto& plugin : *plugins) { + if (plugin.SocManufacturer() == ctx.Run().soc_manufacturer) { + ctx.Dump().Labeled() << absl::StreamFormat("Selected plugin for: %s\n", + plugin.SocManufacturer()); + ctx.Dump().Done(); + return std::move(plugin); + } + } + + ctx.Dump().Fail(); + return Unexpected(kLiteRtStatusErrorNotFound); +} + +Expected LoadModel(Context& ctx) { + ctx.Dump().Start("Load Model"); + ctx.Dump().Labeled() << absl::StreamFormat("Loading model from: %s\n", + ctx.Run().model.value()); + auto model_result = Model::CreateFromFile(ctx.Run().model->data()); + if (!model_result.HasValue()) { + ctx.Dump().Labeled() << "Failed to load model from file."; + ctx.Dump().Fail(); + return model_result; + } + + ctx.Dump().Labeled(); + Dump(*model_result.Value().Get(), ctx.Dump().Display()); + ctx.Dump().Done(); + + return model_result; +} + +std::vector ApplyPartition(Context& ctx, const Model& model, + CompilerPlugin& plugin) { + ctx.Dump().Start("Partition Model"); + + ctx.Dump().Labeled() << "Input model: \n"; + for (auto it = model.Get()->Subgraphs().begin(); + it < model.Get()->Subgraphs().end(); ++it) { + ctx.Dump().Labeled(); + ctx.Dump().Indented() << "(input graph) "; + Dump(**it, ctx.Dump().Display()); + } + + if (model.NumSubgraphs() != 1) { + ctx.Dump().Fail(); + // TODO(@lukeboyer) Finish multi-subgraph support. + return {}; + } + auto partition = plugin.Partition(Subgraph(&model.Get()->Subgraph(0))); + if (!partition.HasValue()) { + return {}; + } + auto grouped_partitions = GroupPartitions(partition.Value()); + if (grouped_partitions.empty()) { + return {}; + } + ctx.Dump().Labeled() << absl::StreamFormat( + "Plugin selected %lu ops, yielding %lu partitions\n", + partition.Value().size(), grouped_partitions.size()); + + std::vector res; + for (auto& partition : grouped_partitions) { + LiteRtOp custom_op = + OutlinePartition(*model.Get()->Subgraphs().front(), + &model.Get()->EmplaceSubgraph(), partition); + res.push_back(custom_op); + } + + ctx.Dump().Labeled() << "Partitioned model: \n"; + ctx.Dump().Labeled(); + ctx.Dump().Indented() << "(initial graph) "; + Dump(model.Get()->Subgraph(0), ctx.Dump().Display()); + for (auto it = model.Get()->Subgraphs().begin() + 1; + it < model.Get()->Subgraphs().end(); ++it) { + ctx.Dump().Labeled(); + ctx.Dump().Indented() << "(new graph) "; + Dump(**it, ctx.Dump().Display()); + } + + ctx.Dump().Done(); + return res; +} + +Expected PartitionModel(Context& ctx, Model&& model, + CompilerPlugin& plugin) { + auto custom_ops = ApplyPartition(ctx, model, plugin); + if (custom_ops.empty()) { + return Unexpected(kLiteRtStatusErrorGraphModification); + } + return std::move(model); +} + +Expected> CompilePartitions( + Context& ctx, std::vector& partitions, + CompilerPlugin& plugin) { + ctx.Dump().Start("Compile Model"); + ctx.Dump().Labeled() << absl::StreamFormat( + "Requesting compilation for target \"%s\" on %lu subgraphs\n", + ctx.SocModelTarget(), partitions.size()); + + std::vector call_info_out; + if (plugin.Compile(ctx.SocModelTarget(), partitions, ctx.Out(), + call_info_out) != kLiteRtStatusOk) { + ctx.Dump().Fail(); + return Unexpected(kLiteRtStatusErrorCompilation); + } + + ctx.Dump().Labeled() << "Entry point info: "; + for (auto it = call_info_out.begin(); it < call_info_out.end(); ++it) { + ctx.Dump().Display() << absl::StreamFormat("\"%s\"", *it); + if (it < call_info_out.end() - 1) { + ctx.Dump().Display() << ", "; + } + } + ctx.Dump().Display() << "\n"; + + ctx.Dump().Done(); + return std::move(call_info_out); +} + +// +// INFO Command +// + +LiteRtStatus ValidateInfoRun(const ApplyPluginRun& run) { + LITERT_ENSURE_CONFIG(!run.lib_search_paths.empty()); + LITERT_ENSURE_CONFIG(run.outs.size() == 1); + return kLiteRtStatusOk; +} + +LiteRtStatus Info(Context& ctx) { + auto plugins = LoadAllPlugins(ctx); + if (!plugins) { + return plugins.Error().Status(); + } + + for (auto& plugin : *plugins) { + ctx.Out() << absl::StreamFormat("< LiteRtCompilerPlugin > \"%s\" | ", + plugin.SocManufacturer()); + const auto& models = plugin.SocModels(); + for (auto it = models.begin(); it < models.end(); ++it) { + ctx.Out() << absl::StreamFormat("\"%s\"", *it); + if (it < models.end() - 1) { + ctx.Out() << ", "; + } + } + ctx.Out() << "\n"; + } + return kLiteRtStatusOk; +} + +// +// NOOP Command +// + +LiteRtStatus ValidateNoopRun(const ApplyPluginRun& run) { + LITERT_ENSURE_CONFIG(run.model.has_value()); + LITERT_ENSURE_CONFIG(run.outs.size() == 1); + return kLiteRtStatusOk; +} + +LiteRtStatus Noop(Context& ctx) { + auto model = LoadModel(ctx); + if (!model) { + return model.Error().Status(); + } + + auto serialized = SerializeModel(std::move(*model->Get())); + if (!serialized) { + return serialized.Error().Status(); + } + LITERT_ENSURE(VerifyFlatbuffer(serialized->Span()), + kLiteRtStatusErrorInvalidFlatbuffer, + "Failed to invalidate flatbuffer"); + serialized->WriteStr(ctx.Out()); + return kLiteRtStatusOk; +} + +// +// PARTITION Command +// + +LiteRtStatus ValidatePartitionRun(const ApplyPluginRun& run) { + LITERT_ENSURE_CONFIG(!run.lib_search_paths.empty()); + LITERT_ENSURE_CONFIG(run.model.has_value() && !run.model.value().empty()); + LITERT_ENSURE_CONFIG(run.soc_manufacturer.has_value()); + LITERT_ENSURE_CONFIG(!run.outs.empty()); + return kLiteRtStatusOk; +} + +LiteRtStatus Partition(Context& ctx) { + auto plugin = LoadPlugin(ctx); + if (!plugin) { + return plugin.Error().Status(); + } + + auto model = LoadModel(ctx); + if (!model) { + return model.Error().Status(); + } + + auto partitioned_model = PartitionModel(ctx, std::move(*model), *plugin); + if (!partitioned_model) { + return partitioned_model.Error().Status(); + } + + ctx.Dump().Start("Serializing model"); + auto serialized = SerializeModel(std::move(*partitioned_model->Get())); + DumpModelStats(ctx, *serialized); + ctx.Dump().Done(); + + ctx.Dump().Start("Verifying flatbuffer"); + LITERT_ENSURE(VerifyFlatbuffer(serialized->Span()), + kLiteRtStatusErrorInvalidFlatbuffer, + "Failed to invalidate flatbuffer"); + ctx.Dump().Done(); + + ctx.Dump().Start("Writing to out"); + serialized->WriteStr(ctx.Out()); + ctx.Dump().Done(); + + return kLiteRtStatusOk; +} + +// +// COMPILE Command +// + +LiteRtStatus ValidateCompileRun(const ApplyPluginRun& run) { + LITERT_ENSURE_CONFIG(!run.lib_search_paths.empty()); + LITERT_ENSURE_CONFIG(run.model.has_value()); + LITERT_ENSURE_CONFIG(run.soc_manufacturer.has_value()); + LITERT_ENSURE_CONFIG(run.outs.size() == run.soc_models.size()); + // TODO: implement multi target compilation. + LITERT_ENSURE_SUPPORTED(run.soc_models.size() == 1, + "Multi target compilation not implemented."); + return kLiteRtStatusOk; +} + +LiteRtStatus Compile(Context& ctx) { + auto model = LoadModel(ctx); + if (!model) { + return model.Error().Status(); + } + + auto plugin = LoadPlugin(ctx); + if (!plugin) { + return plugin.Error().Status(); + } + + std::vector compilation_input; + compilation_input.reserve(model->Get()->NumSubgraphs()); + for (auto* subgraph : model->Get()->Subgraphs()) { + compilation_input.push_back(subgraph); + } + + auto entry_points = CompilePartitions(ctx, compilation_input, *plugin); + if (!entry_points) { + return entry_points.Error().Status(); + } + + return kLiteRtStatusOk; +} + +// +// APPLY Command +// + +LiteRtStatus StampModel(Context& ctx, LiteRtModel model) { + auto stamp = MakeBuildStamp(ctx.SocManufacturer(), ctx.SocModelTarget(), + ctx.Serialization()); + if (!stamp) { + return stamp.Error().Status(); + } + ctx.Dump().Labeled() << absl::StreamFormat("Stamping model: %s\n", + stamp->StrView()); + return model->PushMetadata(kLiteRtBuildStampKey, *stamp); +} + +Expected> DoMetadataSerialization( + Context& ctx, std::vector& custom_ops, + std::vector& call_info, BufferRef compilation_out, + Model&& model) { + ctx.Dump().Start("Serializing with bytecode in METADATA"); + + { + auto call_it = call_info.begin(); + auto custom_op_it = custom_ops.begin(); + for (; call_it < call_info.end() && custom_op_it < custom_ops.end(); + ++call_it, ++custom_op_it) { + auto& custom_op = **custom_op_it; + custom_op.SetCustomOptions(call_it->c_str()); + } + } + + { + ctx.Dump().Labeled() << absl::StreamFormat( + "Adding metadata byte code of size: %lu bytes\n", + compilation_out.Size()); + + LITERT_EXPECT_OK(model.Get()->PushMetadata( + kByteCodeMetadataKey, compilation_out.Data(), compilation_out.Size())); + } + + auto serialized = SerializeModel(std::move(*model.Get())); + if (!serialized) { + return serialized.Error(); + } + + ctx.Dump().Labeled() << absl::StreamFormat( + "Serialized model of size: %lu bytes\n", serialized->Size()); + if (!VerifyFlatbuffer(serialized->Span())) { + ctx.Dump().Fail(); + return Unexpected(kLiteRtStatusErrorInvalidFlatbuffer); + } + ctx.Dump().Done(); + + return serialized; +} + +Expected> DoAppendSerialization( + Context& ctx, std::vector& custom_ops, + std::vector& call_info, BufferRef compilation_out, + Model&& model) { + ctx.Dump().Start("Serializing with bytecode APPEND"); + + // This need not be the same for all custom ops. + static constexpr absl::string_view kSharedByteCodePlaceholderName = + kByteCodeMetadataKey; + LITERT_EXPECT_OK(model.Get()->PushMetadata(kSharedByteCodePlaceholderName, + MakeByteCodePlaceholder())); + + { + auto call_it = call_info.begin(); + auto custom_op_it = custom_ops.begin(); + for (; call_it < call_info.end() && custom_op_it < custom_ops.end(); + ++call_it, ++custom_op_it) { + auto exec_info = MakeExecInfo(*call_it, kSharedByteCodePlaceholderName); + if (!exec_info) { + return exec_info; + } + auto& custom_op = **custom_op_it; + custom_op.SetCustomOptions(std::move(*exec_info)); + } + } + + auto serialized = SerializeModel(std::move(*model.Get())); + if (!serialized) { + return serialized; + } + + ctx.Dump().Labeled() << absl::StreamFormat( + "Serialized model of size: %lu bytes\n", serialized->Size()); + LITERT_EXPECT_OK( + FinishByteCodePlaceholders(*serialized, compilation_out.Size())); + + OwningBufferRef with_append(serialized->Size() + + compilation_out.Size()); + + uint8_t* write = with_append.Data(); + std::memcpy(write, serialized->Data(), serialized->Size()); + write += serialized->Size(); + std::memcpy(write, compilation_out.Data(), compilation_out.Size()); + + ctx.Dump().Labeled() << absl::StreamFormat("Appended byte code of size %lu\n", + compilation_out.Size()); + + ctx.Dump().Done(); + return with_append; +} + +LiteRtStatus ValidateApplyRun(const ApplyPluginRun& run) { + LITERT_ENSURE_CONFIG(!run.lib_search_paths.empty()); + LITERT_ENSURE_CONFIG(run.model.has_value()); + LITERT_ENSURE_CONFIG(run.soc_manufacturer.has_value()); + LITERT_ENSURE_CONFIG(run.outs.size() == run.soc_models.size()); + // TODO: implement multi target compilation. + LITERT_ENSURE_SUPPORTED(run.soc_models.size() == 1, + "Multi target compilation not implemented."); + LITERT_ENSURE_SUPPORTED(run.serialization != Serialization::kUnknown, + "No serialization strategy supported."); + return kLiteRtStatusOk; +} + +LiteRtStatus Apply(Context& ctx) { + auto model = LoadModel(ctx); + if (!model) { + return model.Error().Status(); + } + + auto plugin = LoadPlugin(ctx); + if (!plugin) { + return plugin.Error().Status(); + } + + static constexpr size_t kNumInputSubgraphs = 1; + LITERT_ENSURE_SUPPORTED(model->Get()->NumSubgraphs() == kNumInputSubgraphs, + "Only single subgraph models currently supported."); + + // Query plugin for compilable ops and slice partitions out of the graph, + // replacing use with single custom op.. + auto custom_ops = ApplyPartition(ctx, *model, *plugin); + LITERT_ENSURE(!custom_ops.empty(), kLiteRtStatusErrorGraphModification, + "Failed to partition graph."); + ABSL_DCHECK_EQ(custom_ops.size(), + model->Get()->NumSubgraphs() - kNumInputSubgraphs); + + // All new subgraphs to be compiled are appended to the model's subgraphs. + auto new_sg_start = model->Get()->Subgraphs().begin() + kNumInputSubgraphs; + auto new_sg_end = model->Get()->Subgraphs().end(); + std::vector compilation_input; + for (auto it = new_sg_start; it < new_sg_end; ++it) { + compilation_input.push_back(*it); + } + + // Call compilation method on the plugin. + std::stringstream compilation_out; + OutStream out = ctx.SwapOut(compilation_out); + + auto call_info = CompilePartitions(ctx, compilation_input, *plugin); + + // Update custom op info the it's respective entry point info from the plugin. + LITERT_ENSURE(call_info->size() == custom_ops.size(), + kLiteRtStatusErrorCompilation, + "Failed to verify entry point information."); + + model->Get()->ResizeSubgraphsDown(kNumInputSubgraphs); + + LITERT_RETURN_STATUS_IF_NOT_OK(StampModel(ctx, model->Get())); + + BufferRef compiled_buffer(compilation_out.view().data(), + compilation_out.view().size()); + + // For each custom op, if the input tensor is a constant, it should be removed + // from the input list. + // TODO(@lukeboyer) Move this to algo, use model_graph api, and test behavior. + for (auto& custom_op : custom_ops) { + std::vector new_inputs; + for (auto* input : custom_op->Inputs()) { + if (!IsConstant(*input)) { + new_inputs.push_back(input); + } + } + custom_op->Inputs() = new_inputs; + } + + ctx.SwapOut(out); + if (ctx.Serialization() == Serialization::kMetadata) { + auto serialized = DoMetadataSerialization( + ctx, custom_ops, *call_info, compiled_buffer, std::move(*model)); + if (!serialized) { + return serialized.Error().Status(); + } + serialized->WriteStr(ctx.Out()); + + } else if (ctx.Serialization() == Serialization::kAppend) { + auto serialized = DoAppendSerialization(ctx, custom_ops, *call_info, + compiled_buffer, std::move(*model)); + if (!serialized) { + return serialized.Error().Status(); + } + serialized->WriteStr(ctx.Out()); + + } else { + return kLiteRtStatusErrorUnsupported; + } + return kLiteRtStatusOk; +} + +} // namespace + +LiteRtStatus ApplyPlugin(ApplyPluginRun::Ptr run) { + Context context(std::move(run)); + DumpPreamble(context.Dump()); + + switch (context.Cmd()) { + case ApplyPluginRun::Cmd::INFO: + if (auto stat = ValidateInfoRun(context.Run()); stat != kLiteRtStatusOk) { + context.Dump().Labeled() << "Invalid arguments for INFO command\n"; + return stat; + } + return Info(context); + + case ApplyPluginRun::Cmd::PARTITION: + if (auto stat = ValidatePartitionRun(context.Run()); + stat != kLiteRtStatusOk) { + context.Dump().Labeled() << "Invalid arguments for PARTITION command\n"; + return stat; + } + return Partition(context); + + case ApplyPluginRun::Cmd::COMPILE: + if (auto stat = ValidateCompileRun(context.Run()); + stat != kLiteRtStatusOk) { + context.Dump().Labeled() << "Invalid arguments for COMPILE command\n"; + return stat; + } + return Compile(context); + + case ApplyPluginRun::Cmd::APPLY: + if (auto stat = ValidateApplyRun(context.Run()); + stat != kLiteRtStatusOk) { + context.Dump().Labeled() << "Invalid arguments for APPLY command\n"; + return stat; + } + return Apply(context); + + case ApplyPluginRun::Cmd::NOOP: + + if (auto stat = ValidateNoopRun(context.Run()); stat != kLiteRtStatusOk) { + context.Dump().Labeled() << "Invalid arguments for NOP command\n"; + return stat; + } + return Noop(context); + + default: + return kLiteRtStatusErrorInvalidArgument; + } + + return kLiteRtStatusOk; +} + +} // namespace litert::tools diff --git a/tflite/experimental/litert/tools/apply_plugin.h b/tflite/experimental/litert/tools/apply_plugin.h new file mode 100644 index 00000000..fcb1af03 --- /dev/null +++ b/tflite/experimental/litert/tools/apply_plugin.h @@ -0,0 +1,177 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_APPLY_PLUGIN_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_APPLY_PLUGIN_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_detail.h" +#include "tflite/experimental/litert/core/byte_code_util.h" +#include "tflite/experimental/litert/tools/outstream.h" + +namespace litert::tools { + +using ::litert::internal::Serialization; + +// TODO remove these usings other than Ptr and outStraemT + +struct ApplyPluginRun { + // NOTE: All StrFlagT are expected to have static storage duration. + using Ptr = std::unique_ptr; + + // A specific command implemented by the tool to run. + enum class Cmd { + // Displays info about all plugins found in given search paths. + // + // FLAG SEMANTICS: + // "lib_search_paths": Required, at least one. + // "model": Ignored. + // "soc_manufacturer": Optional, filters plugins to display. + // "soc_models": Ignored. + // "outs": Required, must be size one. + // "dump_out": Optional. + // "serialization": Ignored. + INFO, + + // Does nothing and simply de-serializes and re-serializes the given model. + // This is intended for testing and internal debugging only. + // + // FLAG SEMANTICS: + // "lib_search_paths": Ignored. + // "model": Required. + // "soc_manufacturer": Ignored. + // "soc_models": Ignored. + // "outs": Required, must be size one. + // "dump_out": Optional. + // "serialization": Ignored. + NOOP, + + // Runs the entire end to end flow. This is the standard compiler plugin + // usage. A seperate compilation step will occur for each sco_model tag that + // is supported by the loaded plugin, and a new output model will be + // generated for each. Partitioning is invariant accross different soc_model + // targets from the same manufacturer, so only one compilation step will + // occur even if multiple targest are requested. + // + // FLAG SEMANTICS: + // "lib_search_paths": Required, at least one. + // "model": Required. + // "soc_manufacturer": Required. + // "soc_models": Required, at least one. + // "outs": Required, must be size equal to "soc_models". + // "dump_out": Optional. + // "serialization": Required. + // + // TODO: Support multi target compilation. + APPLY, + + // Only run the partiion step and skip compilation. Writes a ".tflite" model + // to "out" where selected partitions are manifested as new standard + // flatbuffer subgraphs added to the input model. + // The partitions original locations are replaced with a single custom op + // the contains an identifier to the corresponding partition (new subgraph). + // This is intended for testing and development. + // + // FLAG SEMANTICS: + // "lib_search_paths": Required, at least one. + // "model": Required. + // "soc_manufacturer": Required. + // "soc_models": Ignored. + // "outs": Required, must be size one. + // "dump_out": Optional. + // "serialization": Ignored. + PARTITION, + + // Skip partitioning and run the entire input model through compilation + // directly. Fails if any ops in the input model are unsupported by the + // plugin. Writes the raw compiled result to the "out" stream without any + // wrapping flatbuffer. Runs multi-target compilation as in "APPLY", + // Intended for testing and development. + // + // FLAG SEMANTICS: + // "lib_search_paths": Required, at least one. + // "model": Required. + // "soc_manufacturer": Required. + // "soc_models": Required, at least one. + // "out": Required, must be size equal to "soc_models". + // "dump_out": Optional. + // "serialization": Ignored. + // + // TODO: Support multi target compilation. + COMPILE, + }; + + // A command to run, see above. + Cmd cmd; + + // Collection of paths on local files system dictating where the tool should + // look for suitable LiteRtCompilerPlugin shared libraries. The tool will + // select the first ".so" file found with prefix "libLiteRtPlugin" that has + // the "soc_manufacturer" tag passed. Providing more than one plugin shared + // library for the same manufacturer results in an error. + SmallVec lib_search_paths = {}; + + // Path to ".tflite" model the tool should operated on. + std::optional model = {}; + + // A tag representing a manufacturer the tool should target for compilation. + // This is used to select the appropriate plugin if multiple plugins are found + // in "lib_search_paths". + std::optional soc_manufacturer = {}; + + // Collection of soc models tags the tool should target for compilation. + SmallVec soc_models = {}; + + // Where the tool should write its result file(s) to. If the command runs + // compilation, an "out" stream should be passed for each "soc_model" target + // requested for compilation. Output for the "ith" target will be written to + // the "ith" outs stream. + SmallVec outs = {std::cout}; + + // Where to direct logging for this run. Passing nullopt here indicates + // "silent" behavior and should only be used when this tool is part of a + // larger pipeline like an end2end test. + UserStream dump_out; + + // Dictates how the final model with compiled assets should be serialized. + // Only relevant to the "apply" function. + // + // [METADATA] Write the compiled module into a metadata buffer using the + // soc_manufacturer as a key. This is for testing and debugging as it allows + // the contents of the byte code to be rendered by exisitng flatbuffer + // tooling. Custom op options will contain only a string identifying the + // respective entry point. + // + // [APPEND] Appends the compiled byte code to the end of the ".tflite" file. + // Custom options will contain both an entry point name, and an optional + // metadata lookup key. This facilitates per-op metadata while allowing + // multiple ops to share the same metadata if needed. Any instances of this + // metadata are pairs indicating the offset into the file where the byte code + // starts as well as the size of the byte code. + Serialization serialization = Serialization::kMetadata; +}; + +LiteRtStatus ApplyPlugin(ApplyPluginRun::Ptr run); + +} // namespace litert::tools + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_APPLY_PLUGIN_H_ diff --git a/tflite/experimental/litert/tools/apply_plugin_main.cc b/tflite/experimental/litert/tools/apply_plugin_main.cc new file mode 100644 index 00000000..27d49a86 --- /dev/null +++ b/tflite/experimental/litert/tools/apply_plugin_main.cc @@ -0,0 +1,140 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expruns or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/CommandLine.h" +#include "tflite/experimental/litert/core/byte_code_util.h" +#include "tflite/experimental/litert/tools/apply_plugin.h" +#include "tflite/experimental/litert/tools/outstream.h" + +using ::litert::internal::Serialization; +using ::litert::tools::ApplyPlugin; +using ::litert::tools::ApplyPluginRun; +using ::litert::tools::UserStream; + +// NOLINTNEXTLINE +static llvm::cl::opt cmd( + llvm::cl::Positional, + llvm::cl::desc("Routine to run (apply, partition, compile, info, noop)."), + llvm::cl::init("partition")); + +// NOLINTNEXTLINE +static llvm::cl::opt model( + "model", llvm::cl::desc("Path to flatbuffer file."), llvm::cl::init("")); + +// TODO: b/366821557 - Support path to pre-compiled plugin in flags. +// NOLINTNEXTLINE +static llvm::cl::opt soc_manufacturer( + "soc_man", + llvm::cl::desc("String identifier of SoC manufacturer (e.g., GoogleTensor, " + "Qualcomm)."), + llvm::cl::init("ExampleSocManufacturer")); + +// TODO: Support multi target compilation. +// NOLINTNEXTLINE +static llvm::cl::opt soc_model("soc_model", + llvm::cl::desc("Target SoC model."), + llvm::cl::init("ExampleSocModel")); + +// NOLINTNEXTLINE +static llvm::cl::list libs( + "libs", + llvm::cl::desc("List of directories in which to search for suitable " + "compiler plugin shared libraries."), + llvm::cl::list_init(llvm::ArrayRef{ + "third_party/tensorflow/lite/experimental/litert/vendors/examples", + "third_party/tensorflow/lite/experimental/litert/vendors/qualcomm/" + "compiler"})); + +// NOLINTNEXTLINE +static llvm::cl::opt out( + "o", + llvm::cl::desc("Path to file for output, \"-\" indicates standard out, " + "\"--\" for standard err, \"none\" for null stream."), + llvm::cl::init("-")); + +// NOLINTNEXTLINE +static llvm::cl::opt err( + "err", + llvm::cl::desc("Path to file for err output, \"-\" indicates standard out, " + "\"--\" for standard err, \"none\" for null stream."), + llvm::cl::init("--")); + +// NOLINTNEXTLINE +static llvm::cl::opt serialization( + "serialization", llvm::cl::desc("Serialization strategy to use."), + llvm::cl::init("METADATA")); + +ApplyPluginRun::Ptr ParseFlags() { + auto res = std::make_unique(); + + if (!model.empty()) { + res->model = model; + } + + res->soc_manufacturer = soc_manufacturer; + res->soc_models.push_back(soc_model); + + res->lib_search_paths.assign(libs.begin(), libs.end()); + + if (cmd == "apply") { + res->cmd = ApplyPluginRun::Cmd::APPLY; + } else if (cmd == "partition") { + res->cmd = ApplyPluginRun::Cmd::PARTITION; + } else if (cmd == "compile") { + res->cmd = ApplyPluginRun::Cmd::COMPILE; + } else if (cmd == "info") { + res->cmd = ApplyPluginRun::Cmd::INFO; + } else if (cmd == "noop") { + res->cmd = ApplyPluginRun::Cmd::NOOP; + } else { + return nullptr; + } + + if (serialization == "METADATA") { + res->serialization = Serialization::kMetadata; + } else if (serialization == "APPEND") { + res->serialization = Serialization::kAppend; + } else { + res->serialization = Serialization::kUnknown; + } + + return res; +} + +int main(int argc, char* argv[]) { + llvm::cl::ParseCommandLineOptions(argc, argv); + + auto run = ParseFlags(); + if (run == nullptr) { + return 1; + } + + auto out_stream = UserStream::MakeFromFlag(out); + run->outs.clear(); + run->outs.push_back(out_stream.Get()); + + run->dump_out = UserStream::MakeFromFlag(err); + + run->dump_out.Get() << absl::StreamFormat( + "CMD: %s\nMODEL: %s\nSOC_MANUFACTURER: %s\nSOC_MODEL: %s\n", cmd, model, + soc_manufacturer, soc_model); + + return ApplyPlugin(std::move(run)); +} diff --git a/tflite/experimental/litert/tools/apply_plugin_test.cc b/tflite/experimental/litert/tools/apply_plugin_test.cc new file mode 100644 index 00000000..1503aabb --- /dev/null +++ b/tflite/experimental/litert/tools/apply_plugin_test.cc @@ -0,0 +1,227 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/tools/apply_plugin.h" + +#include +#include +#include +#include + +#include +#include +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_buffer_ref.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/core/byte_code_util.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/test/common.h" + +namespace litert::tools { +namespace { + +using ::litert::internal::kByteCodeMetadataKey; +using ::litert::internal::kLiteRtBuildStampKey; +using ::litert::internal::ParseBuildStamp; +using ::litert::internal::ParseByteCodePlaceholder; +using ::litert::internal::ParseExecInfo; +using ::litert::internal::Serialization; +using ::testing::HasSubstr; + +static constexpr absl::string_view kPluginSearchPath = + "tflite/experimental/litert/vendors/examples"; + +static constexpr absl::string_view kSocManufacturer = "ExampleSocManufacturer"; + +static constexpr absl::string_view kSocModel = "ExampleSocModel"; + +absl::string_view TestModelPath() { + static char kModelPath[512] = {}; + if (kModelPath[0] == '\0') { + const auto model_path = + ::litert::testing::GetTestFilePath("one_mul.tflite"); + ABSL_CHECK(model_path.size() < 512); + model_path.copy(kModelPath, model_path.size(), 0); + } + return kModelPath; +} + +ApplyPluginRun::Ptr MakeBaseRun(ApplyPluginRun::Cmd cmd) { + auto run = std::make_unique(); + run->cmd = cmd; + run->lib_search_paths.push_back(kPluginSearchPath); + run->model.emplace(TestModelPath()); + run->soc_manufacturer.emplace(kSocManufacturer); + run->soc_models.push_back(kSocModel); + run->outs.clear(); + return run; +} + +TEST(TestApplyPluginTool, TestInfoBadConfig) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::INFO); + run->lib_search_paths.clear(); + LITERT_ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), + kLiteRtStatusErrorInvalidToolConfig); +} + +TEST(TestApplyPluginTool, TestInfo) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::INFO); + std::stringstream out; + run->outs.push_back(out); + LITERT_ASSERT_STATUS_OK(ApplyPlugin(std::move(run))); + EXPECT_THAT(out.str(), + ::testing::HasSubstr( + "< LiteRtCompilerPlugin > \"ExampleSocManufacturer\" | " + "\"ExampleSocModel\"")); +} + +TEST(TestApplyPluginTool, TestNoopBadConfig) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::NOOP); + run->model.reset(); + LITERT_ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), + kLiteRtStatusErrorInvalidToolConfig); +} + +TEST(TestApplyPluginTool, TestNoop) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::NOOP); + std::stringstream out; + run->outs.push_back(out); + LITERT_ASSERT_STATUS_OK(ApplyPlugin(std::move(run))); + + auto model = Model::CreateFromBuffer( + BufferRef(out.view().data(), out.view().size())); + EXPECT_EQ(model->Get()->NumSubgraphs(), 1); +} + +TEST(TestApplyPluginTool, TestPartitionBadConfig) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::PARTITION); + run->model.reset(); + LITERT_ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), + kLiteRtStatusErrorInvalidToolConfig); +} + +TEST(TestApplyPluginTool, TestPartition) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::PARTITION); + std::stringstream out; + run->outs.push_back(out); + LITERT_ASSERT_STATUS_OK(ApplyPlugin(std::move(run))); + EXPECT_FALSE(out.str().empty()); +} + +TEST(TestApplyPluginTool, TestCompileBadConfig) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::COMPILE); + run->model.reset(); + LITERT_ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), + kLiteRtStatusErrorInvalidToolConfig); +} + +TEST(TestApplyPluginTool, TestCompile) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::COMPILE); + std::stringstream out; + run->outs.push_back(out); + LITERT_ASSERT_STATUS_OK(ApplyPlugin(std::move(run))); + EXPECT_FALSE(out.str().empty()); + EXPECT_THAT(out.str(), HasSubstr("Partition_0_with_1_muls")); +} + +TEST(TestApplyPluginTool, TestApplyBadConfig) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::APPLY); + run->model.reset(); + LITERT_ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), + kLiteRtStatusErrorInvalidToolConfig); +} + +TEST(TestApplyPluginTool, TestApply) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::APPLY); + std::stringstream out; + run->outs.push_back(out); + LITERT_ASSERT_STATUS_OK(ApplyPlugin(std::move(run))); + + auto model = Model::CreateFromBuffer( + BufferRef(out.str().data(), out.str().size())); + EXPECT_EQ(model->Get()->NumSubgraphs(), 1); + + { + auto stamp_buffer = model->Get()->FindMetadata(kLiteRtBuildStampKey); + auto stamp = ParseBuildStamp(*stamp_buffer); + auto [man, soc_model, serial] = *stamp; + EXPECT_EQ(man, kSocManufacturer); + EXPECT_EQ(soc_model, kSocModel); + EXPECT_EQ(serial, Serialization::kMetadata); + } + + { + const auto& custom_op = model->Get()->Subgraph(0).Op(0); + ASSERT_EQ(custom_op.OpCode(), kLiteRtOpCodeTflCustom); + EXPECT_EQ(custom_op.CustomOptions().StrView(), "Partition_0"); + } + + { + auto byte_code_buffer = model->Get()->FindMetadata(kByteCodeMetadataKey); + EXPECT_THAT(byte_code_buffer->StrView(), + HasSubstr("Partition_0_with_1_muls")); + } +} + +// NOLINTBEGIN +TEST(TestApplyPluginTool, TestApplyWithAppendSerialization) { +#ifndef NDEBUG + GTEST_SKIP() << "Flatbuffers assertion will fail in append mode\n"; +#endif + std::stringstream out; + { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::APPLY); + run->serialization = Serialization::kAppend; + run->outs.push_back(out); + LITERT_ASSERT_STATUS_OK(ApplyPlugin(std::move(run))); + } + + BufferRef serialized(out.str().data(), out.str().size()); + + auto model = Model::CreateFromBuffer(serialized); + EXPECT_EQ(model->Get()->NumSubgraphs(), 1); + + { + auto stamp_buffer = model->Get()->FindMetadata(kLiteRtBuildStampKey); + auto stamp = ParseBuildStamp(*stamp_buffer); + auto [man, model, serial] = *stamp; + EXPECT_EQ(man, kSocManufacturer); + EXPECT_EQ(model, kSocModel); + EXPECT_EQ(serial, Serialization::kAppend); + } + + { + const auto& custom_op = model->Get()->Subgraph(0).Op(0); + ASSERT_EQ(custom_op.OpCode(), kLiteRtOpCodeTflCustom); + + auto options = ParseExecInfo(custom_op.CustomOptions()); + auto [entry_point, metadata_key] = *options; + EXPECT_EQ(entry_point, "Partition_0"); + + auto metadata = model->Get()->FindMetadata(metadata_key); + auto byte_code_info = ParseByteCodePlaceholder(*metadata); + auto [offset, size] = *byte_code_info; + + EXPECT_EQ(serialized.StrView().substr(offset, size), + "Partition_0_with_1_muls:"); + } +} +// NOLINTEND + +} // namespace +} // namespace litert::tools diff --git a/tflite/experimental/litert/tools/dump.cc b/tflite/experimental/litert/tools/dump.cc new file mode 100644 index 00000000..f6854aa8 --- /dev/null +++ b/tflite/experimental/litert/tools/dump.cc @@ -0,0 +1,436 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/tools/dump.h" + +#include + +#ifndef __ANDROID__ +#include +#endif + +#include +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/compiler/plugin/compiler_plugin.h" +#include "tflite/experimental/litert/core/model/model.h" + +namespace litert::internal { + +namespace { + +static constexpr int kMaxDisplayCount = 16; + +void DumpNode(const LiteRtTensorT& tensor, std::ostream& out) { + switch (tensor.Type().first) { + case kLiteRtRankedTensorType: + Dump(tensor.Type().second.ranked_tensor_type, out); + break; + case kLiteRtUnrankedTensorType: + Dump(tensor.Type().second.unranked_tensor_type.element_type, out); + break; + default: + out << "UKNOWN_TENSOR_TYPE" << tensor.Type().first; + } + Dump(tensor.Qparams(), out); +} + +void DumpNode(const LiteRtOpT& op, std::ostream& out) { + Dump(op.OpCode(), out); +} + +void DumpSignature(const std::vector& ins, + const std::vector& outs, std::ostream& out) { + out << "("; + for (auto it = ins.begin(); it < ins.end(); ++it) { + DumpNode(**it, out); + if (it != ins.end() - 1) { + out << ", "; + } + } + out << ")"; + + out << " -> "; + const bool paren_outs = outs.size() != 1; + if (paren_outs) { + out << "("; + } + for (auto it = outs.begin(); it < outs.end(); ++it) { + DumpNode(**it, out); + if (it != outs.end() - 1) { + out << ", "; + } + } + if (paren_outs) { + out << ")"; + } +} + +} // namespace + +void Dump(LiteRtOpCode code, std::ostream& out) { + switch (code) { + case kLiteRtOpCodeTflAdd: + out << "TFL_ADD"; + break; + case kLiteRtOpCodeTflMul: + out << "TFL_MUL"; + break; + case kLiteRtOpCodeTflCustom: + out << "TFL_CUSTOM_OP"; + break; + case kLiteRtOpCodeTflSlice: + out << "TFL_SLICE"; + break; + case kLiteRtOpCodeTflDiv: + out << "TFL_DIV"; + break; + case kLiteRtOpCodeTflRsqrt: + out << "TFL_RSQRT"; + break; + case kLiteRtOpCodeTflTanh: + out << "TFL_TANH"; + break; + case kLiteRtOpCodeTflSub: + out << "TFL_SUB"; + break; + case kLiteRtOpCodeTflReshape: + out << "TFL_RESHAPE"; + break; + case kLiteRtOpCodeTflBatchMatmul: + out << "TFL_BATCH_MATMUL"; + break; + case kLiteRtOpCodeTflSum: + out << "TFL_SUM"; + break; + case kLiteRtOpCodeTflConcatenation: + out << "TFL_CONCATENATION"; + break; + case kLiteRtOpCodeTflSoftmax: + out << "TFL_SOFTMAX"; + break; + case kLiteRtOpCodeTflCast: + out << "TFL_CAST"; + break; + case kLiteRtOpCodeTflTranspose: + out << "TFL_TRANSPOSE"; + break; + case kLiteRtOpCodeTflSin: + out << "TFL_SIN"; + break; + case kLiteRtOpCodeTflCos: + out << "TFL_COS"; + break; + case kLiteRtOpCodeTflSelect: + out << "TFL_SELECT"; + break; + case kLiteRtOpCodeTflSelectV2: + out << "TFL_SELECT_V2"; + break; + case kLiteRtOpCodeTflFullyConnected: + out << "TFL_FULLY_CONNECTED"; + break; + case kLiteRtOpCodeTflEmbeddingLookup: + out << "TFL_EMBEDDING_LOOKUP"; + break; + case kLiteRtOpCodeTflLogicalAnd: + out << "TFL_LOGICAL_AND"; + break; + case kLiteRtOpCodeTflLess: + out << "TFL_LESS"; + break; + case kLiteRtOpCodeTflGreater: + out << "TFL_GREATER"; + break; + default: + out << "UKNOWN_OP_CODE: " << code; + break; + } +}; + +// Dump details about the given LiteRtElementType to the given stream. +void Dump(LiteRtElementType type, std::ostream& out) { + switch (type) { + case kLiteRtElementTypeFloat32: + out << "f32"; + break; + case kLiteRtElementTypeInt32: + out << "i32"; + break; + case kLiteRtElementTypeFloat64: + out << "f64"; + break; + case kLiteRtElementTypeInt64: + out << "i64"; + break; + case kLiteRtElementTypeFloat16: + out << "f16"; + break; + case kLiteRtElementTypeInt16: + out << "i16"; + break; + case kLiteRtElementTypeInt8: + out << "i8"; + break; + case kLiteRtElementTypeUInt8: + out << "ui8"; + break; + case kLiteRtElementTypeBool: + out << "i1"; + break; + default: + out << "UKNNOWN_ELEMENT_TYPE: " << type; + } +} + +void Dump(const LiteRtRankedTensorType& type, std::ostream& out) { + out << "<"; + for (int i = 0; i < type.layout.rank; ++i) { + out << type.layout.dimensions[i] << "x"; + } + Dump(type.element_type, out); + out << ">"; +} + +void Dump(const LiteRtTensorT& tensor, std::ostream& out) { + out << "LiteRtTensor : "; + DumpNode(tensor, out); + out << " [ "; + if (tensor.DefiningOp() == nullptr) { + out << "*"; + } else { + DumpNode(*tensor.DefiningOp(), out); + } + out << " ] "; + + out << "("; + for (auto it = tensor.Users().begin(); it < tensor.Users().end(); ++it) { + DumpNode(**it, out); + if (it != tensor.Users().end() - 1) { + out << ", "; + } + } + out << ")"; + out << "\n"; +} + +void Dump(const LiteRtOpT& op, std::ostream& out) { + out << "LiteRtOp : [ "; + DumpNode(op, out); + out << " ] "; + DumpSignature(op.Inputs(), op.Outputs(), out); + out << "\n"; +} + +void Dump(const LiteRtSubgraphT& subgraph, std::ostream& out) { + constexpr absl::string_view kSubgraphTpl = + "LiteRtSubgraph : [ #ops=%d #tensors=%d ] "; + out << absl::StreamFormat(kSubgraphTpl, subgraph.Ops().size(), + subgraph.Tensors().size()); + DumpSignature(subgraph.Inputs(), subgraph.Outputs(), out); + out << "\n"; +} + +void Dump(const CompilerPlugin& plugin, std::ostream& out) { + constexpr absl::string_view kPluginDumpTpl = + "SocManufacturer: %s\nSocModels: { "; + out << absl::StreamFormat(kPluginDumpTpl, plugin.SocManufacturer()); + + for (auto it = plugin.SocModels().begin(); it < plugin.SocModels().end(); + ++it) { + out << *it; + if (it != plugin.SocModels().end() - 1) { + out << ","; + } + out << " "; + } + + out << "}\n"; +} + +void DumpDLL(void* lib_handle, std::ostream& out) { +#ifndef __ANDROID__ + out << "\n--- Lib Info ---\n"; + if (lib_handle == nullptr) { + out << "Handle is nullptr\n"; + return; + } + + Lmid_t dl_ns_idx; + if (0 != ::dlinfo(lib_handle, RTLD_DI_LMID, &dl_ns_idx)) { + return; + } + + std::string dl_origin; + dl_origin.resize(512); + if (0 != ::dlinfo(lib_handle, RTLD_DI_ORIGIN, dl_origin.data())) { + return; + } + + link_map* lm; + if (0 != ::dlinfo(lib_handle, RTLD_DI_LINKMAP, &lm)) { + return; + } + + out << "Lib Namespace: " << dl_ns_idx << "\n"; + out << "Lib Origin: " << dl_origin << "\n"; + + out << "loaded objects:\n"; + + auto* forward = lm->l_next; + auto* backward = lm->l_prev; + + while (forward != nullptr) { + out << " " << forward->l_name << "\n"; + forward = forward->l_next; + } + + out << "***" << lm->l_name << "\n"; + + while (backward != nullptr) { + out << " " << backward->l_name << "\n"; + backward = backward->l_prev; + } + + out << "\n"; +#endif +} + +void Dump(const LiteRtModelT& model, std::ostream& out) { + out << absl::StreamFormat("LiteRtModel : [ #subgraphs=%d ]\n", + model.Subgraphs().size()); +} + +void DumpOptions(const LiteRtOpT& op, std::ostream& out) { + auto& opts = detail::GetTflOptions(op); + if (opts.value == nullptr) { + out << "null options\n"; + return; + } + switch (op.OpCode()) { + case kLiteRtOpCodeTflAdd: + out << "fused_activation_function: " + << opts.AsAddOptions()->fused_activation_function << "\n"; + break; + case kLiteRtOpCodeTflMul: + out << "fused_activation_function: " + << opts.AsMulOptions()->fused_activation_function << "\n"; + break; + case kLiteRtOpCodeTflBatchMatmul: + out << "adj_x: " << opts.AsBatchMatMulOptions()->adj_x << "\n"; + out << "adj_y: " << opts.AsBatchMatMulOptions()->adj_y << "\n"; + out << "asymmetric_quantize_input: " + << opts.AsBatchMatMulOptions()->asymmetric_quantize_inputs << "\n"; + break; + case kLiteRtOpCodeTflConcatenation: + out << "axis: " << opts.AsConcatenationOptions()->axis << "\n"; + out << "fused_activation_function: " + << opts.AsConcatenationOptions()->fused_activation_function << "\n"; + break; + case kLiteRtOpCodeTflDiv: + out << "fused_activation_function: " + << opts.AsDivOptions()->fused_activation_function << "\n"; + break; + case kLiteRtOpCodeTflFullyConnected: + out << "weights_format: " + << opts.AsFullyConnectedOptions()->weights_format << "\n"; + out << "keep_num_dims: " << opts.AsFullyConnectedOptions()->keep_num_dims + << "\n"; + out << "quantized_bias_type: " + << opts.AsFullyConnectedOptions()->quantized_bias_type << "\n"; + out << "asymmetric_quantize_input: " + << opts.AsFullyConnectedOptions()->asymmetric_quantize_inputs << "\n"; + out << "fused_activation_function: " + << opts.AsFullyConnectedOptions()->fused_activation_function << "\n"; + break; + case kLiteRtOpCodeTflSoftmax: + out << "beta: " << opts.AsSoftmaxOptions()->beta << "\n"; + break; + case kLiteRtOpCodeTflStridedSlice: + out << "begin_mask: " << opts.AsStridedSliceOptions()->begin_mask << "\n"; + out << "end_mask: " << opts.AsStridedSliceOptions()->end_mask << "\n"; + out << "ellipsis_mask: " << opts.AsStridedSliceOptions()->ellipsis_mask + << "\n"; + out << "new_axis_mask: " << opts.AsStridedSliceOptions()->new_axis_mask + << "\n"; + out << "shrink_axis_mask: " + << opts.AsStridedSliceOptions()->shrink_axis_mask << "\n"; + out << "offset: " << opts.AsStridedSliceOptions()->offset << "\n"; + break; + case kLiteRtOpCodeTflSub: + out << "fused_activation_function: " + << opts.AsSubOptions()->fused_activation_function << "\n"; + break; + case kLiteRtOpCodeTflReshape: + out << "new_shape: "; + if (opts.AsReshapeOptions() != nullptr) { + const int32_t* new_shape = opts.AsReshapeOptions()->new_shape.data(); + int32_t new_shape_size = opts.AsReshapeOptions()->new_shape.size(); + for (int i = 0; i < new_shape_size; ++i) { + out << new_shape[i] << " "; + } + } + break; + case kLiteRtOpCodeTflSum: + out << "keepdims: " << opts.AsReducerOptions()->keep_dims << "\n"; + break; + default: + out << "No options for op code: " << op.OpCode(); + break; + } +} + +void Dump(Quantization quantization, std::ostream& out) { + int max_display_count; + switch (quantization.first) { + case kLiteRtQuantizationNone: + return; + case kLiteRtQuantizationPerTensor: + out << absl::StreamFormat(" ", + quantization.second.per_tensor.zero_point, + quantization.second.per_tensor.scale); + return; + case kLiteRtQuantizationPerChannel: + max_display_count = + kMaxDisplayCount < quantization.second.per_channel.num_channels + ? kMaxDisplayCount + : quantization.second.per_channel.num_channels; + out << absl::StreamFormat(" ", quantization.second.per_channel.quantized_dimension); + return; + default: + out << " "; + return; + } +} + +} // namespace litert::internal diff --git a/tflite/experimental/litert/tools/dump.h b/tflite/experimental/litert/tools/dump.h new file mode 100644 index 00000000..41e37fb1 --- /dev/null +++ b/tflite/experimental/litert/tools/dump.h @@ -0,0 +1,71 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_DUMP_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_DUMP_H_ + +#include +#include +#include + +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/compiler/plugin/compiler_plugin.h" +#include "tflite/experimental/litert/core/model/model.h" + +namespace litert::internal { + +// +// LiteRt IR +// + +// Dump details about the given LiteRtOpT to the given stream. +void Dump(const LiteRtOpT& op, std::ostream& out = std::cerr); + +// Dump details about the given LiteRtSubgraphT to the given stream. +void Dump(const LiteRtSubgraphT& subgraph, std::ostream& out = std::cerr); + +// Dump details about the given LiteRtTensorT to the given stream. +void Dump(const LiteRtTensorT& tensor, std::ostream& out = std::cerr); + +// Dump details about the given LiteRtOpCode to the given stream. +void Dump(LiteRtOpCode code, std::ostream& out = std::cerr); + +// Dump details about the given LiteRtElementType to the given stream. +void Dump(LiteRtElementType type, std::ostream& out = std::cerr); + +// Dump details about the given LiteRtRankedTensorType to the given stream. +void Dump(const LiteRtRankedTensorType& type, std::ostream& out = std::cerr); + +// Dump details about the given LiteRtModel to the given stream. +void Dump(const LiteRtModelT& model, std::ostream& out = std::cerr); + +// Dump details about the given quantization params. +void Dump(Quantization quantization, std::ostream& out = std::cerr); + +// Dump details about options +void DumpOptions(const LiteRtOpT& op, std::ostream& out = std::cerr); + +// +// Library Utilities +// + +// Dumps details about the loaded LiteRtCompilerPlugin library. +void Dump(const CompilerPlugin& plugin, std::ostream& out = std::cerr); + +// Dumps details about the dynamic library (see "dlinfo"). +void DumpDLL(void* lib_handle, std::ostream& out = std::cerr); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_DUMP_H_ diff --git a/tflite/experimental/litert/tools/dump_test.cc b/tflite/experimental/litert/tools/dump_test.cc new file mode 100644 index 00000000..f1760625 --- /dev/null +++ b/tflite/experimental/litert/tools/dump_test.cc @@ -0,0 +1,131 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/tools/dump.h" + +#include +#include +#include +#include + +#include +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/test/common.h" + +namespace { + +using ::litert::internal::Dump; +using ::litert::internal::DumpOptions; +using ::litert::testing::LoadTestFileModel; + +TEST(DumpTest, TestDump) { + auto model = LoadTestFileModel("one_mul.tflite"); + + { + std::ostringstream model_dump; + Dump(*model.Get(), model_dump); + EXPECT_EQ(model_dump.view(), "LiteRtModel : [ #subgraphs=1 ]\n"); + } + + { + const LiteRtTensorT& in_tensor = model.Get()->Subgraph(0).Input(0); + std::ostringstream in_tensor_dump; + Dump(in_tensor, in_tensor_dump); + EXPECT_EQ(in_tensor_dump.view(), + "LiteRtTensor : <2x2xf32> [ * ] (TFL_MUL)\n"); + } + + { + const LiteRtTensorT& out_tensor = model.Get()->Subgraph(0).Output(0); + std::ostringstream out_tensor_dump; + Dump(out_tensor, out_tensor_dump); + EXPECT_EQ(out_tensor_dump.view(), + "LiteRtTensor : <2x2xf32> [ TFL_MUL ] ()\n"); + } + + { + const LiteRtOpT& op = model.Get()->Subgraph(0).Op(0); + std::ostringstream op_dump; + Dump(op, op_dump); + EXPECT_EQ(op_dump.view(), + "LiteRtOp : [ TFL_MUL ] (<2x2xf32>, <2x2xf32>) -> <2x2xf32>\n"); + } + + { + const LiteRtSubgraphT& subgraph = model.Get()->Subgraph(0); + std::ostringstream subgraph_dump; + Dump(subgraph, subgraph_dump); + EXPECT_EQ( + subgraph_dump.view(), + "LiteRtSubgraph : [ #ops=1 #tensors=3 ] (<2x2xf32>, <2x2xf32>) -> " + "<2x2xf32>\n"); + } +} + +TEST(DumpTest, TestDumpOptions) { + auto model = LoadTestFileModel("simple_strided_slice_op.tflite"); + const LiteRtOpT& op = model.Get()->Subgraph(0).Op(0); + std::ostringstream op_dump; + DumpOptions(op, op_dump); + EXPECT_EQ(op_dump.view(), + "begin_mask: 0\n" + "end_mask: 0\n" + "ellipsis_mask: 0\n" + "new_axis_mask: 0\n" + "shrink_axis_mask: 0\n" + "offset: 0\n"); +} + +TEST(DumpTest, TestDumpPerTensorQuantization) { + QuantizationDetail per_tensor_detail; + per_tensor_detail.per_tensor.scale = 1.0; + per_tensor_detail.per_tensor.zero_point = 2; + std::ostringstream q_dump; + Dump(std::make_pair(kLiteRtQuantizationPerTensor, per_tensor_detail), q_dump); + EXPECT_EQ(q_dump.view(), " "); +} + +TEST(DumpTest, TestDumpPerChannelQuantization) { + static constexpr size_t kRank = 2; + static constexpr size_t kQuantizedDimension = 1; + static constexpr float kScales[kRank] = {1.0, 2.0}; + static constexpr int64_t kZps[kRank] = {2, 3}; + QuantizationDetail per_channel_detail; + per_channel_detail.per_channel.scales = const_cast(kScales); + per_channel_detail.per_channel.zero_points = const_cast(kZps); + per_channel_detail.per_channel.quantized_dimension = kQuantizedDimension; + per_channel_detail.per_channel.num_channels = kRank; + std::ostringstream q_dump; + Dump(std::make_pair(kLiteRtQuantizationPerChannel, per_channel_detail), + q_dump); + EXPECT_FALSE(q_dump.view().empty()); +} + +TEST(DumpTest, TestDumpNoQuantization) { + QuantizationDetail none_detail; + std::ostringstream q_dump; + Dump(std::make_pair(kLiteRtQuantizationNone, none_detail), q_dump); + EXPECT_TRUE(q_dump.view().empty()); +} + +TEST(DumpTest, TestDumpUnknownQuantization) { + QuantizationDetail detail; + std::ostringstream q_dump; + Dump(std::make_pair(kLiteRtQuantizationBlockWise, detail), q_dump); + EXPECT_EQ(q_dump.view(), " "); +} + +} // namespace diff --git a/tflite/experimental/litert/tools/outstream.h b/tflite/experimental/litert/tools/outstream.h new file mode 100644 index 00000000..c7e7e064 --- /dev/null +++ b/tflite/experimental/litert/tools/outstream.h @@ -0,0 +1,83 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_OUTSTREAM_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_OUTSTREAM_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_logging.h" + +namespace litert::tools { + +using OutStream = std::reference_wrapper; +using OutStreamPtr = std::unique_ptr; + +// Out stream configured by a user by flag. +class UserStream { + public: + // Parse the flag and get a configured stream. + static UserStream MakeFromFlag(absl::string_view flag) { + if (flag == kCerr) { + LITERT_LOG(LITERT_INFO, "Setup cerr stream\n", ""); + return UserStream(std::cerr); + } else if (flag == kCout) { + LITERT_LOG(LITERT_INFO, "Setup cout stream\n", ""); + return UserStream(std::cout); + } else if (flag == kNone) { + LITERT_LOG(LITERT_INFO, "Setup null stream\n", ""); + return UserStream(); + } else { + // File stream. + LITERT_LOG(LITERT_INFO, "Setup file stream\n", ""); + auto ofstream = std::make_unique(); + ofstream->open(flag.data()); + return UserStream(std::move(ofstream)); + } + } + + // Get the actual stream to write to. + OutStream Get() { return used_; } + + // Silent stream. + UserStream() + : stored_(std::make_unique(nullptr)), used_(*stored_) {} + // From reference to external stream (cerr, cout) + explicit UserStream(OutStream ostream) : stored_(nullptr), used_(ostream) {} + // From stream to internalize. + explicit UserStream(OutStreamPtr ostream) + : stored_(std::move(ostream)), used_(*stored_) {} + + UserStream(UserStream&&) = default; + UserStream& operator=(UserStream&&) = default; + + private: + // These are used in the various CLI's flags that configure output streams. + static constexpr absl::string_view kCerr = "--"; + static constexpr absl::string_view kCout = "-"; + static constexpr absl::string_view kNone = "none"; + + OutStreamPtr stored_; + OutStream used_; +}; + +} // namespace litert::tools + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_OUTSTREAM_H_ diff --git a/tflite/experimental/litert/tools/tool_display.cc b/tflite/experimental/litert/tools/tool_display.cc new file mode 100644 index 00000000..cd535fef --- /dev/null +++ b/tflite/experimental/litert/tools/tool_display.cc @@ -0,0 +1,86 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/tools/tool_display.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/tools/outstream.h" + +namespace litert::tools { + +std::string ToolDisplay::MakeLabel(absl::string_view tool_label) { + return absl::StrFormat( + "[LITERT_TOOLS%s] ", + tool_label.empty() ? tool_label : absl::StrFormat(":%s", tool_label)); +} + +std::ostream& ToolDisplay::Display() { return ostream_.Get(); } + +std::ostream& ToolDisplay::Labeled() { + Display() << label_; + return Display(); +} + +std::ostream& ToolDisplay::Indented() { + Display() << "\t"; + return Display(); +} + +void ToolDisplay::Start(const absl::string_view scope_name) { + static constexpr absl::string_view kStartFmt = "Starting %s...\n"; + Labeled() << absl::StreamFormat(kStartFmt, scope_name); +} + +void ToolDisplay::Done(const absl::string_view scope_name) { + static constexpr absl::string_view kDoneFmt = "%s Done!\n"; + Labeled() << ""; + Indented() << absl::StreamFormat(kDoneFmt, scope_name); +} + +void ToolDisplay::Fail() { + Labeled() << ""; + Indented() << "Failed\n"; +} + +ToolDisplay::LoggedScope ToolDisplay::StartS(absl::string_view scope_name) { + return LoggedScope(*this, scope_name); +} + +void ToolDisplay::LoggedScope::Start() { parent_.Start(scope_name_); } + +void ToolDisplay::LoggedScope::Done() { parent_.Done(scope_name_); } + +ToolDisplay::LoggedScope::~LoggedScope() { Done(); } + +ToolDisplay::LoggedScope::LoggedScope(ToolDisplay& parent, + absl::string_view scope_name) + : parent_(parent), scope_name_(scope_name) { + Start(); +} + +static constexpr absl::string_view kArt = R"( + __ _ __ ____ __ + / / (_/ /____ / __ \/ /_ + / / / / __/ _ \/ /_/ / __/ + / /___/ / /_/ __/ _, _/ /_ +/_____/_/\__/\___/_/ |_|\__/ +)"; + +void DumpPreamble(ToolDisplay& display) { display.Display() << kArt << "\n"; } + +} // namespace litert::tools diff --git a/tflite/experimental/litert/tools/tool_display.h b/tflite/experimental/litert/tools/tool_display.h new file mode 100644 index 00000000..fd92a2aa --- /dev/null +++ b/tflite/experimental/litert/tools/tool_display.h @@ -0,0 +1,102 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_TOOL_DISPLAY_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_TOOL_DISPLAY_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/tools/outstream.h" + +namespace litert::tools { + +// Utility class for interactive logging for usage in command line tools only. +// Allows user to explicitly set target stream. +class ToolDisplay { + public: + using Ptr = std::unique_ptr; + // Construct configured ToolDisplay. Label is used for prefixing dumps + // in "LabeledStream". + explicit ToolDisplay(UserStream&& ostream, absl::string_view tool_label = "") + : label_(MakeLabel(tool_label)), + ostream_(std::forward(ostream)) {} + explicit ToolDisplay(OutStream ostream, absl::string_view tool_label = "") + : label_(MakeLabel(tool_label)), ostream_(UserStream(ostream)) {} + + ToolDisplay(const ToolDisplay&) = delete; + ToolDisplay& operator=(const ToolDisplay&) = delete; + ToolDisplay(ToolDisplay&&) = delete; + ToolDisplay& operator=(ToolDisplay&&) = delete; + + // Get out stream. + std::ostream& Display(); + + // Get Display with label prefix. + std::ostream& Labeled(); + + // Get Display with indent. + std::ostream& Indented(); + + // Log string indicating a sub rountine is beginning. + void Start(absl::string_view scope_name); + + // Log string indicating a sub rountine is done and succeeded. + void Done(absl::string_view scope_name = ""); + + // Log string indicating a sub rountine is done and failed. + void Fail(); + + // Logs "start/finish" messages automatically. + class LoggedScope { + friend class ToolDisplay; + + public: + LoggedScope(const LoggedScope&) = delete; + LoggedScope& operator=(const LoggedScope&) = delete; + LoggedScope(LoggedScope&&) = delete; + LoggedScope& operator=(LoggedScope&&) = delete; + + ~LoggedScope(); + + private: + explicit LoggedScope(ToolDisplay& parent, absl::string_view scope_name); + + void Start(); + void Done(); + + ToolDisplay& parent_; + // These should all be from literals. + absl::string_view scope_name_; + }; + + // Get object that prints a start message and an exit message + // automatically when it goes out of scope. + [[maybe_unused]] LoggedScope StartS(absl::string_view scope_name); + + private: + static std::string MakeLabel(absl::string_view tool_label); + std::string label_; + UserStream ostream_; +}; + +// Print art and info at cli startup. +void DumpPreamble(ToolDisplay& display); + +} // namespace litert::tools + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_TOOL_DISPLAY_H_ diff --git a/tflite/experimental/litert/tools/tool_display_test.cc b/tflite/experimental/litert/tools/tool_display_test.cc new file mode 100644 index 00000000..b7c7eb92 --- /dev/null +++ b/tflite/experimental/litert/tools/tool_display_test.cc @@ -0,0 +1,99 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/tools/tool_display.h" + +#include + +#include +#include +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" + +namespace { + +using ::litert::tools::ToolDisplay; +using ::testing::EndsWith; +using ::testing::StartsWith; + +static constexpr absl::string_view kToolName = "test-tool"; +static constexpr absl::string_view kLabel = "[LITERT_TOOLS:test-tool]"; +static constexpr absl::string_view kStartLabel = "Test Routine"; +static constexpr absl::string_view kDisplayInfo = "info"; + +TEST(TestToolDisplay, Display) { + std::stringstream out; + ToolDisplay display(out, kToolName); + display.Display() << kDisplayInfo; + EXPECT_EQ(out.view(), kDisplayInfo); +} + +TEST(TestToolDisplay, Indented) { + std::stringstream out; + ToolDisplay display(out, kToolName); + display.Indented() << kDisplayInfo; + EXPECT_EQ(out.view(), absl::StrFormat("\t%s", kDisplayInfo)); +} + +TEST(TestToolDisplay, Labeled) { + std::stringstream out; + ToolDisplay display(out, kToolName); + display.Labeled() << kDisplayInfo; + EXPECT_EQ(out.view(), absl::StrFormat("%s %s", kLabel, kDisplayInfo)); +} + +TEST(TestToolDisplay, LabeledNoToolName) { + std::stringstream out; + ToolDisplay display(out); + display.Labeled() << kDisplayInfo; + EXPECT_EQ(out.view(), + absl::StrFormat("%s %s", "[LITERT_TOOLS]", kDisplayInfo)); +} + +TEST(TestToolDisplay, Start) { + std::stringstream out; + ToolDisplay display(out, kToolName); + display.Start(kStartLabel); + EXPECT_EQ(out.view(), + absl::StrFormat("%s Starting %s...\n", kLabel, kStartLabel)); +} + +TEST(TestToolDisplay, Done) { + std::stringstream out; + ToolDisplay display(out, kToolName); + display.Done(kStartLabel); + EXPECT_EQ(out.view(), + absl::StrFormat("%s \t%s Done!\n", kLabel, kStartLabel)); +} + +TEST(TestToolDisplay, Fail) { + std::stringstream out; + ToolDisplay display(out, kToolName); + display.Fail(); + EXPECT_EQ(out.view(), absl::StrFormat("%s \tFailed\n", kLabel)); +} + +TEST(TestLoggedScope, EnterExit) { + std::stringstream out; + ToolDisplay display(out, kToolName); + { + auto s = display.StartS(kStartLabel); + } + EXPECT_THAT(out.view(), StartsWith(absl::StrFormat("%s Starting %s...\n", + kLabel, kStartLabel))); + EXPECT_THAT(out.view(), EndsWith(absl::StrFormat("%s \t%s Done!\n", kLabel, + kStartLabel))); +} + +} // namespace diff --git a/tflite/experimental/litert/vendors/c/BUILD b/tflite/experimental/litert/vendors/c/BUILD new file mode 100644 index 00000000..9337d1d7 --- /dev/null +++ b/tflite/experimental/litert/vendors/c/BUILD @@ -0,0 +1,68 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "litert_compiler_plugin", + hdrs = ["litert_compiler_plugin.h"], + deps = [ + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_model", + ], +) + +cc_library( + name = "litert_compiler_plugin_api", + hdrs = ["litert_compiler_plugin_api.h"], + deps = [ + ":litert_compiler_plugin", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_model", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "litert_dispatch_c_api", + hdrs = [ + "litert_dispatch.h", + "litert_dispatch_api.h", + ], + deps = [ + "//tflite/experimental/litert/c:litert_any", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "//tflite/experimental/litert/runtime/dispatch", + ], +) + +# This test verifies that the C API header files can build via C compiler. +cc_test( + name = "litert_vendor_c_api_common_test", + srcs = ["litert_vendor_c_api_common_test.c"], + copts = ["--std=c11"], + linkopts = ["-ldl"], + deps = [ + ":litert_compiler_plugin", + ":litert_compiler_plugin_api", + ":litert_dispatch_c_api", + ], +) + +exports_files(srcs = glob(["litert_*.h"])) diff --git a/tflite/experimental/litert/vendors/c/litert_compiler_plugin.h b/tflite/experimental/litert/vendors/c/litert_compiler_plugin.h new file mode 100644 index 00000000..259dd5ee --- /dev/null +++ b/tflite/experimental/litert/vendors/c/litert_compiler_plugin.h @@ -0,0 +1,102 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_H_ + +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +LITERT_DEFINE_HANDLE(LiteRtCompilerPlugin); + +// Artifact produced from compiling a selected partition of ops. +LITERT_DEFINE_HANDLE(LiteRtCompiledResult); + +// +// Plugin +// + +LiteRtStatus LiteRtGetCompilerPluginVersion(LiteRtApiVersion* api_version); + +// Name associated with the manufacturer this plugin relates to (e.g, +// GoogleTensor, Qualcomm). +const char* LiteRtGetCompilerPluginSocManufacturer(); + +LiteRtStatus LiteRtCreateCompilerPlugin(LiteRtCompilerPlugin* compiler_plugin); + +void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin); + +// Number of SoC models supported by this plugin. +LiteRtStatus LiteRtGetNumCompilerPluginSupportedSocModels( + LiteRtCompilerPlugin compiler_plugin, + LiteRtParamIndex* num_supported_soc_models); + +// Gets the name of the SoC model at the given index. The memory +// associated with the returned name is owned by the plugin. +LiteRtStatus LiteRtGetCompilerPluginSupportedSocModel( + LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, + const char** soc_model_name); + +// Select desired ops for compilation. This will only be called once +// per subgraph, plugins should select all supportable ops. +LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, + LiteRtSubgraph subgraph, + LiteRtOpList selected_ops); + +// Prepare result to pass to the runtime for given partition and, optionally, +// for a given SoC model (parameter `soc_model` can be NULL to specify a default +// SoC model). The given subgraphs are valid sub-DAG within the ops selected in +// partition step. +LiteRtStatus LiteRtCompilerPluginCompile(LiteRtCompilerPlugin compiler_plugin, + const char* soc_model, + LiteRtSubgraphArray partitions, + LiteRtParamIndex num_partitions, + LiteRtCompiledResult* compiled_result); + +// +// Compiled Partition +// + +void LiteRtDestroyCompiledResult(LiteRtCompiledResult result); + +// Get serialized result to compiled modules available to all custom ops. +// This could be one module with multiple entry points or multiple modules +// concat together. +LiteRtStatus LiteRtGetCompiledResultByteCode( + LiteRtCompiledResult compiled_result, const void** byte_code, + size_t* byte_code_size); + +// Get info to embed in a particular custom op. This could be any opaque data +// parsed in the custom op. +LiteRtStatus LiteRtGetCompiledResultCallInfo( + LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, + const void** call_info, size_t* call_info_size); + +// Get the number of calls that will be made to the HAL for this graph. +// This should equal the number of partitions given for compilation which +// is equal to the number of custom ops in the final model. +LiteRtStatus LiteRtGetNumCompiledResultCalls( + LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_H_ diff --git a/tflite/experimental/litert/vendors/c/litert_compiler_plugin_api.h b/tflite/experimental/litert/vendors/c/litert_compiler_plugin_api.h new file mode 100644 index 00000000..03aab6b7 --- /dev/null +++ b/tflite/experimental/litert/vendors/c/litert_compiler_plugin_api.h @@ -0,0 +1,130 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_API_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_API_H_ + +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/vendors/c/litert_compiler_plugin.h" + +// Wrapper for dynamically loaded LiteRtCompilerPlugin library. See +// "litert_compiler_plugin.h". + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// +// Api Interface +// + +typedef LiteRtStatus (*LiteRtGetCompilerPluginVersionT)(LiteRtApiVersion*); + +typedef const char* (*LiteRtGetCompilerPluginSocManufacturerT)(); + +typedef LiteRtStatus (*LiteRtCreateCompilerPluginT)(LiteRtCompilerPlugin*); + +typedef void (*LiteRtDestroyCompilerPluginT)(LiteRtCompilerPlugin); + +typedef LiteRtStatus (*LiteRtGetNumCompilerPluginSupportedSocModelsT)( + LiteRtCompilerPlugin, LiteRtParamIndex*); + +typedef LiteRtStatus (*LiteRtGetCompilerPluginSupportedSocModelT)( + LiteRtCompilerPlugin, LiteRtParamIndex soc_model_idx, + const char** soc_moel_idx); + +typedef LiteRtStatus (*LiteRtCompilerPluginPartitionT)( + LiteRtCompilerPlugin, LiteRtSubgraph subgraph, LiteRtOpList selected_ops); + +typedef LiteRtStatus (*LiteRtCompilerPluginCompileT)( + LiteRtCompilerPlugin, const char* soc_model, LiteRtSubgraphArray partitions, + LiteRtParamIndex num_partitions, LiteRtCompiledResult* compiled_result); + +typedef void (*LiteRtDestroyCompiledResultT)(LiteRtCompiledResult); + +typedef LiteRtStatus (*LiteRtGetCompiledResultByteCodeT)( + LiteRtCompiledResult, const void** byte_code, size_t* byte_code_size); + +typedef LiteRtStatus (*LiteRtGetCompiledResultCallInfoT)( + LiteRtCompiledResult, LiteRtParamIndex call_idx, const void** call_info, + size_t* call_info_size); + +typedef LiteRtStatus (*LiteRtGetNumCompiledResultCallsT)( + LiteRtCompiledResult, LiteRtParamIndex* num_calls); + +// +// Function Pointer Container +// + +// Wraps all resolved functions from api interface. +struct LiteRtCompilerPluginApi { + LiteRtGetCompilerPluginVersionT get_compiler_plugin_version; + LiteRtGetCompilerPluginSocManufacturerT get_compiler_plugin_soc_manufacturer; + LiteRtCreateCompilerPluginT create_compiler_plugin; + LiteRtDestroyCompilerPluginT destroy_compiler_plugin; + + LiteRtGetNumCompilerPluginSupportedSocModelsT + get_num_compiler_plugin_supported_models; + LiteRtGetCompilerPluginSupportedSocModelT + get_compiler_plugin_supported_soc_model; + + LiteRtCompilerPluginPartitionT compiler_plugin_partition; + LiteRtCompilerPluginCompileT compiler_plugin_compile; + + LiteRtDestroyCompiledResultT destroy_compiled_result; + LiteRtGetCompiledResultByteCodeT get_compiled_result_byte_code; + LiteRtGetCompiledResultCallInfoT get_compiled_result_call_info; + LiteRtGetNumCompiledResultCallsT get_compiled_result_num_calls; +}; + +#ifdef __cplusplus +} + +#include "absl/strings/string_view.h" + +static constexpr absl::string_view kLiteRtGetCompilerPluginVersion = + "LiteRtGetCompilerPluginVersion"; +static constexpr absl::string_view kLiteRtGetCompilerPluginSocManufacturer = + "LiteRtGetCompilerPluginSocManufacturer"; +static constexpr absl::string_view + kLiteRtGetNumCompilerPluginSupportedSocModels = + "LiteRtGetNumCompilerPluginSupportedSocModels"; +static constexpr absl::string_view kLiteRtGetCompilerPluginSupportedSocModel = + "LiteRtGetCompilerPluginSupportedSocModel"; + +static constexpr absl::string_view kLiteRtCreateCompilerPlugin = + "LiteRtCreateCompilerPlugin"; +static constexpr absl::string_view kLiteRtDestroyCompilerPlugin = + "LiteRtDestroyCompilerPlugin"; + +static constexpr absl::string_view kLiteRtCompilerPluginPartition = + "LiteRtCompilerPluginPartition"; +static constexpr absl::string_view kLiteRtCompilerPluginCompile = + "LiteRtCompilerPluginCompile"; + +static constexpr absl::string_view kLiteRtDestroyCompiledResult = + "LiteRtDestroyCompiledResult"; +static constexpr absl::string_view kLiteRtGetCompiledResultByteCode = + "LiteRtGetCompiledResultByteCode"; +static constexpr absl::string_view kLiteRtGetCompiledResultCallInfo = + "LiteRtGetCompiledResultCallInfo"; +static constexpr absl::string_view kLiteRtGetNumCompiledResultCalls = + "LiteRtGetNumCompiledResultCalls"; + +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_API_H_ diff --git a/tflite/experimental/litert/vendors/c/litert_dispatch.h b/tflite/experimental/litert/vendors/c/litert_dispatch.h new file mode 100644 index 00000000..3f992d60 --- /dev/null +++ b/tflite/experimental/litert/vendors/c/litert_dispatch.h @@ -0,0 +1,275 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_H_ + +#include +#include +#include + +#include "tflite/experimental/litert/c/litert_any.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_event.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// ///////////////////////////////////////////////////////////////////////////// +// Basic Execution API +// ///////////////////////////////////////////////////////////////////////////// + +LITERT_DEFINE_HANDLE(LiteRtDispatchDeviceContext); +LITERT_DEFINE_HANDLE(LiteRtDispatchInvocationContext); + +typedef uint64_t LiteRtTensorBufferHandle; + +typedef enum LiteRtDispatchCapabilities { + kLiteRtDispatchCapabilitiesNone = 0, + kLiteRtDispatchCapabilitiesBasic = 1, // The vendor supports the Basic API + kLiteRtDispatchCapabilitiesAsync = 2, // The vendor supports the Async API + kLiteRtDispatchCapabilitiesGraph = 4, // The vendor supports the Graph API +} LiteRtDispatchCapabilities; + +// Types of executable that can run on the HW accelerators. +typedef enum LiteRtDispatchExecutableType { + kLiteRtDispatchExecutableTypeUnknown = 0, + kLiteRtDispatchExecutableTypeDspLibrary = 1, // DSP library + kLiteRtDispatchExecutableTypeMlModel = 2, // Vendor-specific ML model +} LiteRtDispatchExecutableType; + +typedef struct LiteRtDispatchOption { + const char* name; + LiteRtAny value; +} LiteRtDispatchOption; + +// This option can be used to specify a directory from where to load shared +// libraries. +static const char* kDispatchOptionSharedLibraryDir = "shared_library_dir"; + +// Initialize the Dispatch API runtime. +// +// This function should be called before calling any other Dispatch API +// functions. +LiteRtStatus LiteRtDispatchInitialize(const LiteRtDispatchOption* options, + int num_options); + +// Return the version of the Dispatch API runtime. +LiteRtStatus LiteRtDispatchGetApiVersion(LiteRtApiVersion* api_version); + +// Return the vendor id of the Dispatch API runtime. +// +// This function returns a pointer to a statically allocated string that is the +// ID of vendor providing the Dispatch API runtime. +LiteRtStatus LiteRtDispatchGetVendorId(const char** vendor_id); + +// Return the build ID of the Dispatch API runtime. +// +// This function returns a pointer to a statically allocated string that is the +// ID of the Dispatch API runtime build. +LiteRtStatus LiteRtDispatchGetBuildId(const char** build_id); + +// Return the capabilities supported by the Dispatch API runtime as a set of the +// values specified in LiteRtDispatchCapabilities. +LiteRtStatus LiteRtDispatchGetCapabilities(int* capabilities); + +// Create a `LiteRtDispatchDeviceContext` object. +// +// The returned object is used to talk with the underlying HW. The caller owns +// the memory associated with the context and should call +// LiteRtDispatchDeviceContextDestroy() to release it. Return NULL in case of +// error. +LiteRtStatus LiteRtDispatchDeviceContextCreate( + LiteRtDispatchDeviceContext* device_context); + +// Release a `LiteRtDispatchDeviceContext` object. +// +// The given context should be release only after releasing all associated +// objects. +LiteRtStatus LiteRtDispatchDeviceContextDestroy( + LiteRtDispatchDeviceContext device_context); + +// Given a tensor type for an invocation context input, obtain the attributes +// the HW requires for the associated tensor buffer. The returned +// `tensor_buffer_requirements` object is owned by the caller. +LiteRtStatus LiteRtDispatchGetInputRequirements( + LiteRtDispatchInvocationContext invocation_context, int input_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements); + +// Given a tensor type for an invocation context output, obtain the attributes +// the HW requires for the associated tensor buffer. The returned +// `tensor_buffer_requirements` object is owned by the caller. +LiteRtStatus LiteRtDispatchGetOutputRequirements( + LiteRtDispatchInvocationContext invocation_context, int output_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements); + +// Registers a buffer with the given device context. +// Note: The memory backing the buffer should be valid until +// `LiteRtDispatchUnregisterTensorBuffer` is called. +LiteRtStatus LiteRtDispatchRegisterTensorBuffer( + LiteRtDispatchDeviceContext device_context, + LiteRtTensorBuffer tensor_buffer, + LiteRtTensorBufferHandle* tensor_buffer_handle); + +// Unregisters the registered buffer associated with the given +// `LiteRtTensorBufferHandle`. +// Note: The registered `LiteRtTensorBufferHandle` is supposed to be +// unregistered with this function before the associated `ThrContext` is deleted +// by calling `LiteRtDispatchDeviceContextDestroy`. +LiteRtStatus LiteRtDispatchUnregisterTensorBuffer( + LiteRtDispatchDeviceContext device_context, + LiteRtTensorBufferHandle tensor_buffer_handle); + +// Create an invocation context to run a given function from a given +// executable. Parameter `function_name` is required if the provided executable +// includes multiple functions. +LiteRtStatus LiteRtDispatchInvocationContextCreate( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType exec_type, const void* exec_bytecode_ptr, + size_t exec_bytecode_size, const char* function_name, int num_inputs, + int num_outputs, LiteRtDispatchInvocationContext* invocation_context); + +LiteRtStatus LiteRtDispatchInvocationContextDestroy( + LiteRtDispatchInvocationContext invocation_context); + +LiteRtStatus LiteRtDispatchAttachInput( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle); + +LiteRtStatus LiteRtDispatchAttachOutput( + LiteRtDispatchInvocationContext invocation_context, int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle); + +LiteRtStatus LiteRtDispatchDetachInput( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle); + +LiteRtStatus LiteRtDispatchDetachOutput( + LiteRtDispatchInvocationContext invocation_context, int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle); + +LiteRtStatus LiteRtDispatchInvoke( + LiteRtDispatchInvocationContext invocation_context); + +// ///////////////////////////////////////////////////////////////////////////// +// Async Execution API +// ///////////////////////////////////////////////////////////////////////////// + +LiteRtStatus LiteRtDispatchAttachInputEvent( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtEvent input_event); + +LiteRtStatus LiteRtDispatchInvokeAsync( + LiteRtDispatchInvocationContext invocation_context, int num_output_events, + LiteRtEvent* output_events); + +// ///////////////////////////////////////////////////////////////////////////// +// Graph Execution API +// ///////////////////////////////////////////////////////////////////////////// + +typedef uint64_t LiteRtDispatchNodeId; +typedef uint64_t LiteRtDispatchEdgeId; +typedef uint64_t LiteRtDispatchExecutableHandle; + +LITERT_DEFINE_HANDLE(LiteRtDispatchGraph); + +// Types of graph nodes. +typedef enum LiteRtDispatchNodeType { + kLiteRtDispatchNodeTypeUnknown = 0, + kLiteRtDispatchNodeTypeDsp = + 1, // Can execute both ML models and Dsp libraries + kLiteRtDispatchNodeTypeNpu = 2, // Can execute only ML models +} LiteRtDispatchNodeType; + +LiteRtStatus LiteRtDispatchGraphCreate( + LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph** graph); + +LiteRtStatus LiteRtDispatchGraphDestroy(LiteRtDispatchGraph* graph); + +// Add a compute node to a given graph. Parameter node_id should be unique to +// the graph. +LiteRtStatus LiteRtDispatchAddNode(LiteRtDispatchGraph* graph, + LiteRtDispatchNodeId node_id, + LiteRtDispatchNodeType node_type); + +// Add an edge a given graph. Parameter edge_id should be unique to the graph. +LiteRtStatus LiteRtDispatchAddEdge(LiteRtDispatchGraph* graph, + LiteRtDispatchEdgeId edge_id); + +// Connect a given node's input. +LiteRtStatus LiteRtDispatchConnectNodeInput(LiteRtDispatchGraph* graph, + LiteRtDispatchNodeId node_id, + int input_index, + LiteRtDispatchEdgeId edge_id); + +// Connect a given node's output. +LiteRtStatus LiteRtDispatchConnectNodeOutput(LiteRtDispatchGraph* graph, + LiteRtDispatchNodeId node_id, + int output_index, + LiteRtDispatchEdgeId edge_id); + +// Connect a given graph's input. +LiteRtStatus LiteRtDispatchConnectGraphInput(LiteRtDispatchGraph* graph, + int input_index, + LiteRtDispatchEdgeId edge_id); + +// Connect a given graph's output. +LiteRtStatus LiteRtDispatchConnectGraphOutput(LiteRtDispatchGraph* graph, + int output_index, + LiteRtDispatchEdgeId edge_id); + +LiteRtStatus LiteRtDispatchLoadExecutable( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType type, const void* bytecode, + size_t bytecode_size, LiteRtDispatchExecutableHandle* exec_handle); + +LiteRtStatus LiteRtDispatchUnloadExecutable( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableHandle exec_handle); + +// Assign an executable function to a graph node. Parameter `function_name` is +// mandatory if the given executable includes multiple functions. +LiteRtStatus LiteRtDispatchAssignNodeFunction( + LiteRtDispatchGraph* graph, LiteRtDispatchNodeId node_id, + LiteRtDispatchExecutableHandle exec_handle, const char* function_name); + +// Add an annotation to an entire graph. +LiteRtStatus LiteRtDispatchAnnotateGraph(LiteRtDispatchGraph* graph, + const char* key, const char* value); + +// Add an annotation to a specified node. +LiteRtStatus LiteRtDispatchAnnotateNode(LiteRtDispatchGraph* graph, + LiteRtDispatchNodeId node_id, + const char* key, const char* value); + +// Add an annotation to a specified edge. +LiteRtStatus LiteRtDispatchAnnotateEdge(LiteRtDispatchGraph* graph, + LiteRtDispatchEdgeId edge_id, + const char* key, const char* value); + +LiteRtStatus LiteRtDispatchInvocationContextCreateFromGraph( + LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph* graph, + LiteRtDispatchInvocationContext* invocation_context); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_H_ diff --git a/tflite/experimental/litert/vendors/c/litert_dispatch_api.h b/tflite/experimental/litert/vendors/c/litert_dispatch_api.h new file mode 100644 index 00000000..c4605ba1 --- /dev/null +++ b/tflite/experimental/litert/vendors/c/litert_dispatch_api.h @@ -0,0 +1,222 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_API_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_API_H_ + +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_event.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// ///////////////////////////////////////////////////////////////////////////// + +typedef LiteRtStatus (*LiteRtDispatchInitializeT)( + const LiteRtDispatchOption* options, int num_options); + +typedef LiteRtStatus (*LiteRtDispatchGetVendorIdT)(const char** vendor_id); + +typedef LiteRtStatus (*LiteRtDispatchGetBuildIdT)(const char** build_id); + +typedef LiteRtStatus (*LiteRtDispatchGetCapabilitiesT)(int* capabilities); + +typedef LiteRtStatus (*LiteRtDispatchDeviceContextCreateT)( + LiteRtDispatchDeviceContext* device_context); + +typedef LiteRtStatus (*LiteRtDispatchDeviceContextDestroyT)( + LiteRtDispatchDeviceContext device_context); + +typedef LiteRtStatus (*LiteRtDispatchGetInputRequirementsT)( + LiteRtDispatchInvocationContext invocation_context, int input_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements); + +typedef LiteRtStatus (*LiteRtDispatchGetOutputRequirementsT)( + LiteRtDispatchInvocationContext invocation_context, int output_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements); + +typedef LiteRtStatus (*LiteRtDispatchRegisterTensorBufferT)( + LiteRtDispatchDeviceContext device_context, + LiteRtTensorBuffer tensor_buffer, + LiteRtTensorBufferHandle* tensor_buffer_handle); + +typedef LiteRtStatus (*LiteRtDispatchUnregisterTensorBufferT)( + LiteRtDispatchDeviceContext device_context, + LiteRtTensorBufferHandle handle); + +typedef LiteRtStatus (*LiteRtDispatchInvocationContextCreateT)( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType exec_type, const void* exec_bytecode, + size_t exec_bytecode_size, const char* function_name, int num_inputs, + int num_outputs, LiteRtDispatchInvocationContext* invocation_context); + +typedef LiteRtStatus (*LiteRtDispatchInvocationContextDestroyT)( + LiteRtDispatchInvocationContext invocation_context); + +typedef LiteRtStatus (*LiteRtDispatchAttachInputT)( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle); + +typedef LiteRtStatus (*LiteRtDispatchAttachOutputT)( + LiteRtDispatchInvocationContext invocation_context, int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle); + +typedef LiteRtStatus (*LiteRtDispatchDetachInputT)( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle); + +typedef LiteRtStatus (*LiteRtDispatchDetachOutputT)( + LiteRtDispatchInvocationContext invocation_context, int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle); + +typedef LiteRtStatus (*LiteRtDispatchInvokeT)( + LiteRtDispatchInvocationContext invocation_context); + +typedef struct LiteRtDispatchInterface { + LiteRtDispatchInitializeT initialize; + LiteRtDispatchGetVendorIdT get_vendor_id; + LiteRtDispatchGetBuildIdT get_build_id; + LiteRtDispatchGetCapabilitiesT get_capabilities; + LiteRtDispatchDeviceContextCreateT device_context_create; + LiteRtDispatchDeviceContextDestroyT device_context_destroy; + LiteRtDispatchGetInputRequirementsT get_input_requirements; + LiteRtDispatchGetOutputRequirementsT get_output_requirements; + LiteRtDispatchRegisterTensorBufferT register_tensor_buffer; + LiteRtDispatchUnregisterTensorBufferT unregister_tensor_buffer; + LiteRtDispatchInvocationContextCreateT invocation_context_create; + LiteRtDispatchInvocationContextDestroyT invocation_context_destroy; + LiteRtDispatchAttachInputT attach_input; + LiteRtDispatchAttachOutputT attach_output; + LiteRtDispatchDetachInputT detach_input; + LiteRtDispatchDetachOutputT detach_output; + LiteRtDispatchInvokeT invoke; +} LiteRtDispatchInterface; + +// ///////////////////////////////////////////////////////////////////////////// + +typedef LiteRtStatus (*LiteRtDispatchAttachInputEventT)( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtEvent input_event); + +typedef LiteRtStatus (*LiteRtDispatchInvokeAsyncT)( + LiteRtDispatchInvocationContext invocation_context, int num_output_events, + LiteRtEvent* output_events); + +typedef struct LiteRtDispatchAsyncInterface { + LiteRtDispatchAttachInputEventT attach_input_event; + LiteRtDispatchInvokeAsyncT invoke_async; +} LiteRtDispatchAsyncInterface; + +// ///////////////////////////////////////////////////////////////////////////// + +typedef LiteRtStatus (*LiteRtDispatchGraphCreateT)( + LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph* graph); + +typedef LiteRtStatus (*LiteRtDispatchGraphDestroyT)(LiteRtDispatchGraph graph); + +typedef LiteRtStatus (*LiteRtDispatchAddNodeT)( + LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, + LiteRtDispatchNodeType node_type); + +typedef LiteRtStatus (*LiteRtDispatchAddEdgeT)(LiteRtDispatchGraph graph, + LiteRtDispatchEdgeId edge_id); + +typedef LiteRtStatus (*LiteRtDispatchConnectNodeInputT)( + LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, int input_index, + LiteRtDispatchEdgeId edge_id); + +typedef LiteRtStatus (*LiteRtDispatchConnectNodeOutputT)( + LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, int output_index, + LiteRtDispatchEdgeId edge_id); + +typedef LiteRtStatus (*LiteRtDispatchConnectGraphInputT)( + LiteRtDispatchGraph graph, int input_index, LiteRtDispatchEdgeId edge_id); + +typedef LiteRtStatus (*LiteRtDispatchConnectGraphOutputT)( + LiteRtDispatchGraph graph, int output_index, LiteRtDispatchEdgeId edge_id); + +typedef LiteRtStatus (*LiteRtDispatchLoadExecutableT)( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType type, const void* bytecode_ptr, + size_t bytecode_size, LiteRtDispatchExecutableHandle* exec_handle); + +typedef LiteRtStatus (*LiteRtDispatchUnloadExecutableT)( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableHandle exec_handle); + +typedef LiteRtStatus (*LiteRtDispatchAssignNodeFunctionT)( + LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, + LiteRtDispatchExecutableHandle exec_handle, const char* function_name); + +typedef LiteRtStatus (*LiteRtDispatchInvocationContextCreateFromGraphT)( + LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph, + LiteRtDispatchInvocationContext* invocation_context); + +typedef LiteRtStatus (*LiteRtDispatchAnnotateGraphT)(LiteRtDispatchGraph graph, + const char* key, + const char* value); + +typedef LiteRtStatus (*LiteRtDispatchAnnotateNodeT)( + LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, const char* key, + const char* value); + +typedef LiteRtStatus (*LiteRtDispatchAnnotateEdgeT)( + LiteRtDispatchGraph graph, LiteRtDispatchEdgeId edge_id, const char* key, + const char* value); + +typedef struct LiteRtDispatchGraphInterface { + LiteRtDispatchGraphCreateT graph_create; + LiteRtDispatchGraphDestroyT graph_destroy; + LiteRtDispatchAddNodeT add_node; + LiteRtDispatchAddEdgeT add_edge; + LiteRtDispatchConnectNodeInputT connect_node_input; + LiteRtDispatchConnectNodeOutputT connect_node_output; + LiteRtDispatchConnectGraphInputT connect_graph_input; + LiteRtDispatchConnectGraphOutputT connect_graph_output; + LiteRtDispatchLoadExecutableT load_executable; + LiteRtDispatchUnloadExecutableT unload_executable; + LiteRtDispatchAssignNodeFunctionT assign_node_function; + LiteRtDispatchAnnotateGraphT annotate_graph; + LiteRtDispatchAnnotateNodeT annotate_node; + LiteRtDispatchAnnotateEdgeT annotate_edge; + LiteRtDispatchInvocationContextCreateFromGraphT + invocation_context_create_from_graph; +} LiteRtDispatchGraphInterface; + +// ///////////////////////////////////////////////////////////////////////////// + +// FIXME See Vulkan and OpenCL extensions. +typedef struct LiteRtDispatchApi { + LiteRtApiVersion version; + LiteRtDispatchInterface* interface; + LiteRtDispatchAsyncInterface* async_interface; + LiteRtDispatchGraphInterface* graph_interface; +} LiteRtDispatchApi; + +LiteRtStatus LiteRtDispatchGetApi(LiteRtDispatchApi* api); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_API_H_ diff --git a/tflite/experimental/litert/vendors/c/litert_vendor_c_api_common_test.c b/tflite/experimental/litert/vendors/c/litert_vendor_c_api_common_test.c new file mode 100644 index 00000000..03a7c103 --- /dev/null +++ b/tflite/experimental/litert/vendors/c/litert_vendor_c_api_common_test.c @@ -0,0 +1,28 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// This file exists to verify that the below header files can build, link, +// and run as C code. +#ifdef __cplusplus +#error "This file should be compiled as C code, not as C++." +#endif + +// Include all the header files in the litert/c directory. +#include "tflite/experimental/litert/vendors/c/litert_compiler_plugin.h" // NOLINT +#include "tflite/experimental/litert/vendors/c/litert_compiler_plugin_api.h" // NOLINT +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" // NOLINT +#include "tflite/experimental/litert/vendors/c/litert_dispatch_api.h" // NOLINT + +int main(void) { + return 0; +} diff --git a/tflite/experimental/litert/vendors/cc/BUILD b/tflite/experimental/litert/vendors/cc/BUILD new file mode 100644 index 00000000..eeb264fb --- /dev/null +++ b/tflite/experimental/litert/vendors/cc/BUILD @@ -0,0 +1,27 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "litert_compiler_plugin", + hdrs = ["litert_compiler_plugin.h"], + deps = [ + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/vendors/c:litert_compiler_plugin", + ], +) diff --git a/tflite/experimental/litert/vendors/cc/litert_compiler_plugin.h b/tflite/experimental/litert/vendors/cc/litert_compiler_plugin.h new file mode 100644 index 00000000..ab1332cf --- /dev/null +++ b/tflite/experimental/litert/vendors/cc/litert_compiler_plugin.h @@ -0,0 +1,46 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_LITERT_COMPILER_PLUGIN_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_LITERT_COMPILER_PLUGIN_H_ + +#include + +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/vendors/c/litert_compiler_plugin.h" + +namespace litert { + +// Deleter for incomplete compiler plugin type. +struct LiteRtCompilerPluginDeleter { + void operator()(LiteRtCompilerPlugin plugin) { + if (plugin != nullptr) { + LiteRtDestroyCompilerPlugin(plugin); + } + } +}; + +// Smart pointer wrapper for incomplete plugin type. +using PluginPtr = + std::unique_ptr; + +// Initialize a plugin via c-api and wrap result in smart pointer. +inline PluginPtr CreatePlugin() { + LiteRtCompilerPlugin plugin; + LITERT_CHECK_STATUS_OK(LiteRtCreateCompilerPlugin(&plugin)); + return PluginPtr(plugin); +} + +} // namespace litert +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_LITERT_COMPILER_PLUGIN_H_ diff --git a/tflite/experimental/litert/vendors/examples/BUILD b/tflite/experimental/litert/vendors/examples/BUILD new file mode 100644 index 00000000..7524c59a --- /dev/null +++ b/tflite/experimental/litert/vendors/examples/BUILD @@ -0,0 +1,59 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//tflite/experimental/litert/build_common:litert_build_defs.bzl", "litert_dynamic_lib") + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//visibility:private"], +) + +litert_dynamic_lib( + name = "example_plugin", + srcs = ["example_plugin.cc"], + hdrs = ["//tflite/experimental/litert/vendors/c:litert_compiler_plugin.h"], + export_litert_only = True, + linkstatic = 1, + shared_lib_name = "example_plugin_so", + so_name = "libLiteRtCompilerPlugin_Example.so", + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + ], +) + +cc_test( + name = "example_plugin_test", + srcs = [ + "example_plugin_test.cc", + ], + data = ["//tflite/experimental/litert/test:mlir_test_data"], + deps = [ + ":example_plugin", # buildcleaner: keep + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/core/model", + "//tflite/experimental/litert/test:common", + "//tflite/experimental/litert/vendors/cc:litert_compiler_plugin", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tflite/experimental/litert/vendors/examples/example_plugin.cc b/tflite/experimental/litert/vendors/examples/example_plugin.cc new file mode 100644 index 00000000..137d6f46 --- /dev/null +++ b/tflite/experimental/litert/vendors/examples/example_plugin.cc @@ -0,0 +1,194 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/c/litert_compiler_plugin.h" + +// +// Configurations +// + +namespace { + +constexpr char kPluginManufacturer[] = "ExampleSocManufacturer"; +constexpr char kPluginSocModel[] = "ExampleSocModel"; + +} // namespace + +LiteRtStatus LiteRtGetCompilerPluginVersion(LiteRtApiVersion* api_version) { + if (!api_version) { + return kLiteRtStatusErrorInvalidArgument; + } + api_version->major = LITERT_API_VERSION_MAJOR; + api_version->minor = LITERT_API_VERSION_MINOR; + api_version->patch = LITERT_API_VERSION_PATCH; + return kLiteRtStatusOk; +} + +const char* LiteRtGetCompilerPluginSocManufacturer() { + return kPluginManufacturer; +} + +LiteRtStatus LiteRtGetNumCompilerPluginSupportedSocModels( + LiteRtCompilerPlugin compiler_plugin, + LiteRtParamIndex* num_supported_soc_models) { + if (!compiler_plugin || !num_supported_soc_models) { + return kLiteRtStatusErrorInvalidArgument; + } + *num_supported_soc_models = 1; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetCompilerPluginSupportedSocModel( + LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, + const char** soc_model_name) { + if (!compiler_plugin || !soc_model_name) { + return kLiteRtStatusErrorInvalidArgument; + } else if (soc_model_idx != 0) { + return kLiteRtStatusErrorUnsupported; + } + *soc_model_name = kPluginSocModel; + return kLiteRtStatusOk; +} + +// +// Compiled Result Definition +// + +struct LiteRtCompiledResultT { + std::string byte_code; + std::vector per_op_data; +}; + +LiteRtStatus LiteRtGetCompiledResultByteCode( + LiteRtCompiledResult compiled_result, const void** byte_code, + size_t* byte_code_size) { + *byte_code = compiled_result->byte_code.data(); + *byte_code_size = compiled_result->byte_code.size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetCompiledResultCallInfo( + LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, + const void** call_info, size_t* call_info_size) { + if (call_idx >= compiled_result->per_op_data.size()) { + return kLiteRtStatusErrorIndexOOB; + } + + *call_info = compiled_result->per_op_data.at(call_idx).data(); + *call_info_size = compiled_result->per_op_data.at(call_idx).size(); + + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetNumCompiledResultCalls( + LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls) { + *num_calls = compiled_result->per_op_data.size(); + return kLiteRtStatusOk; +} + +void LiteRtDestroyCompiledResult(LiteRtCompiledResult compiled_result) { + delete compiled_result; +} + +// +// Plugin Definition +// + +// Plugins can hold state. +struct LiteRtCompilerPluginT {}; + +LiteRtStatus LiteRtCreateCompilerPlugin(LiteRtCompilerPlugin* compiler_plugin) { + *compiler_plugin = new LiteRtCompilerPluginT; + return kLiteRtStatusOk; +} + +void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin) { + delete compiler_plugin; +} + +LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, + LiteRtSubgraph subgraph, + LiteRtOpList selected_ops) { + ::litert::Subgraph main_subgraph(subgraph); + for (const auto& op : main_subgraph.Ops()) { + if (op.Code() != kLiteRtOpCodeTflMul) { + continue; + } + LITERT_RETURN_STATUS_IF_NOT_OK(LiteRtPushOp(selected_ops, op.Get())); + } + return kLiteRtStatusOk; +} + +namespace { + +LiteRtStatus CompileSinglePartition(LiteRtParamIndex partition_index, + LiteRtSubgraph subgraph, + LiteRtCompiledResultT& result) { + const litert::Subgraph sg(subgraph); + int num_muls_in_partition = 0; + for (const auto& op : sg.Ops()) { + if (op.Code() != kLiteRtOpCodeTflMul) { + return kLiteRtStatusErrorUnsupported; + } + ++num_muls_in_partition; + } + + { + char* byte_code_append; + (void)asprintf(&byte_code_append, + "Partition_%lu_with_%d_muls:", partition_index, + num_muls_in_partition); + result.byte_code.append(byte_code_append); + free(byte_code_append); + } + + { + char* per_op_data; + (void)asprintf(&per_op_data, "Partition_%lu", partition_index); + result.per_op_data.push_back(per_op_data); + free(per_op_data); + } + + return kLiteRtStatusOk; +} + +} // namespace + +LiteRtStatus LiteRtCompilerPluginCompile( + LiteRtCompilerPlugin compiler_plugin, const char* soc_model, + LiteRtSubgraphArray partitions, LiteRtParamIndex num_partitions, + LiteRtCompiledResult* compiled_result) { + LiteRtCompiledResult result = new LiteRtCompiledResultT; + + for (auto i = 0; i < num_partitions; ++i) { + LITERT_RETURN_STATUS_IF_NOT_OK( + CompileSinglePartition(i, partitions[i], *result)); + } + + *compiled_result = result; + + return kLiteRtStatusOk; +} diff --git a/tflite/experimental/litert/vendors/examples/example_plugin_test.cc b/tflite/experimental/litert/vendors/examples/example_plugin_test.cc new file mode 100644 index 00000000..90a56649 --- /dev/null +++ b/tflite/experimental/litert/vendors/examples/example_plugin_test.cc @@ -0,0 +1,97 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/test/common.h" +#include "tflite/experimental/litert/vendors/c/litert_compiler_plugin.h" +#include "tflite/experimental/litert/vendors/cc/litert_compiler_plugin.h" + +namespace litert { +namespace { + +TEST(TestDummyPlugin, GetConfigInfo) { + ASSERT_STREQ(LiteRtGetCompilerPluginSocManufacturer(), + "ExampleSocManufacturer"); + + auto plugin = CreatePlugin(); + + LiteRtParamIndex num_supported_soc_models; + LITERT_ASSERT_STATUS_OK(LiteRtGetNumCompilerPluginSupportedSocModels( + plugin.get(), &num_supported_soc_models)); + ASSERT_EQ(num_supported_soc_models, 1); + + const char* soc_model_name; + LITERT_ASSERT_STATUS_OK(LiteRtGetCompilerPluginSupportedSocModel( + plugin.get(), 0, &soc_model_name)); + ASSERT_STREQ(soc_model_name, "ExampleSocModel"); +} + +TEST(TestCallDummyPlugin, PartitionSimpleMultiAdd) { + auto plugin = CreatePlugin(); + auto model = testing::LoadTestFileModel("simple_multi_op.tflite"); + + LiteRtOpListT selected_op_list; + LITERT_ASSERT_STATUS_OK(LiteRtCompilerPluginPartition( + plugin.get(), model.Subgraph(0)->Get(), &selected_op_list)); + const auto selected_ops = selected_op_list.Vec(); + + ASSERT_EQ(selected_ops.size(), 2); + ASSERT_EQ(selected_ops[0]->OpCode(), kLiteRtOpCodeTflMul); + ASSERT_EQ(selected_ops[1]->OpCode(), kLiteRtOpCodeTflMul); +} + +TEST(TestCallDummyPlugin, CompileMulSubgraph) { + auto plugin = CreatePlugin(); + auto model = testing::LoadTestFileModel("mul_simple.tflite"); + + auto main_subgraph = model.MainSubgraph(); + LiteRtSubgraph litert_subgraph = main_subgraph->Get(); + + LiteRtCompiledResult compiled; + LITERT_ASSERT_STATUS_OK(LiteRtCompilerPluginCompile( + plugin.get(), /*soc_model=*/nullptr, &litert_subgraph, 1, &compiled)); + + const void* byte_code; + size_t byte_code_size; + + LITERT_ASSERT_STATUS_OK( + LiteRtGetCompiledResultByteCode(compiled, &byte_code, &byte_code_size)); + + absl::string_view byte_code_string(reinterpret_cast(byte_code), + byte_code_size); + ASSERT_EQ(byte_code_string, "Partition_0_with_2_muls:"); + + const void* op_data; + size_t op_data_size; + + LITERT_ASSERT_STATUS_OK( + LiteRtGetCompiledResultCallInfo(compiled, 0, &op_data, &op_data_size)); + + absl::string_view op_data_string(reinterpret_cast(op_data), + op_data_size); + ASSERT_EQ(op_data_string, "Partition_0"); + + LiteRtDestroyCompiledResult(compiled); +} + +} // namespace +} // namespace litert diff --git a/tflite/experimental/litert/vendors/google_tensor/dispatch/BUILD b/tflite/experimental/litert/vendors/google_tensor/dispatch/BUILD new file mode 100644 index 00000000..dab1f58c --- /dev/null +++ b/tflite/experimental/litert/vendors/google_tensor/dispatch/BUILD @@ -0,0 +1,85 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//tflite/experimental/litert/build_common:litert_build_defs.bzl", "litert_dynamic_lib") + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +litert_dynamic_lib( + name = "dispatch_api", + srcs = [ + "dispatch_api.cc", + "litert_dispatch_device_context.cc", + "litert_dispatch_invocation_context.cc", + "southbound.cc", + ], + hdrs = [ + "dispatch_api.h", + "litert_dispatch_device_context.h", + "litert_dispatch_graph.h", + "litert_dispatch_invocation_context.h", + "southbound.h", + # copybara:uncomment "@org_tensorflow//third_party/odml/infra/southbound:sb_api.h", + ], + export_litert_only = True, + linkopts = select({ + "@org_tensorflow//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + linkstatic = 1, + shared_lib_name = "dispatch_api_so", + so_name = "libLiteRtDispatch.so", + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/core/util:tensor_type_util", + "//tflite/experimental/litert/vendors/c:litert_dispatch_c_api", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "dispatch_api_google_tensor_test", + srcs = [ + "dispatch_api_google_tensor_test.cc", + ], + data = [ + ":dispatch_api_so", + ], + linkopts = select({ + "@org_tensorflow//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + deps = [ + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "//tflite/experimental/litert/core:filesystem", + "//tflite/experimental/litert/test:common", + "//tflite/experimental/litert/test:simple_model_npu", + "//tflite/experimental/litert/vendors/c:litert_dispatch_c_api", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tflite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.cc b/tflite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.cc new file mode 100644 index 00000000..cbd719cb --- /dev/null +++ b/tflite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.cc @@ -0,0 +1,1194 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h" + +#include +#include +#include +#include +#include + +#if LITERT_HAS_AHWB_SUPPORT +#include +#endif + +#include "absl/strings/string_view.h" +#include "third_party/odml/infra/southbound/sb_api.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_event.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch_api.h" +#include "tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h" +#include "tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h" +#include "tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h" +#include "tflite/experimental/litert/vendors/google_tensor/dispatch/southbound.h" + +namespace { + +constexpr const int VERSION_MAJOR = 0; +constexpr const int VERSION_MINOR = 1; +constexpr const int VERSION_PATCH = 0; + +// We store THR names in a global set as a workaround to b/369144429. +std::set ThrNames; + +absl::string_view ThrNodeIdStr(LiteRtDispatchNodeId node_id) { + auto str = "node_" + std::to_string(node_id); + auto iter = ThrNames.find(str); + if (iter == ThrNames.end()) { + iter = ThrNames.insert(iter, str); + } + return *iter; +} + +absl::string_view ThrEdgeIdStr(LiteRtDispatchEdgeId edge_id) { + auto str = "edge_" + std::to_string(edge_id); + auto iter = ThrNames.find(str); + if (iter == ThrNames.end()) { + iter = ThrNames.insert(iter, str); + } + return *iter; +} + +litert::google_tensor::Southbound* TheSouthbound; +char BuildId[256]; + +} // namespace + +namespace litert { +namespace google_tensor { + +// ///////////////////////////////////////////////////////////////////////////// +// Basic Execution API +// ///////////////////////////////////////////////////////////////////////////// + +const char* GetSharedLibraryDir(const LiteRtDispatchOption* options, + int num_options) { + for (auto i = 0; i < num_options; ++i) { + auto& option = options[i]; + if (!strcmp(option.name, kDispatchOptionSharedLibraryDir)) { + return option.value.str_value; + } + } + return nullptr; +} + +LiteRtStatus Initialize(const LiteRtDispatchOption* options, int num_options) { + auto* shared_library_dir = GetSharedLibraryDir(options, num_options); + std::optional shared_library_dir_opt = + shared_library_dir ? std::make_optional(std::string(shared_library_dir)) + : std::nullopt; + + if (auto southbound = + litert::google_tensor::Southbound::Create(shared_library_dir_opt); + !southbound) { + LITERT_LOG(LITERT_INFO, "Initialization failure: %s", + southbound.Error().Message().data()); + return southbound.Error().Status(); + } else { + TheSouthbound = southbound->release(); + } + + auto thr_initialize = TheSouthbound->api().thr_initialize; + if (!thr_initialize) { + LITERT_LOG(LITERT_INFO, "thr_initialize not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + if (auto status = thr_initialize(); status != kThrStatusSuccess) { + LITERT_LOG(LITERT_INFO, "thr_initialize failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + auto thr_get_vendor_api_version = + TheSouthbound->api().thr_get_vendor_api_version; + const char* sb_api_version = + thr_get_vendor_api_version ? thr_get_vendor_api_version() : "N.A."; + auto thr_get_vendor_id = TheSouthbound->api().thr_get_vendor_id; + const char* sb_vendor_id = thr_get_vendor_id ? thr_get_vendor_id() : "N.A."; + snprintf( + BuildId, sizeof(BuildId), + "GoogleTensor Dispatch API version %d.%d.%d, Darwinn API version %s, " + "vendor id: %s", + VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH, sb_api_version, + sb_vendor_id); + BuildId[sizeof(BuildId) - 1] = 0; + + return kLiteRtStatusOk; +} + +LiteRtStatus GetVendorId(const char** vendor_id) { + *vendor_id = "Google"; + return kLiteRtStatusOk; +} + +LiteRtStatus GetBuildId(const char** build_id) { + *build_id = BuildId; + return kLiteRtStatusOk; +} + +LiteRtStatus GetCapabilities(int* capabilities) { + *capabilities = kLiteRtDispatchCapabilitiesBasic | + kLiteRtDispatchCapabilitiesAsync | + kLiteRtDispatchCapabilitiesGraph; + return kLiteRtStatusOk; +} + +LiteRtStatus DeviceContextCreate(LiteRtDispatchDeviceContext* device_context) { + if (auto context = LiteRtDispatchDeviceContextT::Create(*TheSouthbound); + context) { + *device_context = context->release(); + return kLiteRtStatusOk; + } else { + LITERT_LOG(LITERT_ERROR, "Failed to create device context: %s", + context.Error().Message().data()); + return context.Error().Status(); + } + return kLiteRtStatusOk; +} + +LiteRtStatus DeviceContextDestroy(LiteRtDispatchDeviceContext device_context) { + delete device_context; + return kLiteRtStatusOk; +} + +LiteRtStatus GetInputRequirements( + LiteRtDispatchInvocationContext invocation_context, int input_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements) { + if (auto requirements = + invocation_context->GetInputRequirements(input_index, *tensor_type); + requirements) { + *tensor_buffer_requirements = *requirements; + return kLiteRtStatusOk; + } else { + LITERT_LOG(LITERT_ERROR, "Failed to get tensor buffer requirements: %s", + requirements.Error().Message().data()); + return requirements.Error().Status(); + } +} + +LiteRtStatus GetOutputRequirements( + LiteRtDispatchInvocationContext invocation_context, int output_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements) { + if (auto requirements = + invocation_context->GetOutputRequirements(output_index, *tensor_type); + requirements) { + *tensor_buffer_requirements = *requirements; + return kLiteRtStatusOk; + } else { + LITERT_LOG(LITERT_ERROR, "Failed to get tensor buffer requirements: %s", + requirements.Error().Message().data()); + return requirements.Error().Status(); + } +} + +LiteRtStatus RegisterTensorBuffer( + LiteRtDispatchDeviceContext device_context, + LiteRtTensorBuffer tensor_buffer, + LiteRtTensorBufferHandle* tensor_buffer_handle) { + LiteRtTensorBufferType tensor_buffer_type; + if (auto status = + LiteRtGetTensorBufferType(tensor_buffer, &tensor_buffer_type); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get buffer type: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + if (tensor_buffer_type != kLiteRtTensorBufferTypeAhwb) { + LITERT_LOG(LITERT_ERROR, "Unsupported buffer type: %d", tensor_buffer_type); + return kLiteRtStatusErrorUnsupported; + } + + size_t tensor_buffer_size; + if (auto status = + LiteRtGetTensorBufferSize(tensor_buffer, &tensor_buffer_size); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get buffer size: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + size_t tensor_buffer_offset; + if (auto status = + LiteRtGetTensorBufferOffset(tensor_buffer, &tensor_buffer_offset); + status != kLiteRtStatusOk) { + if (status == kLiteRtStatusErrorNotFound) { + tensor_buffer_offset = 0; + } else { + LITERT_LOG(LITERT_ERROR, "Failed to get buffer offset: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + } + + LiteRtRankedTensorType tensor_type; + if (auto status = + LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get tensor buffer type: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + auto* tensor_strides = tensor_type.layout.strides; + if (tensor_strides != nullptr) { + LITERT_LOG(LITERT_ERROR, "Tensor strides are not supported"); + return kLiteRtStatusErrorRuntimeFailure; + } + + AHardwareBuffer* ahwb; +#if LITERT_HAS_AHWB_SUPPORT + if (auto status = LiteRtGetTensorBufferAhwb(tensor_buffer, &ahwb); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get AHWB: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } +#else + LITERT_LOG(LITERT_ERROR, "AHardwareBuffer is not supported on this platform"); + return kLiteRtStatusErrorRuntimeFailure; +#endif + + ThrContext* thr_context = device_context->thr_context(); + ThrBufferHandle thr_buffer_handle; + + if (tensor_buffer_offset == 0) { + auto thr_register_buffer = TheSouthbound->api().thr_register_buffer; + if (!thr_register_buffer) { + LITERT_LOG(LITERT_ERROR, "thr_register_buffer not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + if (auto status = thr_register_buffer( + thr_context, ThrBufferType::kThrBufferTypeAHardwareBuffer, ahwb, + tensor_buffer_size, &thr_buffer_handle); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_register_buffer failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + } else { + auto thr_register_buffer_with_offset = + TheSouthbound->api().thr_register_buffer_with_offset; + if (!thr_register_buffer_with_offset) { + LITERT_LOG(LITERT_ERROR, "thr_register_buffer_with_offset not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + if (auto status = thr_register_buffer_with_offset( + thr_context, ThrBufferType::kThrBufferTypeAHardwareBuffer, ahwb, + tensor_buffer_offset, tensor_buffer_size, &thr_buffer_handle); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_register_buffer_with_offset failed: %d", + status); + return kLiteRtStatusErrorRuntimeFailure; + } + } + + *tensor_buffer_handle = thr_buffer_handle; + return kLiteRtStatusOk; +} + +LiteRtStatus UnregisterTensorBuffer( + LiteRtDispatchDeviceContext device_context, + LiteRtTensorBufferHandle tensor_buffer_handle) { + auto thr_unregister_buffer = TheSouthbound->api().thr_unregister_buffer; + if (!thr_unregister_buffer) { + LITERT_LOG(LITERT_ERROR, "thr_unregister_buffer not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrBufferHandle thr_buffer_handle = tensor_buffer_handle; + if (auto status = thr_unregister_buffer(device_context->thr_context(), + thr_buffer_handle); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_register_buffer failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus InvocationContextCreate( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType exec_type, const void* exec_bytecode, + size_t exec_bytecode_size, const char* function_name, int num_inputs, + int num_outputs, LiteRtDispatchInvocationContext* invocation_context) { + LiteRtDispatchGraph graph = nullptr; + if (auto status = GraphCreate(device_context, &graph); + status != kLiteRtStatusOk) { + return status; + } + + LiteRtDispatchNodeId node_id = 0; + LiteRtDispatchNodeType node_type; + switch (exec_type) { + case kLiteRtDispatchExecutableTypeDspLibrary: + node_type = kLiteRtDispatchNodeTypeDsp; + break; + case kLiteRtDispatchExecutableTypeMlModel: + node_type = kLiteRtDispatchNodeTypeNpu; + break; + default: + LITERT_LOG(LITERT_ERROR, "Unexpected executable type: %d", exec_type); + return kLiteRtStatusErrorInvalidArgument; + } + + if (auto status = AddNode(graph, node_id, node_type); + status != kLiteRtStatusOk) { + return status; + } + + LiteRtDispatchExecutableHandle exec_handle; + if (auto status = LoadExecutable(device_context, exec_type, exec_bytecode, + exec_bytecode_size, &exec_handle); + status != kLiteRtStatusOk) { + return status; + } + + if (auto status = + AssignNodeFunction(graph, node_id, exec_handle, function_name); + status != kLiteRtStatusOk) { + return status; + } + + LiteRtDispatchEdgeId next_edge_id = 0; + + for (auto input_index = 0; input_index < num_inputs; ++input_index) { + LiteRtDispatchEdgeId edge_id = next_edge_id++; + if (auto status = AddEdge(graph, edge_id); status != kLiteRtStatusOk) { + return status; + } + if (auto status = ConnectGraphInput(graph, input_index, edge_id); + status != kLiteRtStatusOk) { + return status; + } + if (auto status = ConnectNodeInput(graph, node_id, input_index, edge_id); + status != kLiteRtStatusOk) { + return status; + } + } + + for (auto output_index = 0; output_index < num_outputs; ++output_index) { + LiteRtDispatchEdgeId edge_id = next_edge_id++; + if (auto status = AddEdge(graph, edge_id); status != kLiteRtStatusOk) { + return status; + } + if (auto status = ConnectNodeOutput(graph, node_id, output_index, edge_id); + status != kLiteRtStatusOk) { + return status; + } + if (auto status = ConnectGraphOutput(graph, output_index, edge_id); + status != kLiteRtStatusOk) { + return status; + } + } + + if (auto status = InvocationContextCreateFromGraph(device_context, graph, + invocation_context); + status != kLiteRtStatusOk) { + return status; + } + + (*invocation_context)->AttachExecutable(exec_handle); + + return kLiteRtStatusOk; +} + +LiteRtStatus InvocationContextDestroy( + LiteRtDispatchInvocationContext invocation_context) { + auto thr_invocation_context_delete = + TheSouthbound->api().thr_invocation_context_delete; + if (!thr_invocation_context_delete) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_delete not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = invocation_context->graph()->thr_graph(); + ThrInvocationContext* thr_icontext = + invocation_context->thr_invocation_context(); + if (auto status = thr_invocation_context_delete(thr_graph, thr_icontext); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_delete failed: %d", + status); + return kLiteRtStatusErrorRuntimeFailure; + } + + delete invocation_context; + + return kLiteRtStatusOk; +} + +LiteRtStatus AttachBufferHelper( + LiteRtDispatchInvocationContext invocation_context, + LiteRtDispatchEdgeId edge_id, + LiteRtTensorBufferHandle tensor_buffer_handle) { + auto thr_invocation_context_attach_buffer = + TheSouthbound->api().thr_invocation_context_attach_buffer; + if (!thr_invocation_context_attach_buffer) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_attach_buffer not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrInvocationContext* thr_icontext = + invocation_context->thr_invocation_context(); + ThrContext* thr_context = invocation_context->device_context()->thr_context(); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + ThrBufferHandle thr_buffer_handle = tensor_buffer_handle; + if (auto status = thr_invocation_context_attach_buffer( + thr_icontext, thr_context, thr_edge_id.data(), thr_buffer_handle); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_attach_buffer failed: %d", + status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus AttachInput(LiteRtDispatchInvocationContext invocation_context, + int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (auto status_or = + invocation_context->graph()->InputEdge(graph_input_index); + !status_or) { + LITERT_LOG(LITERT_ERROR, "Unexpected graph input index: %d", + graph_input_index); + return kLiteRtStatusErrorInvalidArgument; + } else { + auto edge_id = *status_or; + return AttachBufferHelper(invocation_context, edge_id, + tensor_buffer_handle); + } +} + +LiteRtStatus AttachOutput(LiteRtDispatchInvocationContext invocation_context, + int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (auto status = invocation_context->graph()->OutputEdge(graph_output_index); + !status) { + LITERT_LOG(LITERT_ERROR, "Unexpected graph output index: %d", + graph_output_index); + return kLiteRtStatusErrorInvalidArgument; + } else { + auto edge_id = *status; + return AttachBufferHelper(invocation_context, edge_id, + tensor_buffer_handle); + } +} + +LiteRtStatus DetachTensorBufferHelper( + LiteRtDispatchInvocationContext invocation_context, + LiteRtDispatchEdgeId edge_id, + LiteRtTensorBufferHandle tensor_buffer_handle) { + auto thr_invocation_context_detach_buffer = + TheSouthbound->api().thr_invocation_context_detach_buffer; + if (!thr_invocation_context_detach_buffer) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_detach_buffer not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrInvocationContext* thr_icontext = + invocation_context->thr_invocation_context(); + ThrContext* thr_context = invocation_context->device_context()->thr_context(); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + ThrBufferHandle thr_buffer_handle = tensor_buffer_handle; + if (auto status = thr_invocation_context_detach_buffer( + thr_icontext, thr_context, thr_edge_id.data(), thr_buffer_handle); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_detach_buffer failed: %d", + status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus DetachInput(LiteRtDispatchInvocationContext invocation_context, + int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (auto status_or = + invocation_context->graph()->InputEdge(graph_input_index); + !status_or) { + LITERT_LOG(LITERT_ERROR, "Unexpected graph input index: %d", + graph_input_index); + return kLiteRtStatusErrorInvalidArgument; + } else { + auto edge_id = *status_or; + return DetachTensorBufferHelper(invocation_context, edge_id, + tensor_buffer_handle); + } +} + +LiteRtStatus DetachOutput(LiteRtDispatchInvocationContext invocation_context, + int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (auto status = invocation_context->graph()->OutputEdge(graph_output_index); + !status) { + LITERT_LOG(LITERT_ERROR, "Unexpected graph output index: %d", + graph_output_index); + return kLiteRtStatusErrorInvalidArgument; + } else { + auto edge_id = *status; + return DetachTensorBufferHelper(invocation_context, edge_id, + tensor_buffer_handle); + } +} + +LiteRtStatus PrepareForInvoke( + LiteRtDispatchInvocationContext invocation_context, + bool create_output_sync_fence) { + auto thr_invocation_context_prepare_for_invoke = + TheSouthbound->api().thr_invocation_context_prepare_for_invoke; + if (!thr_invocation_context_prepare_for_invoke) { + LITERT_LOG(LITERT_ERROR, + "thr_invocation_context_prepare_for_invoke not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrInvocationContext* thr_icontext = + invocation_context->thr_invocation_context(); + if (auto status = thr_invocation_context_prepare_for_invoke( + thr_icontext, create_output_sync_fence); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, + "thr_invocation_context_prepare_for_invoke failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus InvokeOnce(LiteRtDispatchInvocationContext invocation_context) { + auto thr_invocation_context_invoke_once = + TheSouthbound->api().thr_invocation_context_invoke_once; + if (!thr_invocation_context_invoke_once) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_invoke_once not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrInvocationContext* thr_icontext = + invocation_context->thr_invocation_context(); + if (auto status = thr_invocation_context_invoke_once(thr_icontext); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_invoke_once failed: %d", + status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus Wait(LiteRtDispatchInvocationContext invocation_context) { + auto thr_invocation_context_wait = + TheSouthbound->api().thr_invocation_context_wait; + if (!thr_invocation_context_wait) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_wait not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrInvocationContext* thr_icontext = + invocation_context->thr_invocation_context(); + if (auto status = thr_invocation_context_wait(thr_icontext); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_wait failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus Invoke(LiteRtDispatchInvocationContext invocation_context) { + if (auto status = PrepareForInvoke(invocation_context, + /*create_output_sync_fence=*/false); + status != kLiteRtStatusOk) { + return status; + } + + if (auto status = InvokeOnce(invocation_context); status != kLiteRtStatusOk) { + return status; + } + return Wait(invocation_context); +} + +// ///////////////////////////////////////////////////////////////////////////// +// Async Execution API +// ///////////////////////////////////////////////////////////////////////////// + +LiteRtStatus AttachInputEvent( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtEvent input_event) { + auto status_or = invocation_context->graph()->InputEdge(graph_input_index); + if (!status_or) { + LITERT_LOG(LITERT_ERROR, "Unexpected graph input index: %d", + graph_input_index); + return kLiteRtStatusErrorInvalidArgument; + } + auto edge_id = *status_or; + + auto thr_invocation_context_attach_input_buffer_sync_fence = + TheSouthbound->api() + .thr_invocation_context_attach_input_buffer_sync_fence; + if (!thr_invocation_context_attach_input_buffer_sync_fence) { + LITERT_LOG( + LITERT_ERROR, + "thr_invocation_context_attach_input_buffer_sync_fence not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + int input_fence_fd; + if (auto status = LiteRtGetEventSyncFenceFd(input_event, &input_fence_fd); + status != kLiteRtStatusOk) { + return status; + } + + ThrInvocationContext* thr_icontext = + invocation_context->thr_invocation_context(); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + if (auto status = thr_invocation_context_attach_input_buffer_sync_fence( + thr_icontext, thr_edge_id.data(), input_fence_fd); + status != kThrStatusSuccess) { + LITERT_LOG( + LITERT_ERROR, + "thr_invocation_context_attach_input_buffer_sync_fence failed: %d", + status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +namespace { + +LiteRtStatus GetOutputEvent(LiteRtDispatchInvocationContext invocation_context, + int graph_output_index, LiteRtEvent* output_event) { + auto status_or = invocation_context->graph()->OutputEdge(graph_output_index); + if (!status_or) { + LITERT_LOG(LITERT_ERROR, "Unexpected graph output index: %d", + graph_output_index); + return kLiteRtStatusErrorInvalidArgument; + } + auto edge_id = *status_or; + + auto thr_invocation_context_get_output_buffer_sync_fence = + TheSouthbound->api().thr_invocation_context_get_output_buffer_sync_fence; + if (!thr_invocation_context_get_output_buffer_sync_fence) { + LITERT_LOG(LITERT_ERROR, + "thr_invocation_context_get_output_buffer_sync_fence not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrInvocationContext* thr_icontext = + invocation_context->thr_invocation_context(); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + int output_fence_fd; + if (auto status = thr_invocation_context_get_output_buffer_sync_fence( + thr_icontext, thr_edge_id.data(), &output_fence_fd); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, + "thr_invocation_context_get_output_buffer_sync_fence failed: %d", + status); + return kLiteRtStatusErrorRuntimeFailure; + } + + if (auto status = LiteRtCreateEventFromSyncFenceFd( + output_fence_fd, /*owns_fd=*/false, output_event); + status != kLiteRtStatusOk) { + return status; + } + + return kLiteRtStatusOk; +} + +} // namespace + +LiteRtStatus InvokeAsync(LiteRtDispatchInvocationContext invocation_context, + int num_output_events, LiteRtEvent* output_events) { + if (num_output_events != invocation_context->graph()->NumOutputs()) { + LITERT_LOG(LITERT_ERROR, "Unexpected number of output events: %d", + num_output_events); + return kLiteRtStatusErrorInvalidArgument; + } + + if (auto status = PrepareForInvoke(invocation_context, + /*create_output_sync_fence=*/true); + status != kLiteRtStatusOk) { + return status; + } + + if (auto status = InvokeOnce(invocation_context); status != kLiteRtStatusOk) { + return status; + } + + for (auto graph_output_index = 0; graph_output_index < num_output_events; + ++graph_output_index) { + if (auto status = GetOutputEvent(invocation_context, graph_output_index, + &output_events[graph_output_index]); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get event for output %d: %d", + graph_output_index, status); + return kLiteRtStatusErrorRuntimeFailure; + } + } + + return kLiteRtStatusOk; +} + +// ///////////////////////////////////////////////////////////////////////////// +// Graph Execution API +// ///////////////////////////////////////////////////////////////////////////// + +LiteRtStatus GraphCreate(LiteRtDispatchDeviceContext device_context, + LiteRtDispatchGraph* graph) { + auto thr_graph_create = TheSouthbound->api().thr_graph_create; + if (!thr_graph_create) { + LITERT_LOG(LITERT_ERROR, "thr_graph_create not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = thr_graph_create(device_context->thr_context()); + if (!thr_graph) { + LITERT_LOG(LITERT_ERROR, "thr_graph_create failed"); + return kLiteRtStatusErrorRuntimeFailure; + } + + *graph = new LiteRtDispatchGraphT(thr_graph, device_context); + return kLiteRtStatusOk; +} + +LiteRtStatus GraphDestroy(LiteRtDispatchGraph graph) { + auto thr_graph_delete = TheSouthbound->api().thr_graph_delete; + if (!thr_graph_delete) { + LITERT_LOG(LITERT_ERROR, "thr_graph_delete not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + graph->device_context()->remove_graph(graph->thr_graph()); + + ThrGraph* thr_graph = graph->thr_graph(); + if (auto status = thr_graph_delete(thr_graph); status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_destroy failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + delete graph; + return kLiteRtStatusOk; +} + +LiteRtStatus AddNode(LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, + LiteRtDispatchNodeType node_type) { + auto thr_graph_add_sq_node = TheSouthbound->api().thr_graph_add_sq_node; + if (!thr_graph_add_sq_node) { + LITERT_LOG(LITERT_ERROR, "thr_graph_add_sq_node not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_node_id = ThrNodeIdStr(node_id); + ThrNodeType thr_node_type; + switch (node_type) { + case kLiteRtDispatchNodeTypeDsp: + thr_node_type = kThrNodeTypeDsp; + break; + case kLiteRtDispatchNodeTypeNpu: + thr_node_type = kThrNodeTypeNpu; + break; + default: + LITERT_LOG(LITERT_ERROR, "Unexpected node type: %d", node_type); + return kLiteRtStatusErrorInvalidArgument; + } + + if (auto status = + thr_graph_add_sq_node(thr_graph, thr_node_id.data(), thr_node_type); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_add_sq_node failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus AddEdge(LiteRtDispatchGraph graph, LiteRtDispatchEdgeId edge_id) { + auto thr_graph_add_edge = TheSouthbound->api().thr_graph_add_edge; + if (!thr_graph_add_edge) { + LITERT_LOG(LITERT_ERROR, "thr_graph_add_edge not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + ThrEdgeType thr_edge_type = kThrEdgeNoType; + if (auto status = + thr_graph_add_edge(thr_graph, thr_edge_id.data(), thr_edge_type); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_add_edge failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus ConnectNodeInput(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, int input_index, + LiteRtDispatchEdgeId edge_id) { + auto thr_graph_connect_node_input = + TheSouthbound->api().thr_graph_connect_node_input; + if (!thr_graph_connect_node_input) { + LITERT_LOG(LITERT_ERROR, "thr_graph_connect_node_input not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + int next_input_index = graph->NextNodeInputIndex(node_id); + if (input_index != next_input_index) { + LITERT_LOG(LITERT_ERROR, "Unexpected input index %d, expected %d", + input_index, next_input_index); + return kLiteRtStatusErrorInvalidArgument; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_node_id = ThrNodeIdStr(node_id); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + if (auto status = thr_graph_connect_node_input(thr_graph, thr_node_id.data(), + thr_edge_id.data()); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_set_input_edge failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + graph->AddInputEdge(input_index, edge_id); + return kLiteRtStatusOk; +} + +LiteRtStatus ConnectNodeOutput(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, int output_index, + LiteRtDispatchEdgeId edge_id) { + auto thr_graph_connect_node_output = + TheSouthbound->api().thr_graph_connect_node_output; + if (!thr_graph_connect_node_output) { + LITERT_LOG(LITERT_ERROR, "thr_graph_connect_node_output not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + int next_output_index = graph->NextNodeOutputIndex(node_id); + if (output_index != next_output_index) { + LITERT_LOG(LITERT_ERROR, "Unexpected output index %d, expected %d", + output_index, next_output_index); + return kLiteRtStatusErrorInvalidArgument; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_node_id = ThrNodeIdStr(node_id); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + if (auto status = thr_graph_connect_node_output(thr_graph, thr_node_id.data(), + thr_edge_id.data()); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_set_output_edge failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + graph->AddOutputEdge(output_index, edge_id); + return kLiteRtStatusOk; +} + +LiteRtStatus ConnectGraphInput(LiteRtDispatchGraph graph, int input_index, + LiteRtDispatchEdgeId edge_id) { + int next_input_index = graph->NextGraphInputIndex(); + if (input_index != next_input_index) { + LITERT_LOG(LITERT_ERROR, "Unexpected input index %d, expected %d", + input_index, next_input_index); + return kLiteRtStatusErrorInvalidArgument; + } + + auto thr_graph_set_input_edge = TheSouthbound->api().thr_graph_set_input_edge; + if (!thr_graph_set_input_edge) { + LITERT_LOG(LITERT_ERROR, "thr_graph_set_input_edge not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + if (auto status = thr_graph_set_input_edge(thr_graph, thr_edge_id.data()); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_set_input_edge failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus ConnectGraphOutput(LiteRtDispatchGraph graph, int output_index, + LiteRtDispatchEdgeId edge_id) { + int next_output_index = graph->NextGraphOutputIndex(); + if (output_index != next_output_index) { + LITERT_LOG(LITERT_ERROR, "Unexpected output index %d, expected %d", + output_index, next_output_index); + return kLiteRtStatusErrorInvalidArgument; + } + + auto thr_graph_set_output_edge = + TheSouthbound->api().thr_graph_set_output_edge; + if (!thr_graph_set_output_edge) { + LITERT_LOG(LITERT_ERROR, "thr_graph_set_output_edge not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + if (auto status = thr_graph_set_output_edge(thr_graph, thr_edge_id.data()); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_set_output_edge failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus LoadExecutable(LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType type, + const void* bytecode, size_t bytecode_size, + LiteRtDispatchExecutableHandle* exec_handle) { + auto thr_load_sq_container = TheSouthbound->api().thr_load_sq_container; + if (!thr_load_sq_container) { + LITERT_LOG(LITERT_ERROR, "thr_load_sq_container not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrSqContainerType thr_type; + switch (type) { + case kLiteRtDispatchExecutableTypeDspLibrary: + thr_type = kThrSqContainerTypeFunctionLibrary; + break; + case kLiteRtDispatchExecutableTypeMlModel: + thr_type = kThrSqContainerTypeMlModel; + break; + default: + LITERT_LOG(LITERT_ERROR, "Unexpected executable type: %d", type); + return kLiteRtStatusErrorInvalidArgument; + } + + ThrContext* thr_context = device_context->thr_context(); + ThrSqContainerHandle sq_handle; + if (auto status = thr_load_sq_container(thr_context, thr_type, bytecode, + bytecode_size, &sq_handle); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_load_sq_container failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + *exec_handle = sq_handle; + return kLiteRtStatusOk; +} + +LiteRtStatus UnloadExecutable(LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableHandle exec_handle) { + auto thr_unload_sq_container = TheSouthbound->api().thr_unload_sq_container; + if (!thr_unload_sq_container) { + LITERT_LOG(LITERT_ERROR, "thr_unload_sq_container not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrContext* thr_context = device_context->thr_context(); + ThrSqContainerHandle sq_handle = exec_handle; + if (auto status = thr_unload_sq_container(thr_context, sq_handle); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_unload_sq_container failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus AssignNodeFunction(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, + LiteRtDispatchExecutableHandle exec_handle, + const char* function_name) { + auto thr_graph_assign_sq = TheSouthbound->api().thr_graph_assign_sq; + if (!thr_graph_assign_sq) { + LITERT_LOG(LITERT_ERROR, "thr_graph_assign_sq not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_node_id = ThrNodeIdStr(node_id); + ThrSqContainerHandle sq_handle = exec_handle; + // An empty function name represent no function name being provided and + // therefore we must pass a nullptr to the call below, otherwise the SB API + // will expect a model with a signature. See b/378913220. + const char* function_name_ptr = + absl::string_view(function_name).empty() ? nullptr : function_name; + if (auto status = thr_graph_assign_sq(thr_graph, thr_node_id.data(), + sq_handle, function_name_ptr); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_assign_sq failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus AnnotateGraph(LiteRtDispatchGraph graph, const char* key, + const char* value) { + auto thr_graph_annotate_graph = TheSouthbound->api().thr_graph_annotate_graph; + if (!thr_graph_annotate_graph) { + LITERT_LOG(LITERT_ERROR, "thr_graph_annotate_graph not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = graph->thr_graph(); + if (auto status = thr_graph_annotate_graph(thr_graph, key, value); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_annotate_graph failed"); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus AnnotateNode(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, const char* key, + const char* value) { + auto thr_graph_annotate_node = TheSouthbound->api().thr_graph_annotate_node; + if (!thr_graph_annotate_node) { + LITERT_LOG(LITERT_ERROR, "thr_graph_annotate_node not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_node_id = ThrNodeIdStr(node_id); + if (auto status = + thr_graph_annotate_node(thr_graph, thr_node_id.data(), key, value); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_annotate_node failed"); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus AnnotateEdge(LiteRtDispatchGraph graph, + LiteRtDispatchEdgeId edge_id, const char* key, + const char* value) { + auto thr_graph_annotate_edge = TheSouthbound->api().thr_graph_annotate_edge; + if (!thr_graph_annotate_edge) { + LITERT_LOG(LITERT_ERROR, "thr_graph_annotate_edge not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + if (auto status = + thr_graph_annotate_edge(thr_graph, thr_edge_id.data(), key, value); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_annotate_edge failed"); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus InvocationContextCreateFromGraph( + LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph, + LiteRtDispatchInvocationContext* invocation_context) { + auto thr_invocation_context_get = + TheSouthbound->api().thr_invocation_context_get; + if (!thr_invocation_context_get) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_get not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_icontext = + thr_invocation_context_get(thr_graph, device_context->thr_context()); + if (!thr_icontext) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_get failed"); + return kLiteRtStatusErrorRuntimeFailure; + } + + device_context->add_graph(thr_graph); + *invocation_context = + new LiteRtDispatchInvocationContextT(thr_icontext, device_context, graph); + + return kLiteRtStatusOk; +} + +} // namespace google_tensor +} // namespace litert + +// ///////////////////////////////////////////////////////////////////////////// + +namespace { + +LiteRtDispatchInterface TheInterface = { + .initialize = litert::google_tensor::Initialize, + .get_vendor_id = litert::google_tensor::GetVendorId, + .get_build_id = litert::google_tensor::GetBuildId, + .get_capabilities = litert::google_tensor::GetCapabilities, + .device_context_create = litert::google_tensor::DeviceContextCreate, + .device_context_destroy = litert::google_tensor::DeviceContextDestroy, + .get_input_requirements = litert::google_tensor::GetInputRequirements, + .get_output_requirements = litert::google_tensor::GetOutputRequirements, + .register_tensor_buffer = litert::google_tensor::RegisterTensorBuffer, + .unregister_tensor_buffer = litert::google_tensor::UnregisterTensorBuffer, + .invocation_context_create = litert::google_tensor::InvocationContextCreate, + .invocation_context_destroy = + litert::google_tensor::InvocationContextDestroy, + .attach_input = litert::google_tensor::AttachInput, + .attach_output = litert::google_tensor::AttachOutput, + .detach_input = litert::google_tensor::DetachInput, + .detach_output = litert::google_tensor::DetachOutput, + .invoke = litert::google_tensor::Invoke, +}; + +LiteRtDispatchAsyncInterface TheAsyncInterface = { + .attach_input_event = litert::google_tensor::AttachInputEvent, + .invoke_async = litert::google_tensor::InvokeAsync, +}; + +LiteRtDispatchGraphInterface TheGraphInterface = { + .graph_create = litert::google_tensor::GraphCreate, + .graph_destroy = litert::google_tensor::GraphDestroy, + .add_node = litert::google_tensor::AddNode, + .add_edge = litert::google_tensor::AddEdge, + .connect_node_input = litert::google_tensor::ConnectNodeInput, + .connect_node_output = litert::google_tensor::ConnectNodeOutput, + .connect_graph_input = litert::google_tensor::ConnectGraphInput, + .connect_graph_output = litert::google_tensor::ConnectGraphOutput, + .load_executable = litert::google_tensor::LoadExecutable, + .unload_executable = litert::google_tensor::UnloadExecutable, + .assign_node_function = litert::google_tensor::AssignNodeFunction, + .annotate_graph = litert::google_tensor::AnnotateGraph, + .annotate_node = litert::google_tensor::AnnotateNode, + .annotate_edge = litert::google_tensor::AnnotateEdge, + .invocation_context_create_from_graph = + litert::google_tensor::InvocationContextCreateFromGraph, +}; + +LiteRtDispatchApi TheApi = { + .version = {.major = VERSION_MAJOR, + .minor = VERSION_MINOR, + .patch = VERSION_PATCH}, + .interface = &TheInterface, + .async_interface = &TheAsyncInterface, + .graph_interface = &TheGraphInterface, +}; + +} // namespace + +LiteRtStatus LiteRtDispatchGetApi(LiteRtDispatchApi* api) { + *api = TheApi; + return kLiteRtStatusOk; +} diff --git a/tflite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h b/tflite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h new file mode 100644 index 00000000..7c7672d4 --- /dev/null +++ b/tflite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h @@ -0,0 +1,67 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_DISPATCH_API_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_DISPATCH_API_H_ + +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" + +namespace litert { +namespace google_tensor { + +LiteRtStatus GraphCreate(LiteRtDispatchDeviceContext device_context, + LiteRtDispatchGraph* graph); +LiteRtStatus GraphDestroy(LiteRtDispatchGraph graph); +LiteRtStatus AddNode(LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, + LiteRtDispatchNodeType node_type); +LiteRtStatus AddEdge(LiteRtDispatchGraph graph, LiteRtDispatchEdgeId edge_id); +LiteRtStatus ConnectNodeInput(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, int input_index, + LiteRtDispatchEdgeId edge_id); +LiteRtStatus ConnectNodeOutput(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, int output_index, + LiteRtDispatchEdgeId edge_id); +LiteRtStatus ConnectGraphInput(LiteRtDispatchGraph graph, int input_index, + LiteRtDispatchEdgeId edge_id); +LiteRtStatus ConnectGraphOutput(LiteRtDispatchGraph graph, int output_index, + LiteRtDispatchEdgeId edge_id); +LiteRtStatus LoadExecutable(LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType type, + const void* bytecode, size_t bytecode_size, + LiteRtDispatchExecutableHandle* exec_handle); +LiteRtStatus UnloadExecutable(LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableHandle exec_handle); +LiteRtStatus AssignNodeFunction(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, + LiteRtDispatchExecutableHandle exec_handle, + const char* function_name); +LiteRtStatus AnnotateGraph(LiteRtDispatchGraph graph, const char* key, + const char* value); +LiteRtStatus AnnotateNode(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, const char* key, + const char* value); +LiteRtStatus AnnotateEdge(LiteRtDispatchGraph graph, + LiteRtDispatchEdgeId edge_id, const char* key, + const char* value); +LiteRtStatus InvocationContextCreateFromGraph( + LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph, + LiteRtDispatchInvocationContext* invocation_context); + +} // namespace google_tensor +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_DISPATCH_API_H_ diff --git a/tflite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc b/tflite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc new file mode 100644 index 00000000..7d8420c7 --- /dev/null +++ b/tflite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc @@ -0,0 +1,282 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include "absl/log/absl_log.h" +#include "absl/log/log.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/core/filesystem.h" +#include "tflite/experimental/litert/test/common.h" +#include "tflite/experimental/litert/test/testdata/simple_model_test_vectors.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" + +using ::testing::Pointwise; + +TEST(DispatchApi, GoogleTensor) { +#if !defined(__ANDROID__) + GTEST_SKIP() + << "This test is specific to Android devices with a GoogleTensor eTPU"; +#endif + + EXPECT_EQ(LiteRtDispatchInitialize(/*options=*/nullptr, /*num_options=*/0), + kLiteRtStatusOk); + + const char* vendor_id; + EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); + ABSL_LOG(INFO) << "vendor_id: " << vendor_id; + + const char* build_id; + EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); + ABSL_LOG(INFO) << "build_id: " << build_id; + + LiteRtApiVersion api_version; + EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); + ABSL_LOG(INFO) << "api_version: " << api_version.major << "." + << api_version.minor << "." << api_version.patch; + + int capabilities; + EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); + ABSL_LOG(INFO) << "capabilities: " << capabilities; + + LiteRtDispatchDeviceContext device_context = nullptr; + EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), + kLiteRtStatusOk); + ABSL_LOG(INFO) << "device_context: " << device_context; + + auto model_file_name = kGoogleTensorModelFileName; + auto model = litert::internal::LoadBinaryFile(model_file_name); + EXPECT_TRUE(model); + ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() + << " bytes"; + + // /////////////////////////////////////////////////////////////////////////// + // Set up an invocation context for a given model. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtDispatchInvocationContext invocation_context = nullptr; + EXPECT_EQ(LiteRtDispatchInvocationContextCreate( + device_context, kLiteRtDispatchExecutableTypeMlModel, + model->Data(), model->Size(), /*function_name=*/nullptr, + /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), + kLiteRtStatusOk); + ABSL_LOG(INFO) << "Invocation context: " << invocation_context; + + // /////////////////////////////////////////////////////////////////////////// + // Determine tensor buffer requirements. + // /////////////////////////////////////////////////////////////////////////// + + int num_tensor_buffer_types; + LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetInputRequirements( + invocation_context, /*input_index=*/0, &kInput0TensorType, + &input_0_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + input_0_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType input_0_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + input_0_tensor_buffer_requirements, /*type_index=*/0, + &input_0_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); + size_t input_0_tensor_buffer_size; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsBufferSize( + input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); + + LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetInputRequirements( + invocation_context, /*input_index=*/1, &kInput1TensorType, + &input_1_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + input_1_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType input_1_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + input_1_tensor_buffer_requirements, /*type_index=*/0, + &input_1_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); + size_t input_1_tensor_buffer_size; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsBufferSize( + input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); + + LiteRtTensorBufferRequirements output_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetOutputRequirements( + invocation_context, /*output_index=*/0, &kOutputTensorType, + &output_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + output_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType output_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + output_tensor_buffer_requirements, /*type_index=*/0, + &output_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); + size_t output_tensor_buffer_size; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( + output_tensor_buffer_requirements, &output_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); + + // /////////////////////////////////////////////////////////////////////////// + // Allocate tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtTensorBuffer input_0_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + input_0_tensor_buffer_type, &kInput0TensorType, + input_0_tensor_buffer_size, &input_0_tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBuffer input_1_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + input_1_tensor_buffer_type, &kInput1TensorType, + input_1_tensor_buffer_size, &input_1_tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBuffer output_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + output_tensor_buffer_type, &kOutputTensorType, + output_tensor_buffer_size, &output_tensor_buffer), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Register tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtTensorBufferHandle input_1_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, input_1_tensor_buffer, &input_1_handle), + kLiteRtStatusOk); + + LiteRtTensorBufferHandle input_0_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, input_0_tensor_buffer, &input_0_handle), + kLiteRtStatusOk); + + LiteRtTensorBufferHandle output_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, output_tensor_buffer, &output_handle), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Attach tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, + /*graph_input_index=*/0, input_0_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, + /*graph_input_index=*/1, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, + /*graph_output_index=*/0, output_handle), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Fill the input buffers with data. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Filling inputs with data"; + void* host_mem_addr; + + ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Execute model. + // /////////////////////////////////////////////////////////////////////////// + + ABSL_LOG(INFO) << "Invoking execution..."; + EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Check output for correctness. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Checking output..."; + void* host_mem_addr; + ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + auto output = absl::MakeSpan(static_cast(host_mem_addr), + kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Clean up resources. + // /////////////////////////////////////////////////////////////////////////// + + EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, + /*graph_input_index=*/0, input_0_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, + /*graph_input_index=*/1, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, + /*graph_output_index=*/0, output_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), + kLiteRtStatusOk); + EXPECT_EQ( + LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ( + LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), + kLiteRtStatusOk); + LiteRtDestroyTensorBuffer(output_tensor_buffer); + LiteRtDestroyTensorBuffer(input_1_tensor_buffer); + LiteRtDestroyTensorBuffer(input_0_tensor_buffer); + EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), + kLiteRtStatusOk); +} diff --git a/tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.cc b/tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.cc new file mode 100644 index 00000000..1c566410 --- /dev/null +++ b/tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.cc @@ -0,0 +1,61 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h" + +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +using litert::Expected; +using litert::Unexpected; + +LiteRtDispatchDeviceContextT::~LiteRtDispatchDeviceContextT() { + if (!thr_graphs_.empty()) { + auto thr_graph_delete = southbound_.api().thr_graph_delete; + if (!thr_graph_delete) { + LITERT_LOG(LITERT_ERROR, "thr_graph_delete not found"); + } else { + for (auto* thr_graph : thr_graphs_) { + thr_graph_delete(thr_graph); + } + } + } + + if (thr_context_) { + auto thr_context_delete = southbound_.api().thr_context_delete; + if (!thr_context_delete) { + LITERT_LOG(LITERT_ERROR, "thr_context_delete not found"); + } else { + thr_context_delete(thr_context_); + } + } +} + +Expected +LiteRtDispatchDeviceContextT::Create( + const litert::google_tensor::Southbound& southbound) { + Ptr device_context(new LiteRtDispatchDeviceContextT(southbound)); + + auto thr_context_create = southbound.api().thr_context_create; + if (!thr_context_create) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "thr_context_create not found"); + } + + device_context->thr_context_ = thr_context_create(); + return device_context; +} diff --git a/tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h b/tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h new file mode 100644 index 00000000..7e006f3c --- /dev/null +++ b/tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "third_party/odml/infra/southbound/sb_api.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tflite/experimental/litert/vendors/google_tensor/dispatch/southbound.h" + +class LiteRtDispatchDeviceContextT { + public: + using Ptr = std::unique_ptr; + + ~LiteRtDispatchDeviceContextT(); + + static litert::Expected Create( + const litert::google_tensor::Southbound& southbound); + + ThrContext* thr_context() { return thr_context_; } + void add_graph(ThrGraph* graph) { thr_graphs_.insert(graph); } + void remove_graph(ThrGraph* graph) { thr_graphs_.erase(graph); } + + private: + explicit LiteRtDispatchDeviceContextT( + const litert::google_tensor::Southbound& southbound) + : southbound_(southbound) {} + + const litert::google_tensor::Southbound& southbound_; + ThrContext* thr_context_ = nullptr; + absl::flat_hash_set thr_graphs_; +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ diff --git a/tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h b/tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h new file mode 100644 index 00000000..e7cb8bc1 --- /dev/null +++ b/tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h @@ -0,0 +1,94 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_GRAPH_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_GRAPH_H_ + +#include +#include + +#include "third_party/odml/infra/southbound/sb_api.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" + +class LiteRtDispatchGraphT { + public: + LiteRtDispatchGraphT(ThrGraph* thr_graph, + LiteRtDispatchDeviceContext device_context) + : thr_graph_(thr_graph), device_context_(device_context) {} + + ThrGraph* thr_graph() { return thr_graph_; } + + LiteRtDispatchDeviceContext device_context() { return device_context_; } + + int NextNodeInputIndex(LiteRtDispatchNodeId node_id) { + return NextNodeIoIndex(node_id, next_node_input_index_); + } + + int NextNodeOutputIndex(LiteRtDispatchNodeId node_id) { + return NextNodeIoIndex(node_id, next_node_output_index_); + } + + int NextGraphInputIndex() { return next_graph_input_index_++; } + + int NextGraphOutputIndex() { return next_graph_output_index_++; } + + void AddInputEdge(int input_index, LiteRtDispatchEdgeId edge_id) { + input_edges_[input_index] = edge_id; + } + + void AddOutputEdge(int output_index, LiteRtDispatchEdgeId edge_id) { + output_edges_[output_index] = edge_id; + } + + litert::Expected InputEdge(int input_index) const { + return IoEdge(input_index, input_edges_); + } + + litert::Expected OutputEdge(int output_index) const { + return IoEdge(output_index, output_edges_); + } + + size_t NumOutputs() const { return output_edges_.size(); } + + private: + using NextNodeIoIndexMap = std::map; + using IoIndexToEdgeIdMap = std::map; + + int NextNodeIoIndex(LiteRtDispatchNodeId node_id, NextNodeIoIndexMap& map) { + return map[node_id]++; + } + + litert::Expected IoEdge( + int io_index, const IoIndexToEdgeIdMap& map) const { + auto iter = map.find(io_index); + if (iter == map.end()) { + return litert::Unexpected(kLiteRtStatusErrorNotFound, + "Unexpected graph input/output index"); + } + return iter->second; + } + + ThrGraph* thr_graph_; + LiteRtDispatchDeviceContext device_context_; + NextNodeIoIndexMap next_node_input_index_; + NextNodeIoIndexMap next_node_output_index_; + int next_graph_input_index_ = 0; + int next_graph_output_index_ = 0; + IoIndexToEdgeIdMap input_edges_; + IoIndexToEdgeIdMap output_edges_; +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_GRAPH_H_ diff --git a/tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.cc b/tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.cc new file mode 100644 index 00000000..343bfb80 --- /dev/null +++ b/tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.cc @@ -0,0 +1,84 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h" + +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/core/util/tensor_type_util.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" + +using litert::Expected; +using litert::Unexpected; + +namespace { + +constexpr const size_t kEdgeTpuPadding = 64; + +inline constexpr auto Pad(auto x, auto align) { + return ((x + align - 1) / align) * align; +} + +Expected GetTensorBufferRequirements( + const LiteRtRankedTensorType& tensor_type) { + auto* tensor_strides = tensor_type.layout.strides; + if (tensor_strides != nullptr) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Tensor strides are not supported on GoogleTensor"); + } + + LiteRtTensorBufferType supported_tensor_buffer_types[] = { + kLiteRtTensorBufferTypeAhwb, + }; + int num_supported_tensor_buffer_types = + sizeof(supported_tensor_buffer_types) / + sizeof(supported_tensor_buffer_types[0]); + + auto buffer_size = litert::internal::GetNumPackedBytes(tensor_type); + if (!buffer_size) { + return Unexpected(buffer_size.Error()); + } + + size_t padded_buffer_size = Pad(*buffer_size, kEdgeTpuPadding); + + LiteRtTensorBufferRequirements requirements; + if (auto status = LiteRtCreateTensorBufferRequirements( + num_supported_tensor_buffer_types, supported_tensor_buffer_types, + padded_buffer_size, /*num_strides=*/0, /*strides=*/nullptr, + &requirements); + status != kLiteRtStatusOk) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to create tensor buffer requirements"); + } + + return requirements; +} +} // namespace + +Expected +LiteRtDispatchInvocationContextT::GetInputRequirements( + int input_index, const LiteRtRankedTensorType& tensor_type) { + return GetTensorBufferRequirements(tensor_type); +} + +Expected +LiteRtDispatchInvocationContextT::GetOutputRequirements( + int output_index, const LiteRtRankedTensorType& tensor_type) { + return GetTensorBufferRequirements(tensor_type); +} diff --git a/tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h b/tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h new file mode 100644 index 00000000..7cc32764 --- /dev/null +++ b/tflite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h @@ -0,0 +1,66 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ + +#include + +#include "third_party/odml/infra/southbound/sb_api.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tflite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h" + +class LiteRtDispatchInvocationContextT { + public: + LiteRtDispatchInvocationContextT(ThrInvocationContext* thr_invocation_context, + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchGraph graph) + : thr_invocation_context_(thr_invocation_context), + device_context_(device_context), + graph_(graph) {} + + ~LiteRtDispatchInvocationContextT() { + if (exec_handle_) { + litert::google_tensor::UnloadExecutable(device_context_, *exec_handle_); + } + } + + litert::Expected GetInputRequirements( + int input_index, const LiteRtRankedTensorType& tensor_type); + litert::Expected GetOutputRequirements( + int output_index, const LiteRtRankedTensorType& tensor_type); + + ThrInvocationContext* thr_invocation_context() { + return thr_invocation_context_; + } + + LiteRtDispatchDeviceContext device_context() { return device_context_; } + + LiteRtDispatchGraph graph() { return graph_; } + + void AttachExecutable(LiteRtDispatchExecutableHandle exec_handle) { + exec_handle_ = exec_handle; + } + + private: + ThrInvocationContext* thr_invocation_context_; + LiteRtDispatchDeviceContext device_context_; + LiteRtDispatchGraph graph_; + std::optional exec_handle_; +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ diff --git a/tflite/experimental/litert/vendors/google_tensor/dispatch/southbound.cc b/tflite/experimental/litert/vendors/google_tensor/dispatch/southbound.cc new file mode 100644 index 00000000..a70775fd --- /dev/null +++ b/tflite/experimental/litert/vendors/google_tensor/dispatch/southbound.cc @@ -0,0 +1,146 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/google_tensor/dispatch/southbound.h" + +#include + +#include +#include +#include + +#include "third_party/odml/infra/southbound/sb_api.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +#define Load(H, S) \ + H = reinterpret_cast(::dlsym(dlib_handle_, #S)); \ + if (!H) { \ + LITERT_LOG(LITERT_WARNING, "Failed to load symbol %s: %s", #S, \ + ::dlerror()); \ + } + +namespace litert { +namespace google_tensor { + +namespace { +// Currently the SouthBound implementation is bundled inside the Edge TPU +// runtime shared library. +constexpr const char* kSouthBoundLibPath = "/vendor/lib64/libedgetpu_util.so"; +} // namespace + +Southbound::Southbound() : api_(new ThrFunctions) {} + +Southbound::~Southbound() { + if (dlib_handle_) { + ::dlclose(dlib_handle_); + } +} + +Expected Southbound::Create( + std::optional shared_library_dir) { + Ptr southbound(new Southbound); + if (auto status = southbound->LoadSymbols(shared_library_dir); !status) { + return Unexpected(status.Error()); + } + + return southbound; +} + +Expected Southbound::LoadSymbols( + std::optional shared_library_dir) { + // Always load the Southbound API library from the vendor partition. + (void)shared_library_dir; + + dlib_handle_ = ::dlopen(kSouthBoundLibPath, RTLD_NOW | RTLD_LOCAL); + if (!dlib_handle_) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to load Southbound shared library"); + } + + // Binds all supported symbols from the shared library to the function + // pointers. + Load(api_->thr_initialize, thrInitialize); + + Load(api_->thr_get_vendor_api_version, thrGetVendorApiVersion); + Load(api_->thr_get_vendor_id, thrGetVendorId); + + Load(api_->thr_context_create, thrContextCreate); + Load(api_->thr_context_delete, thrContextDelete); + + Load(api_->thr_graph_create, thrGraphCreate); + Load(api_->thr_graph_delete, thrGraphDelete); + + Load(api_->thr_graph_add_edge, thrGraphAddEdge); + Load(api_->thr_graph_add_sq_node, thrGraphAddSqNode); + + Load(api_->thr_graph_connect_node_input, thrGraphConnectNodeInput); + Load(api_->thr_graph_connect_node_output, thrGraphConnectNodeOutput); + + Load(api_->thr_graph_set_input_edge, thrGraphSetInputEdge); + Load(api_->thr_graph_set_output_edge, thrGraphSetOutputEdge); + + Load(api_->thr_graph_annotate_graph, thrGraphAnnotateGraph); + Load(api_->thr_graph_annotate_edge, thrGraphAnnotateEdge); + Load(api_->thr_graph_annotate_node, thrGraphAnnotateNode); + + Load(api_->thr_load_sq_container, thrLoadSqContainer); + Load(api_->thr_load_sq_container_fd, thrLoadSqContainerFd); + Load(api_->thr_load_sq_container_file, thrLoadSqContainerFile); + Load(api_->thr_unload_sq_container, thrUnloadSqContainer); + + Load(api_->thr_graph_assign_sq, thrGraphAssignSq); + Load(api_->thr_sq_query_scratch_pad, thrSqQueryScratchPad); + Load(api_->thr_sq_attach_scratch_pad_buffer, thrSqAttachScratchPadBuffer); + + Load(api_->thr_register_buffer, thrRegisterBuffer); + Load(api_->thr_register_buffer_with_offset, thrRegisterBufferWithOffset); + Load(api_->thr_unregister_buffer, thrUnregisterBuffer); + + Load(api_->thr_invocation_context_get, thrInvocationContextGet); + Load(api_->thr_invocation_context_delete, thrInvocationContextDelete); + + Load(api_->thr_invocation_context_attach_buffer, + thrInvocationContextAttachBuffer); + Load(api_->thr_invocation_context_detach_buffer, + thrInvocationContextDetachBuffer); + + Load(api_->thr_invocation_context_prepare_for_invoke, + thrInvocationContextPrepareForInvoke); + Load(api_->thr_invocation_context_invoke_once, + thrInvocationContextInvokeOnce); + Load(api_->thr_invocation_context_wait, thrInvocationContextWait); + + Load(api_->thr_invocation_context_attach_input_buffer_sync_fence, + thrInvocationContextAttachInputBufferSyncFence); + Load(api_->thr_invocation_context_get_output_buffer_sync_fence, + thrInvocationContextGetOutputBufferSyncFence); + + Load(api_->thr_invocation_context_query_node_scratch_pad, + thrInvocationContextQueryNodeScratchPad); + Load(api_->thr_invocation_context_attach_scratch_pad_buffer, + thrInvocationContextAttachScratchPadBuffer); + + Load(api_->thr_vendor_set_system_attribute_str, + thrVendorSetSystemAttributeStr); + Load(api_->thr_vendor_set_system_attribute_int64, + thrVendorSetSystemAttributeInt64); + + LITERT_LOG(LITERT_INFO, "SouthBound symbols loaded"); + return {}; +} + +} // namespace google_tensor +} // namespace litert diff --git a/tflite/experimental/litert/vendors/google_tensor/dispatch/southbound.h b/tflite/experimental/litert/vendors/google_tensor/dispatch/southbound.h new file mode 100644 index 00000000..026a5e78 --- /dev/null +++ b/tflite/experimental/litert/vendors/google_tensor/dispatch/southbound.h @@ -0,0 +1,128 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_SOUTHBOUND_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_SOUTHBOUND_H_ + +#include +#include +#include + +#include "third_party/odml/infra/southbound/sb_api.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +namespace litert { +namespace google_tensor { + +class Southbound { + public: + using Ptr = std::unique_ptr; + struct ThrFunctions; + + Southbound(Southbound&) = delete; + Southbound(Southbound&&) = delete; + Southbound& operator=(const Southbound&) = delete; + Southbound& operator=(Southbound&&) = delete; + + ~Southbound(); + + static Expected Create(std::optional shared_library_dir); + + const ThrFunctions& api() const { return *api_; } + + private: + Southbound(); + Expected LoadSymbols(std::optional shared_library_dir); + + void* dlib_handle_ = nullptr; + std::unique_ptr api_; +}; + +// A convenient struct for holding function pointers to SouthBound symbols. +// These function pointers will be loaded to the shared library on device during +// runtime. +struct Southbound::ThrFunctions { + decltype(&thrInitialize) thr_initialize = nullptr; + + decltype(&thrGetVendorApiVersion) thr_get_vendor_api_version = nullptr; + decltype(&thrGetVendorId) thr_get_vendor_id = nullptr; + + decltype(&thrContextCreate) thr_context_create = nullptr; + decltype(&thrContextDelete) thr_context_delete = nullptr; + + decltype(&thrGraphCreate) thr_graph_create = nullptr; + decltype(&thrGraphDelete) thr_graph_delete = nullptr; + + decltype(&thrGraphAddEdge) thr_graph_add_edge = nullptr; + decltype(&thrGraphAddSqNode) thr_graph_add_sq_node = nullptr; + + decltype(&thrGraphConnectNodeInput) thr_graph_connect_node_input = nullptr; + decltype(&thrGraphConnectNodeOutput) thr_graph_connect_node_output = nullptr; + + decltype(&thrGraphSetInputEdge) thr_graph_set_input_edge = nullptr; + decltype(&thrGraphSetOutputEdge) thr_graph_set_output_edge = nullptr; + + decltype(&thrGraphAnnotateGraph) thr_graph_annotate_graph = nullptr; + decltype(&thrGraphAnnotateEdge) thr_graph_annotate_edge = nullptr; + decltype(&thrGraphAnnotateNode) thr_graph_annotate_node = nullptr; + + decltype(&thrLoadSqContainer) thr_load_sq_container = nullptr; + decltype(&thrLoadSqContainerFd) thr_load_sq_container_fd = nullptr; + decltype(&thrLoadSqContainerFile) thr_load_sq_container_file = nullptr; + decltype(&thrUnloadSqContainer) thr_unload_sq_container = nullptr; + + decltype(&thrGraphAssignSq) thr_graph_assign_sq = nullptr; + decltype(&thrSqQueryScratchPad) thr_sq_query_scratch_pad = nullptr; + decltype(&thrSqAttachScratchPadBuffer) thr_sq_attach_scratch_pad_buffer = + nullptr; + + decltype(&thrRegisterBuffer) thr_register_buffer = nullptr; + decltype(&thrRegisterBufferWithOffset) thr_register_buffer_with_offset = + nullptr; + decltype(&thrUnregisterBuffer) thr_unregister_buffer = nullptr; + + decltype(&thrInvocationContextGet) thr_invocation_context_get = nullptr; + decltype(&thrInvocationContextDelete) thr_invocation_context_delete = nullptr; + + decltype(&thrInvocationContextAttachBuffer) + thr_invocation_context_attach_buffer = nullptr; + decltype(&thrInvocationContextDetachBuffer) + thr_invocation_context_detach_buffer = nullptr; + + decltype(&thrInvocationContextPrepareForInvoke) + thr_invocation_context_prepare_for_invoke = nullptr; + decltype(&thrInvocationContextInvokeOnce) thr_invocation_context_invoke_once = + nullptr; + decltype(&thrInvocationContextWait) thr_invocation_context_wait = nullptr; + + decltype(&thrInvocationContextAttachInputBufferSyncFence) + thr_invocation_context_attach_input_buffer_sync_fence = nullptr; + decltype(&thrInvocationContextGetOutputBufferSyncFence) + thr_invocation_context_get_output_buffer_sync_fence = nullptr; + + decltype(&thrInvocationContextQueryNodeScratchPad) + thr_invocation_context_query_node_scratch_pad = nullptr; + decltype(&thrInvocationContextAttachScratchPadBuffer) + thr_invocation_context_attach_scratch_pad_buffer = nullptr; + + decltype(&thrVendorSetSystemAttributeStr) + thr_vendor_set_system_attribute_str = nullptr; + decltype(&thrVendorSetSystemAttributeInt64) + thr_vendor_set_system_attribute_int64 = nullptr; +}; + +} // namespace google_tensor +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_SOUTHBOUND_H_ diff --git a/tflite/experimental/litert/vendors/mediatek/BUILD b/tflite/experimental/litert/vendors/mediatek/BUILD new file mode 100644 index 00000000..632126af --- /dev/null +++ b/tflite/experimental/litert/vendors/mediatek/BUILD @@ -0,0 +1,34 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "neuron_adapter", + srcs = [ + "neuron_adapter.cc", + ], + hdrs = [ + "neuron_adapter.h", + ], + deps = [ + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/core:dynamic_loading", + ], +) diff --git a/tflite/experimental/litert/vendors/mediatek/dispatch/BUILD b/tflite/experimental/litert/vendors/mediatek/dispatch/BUILD new file mode 100644 index 00000000..acc89511 --- /dev/null +++ b/tflite/experimental/litert/vendors/mediatek/dispatch/BUILD @@ -0,0 +1,85 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//tflite/experimental/litert/build_common:litert_build_defs.bzl", "litert_dynamic_lib") + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +litert_dynamic_lib( + name = "dispatch_api", + srcs = [ + "dispatch_api.cc", + "litert_dispatch_device_context.cc", + "litert_dispatch_invocation_context.cc", + ], + hdrs = [ + "litert_dispatch_device_context.h", + "litert_dispatch_invocation_context.h", + ], + export_litert_only = True, + linkopts = select({ + "@org_tensorflow//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + linkstatic = 1, + shared_lib_name = "dispatch_api_so", + so_name = "libLiteRtDispatch.so", + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_tensor_buffer", + "//tflite/experimental/litert/core:dynamic_loading", + "//tflite/experimental/litert/core/util:tensor_type_util", + "//tflite/experimental/litert/vendors/c:litert_dispatch_c_api", + "//tflite/experimental/litert/vendors/mediatek:neuron_adapter", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "dispatch_api_mediatek_test", + srcs = [ + "dispatch_api_mediatek_test.cc", + ], + data = [ + ":dispatch_api_so", + ], + linkopts = select({ + "@org_tensorflow//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + deps = [ + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "//tflite/experimental/litert/core:filesystem", + "//tflite/experimental/litert/test:simple_model_npu", + "//tflite/experimental/litert/vendors/c:litert_dispatch_c_api", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tflite/experimental/litert/vendors/mediatek/dispatch/README.md b/tflite/experimental/litert/vendors/mediatek/dispatch/README.md new file mode 100644 index 00000000..35a6130c --- /dev/null +++ b/tflite/experimental/litert/vendors/mediatek/dispatch/README.md @@ -0,0 +1,4 @@ +Test case can dispatch_api_mediatek_test can be run on a device with a MetiaTek +mt6989 SoC with the following comands + +$ ../../../google/run_test_on_android.sh dispatch_api_mediatek_test diff --git a/tflite/experimental/litert/vendors/mediatek/dispatch/dispatch_api.cc b/tflite/experimental/litert/vendors/mediatek/dispatch/dispatch_api.cc new file mode 100644 index 00000000..738f714b --- /dev/null +++ b/tflite/experimental/litert/vendors/mediatek/dispatch/dispatch_api.cc @@ -0,0 +1,327 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#if LITERT_HAS_AHWB_SUPPORT +#include +#endif + +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch_api.h" +#include "tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h" +#include "tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.h" +#include "tflite/experimental/litert/vendors/mediatek/neuron_adapter.h" + +namespace { + +litert::mediatek::NeuronAdapter* TheNeuronAdapter; +char BuildId[256]; + +} // namespace + +namespace litert { +namespace mediatek { + +// ///////////////////////////////////////////////////////////////////////////// +// Basic Execution API +// ///////////////////////////////////////////////////////////////////////////// + +const char* GetSharedLibraryDir(const LiteRtDispatchOption* options, + int num_options) { + for (auto i = 0; i < num_options; ++i) { + auto& option = options[i]; + if (!strcmp(option.name, kDispatchOptionSharedLibraryDir)) { + return option.value.str_value; + } + } + return nullptr; +} + +LiteRtStatus LiteRtInitialize(const LiteRtDispatchOption* options, + int num_options) { + auto* shared_library_dir = GetSharedLibraryDir(options, num_options); + std::optional shared_library_dir_opt = + shared_library_dir ? std::make_optional(std::string(shared_library_dir)) + : std::nullopt; + + if (auto neuron_adapter = + litert::mediatek::NeuronAdapter::Create(shared_library_dir_opt); + neuron_adapter) { + TheNeuronAdapter = neuron_adapter->release(); + } else { + LITERT_LOG(LITERT_INFO, "Initialization failure: %s", + neuron_adapter.Error().Message().data()); + return neuron_adapter.Error().Status(); + } + + auto get_version = TheNeuronAdapter->api().get_version; + if (!get_version) { + LITERT_LOG(LITERT_ERROR, "get_version not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + NeuronRuntimeVersion version; + if (get_version(&version) != NEURON_NO_ERROR) { + LITERT_LOG(LITERT_ERROR, "Failed to get version"); + return kLiteRtStatusErrorRuntimeFailure; + } + LITERT_LOG(LITERT_INFO, "Neuron SDK version: %d.%d.%d", version.major, + version.minor, version.patch); + + snprintf(BuildId, sizeof(BuildId), + "MediaTek Dispatch API version %d.%d.%d, NeuronAdaptor API version " + "%d.%d.%d", + LITERT_API_VERSION_MAJOR, LITERT_API_VERSION_MINOR, + LITERT_API_VERSION_PATCH, version.major, version.minor, + version.patch); + BuildId[sizeof(BuildId) - 1] = 0; + + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetVendorId(const char** vendor_id) { + *vendor_id = "MediaTek"; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetBuildId(const char** build_id) { + *build_id = BuildId; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetCapabilities(int* capabilities) { + *capabilities = kLiteRtDispatchCapabilitiesBasic; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtDeviceContextCreate( + LiteRtDispatchDeviceContext* device_context) { + if (auto context = LiteRtDispatchDeviceContextT::Create(*TheNeuronAdapter); + context) { + *device_context = context->release(); + return kLiteRtStatusOk; + } else { + LITERT_LOG(LITERT_ERROR, "Failed to create device context: %s", + context.Error().Message().data()); + return context.Error().Status(); + } +} + +LiteRtStatus LiteRtDeviceContextDestroy( + LiteRtDispatchDeviceContext device_context) { + delete device_context; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetInputRequirements( + LiteRtDispatchInvocationContext invocation_context, int input_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements) { + if (auto requirements = + invocation_context->GetInputRequirements(input_index, *tensor_type); + requirements) { + *tensor_buffer_requirements = *requirements; + return kLiteRtStatusOk; + } else { + LITERT_LOG(LITERT_ERROR, "Failed to get tensor buffer requirements: %s", + requirements.Error().Message().data()); + return requirements.Error().Status(); + } +} + +LiteRtStatus LiteRtGetOutputRequirements( + LiteRtDispatchInvocationContext invocation_context, int output_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements) { + if (auto requirements = + invocation_context->GetOutputRequirements(output_index, *tensor_type); + requirements) { + *tensor_buffer_requirements = *requirements; + return kLiteRtStatusOk; + } else { + LITERT_LOG(LITERT_ERROR, "Failed to get tensor buffer requirements: %s", + requirements.Error().Message().data()); + return requirements.Error().Status(); + } +} + +LiteRtStatus LiteRtRegisterTensorBuffer( + LiteRtDispatchDeviceContext device_context, + LiteRtTensorBuffer tensor_buffer, + LiteRtTensorBufferHandle* tensor_buffer_handle) { + litert::TensorBuffer tensor_buffer_(tensor_buffer, /*owned=*/false); + if (auto result = device_context->RegisterTensorBuffer(tensor_buffer_); + result) { + *tensor_buffer_handle = *result; + return kLiteRtStatusOk; + } else { + LITERT_LOG(LITERT_ERROR, "Failed to register tensor buffer: %s", + result.Error().Message().data()); + return result.Error().Status(); + } +} + +LiteRtStatus LiteRtUnregisterTensorBuffer( + LiteRtDispatchDeviceContext device_context, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (auto status = + device_context->UnregisterTensorBuffer(tensor_buffer_handle); + !status) { + LITERT_LOG(LITERT_ERROR, "Failed to unregister tensor buffer: %s", + status.Error().Message().data()); + return status.Error().Status(); + } + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtInvocationContextCreate( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType exec_type, const void* exec_bytecode_ptr, + size_t exec_bytecode_size, const char* function_name, int num_inputs, + int num_outputs, LiteRtDispatchInvocationContext* invocation_context) { + auto context = LiteRtDispatchInvocationContextT::Create( + *TheNeuronAdapter, device_context, exec_type, exec_bytecode_ptr, + exec_bytecode_size, function_name, num_inputs, num_outputs); + if (!context) { + LITERT_LOG(LITERT_ERROR, "Failed to create context from context binary: %s", + context.Error().Message().data()); + return context.Error().Status(); + } + *invocation_context = context->release(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtInvocationContextDestroy( + LiteRtDispatchInvocationContext invocation_context) { + delete invocation_context; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtAttachInput( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (auto status = invocation_context->AttachInput(graph_input_index, + tensor_buffer_handle); + !status) { + LITERT_LOG(LITERT_ERROR, "Failed to attach input: %s", + status.Error().Message().data()); + return status.Error().Status(); + } + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtAttachOutput( + LiteRtDispatchInvocationContext invocation_context, int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (auto status = invocation_context->AttachOutput(graph_output_index, + tensor_buffer_handle); + !status) { + LITERT_LOG(LITERT_ERROR, "Failed to attach output: %s", + status.Error().Message().data()); + return status.Error().Status(); + } + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtDetachInput( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (auto status = invocation_context->DetachInput(graph_input_index, + tensor_buffer_handle); + !status) { + LITERT_LOG(LITERT_ERROR, "Failed to detach input: %s", + status.Error().Message().data()); + return status.Error().Status(); + } + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtDetachOutput( + LiteRtDispatchInvocationContext invocation_context, int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (auto status = invocation_context->DetachOutput(graph_output_index, + tensor_buffer_handle); + !status) { + LITERT_LOG(LITERT_ERROR, "Failed to detach output: %s", + status.Error().Message().data()); + return status.Error().Status(); + } + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtInvoke(LiteRtDispatchInvocationContext invocation_context) { + if (auto status = invocation_context->Invoke(); !status) { + LITERT_LOG(LITERT_ERROR, "Failed to invoke context: %s", + status.Error().Message().data()); + return status.Error().Status(); + } + return kLiteRtStatusOk; +} + +} // namespace mediatek +} // namespace litert + +// ///////////////////////////////////////////////////////////////////////////// + +namespace { + +LiteRtDispatchInterface TheInterface = { + .initialize = litert::mediatek::LiteRtInitialize, + .get_vendor_id = litert::mediatek::LiteRtGetVendorId, + .get_build_id = litert::mediatek::LiteRtGetBuildId, + .get_capabilities = litert::mediatek::LiteRtGetCapabilities, + .device_context_create = litert::mediatek::LiteRtDeviceContextCreate, + .device_context_destroy = litert::mediatek::LiteRtDeviceContextDestroy, + .get_input_requirements = litert::mediatek::LiteRtGetInputRequirements, + .get_output_requirements = litert::mediatek::LiteRtGetOutputRequirements, + .register_tensor_buffer = litert::mediatek::LiteRtRegisterTensorBuffer, + .unregister_tensor_buffer = litert::mediatek::LiteRtUnregisterTensorBuffer, + .invocation_context_create = + litert::mediatek::LiteRtInvocationContextCreate, + .invocation_context_destroy = + litert::mediatek::LiteRtInvocationContextDestroy, + .attach_input = litert::mediatek::LiteRtAttachInput, + .attach_output = litert::mediatek::LiteRtAttachOutput, + .detach_input = litert::mediatek::LiteRtDetachInput, + .detach_output = litert::mediatek::LiteRtDetachOutput, + .invoke = litert::mediatek::LiteRtInvoke, +}; + +LiteRtDispatchApi TheApi = { + .version = {.major = LITERT_API_VERSION_MAJOR, + .minor = LITERT_API_VERSION_MINOR, + .patch = LITERT_API_VERSION_PATCH}, + .interface = &TheInterface, + .async_interface = nullptr, + .graph_interface = nullptr, +}; + +} // namespace + +LiteRtStatus LiteRtDispatchGetApi(LiteRtDispatchApi* api) { + *api = TheApi; + return kLiteRtStatusOk; +} diff --git a/tflite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc b/tflite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc new file mode 100644 index 00000000..111f9663 --- /dev/null +++ b/tflite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc @@ -0,0 +1,331 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include "absl/log/absl_log.h" +#include "absl/log/log.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/core/filesystem.h" +#include "tflite/experimental/litert/test/testdata/simple_model_test_vectors.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" + +using ::testing::Pointwise; + +TEST(DispatchApi, MediaTek) { +#if !defined(__ANDROID__) + GTEST_SKIP() + << "This test is specific to Android devices with a MediaTek NPU"; +#endif + + EXPECT_EQ(LiteRtDispatchInitialize(/*options=*/nullptr, /*num_options=*/0), + kLiteRtStatusOk); + + const char* vendor_id; + EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); + ABSL_LOG(INFO) << "vendor_id: " << vendor_id; + + const char* build_id; + EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); + ABSL_LOG(INFO) << "build_id: " << build_id; + + LiteRtApiVersion api_version; + EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); + ABSL_LOG(INFO) << "api_version: " << api_version.major << "." + << api_version.minor << "." << api_version.patch; + + int capabilities; + EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); + ABSL_LOG(INFO) << "capabilities: " << capabilities; + + LiteRtDispatchDeviceContext device_context = nullptr; + EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), + kLiteRtStatusOk); + ABSL_LOG(INFO) << "device_context: " << device_context; + + auto model_file_name = kMediaTekModelFileName; + auto model = litert::internal::LoadBinaryFile(model_file_name); + EXPECT_TRUE(model); + ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() + << " bytes"; + + // /////////////////////////////////////////////////////////////////////////// + // Set up an invocation context for a given model. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtDispatchInvocationContext invocation_context = nullptr; + EXPECT_EQ(LiteRtDispatchInvocationContextCreate( + device_context, kLiteRtDispatchExecutableTypeMlModel, + model->Data(), model->Size(), /*function_name=*/nullptr, + /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), + kLiteRtStatusOk); + ABSL_LOG(INFO) << "Invocation context: " << invocation_context; + + // /////////////////////////////////////////////////////////////////////////// + // Determine tensor buffer requirements. + // /////////////////////////////////////////////////////////////////////////// + + int num_tensor_buffer_types; + LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetInputRequirements( + invocation_context, /*input_index=*/0, &kInput0TensorType, + &input_0_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + input_0_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType input_0_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + input_0_tensor_buffer_requirements, /*type_index=*/0, + &input_0_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); + size_t input_0_tensor_buffer_size; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsBufferSize( + input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); + + LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetInputRequirements( + invocation_context, /*input_index=*/1, &kInput1TensorType, + &input_1_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + input_1_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType input_1_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + input_1_tensor_buffer_requirements, /*type_index=*/0, + &input_1_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); + size_t input_1_tensor_buffer_size; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsBufferSize( + input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); + + LiteRtTensorBufferRequirements output_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetOutputRequirements( + invocation_context, /*output_index=*/0, &kOutputTensorType, + &output_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + output_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType output_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + output_tensor_buffer_requirements, /*type_index=*/0, + &output_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); + size_t output_tensor_buffer_size; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( + output_tensor_buffer_requirements, &output_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); + + // /////////////////////////////////////////////////////////////////////////// + // Allocate tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtTensorBuffer input_0_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + input_0_tensor_buffer_type, &kInput0TensorType, + input_0_tensor_buffer_size, &input_0_tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBuffer input_1_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + input_1_tensor_buffer_type, &kInput1TensorType, + input_1_tensor_buffer_size, &input_1_tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBuffer output_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + output_tensor_buffer_type, &kOutputTensorType, + output_tensor_buffer_size, &output_tensor_buffer), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Register tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtTensorBufferHandle input_1_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, input_1_tensor_buffer, &input_1_handle), + kLiteRtStatusOk); + + LiteRtTensorBufferHandle input_0_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, input_0_tensor_buffer, &input_0_handle), + kLiteRtStatusOk); + + LiteRtTensorBufferHandle output_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, output_tensor_buffer, &output_handle), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Attach tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, + /*graph_input_index=*/0, input_0_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, + /*graph_input_index=*/1, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, + /*graph_output_index=*/0, output_handle), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Fill the input buffers with data. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Filling inputs with data"; + void* host_mem_addr; + + ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Execute model. + // /////////////////////////////////////////////////////////////////////////// + + ABSL_LOG(INFO) << "Invoking execution..."; + EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Check output for correctness. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Checking output..."; + void* host_mem_addr; + ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + auto output = absl::MakeSpan(static_cast(host_mem_addr), + kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Fill the input buffers with more data. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Filling inputs with data"; + void* host_mem_addr; + + ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput0Tensor_2, + sizeof(kTestInput0Tensor_2)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput1Tensor_2, + sizeof(kTestInput1Tensor_2)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Execute model once more. + // /////////////////////////////////////////////////////////////////////////// + + ABSL_LOG(INFO) << "Invoking second execution..."; + EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Check output for correctness. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Checking output..."; + void* host_mem_addr; + ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + auto output = absl::MakeSpan(static_cast(host_mem_addr), + kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor_2[i]; + } + EXPECT_THAT(output, + Pointwise(testing::FloatNear(1e-3), kTestOutputTensor_2)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Clean up resources. + // /////////////////////////////////////////////////////////////////////////// + + EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, + /*graph_input_index=*/0, input_0_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, + /*graph_input_index=*/1, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, + /*graph_output_index=*/0, output_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), + kLiteRtStatusOk); + EXPECT_EQ( + LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ( + LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), + kLiteRtStatusOk); + LiteRtDestroyTensorBuffer(output_tensor_buffer); + LiteRtDestroyTensorBuffer(input_1_tensor_buffer); + LiteRtDestroyTensorBuffer(input_0_tensor_buffer); + EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), + kLiteRtStatusOk); +} diff --git a/tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.cc b/tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.cc new file mode 100644 index 00000000..b04ee266 --- /dev/null +++ b/tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.cc @@ -0,0 +1,137 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h" + +#include +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tflite/experimental/litert/vendors/mediatek/neuron_adapter.h" + +// NOLINTNEXTLINE +using litert::mediatek::NEURON_NO_ERROR; +using litert::mediatek::NeuronMemory; + +LiteRtDispatchDeviceContextT::~LiteRtDispatchDeviceContextT() = default; + +litert::Expected +LiteRtDispatchDeviceContextT::Create( + const litert::mediatek::NeuronAdapter& neuron_adapter) { + return std::unique_ptr( + new LiteRtDispatchDeviceContextT(neuron_adapter)); +} + +litert::Expected +LiteRtDispatchDeviceContextT::RegisterTensorBuffer( + const litert::TensorBuffer& tensor_buffer) { + auto tensor_buffer_type = tensor_buffer.BufferType(); + if (!tensor_buffer_type) { + return tensor_buffer_type.Error(); + } + + if (*tensor_buffer_type != kLiteRtTensorBufferTypeAhwb) { + LITERT_LOG(LITERT_ERROR, "Unsupported buffer type: %d", + *tensor_buffer_type); + return litert::Unexpected(kLiteRtStatusErrorUnsupported); + } + + auto tensor_buffer_size = tensor_buffer.Size(); + if (!tensor_buffer_size) { + return tensor_buffer_size.Error(); + } + + auto tensor_buffer_offset = tensor_buffer.Offset(); + if (!tensor_buffer_offset) { + return tensor_buffer_offset.Error(); + } + + auto ahwb = tensor_buffer.GetAhwb(); + if (!ahwb) { + return ahwb.Error(); + } + +#ifdef __ANDROID__ + NeuronMemory* neuron_memory; + if (neuron_adapter_.api().memory_create_from_ahwb(*ahwb, &neuron_memory) != + NEURON_NO_ERROR) { + return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to create NeuronMemory from AHWB"); + } + return neuron_memory_registry_.Register(neuron_memory, *tensor_buffer_size, + *tensor_buffer_offset); +#else + (void)neuron_adapter_; + return litert::Unexpected( + kLiteRtStatusErrorRuntimeFailure, + "AHardwareBuffer is not supported on this platform"); +#endif +} + +LiteRtDispatchDeviceContextT::NeuronMemoryRegistry::~NeuronMemoryRegistry() { + for (auto i = 0; i < records_.size(); ++i) { + auto& record = records_[i]; + if (record.neuron_memory != nullptr) { + neuron_adapter_.api().memory_free(record.neuron_memory); + } + } +} + +LiteRtTensorBufferHandle +LiteRtDispatchDeviceContextT::NeuronMemoryRegistry::Register( + NeuronMemory* neuron_memory, size_t size, size_t offset) { + int dest_index = -1; + for (auto i = 0; i < records_.size(); ++i) { + if (!records_[i].neuron_memory) { + dest_index = i; + break; + } + } + if (dest_index < 0) { + dest_index = records_.size(); + records_.push_back({}); + } + auto& dest = records_[dest_index]; + dest = {neuron_memory, size, offset}; + return dest_index; +} + +litert::Expected +LiteRtDispatchDeviceContextT::NeuronMemoryRegistry::Unregister( + LiteRtTensorBufferHandle tensor_buffer_handle) { + auto record = Find(tensor_buffer_handle); + if (!record) { + return record.Error(); + } else { + auto& mem = (*record)->neuron_memory; + neuron_adapter_.api().memory_free(mem); + mem = nullptr; + return {}; + } +} + +litert::Expected +LiteRtDispatchDeviceContextT::NeuronMemoryRegistry::Find( + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (tensor_buffer_handle < 0 || tensor_buffer_handle >= records_.size()) { + return litert::Unexpected(kLiteRtStatusErrorInvalidArgument, + "Invalid tensor buffer handle"); + } + return &records_[tensor_buffer_handle]; +} diff --git a/tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h b/tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h new file mode 100644 index 00000000..1cc58abf --- /dev/null +++ b/tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h @@ -0,0 +1,87 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tflite/experimental/litert/vendors/mediatek/neuron_adapter.h" + +class LiteRtDispatchDeviceContextT { + public: + using Ptr = std::unique_ptr; + struct NeuronMemoryInfo { + litert::mediatek::NeuronMemory* neuron_memory; + size_t size; + size_t offset; + }; + + ~LiteRtDispatchDeviceContextT(); + + static litert::Expected Create( + const litert::mediatek::NeuronAdapter& neuron_adapter); + + litert::Expected RegisterTensorBuffer( + const litert::TensorBuffer& tensor_buffer); + + litert::Expected UnregisterTensorBuffer( + LiteRtTensorBufferHandle tensor_buffer_handle) { + return neuron_memory_registry_.Unregister(tensor_buffer_handle); + } + + litert::Expected GetNeuronMemoryInfo( + LiteRtTensorBufferHandle tensor_buffer_handle) { + auto record = neuron_memory_registry_.Find(tensor_buffer_handle); + if (!record) { + return record.Error(); + } else { + return NeuronMemoryInfo(**record); + } + } + + private: + class NeuronMemoryRegistry { + public: + explicit NeuronMemoryRegistry( + const litert::mediatek::NeuronAdapter& neuron_adapter) + : neuron_adapter_(neuron_adapter) {} + ~NeuronMemoryRegistry(); + LiteRtTensorBufferHandle Register( + litert::mediatek::NeuronMemory* neuron_memory, size_t size, + size_t offset); + litert::Expected Unregister( + LiteRtTensorBufferHandle tensor_buffer_handle); + litert::Expected Find( + LiteRtTensorBufferHandle tensor_buffer_handle); + + private: + const litert::mediatek::NeuronAdapter& neuron_adapter_; + std::vector records_; + }; + + explicit LiteRtDispatchDeviceContextT( + const litert::mediatek::NeuronAdapter& neuron_adapter) + : neuron_adapter_(neuron_adapter), + neuron_memory_registry_(neuron_adapter) {} + + const litert::mediatek::NeuronAdapter& neuron_adapter_; + NeuronMemoryRegistry neuron_memory_registry_; +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ diff --git a/tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.cc b/tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.cc new file mode 100644 index 00000000..091a1f07 --- /dev/null +++ b/tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.cc @@ -0,0 +1,422 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.h" + +#include +#include +#include +#include +#include +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h" +#include "tflite/experimental/litert/vendors/mediatek/neuron_adapter.h" + +using litert::mediatek::NEURON_NO_ERROR; +using litert::mediatek::NEURON_PREFER_SUSTAINED_SPEED; +using litert::mediatek::NEURON_PRIORITY_HIGH; +using litert::mediatek::NEURON_TENSOR_FLOAT32; +using litert::mediatek::NeuronCompilation; +using litert::mediatek::NeuronExecution; +using litert::mediatek::NeuronModel; +using litert::mediatek::NeuronOperandType; +using litert::mediatek::NeuronOperationType; +using litert::mediatek::NeuronRuntimeVersion; + +namespace { + +bool LoadFromCachedNetwork( + const litert::mediatek::NeuronAdapter& neuron_adapter, NeuronModel*& model, + NeuronCompilation*& compilation, const void* bytecode_addr, + size_t bytecode_size) { + return neuron_adapter.api().model_restore_from_compiled_network( + &model, &compilation, bytecode_addr, bytecode_size) == + NEURON_NO_ERROR; +} + +uint16_t GetRestoreDlaExtensionOperandType( + const litert::mediatek::NeuronAdapter& neuron_adapter) { + NeuronRuntimeVersion version; + neuron_adapter.api().get_version(&version); + // The values below were suggested by MTK. + if (version.major >= 8) { + return 0x0200; + } else { + return 0x0100; + } +} + +bool LoadFromDlaBytecode(const litert::mediatek::NeuronAdapter& neuron_adapter, + NeuronModel*& model, NeuronCompilation*& compilation, + const void* bytecode_addr, size_t bytecode_size, + int num_inputs, int num_outputs, + const std::string& options) { + LITERT_LOG(LITERT_INFO, "Creating model..."); + if (neuron_adapter.api().model_create(&model) != NEURON_NO_ERROR) { + LITERT_LOG(LITERT_ERROR, "Failed to create model"); + return false; + } + + // fake input, the real outputs are loaded by compiled network. + constexpr const NeuronOperandType fake_io_operand_type{ + .type = NEURON_TENSOR_FLOAT32, + .dimensionCount = 0, + .scale = 0.0f, + .zeroPoint = 0, + }; + + std::vector input_op_number; + input_op_number.reserve(num_inputs); + for (auto i = 0; i < num_inputs; i++) { + if (neuron_adapter.api().model_add_operand(model, &fake_io_operand_type) != + NEURON_NO_ERROR) { + LITERT_LOG(LITERT_ERROR, "Failed to add input operand %d", i); + return false; + } + input_op_number.emplace_back(i); + } + + const uint16_t kNetworkOperandRestoreData = + GetRestoreDlaExtensionOperandType(neuron_adapter); + constexpr const uint16_t kRestoreDlaExtensionOperationType = 0; + constexpr const char* kExtensionRestoreCompiledNetwork = + "com.mediatek.compiled_network"; + + int32_t operand_type; + if (neuron_adapter.api().model_get_extension_operand_type( + model, kExtensionRestoreCompiledNetwork, kNetworkOperandRestoreData, + &operand_type) != NEURON_NO_ERROR) { + LITERT_LOG(LITERT_ERROR, "Failed to get extension operand"); + return false; + } + + const NeuronOperandType extension_operand_type{ + .type = operand_type, + .dimensionCount = 0, + .scale = 0.0f, + .zeroPoint = 0, + }; + if (neuron_adapter.api().model_add_operand(model, &extension_operand_type) != + NEURON_NO_ERROR) { + LITERT_LOG(LITERT_ERROR, "Failed to add extension operand"); + return false; + } + input_op_number.emplace_back(input_op_number.size()); + if (neuron_adapter.api().model_set_operand_value( + model, input_op_number.back(), bytecode_addr, bytecode_size) != + NEURON_NO_ERROR) { + LITERT_LOG(LITERT_ERROR, "Failed to set extension operand value"); + return false; + } + + std::vector output_op_number; + for (auto i = 0; i < num_outputs; i++) { + if (neuron_adapter.api().model_add_operand(model, &fake_io_operand_type) != + NEURON_NO_ERROR) { + LITERT_LOG(LITERT_ERROR, "Failed to add output operand %d", i); + return false; + } + output_op_number.emplace_back(input_op_number.size() + i); + } + + int32_t operation_type; + if (neuron_adapter.api().model_get_extension_operation_type( + model, kExtensionRestoreCompiledNetwork, + kRestoreDlaExtensionOperationType, + &operation_type) != NEURON_NO_ERROR) { + LITERT_LOG(LITERT_ERROR, "Failed to get extension operation"); + return false; + } + + // Add extension operation + if (neuron_adapter.api().model_add_operation( + model, static_cast(operation_type), + input_op_number.size(), input_op_number.data(), + output_op_number.size(), + output_op_number.data()) != NEURON_NO_ERROR) { + LITERT_LOG(LITERT_ERROR, "Failed to add extension operation"); + return false; + } + + if (neuron_adapter.api().model_identify_inputs_and_outputs( + model, input_op_number.size() - 1, input_op_number.data(), + output_op_number.size(), + output_op_number.data()) != NEURON_NO_ERROR) { + LITERT_LOG(LITERT_ERROR, "Failed to identify I/Os"); + return false; + } + + if (neuron_adapter.api().model_finish(model) != NEURON_NO_ERROR) { + LITERT_LOG(LITERT_ERROR, "Failed to finish model"); + return false; + } + + if (neuron_adapter.api().compilation_create_with_options( + model, &compilation, options.c_str()) != NEURON_NO_ERROR) { + LITERT_LOG(LITERT_ERROR, "Failed to create compilation"); + return false; + } + + if (neuron_adapter.api().compilation_set_priority( + compilation, NEURON_PRIORITY_HIGH) != NEURON_NO_ERROR) { + LITERT_LOG(LITERT_ERROR, "Failed to set compilation priority"); + return false; + } + + if (neuron_adapter.api().compilation_set_preference( + compilation, NEURON_PREFER_SUSTAINED_SPEED) != NEURON_NO_ERROR) { + LITERT_LOG(LITERT_ERROR, "Failed to set compilation preference"); + return false; + } + + if (!options.empty()) { + if (neuron_adapter.api().compilation_set_optimization_string( + compilation, options.c_str()) != NEURON_NO_ERROR) { + LITERT_LOG(LITERT_ERROR, "Failed to set optimization string"); + return false; + } + } + + if (neuron_adapter.api().compilation_finish(compilation) != NEURON_NO_ERROR) { + LITERT_LOG(LITERT_ERROR, "Failed to finish compilation"); + return false; + } + + return true; +} + +bool LoadModelAndCompilation( + const litert::mediatek::NeuronAdapter& neuron_adapter, NeuronModel*& model, + NeuronCompilation*& compilation, const void* bytecode_addr, + size_t bytecode_size, int num_inputs, int num_outputs) { + // Option `import_forever` has been recommended by MediaTek to reduce memory + // footprint when using the same I/O buffers across multiple invocations. + constexpr const char* kOptions = + "--apusys-config \"{ \\\"import_forever\\\": true }\""; + if (!LoadFromDlaBytecode(neuron_adapter, model, compilation, bytecode_addr, + bytecode_size, num_inputs, num_outputs, kOptions)) { + return LoadFromCachedNetwork(neuron_adapter, model, compilation, + bytecode_addr, bytecode_size); + } + return true; +} + +} // namespace + +litert::Expected +LiteRtDispatchInvocationContextT::Create( + litert::mediatek::NeuronAdapter& neuron_adapter, + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType exec_type, const void* bytecode_ptr, + size_t bytecode_size, const char* function_name, int num_inputs, + int num_outputs) { + NeuronModel* model; + NeuronCompilation* compilation; + if (!LoadModelAndCompilation(neuron_adapter, model, compilation, bytecode_ptr, + bytecode_size, num_inputs, num_outputs)) { + return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to load compiled model"); + } + + NeuronExecution* execution; + if (neuron_adapter.api().execution_create(compilation, &execution) != + NEURON_NO_ERROR) { + if (compilation) { + neuron_adapter.api().compilation_free(compilation); + } + if (model) { + neuron_adapter.api().model_free(model); + } + return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to create execution"); + } + + if (neuron_adapter.api().execution_set_boost_hint(execution, 100) != + NEURON_NO_ERROR) { + if (execution) { + neuron_adapter.api().execution_free(execution); + } + if (compilation) { + neuron_adapter.api().compilation_free(compilation); + } + if (model) { + neuron_adapter.api().model_free(model); + } + return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to set execution boost hint"); + } + + return Ptr(new LiteRtDispatchInvocationContextT( + neuron_adapter, device_context, model, compilation, execution, num_inputs, + num_outputs)); +} + +LiteRtDispatchInvocationContextT::~LiteRtDispatchInvocationContextT() { + if (execution_) { + neuron_adapter_.api().execution_free(execution_); + } + if (compilation_) { + neuron_adapter_.api().compilation_free(compilation_); + } + if (model_) { + neuron_adapter_.api().model_free(model_); + } +} + +LiteRtDispatchInvocationContextT::IoRequirementsBuilder::IoRequirementsBuilder( + size_t buffer_size, const std::vector& padded_dimensions) + : buffer_size_(buffer_size) { + auto rank = padded_dimensions.size(); + strides_.resize(rank); + strides_[0] = 1; + for (auto i = 1; i < rank; ++i) { + strides_[i] = padded_dimensions[i - 1]; + } +} + +litert::Expected +LiteRtDispatchInvocationContextT::IoRequirementsBuilder::Create() { + static constexpr std::array + kSupportedTensorBufferTypes = { + kLiteRtTensorBufferTypeAhwb, + }; + + LiteRtTensorBufferRequirements requirements; + if (auto status = LiteRtCreateTensorBufferRequirements( + kSupportedTensorBufferTypes.size(), + kSupportedTensorBufferTypes.data(), buffer_size_, strides_.size(), + strides_.data(), &requirements); + status != kLiteRtStatusOk) { + return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to create tensor buffer requirements"); + } + + return requirements; +} + +litert::Expected +LiteRtDispatchInvocationContextT::GetInputRequirements( + int input_index, const LiteRtRankedTensorType& tensor_type) { + if (!input_requirements_builders_[input_index]) { + size_t buffer_size; + if (neuron_adapter_.api().compilation_get_input_padded_size( + compilation_, input_index, &buffer_size) != NEURON_NO_ERROR) { + return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to get input padded size"); + } + + std::vector padded_dimensions(tensor_type.layout.rank); + if (neuron_adapter_.api().compilation_get_input_padded_dimensions( + compilation_, input_index, padded_dimensions.data()) != + NEURON_NO_ERROR) { + return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to get input padded dimensions"); + } + + input_requirements_builders_[input_index] = + std::make_unique(buffer_size, padded_dimensions); + } + + return input_requirements_builders_[input_index]->Create(); +} + +litert::Expected +LiteRtDispatchInvocationContextT::GetOutputRequirements( + int output_index, const LiteRtRankedTensorType& tensor_type) { + if (!output_requirements_builders_[output_index]) { + size_t buffer_size; + if (neuron_adapter_.api().compilation_get_output_padded_size( + compilation_, output_index, &buffer_size) != NEURON_NO_ERROR) { + return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to get output padded size"); + } + + std::vector padded_dimensions(tensor_type.layout.rank); + if (neuron_adapter_.api().compilation_get_output_padded_dimensions( + compilation_, output_index, padded_dimensions.data()) != + NEURON_NO_ERROR) { + return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to get output padded dimensions"); + } + + output_requirements_builders_[output_index] = + std::make_unique(buffer_size, padded_dimensions); + } + + return output_requirements_builders_[output_index]->Create(); +} + +litert::Expected LiteRtDispatchInvocationContextT::AttachInput( + int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle) { + auto neuron_memory_info = + device_context_->GetNeuronMemoryInfo(tensor_buffer_handle); + if (!neuron_memory_info) { + return litert::Unexpected(neuron_memory_info.Error()); + } + + if (neuron_adapter_.api().execution_set_input_from_memory( + execution_, graph_input_index, nullptr, + neuron_memory_info->neuron_memory, neuron_memory_info->offset, + neuron_memory_info->size) != NEURON_NO_ERROR) { + return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to set execution input from memory"); + } + return {}; +} + +litert::Expected LiteRtDispatchInvocationContextT::AttachOutput( + int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle) { + auto neuron_memory_info = + device_context_->GetNeuronMemoryInfo(tensor_buffer_handle); + if (!neuron_memory_info) { + return litert::Unexpected(neuron_memory_info.Error()); + } + + if (neuron_adapter_.api().execution_set_output_from_memory( + execution_, graph_output_index, nullptr, + neuron_memory_info->neuron_memory, neuron_memory_info->offset, + neuron_memory_info->size) != NEURON_NO_ERROR) { + return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to set execution output from memory"); + } + return {}; +} + +litert::Expected LiteRtDispatchInvocationContextT::DetachInput( + int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle) { + // Nothing to do. + return {}; +} + +litert::Expected LiteRtDispatchInvocationContextT::DetachOutput( + int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle) { + // Nothing to do. + return {}; +} + +litert::Expected LiteRtDispatchInvocationContextT::Invoke() { + if (neuron_adapter_.api().execution_compute(execution_) != NEURON_NO_ERROR) { + return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to execute network"); + } + return {}; +} diff --git a/tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.h b/tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.h new file mode 100644 index 00000000..d49aa29c --- /dev/null +++ b/tflite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.h @@ -0,0 +1,96 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ + +#include + +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tflite/experimental/litert/vendors/mediatek/neuron_adapter.h" + +class LiteRtDispatchInvocationContextT { + public: + using Ptr = std::unique_ptr; + + static litert::Expected Create( + litert::mediatek::NeuronAdapter& neuron_adapter, + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType exec_type, const void* exec_bytecode_ptr, + size_t exec_bytecode_size, const char* function_name, int num_inputs, + int num_outputs); + + ~LiteRtDispatchInvocationContextT(); + + litert::Expected GetInputRequirements( + int input_index, const LiteRtRankedTensorType& tensor_type); + + litert::Expected GetOutputRequirements( + int output_index, const LiteRtRankedTensorType& tensor_type); + + litert::Expected AttachInput( + int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle); + litert::Expected AttachOutput( + int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle); + + litert::Expected DetachInput( + int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle); + litert::Expected DetachOutput( + int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle); + + litert::Expected Invoke(); + + private: + class IoRequirementsBuilder { + public: + IoRequirementsBuilder(size_t buffer_size, + const std::vector& padded_dimensions); + litert::Expected Create(); + + private: + size_t buffer_size_; + std::vector strides_; + }; + + LiteRtDispatchInvocationContextT( + const litert::mediatek::NeuronAdapter& neuron_adapter, + LiteRtDispatchDeviceContext device_context, + litert::mediatek::NeuronModel* model, + litert::mediatek::NeuronCompilation* compilation, + litert::mediatek::NeuronExecution* execution, int num_inputs, + int num_outputs) + : neuron_adapter_(neuron_adapter), + device_context_(device_context), + model_(model), + compilation_(compilation), + execution_(execution), + input_requirements_builders_(num_inputs), + output_requirements_builders_(num_outputs) {} + + const litert::mediatek::NeuronAdapter& neuron_adapter_; + LiteRtDispatchDeviceContext device_context_; + litert::mediatek::NeuronModel* model_; + litert::mediatek::NeuronCompilation* compilation_; + litert::mediatek::NeuronExecution* execution_; + std::vector> + input_requirements_builders_; + std::vector> + output_requirements_builders_; +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ diff --git a/tflite/experimental/litert/vendors/mediatek/neuron_adapter.cc b/tflite/experimental/litert/vendors/mediatek/neuron_adapter.cc new file mode 100644 index 00000000..a369132e --- /dev/null +++ b/tflite/experimental/litert/vendors/mediatek/neuron_adapter.cc @@ -0,0 +1,130 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/mediatek/neuron_adapter.h" + +#include + +#include +#include +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/core/dynamic_loading.h" + +#define LOAD_SYMB(S, H) \ + H = reinterpret_cast(::dlsym(dlib_handle_, #S)); \ + if (!H) { \ + LITERT_LOG(LITERT_WARNING, "Failed to load symbol %s: %s", #S, \ + ::dlerror()); \ + } + +namespace litert { +namespace mediatek { + +NeuronAdapter::NeuronAdapter() : api_(new Api) {} + +NeuronAdapter::~NeuronAdapter() { + if (dlib_handle_) { + litert::internal::CloseLib(dlib_handle_); + } +} + +litert::Expected NeuronAdapter::Create( + std::optional shared_library_dir) { + std::unique_ptr neuron_adapter(new NeuronAdapter); + if (auto status = neuron_adapter->LoadSymbols(shared_library_dir); !status) { + return status.Error(); + } + + return neuron_adapter; +} + +litert::Expected NeuronAdapter::LoadSymbols( + std::optional shared_library_dir) { + // The following preinstalled library is for system partition applications. + if (litert::internal::OpenLib("libneuronusdk_adapter.mtk.so", + &dlib_handle_) != kLiteRtStatusOk) { + // The next preinstalled library is in the vendor partition. + if (litert::internal::OpenLib("libneuron_adapter_mgvi.so", &dlib_handle_) != + kLiteRtStatusOk) { + // Finally, the app may want to provide their own version of the library. + constexpr auto kLibNeuronAdapterLib = "libneuron_adapter.so"; + std::string library_path = + shared_library_dir.has_value() + ? *shared_library_dir + kLibNeuronAdapterLib + : kLibNeuronAdapterLib; + if (litert::internal::OpenLib(library_path, &dlib_handle_) != + kLiteRtStatusOk) { + return litert::Unexpected( + kLiteRtStatusErrorRuntimeFailure, + "Failed to load NeuronAdapter shared library"); + } + } + } + + // Binds all supported symbols from the shared library to the function + // pointers. + LOAD_SYMB(NeuronCompilation_create, api_->compilation_create); + LOAD_SYMB(NeuronCompilation_createWithOptions, + api_->compilation_create_with_options); + LOAD_SYMB(NeuronCompilation_finish, api_->compilation_finish); + LOAD_SYMB(NeuronCompilation_free, api_->compilation_free); + LOAD_SYMB(NeuronCompilation_getInputPaddedDimensions, + api_->compilation_get_input_padded_dimensions); + LOAD_SYMB(NeuronCompilation_getInputPaddedSize, + api_->compilation_get_input_padded_size); + LOAD_SYMB(NeuronCompilation_getOutputPaddedDimensions, + api_->compilation_get_output_padded_dimensions); + LOAD_SYMB(NeuronCompilation_getOutputPaddedSize, + api_->compilation_get_output_padded_size); + LOAD_SYMB(NeuronCompilation_setOptimizationString, + api_->compilation_set_optimization_string); + LOAD_SYMB(NeuronCompilation_setPreference, api_->compilation_set_preference); + LOAD_SYMB(NeuronCompilation_setPriority, api_->compilation_set_priority); + LOAD_SYMB(NeuronExecution_compute, api_->execution_compute); + LOAD_SYMB(NeuronExecution_create, api_->execution_create); + LOAD_SYMB(NeuronExecution_free, api_->execution_free); + LOAD_SYMB(NeuronExecution_setBoostHint, api_->execution_set_boost_hint); + LOAD_SYMB(NeuronExecution_setInputFromMemory, + api_->execution_set_input_from_memory); + LOAD_SYMB(NeuronExecution_setOutputFromMemory, + api_->execution_set_output_from_memory); + LOAD_SYMB(NeuronMemory_createFromAHardwareBuffer, + api_->memory_create_from_ahwb); + LOAD_SYMB(NeuronMemory_free, api_->memory_free); + LOAD_SYMB(NeuronModel_addOperand, api_->model_add_operand); + LOAD_SYMB(NeuronModel_addOperation, api_->model_add_operation); + LOAD_SYMB(NeuronModel_create, api_->model_create); + LOAD_SYMB(NeuronModel_finish, api_->model_finish); + LOAD_SYMB(NeuronModel_free, api_->model_free); + LOAD_SYMB(NeuronModel_getExtensionOperandType, + api_->model_get_extension_operand_type); + LOAD_SYMB(NeuronModel_getExtensionOperationType, + api_->model_get_extension_operation_type); + LOAD_SYMB(NeuronModel_identifyInputsAndOutputs, + api_->model_identify_inputs_and_outputs); + LOAD_SYMB(NeuronModel_restoreFromCompiledNetwork, + api_->model_restore_from_compiled_network); + LOAD_SYMB(NeuronModel_setOperandValue, api_->model_set_operand_value); + LOAD_SYMB(Neuron_getVersion, api_->get_version); + + LITERT_LOG(LITERT_INFO, "NeuronAdapter symbols loaded"); + return {}; +} + +} // namespace mediatek +} // namespace litert diff --git a/tflite/experimental/litert/vendors/mediatek/neuron_adapter.h b/tflite/experimental/litert/vendors/mediatek/neuron_adapter.h new file mode 100644 index 00000000..7e5b28cd --- /dev/null +++ b/tflite/experimental/litert/vendors/mediatek/neuron_adapter.h @@ -0,0 +1,219 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_NEURON_ADAPTER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_NEURON_ADAPTER_H_ + +#include +#include +#include +#include + +#include "tflite/experimental/litert/cc/litert_expected.h" + +#if LITERT_HAS_AHWB_SUPPORT +#include +#else +struct AHardwareBuffer {}; +#endif + +namespace litert::mediatek { + +// ///////////////////////////////////////////////////////////////////////////// +// +// A minimal set of definitions for the NeuronAdapter API, from public domain +// sources. +// +// ///////////////////////////////////////////////////////////////////////////// + +struct NeuronRuntimeVersion { + uint8_t major; + uint8_t minor; + uint8_t patch; +}; + +enum NeuronOperationType { + NEURON_ADD = 0, +}; + +struct NeuronOperandType { + int32_t type; + // NOLINTNEXTLINE + uint32_t dimensionCount; + const uint32_t* dimensions; + float scale; + // NOLINTNEXTLINE + int32_t zeroPoint; +}; + +struct NeuronModel; +struct NeuronCompilation; +struct NeuronExecution; +struct NeuronMemory; + +static constexpr int NEURON_NO_ERROR = 0; +static constexpr int NEURON_FLOAT32 = 0; +static constexpr int NEURON_TENSOR_FLOAT32 = 3; +static constexpr int NEURON_PRIORITY_HIGH = 110; +static constexpr int NEURON_PREFER_SUSTAINED_SPEED = 2; + +int NeuronCompilation_create(NeuronModel* model, + NeuronCompilation** compilation); +int NeuronCompilation_createWithOptions(NeuronModel* model, + NeuronCompilation** compilation, + const char* options); +int NeuronCompilation_finish(NeuronCompilation* compilation); +int NeuronCompilation_getInputPaddedDimensions(NeuronCompilation* compilation, + int32_t index, + uint32_t* dimensions); +int NeuronCompilation_getInputPaddedSize(NeuronCompilation* compilation, + int32_t index, size_t* size); +int NeuronCompilation_getOutputPaddedDimensions(NeuronCompilation* compilation, + int32_t index, + uint32_t* dimensions); +int NeuronCompilation_getOutputPaddedSize(NeuronCompilation* compilation, + int32_t index, size_t* size); +int NeuronCompilation_setOptimizationString(NeuronCompilation* compilation, + const char* optimizationString); +int NeuronCompilation_setPreference(NeuronCompilation* compilation, + int32_t preference); +int NeuronCompilation_setPriority(NeuronCompilation* compilation, int priority); +int NeuronExecution_compute(NeuronExecution* execution); +int NeuronExecution_create(NeuronCompilation* compilation, + NeuronExecution** execution); +int NeuronExecution_setBoostHint(NeuronExecution* execution, + uint8_t boostValue); +int NeuronExecution_setInputFromMemory(NeuronExecution* execution, + uint32_t index, + const NeuronOperandType* type, + const NeuronMemory* memory, + size_t offset, size_t length); +int NeuronExecution_setOutputFromMemory(NeuronExecution* execution, + uint32_t index, + const NeuronOperandType* type, + const NeuronMemory* memory, + size_t offset, size_t length); +int NeuronMemory_createFromAHardwareBuffer(const AHardwareBuffer* ahwb, + NeuronMemory** memory); +int NeuronModel_addOperand(NeuronModel* model, const NeuronOperandType* type); +int NeuronModel_addOperation(NeuronModel* model, NeuronOperationType type, + uint32_t inputCount, const uint32_t* inputs, + uint32_t outputCount, const uint32_t* outputs); +int NeuronModel_create(NeuronModel** model); +int NeuronModel_finish(NeuronModel* model); +int NeuronModel_getExtensionOperandType(NeuronModel* model, + const char* extensionName, + uint16_t operandCodeWithinExtension, + int32_t* type); +int NeuronModel_getExtensionOperationType(NeuronModel* model, + const char* extensionName, + uint16_t operationCodeWithinExtension, + int32_t* type); +int NeuronModel_identifyInputsAndOutputs(NeuronModel* model, + uint32_t inputCount, + const uint32_t* inputs, + uint32_t outputCount, + const uint32_t* outputs); +int NeuronModel_restoreFromCompiledNetwork(NeuronModel** model, + NeuronCompilation** compilation, + const void* buffer, size_t size); +int NeuronModel_setOperandValue(NeuronModel* model, int32_t index, + const void* buffer, size_t length); +int Neuron_getVersion(NeuronRuntimeVersion* version); +void NeuronCompilation_free(NeuronCompilation* compilation); +void NeuronExecution_free(NeuronExecution* execution); +void NeuronMemory_free(NeuronMemory* memory); +void NeuronModel_free(NeuronModel* model); + +// ///////////////////////////////////////////////////////////////////////////// + +class NeuronAdapter { + public: + using Ptr = std::unique_ptr; + struct Api; + + NeuronAdapter(NeuronAdapter&) = delete; + NeuronAdapter(NeuronAdapter&&) = delete; + NeuronAdapter& operator=(const NeuronAdapter&) = delete; + NeuronAdapter& operator=(NeuronAdapter&&) = delete; + + ~NeuronAdapter(); + + static litert::Expected Create( + std::optional shared_library_dir); + + const Api& api() const { return *api_; } + + private: + NeuronAdapter(); + litert::Expected LoadSymbols( + std::optional shared_library_dir); + + void* dlib_handle_ = nullptr; + std::unique_ptr api_; +}; + +// A convenient struct for holding function pointers to NeuronAdapter API +// symbols. These function pointers will be loaded to the shared library on +// device during runtime. +struct NeuronAdapter::Api { + decltype(&NeuronCompilation_create) compilation_create = nullptr; + decltype(&NeuronCompilation_createWithOptions) + compilation_create_with_options = nullptr; + decltype(&NeuronCompilation_finish) compilation_finish = nullptr; + decltype(&NeuronCompilation_free) compilation_free = nullptr; + decltype(&NeuronCompilation_getInputPaddedDimensions) + compilation_get_input_padded_dimensions = nullptr; + decltype(&NeuronCompilation_getInputPaddedSize) + compilation_get_input_padded_size = nullptr; + decltype(&NeuronCompilation_getOutputPaddedDimensions) + compilation_get_output_padded_dimensions = nullptr; + decltype(&NeuronCompilation_getOutputPaddedSize) + compilation_get_output_padded_size = nullptr; + decltype(&NeuronCompilation_setOptimizationString) + compilation_set_optimization_string = nullptr; + decltype(&NeuronCompilation_setPreference) compilation_set_preference = + nullptr; + decltype(&NeuronCompilation_setPriority) compilation_set_priority = nullptr; + decltype(&NeuronExecution_compute) execution_compute = nullptr; + decltype(&NeuronExecution_create) execution_create = nullptr; + decltype(&NeuronExecution_free) execution_free = nullptr; + decltype(&NeuronExecution_setBoostHint) execution_set_boost_hint = nullptr; + decltype(&NeuronExecution_setInputFromMemory) + execution_set_input_from_memory = nullptr; + decltype(&NeuronExecution_setOutputFromMemory) + execution_set_output_from_memory = nullptr; + decltype(&NeuronMemory_createFromAHardwareBuffer) memory_create_from_ahwb = + nullptr; + decltype(&NeuronMemory_free) memory_free = nullptr; + decltype(&NeuronModel_addOperand) model_add_operand = nullptr; + decltype(&NeuronModel_addOperation) model_add_operation = nullptr; + decltype(&NeuronModel_create) model_create = nullptr; + decltype(&NeuronModel_finish) model_finish = nullptr; + decltype(&NeuronModel_free) model_free = nullptr; + decltype(&NeuronModel_getExtensionOperandType) + model_get_extension_operand_type = nullptr; + decltype(&NeuronModel_getExtensionOperationType) + model_get_extension_operation_type = nullptr; + decltype(&NeuronModel_identifyInputsAndOutputs) + model_identify_inputs_and_outputs = nullptr; + decltype(&NeuronModel_restoreFromCompiledNetwork) + model_restore_from_compiled_network = nullptr; + decltype(&NeuronModel_setOperandValue) model_set_operand_value = nullptr; + decltype(&Neuron_getVersion) get_version = nullptr; +}; + +} // namespace litert::mediatek + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_NEURON_ADAPTER_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/BUILD b/tflite/experimental/litert/vendors/qualcomm/BUILD new file mode 100644 index 00000000..0e570f66 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/BUILD @@ -0,0 +1,136 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//tflite/experimental/litert/build_common:litert_build_defs.bzl", "litert_lib", "litert_test") +load("//tflite/experimental/litert/vendors/qualcomm:qualcomm_build_defs.bzl", "litert_cc_lib_with_qnn") + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "common", + hdrs = ["common.h"], + deps = [ + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/cc:litert_model", + ], +) + +litert_lib( + name = "qnn_log", + srcs = ["qnn_log.cc"], + hdrs = ["qnn_log.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + deps = [ + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + ], +) + +cc_library( + name = "qnn_manager_hdr", + hdrs = ["qnn_manager.h"], + deps = [ + ":common", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + ], +) + +litert_cc_lib_with_qnn( + name = "qnn_manager", + srcs = [ + "qnn_manager.cc", + ], + hdrs = ["qnn_manager.h"], + include_system = True, + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + ungrte = True, + deps = [ + ":common", + ":qnn_log", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/core:dynamic_loading", + ], +) + +litert_test( + name = "qnn_manager_test", + srcs = ["qnn_manager_test.cc"], + linkstatic = True, + tags = [ + # Tests with ungrte deps do not currently work on forge. + "no-remote-exec", + "notap", + # Don't build/test in OS until qnn is available. + "nobuilder", + "no_oss", + # Sanitizer runtime doesn't work with anything that loads libQnnHtp.so. + "nosan", + ], + deps = [ + ":qnn_manager", + "//tflite/experimental/litert/test:common", + "//tflite/experimental/litert/vendors/qualcomm/tools:dump", + ], +) + +cc_library( + name = "context_binary_info", + srcs = ["context_binary_info.cc"], + hdrs = ["context_binary_info.h"], + deps = [ + ":qnn_manager", + ":qnn_tensor", + "@com_google_absl//absl/strings:string_view", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/cc:litert_expected", + ], +) + +cc_library( + name = "qnn_tensor", + srcs = ["qnn_tensor.cc"], + hdrs = ["qnn_tensor.h"], + deps = [ + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings:string_view", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/cc:litert_expected", + ], +) diff --git a/tflite/experimental/litert/vendors/qualcomm/common.h b/tflite/experimental/litert/vendors/qualcomm/common.h new file mode 100644 index 00000000..a327d085 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/common.h @@ -0,0 +1,100 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMMON_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMMON_H_ + +#include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "third_party/qairt/latest/include/QNN/System/QnnSystemInterface.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_model.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#define LITERT_RETURN_STATUS_IF_QNN_NOT_OK(expr) \ + if (QNN_SUCCESS != (expr)) { \ + return kLiteRtStatusErrorNotFound; \ + } + +// Pointers to functions of a dynamically loaded QNN library. +typedef QNN_INTERFACE_VER_TYPE QnnApi; + +// Pointers to functions of a dynamically loaded QNN system library. +typedef QNN_SYSTEM_INTERFACE_VER_TYPE QnnSystemApi; + +// QNN backend library should be on DT_RUNPATH (-rpath). +static const char kLibQnnHtpSo[] = "libQnnHtp.so"; + +// QNN backend library should be on DT_RUNPATH (-rpath). +static const char kLibQnnSystemSo[] = "libQnnSystem.so"; + +// Map LiteRT element type to Qnn counterpart. +inline LiteRtStatus LegalizeElementType(litert::ElementType litert_type, + Qnn_DataType_t* qnn_type) { + switch (litert_type) { + case litert::ElementType::Bool: + *qnn_type = QNN_DATATYPE_BOOL_8; + break; + case litert::ElementType::Int4: + *qnn_type = QNN_DATATYPE_SFIXED_POINT_4; + break; + case litert::ElementType::Int8: + *qnn_type = QNN_DATATYPE_INT_8; + break; + case litert::ElementType::Int16: + *qnn_type = QNN_DATATYPE_INT_16; + break; + case litert::ElementType::Int32: + *qnn_type = QNN_DATATYPE_INT_32; + break; + case litert::ElementType::Int64: + *qnn_type = QNN_DATATYPE_INT_64; + break; + case litert::ElementType::UInt8: + *qnn_type = QNN_DATATYPE_UINT_8; + break; + case litert::ElementType::UInt16: + *qnn_type = QNN_DATATYPE_UINT_16; + break; + case litert::ElementType::UInt32: + *qnn_type = QNN_DATATYPE_UINT_32; + break; + case litert::ElementType::UInt64: + *qnn_type = QNN_DATATYPE_UINT_64; + break; + case litert::ElementType::Float16: + *qnn_type = QNN_DATATYPE_FLOAT_16; + break; + case litert::ElementType::Float32: + *qnn_type = QNN_DATATYPE_FLOAT_32; + break; + case litert::ElementType::Float64: + *qnn_type = QNN_DATATYPE_FLOAT_64; + break; + default: + return kLiteRtStatusErrorUnsupported; + } + return kLiteRtStatusOk; +} + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMMON_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/BUILD b/tflite/experimental/litert/vendors/qualcomm/compiler/BUILD new file mode 100644 index 00000000..173030ad --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/BUILD @@ -0,0 +1,170 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//tflite/experimental/litert/build_common:litert_build_defs.bzl", "litert_dynamic_lib", "litert_lib", "litert_test") + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//visibility:private"], +) + +litert_dynamic_lib( + name = "qnn_compiler_plugin", + srcs = ["qnn_compiler_plugin.cc"], + hdrs = ["//tflite/experimental/litert/vendors/c:litert_compiler_plugin.h"], + export_litert_only = True, + shared_lib_name = "qnn_compiler_plugin_so", + so_name = "libLiteRtCompilerPlugin_Qualcomm.so", + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + ungrte = True, + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":qnn_compose_graph", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + ], +) + +litert_test( + name = "qnn_compiler_plugin_test", + srcs = [ + "qnn_compiler_plugin_test.cc", + ], + data = [ + "//tflite/experimental/litert/test:mlir_test_data", + "//tflite/experimental/litert/test:tflite_test_data", + ], + linkstatic = True, + tags = [ + # Tests with ungrte deps do not currently work on forge. + "no-remote-exec", + "notap", + # Don't build/test in OS until qnn is available. + "nobuilder", + "no_oss", + # Sanitizer runtime doesn't work with anything that loads libQnnHtp.so. + "nosan", + ], + use_sys_malloc = True, + deps = [ + ":qnn_compiler_plugin", # buildcleaner: keep + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings:string_view", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/core/model", + "//tflite/experimental/litert/test:common", + "//tflite/experimental/litert/test:test_macros", + "//tflite/experimental/litert/test:test_models", + "//tflite/experimental/litert/vendors/cc:litert_compiler_plugin", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + ], +) + +litert_lib( + name = "qnn_compose_graph", + srcs = ["qnn_compose_graph.cc"], + hdrs = ["qnn_compose_graph.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + deps = [ + ":graph_mapper", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:add_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:batch_matmul_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:cast_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:concatenation_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:cos_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:div_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:embedding_lookup_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:fully_connected_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:greater_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:less_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:logical_and_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:mul_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:reshape_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:rsqrt_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:select_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:sin_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:slice_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:softmax_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:sub_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:sum_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:tanh_op_legalization", + "//tflite/experimental/litert/vendors/qualcomm/compiler/legalizations:transpose_op_legalization", + ], +) + +litert_lib( + name = "graph_mapper", + srcs = [ + "graph_mapper.cc", + ], + hdrs = ["graph_mapper.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/cc:litert_element_type", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/IR/BUILD b/tflite/experimental/litert/vendors/qualcomm/compiler/IR/BUILD new file mode 100644 index 00000000..fd4859aa --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/IR/BUILD @@ -0,0 +1,123 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert/vendors/qualcomm/compiler:__subpackages__"], +) + +cc_library( + name = "qnn_tensor", + srcs = ["qnn_tensor.cc"], + hdrs = ["qnn_tensor.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + deps = [ + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/vendors/qualcomm:common", + ], +) + +cc_test( + name = "qnn_tensor_test", + srcs = ["qnn_tensor_test.cc"], + data = [ + "//tflite/experimental/litert/test:mlir_test_data", + "//tflite/experimental/litert/test:tflite_test_data", + ], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + "no_oss", + ], + deps = [ + ":qnn_tensor", + "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/test:common", + "//tflite/experimental/litert/test:test_macros", + "//tflite/experimental/litert/test:test_models", + ], +) + +cc_library( + name = "qnn_op", + srcs = ["qnn_op.cc"], + hdrs = ["qnn_op.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + deps = [ + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/cc:litert_model", + ], +) + +cc_test( + name = "qnn_op_test", + srcs = ["qnn_op_test.cc"], + data = ["//tflite/experimental/litert/test:mlir_test_data"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + "no_oss", + ], + deps = [ + ":qnn_op", + "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/strings", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/test:common", + ], +) + +cc_test( + name = "op_compatibility_test", + srcs = ["op_compatibility_test.cc"], + data = ["//tflite/experimental/litert/test:mlir_test_data"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + "no_oss", + ], + deps = [ + ":qnn_op", + "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/test:common", + ], +) diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/IR/op_compatibility_test.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/IR/op_compatibility_test.cc new file mode 100644 index 00000000..e8c6ed67 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/IR/op_compatibility_test.cc @@ -0,0 +1,82 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include "absl/strings/match.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/cc/litert_model_predicates.h" +#include "tflite/experimental/litert/test/common.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" + +namespace { + +static constexpr absl::string_view kOpTpl = "simple_%s_op.tflite"; +struct OpInfo { + std::string op_name; + std::string expected_type_name; +}; + +// TODOL: b/365299994 - Add "stablehlo_scatter" once muti subgraphs is +// supported. +// clang-format off +const auto kSupportedOps = testing::Values( + OpInfo{"add", "ElementWiseAdd"}, + OpInfo{"mul", "ElementWiseMultiply"}, + OpInfo{"batch_matmul", "MatMul"}, + OpInfo{"concatenation", "Concat"}, + OpInfo{"div", "ElementWiseDivide"}, + OpInfo{"fully_connected", "FullyConnected"}, + OpInfo{"reshape", "Reshape"}, + OpInfo{"rsqrt", "ElementWiseRsqrt"}, + OpInfo{"select_v2", "ElementWiseSelect"}, + OpInfo{"select", "ElementWiseSelect"}, + OpInfo{"strided_slice", "StridedSlice"}, + OpInfo{"slice", "StridedSlice"}, + OpInfo{"softmax", "Softmax"}, + OpInfo{"sub", "ElementWiseSubtract"}, + OpInfo{"tanh", "Tanh"}, + OpInfo{"transpose", "Transpose"}); +// clang-format on + +class OpCompatibilityTest : public ::testing::TestWithParam {}; + +TEST_P(OpCompatibilityTest, SupportedOpsTest) { + auto test_params = GetParam(); + std::string model_path = absl::StrFormat(kOpTpl, test_params.op_name); + auto model = litert::testing::LoadTestFileModel(model_path); + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + auto ops = subgraph->Ops(); + + Qnn_OpConfig_t qnn_op = litert::qnn::BuildDefaultOp(); + LITERT_ASSERT_STATUS_OK(litert::qnn::LegalizeOp(ops.front().Get(), qnn_op)); + + EXPECT_TRUE(absl::StrContains(qnn_op.v1.name, test_params.op_name)); + EXPECT_STREQ(qnn_op.v1.packageName, "qti.aisw"); + EXPECT_STREQ(qnn_op.v1.typeName, test_params.expected_type_name.c_str()); + + EXPECT_EQ(qnn_op.v1.numOfInputs, 0); + EXPECT_EQ(qnn_op.v1.numOfOutputs, 0); + EXPECT_EQ(qnn_op.v1.numOfParams, 0); + + litert::qnn::ResetOp(qnn_op); +} + +INSTANTIATE_TEST_SUITE_P(SupportedOpsTest, OpCompatibilityTest, kSupportedOps); + +} // namespace diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.cc new file mode 100644 index 00000000..d36fc971 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.cc @@ -0,0 +1,147 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_model.h" + +// A macro dance to create a unique literal string given a prefix. +#define STRINGIFY(x) #x +#define QNN_OP_NAME(prefix) STRINGIFY(prefix##__COUNTER) + +namespace litert::qnn { + +namespace { + +// Maps "op-code" related information (name, packageName, typeName) from src +// to dest. +LiteRtStatus LegalizeOpType(const Op& src, Qnn_OpConfig_t& dest) { + switch (src.Code()) { + case kLiteRtOpCodeTflMul: + dest.v1.name = QNN_OP_NAME(mul_); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "ElementWiseMultiply"; + break; + case kLiteRtOpCodeTflAdd: + dest.v1.name = QNN_OP_NAME("add"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "ElementWiseAdd"; + break; + case kLiteRtOpCodeTflBatchMatmul: + dest.v1.name = QNN_OP_NAME("batch_matmul"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "MatMul"; + break; + case kLiteRtOpCodeTflConcatenation: + dest.v1.name = QNN_OP_NAME("concatenation"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "Concat"; + break; + case kLiteRtOpCodeTflDiv: + dest.v1.name = QNN_OP_NAME("div"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "ElementWiseDivide"; + break; + case kLiteRtOpCodeTflFullyConnected: + dest.v1.name = QNN_OP_NAME("fully_connected"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "FullyConnected"; + break; + case kLiteRtOpCodeTflReshape: + dest.v1.name = QNN_OP_NAME("reshape"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "Reshape"; + break; + case kLiteRtOpCodeTflRsqrt: + dest.v1.name = QNN_OP_NAME("rsqrt"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "ElementWiseRsqrt"; + break; + case kLiteRtOpCodeTflSelectV2: + dest.v1.name = QNN_OP_NAME("select_v2"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "ElementWiseSelect"; + break; + case kLiteRtOpCodeTflSelect: + dest.v1.name = QNN_OP_NAME("select"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "ElementWiseSelect"; + break; + case kLiteRtOpCodeTflStridedSlice: + dest.v1.name = QNN_OP_NAME("strided_slice"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "StridedSlice"; + break; + case kLiteRtOpCodeTflSlice: + dest.v1.name = QNN_OP_NAME("slice"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "StridedSlice"; + break; + case kLiteRtOpCodeTflSoftmax: + dest.v1.name = QNN_OP_NAME("softmax"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "Softmax"; + break; + case kLiteRtOpCodeTflSub: + dest.v1.name = QNN_OP_NAME("sub"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "ElementWiseSubtract"; + break; + case kLiteRtOpCodeTflTanh: + dest.v1.name = QNN_OP_NAME("tanh"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "Tanh"; + break; + case kLiteRtOpCodeTflTranspose: + dest.v1.name = QNN_OP_NAME("transpose"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "Transpose"; + break; + default: + return kLiteRtStatusErrorUnsupported; + } + return kLiteRtStatusOk; +} + +} // namespace + +Qnn_OpConfig_t BuildDefaultOp() { + Qnn_OpConfig_t op = QNN_OPCONFIG_INIT; + ResetOp(op); + return op; +} +Qnn_Param_t BuildDefaultParam() { + Qnn_Param_t param = QNN_PARAM_INIT; + ResetParam(param); + return param; +} + +void ResetOp(Qnn_OpConfig_t& op) { + op = QNN_OPCONFIG_INIT; + op.version = QNN_OPCONFIG_VERSION_1; + op.v1 = QNN_OPCONFIG_V1_INIT; +} + +void ResetParam(Qnn_Param_t& param) { param = QNN_PARAM_INIT; } +LiteRtStatus LegalizeOp(LiteRtOp src, Qnn_OpConfig_t& dest) { + ResetOp(dest); + Op op(src); + return LegalizeOpType(op, dest); +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h b/tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h new file mode 100644 index 00000000..d15a6cbf --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h @@ -0,0 +1,53 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_OP_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_OP_H_ + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" + +namespace litert::qnn { + +// +// Initialize QNN Op. +// + +// NOTE: Any referential data within a QNN Op +// is allocated with "new" and must be explicitly cleaned up with ResetOp. + +// Construct a "blank" QNN Op. +Qnn_OpConfig_t BuildDefaultOp(); + +// Construct a "blank" QNN Param. +Qnn_Param_t BuildDefaultParam(); + +// Reset the given tensor, deallocating anything on the heap that it points to. +void ResetOp(Qnn_OpConfig_t& op); + +// Reset the given param, deallocating anything on the heap that it points to. +void ResetParam(Qnn_Param_t& param); + +// +// Legalize LiteRt Op to Analogous QNN Construct. +// + +// Map src op onto dest. Resets dest before doing anything. This only handles +// attribute-like info. It does not set edges (in/out tensors). +LiteRtStatus LegalizeOp(LiteRtOp src, Qnn_OpConfig_t& dest); + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_OP_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op_test.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op_test.cc new file mode 100644 index 00000000..0bd28ebd --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op_test.cc @@ -0,0 +1,65 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" + +#include +#include "absl/strings/match.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model_predicates.h" +#include "tflite/experimental/litert/test/common.h" + +namespace { + +TEST(TestInitQnnOp, BuildDefaultOp) { + Qnn_OpConfig_t op = litert::qnn::BuildDefaultOp(); + ASSERT_EQ(op.version, QNN_OPCONFIG_VERSION_1); +} + +TEST(TestLegalizeOp, SimpleSupportedOp) { + auto model = litert::testing::LoadTestFileModel("one_mul.tflite"); + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + auto ops = subgraph->Ops(); + + Qnn_OpConfig_t qnn_op = litert::qnn::BuildDefaultOp(); + LITERT_ASSERT_STATUS_OK(litert::qnn::LegalizeOp(ops.front().Get(), qnn_op)); + + EXPECT_TRUE(absl::StrContains(qnn_op.v1.name, "mul")); + EXPECT_STREQ(qnn_op.v1.packageName, "qti.aisw"); + EXPECT_STREQ(qnn_op.v1.typeName, "ElementWiseMultiply"); + + EXPECT_EQ(qnn_op.v1.numOfInputs, 0); + EXPECT_EQ(qnn_op.v1.numOfOutputs, 0); + EXPECT_EQ(qnn_op.v1.numOfParams, 0); + + litert::qnn::ResetOp(qnn_op); +} + +TEST(TestLegalizeOp, UnsupportedOp) { + auto model = litert::testing::LoadTestFileModel("simple_floor_mod_op.tflite"); + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + auto ops = subgraph->Ops(); + + Qnn_OpConfig_t qnn_op = litert::qnn::BuildDefaultOp(); + LITERT_ASSERT_STATUS_HAS_CODE( + litert::qnn::LegalizeOp(ops.front().Get(), qnn_op), + kLiteRtStatusErrorUnsupported); + + litert::qnn::ResetOp(qnn_op); +} + +} // namespace diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.cc new file mode 100644 index 00000000..51aa87be --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.cc @@ -0,0 +1,246 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" + +#include + +#include "absl/log/absl_check.h" +#include "absl/types/span.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/common.h" + +namespace litert::qnn { + +namespace { + +LiteRtStatus LegalizeShapeInfo(const litert::Layout& src, Qnn_Tensor_t& dest) { + LITERT_ENSURE_SUPPORTED(!src.HasStrides(), "Strides not yet supported"); + + dest.v2.rank = src.Rank(); + // Ad-hoc fix: rank 0 tensor needs to be single element 1D tensor in QNN. + if (dest.v2.rank == 0) { + LITERT_LOG(LITERT_INFO, "Setting rank 0 tensor to single element tensor"); + dest.v2.rank = 1; + dest.v2.dimensions = new uint32_t[1]; + dest.v2.dimensions[0] = 1; + return kLiteRtStatusOk; + } + + dest.v2.dimensions = new uint32_t[dest.v2.rank]; + for (int i = 0; i < dest.v2.rank; ++i) { + const auto src_dim = src.Dimensions()[i]; + LITERT_ENSURE(src_dim >= 1, kLiteRtStatusErrorInvalidArgument, + "Cannot pass dim < 1 to QNN Tensor."); + + dest.v2.dimensions[i] = src.Dimensions()[i]; + } + return kLiteRtStatusOk; +} + +void FreeTensorDims(Qnn_Tensor_t& tensor) { + if (tensor.version == QNN_TENSOR_VERSION_2 && + tensor.v2.dimensions != nullptr) { + delete[] tensor.v2.dimensions; + tensor.v2.dimensions = nullptr; + tensor.v2.rank = 0; + } +} + +void FreePerChannelQuantization(Qnn_Tensor_t& tensor) { + if (tensor.v2.quantizeParams.quantizationEncoding == + QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + delete[] tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset; + tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset = nullptr; + tensor.v2.quantizeParams.axisScaleOffsetEncoding.numScaleOffsets = 0; + } +} + +} // namespace + +void SetInputTensorAttrs(Qnn_Tensor_t& tensor) { + ABSL_DCHECK(tensor.version == QNN_TENSOR_VERSION_2); + tensor.v2.type = QNN_TENSOR_TYPE_APP_WRITE; + tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; + tensor.v2.clientBuf = QNN_CLIENT_BUFFER_INIT; +} + +void SetOutputTensorAttrs(Qnn_Tensor_t& tensor) { + ABSL_DCHECK(tensor.version == QNN_TENSOR_VERSION_2); + tensor.v2.type = QNN_TENSOR_TYPE_APP_READ; +} + +void SetResultTensorAttrs(Qnn_Tensor_t& tensor) { + ABSL_DCHECK(tensor.version == QNN_TENSOR_VERSION_2); + tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; + tensor.v2.type = QNN_TENSOR_TYPE_NATIVE; +} + +void ResetTensor(Qnn_Tensor_t& tensor) { + FreeTensorDims(tensor); + FreePerChannelQuantization(tensor); + tensor = QNN_TENSOR_INIT; + tensor.version = QNN_TENSOR_VERSION_2; + tensor.v2 = QNN_TENSOR_V2_INIT; + tensor.v2.dataFormat = QNN_TENSOR_DATA_FORMAT_DENSE; + tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; +} + +Qnn_Tensor_t BuildDefaultTensor(uint32_t id) { + Qnn_Tensor_t tensor = QNN_TENSOR_INIT; + ResetTensor(tensor); + tensor.v2.id = id; + return tensor; +} + +Qnn_Tensor_t BuildDefaultTensor() { return BuildDefaultTensor(0); } + +Qnn_Tensor_t BuildInputTensor() { + auto tensor = BuildDefaultTensor(); + SetInputTensorAttrs(tensor); + return tensor; +} + +Qnn_ClientBuffer_t BuildDefaultClientBuffer() { + Qnn_ClientBuffer_t client_buf = QNN_CLIENT_BUFFER_INIT; + client_buf.data = nullptr; + client_buf.dataSize = 0; + return client_buf; +} + +Qnn_Tensor_t BuildOutputTensor() { + Qnn_Tensor_t tensor = BuildDefaultTensor(); + SetOutputTensorAttrs(tensor); + return tensor; +} + +uint32_t MoveToId(Qnn_Tensor_t& tensor) { + const auto id = tensor.v2.id; + ResetTensor(tensor); + tensor.v2.id = id; + return id; +} + +void SetPerChannelQuantization( + Qnn_Tensor_t& tensor, + const LiteRtQuantizationPerChannel& lite_rt_quantization_per_channel) { + tensor.v2.quantizeParams.quantizationEncoding = + QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; + + tensor.v2.quantizeParams.axisScaleOffsetEncoding = QNN_AXIS_SCALE_OFFSET_INIT; + tensor.v2.quantizeParams.axisScaleOffsetEncoding.axis = + lite_rt_quantization_per_channel.quantized_dimension; + tensor.v2.quantizeParams.axisScaleOffsetEncoding.numScaleOffsets = + lite_rt_quantization_per_channel.num_channels; + + // Allocates memory for scaleOffset array. + tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset = + new Qnn_ScaleOffset_t[lite_rt_quantization_per_channel.num_channels]; + + for (int i = 0; i < lite_rt_quantization_per_channel.num_channels; ++i) { + tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset[i].scale = + lite_rt_quantization_per_channel.scales[i]; + tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset[i].offset = + lite_rt_quantization_per_channel.zero_points[i]; + } +} + +void SetPerTensorQuantization( + Qnn_Tensor_t& tensor, + const LiteRtQuantizationPerTensor& lite_rt_quantization_per_tensor) { + tensor.v2.quantizeParams.quantizationEncoding = + QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + tensor.v2.quantizeParams.scaleOffsetEncoding.scale = + lite_rt_quantization_per_tensor.scale; + tensor.v2.quantizeParams.scaleOffsetEncoding.offset = + lite_rt_quantization_per_tensor.zero_point; +} + +LiteRtStatus LegalizeQuntizationParameter(const litert::Tensor& src, + Qnn_Tensor_t& dest) { + LiteRtQuantizationTypeId lite_rt_quantization_type_id = src.QTypeId(); + switch (lite_rt_quantization_type_id) { + case kLiteRtQuantizationPerTensor: + SetPerTensorQuantization(dest, src.PerTensorQuantization()); + return kLiteRtStatusOk; + case kLiteRtQuantizationPerChannel: + SetPerChannelQuantization(dest, src.PerChannelQuantization()); + return kLiteRtStatusOk; + default: + LITERT_LOG(LITERT_ERROR, "Unsupported quantization type."); + return kLiteRtStatusErrorInvalidArgument; + } +} + +LiteRtStatus LegalizeTensor(const litert::Tensor& src, Qnn_Tensor_t& dest) { + if (src.TypeId() != kLiteRtRankedTensorType) { + return kLiteRtStatusErrorInvalidArgument; + } + + ResetTensor(dest); + + if (src.HasQuantization()) { + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeQuntizationParameter(src, dest)); + } + + Qnn_DataType_t* qnn_data_type = &dest.v2.dataType; + LITERT_RETURN_STATUS_IF_NOT_OK( + LegalizeElementType(src.RankedTensorType().ElementType(), qnn_data_type)); + + LITERT_RETURN_STATUS_IF_NOT_OK( + LegalizeShapeInfo(src.RankedTensorType().Layout(), dest)); + + const bool is_subgraph_in = src.IsSubgraphInput(); + const bool is_subgraph_out = src.IsSubgraphOutput(); + const bool is_constant = src.IsConstant(); + + LITERT_ENSURE(!(is_subgraph_in && is_subgraph_out), + kLiteRtStatusErrorInvalidArgument, + "Malformed tensor, cannot be both subgraph in and out."); + if (is_constant) { + LITERT_LOG(LITERT_INFO, "Adding constant tensor %s to qnn graph", + dest.v2.name); + LITERT_ENSURE(src.HasWeights(), kLiteRtStatusErrorInvalidLegalization, + "Empty weights for constant tensor."); + Qnn_ClientBuffer_t client_buf = BuildDefaultClientBuffer(); + client_buf.data = (void*)src.Weights().Bytes().data(); + client_buf.dataSize = src.Weights().Bytes().size(); + dest.v2.clientBuf = client_buf; + dest.v2.memType = QNN_TENSORMEMTYPE_RAW; + dest.v2.type = QNN_TENSOR_TYPE_STATIC; + dest.v2.isDynamicDimensions = nullptr; + } + + if (is_subgraph_in) { + LITERT_LOG(LITERT_INFO, "Adding subgraph input tensor to qnn graph"); + SetInputTensorAttrs(dest); + } + if (is_subgraph_out) { + LITERT_LOG(LITERT_INFO, "Adding subgraph output tensor to qnn graph"); + SetOutputTensorAttrs(dest); + } + if (!is_constant && !is_subgraph_in && !is_subgraph_out) { + LITERT_LOG(LITERT_INFO, "Adding result tensor to qnn graph"); + SetResultTensorAttrs(dest); + } + + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h b/tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h new file mode 100644 index 00000000..a98ebbf0 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h @@ -0,0 +1,73 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_TENSOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_TENSOR_H_ + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" + +namespace litert::qnn { + +// +// Initialize QNN Tensors. +// + +// NOTE: Within LiteRt land, all Qnn Tensors are treated as "v2". Any +// referential data (like dimensions : uint32_t*) within a QNN Tensor +// is allocated with "new" and must be explicitly cleaned up with ResetTensor. + +// Construct a "blank" QNN Tensor. +Qnn_Tensor_t BuildDefaultTensor(); + +// Construct a "blank" QNN Tensor with given id. +Qnn_Tensor_t BuildDefaultTensor(uint32_t id); + +// Constructa a "blank" QNN Tensor meant to be used as a graph input. +Qnn_Tensor_t BuildInputTensor(); + +// Constructa a "blank" QNN Tensor meant to be used as a graph output. +Qnn_Tensor_t BuildOutputTensor(); + +Qnn_ClientBuffer_t BuildDefaultClientBuffer(); + +// Adds attributes to given tensor making it amenable for use as graph input. +void SetInputTensorAttrs(Qnn_Tensor_t& tensor); + +// Adds attributes to given tensor making it amenable for use as graph output. +void SetOutputTensorAttrs(Qnn_Tensor_t& tensor); + +// Adds attributes to given tensor making it amenable for uses a intermediate +// output. +void SetResultTensorAttrs(Qnn_Tensor_t& tensor); + +// Reset the given tensor, deallocating anything on the heap that it points to. +void ResetTensor(Qnn_Tensor_t& tensor); + +// Resets all fields other than id in the given tensor and returns the id for +// convenience. Only the id is needed to traffic QNN Tensors after they have +// been registered with the context. +uint32_t MoveToId(Qnn_Tensor_t& tensor); + +// +// Legalize LiteRt Tensors to Analogous QNN Construct. +// + +// Map src tensor onto dest. Resets dest before doing anything. +LiteRtStatus LegalizeTensor(const litert::Tensor& src, Qnn_Tensor_t& dest); + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_TENSOR_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor_test.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor_test.cc new file mode 100644 index 00000000..0f4b9a59 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor_test.cc @@ -0,0 +1,203 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" + +#include +#include +#include "absl/types/span.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/test/common.h" +#include "tflite/experimental/litert/test/test_macros.h" +#include "tflite/experimental/litert/test/test_models.h" + +namespace { + +constexpr float kSimpleMulQuantModelOutputScale = 0.00028621565f; +constexpr float kSimpleMulQuantModelOutputOffset = 0; + +TEST(TestInitQnnTensor, BuildDefaultTensor) { + Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(); + ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(tensor.v2.dataFormat, QNN_TENSOR_DATA_FORMAT_DENSE); + EXPECT_EQ(tensor.v2.rank, 0); + EXPECT_EQ(tensor.v2.dimensions, nullptr); + EXPECT_EQ(tensor.v2.id, 0); +} + +TEST(TestInitQnnTensor, BuildDefaultTensorWithId) { + Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(2); + ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(tensor.v2.dataFormat, QNN_TENSOR_DATA_FORMAT_DENSE); + EXPECT_EQ(tensor.v2.rank, 0); + EXPECT_EQ(tensor.v2.dimensions, nullptr); + EXPECT_EQ(tensor.v2.id, 2); +} + +TEST(TestInitQnnTensor, BuildDefaultInputTensor) { + Qnn_Tensor_t tensor = litert::qnn::BuildInputTensor(); + ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_WRITE); + EXPECT_EQ(tensor.v2.memType, QNN_TENSORMEMTYPE_RAW); + EXPECT_EQ(tensor.v2.clientBuf.dataSize, 0); +} + +TEST(TestInitQnnTensor, SetInputTensor) { + Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(); + litert::qnn::SetInputTensorAttrs(tensor); + ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_WRITE); + EXPECT_EQ(tensor.v2.memType, QNN_TENSORMEMTYPE_RAW); + EXPECT_EQ(tensor.v2.clientBuf.dataSize, 0); +} + +TEST(TestInitQnnTensor, BuildDefaultOutputTensor) { + Qnn_Tensor_t tensor = litert::qnn::BuildOutputTensor(); + ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); +} + +TEST(TestInitQnnTensor, SetOutputTensor) { + Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(); + litert::qnn::SetOutputTensorAttrs(tensor); + ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); +} + +TEST(TestInitQnnTensor, MoveToId) { + Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(2); + + litert::qnn::SetOutputTensorAttrs(tensor); + ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); + + EXPECT_EQ(litert::qnn::MoveToId(tensor), 2); + EXPECT_EQ(tensor.v2.id, 2); + EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_UNDEFINED); +} + +TEST(TestLegalizeTensor, SimpleSupportedTensorSubgraphInput) { + auto model = litert::testing::LoadTestFileModel("one_mul.tflite"); + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + auto outputs = subgraph->Outputs(); + + auto qnn_tensor = litert::qnn::BuildDefaultTensor(); + const auto& output_tensor = outputs.front(); + LITERT_ASSERT_STATUS_OK( + litert::qnn::LegalizeTensor(output_tensor, qnn_tensor)); + + ASSERT_EQ(qnn_tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(qnn_tensor.v2.dataType, QNN_DATATYPE_FLOAT_32); + EXPECT_EQ(qnn_tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); + + ASSERT_EQ(qnn_tensor.v2.rank, 2); + ASSERT_NE(qnn_tensor.v2.dimensions, nullptr); + EXPECT_THAT(absl::MakeConstSpan(qnn_tensor.v2.dimensions, 2), + ::testing::ElementsAreArray({2, 2})); + + litert::qnn::ResetTensor(qnn_tensor); +} + +TEST(TestLegalizeTensor, SimpleSupportedTensor) { + auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); + + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + auto ops = subgraph->Ops(); + auto op_outs = ops.at(1).Outputs(); + + auto qnn_tensor = litert::qnn::BuildDefaultTensor(); + const auto& op_out = op_outs.front(); + LITERT_ASSERT_STATUS_OK(litert::qnn::LegalizeTensor(op_out, qnn_tensor)); + + ASSERT_EQ(qnn_tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(qnn_tensor.v2.dataType, QNN_DATATYPE_FLOAT_32); + EXPECT_EQ(qnn_tensor.v2.type, QNN_TENSOR_TYPE_NATIVE); + + ASSERT_EQ(qnn_tensor.v2.rank, 2); + ASSERT_NE(qnn_tensor.v2.dimensions, nullptr); + EXPECT_THAT(absl::MakeConstSpan(qnn_tensor.v2.dimensions, 2), + ::testing::ElementsAreArray({2, 2})); + + litert::qnn::ResetTensor(qnn_tensor); +} + +TEST(TestLegalizeTensor, SimpleQuantizedTensor) { + auto model = litert::testing::LoadTestFileModel(kQSimpleMul16x16Model); + + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + auto ops = subgraph->Ops(); + auto op_outs = ops.at(0).Outputs(); + + auto qnn_tensor = litert::qnn::BuildDefaultTensor(); + const auto& op_out = op_outs.front(); + LITERT_ASSERT_STATUS_OK(litert::qnn::LegalizeTensor(op_out, qnn_tensor)); + + ASSERT_EQ(qnn_tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(qnn_tensor.v2.dataType, QNN_DATATYPE_INT_16); + EXPECT_EQ(qnn_tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); + + ASSERT_EQ(qnn_tensor.v2.quantizeParams.quantizationEncoding, + QNN_QUANTIZATION_ENCODING_SCALE_OFFSET); + ASSERT_FLOAT_EQ(qnn_tensor.v2.quantizeParams.scaleOffsetEncoding.scale, + kSimpleMulQuantModelOutputScale); + + ASSERT_FLOAT_EQ(qnn_tensor.v2.quantizeParams.scaleOffsetEncoding.offset, + kSimpleMulQuantModelOutputOffset); + litert::qnn::ResetTensor(qnn_tensor); +} + +TEST(TestLegalizeTensor, PerChannelQuantizedTensor) { + auto model = litert::testing::LoadTestFileModel(kQKeyEinsum16x8Model); + + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + auto ops = subgraph->Ops(); + auto op_ins = ops.at(1).Inputs(); + + auto qnn_tensor = litert::qnn::BuildDefaultTensor(); + const auto& per_channel_quant_tensor = op_ins[1]; + LITERT_ASSERT_STATUS_OK( + litert::qnn::LegalizeTensor(per_channel_quant_tensor, qnn_tensor)); + + EXPECT_EQ(qnn_tensor.v2.dataType, QNN_DATATYPE_INT_8); + + LiteRtQuantizationPerChannel per_channel_quant_params = + per_channel_quant_tensor.PerChannelQuantization(); + + ASSERT_EQ(qnn_tensor.v2.quantizeParams.quantizationEncoding, + QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET); + EXPECT_EQ(qnn_tensor.v2.quantizeParams.axisScaleOffsetEncoding.axis, + per_channel_quant_params.quantized_dimension); + EXPECT_EQ( + qnn_tensor.v2.quantizeParams.axisScaleOffsetEncoding.numScaleOffsets, + per_channel_quant_params.num_channels); + for (int i = 0; i < per_channel_quant_params.num_channels; ++i) { + ASSERT_FLOAT_EQ( + qnn_tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset[i] + .scale, + per_channel_quant_params.scales[i]); + ASSERT_EQ( + qnn_tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset[i] + .offset, + per_channel_quant_params.zero_points[i]); + } + litert::qnn::ResetTensor(qnn_tensor); +} + +} // namespace diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.cc new file mode 100644 index 00000000..8c05b4cb --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.cc @@ -0,0 +1,163 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" + +#include +#include + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpGraph.h" +#include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnGraph.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_element_type.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/common.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" +#include "tflite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +namespace litert::qnn { + +// Get empty configurations for graph building. +inline absl::Span GetFp32GraphConfigs() { + static QnnHtpGraph_CustomConfig_t htp_graph_config = + QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT; + htp_graph_config.option = QNN_HTP_GRAPH_CONFIG_OPTION_PRECISION; + htp_graph_config.precision = QNN_PRECISION_FLOAT16; + + static QnnGraph_Config_t graph_config = QNN_GRAPH_CONFIG_INIT; + graph_config.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; + graph_config.customConfig = &htp_graph_config; + + static const QnnGraph_Config_t* configs[2] = {&graph_config, nullptr}; + return absl::MakeSpan(configs); +} + +inline absl::Span GetDefaultGraphConfigs() { + static const QnnGraph_Config_t* configs[] = {nullptr}; + return absl::MakeSpan(configs); +} + +absl::Span GraphMapper::PickGraphConfigHeuristic() { + for (const auto& input : subgraph_.Inputs()) { + if (input.RankedTensorType().ElementType() == ElementType::Float32) { + return GetFp32GraphConfigs(); + } + } + for (const auto& output : subgraph_.Outputs()) { + if (output.RankedTensorType().ElementType() == ElementType::Float32) { + return GetFp32GraphConfigs(); + } + } + return GetDefaultGraphConfigs(); +} + +LiteRtStatus GraphMapper::AssignTensorName(Qnn_Tensor_t& qnn_tensor) { + char* name = nullptr; + const int written = asprintf(&name, "Tensor_%d", cur_tensor_num_++); + LITERT_ENSURE(written != -1 && name != nullptr, kLiteRtStatusErrorNotFound, + "Failed to make tensor name"); + qnn_tensor.v2.name = name; + return kLiteRtStatusOk; +} + +absl::flat_hash_map& GraphMapper::CurrentScope() { + return current_scope_; +} + +LiteRtStatus GraphMapper::LookupInScope(LiteRtTensor litert_tensor, + Qnn_Tensor_t& qnn_tensor) { + // If we go in topological order, this should never happen. TODO: add + // "internal error" status code. + const auto qnn_id = CurrentScope().find(litert_tensor); + // when qnn_id is not found, the tensor is a constant tensor thats not been + // added qnn graph. + if (qnn_id == CurrentScope().end()) { + LITERT_LOG(LITERT_INFO, "Adding constant tensor %s to qnn graph", + qnn_tensor.v2.name); + LITERT_RETURN_STATUS_IF_NOT_OK( + LegalizeAndRegister(litert_tensor, qnn_tensor)); + LITERT_RETURN_STATUS_IF_NOT_OK(PushToScope(litert_tensor, qnn_tensor)); + // } + return kLiteRtStatusOk; + } + LITERT_LOG(LITERT_INFO, "Found tensor %d in current_scope.", qnn_id->second); + ResetTensor(qnn_tensor); + qnn_tensor.v2.id = qnn_id->second; + + return kLiteRtStatusOk; +} + +LiteRtStatus GraphMapper::PushToScope(LiteRtTensor litert_tensor, + Qnn_Tensor_t& qnn_tensor) { + CurrentScope()[litert_tensor] = MoveToId(qnn_tensor); + return kLiteRtStatusOk; +} + +QnnManager& GraphMapper::Qnn() { return qnn_; } + +Qnn_GraphHandle_t& GraphMapper::QnnGraph() { return qnn_graph_; } + +LiteRtStatus GraphMapper::LegalizeAndRegister(LiteRtTensor litert_tensor, + Qnn_Tensor_t& qnn_tensor) { + litert::Tensor tensor(litert_tensor); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeTensor(tensor, qnn_tensor)); + LITERT_RETURN_STATUS_IF_NOT_OK(AssignTensorName(qnn_tensor)); + + // Set tensor as graph output if it is used by other Ops. + if (graph_outpus_.contains(litert_tensor)) { + LITERT_LOG(LITERT_INFO, "Setting tensor %d as Graph output", + qnn_tensor.v2.id); + qnn_tensor.v2.type = QNN_TENSOR_TYPE_APP_READ; + } + + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + qnn_.Api()->tensorCreateGraphTensor(QnnGraph(), &qnn_tensor)); + + LITERT_LOG(LITERT_INFO, "Legalized and registered tensor %d", + qnn_tensor.v2.id); + + return kLiteRtStatusOk; +} + +LiteRtStatus GraphMapper::IsLiteRtSubgraphSupported() { + // For now, we assume all LiteRt subgraphs are supported. + // TODO: b/381133565: Implement or remove this function. + return kLiteRtStatusOk; +} + +LiteRtStatus GraphMapper::InitQnnGraph(absl::string_view qnn_graph_name) { + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + qnn_.Api()->graphCreate(context_handle_, qnn_graph_name.data(), + PickGraphConfigHeuristic().data(), &QnnGraph())); + return kLiteRtStatusOk; +} + +LiteRtStatus GraphMapper::Finalize() { + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + qnn_.Api()->graphFinalize(QnnGraph(), nullptr, nullptr)); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h b/tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h new file mode 100644 index 00000000..61b8b7e3 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h @@ -0,0 +1,121 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_GRAPH_MAPPER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_GRAPH_MAPPER_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnGraph.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +namespace litert::qnn { + +// Algorithm class for managing "scope" when mapping litert Subgraphs +// to QNN Graphs. +class GraphMapper { + public: + GraphMapper(LiteRtSubgraph subgraph, QnnManager& qnn, + Qnn_ContextHandle_t context_handle) + : subgraph_(Subgraph(subgraph)), + qnn_(qnn), + context_handle_(context_handle) {} + + // Legalize given LiteRtTensors attributes into QNN Tensor registered with + // QNN context. Result QNN Tensor is empty except for the canonical id + // assigned by QNN Api. + LiteRtStatus LegalizeAndRegister(LiteRtTensor litert_tensor, + Qnn_Tensor_t& qnn_tensor); + + // Find ID associated with evaluated litert Tensor and add it to given + // QNN Tensor. + LiteRtStatus LookupInScope(LiteRtTensor litert_tensor, + Qnn_Tensor_t& qnn_tensor); + + // Adds new mapping to scope. All fields other than ID in given QNN Tensor are + // cleared and its ID is added to "current_scope". Expects QNN Tensor has + // already been registered with context. + LiteRtStatus PushToScope(LiteRtTensor litert_tensor, + Qnn_Tensor_t& qnn_tensor); + + // NOTE: QNN Tensors must be created with a unique name. This will ensure + // uniqueness but will want to have more meaningful names in the future. + LiteRtStatus AssignTensorName(Qnn_Tensor_t& qnn_tensor); + + // QNN Sdk Accessors + QnnManager& Qnn(); + Qnn_GraphHandle_t& QnnGraph(); + + // CC Convenience Accessors + const Subgraph& Graph() const { return subgraph_; } + + // Accessor for current scope. + // Since each QNN Tensor needs to have a unique name globally within each QNN + // context, we maintain "Current scope", which is a map of evaluated + // LiteRtTensors to their resolved QNN Tensor ID. + absl::flat_hash_map& CurrentScope(); + + // Can implementation handle given LiteRtSubgraph topology (see comment at + // bottom of file). + LiteRtStatus IsLiteRtSubgraphSupported(); + + // Initialize QNN Graph with given name. Call this after parsing + // LiteRtSubgraph. + LiteRtStatus InitQnnGraph(absl::string_view qnn_graph_name); + + // Finalize QNN Graph. Call this after all ops have been mapped. + LiteRtStatus Finalize(); + + // Pick graph config based on subgraph. + absl::Span PickGraphConfigHeuristic(); + + inline void RegisterOutput(LiteRtTensor litert_tensor) { + graph_outpus_.insert(litert_tensor); + } + + private: + const Subgraph subgraph_; + + // Set of all outputs of the graph. + absl::flat_hash_set graph_outpus_; + + // Maps evaluated tensors to their resolved QNN Tensor ID. + absl::flat_hash_map current_scope_; + + // + // QNN Sdk State + // + QnnManager& qnn_; + Qnn_ContextHandle_t context_handle_; + Qnn_GraphHandle_t qnn_graph_ = nullptr; + + // + // Tensor Naming + // + + uint32_t cur_tensor_num_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_GRAPH_MAPPER_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/BUILD b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/BUILD new file mode 100644 index 00000000..51cefec7 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/BUILD @@ -0,0 +1,789 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//tflite/experimental/litert/build_common:litert_build_defs.bzl", "litert_lib") + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//visibility:private"], +) + +litert_lib( + name = "legalization", + hdrs = ["legalization.h"], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + ], +) + +litert_lib( + name = "add_op_legalization", + srcs = ["add_op_legalization.cc"], + hdrs = ["add_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "batch_matmul_op_legalization", + srcs = ["batch_matmul_op_legalization.cc"], + hdrs = ["batch_matmul_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "cast_op_legalization", + srcs = ["cast_op_legalization.cc"], + hdrs = ["cast_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "concatenation_op_legalization", + srcs = ["concatenation_op_legalization.cc"], + hdrs = ["concatenation_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "cos_op_legalization", + srcs = ["cos_op_legalization.cc"], + hdrs = ["cos_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "div_op_legalization", + srcs = ["div_op_legalization.cc"], + hdrs = ["div_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "embedding_lookup_op_legalization", + srcs = ["embedding_lookup_op_legalization.cc"], + hdrs = ["embedding_lookup_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "transpose_op_legalization", + srcs = ["transpose_op_legalization.cc"], + hdrs = ["transpose_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "fully_connected_op_legalization", + srcs = ["fully_connected_op_legalization.cc"], + hdrs = ["fully_connected_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "greater_op_legalization", + srcs = ["greater_op_legalization.cc"], + hdrs = ["greater_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "less_op_legalization", + srcs = ["less_op_legalization.cc"], + hdrs = ["less_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "logical_and_op_legalization", + srcs = ["logical_and_op_legalization.cc"], + hdrs = ["logical_and_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "mul_op_legalization", + srcs = ["mul_op_legalization.cc"], + hdrs = ["mul_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "reshape_op_legalization", + srcs = ["reshape_op_legalization.cc"], + hdrs = ["reshape_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "rsqrt_op_legalization", + srcs = ["rsqrt_op_legalization.cc"], + hdrs = ["rsqrt_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "sin_op_legalization", + srcs = ["sin_op_legalization.cc"], + hdrs = ["sin_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "select_op_legalization", + srcs = ["select_op_legalization.cc"], + hdrs = ["select_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "slice_op_legalization", + srcs = ["slice_op_legalization.cc"], + hdrs = ["slice_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "sum_op_legalization", + srcs = ["sum_op_legalization.cc"], + hdrs = ["sum_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "sub_op_legalization", + srcs = ["sub_op_legalization.cc"], + hdrs = ["sub_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "softmax_op_legalization", + srcs = ["softmax_op_legalization.cc"], + hdrs = ["softmax_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "tanh_op_legalization", + srcs = ["tanh_op_legalization.cc"], + hdrs = ["tanh_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_op_code", + "//tflite/experimental/litert/c:litert_options", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/cc:litert_macros", + "//tflite/experimental/litert/cc:litert_model", + "//tflite/experimental/litert/cc:litert_model_predicates", + "//tflite/experimental/litert/tools:dump", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tflite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tflite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.cc new file mode 100644 index 00000000..d4905686 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.cc @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnAddOpTypeName = "ElementWiseAdd"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kAddOpFmt = "add_%d"; + +LiteRtStatus AddOpLegalization::LegalizeOp(const litert::Op& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflAdd) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kAddOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnAddOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized add op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h new file mode 100644 index 00000000..e3893a5b --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class AddOpLegalization : public Legalization { + public: + AddOpLegalization() = default; + ~AddOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.cc new file mode 100644 index 00000000..a7ccea44 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.cc @@ -0,0 +1,52 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnBatchMatmulOpTypeName = "MatMul"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kBatchMatmulOpFmt = "batch_matmul_%d"; + +LiteRtStatus BatchMatmulOpLegalization::LegalizeOp(const litert::Op& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflBatchMatmul) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kBatchMatmulOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK( + SetOpInfo(op_name.c_str(), kDefaultQnnOpPackageName.data(), + kQnnBatchMatmulOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized batch_matmul op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h new file mode 100644 index 00000000..a9b2b539 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_BATCH_MATMUL_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_BATCH_MATMUL_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class BatchMatmulOpLegalization : public Legalization { + public: + BatchMatmulOpLegalization() = default; + ~BatchMatmulOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_BATCH_MATMUL_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.cc new file mode 100644 index 00000000..9e554016 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.cc @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnCastOpTypeName = "Cast"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kCastOpFmt = "cast_%d"; + +LiteRtStatus CastOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflCast) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kCastOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnCastOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized cast op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.h new file mode 100644 index 00000000..c4c3eecf --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_CAST_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_CAST_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class CastOpLegalization : public Legalization { + public: + CastOpLegalization() = default; + ~CastOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_CAST_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.cc new file mode 100644 index 00000000..76191035 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.cc @@ -0,0 +1,101 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/c/litert_options.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/vendors/qualcomm/common.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnConcatenationOpTypeName = "Concat"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kConcatenationOpFmt = "concatenation_%d"; + +static constexpr int kReduceConcatenationOpOutputSize = 1; +static constexpr int kReduceConcatenationOpParamSize = 1; + +LiteRtStatus ConcatenationOpLegalization::LegalizeOp( + const Op& src, Qnn_OpConfig_t& dest, GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflConcatenation) { + return kLiteRtStatusLegalizeNoMatch; + } + DumpLegalization(*src.Get()); + std::string op_name = absl::StrFormat(kConcatenationOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK( + SetOpInfo(op_name.c_str(), kDefaultQnnOpPackageName.data(), + kQnnConcatenationOpTypeName.data(), dest)); + + // Look up op input tensors in scope. + const auto op_ins = src.Inputs(); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, op_ins.size(), QNN_TENSOR_INIT); + + Qnn_Tensor_t* cur_qnn_op_in = qnn_op_ins; + for (const auto& op_in : op_ins) { + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LookupInScope(op_in.Get(), *cur_qnn_op_in)); + ++cur_qnn_op_in; + } + + // QNN concatenation op expects 1 output tensor. + const auto op_outs = src.Outputs(); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, + kReduceConcatenationOpOutputSize, QNN_TENSOR_INIT); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); + + // Extract axis option from concatenation op. + int32_t axis; + LITERT_RETURN_STATUS_IF_NOT_OK( + LiteRtGetConcatenationAxisOption(src.Get(), &axis)); + + // Construct the scalar "axis" param. + Qnn_Param_t axis_param = BuildDefaultParam(); + axis_param.paramType = QNN_PARAMTYPE_SCALAR; + axis_param.name = "axis"; + Qnn_Scalar_t axis_scalar = QNN_SCALAR_INIT; + axis_scalar.dataType = QNN_DATATYPE_UINT_32; + axis_scalar.int32Value = axis; + axis_param.scalarParam = axis_scalar; + + Qnn_Param_t concatenation_params[] = {axis_param}; + dest.v1.inputTensors = qnn_op_ins; + dest.v1.numOfInputs = op_ins.size(); + dest.v1.outputTensors = qnn_op_outs; + dest.v1.numOfOutputs = kReduceConcatenationOpOutputSize; + dest.v1.numOfParams = kReduceConcatenationOpParamSize; + dest.v1.params = concatenation_params; + + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); + + LITERT_LOG(LITERT_INFO, "Legalized concatenation op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.h new file mode 100644 index 00000000..0f567e0f --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.h @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_CONCATENATION_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_CONCATENATION_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class ConcatenationOpLegalization : public Legalization { + public: + ConcatenationOpLegalization() = default; + ~ConcatenationOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { + return std::make_unique(); + } + + LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_CONCATENATION_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.cc new file mode 100644 index 00000000..814d6f20 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.cc @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnCosOpTypeName = "ElementWiseCos"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kCosOpFmt = "cos_%d"; + +LiteRtStatus CosOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflCos) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kCosOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnCosOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized cos op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.h new file mode 100644 index 00000000..817b1b65 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_COS_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_COS_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class CosOpLegalization : public Legalization { + public: + CosOpLegalization() = default; + ~CosOpLegalization() = default; + using UniquePtr = std::unique_ptr; + static UniquePtr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_COS_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.cc new file mode 100644 index 00000000..23a0f0cb --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.cc @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnDivOpTypeName = "ElementWiseDivide"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kDivOpFmt = "div_%d"; + +LiteRtStatus DivOpLegalization::LegalizeOp(const litert::Op& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflDiv) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kDivOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnDivOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized div op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h new file mode 100644 index 00000000..348ed495 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_DIV_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_DIV_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class DivOpLegalization : public Legalization { + public: + DivOpLegalization() = default; + ~DivOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_DIV_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.cc new file mode 100644 index 00000000..7e151b69 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.cc @@ -0,0 +1,104 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/common.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnEmbeddingLookupOpTypeName = "Gather"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kEmbeddingLookupOpFmt = + "embedding_lookup_%d"; + +static constexpr int kReduceEmbeddingLookupOpOutputSize = 1; +static constexpr int kReduceEmbeddingLookupOpParamSize = 1; + +static constexpr int kEmbeddingLookupOpTableInputIndex = 1; +static constexpr int kEmbeddingLookupOpLookipInputIndex = 0; +static constexpr int kQnnGatherOpTableInputIndex = 0; +static constexpr int kQnnGatherOpLookupInputIndex = 1; + +LiteRtStatus EmbeddingLookupOpLegalization::LegalizeOp( + const Op& src, Qnn_OpConfig_t& dest, GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflEmbeddingLookup) { + return kLiteRtStatusLegalizeNoMatch; + } + DumpLegalization(*src.Get()); + std::string op_name = absl::StrFormat(kEmbeddingLookupOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK( + SetOpInfo(op_name.c_str(), kDefaultQnnOpPackageName.data(), + kQnnEmbeddingLookupOpTypeName.data(), dest)); + + // Look up op input tensors in scope. + const auto op_ins = src.Inputs(); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, op_ins.size(), QNN_TENSOR_INIT); + + LITERT_RETURN_STATUS_IF_NOT_OK(graph_mapper.LookupInScope( + op_ins[kEmbeddingLookupOpLookipInputIndex].Get(), + qnn_op_ins[kQnnGatherOpLookupInputIndex])); + LITERT_RETURN_STATUS_IF_NOT_OK(graph_mapper.LookupInScope( + op_ins[kEmbeddingLookupOpTableInputIndex].Get(), + qnn_op_ins[kQnnGatherOpTableInputIndex])); + + // QNN embedding_lookup op expects 1 output tensor. + const auto op_outs = src.Outputs(); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, + kReduceEmbeddingLookupOpOutputSize, QNN_TENSOR_INIT); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); + + // Construct the scalar "axis" param. + Qnn_Param_t axis_param = BuildDefaultParam(); + axis_param.paramType = QNN_PARAMTYPE_SCALAR; + axis_param.name = "axis"; + Qnn_Scalar_t axis_scalar = QNN_SCALAR_INIT; + axis_scalar.dataType = QNN_DATATYPE_INT_32; + // Embedding lookup op expects axis to always be 0. + axis_scalar.int32Value = 0; + axis_param.scalarParam = axis_scalar; + + Qnn_Param_t embedding_lookup_params[] = {axis_param}; + dest.v1.inputTensors = qnn_op_ins; + dest.v1.numOfInputs = op_ins.size(); + dest.v1.outputTensors = qnn_op_outs; + dest.v1.numOfOutputs = kReduceEmbeddingLookupOpOutputSize; + dest.v1.numOfParams = kReduceEmbeddingLookupOpParamSize; + dest.v1.params = embedding_lookup_params; + + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); + + LITERT_LOG(LITERT_INFO, "Legalized embedding_lookup op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.h new file mode 100644 index 00000000..87c28687 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.h @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_EMBEDDING_LOOKUP_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_EMBEDDING_LOOKUP_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class EmbeddingLookupOpLegalization : public Legalization { + public: + EmbeddingLookupOpLegalization() = default; + ~EmbeddingLookupOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { + return std::make_unique(); + } + + LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_EMBEDDING_LOOKUP_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.cc new file mode 100644 index 00000000..386f3418 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.cc @@ -0,0 +1,52 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnFullyConnectedOpTypeName = + "FullyConnected"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kFullyConnectedOpFmt = "fully_connected_%d"; + +LiteRtStatus FullyConnectedOpLegalization::LegalizeOp( + const Op& src, Qnn_OpConfig_t& dest, GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflFullyConnected) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kFullyConnectedOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK( + SetOpInfo(op_name.c_str(), kDefaultQnnOpPackageName.data(), + kQnnFullyConnectedOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + + LITERT_LOG(LITERT_INFO, "Legalized fully_connected op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.h new file mode 100644 index 00000000..e1319388 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.h @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_FULLY_CONNECTED_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_FULLY_CONNECTED_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class FullyConnectedOpLegalization : public Legalization { + public: + FullyConnectedOpLegalization() = default; + ~FullyConnectedOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { + return std::make_unique(); + } + + LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_FULLY_CONNECTED_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.cc new file mode 100644 index 00000000..280d31f9 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.cc @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Ungreater required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnGreaterOpTypeName = "ElementWiseGreater"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kGreaterOpFmt = "greater_%d"; + +LiteRtStatus GreaterOpLegalization::LegalizeOp(const litert::Op& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflGreater) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kGreaterOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnGreaterOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized greater op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.h new file mode 100644 index 00000000..aab5d761 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_GREATER_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_GREATER_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class GreaterOpLegalization : public Legalization { + public: + GreaterOpLegalization() = default; + ~GreaterOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_GREATER_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h new file mode 100644 index 00000000..2e35e83b --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h @@ -0,0 +1,50 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LEGALIZATION_H_ + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" + +#define STRINGIFY(x) #x +#define QNN_OP_NAME(prefix) STRINGIFY(prefix##__COUNTER__) + +namespace litert::qnn { + +class Legalization { + public: + Legalization() = default; + virtual ~Legalization() = default; + + virtual LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) = 0; + + // Sets the op name, package name, and type. + // Note: All argument strings can't be de-allocated until the op has been + // registered with the qnn api. i.e graphAddNode(). + inline LiteRtStatus SetOpInfo(const char* name, const char* op_package_name, + const char* op_type, Qnn_OpConfig_t& op) { + op.v1.name = name; + op.v1.packageName = op_package_name; + op.v1.typeName = op_type; + return kLiteRtStatusOk; + } +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.cc new file mode 100644 index 00000000..d3d60f87 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.cc @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnLessOpTypeName = "ElementWiseLess"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kLessOpFmt = "less_%d"; + +LiteRtStatus LessOpLegalization::LegalizeOp(const litert::Op& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflLess) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kLessOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnLessOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized less op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.h new file mode 100644 index 00000000..baaa524c --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LESS_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LESS_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class LessOpLegalization : public Legalization { + public: + LessOpLegalization() = default; + ~LessOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LESS_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.cc new file mode 100644 index 00000000..0d44f140 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.cc @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnLogicalAndOpTypeName = "ElementWiseAnd"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kLogicalAndOpFmt = "logical_and_%d"; + +LiteRtStatus LogicalAndOpLegalization::LegalizeOp(const Op& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflLogicalAnd) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kLogicalAndOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK( + SetOpInfo(op_name.c_str(), kDefaultQnnOpPackageName.data(), + kQnnLogicalAndOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized logical_and op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.h new file mode 100644 index 00000000..0052ea88 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.h @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LOGICAL_AND_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LOGICAL_AND_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class LogicalAndOpLegalization : public Legalization { + public: + LogicalAndOpLegalization() = default; + ~LogicalAndOpLegalization() = default; + using UniquePtr = std::unique_ptr; + static UniquePtr Create() { + return std::make_unique(); + } + + LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LOGICAL_AND_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.cc new file mode 100644 index 00000000..153115e8 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.cc @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnMulOpTypeName = "ElementWiseMultiply"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kMulOpFmt = "mul_%d"; + +LiteRtStatus MulOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflMul) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kMulOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnMulOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized mul op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h new file mode 100644 index 00000000..bbfa0463 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_MUL_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_MUL_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class MulOpLegalization : public Legalization { + public: + MulOpLegalization() = default; + ~MulOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_MUL_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.cc new file mode 100644 index 00000000..3a7a350b --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.cc @@ -0,0 +1,82 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/cc/litert_model_predicates.h" +#include "tflite/experimental/litert/vendors/qualcomm/common.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnReshapeOpTypeName = "Reshape"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kReshapeOpFmt = "reshape_%d"; + +static constexpr int kReshapeOpInputSize = 1; +static constexpr int kReshapeOpOutputSize = 1; + +LiteRtStatus ReshapeOpLegalization::LegalizeOp(const litert::Op& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflReshape) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kReshapeOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnReshapeOpTypeName.data(), dest)); + DumpLegalization(*src.Get()); + // Look up op input tensors in scope. + const auto op_ins = src.Inputs(); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kReshapeOpInputSize, + QNN_TENSOR_INIT); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LookupInScope(op_ins.front().Get(), qnn_op_ins[0])); + + // Legalize op outputs and update scope. + + const auto op_outs = src.Outputs(); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, kReshapeOpOutputSize, + QNN_TENSOR_INIT); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); + + dest.v1.numOfInputs = kReshapeOpInputSize; + dest.v1.inputTensors = qnn_op_ins; + + dest.v1.numOfOutputs = kReshapeOpOutputSize; + dest.v1.outputTensors = qnn_op_outs; + + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); + + LITERT_LOG(LITERT_INFO, "Legalized reshape op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h new file mode 100644 index 00000000..e4f6cf6d --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RESHAPE_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RESHAPE_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class ReshapeOpLegalization : public Legalization { + public: + ReshapeOpLegalization() = default; + ~ReshapeOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RESHAPE_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.cc new file mode 100644 index 00000000..52c8769c --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.cc @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnRsqrtOpTypeName = "ElementWiseRsqrt"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kRsqrtOpFmt = "rsqrt_%d"; + +LiteRtStatus RsqrtOpLegalization::LegalizeOp(const litert::Op& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflRsqrt) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kRsqrtOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnRsqrtOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized rsqrt op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h new file mode 100644 index 00000000..077bba60 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RSQRT_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RSQRT_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class RsqrtOpLegalization : public Legalization { + public: + RsqrtOpLegalization() = default; + ~RsqrtOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RSQRT_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.cc new file mode 100644 index 00000000..8a757d6d --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.cc @@ -0,0 +1,55 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnSelectOpTypeName = "ElementWiseSelect"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kSelectOpFmt = "select_%d"; + +LiteRtStatus SelectOpLegalization::LegalizeOp(const Op& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflSelect && + src.Code() != kLiteRtOpCodeTflSelectV2) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kSelectOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnSelectOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + + return kLiteRtStatusOk; + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized select op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.h new file mode 100644 index 00000000..e1bd02c0 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SELECT_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SELECT_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class SelectOpLegalization : public Legalization { + public: + SelectOpLegalization() = default; + ~SelectOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SELECT_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.cc new file mode 100644 index 00000000..e26ffed5 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.cc @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnSinOpTypeName = "ElementWiseSin"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kSinOpFmt = "sin_%d"; + +LiteRtStatus SinOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflSin) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kSinOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnSinOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized sin op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.h new file mode 100644 index 00000000..4b036d62 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SIN_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SIN_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class SinOpLegalization : public Legalization { + public: + SinOpLegalization() = default; + ~SinOpLegalization() = default; + using UniquePtr = std::unique_ptr; + static UniquePtr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SIN_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.cc new file mode 100644 index 00000000..13b6d16c --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.cc @@ -0,0 +1,153 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h" + +#include +#include + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/cc/litert_model_predicates.h" +#include "tflite/experimental/litert/vendors/qualcomm/common.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" +#include "tflite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnSliceOpTypeName = "StridedSlice"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kSliceOpFmt = "slice_%d"; + +static constexpr int kSliceOpInputSize = 1; +static constexpr int kSliceOpOutputSize = 1; +static constexpr int kSliceOpParamSize = 1; +// QNN StridedSlice op packs "start", "end", and "stride" into a single tensor +// param "ranges". +static constexpr int kRangesParamArgSize = 3; +static constexpr int kRangesParamRank = 2; + +LiteRtStatus SliceOpLegalization::LegalizeOp(const Op& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflSlice) { + return kLiteRtStatusLegalizeNoMatch; + } + DumpLegalization(*src.Get()); + std::string op_name = absl::StrFormat(kSliceOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnSliceOpTypeName.data(), dest)); + + // QNN strided slice op expects 1 input tensor. + const auto op_ins = src.Inputs(); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kSliceOpInputSize, + QNN_TENSOR_INIT); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LookupInScope(op_ins.front().Get(), qnn_op_ins[0])); + + // QNN strided slice op expects 1 output tensor. + const auto op_outs = src.Outputs(); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, kSliceOpOutputSize, + QNN_TENSOR_INIT); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); + + const auto& src_input_tensor = op_ins.front(); + auto src_input_tensor_rank = + src_input_tensor.RankedTensorType().Layout().Rank(); + + // Prepare qnn strided slice parameters. + + auto src_begin_indices = op_ins.at(1).WeightsData(); + if (!src_begin_indices) { + return src_begin_indices.Error().Status(); + } + + auto src_size_indices = op_ins.at(2).WeightsData(); + if (!src_size_indices) { + return src_size_indices.Error().Status(); + } + + // Check if src_begin_indices and src_size_indices are weights tensors. + if (src_begin_indices->empty() || src_size_indices->empty()) { + return kLiteRtStatusErrorInvalidLegalization; + } + + LITERT_STACK_ARRAY(int32_t, range_tensor_data, + src_input_tensor_rank* kRangesParamArgSize, + /*init value*/ 0); + for (int i = 0; i < src_input_tensor_rank; ++i) { + // Copy begin, end, and stride values from src_begin_indices and + // src_size_indices to range_tensor_data. Stride is always 1. + range_tensor_data[i * kRangesParamArgSize] = src_begin_indices->at(i); + range_tensor_data[i * kRangesParamArgSize + 1] = src_size_indices->at(i); + range_tensor_data[i * kRangesParamArgSize + 2] = 1; + } + + Qnn_ClientBuffer_t range_tensor_client_buf = BuildDefaultClientBuffer(); + range_tensor_client_buf.data = range_tensor_data; + range_tensor_client_buf.dataSize = + src_input_tensor_rank * kRangesParamArgSize * sizeof(int32_t); + + // Construct the const tensor "ranges". + Qnn_Tensor_t range_tensor = BuildDefaultTensor(); + graph_mapper.AssignTensorName(range_tensor); + range_tensor.v2.dataType = QNN_DATATYPE_INT_32; + range_tensor.v2.type = QNN_TENSOR_TYPE_STATIC; + range_tensor.v2.rank = kRangesParamRank; + range_tensor.v2.dimensions = new uint32_t[kRangesParamRank]; + range_tensor.v2.dimensions[0] = src_input_tensor_rank; + range_tensor.v2.dimensions[1] = kRangesParamArgSize; + range_tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; + range_tensor.v2.clientBuf = range_tensor_client_buf; + range_tensor.v2.isDynamicDimensions = nullptr; + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + graph_mapper.Qnn().Api()->tensorCreateGraphTensor(graph_mapper.QnnGraph(), + &range_tensor)); + + Qnn_Param_t range_param = BuildDefaultParam(); + range_param.paramType = QNN_PARAMTYPE_TENSOR; + range_param.name = "ranges"; + range_param.tensorParam = range_tensor; + + Qnn_Param_t strided_slice_params[] = {range_param}; + dest.v1.inputTensors = qnn_op_ins; + dest.v1.numOfInputs = kSliceOpInputSize; + dest.v1.outputTensors = qnn_op_outs; + dest.v1.numOfOutputs = kSliceOpOutputSize; + dest.v1.numOfParams = kSliceOpParamSize; + dest.v1.params = strided_slice_params; + + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); + + LITERT_LOG(LITERT_INFO, "Legalized slice op", ""); + + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h new file mode 100644 index 00000000..9cf9fa77 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h @@ -0,0 +1,47 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SLICE_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SLICE_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class SliceOpLegalization : public Legalization { + public: + SliceOpLegalization() = default; + ~SliceOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SLICE_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.cc new file mode 100644 index 00000000..1a2bd65e --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.cc @@ -0,0 +1,100 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/c/litert_options.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/common.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnSoftmaxOpTypeName = "Softmax"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kSoftmaxOpFmt = "softmax_%d"; + +static constexpr int kSoftmaxOpInputSize = 1; +static constexpr int kSoftmaxOpOutputSize = 1; +static constexpr int kSoftmaxOpParamSize = 1; + +LiteRtStatus SoftmaxOpLegalization::LegalizeOp(const Op& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflSoftmax) { + return kLiteRtStatusLegalizeNoMatch; + } + DumpLegalization(*src.Get()); + std::string op_name = absl::StrFormat(kSoftmaxOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnSoftmaxOpTypeName.data(), dest)); + + // QNN reduce softmax op expects 1 input tensor. + const auto op_ins = src.Inputs(); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kSoftmaxOpInputSize, + QNN_TENSOR_INIT); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LookupInScope(op_ins.front().Get(), qnn_op_ins[0])); + + // QNN softmax op expects 1 output tensor. + const auto op_outs = src.Outputs(); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, kSoftmaxOpOutputSize, + QNN_TENSOR_INIT); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); + + // Prepare QNN reduce softmax parameters. + + // Extract beta option from softmax op. + float beta; + LITERT_RETURN_STATUS_IF_NOT_OK(LiteRtGetSoftmaxBetaOption(src.Get(), &beta)); + Qnn_Param_t beta_param = BuildDefaultParam(); + beta_param.paramType = QNN_PARAMTYPE_SCALAR; + beta_param.name = "beta"; + Qnn_Scalar_t keep_dims_scalar = QNN_SCALAR_INIT; + keep_dims_scalar.dataType = QNN_DATATYPE_FLOAT_32; + keep_dims_scalar.floatValue = beta; + beta_param.scalarParam = keep_dims_scalar; + + Qnn_Param_t reduce_softmax_params[] = {beta_param}; + dest.v1.inputTensors = qnn_op_ins; + dest.v1.numOfInputs = kSoftmaxOpInputSize; + dest.v1.outputTensors = qnn_op_outs; + dest.v1.numOfOutputs = kSoftmaxOpOutputSize; + dest.v1.numOfParams = kSoftmaxOpParamSize; + dest.v1.params = reduce_softmax_params; + + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); + + LITERT_LOG(LITERT_INFO, "Legalized softmax op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.h new file mode 100644 index 00000000..d18b1adb --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SOFTMAX_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SOFTMAX_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class SoftmaxOpLegalization : public Legalization { + public: + SoftmaxOpLegalization() = default; + ~SoftmaxOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SOFTMAX_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.cc new file mode 100644 index 00000000..44490932 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.cc @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnSubOpTypeName = "ElementWiseSubtract"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kSubOpFmt = "sub_%d"; + +LiteRtStatus SubOpLegalization::LegalizeOp(const litert::Op& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflSub) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kSubOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnSubOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized sub op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h new file mode 100644 index 00000000..526ea8ec --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUB_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUB_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class SubOpLegalization : public Legalization { + public: + SubOpLegalization() = default; + ~SubOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUB_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.cc new file mode 100644 index 00000000..2cede68f --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.cc @@ -0,0 +1,139 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/c/litert_options.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/cc/litert_model_predicates.h" +#include "tflite/experimental/litert/vendors/qualcomm/common.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnSumOpTypeName = "ReduceSum"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kSumOpFmt = "sum_%d"; + +static constexpr int kReduceSumOpInputSize = 1; +static constexpr int kReduceSumOpOutputSize = 1; +static constexpr int kReduceSumOpParamSize = 1; +static constexpr int kReduceSumOpParamRank = 1; + +LiteRtStatus SumOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflSum) { + return kLiteRtStatusLegalizeNoMatch; + } + DumpLegalization(*src.Get()); + std::string op_name = absl::StrFormat(kSumOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnSumOpTypeName.data(), dest)); + + // QNN reduce sum op expects 1 input tensor. + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kReduceSumOpInputSize, + QNN_TENSOR_INIT); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LookupInScope(src.Inputs().front().Get(), qnn_op_ins[0])); + + // QNN sum op expects 1 output tensor. + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, kReduceSumOpOutputSize, + QNN_TENSOR_INIT); + LITERT_RETURN_STATUS_IF_NOT_OK(graph_mapper.LegalizeAndRegister( + src.Outputs().front().Get(), qnn_op_outs[0])); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.PushToScope(src.Outputs().front().Get(), qnn_op_outs[0])); + + // Prepare QNN reduce sum parameters. + const auto inputs = src.Inputs(); + const auto& src_axes = inputs.at(1); + + // Check if src_axes are weights tensors. + if (!src_axes.HasWeights()) { + LITERT_LOG(LITERT_ERROR, "Sum op axes are not weights tensors"); + return kLiteRtStatusErrorInvalidLegalization; + } + int32_t dest_axes_size = src_axes.RankedTensorType().Layout().Dimensions()[0]; + auto src_axes_data = src_axes.Weights().Bytes(); + Qnn_ClientBuffer_t axes_tensor_client_buf = BuildDefaultClientBuffer(); + axes_tensor_client_buf.data = (void*)src_axes_data.data(); + axes_tensor_client_buf.dataSize = src_axes_data.size(); + + // Extract keepdims option from sum op. + bool keep_dims; + LITERT_RETURN_STATUS_IF_NOT_OK( + LiteRtGetSumKeepDimsOption(src.Get(), &keep_dims)); + + // Construct the scalar "keep_dims" param. + if (keep_dims) { + Qnn_Param_t range_param = BuildDefaultParam(); + range_param.paramType = QNN_PARAMTYPE_SCALAR; + range_param.name = "keep_dims"; + Qnn_Scalar_t keep_dims_scalar = QNN_SCALAR_INIT; + keep_dims_scalar.dataType = QNN_DATATYPE_BOOL_8; + keep_dims_scalar.bool8Value = true; + range_param.scalarParam = keep_dims_scalar; + } + + // Construct the const tensor "axes". + Qnn_Tensor_t range_tensor = BuildDefaultTensor(); + graph_mapper.AssignTensorName(range_tensor); + range_tensor.v2.dataType = QNN_DATATYPE_INT_32; + range_tensor.v2.type = QNN_TENSOR_TYPE_STATIC; + range_tensor.v2.rank = kReduceSumOpParamRank; + range_tensor.v2.dimensions = new uint32_t[kReduceSumOpParamRank]; + range_tensor.v2.dimensions[0] = dest_axes_size; + range_tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; + range_tensor.v2.clientBuf = axes_tensor_client_buf; + range_tensor.v2.isDynamicDimensions = nullptr; + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + graph_mapper.Qnn().Api()->tensorCreateGraphTensor(graph_mapper.QnnGraph(), + &range_tensor)); + + Qnn_Param_t range_param = BuildDefaultParam(); + range_param.paramType = QNN_PARAMTYPE_TENSOR; + range_param.name = "axes"; + range_param.tensorParam = range_tensor; + + Qnn_Param_t reduce_sum_params[] = {range_param}; + dest.v1.inputTensors = qnn_op_ins; + dest.v1.numOfInputs = kReduceSumOpInputSize; + dest.v1.outputTensors = qnn_op_outs; + dest.v1.numOfOutputs = kReduceSumOpOutputSize; + dest.v1.numOfParams = kReduceSumOpParamSize; + dest.v1.params = reduce_sum_params; + + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); + + LITERT_LOG(LITERT_INFO, "Legalized sum op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.h new file mode 100644 index 00000000..c4e407dc --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUM_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUM_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class SumOpLegalization : public Legalization { + public: + SumOpLegalization() = default; + ~SumOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUM_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.cc new file mode 100644 index 00000000..92c1e0b3 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.cc @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnTanhOpTypeName = "Tanh"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kTanhOpFmt = "tanh_%d"; + +LiteRtStatus TanhOpLegalization::LegalizeOp(const litert::Op& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflTanh) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kTanhOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnTanhOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized tanh op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h new file mode 100644 index 00000000..c20b6e0e --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TANH_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TANH_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class TanhOpLegalization : public Legalization { + public: + TanhOpLegalization() = default; + ~TanhOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TANH_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.cc new file mode 100644 index 00000000..2575b4ef --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.cc @@ -0,0 +1,121 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/common.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnTransposeOpTypeName = "Transpose"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kTransposeOpFmt = "transpose_%d"; + +static constexpr int kTransposeOpInputSize = 1; +static constexpr int kTransposeOpOutputSize = 1; +static constexpr int kTransposeOpParamSize = 1; +static constexpr int kTransposeOpParamRank = 1; + +LiteRtStatus TransposeOpLegalization::LegalizeOp(const Op& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflTranspose) { + return kLiteRtStatusLegalizeNoMatch; + } + DumpLegalization(*src.Get()); + std::string op_name = absl::StrFormat(kTransposeOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK( + SetOpInfo(op_name.c_str(), kDefaultQnnOpPackageName.data(), + kQnnTransposeOpTypeName.data(), dest)); + + // QNN transpose op expects 1 input tensor. + const auto op_ins = src.Inputs(); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kTransposeOpInputSize, + QNN_TENSOR_INIT); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LookupInScope(op_ins.front().Get(), qnn_op_ins[0])); + + // QNN transpose op expects 1 output tensor. + const auto op_outs = src.Outputs(); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, kTransposeOpOutputSize, + QNN_TENSOR_INIT); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); + + // Prepare QNN transpose parameters. + auto perm = Tensor(op_ins.at(1).Get()); + + // Check if src_axes are weights tensors. + if (!perm.HasWeights()) { + return kLiteRtStatusErrorInvalidLegalization; + } + auto perm_data = perm.Weights().Bytes(); + int32_t dest_axes_size = perm_data.size(); + Qnn_ClientBuffer_t perm_tensor_client_buf = BuildDefaultClientBuffer(); + perm_tensor_client_buf.data = (void*)perm_data.data(); + perm_tensor_client_buf.dataSize = dest_axes_size; + + // Construct the const tensor "perm". + Qnn_Tensor_t perm_tensor = BuildDefaultTensor(); + graph_mapper.AssignTensorName(perm_tensor); + perm_tensor.v2.dataType = QNN_DATATYPE_INT_32; + perm_tensor.v2.type = QNN_TENSOR_TYPE_STATIC; + perm_tensor.v2.rank = kTransposeOpParamRank; + perm_tensor.v2.dimensions = new uint32_t[kTransposeOpParamRank]; + perm_tensor.v2.dimensions[0] = dest_axes_size; + perm_tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; + perm_tensor.v2.clientBuf = perm_tensor_client_buf; + perm_tensor.v2.isDynamicDimensions = nullptr; + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + graph_mapper.Qnn().Api()->tensorCreateGraphTensor(graph_mapper.QnnGraph(), + &perm_tensor)); + + Qnn_Param_t perm_param = BuildDefaultParam(); + perm_param.paramType = QNN_PARAMTYPE_TENSOR; + perm_param.name = "perm"; + perm_param.tensorParam = perm_tensor; + + Qnn_Param_t transpose_params[] = {perm_param}; + dest.v1.inputTensors = qnn_op_ins; + dest.v1.numOfInputs = kTransposeOpInputSize; + dest.v1.outputTensors = qnn_op_outs; + dest.v1.numOfOutputs = kTransposeOpOutputSize; + dest.v1.numOfParams = kTransposeOpParamSize; + dest.v1.params = transpose_params; + + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); + + LITERT_LOG(LITERT_INFO, "Legalized transpose op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.h new file mode 100644 index 00000000..8dcfbdd1 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TRANSPOSE_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TRANSPOSE_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class TransposeOpLegalization : public Legalization { + public: + TransposeOpLegalization() = default; + ~TransposeOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TRANSPOSE_OP_LEGALIZATION_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.cc new file mode 100644 index 00000000..56ea9c1d --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.cc @@ -0,0 +1,87 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/cc/litert_model_predicates.h" +#include "tflite/experimental/litert/tools/dump.h" +#include "tflite/experimental/litert/vendors/qualcomm/common.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +namespace litert::qnn { + +using ::litert::internal::Dump; +using ::litert::internal::DumpOptions; + +// Dump source Op details. +void DumpLegalization(const LiteRtOpT& op) { + std::ostringstream dump; + // TODO Make dump tools part of stable api. + Dump(op, dump); + DumpOptions(op, dump); + std::string s = dump.str(); + LITERT_LOG(LITERT_INFO, "%s", s.data()); +} + +LiteRtStatus LegalizeSimpleOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + DumpLegalization(*src.Get()); + // Look up op input tensors in scope. + const auto op_ins = src.Inputs(); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, op_ins.size(), QNN_TENSOR_INIT); + + Qnn_Tensor_t* cur_qnn_op_in = qnn_op_ins; + for (const auto& op_in : op_ins) { + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LookupInScope(op_in.Get(), *cur_qnn_op_in)); + ++cur_qnn_op_in; + } + + // Legalize op outputs and update scope. + + const auto op_outs = src.Outputs(); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, op_outs.size(), + QNN_TENSOR_INIT); + + Qnn_Tensor_t* cur_qnn_op_out = qnn_op_outs; + for (const auto& op_out : op_outs) { + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LegalizeAndRegister(op_out.Get(), *cur_qnn_op_out)); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.PushToScope(op_out.Get(), *cur_qnn_op_out)); + ++cur_qnn_op_out; + } + dest.v1.numOfInputs = op_ins.size(); + dest.v1.inputTensors = qnn_op_ins; + + dest.v1.numOfOutputs = op_outs.size(); + dest.v1.outputTensors = qnn_op_outs; + + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); + + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h new file mode 100644 index 00000000..ea287527 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h @@ -0,0 +1,39 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_UTIL_H_ + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" + +namespace litert::qnn { + +// Use this function to legalize a LiteRtOp to a Qnn Op when: +// 1. Source input/output tensor and destination input/ouptut tensor are 1 : 1 +// mapped +// 2. Assigning params to destination OP does not depending on input tensor of +// source OP. +LiteRtStatus LegalizeSimpleOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + +// Dump source Op details. +void DumpLegalization(const LiteRtOpT& op); + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_UTIL_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc new file mode 100644 index 00000000..73cfc72d --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc @@ -0,0 +1,291 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include +#include +#include + +#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpDevice.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/c/litert_compiler_plugin.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h" +#include "tflite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +using ::litert::qnn::QnnManager; + +// +// Configurations +// + +namespace { + +constexpr char kPluginManufacturer[] = "Qualcomm"; + +// clang-format off +constexpr std::pair kPluginSocModels[] = { + {"V68", QNN_HTP_DEVICE_ARCH_V68}, + {"V69", QNN_HTP_DEVICE_ARCH_V69}, + {"V73", QNN_HTP_DEVICE_ARCH_V73}, + {"V75", QNN_HTP_DEVICE_ARCH_V75}, + {"V79", QNN_HTP_DEVICE_ARCH_V79}, +}; + +constexpr LiteRtOpCode kSupportedOps[] = { + kLiteRtOpCodeTflAdd, + kLiteRtOpCodeTflDiv, + kLiteRtOpCodeTflMul, + kLiteRtOpCodeTflRsqrt, + kLiteRtOpCodeTflSlice, + kLiteRtOpCodeTflSelect, + kLiteRtOpCodeTflSelectV2, + kLiteRtOpCodeTflSub, + kLiteRtOpCodeTflTanh, + kLiteRtOpCodeTflBatchMatmul, + kLiteRtOpCodeTflReshape, + kLiteRtOpCodeTflSum, + kLiteRtOpCodeTflConcatenation, + kLiteRtOpCodeTflSoftmax, + kLiteRtOpCodeTflCast, + kLiteRtOpCodeTflTranspose, + kLiteRtOpCodeTflSin, + kLiteRtOpCodeTflCos, + kLiteRtOpCodeTflFullyConnected, + kLiteRtOpCodeTflEmbeddingLookup, + kLiteRtOpCodeTflLogicalAnd, + kLiteRtOpCodeTflLess, + kLiteRtOpCodeTflGreater, +}; +// clang-format on + +constexpr auto kNumPluginSocModels = + sizeof(kPluginSocModels) / sizeof(kPluginSocModels[0]); + +std::optional FindSocModel( + absl::string_view soc_model_name) { + std::optional soc_model; + for (auto i = 0; i < kNumPluginSocModels; ++i) { + if (soc_model_name == kPluginSocModels[i].first) { + soc_model = kPluginSocModels[i].second; + break; + } + } + return soc_model; +} + +} // namespace + +LiteRtStatus LiteRtGetCompilerPluginVersion(LiteRtApiVersion* api_version) { + if (api_version == nullptr) { + return kLiteRtStatusErrorInvalidArgument; + } + api_version->major = LITERT_API_VERSION_MAJOR; + api_version->minor = LITERT_API_VERSION_MINOR; + api_version->patch = LITERT_API_VERSION_PATCH; + return kLiteRtStatusOk; +} + +const char* LiteRtGetCompilerPluginSocManufacturer() { + return kPluginManufacturer; +} + +LiteRtStatus LiteRtGetNumCompilerPluginSupportedSocModels( + LiteRtCompilerPlugin compiler_plugin, + LiteRtParamIndex* num_supported_soc_models) { + if (!compiler_plugin || !num_supported_soc_models) { + return kLiteRtStatusErrorInvalidArgument; + } + *num_supported_soc_models = kNumPluginSocModels; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetCompilerPluginSupportedSocModel( + LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, + const char** soc_model_name) { + if (!compiler_plugin || !soc_model_name) { + return kLiteRtStatusErrorInvalidArgument; + } else if (soc_model_idx < 0 || soc_model_idx >= kNumPluginSocModels) { + return kLiteRtStatusErrorInvalidArgument; + } + *soc_model_name = kPluginSocModels[soc_model_idx].first; + return kLiteRtStatusOk; +} + +// +// Compiled Result Definition +// + +struct LiteRtCompiledResultT { + std::vector context_bin; + std::vector graph_names; +}; + +LiteRtStatus LiteRtGetCompiledResultByteCode( + LiteRtCompiledResult compiled_result, const void** byte_code, + size_t* byte_code_size) { + *byte_code = compiled_result->context_bin.data(); + *byte_code_size = compiled_result->context_bin.size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetCompiledResultCallInfo( + LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, + const void** call_info, size_t* call_info_size) { + if (call_idx >= compiled_result->graph_names.size()) { + return kLiteRtStatusErrorIndexOOB; + } + + *call_info = compiled_result->graph_names.at(call_idx).data(); + *call_info_size = compiled_result->graph_names.at(call_idx).size(); + + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetNumCompiledResultCalls( + LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls) { + *num_calls = compiled_result->graph_names.size(); + return kLiteRtStatusOk; +} + +void LiteRtDestroyCompiledResult(LiteRtCompiledResult compiled_result) { + delete compiled_result; +} + +// +// Plugin Definition +// + +// Plugins can hold state. +struct LiteRtCompilerPluginT {}; + +LiteRtStatus LiteRtCreateCompilerPlugin(LiteRtCompilerPlugin* compiler_plugin) { + auto* plugin = new LiteRtCompilerPluginT; + *compiler_plugin = plugin; + return kLiteRtStatusOk; +} + +void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin) { + delete compiler_plugin; +} + +namespace { + +// TODO update this function to match the new legalizations. +bool IsOpSupported(const litert::Op& op) { + // NOTE: Currently we are demoing by just mapping simple f32 mul ops. + // In the limit this function withh want to leverage QNN SDK's getSuportedOps + // feature (along with our op/type mappings). + // Use a very loose guard for now -- only checking if op code is supported. + + for (auto supported_op : kSupportedOps) { + if (op.Code() == supported_op) { + return true; + } + } + return false; +} + +} // namespace + +LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, + LiteRtSubgraph subgraph, + LiteRtOpList selected_ops) { + ::litert::Subgraph graph(subgraph); + for (const auto& op : graph.Ops()) { + if (!IsOpSupported(op)) { + continue; + } + + LITERT_RETURN_STATUS_IF_NOT_OK(LiteRtPushOp(selected_ops, op.Get())); + } + + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtCompilerPluginCompile( + LiteRtCompilerPlugin compiler_plugin, const char* soc_model, + LiteRtSubgraphArray partitions, LiteRtParamIndex num_partitions, + LiteRtCompiledResult* compiled_result) { + LITERT_LOG(LITERT_INFO, + "Starting QNN Compilation for %d subgraphs, soc_model=%s", + num_partitions, soc_model); + + auto opt_soc_model = soc_model ? FindSocModel(soc_model) : std::nullopt; + if (opt_soc_model) { + LITERT_LOG(LITERT_ERROR, "Compiling QNN architecture: %d", *opt_soc_model); + } else if (soc_model) { + LITERT_LOG(LITERT_ERROR, "Unexpected SoC model: %s", soc_model); + return kLiteRtStatusErrorInvalidArgument; + } + + // Initialize SDK and load qnn shared libraries. + + LITERT_LOG(LITERT_INFO, "%s", "Creating QNN manager"); + auto backend_configs = QnnManager::DefaultBackendConfigs(); + auto qnn_manager = QnnManager::Create( + backend_configs, /*shared_library_dir=*/std::nullopt, opt_soc_model); + if (!qnn_manager) { + LITERT_LOG(LITERT_ERROR, "%s", qnn_manager.Error().Message().data()); + return qnn_manager.Error().Status(); + } + LITERT_LOG(LITERT_INFO, "%s", "QNN manager created"); + + // Initialize context. + + LITERT_LOG(LITERT_INFO, "%s", "Creating context handle"); + auto context_configs = QnnManager::DefaultContextConfigs(); + auto context_handle = (*qnn_manager)->CreateContextHandle(context_configs); + if (!context_handle) { + LITERT_LOG(LITERT_ERROR, "%s", context_handle.Error().Message().data()); + return context_handle.Error().Status(); + } + LITERT_LOG(LITERT_INFO, "%s", "Context handle created"); + + auto result = std::make_unique(); + + // Compose graphs. + + LITERT_LOG(LITERT_INFO, "%s", "Composing graph(s)"); + // TODO: Support multiple partitions in QCC plugin compile. + LITERT_ENSURE_SUPPORTED(num_partitions, 1); + { + std::string& entry_point_name = result->graph_names.emplace_back(); + entry_point_name = "qnn_partition_0"; + LITERT_RETURN_STATUS_IF_NOT_OK(litert::qnn::ComposeGraph( + **qnn_manager, context_handle->get(), partitions[0], entry_point_name)); + } + LITERT_LOG(LITERT_INFO, "%s", "Graph composed"); + + // Generate context binary. + + LITERT_LOG(LITERT_INFO, "%s", "Generating context binary"); + LITERT_RETURN_STATUS_IF_NOT_OK( + (*qnn_manager) + ->GenerateContextBinary(context_handle->get(), result->context_bin)); + LITERT_LOG(LITERT_INFO, "%s", "Context binary generated"); + + *compiled_result = result.release(); + + return kLiteRtStatusOk; +} diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc new file mode 100644 index 00000000..9c5f800a --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc @@ -0,0 +1,190 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_op_code.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/core/model/model.h" +#include "tflite/experimental/litert/test/common.h" +#include "tflite/experimental/litert/test/test_macros.h" +#include "tflite/experimental/litert/test/test_models.h" +#include "tflite/experimental/litert/vendors/c/litert_compiler_plugin.h" +#include "tflite/experimental/litert/vendors/cc/litert_compiler_plugin.h" + +namespace litert { +namespace { + +using ::testing::Values; + +// clang-format off +const auto kSupportedOps = + Values( + "simple_add_op.tflite", + "simple_div_op.tflite", + "simple_mul_op.tflite", + "simple_rsqrt_op.tflite", + "simple_slice_op.tflite", + "simple_sub_op.tflite", + "simple_sum_op.tflite", + "simple_tanh_op.tflite", + "simple_reshape_op.tflite", + "simple_batch_matmul_op.tflite", + "rms_norm.tflite", + "simple_concatenation_op.tflite", + "simple_softmax_op.tflite", + "simple_cast_op.tflite", + "simple_transpose_op.tflite", + "simple_sin_op.tflite", + "simple_cos_op.tflite", + "simple_select_op.tflite", + "simple_select_v2_op.tflite", + "simple_fully_connected_op.tflite", + "fully_connected_3d.tflite", + "simple_embedding_lookup_op.tflite", + "simple_logical_and_op.tflite", + "simple_less_op.tflite", + "simple_greater_op.tflite", + kFeedForwardModel, + kKeyEinsumModel, + kQueryEinsumModel, + kValueEinsumModel, + kAttnVecEinsumModel, + kROPEModel, + kLookUpROPEModel, + kRMSNormModel, + kSDPAModel, + kAttentionModel, + kTransformerBlockModel, + kQSimpleMul16x16Model, + kQMulAdd16x16Model, + kQQueryEinsum16x8Model, + kQKeyEinsum16x8Model, + kQVauleEinsum16x8Model, + kQAttnVecEinsum16x8Model + ); +// clang-format on + +TEST(TestQnnPlugin, GetConfigInfo) { + EXPECT_STREQ(LiteRtGetCompilerPluginSocManufacturer(), "Qualcomm"); + + auto plugin = CreatePlugin(); + + LiteRtParamIndex num_supported_soc_models; + LITERT_ASSERT_STATUS_OK(LiteRtGetNumCompilerPluginSupportedSocModels( + plugin.get(), &num_supported_soc_models)); + ASSERT_EQ(num_supported_soc_models, 5); + + const char* config_id; + LITERT_CHECK_STATUS_OK( + LiteRtGetCompilerPluginSupportedSocModel(plugin.get(), 0, &config_id)); + EXPECT_STREQ(config_id, "V68"); +} + +TEST(TestQnnPlugin, PartitionMulOps) { + auto plugin = CreatePlugin(); + auto model = testing::LoadTestFileModel("one_mul.tflite"); + + LiteRtOpListT selected_op_list; + LITERT_ASSERT_STATUS_OK(LiteRtCompilerPluginPartition( + plugin.get(), model.Subgraph(0)->Get(), &selected_op_list)); + const auto selected_ops = selected_op_list.Vec(); + + ASSERT_EQ(selected_ops.size(), 1); + EXPECT_EQ(selected_ops[0]->OpCode(), kLiteRtOpCodeTflMul); +} + +TEST(TestQnnPlugin, CompileMulSubgraph) { + auto plugin = CreatePlugin(); + auto model = testing::LoadTestFileModel("one_mul.tflite"); + + const auto subgraph = model.MainSubgraph(); + LiteRtSubgraph litert_subgraph = subgraph->Get(); + + LiteRtCompiledResult compiled; + LITERT_ASSERT_STATUS_OK(LiteRtCompilerPluginCompile( + plugin.get(), "V75", &litert_subgraph, 1, &compiled)); + + const void* byte_code; + size_t byte_code_size; + + LITERT_ASSERT_STATUS_OK( + LiteRtGetCompiledResultByteCode(compiled, &byte_code, &byte_code_size)); + + absl::string_view byte_code_string(reinterpret_cast(byte_code), + byte_code_size); + ASSERT_FALSE(byte_code_string.empty()); + + const void* op_data; + size_t op_data_size; + + LITERT_ASSERT_STATUS_OK( + LiteRtGetCompiledResultCallInfo(compiled, 0, &op_data, &op_data_size)); + + absl::string_view op_data_string(reinterpret_cast(op_data), + op_data_size); + ASSERT_EQ("qnn_partition_0", op_data_string); + + LiteRtDestroyCompiledResult(compiled); +} + +class QnnPluginOpCompatibilityTest + : public ::testing::TestWithParam {}; + +TEST_P(QnnPluginOpCompatibilityTest, SupportedOpsTest) { + LITERT_LOG(LITERT_INFO, "Testing TFLite model: %s", GetParam().c_str()); + auto plugin = CreatePlugin(); + auto model = testing::LoadTestFileModel(GetParam()); + + const auto subgraph = model.MainSubgraph(); + LiteRtSubgraph litert_subgraph = subgraph->Get(); + + LiteRtCompiledResult compiled; + LITERT_ASSERT_STATUS_OK(LiteRtCompilerPluginCompile( + plugin.get(), "V75", &litert_subgraph, 1, &compiled)); + + const void* byte_code; + size_t byte_code_size; + + LITERT_ASSERT_STATUS_OK( + LiteRtGetCompiledResultByteCode(compiled, &byte_code, &byte_code_size)); + + absl::string_view byte_code_string(reinterpret_cast(byte_code), + byte_code_size); + ASSERT_FALSE(byte_code_string.empty()); + + const void* op_data; + size_t op_data_size; + + LITERT_ASSERT_STATUS_OK( + LiteRtGetCompiledResultCallInfo(compiled, 0, &op_data, &op_data_size)); + + absl::string_view op_data_string(reinterpret_cast(op_data), + op_data_size); + ASSERT_EQ("qnn_partition_0", op_data_string); + + LiteRtDestroyCompiledResult(compiled); +} + +INSTANTIATE_TEST_SUITE_P(SupportedOpsTest, QnnPluginOpCompatibilityTest, + kSupportedOps); + +} // namespace +} // namespace litert diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.cc b/tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.cc new file mode 100644 index 00000000..24e5547d --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.cc @@ -0,0 +1,173 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h" + +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/cc/litert_macros.h" +#include "tflite/experimental/litert/cc/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/common.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.h" +#include "tflite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +namespace litert::qnn { + +namespace { + +LiteRtStatus RegisterAllLegalizations( + std::vector>& legalizations) { + legalizations.push_back(MulOpLegalization::Create()); + legalizations.push_back(BatchMatmulOpLegalization::Create()); + legalizations.push_back(SliceOpLegalization::Create()); + legalizations.push_back(AddOpLegalization::Create()); + legalizations.push_back(DivOpLegalization::Create()); + legalizations.push_back(RsqrtOpLegalization::Create()); + legalizations.push_back(TanhOpLegalization::Create()); + legalizations.push_back(SubOpLegalization::Create()); + legalizations.push_back(ReshapeOpLegalization::Create()); + legalizations.push_back(SumOpLegalization::Create()); + legalizations.push_back(ConcatenationOpLegalization::Create()); + legalizations.push_back(SoftmaxOpLegalization::Create()); + legalizations.push_back(CastOpLegalization::Create()); + legalizations.push_back(TransposeOpLegalization::Create()); + legalizations.push_back(SinOpLegalization::Create()); + legalizations.push_back(CosOpLegalization::Create()); + legalizations.push_back(SelectOpLegalization::Create()); + legalizations.push_back(FullyConnectedOpLegalization::Create()); + legalizations.push_back(EmbeddingLookupOpLegalization::Create()); + legalizations.push_back(LogicalAndOpLegalization::Create()); + legalizations.push_back(LessOpLegalization::Create()); + legalizations.push_back(GreaterOpLegalization::Create()); + LITERT_LOG(LITERT_INFO, "Scheduling %lu legalizations", legalizations.size()); + return kLiteRtStatusOk; +} + +LiteRtStatus MapGraph(QnnManager& qnn, Qnn_ContextHandle_t context_handle, + LiteRtSubgraph subgraph, + absl::string_view qnn_graph_name) { + // Register legalizations. + std::vector> legalizations; + LITERT_RETURN_STATUS_IF_NOT_OK(RegisterAllLegalizations(legalizations)); + + GraphMapper graph_mapper(subgraph, qnn, context_handle); + LITERT_RETURN_STATUS_IF_NOT_OK(graph_mapper.IsLiteRtSubgraphSupported()); + LITERT_RETURN_STATUS_IF_NOT_OK(graph_mapper.InitQnnGraph(qnn_graph_name)); + + // + // Legalize subgraph inputs and update tensors in scope + // + + for (const auto& subgraph_input : graph_mapper.Graph().Inputs()) { + Qnn_Tensor_t qnn_subgraph_input = BuildInputTensor(); + + LITERT_RETURN_STATUS_IF_NOT_OK(graph_mapper.LegalizeAndRegister( + subgraph_input.Get(), qnn_subgraph_input)); + + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.PushToScope(subgraph_input.Get(), qnn_subgraph_input)); + } + + for (const auto& subgraph_output : graph_mapper.Graph().Outputs()) { + graph_mapper.RegisterOutput(subgraph_output.Get()); + } + // + // Topologically traverse graph, legalizing and updating tensors in scope + // + + for (const auto& op : graph_mapper.Graph().Ops()) { + Qnn_OpConfig_t qnn_op = BuildDefaultOp(); + for (auto it = legalizations.begin(); it != legalizations.end(); ++it) { + LITERT_RETURN_STATUS_IF_NOT_OK_OR_NOT_MATCHED( + (*it)->LegalizeOp(op, qnn_op, graph_mapper)); + } + } + + LITERT_RETURN_STATUS_IF_QNN_NOT_OK(graph_mapper.Finalize()); + + return kLiteRtStatusOk; +} + +} // namespace + +//===----------------------------------------------------------------------===// +// +// [WIP] LiteRT SUBGRAPH -> QNN GRAPH +// +// Core driver for IR translation. Traverses LiteRt Subgraph, iteratively +// "legalizing" (mapping) LiteRt entities to their QNN counterpart. +// +// APPROACH: +// +// To support the general case we will need a driver loop that either +// traverses input recursively through edges or just iterates topologically. +// +// The algorithm is pretty straightforward: +// * Store mapping between already evaluated LiteRtTensors and their +// newly constructed Qnn Tensor counterpart. +// * Look up QNN Tensors when setting QNN Op inputs. +// * Add new QNN Tensor when setting QNN Op outputs. +// +// NOTES ON QNN API: +// +// After QNN Tensors are registered in the context, they need only +// be stored as their ID. QNN Tensor and "id" : uint32_t are used +// interchangeably. +// +//===----------------------------------------------------------------------===// + +LiteRtStatus ComposeGraph(QnnManager& qnn, Qnn_ContextHandle_t context_handle, + LiteRtSubgraph subgraph, + absl::string_view qnn_graph_name) { + LITERT_RETURN_STATUS_IF_NOT_OK( + MapGraph(qnn, context_handle, subgraph, qnn_graph_name)); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h b/tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h new file mode 100644 index 00000000..aebfa707 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h @@ -0,0 +1,33 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_QNN_COMPOSE_GRAPH_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_QNN_COMPOSE_GRAPH_H_ + +#include "absl/strings/string_view.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +namespace litert::qnn { + +// Composes a new QNN Graph from given LiteRt Graph. Qnn Graph is written to +// context behind "qnn". Uses given graph_name to name entry point. +LiteRtStatus ComposeGraph(QnnManager& qnn, Qnn_ContextHandle_t context_handle, + LiteRtSubgraph subgraph, + absl::string_view qnn_graph_name); + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_QNN_COMPOSE_GRAPH_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/context_binary_info.cc b/tflite/experimental/litert/vendors/qualcomm/context_binary_info.cc new file mode 100644 index 00000000..dabd8136 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/context_binary_info.cc @@ -0,0 +1,216 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/context_binary_info.h" + +#include +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "third_party/qairt/latest/include/QNN/System/QnnSystemContext.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/vendors/qualcomm/qnn_manager.h" +#include "tflite/experimental/litert/vendors/qualcomm/qnn_tensor.h" + +namespace litert { +namespace qnn { + +namespace { + +Expected InsertQnnTensors(int num_qnn_tensors, Qnn_Tensor_t* qnn_tensors, + std::vector* tensors) { + tensors->clear(); + tensors->reserve(num_qnn_tensors); + for (auto i = 0; i < num_qnn_tensors; ++i) { + auto tensor = QnnTensor::Create(qnn_tensors[i]); + if (!tensor) { + return Unexpected(tensor.Error()); + } + tensors->push_back(std::move(*tensor)); + } + return {}; +} + +Expected InsertQnnGraphInfos( + int num_qnn_graph_infos, QnnSystemContext_GraphInfo_t* qnn_graph_infos, + std::vector* graphs) { + graphs->clear(); + graphs->reserve(num_qnn_graph_infos); + for (auto i = 0; i < num_qnn_graph_infos; ++i) { + auto graph = GraphInfo::Create(qnn_graph_infos[i]); + if (!graph) { + return Unexpected(graph.Error()); + } + graphs->push_back(std::move(*graph)); + } + + return {}; +} + +} // namespace + +Expected GraphInfo::Create( + const QnnSystemContext_GraphInfo_t& graph_info) { + GraphInfo info; + auto status = info.Init(graph_info); + if (status) { + return info; + } else { + return Unexpected(status.Error()); + } +} + +Expected GraphInfo::Init(const QnnSystemContext_GraphInfo_t& graph_info) { + if (graph_info.version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) { + const auto& graph_info_ = graph_info.graphInfoV1; + name_ = graph_info_.graphName; + LITERT_LOG(LITERT_INFO, "Found qnn graph: %s", name_.c_str()); + + if (auto status = InsertQnnTensors(graph_info_.numGraphInputs, + graph_info_.graphInputs, &inputs_); + !status) { + return Unexpected(status.Error()); + } + if (auto status = InsertQnnTensors(graph_info_.numGraphOutputs, + graph_info_.graphOutputs, &outputs_); + !status) { + return Unexpected(status.Error()); + } + + } else if (graph_info.version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_2) { + const auto& graph_info_ = graph_info.graphInfoV2; + name_ = graph_info_.graphName; + LITERT_LOG(LITERT_INFO, "Found qnn graph: %s", name_.c_str()); + + if (auto status = InsertQnnTensors(graph_info_.numGraphInputs, + graph_info_.graphInputs, &inputs_); + !status) { + return Unexpected(status.Error()); + } + if (auto status = InsertQnnTensors(graph_info_.numGraphOutputs, + graph_info_.graphOutputs, &outputs_); + !status) { + return Unexpected(status.Error()); + } + } else if (graph_info.version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) { + const auto& graph_info_ = graph_info.graphInfoV3; + name_ = graph_info_.graphName; + LITERT_LOG(LITERT_INFO, "Found qnn graph: %s", name_.c_str()); + + if (auto status = InsertQnnTensors(graph_info_.numGraphInputs, + graph_info_.graphInputs, &inputs_); + !status) { + return Unexpected(status.Error()); + } + if (auto status = InsertQnnTensors(graph_info_.numGraphOutputs, + graph_info_.graphOutputs, &outputs_); + !status) { + return Unexpected(status.Error()); + } + + } else { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Unsupported graph info version."); + } + return {}; +} + +Expected ContextBinaryInfo::Init( + const QnnSystemContext_BinaryInfo_t& binary_info) { + if (binary_info.version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) { + const auto& context_binary_info = binary_info.contextBinaryInfoV1; + if (auto status = InsertQnnTensors(context_binary_info.numContextTensors, + context_binary_info.contextTensors, + &context_tensors_); + !status) { + return Unexpected(status.Error()); + } + if (auto status = InsertQnnGraphInfos(context_binary_info.numGraphs, + context_binary_info.graphs, &graphs_); + !status) { + return Unexpected(status.Error()); + } + + } else if (binary_info.version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { + const auto& context_binary_info = binary_info.contextBinaryInfoV2; + if (auto status = InsertQnnTensors(context_binary_info.numContextTensors, + context_binary_info.contextTensors, + &context_tensors_); + !status) { + return Unexpected(status.Error()); + } + if (auto status = InsertQnnGraphInfos(context_binary_info.numGraphs, + context_binary_info.graphs, &graphs_); + !status) { + return Unexpected(status.Error()); + } + } else if (binary_info.version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) { + const auto& context_binary_info = binary_info.contextBinaryInfoV3; + if (auto status = InsertQnnTensors(context_binary_info.numContextTensors, + context_binary_info.contextTensors, + &context_tensors_); + !status) { + return Unexpected(status.Error()); + } + if (auto status = InsertQnnGraphInfos(context_binary_info.numGraphs, + context_binary_info.graphs, &graphs_); + !status) { + return Unexpected(status.Error()); + } + } else { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Unsupported context binary version."); + } + return {}; +} + +Expected ContextBinaryInfo::Create( + QnnManager& qnn, const void* exec_bytecode_ptr, size_t exec_bytecode_size) { + auto system_context_handle = qnn.CreateSystemContextHandle(); + if (!system_context_handle) { + return Unexpected(system_context_handle.Error()); + } + + const QnnSystemContext_BinaryInfo_t* binary_info = nullptr; + Qnn_ContextBinarySize_t binary_info_size = 0; + if (auto status = qnn.SystemApi()->systemContextGetBinaryInfo( + system_context_handle->get(), const_cast(exec_bytecode_ptr), + exec_bytecode_size, &binary_info, &binary_info_size); + status != QNN_SUCCESS) { + LITERT_LOG(LITERT_ERROR, "Failed to get context binary info: %d", status); + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to get context binary info"); + } + + if (!binary_info) { + LITERT_LOG(LITERT_ERROR, "Null binary info", ""); + return Unexpected(kLiteRtStatusErrorRuntimeFailure, "Null binary info"); + } + + ContextBinaryInfo info; + auto status = info.Init(*binary_info); + + if (status) { + return info; + } else { + return Unexpected(status.Error()); + } +} + +} // namespace qnn +} // namespace litert diff --git a/tflite/experimental/litert/vendors/qualcomm/context_binary_info.h b/tflite/experimental/litert/vendors/qualcomm/context_binary_info.h new file mode 100644 index 00000000..d60b3e1c --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/context_binary_info.h @@ -0,0 +1,68 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CONTEXT_BINARY_INFO_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CONTEXT_BINARY_INFO_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/vendors/qualcomm/qnn_manager.h" +#include "tflite/experimental/litert/vendors/qualcomm/qnn_tensor.h" + +namespace litert { +namespace qnn { + +class GraphInfo { + public: + static Expected Create( + const QnnSystemContext_GraphInfo_t& graph_info); + const std::string& Name() const { return name_; } + const std::vector& Inputs() const { return inputs_; } + const std::vector& Outputs() const { return outputs_; } + + private: + GraphInfo() = default; + Expected Init(const QnnSystemContext_GraphInfo_t& graph_info); + std::string name_; + std::vector inputs_; + std::vector outputs_; +}; + +class ContextBinaryInfo { + public: + static Expected Create(QnnManager& qnn, + const void* exec_bytecode_ptr, + size_t exec_bytecode_size); + const std::vector& ContextTensors() const { + return context_tensors_; + } + const std::vector& Graphs() const { return graphs_; } + + private: + ContextBinaryInfo() = default; + Expected Init(const QnnSystemContext_BinaryInfo_t& binary_info); + std::vector context_tensors_; + std::vector graphs_; +}; + +} // namespace qnn +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CONTEXT_BINARY_INFO_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/dispatch/BUILD b/tflite/experimental/litert/vendors/qualcomm/dispatch/BUILD new file mode 100644 index 00000000..e0134009 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/dispatch/BUILD @@ -0,0 +1,93 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//tflite/experimental/litert/build_common:litert_build_defs.bzl", "litert_dynamic_lib") + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +litert_dynamic_lib( + name = "dispatch_api", + srcs = [ + "dispatch_api.cc", + "litert_dispatch_device_context.cc", + "litert_dispatch_invocation_context.cc", + ], + hdrs = [ + "litert_dispatch_device_context.h", + "litert_dispatch_invocation_context.h", + "registry.h", + ], + export_litert_only = True, + linkopts = select({ + "@org_tensorflow//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + linkstatic = 1, + shared_lib_name = "dispatch_api_so", + so_name = "libLiteRtDispatch.so", + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tflite/experimental/litert:__subpackages__"], + deps = [ + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings:string_view", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_logging", + "//tflite/experimental/litert/c:litert_model", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "//tflite/experimental/litert/cc:litert_expected", + "//tflite/experimental/litert/core/util:tensor_type_util", + "//tflite/experimental/litert/vendors/c:litert_dispatch_c_api", + "//tflite/experimental/litert/vendors/qualcomm:common", + "//tflite/experimental/litert/vendors/qualcomm:context_binary_info", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager", + ], +) + +cc_test( + name = "dispatch_api_qualcomm_test", + srcs = [ + "dispatch_api_qualcomm_test.cc", + ], + data = [ + ":dispatch_api_so", + ], + linkopts = select({ + "@org_tensorflow//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + linkstatic = 1, + tags = [ + "no-remote-exec", + "no_oss", + "notap", + ], + deps = [ + "//tflite/experimental/litert/c:litert_common", + "//tflite/experimental/litert/c:litert_tensor_buffer", + "//tflite/experimental/litert/core:filesystem", + "//tflite/experimental/litert/test:simple_model_npu", + "//tflite/experimental/litert/vendors/c:litert_dispatch_c_api", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tflite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api.cc b/tflite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api.cc new file mode 100644 index 00000000..ef913907 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api.cc @@ -0,0 +1,296 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch_api.h" +#include "tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h" +#include "tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h" +#include "tflite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +namespace { + +using ::litert::qnn::QnnManager; + +static constexpr const int VERSION_MAJOR = 0; +static constexpr const int VERSION_MINOR = 1; +static constexpr const int VERSION_PATCH = 0; + +static std::unique_ptr TheQnnManager; + +QnnManager& Qnn() { return *TheQnnManager; } + +char BuildId[256]; + +// ///////////////////////////////////////////////////////////////////////////// +// Basic Execution API +// ///////////////////////////////////////////////////////////////////////////// + +const char* GetSharedLibraryDir(const LiteRtDispatchOption* options, + int num_options) { + for (auto i = 0; i < num_options; ++i) { + auto& option = options[i]; + if (!strcmp(option.name, kDispatchOptionSharedLibraryDir)) { + return option.value.str_value; + } + } + return nullptr; +} + +LiteRtStatus Initialize(const LiteRtDispatchOption* options, int num_options) { + auto* shared_library_dir = GetSharedLibraryDir(options, num_options); + std::optional shared_library_dir_opt = + shared_library_dir ? std::make_optional(std::string(shared_library_dir)) + : std::nullopt; + + auto configs = QnnManager::DefaultBackendConfigs(); + if (auto qnn_manager = QnnManager::Create(configs, shared_library_dir_opt); + !qnn_manager) { + LITERT_LOG(LITERT_ERROR, "%s", qnn_manager.Error().Message().data()); + return qnn_manager.Error().Status(); + } else { + std::swap(TheQnnManager, *qnn_manager); + } + + Qnn_ApiVersion_t qnn_api_version; + if (auto status = Qnn().Api()->backendGetApiVersion(&qnn_api_version); + status != QNN_SUCCESS) { + LITERT_LOG(LITERT_ERROR, "Failed to get QNN API version: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + const char* build_id; + if (auto status = Qnn().Api()->backendGetBuildId(&build_id); + status != QNN_SUCCESS) { + LITERT_LOG(LITERT_ERROR, "Failed to get QNN build ID: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + snprintf(BuildId, sizeof(BuildId), + "Qualcomm Dispatch API version %d.%d.%d, QNN API version %d.%d.%d, " + "build id: %s", + VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH, + qnn_api_version.coreApiVersion.major, + qnn_api_version.coreApiVersion.minor, + qnn_api_version.coreApiVersion.patch, build_id); + BuildId[sizeof(BuildId) - 1] = 0; + + return kLiteRtStatusOk; +} + +LiteRtStatus GetVendorId(const char** vendor_id) { + *vendor_id = "Qualcomm"; + return kLiteRtStatusOk; +} + +LiteRtStatus GetBuildId(const char** build_id) { + *build_id = BuildId; + return kLiteRtStatusOk; +} + +LiteRtStatus GetCapabilities(int* capabilities) { + *capabilities = kLiteRtDispatchCapabilitiesBasic; + return kLiteRtStatusOk; +} + +LiteRtStatus DeviceContextCreate(LiteRtDispatchDeviceContext* device_context) { + if (auto context = LiteRtDispatchDeviceContextT::Create(Qnn()); context) { + *device_context = context->release(); + return kLiteRtStatusOk; + } else { + LITERT_LOG(LITERT_ERROR, "Failed to create device context: %s", + context.Error().Message().data()); + return context.Error().Status(); + } +} + +LiteRtStatus DeviceContextDestroy(LiteRtDispatchDeviceContext device_context) { + delete device_context; + return kLiteRtStatusOk; +} + +LiteRtStatus GetInputRequirements( + LiteRtDispatchInvocationContext invocation_context, int input_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements) { + if (auto requirements = + invocation_context->GetInputRequirements(input_index, *tensor_type); + requirements) { + *tensor_buffer_requirements = *requirements; + return kLiteRtStatusOk; + } else { + LITERT_LOG(LITERT_ERROR, "Failed to get tensor buffer requirements: %s", + requirements.Error().Message().data()); + return requirements.Error().Status(); + } +} + +LiteRtStatus GetOutputRequirements( + LiteRtDispatchInvocationContext invocation_context, int output_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements) { + if (auto requirements = + invocation_context->GetOutputRequirements(output_index, *tensor_type); + requirements) { + *tensor_buffer_requirements = *requirements; + return kLiteRtStatusOk; + } else { + LITERT_LOG(LITERT_ERROR, "Failed to get tensor buffer requirements: %s", + requirements.Error().Message().data()); + return requirements.Error().Status(); + } +} + +LiteRtStatus RegisterTensorBuffer( + LiteRtDispatchDeviceContext device_context, LiteRtTensorBuffer buffer, + LiteRtTensorBufferHandle* tensor_buffer_handle) { + if (auto status = device_context->RegisterTensorBuffer(buffer); !status) { + LITERT_LOG(LITERT_ERROR, "Failed to register buffer: %s", + status.Error().Message().data()); + return status.Error().Status(); + } else { + *tensor_buffer_handle = *status; + return kLiteRtStatusOk; + } +} + +LiteRtStatus UnregisterTensorBuffer(LiteRtDispatchDeviceContext device_context, + LiteRtTensorBufferHandle handle) { + if (auto status = device_context->UnregisterTensorBuffer(handle); !status) { + LITERT_LOG(LITERT_ERROR, "Failed to unregister buffer: %s", + status.Error().Message().data()); + return status.Error().Status(); + } else { + return kLiteRtStatusOk; + } +} + +LiteRtStatus InvocationContextCreate( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType exec_type, const void* exec_bytecode_ptr, + size_t exec_bytecode_size, const char* function_name, int num_inputs, + int num_outputs, LiteRtDispatchInvocationContext* invocation_context) { + auto context = LiteRtDispatchInvocationContextT::Create( + Qnn(), *device_context, exec_bytecode_ptr, exec_bytecode_size, + function_name); + if (!context) { + LITERT_LOG(LITERT_ERROR, "Failed to create context from context binary: %s", + context.Error().Message().data()); + return context.Error().Status(); + } + *invocation_context = context->release(); + device_context->SetInvocationContext(*invocation_context); + return kLiteRtStatusOk; +} + +LiteRtStatus InvocationContextDestroy( + LiteRtDispatchInvocationContext invocation_context) { + delete invocation_context; + return kLiteRtStatusOk; +} + +LiteRtStatus AttachInput(LiteRtDispatchInvocationContext invocation_context, + int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (auto status = invocation_context->AttachInput(graph_input_index, + tensor_buffer_handle); + !status) { + LITERT_LOG(LITERT_ERROR, "Failed to attach input buffer: %s", + status.Error().Message().data()); + return status.Error().Status(); + } + return kLiteRtStatusOk; +} + +LiteRtStatus AttachOutput(LiteRtDispatchInvocationContext invocation_context, + int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (auto status = invocation_context->AttachOutput(graph_output_index, + tensor_buffer_handle); + !status) { + LITERT_LOG(LITERT_ERROR, "Failed to attach output buffer: %s", + status.Error().Message().data()); + return status.Error().Status(); + } + return kLiteRtStatusOk; +} + +LiteRtStatus DetachInput(LiteRtDispatchInvocationContext invocation_context, + int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + // Nothing to do here. + return kLiteRtStatusOk; +} + +LiteRtStatus DetachOutput(LiteRtDispatchInvocationContext invocation_context, + int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + // Nothing to do here. + return kLiteRtStatusOk; +} + +LiteRtStatus Invoke(LiteRtDispatchInvocationContext invocation_context) { + if (auto status = invocation_context->Execute(); !status) { + LITERT_LOG(LITERT_ERROR, "Failed to execute invocation context: %s", + status.Error().Message().data()); + return status.Error().Status(); + } + return kLiteRtStatusOk; +} + +// ///////////////////////////////////////////////////////////////////////////// + +LiteRtDispatchInterface TheInterface = { + /*.initialize=*/Initialize, + /*.get_vendor_id=*/GetVendorId, + /*.get_build_id=*/GetBuildId, + /*.get_capabilities=*/GetCapabilities, + /*.device_context_create=*/DeviceContextCreate, + /*.device_context_destroy=*/DeviceContextDestroy, + /*.get_input_requirements=*/GetInputRequirements, + /*.get_output_requirements=*/GetOutputRequirements, + /*.register_tensor_buffer=*/RegisterTensorBuffer, + /*.unregister_tensor_buffer=*/UnregisterTensorBuffer, + /*.invocation_context_create=*/InvocationContextCreate, + /*.invocation_context_destroy=*/InvocationContextDestroy, + /*.attach_input=*/AttachInput, + /*.attach_output=*/AttachOutput, + /*.detach_input=*/DetachInput, + /*.detach_output=*/DetachOutput, + /*.invoke=*/Invoke, +}; + +LiteRtDispatchApi TheApi = { + /*.version=*/{/*.major=*/VERSION_MAJOR, + /*.minor=*/VERSION_MINOR, + /*.patch=*/VERSION_PATCH}, + /*.interface=*/&TheInterface, + /*.async_interface=*/nullptr, + /*.graph_interface=*/nullptr, +}; + +} // namespace + +LiteRtStatus LiteRtDispatchGetApi(LiteRtDispatchApi* api) { + *api = TheApi; + return kLiteRtStatusOk; +} diff --git a/tflite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc b/tflite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc new file mode 100644 index 00000000..29057d91 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc @@ -0,0 +1,532 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include "absl/log/absl_log.h" +#include "absl/log/log.h" +#include "absl/types/span.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/core/filesystem.h" +#include "tflite/experimental/litert/test/testdata/simple_model_test_vectors.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" + +using ::testing::Pointwise; + +TEST(Qualcomm, DispatchApiWithFastRpc) { +#if !defined(__ANDROID__) + GTEST_SKIP() + << "This test is specific to Android devices with a Qualcomm NPU"; +#endif + + EXPECT_EQ(LiteRtDispatchInitialize(/*options=*/nullptr, /*num_options=*/0), + kLiteRtStatusOk); + + const char* vendor_id; + EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); + ABSL_LOG(INFO) << "vendor_id: " << vendor_id; + + const char* build_id; + EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); + ABSL_LOG(INFO) << "build_id: " << build_id; + + LiteRtApiVersion api_version; + EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); + ABSL_LOG(INFO) << "api_version: " << api_version.major << "." + << api_version.minor << "." << api_version.patch; + + int capabilities; + EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); + ABSL_LOG(INFO) << "capabilities: " << capabilities; + + LiteRtDispatchDeviceContext device_context = nullptr; + EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), + kLiteRtStatusOk); + ABSL_LOG(INFO) << "device_context: " << device_context; + + auto model_file_name = kQualcommModelFileName; + auto model = litert::internal::LoadBinaryFile(model_file_name); + EXPECT_TRUE(model); + ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() + << " bytes"; + + // /////////////////////////////////////////////////////////////////////////// + // Set up an invocation context for a given model. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtDispatchInvocationContext invocation_context = nullptr; + EXPECT_EQ(LiteRtDispatchInvocationContextCreate( + device_context, kLiteRtDispatchExecutableTypeMlModel, + model->Data(), model->Size(), /*function_name=*/"simple", + /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), + kLiteRtStatusOk); + ABSL_LOG(INFO) << "Invocation context: " << invocation_context; + + // /////////////////////////////////////////////////////////////////////////// + // Determine tensor buffer requirements. + // /////////////////////////////////////////////////////////////////////////// + + int num_tensor_buffer_types; + LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetInputRequirements( + invocation_context, /*input_index=*/0, &kInput0TensorType, + &input_0_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + input_0_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType input_0_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + input_0_tensor_buffer_requirements, /*type_index=*/0, + &input_0_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeFastRpc); + size_t input_0_tensor_buffer_size; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsBufferSize( + input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); + + LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetInputRequirements( + invocation_context, /*input_index=*/1, &kInput1TensorType, + &input_1_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + input_1_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType input_1_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + input_1_tensor_buffer_requirements, /*type_index=*/0, + &input_1_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeFastRpc); + size_t input_1_tensor_buffer_size; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsBufferSize( + input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); + + LiteRtTensorBufferRequirements output_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetOutputRequirements( + invocation_context, /*output_index=*/0, &kOutputTensorType, + &output_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + output_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType output_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + output_tensor_buffer_requirements, /*type_index=*/0, + &output_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeFastRpc); + size_t output_tensor_buffer_size; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( + output_tensor_buffer_requirements, &output_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); + + // /////////////////////////////////////////////////////////////////////////// + // Allocate tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtTensorBuffer input_0_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + input_0_tensor_buffer_type, &kInput0TensorType, + input_0_tensor_buffer_size, &input_0_tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBuffer input_1_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + input_1_tensor_buffer_type, &kInput1TensorType, + input_1_tensor_buffer_size, &input_1_tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBuffer output_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + output_tensor_buffer_type, &kOutputTensorType, + output_tensor_buffer_size, &output_tensor_buffer), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Register tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtTensorBufferHandle input_1_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, input_1_tensor_buffer, &input_1_handle), + kLiteRtStatusOk); + + LiteRtTensorBufferHandle input_0_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, input_0_tensor_buffer, &input_0_handle), + kLiteRtStatusOk); + + LiteRtTensorBufferHandle output_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, output_tensor_buffer, &output_handle), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Attach tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, + /*graph_input_index=*/0, input_0_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, + /*graph_input_index=*/1, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, + /*graph_output_index=*/0, output_handle), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Fill the input buffers with data. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Filling inputs with data"; + void* host_mem_addr; + + ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Execute model. + // /////////////////////////////////////////////////////////////////////////// + + ABSL_LOG(INFO) << "Invoking execution..."; + EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Check output for correctness. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Checking output..."; + void* host_mem_addr; + ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + auto output = absl::MakeSpan(static_cast(host_mem_addr), + kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Clean up resources. + // /////////////////////////////////////////////////////////////////////////// + + EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, + /*graph_input_index=*/0, input_0_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, + /*graph_input_index=*/1, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, + /*graph_output_index=*/0, output_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), + kLiteRtStatusOk); + EXPECT_EQ( + LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ( + LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), + kLiteRtStatusOk); + LiteRtDestroyTensorBuffer(output_tensor_buffer); + LiteRtDestroyTensorBuffer(input_1_tensor_buffer); + LiteRtDestroyTensorBuffer(input_0_tensor_buffer); + EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), + kLiteRtStatusOk); +} + +TEST(Qualcomm, DispatchApiWithDmaBuf) { +#if !defined(__ANDROID__) + GTEST_SKIP() + << "This test is specific to Android devices with a Qualcomm NPU"; +#endif + + EXPECT_EQ(LiteRtDispatchInitialize(/*options=*/nullptr, /*num_options=*/0), + kLiteRtStatusOk); + + const char* vendor_id; + EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); + ABSL_LOG(INFO) << "vendor_id: " << vendor_id; + + const char* build_id; + EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); + ABSL_LOG(INFO) << "build_id: " << build_id; + + LiteRtApiVersion api_version; + EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); + ABSL_LOG(INFO) << "api_version: " << api_version.major << "." + << api_version.minor << "." << api_version.patch; + + int capabilities; + EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); + ABSL_LOG(INFO) << "capabilities: " << capabilities; + + LiteRtDispatchDeviceContext device_context = nullptr; + EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), + kLiteRtStatusOk); + ABSL_LOG(INFO) << "device_context: " << device_context; + + auto model_file_name = kQualcommModelFileName; + auto model = ::litert::internal::LoadBinaryFile(model_file_name); + EXPECT_TRUE(model); + ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() + << " bytes"; + + // /////////////////////////////////////////////////////////////////////////// + // Set up an invocation context for a given model. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtDispatchInvocationContext invocation_context = nullptr; + EXPECT_EQ(LiteRtDispatchInvocationContextCreate( + device_context, kLiteRtDispatchExecutableTypeMlModel, + model->Data(), model->Size(), /*function_name=*/"simple", + /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), + kLiteRtStatusOk); + ABSL_LOG(INFO) << "Invocation context: " << invocation_context; + + // /////////////////////////////////////////////////////////////////////////// + // Determine tensor buffer requirements. + // /////////////////////////////////////////////////////////////////////////// + + int num_tensor_buffer_types; + LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetInputRequirements( + invocation_context, /*input_index=*/0, &kInput0TensorType, + &input_0_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + input_0_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType input_0_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + input_0_tensor_buffer_requirements, /*type_index=*/1, + &input_0_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); + size_t input_0_tensor_buffer_size; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsBufferSize( + input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); + + LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetInputRequirements( + invocation_context, /*input_index=*/1, &kInput1TensorType, + &input_1_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + input_1_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType input_1_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + input_1_tensor_buffer_requirements, /*type_index=*/1, + &input_1_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); + size_t input_1_tensor_buffer_size; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsBufferSize( + input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); + + LiteRtTensorBufferRequirements output_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetOutputRequirements( + invocation_context, /*output_index=*/0, &kOutputTensorType, + &output_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + output_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType output_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + output_tensor_buffer_requirements, /*type_index=*/1, + &output_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); + size_t output_tensor_buffer_size; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( + output_tensor_buffer_requirements, &output_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); + + // /////////////////////////////////////////////////////////////////////////// + // Allocate tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtTensorBuffer input_0_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + input_0_tensor_buffer_type, &kInput0TensorType, + input_0_tensor_buffer_size, &input_0_tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBuffer input_1_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + input_1_tensor_buffer_type, &kInput1TensorType, + input_1_tensor_buffer_size, &input_1_tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBuffer output_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + output_tensor_buffer_type, &kOutputTensorType, + output_tensor_buffer_size, &output_tensor_buffer), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Register tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtTensorBufferHandle input_1_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, input_1_tensor_buffer, &input_1_handle), + kLiteRtStatusOk); + + LiteRtTensorBufferHandle input_0_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, input_0_tensor_buffer, &input_0_handle), + kLiteRtStatusOk); + + LiteRtTensorBufferHandle output_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, output_tensor_buffer, &output_handle), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Attach tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, + /*graph_input_index=*/0, input_0_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, + /*graph_input_index=*/1, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, + /*graph_output_index=*/0, output_handle), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Fill the input buffers with data. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Filling inputs with data"; + void* host_mem_addr; + + ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Execute model. + // /////////////////////////////////////////////////////////////////////////// + + ABSL_LOG(INFO) << "Invoking execution..."; + EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Check output for correctness. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Checking output..."; + void* host_mem_addr; + ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + auto output = absl::MakeSpan(static_cast(host_mem_addr), + kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Clean up resources. + // /////////////////////////////////////////////////////////////////////////// + + EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, + /*graph_input_index=*/0, input_0_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, + /*graph_input_index=*/1, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, + /*graph_output_index=*/0, output_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), + kLiteRtStatusOk); + EXPECT_EQ( + LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ( + LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), + kLiteRtStatusOk); + LiteRtDestroyTensorBuffer(output_tensor_buffer); + LiteRtDestroyTensorBuffer(input_1_tensor_buffer); + LiteRtDestroyTensorBuffer(input_0_tensor_buffer); + EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), + kLiteRtStatusOk); +} diff --git a/tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.cc b/tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.cc new file mode 100644 index 00000000..6bf83f16 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.cc @@ -0,0 +1,190 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h" + +#include +#include + +#include "absl/log/absl_check.h" +#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpMem.h" +#include "third_party/qairt/latest/include/QNN/QnnBackend.h" +#include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/QnnMem.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tflite/experimental/litert/vendors/qualcomm/common.h" +#include "tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h" +#include "tflite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +using litert::Expected; +using litert::Unexpected; +using litert::qnn::QnnManager; + +Expected +LiteRtDispatchDeviceContextT::Create(QnnManager& qnn) { + return Ptr(new LiteRtDispatchDeviceContextT(qnn)); +} + +Expected LiteRtDispatchDeviceContextT::GetTensorBuffer( + LiteRtTensorBufferHandle tensor_buffer_handle) { + auto registry_entry = tensor_buffer_registry_.Get(tensor_buffer_handle); + if (!registry_entry) { + return Unexpected(registry_entry.Error()); + } + + return (*registry_entry)->tensor_buffer; +} + +Expected LiteRtDispatchDeviceContextT::GetMemHandle( + LiteRtTensorBufferHandle tensor_buffer_handle, const Qnn_Tensor_t& tensor) { + auto registry_entry = tensor_buffer_registry_.Get(tensor_buffer_handle); + if (!registry_entry) { + return Unexpected(registry_entry.Error()); + } + + if (!(*registry_entry)->qnn_mem_handle) { + auto qnn_mem_handle = + RegisterTensorBuffer((*registry_entry)->tensor_buffer, tensor); + if (!qnn_mem_handle) { + return Unexpected(qnn_mem_handle.Error()); + } + (*registry_entry)->qnn_mem_handle = *qnn_mem_handle; + } + + return (*registry_entry)->qnn_mem_handle; +} + +Expected LiteRtDispatchDeviceContextT::RegisterTensorBuffer( + LiteRtTensorBuffer tensor_buffer, const Qnn_Tensor_t& tensor) { + LiteRtTensorBufferType tensor_buffer_type; + if (auto status = + LiteRtGetTensorBufferType(tensor_buffer, &tensor_buffer_type); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to get tensor buffer type"); + } + + size_t tensor_buffer_size; + if (auto status = + LiteRtGetTensorBufferSize(tensor_buffer, &tensor_buffer_size); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to get tensor buffer size"); + } + + size_t tensor_buffer_offset; + if (auto status = + LiteRtGetTensorBufferOffset(tensor_buffer, &tensor_buffer_offset); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to get tensor buffer offset"); + } + + LiteRtRankedTensorType tensor_type; + if (auto status = + LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to get tensor buffer's type"); + } + + auto element_type = + static_cast(tensor_type.element_type); + Qnn_DataType_t tensor_data_type; + if (auto status = LegalizeElementType(element_type, &tensor_data_type); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to legalize datatype"); + } + + uint32_t tensor_rank = tensor_type.layout.rank; + uint32_t* tensor_dimensions = reinterpret_cast( + const_cast(tensor_type.layout.dimensions)); + auto* tensor_strides = tensor_type.layout.strides; + if (tensor_strides != nullptr) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Tensor strides are not supported by QNN"); + } + + void* buffer_host_addr; + int buffer_fd; + (void)buffer_host_addr; + + switch (tensor_buffer_type) { + case kLiteRtTensorBufferTypeFastRpc: +#if LITERT_HAS_FASTRPC_SUPPORT + if (auto status = LiteRtGetTensorBufferFastRpcBuffer( + tensor_buffer, &buffer_host_addr, &buffer_fd); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to get FastRPC buffer"); + } +#else + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "FastRPC support is missing on this platform"); +#endif // LRT_HAS_FASTRPC_SUPPORT + break; + + case kLiteRtTensorBufferTypeDmaBuf: +#if LITERT_HAS_DMABUF_SUPPORT + if (auto status = LiteRtGetTensorBufferDmaBufBuffer( + tensor_buffer, &buffer_host_addr, &buffer_fd); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to get DMA-BUF buffer"); + } +#else + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "DmaBuf support is missing on this platform"); +#endif // LRT_HAS_DMABUF_SUPPORT + break; + + default: + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Unsupported tensor buffer type"); + } + + QnnMemHtp_Descriptor_t mem_htp_descriptor = {}; + mem_htp_descriptor.type = QNN_HTP_MEM_SHARED_BUFFER; + mem_htp_descriptor.size = tensor_buffer_size; + mem_htp_descriptor.sharedBufferConfig = + QnnHtpMem_SharedBufferConfig_t{buffer_fd, tensor_buffer_offset}; + + Qnn_MemDescriptor_t mem_descriptor = {}; + mem_descriptor.memShape = {tensor_rank, tensor_dimensions, nullptr}; + mem_descriptor.dataType = tensor_data_type; + mem_descriptor.memType = QNN_MEM_TYPE_CUSTOM; + mem_descriptor.customInfo = &mem_htp_descriptor; + + if (invocation_context_ == nullptr) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Missing invocation context"); + } + + Qnn_ContextHandle_t context_handle = invocation_context_->ContextHandle(); + + Qnn_MemHandle_t mem_handle = nullptr; + if (auto status = qnn_manager_.Api()->memRegister( + context_handle, &mem_descriptor, 1UL, &mem_handle); + status != QNN_SUCCESS) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to register tensor buffer"); + } + + if (!mem_handle) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to register buffer: null mem_handle"); + } + + return mem_handle; +} diff --git a/tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h b/tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h new file mode 100644 index 00000000..968dc938 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h @@ -0,0 +1,79 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ + +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tflite/experimental/litert/vendors/qualcomm/dispatch/registry.h" +#include "tflite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +class LiteRtDispatchDeviceContextT { + public: + using Ptr = std::unique_ptr; + + ~LiteRtDispatchDeviceContextT() = default; + + static litert::Expected Create(litert::qnn::QnnManager& qnn_manager); + + litert::Expected RegisterTensorBuffer( + LiteRtTensorBuffer tensor_buffer) { + return tensor_buffer_registry_.Register( + TensorBufferRegistryEntry(tensor_buffer)); + } + + litert::Expected UnregisterTensorBuffer( + LiteRtTensorBufferHandle tensor_buffer_handle) { + return tensor_buffer_registry_.Unregister(tensor_buffer_handle); + } + + litert::Expected GetTensorBuffer( + LiteRtTensorBufferHandle tensor_buffer_handle); + + litert::Expected GetMemHandle( + LiteRtTensorBufferHandle tensor_buffer_handle, + const Qnn_Tensor_t& tensor); + + void SetInvocationContext( + LiteRtDispatchInvocationContextT* invocation_context) { + invocation_context_ = invocation_context; + } + + private: + struct TensorBufferRegistryEntry { + LiteRtTensorBuffer tensor_buffer; + Qnn_MemHandle_t qnn_mem_handle = nullptr; + explicit TensorBufferRegistryEntry(LiteRtTensorBuffer tensor_buffer_) + : tensor_buffer(tensor_buffer_) {} + }; + + using TensorBufferRegistry = litert::qnn::Registry; + + LiteRtDispatchDeviceContextT(litert::qnn::QnnManager& qnn_manager) + : qnn_manager_(qnn_manager) {} + + litert::Expected RegisterTensorBuffer( + LiteRtTensorBuffer tensor_buffer, const Qnn_Tensor_t& tensor); + + litert::qnn::QnnManager& qnn_manager_; + TensorBufferRegistry tensor_buffer_registry_; + LiteRtDispatchInvocationContextT* invocation_context_ = nullptr; +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.cc b/tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.cc new file mode 100644 index 00000000..fc4409f8 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.cc @@ -0,0 +1,238 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h" + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_model.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/core/util/tensor_type_util.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tflite/experimental/litert/vendors/qualcomm/context_binary_info.h" +#include "tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h" +#include "tflite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +using litert::Expected; +using litert::Unexpected; +using litert::qnn::QnnManager; + +LiteRtDispatchInvocationContextT::LiteRtDispatchInvocationContextT( + litert::qnn::QnnManager& qnn_manager, + const litert::qnn::ContextBinaryInfo& context_binary_info, + LiteRtDispatchDeviceContextT& device_context, + QnnManager::ContextHandle&& context_handle, + Qnn_ProfileHandle_t profile_handle, int graph_index, + Qnn_GraphHandle_t graph_handle) + : qnn_manager_(qnn_manager), + device_context_(device_context), + context_handle_(std::move(context_handle)), + profile_handle_(profile_handle), + graph_index_(graph_index), + graph_handle_(graph_handle), + inputs_(context_binary_info.Graphs()[graph_index].Inputs()), + outputs_(context_binary_info.Graphs()[graph_index].Outputs()) {} + +Expected +LiteRtDispatchInvocationContextT::Create( + QnnManager& qnn, LiteRtDispatchDeviceContextT& device_context, + const void* exec_bytecode_ptr, size_t exec_bytecode_size, + const char* function_name) { + auto context_binary_info = litert::qnn::ContextBinaryInfo::Create( + qnn, exec_bytecode_ptr, exec_bytecode_size); + if (!context_binary_info) { + return Unexpected(context_binary_info.Error()); + } + + const auto& graphs = context_binary_info->Graphs(); + if (graphs.empty()) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, "No graph found"); + } + + int graph_index = -1; + // If the function_name is not specified and there is only one graph, then + // take that graph. + if (absl::string_view(function_name).empty() && graphs.size() == 1) { + graph_index = 0; + const auto& graph = graphs[graph_index]; + function_name = graph.Name().c_str(); + } else { + for (auto i = 0; i < graphs.size(); ++i) { + const auto& graph = graphs[i]; + if (graph.Name() == absl::string_view(function_name)) { + graph_index = i; + break; + } + } + } + if (graph_index < 0) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Function name not found"); + } + + auto configs = QnnManager::DefaultContextConfigs(); + Qnn_ProfileHandle_t profile_handle = nullptr; + auto context_handle = qnn.CreateContextHandle( + configs, + absl::MakeSpan(static_cast(exec_bytecode_ptr), + exec_bytecode_size), + profile_handle); + if (!context_handle) { + return Unexpected(context_handle.Error()); + } + + Qnn_GraphHandle_t graph_handle; + if (auto status = qnn.Api()->graphRetrieve(context_handle->get(), + function_name, &graph_handle); + status != QNN_SUCCESS) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to retrieve graph"); + } + + return Ptr(new LiteRtDispatchInvocationContextT( + qnn, std::move(*context_binary_info), device_context, + std::move(*context_handle), profile_handle, graph_index, graph_handle)); +} + +namespace { + +Expected GetTensorBufferRequirements( + const LiteRtRankedTensorType& tensor_type) { + auto* tensor_strides = tensor_type.layout.strides; + if (tensor_strides != nullptr) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Tensor strides are not supported by QNN"); + } + + static constexpr std::array + kSupportedTensorBufferTypes = { + kLiteRtTensorBufferTypeFastRpc, + kLiteRtTensorBufferTypeDmaBuf, + }; + + auto buffer_size = litert::internal::GetNumPackedBytes(tensor_type); + if (!buffer_size) { + return Unexpected(buffer_size.Error()); + } + + LiteRtTensorBufferRequirements requirements; + if (auto status = LiteRtCreateTensorBufferRequirements( + kSupportedTensorBufferTypes.size(), + kSupportedTensorBufferTypes.data(), *buffer_size, /*num_strides=*/0, + /*strides=*/nullptr, &requirements); + status != kLiteRtStatusOk) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, "Not implemented"); + } + + return requirements; +} + +} // namespace + +Expected +LiteRtDispatchInvocationContextT::GetInputRequirements( + int input_index, const LiteRtRankedTensorType& tensor_type) { + return GetTensorBufferRequirements(tensor_type); +} + +Expected +LiteRtDispatchInvocationContextT::GetOutputRequirements( + int output_index, const LiteRtRankedTensorType& tensor_type) { + return GetTensorBufferRequirements(tensor_type); +} + +Expected LiteRtDispatchInvocationContextT::AttachInput( + int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle) { + if (graph_input_index < 0 || graph_input_index >= inputs_.size()) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Invalid graph_input_index"); + } + + auto& tensor = inputs_[graph_input_index]; + return AttachBuffer(tensor.Tensor(), tensor_buffer_handle); +} + +Expected LiteRtDispatchInvocationContextT::AttachOutput( + int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle) { + if (graph_output_index < 0 || graph_output_index >= outputs_.size()) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Invalid graph_output_index"); + } + + auto& tensor = outputs_[graph_output_index]; + return AttachBuffer(tensor.Tensor(), tensor_buffer_handle); +} + +Expected LiteRtDispatchInvocationContextT::AttachBuffer( + Qnn_Tensor_t& tensor, LiteRtTensorBufferHandle tensor_buffer_handle) { + auto tensor_buffer = device_context_.GetTensorBuffer(tensor_buffer_handle); + if (!tensor_buffer) { + return Unexpected(tensor_buffer.Error()); + } + + auto mem_handle = device_context_.GetMemHandle(tensor_buffer_handle, tensor); + if (!mem_handle) { + return Unexpected(mem_handle.Error()); + } + + if (tensor.version == QNN_TENSOR_VERSION_1) { + tensor.v1.memType = QNN_TENSORMEMTYPE_MEMHANDLE; + tensor.v1.memHandle = *mem_handle; + + } else if (tensor.version == QNN_TENSOR_VERSION_2) { + if (tensor.v2.isDynamicDimensions != nullptr) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Dynamic dimensions not yet supported"); + } + tensor.v2.memType = QNN_TENSORMEMTYPE_MEMHANDLE; + tensor.v2.memHandle = *mem_handle; + + } else { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Unsupported QNN tensor version"); + } + + return {}; +} + +Expected LiteRtDispatchInvocationContextT::Execute() { + const size_t num_ins = inputs_.size(); + LITERT_STACK_ARRAY(Qnn_Tensor_t, inputs, num_ins, QNN_TENSOR_INIT); + for (size_t i = 0; i < num_ins; ++i) { + *(inputs + i) = inputs_.at(i).Tensor(); + } + + const size_t num_outs = outputs_.size(); + LITERT_STACK_ARRAY(Qnn_Tensor_t, outputs, num_outs, QNN_TENSOR_INIT); + for (size_t i = 0; i < num_outs; ++i) { + *(outputs + i) = outputs_.at(i).Tensor(); + } + + if (auto status = qnn_manager_.Api()->graphExecute( + graph_handle_, inputs, num_ins, outputs, num_outs, + /*profileHandle=*/nullptr, /*signalHandle=*/nullptr); + status != QNN_SUCCESS) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to execute graph"); + } + + return {}; +} diff --git a/tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h b/tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h new file mode 100644 index 00000000..ec5f761c --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h @@ -0,0 +1,82 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ + +#include + +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer.h" +#include "tflite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tflite/experimental/litert/vendors/qualcomm/context_binary_info.h" +#include "tflite/experimental/litert/vendors/qualcomm/dispatch/registry.h" +#include "tflite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +class LiteRtDispatchDeviceContextT; + +class LiteRtDispatchInvocationContextT { + public: + using Ptr = std::unique_ptr; + + ~LiteRtDispatchInvocationContextT() = default; + + static litert::Expected Create( + litert::qnn::QnnManager& qnn_manager, + LiteRtDispatchDeviceContextT& device_context, + const void* exec_bytecode_ptr, size_t exec_bytecode_size, + const char* function_name); + + litert::Expected GetInputRequirements( + int input_index, const LiteRtRankedTensorType& tensor_type); + litert::Expected GetOutputRequirements( + int output_index, const LiteRtRankedTensorType& tensor_type); + + litert::Expected AttachInput( + int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle); + + litert::Expected AttachOutput( + int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle); + + litert::Expected Execute(); + + Qnn_ContextHandle_t ContextHandle() { return context_handle_.get(); } + + private: + LiteRtDispatchInvocationContextT( + litert::qnn::QnnManager& qnn_manager, + const litert::qnn::ContextBinaryInfo& context_binary_info, + LiteRtDispatchDeviceContextT& device_context, + litert::qnn::QnnManager::ContextHandle&& context_handle, + Qnn_ProfileHandle_t profile_handle, int graph_index, + Qnn_GraphHandle_t graph_handle); + + litert::Expected AttachBuffer( + Qnn_Tensor_t& tensor, LiteRtTensorBufferHandle tensor_buffer_handle); + + litert::qnn::QnnManager& qnn_manager_; + LiteRtDispatchDeviceContextT& device_context_; + litert::qnn::QnnManager::ContextHandle context_handle_; + Qnn_ProfileHandle_t profile_handle_; + int graph_index_; + Qnn_GraphHandle_t graph_handle_; + std::vector inputs_; + std::vector outputs_; +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/dispatch/registry.h b/tflite/experimental/litert/vendors/qualcomm/dispatch/registry.h new file mode 100644 index 00000000..642d3856 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/dispatch/registry.h @@ -0,0 +1,73 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_REGISTRY_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_REGISTRY_H_ + +#include + +#include "tflite/experimental/litert/cc/litert_expected.h" + +namespace litert { +namespace qnn { + +template +class Registry { + public: + Expected Register(const V& value) { + // TODO: improve this linear search by keeping an index to the first unused + // element. + for (auto i = 0; i < entries_.size(); ++i) { + auto& entry = entries_[i]; + if (!entry.used) { + entry.value = value; + entry.used = true; + return static_cast(i); + } + } + // Grow the set of entries. + H handle = static_cast(entries_.size()); + entries_.emplace_back(value); + return handle; + } + + Expected Unregister(H handle) { + if (handle < 0 || handle >= entries_.size()) { + return Unexpected(kLiteRtStatusErrorNotFound, "Unexpected handle"); + } + entries_[handle].used = false; + return {}; + } + + Expected Get(H handle) { + if (handle < 0 || handle >= entries_.size()) { + return Unexpected(kLiteRtStatusErrorNotFound, "Unexpected handle"); + } + return &entries_[handle].value; + } + + private: + struct Entry { + V value; + bool used; + explicit Entry(const V& v) : value(v), used(true) {} + }; + + std::vector entries_; +}; + +} // namespace qnn +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_REGISTRY_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/qnn_log.cc b/tflite/experimental/litert/vendors/qualcomm/qnn_log.cc new file mode 100644 index 00000000..87539ec4 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/qnn_log.cc @@ -0,0 +1,64 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/qnn_log.h" + +#include +#include +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnLog.h" + +namespace litert::qnn { +namespace { + +void DefaultStdOutLogger(const char* fmt, QnnLog_Level_t level, + uint64_t timestamp, va_list argp) { + const char* levelStr = ""; + switch (level) { + case QNN_LOG_LEVEL_ERROR: + levelStr = " ERROR "; + break; + case QNN_LOG_LEVEL_WARN: + levelStr = "WARNING"; + break; + case QNN_LOG_LEVEL_INFO: + levelStr = " INFO "; + break; + case QNN_LOG_LEVEL_DEBUG: + levelStr = " DEBUG "; + break; + case QNN_LOG_LEVEL_VERBOSE: + levelStr = "VERBOSE"; + break; + case QNN_LOG_LEVEL_MAX: + levelStr = "UNKNOWN"; + break; + } + char buffer1[256]; + char buffer2[256]; + double ms = timestamp; + snprintf(buffer1, sizeof(buffer1), "%8.1fms [%-7s] ", ms, levelStr); + buffer1[sizeof(buffer1) - 1] = 0; + vsnprintf(buffer2, sizeof(buffer2), fmt, argp); + buffer2[sizeof(buffer1) - 2] = 0; + std::cout << buffer1 << buffer2; +} + +} // namespace + +QnnLog_Callback_t GetDefaultStdOutLogger() { return DefaultStdOutLogger; } + +} // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/qnn_log.h b/tflite/experimental/litert/vendors/qualcomm/qnn_log.h new file mode 100644 index 00000000..934a164b --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/qnn_log.h @@ -0,0 +1,28 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_LOG_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_LOG_H_ + +#include "third_party/qairt/latest/include/QNN/QnnLog.h" + +namespace litert::qnn { + +// Gets a default logger implementation to stdout. +// This is used when initializing qnn logging. +QnnLog_Callback_t GetDefaultStdOutLogger(); + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_LOG_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/qnn_manager.cc b/tflite/experimental/litert/vendors/qualcomm/qnn_manager.cc new file mode 100644 index 00000000..4f4c8ac9 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/qnn_manager.cc @@ -0,0 +1,387 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/QnnLog.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "third_party/qairt/latest/include/QNN/System/QnnSystemCommon.h" +#include "third_party/qairt/latest/include/QNN/System/QnnSystemContext.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/core/dynamic_loading.h" +#include "tflite/experimental/litert/vendors/qualcomm/common.h" +#include "tflite/experimental/litert/vendors/qualcomm/qnn_log.h" + +namespace litert::qnn { + +namespace { + +constexpr char kLibQnnGetProvidersSymbol[] = "QnnInterface_getProviders"; + +constexpr char kLibQnnSystemGetProvidersSymbol[] = + "QnnSystemInterface_getProviders"; + +typedef Qnn_ErrorHandle_t (*QnnInterfaceGetProvidersFn_t)( + const QnnInterface_t*** provider_list, uint32_t* num_providers); + +typedef Qnn_ErrorHandle_t (*QnnSystemInterfaceGetProvidersFn_t)( + const QnnSystemInterface_t***, uint32_t*); + +absl::Span LoadProvidersFromLib(void* lib_so) { + QnnInterfaceGetProvidersFn_t get_providers = nullptr; + LITERT_RETURN_VAL_IF_NOT_OK( + litert::internal::ResolveLibSymbol( + lib_so, kLibQnnGetProvidersSymbol, &get_providers), + {}); + + const QnnInterface_t** interface_providers = nullptr; + uint32_t num_providers = 0; + if (QNN_SUCCESS != get_providers(&interface_providers, &num_providers)) { + LITERT_LOG(LITERT_ERROR, "%s", "Failed to get providers\n"); + return {}; + } + + return absl::MakeSpan(interface_providers, num_providers); +} + +absl::Span LoadSystemProvidersFromLib( + void* lib_so) { + QnnSystemInterfaceGetProvidersFn_t get_providers = nullptr; + LITERT_RETURN_VAL_IF_NOT_OK( + litert::internal::ResolveLibSymbol( + lib_so, kLibQnnSystemGetProvidersSymbol, &get_providers), + {}); + + const QnnSystemInterface_t** interface_providers = nullptr; + uint32_t num_providers = 0; + if (QNN_SUCCESS != get_providers(&interface_providers, &num_providers)) { + LITERT_LOG(LITERT_ERROR, "%s", "Failed to get system providers\n"); + return {}; + } + + return absl::MakeSpan(interface_providers, num_providers); +} + +} // namespace + +QnnManager::~QnnManager() { + (void)FreeDevice(); + (void)FreeBackend(); + (void)FreeLogging(); +} + +LiteRtStatus QnnManager::LoadLib(absl::string_view path) { + LITERT_LOG(LITERT_INFO, "Loading qnn shared library from \"%s\"", + path.data()); + LITERT_RETURN_STATUS_IF_NOT_OK(litert::internal::OpenLib(path, &lib_so_)); + LITERT_LOG(LITERT_INFO, "Loaded qnn shared library", ""); + return kLiteRtStatusOk; +} + +LiteRtStatus QnnManager::LoadSystemLib(absl::string_view path) { + LITERT_RETURN_STATUS_IF_NOT_OK( + litert::internal::OpenLib(path, &lib_system_so_)); + return kLiteRtStatusOk; +} + +const QnnApi* QnnManager::Api() const { + if (interface_ == nullptr) { + return nullptr; + } + return &interface_->QNN_INTERFACE_VER_NAME; +} + +LiteRtStatus QnnManager::ResolveApi() { + if (lib_so_ == nullptr) { + LITERT_LOG(LITERT_ERROR, "%s", + "Cannot resolve functions: libQnn*.so has not been loaded.\n"); + return kLiteRtStatusErrorDynamicLoading; + } + + auto providers = LoadProvidersFromLib(lib_so_); + for (const auto& prov : providers) { + const bool major = + prov->apiVersion.coreApiVersion.major == QNN_API_VERSION_MAJOR; + + const bool minor = + prov->apiVersion.coreApiVersion.minor == QNN_API_VERSION_MINOR; + + const bool patch = + prov->apiVersion.coreApiVersion.patch == QNN_API_VERSION_PATCH; + + if (major && minor && patch) { + interface_ = prov; + break; + } + } + + if (interface_ == nullptr) { + LITERT_LOG(LITERT_ERROR, "%s", "No valid interface was provided\n"); + return kLiteRtStatusErrorDynamicLoading; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus QnnManager::ResolveSystemApi() { + if (lib_so_ == nullptr) { + LITERT_LOG(LITERT_ERROR, "%s", + "Cannot resolve functions: libQnn*.so has not been loaded.\n"); + return kLiteRtStatusErrorDynamicLoading; + } + + auto system_providers = LoadSystemProvidersFromLib(lib_system_so_); + for (const auto& system_prov : system_providers) { + const bool major = + system_prov->systemApiVersion.major == QNN_SYSTEM_API_VERSION_MAJOR; + + const bool minor = + system_prov->systemApiVersion.minor == QNN_SYSTEM_API_VERSION_MINOR; + + const bool patch = + system_prov->systemApiVersion.patch == QNN_SYSTEM_API_VERSION_PATCH; + + if (major && minor && patch) { + system_interface_ = system_prov; + break; + } + } + + if (system_interface_ == nullptr) { + LITERT_LOG(LITERT_ERROR, "%s", "No valid system interface was provided\n"); + return kLiteRtStatusErrorDynamicLoading; + } + + return kLiteRtStatusOk; +} + +const QnnSystemApi* QnnManager::SystemApi() const { + if (system_interface_ == nullptr) { + return nullptr; + } + return &system_interface_->QNN_SYSTEM_INTERFACE_VER_NAME; +} + +LiteRtStatus QnnManager::FreeLogging() { + if (log_handle_ != nullptr) { + if (QNN_SUCCESS != Api()->logFree(log_handle_)) { + LITERT_LOG(LITERT_ERROR, "%s", "Failed to free logging\n"); + return kLiteRtStatusErrorNotFound; + } + } + log_handle_ = nullptr; + return kLiteRtStatusOk; +} + +LiteRtStatus QnnManager::FreeBackend() { + if (backend_handle_ != nullptr) { + if (QNN_SUCCESS != Api()->backendFree(backend_handle_)) { + LITERT_LOG(LITERT_ERROR, "%s", "Failed to free backend\n"); + return kLiteRtStatusErrorNotFound; + } + } + backend_handle_ = nullptr; + return kLiteRtStatusOk; +} + +LiteRtStatus QnnManager::FreeDevice() { + if (device_handle_ != nullptr) { + if (QNN_SUCCESS != Api()->deviceFree(device_handle_)) { + LITERT_LOG(LITERT_ERROR, "%s", "Failed to free device\n"); + return kLiteRtStatusErrorNotFound; + } + } + device_handle_ = nullptr; + return kLiteRtStatusOk; +} + +LiteRtStatus QnnManager::GenerateContextBinary( + Qnn_ContextHandle_t context_handle, std::vector& buffer) { + Qnn_ContextBinarySize_t bin_size = 0; + if (QNN_SUCCESS != Api()->contextGetBinarySize(context_handle, &bin_size)) { + LITERT_LOG(LITERT_ERROR, "%s", "Failed to get context bin size\n"); + return kLiteRtStatusErrorNotFound; + } + buffer.clear(); + buffer.resize(bin_size); + + Qnn_ContextBinarySize_t written_bin_size = 0; + if (QNN_SUCCESS != Api()->contextGetBinary(context_handle, buffer.data(), + buffer.size(), + &written_bin_size)) { + LITERT_LOG(LITERT_ERROR, "%s", "Failed to generated context binary \n"); + return kLiteRtStatusErrorNotFound; + } + + LITERT_LOG(LITERT_INFO, "Serialized a context bin of size (bytes): %lu\n", + written_bin_size); + + return kLiteRtStatusOk; +} + +LiteRtStatus QnnManager::Init(absl::Span configs, + std::optional shared_library_dir, + std::optional soc_model) { + if (shared_library_dir.has_value()) { + // We must change the variable environment used to load DSP libraries. + std::string new_adsp_library_path; + if (auto* adsp_library_path = getenv("ADSP_LIBRARY_PATH"); + adsp_library_path != nullptr) { + new_adsp_library_path = absl::StrFormat( + "%s:%s", shared_library_dir->data(), adsp_library_path); + } else { + new_adsp_library_path = shared_library_dir->data(); + } + LITERT_LOG(LITERT_INFO, "Setting ADSP_LIBRARY_PATH to %s", + new_adsp_library_path.data()); + setenv("ADSP_LIBRARY_PATH", new_adsp_library_path.data(), /*overwrite=*/1); + } + + auto lib_qnn_htp_so_path = + shared_library_dir.has_value() + ? absl::StrFormat("%s/%s", shared_library_dir->data(), kLibQnnHtpSo) + : kLibQnnHtpSo; + LITERT_RETURN_STATUS_IF_NOT_OK(LoadLib(lib_qnn_htp_so_path)); + LITERT_RETURN_STATUS_IF_NOT_OK(ResolveApi()); + + auto lib_qnn_system_so_path = + shared_library_dir.has_value() + ? absl::StrFormat("%s/%s", shared_library_dir->data(), + kLibQnnSystemSo) + : kLibQnnSystemSo; + LITERT_RETURN_STATUS_IF_NOT_OK(LoadSystemLib(lib_qnn_system_so_path)); + LITERT_RETURN_STATUS_IF_NOT_OK(ResolveSystemApi()); + + if (auto status = Api()->logCreate(GetDefaultStdOutLogger(), + QNN_LOG_LEVEL_INFO, &LogHandle()); + status != QNN_SUCCESS) { + LITERT_LOG(LITERT_ERROR, "Failed to create QNN logger: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + if (auto status = + Api()->backendCreate(LogHandle(), configs.data(), &BackendHandle()); + status != QNN_SUCCESS) { + LITERT_LOG(LITERT_ERROR, "Failed to create QNN backend: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + if (soc_model.has_value()) { + LITERT_LOG(LITERT_INFO, + "Initializing QNN backend for device architecture %d", + *soc_model); + QnnHtpDevice_CustomConfig_t arch_custom_config = {}; + arch_custom_config.option = QNN_HTP_DEVICE_CONFIG_OPTION_ARCH; + arch_custom_config.arch.arch = *soc_model; + arch_custom_config.arch.deviceId = 0; + + QnnDevice_Config_t arch_device_config = {}; + arch_device_config.option = QNN_DEVICE_CONFIG_OPTION_CUSTOM; + arch_device_config.customConfig = &arch_custom_config; + + const QnnDevice_Config_t* device_configs[2] = { + &arch_device_config, + nullptr, + }; + + if (auto status = + Api()->deviceCreate(nullptr, device_configs, &DeviceHandle()); + status != QNN_SUCCESS) { + LITERT_LOG(LITERT_ERROR, "Failed to create QNN device: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + } + + return kLiteRtStatusOk; +} + +Expected +QnnManager::CreateSystemContextHandle() { + QnnSystemContext_Handle_t system_context_handle; + if (auto status = SystemApi()->systemContextCreate(&system_context_handle); + status != QNN_SUCCESS) { + LITERT_LOG(LITERT_ERROR, "Failed to create QNN system context: %d", status); + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to create QNN system context"); + } + auto deleter = SystemApi()->systemContextFree; + return SystemContextHandle{system_context_handle, deleter}; +} + +Expected QnnManager::CreateContextHandle( + absl::Span configs) { + Qnn_ContextHandle_t context_handle; + if (auto status = Api()->contextCreate(BackendHandle(), DeviceHandle(), + configs.data(), &context_handle); + status != QNN_SUCCESS) { + LITERT_LOG(LITERT_ERROR, "Failed to create QNN context: %d", status); + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to create QNN context"); + } + auto deleter = Api()->contextFree; + return ContextHandle{context_handle, /*profile_handle=*/nullptr, deleter}; +} + +Expected QnnManager::CreateContextHandle( + absl::Span configs, + absl::Span bytecode, Qnn_ProfileHandle_t profile_handle) { + Qnn_ContextHandle_t context_handle; + if (auto status = Api()->contextCreateFromBinary( + BackendHandle(), DeviceHandle(), configs.data(), bytecode.data(), + bytecode.size(), &context_handle, profile_handle); + status != QNN_SUCCESS) { + LITERT_LOG(LITERT_ERROR, "Failed to create QNN context: %d", status); + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to create QNN context"); + } + auto deleter = Api()->contextFree; + return ContextHandle{context_handle, profile_handle, deleter}; +} + +Expected QnnManager::Create( + absl::Span configs, + std::optional shared_library_dir, + std::optional soc_model) { + Ptr qnn_manager(new QnnManager); + if (auto status = qnn_manager->Init(configs, shared_library_dir, soc_model); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to set up QNN manager"); + } + return qnn_manager; +} + +absl::Span QnnManager::DefaultBackendConfigs() { + static const QnnBackend_Config_t* configs[] = {nullptr}; + return absl::MakeSpan(configs); +} + +absl::Span QnnManager::DefaultContextConfigs() { + static const QnnContext_Config_t* configs[] = {nullptr}; + return absl::MakeSpan(configs); +} + +}; // namespace litert::qnn diff --git a/tflite/experimental/litert/vendors/qualcomm/qnn_manager.h b/tflite/experimental/litert/vendors/qualcomm/qnn_manager.h new file mode 100644 index 00000000..adeb1142 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/qnn_manager.h @@ -0,0 +1,226 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_MANAGER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_MANAGER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpDevice.h" +#include "third_party/qairt/latest/include/QNN/QnnBackend.h" +#include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnContext.h" +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/System/QnnSystemContext.h" +#include "third_party/qairt/latest/include/QNN/System/QnnSystemInterface.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/cc/litert_expected.h" +#include "tflite/experimental/litert/cc/litert_macros.h" // IWYU pragma: keep +#include "tflite/experimental/litert/vendors/qualcomm/common.h" + +//===----------------------------------------------------------------------===// +// +// QnnManger +// +// Syntactic sugar for various Qnn Sdk routines. +// +// Provides various utilities for linking shared libraries at runtime +// against Qnn symbols as well as convience getters and storage of handles +// (pointers). Provides simple wrappers for freeing handles and returning +// LiteRtStatus rather than Qnn ones. Additionally exposes hooks for dumping +// api and shared libarary details. +// +// Does not own any memory and will always have trivial cstor/dstor. The +// user is responsible for freeing any Qnn handles explicitly. Note, +// Qnn handles will be automatically freed when the library is unloaded +// if they have been already. +// +//===----------------------------------------------------------------------===// + +namespace litert::qnn { + +class QnnManager; + +namespace internal { + +void Dump(const QnnManager& qnn, std::ostream& out); + +} // namespace internal + +class QnnManager { + friend void internal::Dump(const QnnManager& qnn, std::ostream& out); + + public: + using Ptr = std::unique_ptr; + using SystemContextHandle = + std::unique_ptr::type, + QnnSystemContext_FreeFn_t>; + class ContextHandle; + + ~QnnManager(); + + static Expected Create( + absl::Span configs, + std::optional shared_library_dir = std::nullopt, + std::optional soc_model = std::nullopt); + + static absl::Span DefaultBackendConfigs(); + static absl::Span DefaultContextConfigs(); + + // Get resolved function pointers for qnn sdk calls. Nullptr if functions + // have not been resolved yet. + const QnnApi* Api() const; + + // Get resolved function pointers for qnn sdk calls. Nullptr if functions + // have not been resolved yet. + const QnnSystemApi* SystemApi() const; + + // + // QNN SDK Objects. + // + + // Create system context handle. + Expected CreateSystemContextHandle(); + + // Create a context handle for compilation. + Expected CreateContextHandle( + absl::Span configs); + + // Create a context handle for inference, from a given bytecode. + Expected CreateContextHandle( + absl::Span configs, + absl::Span bytecode, Qnn_ProfileHandle_t profile_handle); + + // + // Context Binary + // + + // Generates QNN context binary from current context. Writes to given + // buffer. + LiteRtStatus GenerateContextBinary(Qnn_ContextHandle_t context_handle, + std::vector& buffer); + + private: + QnnManager() = default; + + LiteRtStatus Init(absl::Span configs, + std::optional shared_library_dir, + std::optional soc_model); + + // + // Manage libQnn*.so Loading + // + + // Loads the libQnn*.so at given path. + LiteRtStatus LoadLib(absl::string_view path); + + // Loads the libQnnSystem.so at given path. + LiteRtStatus LoadSystemLib(absl::string_view path); + + // + // Resolve and Access QNN SDK Functions + // + + // Resolve all available QNN SDK functions from (already) loaded so. If + // multiple providers are found, selects the first one with a suitable + // version. Fails if none can be found. + LiteRtStatus ResolveApi(); + + // Resolve all available QNN SDK functions from (already) loaded so. If + // multiple providers are found, selects the first one with a suitable + // version. Fails if none can be found. + LiteRtStatus ResolveSystemApi(); + + // Get qnn log handle. Nullptr if logCreate has not been successfully called. + Qnn_LogHandle_t& LogHandle() { return log_handle_; } + + // Get qnn backend handle. Nullptr if backendCreate has not been successfully + // called. + Qnn_BackendHandle_t& BackendHandle() { return backend_handle_; } + + // Get qnn device handle. Nullptr if deviceCreate has not been successfully + // called. + Qnn_DeviceHandle_t& DeviceHandle() { return device_handle_; } + + // Signal QNN SDK to free any memory related to the device. Does nothing + // if deviceCreate has not been called. + LiteRtStatus FreeDevice(); + + // Signal QNN SDK to free any memory related to logging. Does nothing + // if logCreate has not been called. + LiteRtStatus FreeLogging(); + + // Signal QNN SDK to free any memory related to backend. Does nothing + // if backendCreate has not been called. + LiteRtStatus FreeBackend(); + + void* lib_so_ = nullptr; + void* lib_system_so_ = nullptr; + + const QnnInterface_t* interface_ = nullptr; + const QnnSystemInterface_t* system_interface_ = nullptr; + + Qnn_LogHandle_t log_handle_ = nullptr; + Qnn_BackendHandle_t backend_handle_ = nullptr; + Qnn_DeviceHandle_t device_handle_ = nullptr; +}; + +// Unfortunately we can't use std::unique_ptr with a deleter because +// QnnContext_FreeFn_t takes a profile handle as a second argument. +class QnnManager::ContextHandle { + public: + ContextHandle(Qnn_ContextHandle_t context_handle, Qnn_ProfileHandle_t profile, + QnnContext_FreeFn_t free_fn) + : context_handle_(context_handle), profile_(profile), free_fn_(free_fn) {} + + ~ContextHandle() { + if (context_handle_ && free_fn_) { + free_fn_(context_handle_, profile_); + } + } + + ContextHandle(ContextHandle&& other) { *this = std::move(other); } + + ContextHandle(const ContextHandle& other) = delete; + + ContextHandle& operator=(ContextHandle&& other) { + std::swap(context_handle_, other.context_handle_); + std::swap(profile_, other.profile_); + std::swap(free_fn_, other.free_fn_); + return *this; + } + + ContextHandle& operator=(const ContextHandle& other) = delete; + + Qnn_ContextHandle_t get() const noexcept { return context_handle_; } + explicit operator bool() const noexcept { return context_handle_ != nullptr; } + + private: + Qnn_ContextHandle_t context_handle_ = nullptr; + Qnn_ProfileHandle_t profile_ = nullptr; + QnnContext_FreeFn_t free_fn_ = nullptr; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_MANAGER_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/qnn_manager_test.cc b/tflite/experimental/litert/vendors/qualcomm/qnn_manager_test.cc new file mode 100644 index 00000000..1f650710 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/qnn_manager_test.cc @@ -0,0 +1,50 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +#include "tflite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +#include + +#include +#include +#include "tflite/experimental/litert/test/common.h" +#include "tflite/experimental/litert/vendors/qualcomm/tools/dump.h" + +namespace { + +using ::litert::qnn::QnnManager; +using ::litert::qnn::internal::Dump; +using ::testing::HasSubstr; + +// NOTE: This tests that all of the dynamic loading works properly and +// the QNN SDK instance can be properly initialized and destroyed. + +TEST(QnnManagerTest, SetupQnnManager) { + auto configs = QnnManager::DefaultBackendConfigs(); + auto qnn = QnnManager::Create(configs); + ASSERT_TRUE(qnn); +} + +TEST(QnnManagerTest, Dump) { + auto configs = QnnManager::DefaultBackendConfigs(); + auto qnn = QnnManager::Create(configs); + ASSERT_TRUE(qnn); + + std::ostringstream dump; + Dump(**qnn, dump); + + EXPECT_THAT(dump.str(), HasSubstr("< QnnInterface_t >")); + EXPECT_THAT(dump.str(), HasSubstr("< QnnSystemInterface_t >")); +} + +} // namespace diff --git a/tflite/experimental/litert/vendors/qualcomm/qnn_tensor.cc b/tflite/experimental/litert/vendors/qualcomm/qnn_tensor.cc new file mode 100644 index 00000000..68404a94 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/qnn_tensor.cc @@ -0,0 +1,104 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/qnn_tensor.h" + +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/c/litert_common.h" +#include "tflite/experimental/litert/c/litert_logging.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +namespace litert { +namespace qnn { + +QnnTensor::QnnTensor(const QnnTensor& other) : QnnTensor(other.Tensor()) { + auto status = DeepCopy(); + // This should never fail because the input QnnTensor was already deep-copied. + if (!status) { + LITERT_LOG(LITERT_ERROR, "Failed to build QnnTensor: %s", + status.Error().Message().data()); + ABSL_CHECK(status); + } +} + +QnnTensor::QnnTensor(QnnTensor&& other) { + tensor_ = other.tensor_; + // Swap managed memory. + std::swap(name_, other.name_); + std::swap(dimensions_, other.dimensions_); + std::swap(is_dynamic_dimensions_, other.is_dynamic_dimensions_); +} + +Expected QnnTensor::Create(const Qnn_Tensor_t& tensor) { + QnnTensor qnn_tensor(tensor); + if (auto status = qnn_tensor.DeepCopy(); !status) { + return Unexpected(status.Error()); + } + return qnn_tensor; +} + +Expected QnnTensor::DeepCopy() { + if (tensor_.version == QNN_TENSOR_VERSION_1) { + dimensions_.reserve(tensor_.v1.rank); + std::copy(tensor_.v1.dimensions, tensor_.v1.dimensions + tensor_.v1.rank, + std::back_inserter(dimensions_)); + tensor_.v1.dimensions = dimensions_.data(); + + // FIXME: Implement deep copy for quantizeParams. + if (tensor_.v1.quantizeParams.quantizationEncoding == + QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION || + tensor_.v1.quantizeParams.quantizationEncoding == + QNN_QUANTIZATION_ENCODING_VECTOR) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Unsupported QNN quantization"); + } + + } else if (tensor_.version == QNN_TENSOR_VERSION_2) { + dimensions_.reserve(tensor_.v2.rank); + std::copy(tensor_.v2.dimensions, tensor_.v2.dimensions + tensor_.v2.rank, + std::back_inserter(dimensions_)); + tensor_.v2.dimensions = dimensions_.data(); + + if (tensor_.v2.isDynamicDimensions) { + is_dynamic_dimensions_.reserve(tensor_.v2.rank); + std::copy(tensor_.v2.isDynamicDimensions, + tensor_.v2.isDynamicDimensions + tensor_.v2.rank, + std::back_inserter(is_dynamic_dimensions_)); + tensor_.v2.isDynamicDimensions = is_dynamic_dimensions_.data(); + } + + // FIXME: Implement deep copy for quantizeParams. + if (tensor_.v2.quantizeParams.quantizationEncoding == + QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION || + tensor_.v2.quantizeParams.quantizationEncoding == + QNN_QUANTIZATION_ENCODING_VECTOR) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Unsupported QNN quantization"); + } + + } else { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Unsupported QNN tensor version"); + } + + return {}; +} + +} // namespace qnn +} // namespace litert diff --git a/tflite/experimental/litert/vendors/qualcomm/qnn_tensor.h b/tflite/experimental/litert/vendors/qualcomm/qnn_tensor.h new file mode 100644 index 00000000..386d9b57 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/qnn_tensor.h @@ -0,0 +1,60 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_TENSOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_TENSOR_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tflite/experimental/litert/cc/litert_expected.h" + +namespace litert { +namespace qnn { + +class QnnTensor { + public: + static Expected Create(const Qnn_Tensor_t& tensor); + + QnnTensor(const QnnTensor& other); + QnnTensor(QnnTensor&& other); + + QnnTensor& operator=(const QnnTensor&) = delete; + QnnTensor& operator=(QnnTensor&&) = delete; + + Qnn_Tensor_t& Tensor() { return tensor_; } + const Qnn_Tensor_t& Tensor() const { return tensor_; } + + size_t Rank() const { return dimensions_.size(); } + const uint32_t* Dimensions() const { return dimensions_.data(); } + + private: + explicit QnnTensor(const Qnn_Tensor_t& tensor) : tensor_(tensor) {} + Expected DeepCopy(); + + Qnn_Tensor_t tensor_; + std::string name_; + std::vector dimensions_; + std::vector is_dynamic_dimensions_; +}; + +} // namespace qnn +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_TENSOR_H_ diff --git a/tflite/experimental/litert/vendors/qualcomm/qualcomm_build_defs.bzl b/tflite/experimental/litert/vendors/qualcomm/qualcomm_build_defs.bzl new file mode 100644 index 00000000..f30d5a3f --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/qualcomm_build_defs.bzl @@ -0,0 +1,118 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Build definitions for QualComm backend.""" + +load("//tflite/experimental/litert/build_common:litert_build_defs.bzl", "append_rule_kwargs", "litert_bin", "litert_lib", "make_rpaths") + +_QNN_LIBCC_X86_64 = [ + # copybara:uncomment_begin(google-only) + # "@org_tensorflow//third_party/qairt/latest:lib/x86_64-linux-clang/libc++.so.1", + # "@org_tensorflow//third_party/qairt/latest:lib/x86_64-linux-clang/libc++abi.so.1", + # copybara:uncomment_end +] # @unused + +# TODO: Make rpaths dynamic with "$(location {})". +_QNN_LIB_RPATHS_X86_64 = [ + # copybara:uncomment_begin(google-only) + # "third_party/qairt/latest/lib/x86_64-linux-clang", + # copybara:uncomment_end +] + +_QNN_LIB_HTP_X86_64 = [ + # copybara:uncomment_begin(google-only) + # "@org_tensorflow//third_party/qairt/latest:lib/x86_64-linux-clang/libQnnHtp.so", + # copybara:uncomment_end +] + +_QNN_LIB_SYSTEM_X86_64 = [ + # copybara:uncomment_begin(google-only) + # "@org_tensorflow//third_party/qairt/latest:lib/x86_64-linux-clang/libQnnSystem.so", + # copybara:uncomment_end +] + +def _litert_with_qnn_base( + litert_rule, + backend, + include_system, + use_custom_libcc, + **litert_rule_kwargs): + if backend != "htp": + fail("Only htp currently supported") + + if use_custom_libcc: + # TODO: Figure out strategy for custom libcc. + fail("Custom libcc not yet supported") + + data_x86_64 = [] + data_x86_64.extend(_QNN_LIB_HTP_X86_64) + if include_system: + data_x86_64.extend(_QNN_LIB_SYSTEM_X86_64) + data = select({ + "@org_tensorflow//tensorflow:linux_x86_64": data_x86_64, + "//conditions:default": [], + }) + + append_rule_kwargs( + litert_rule_kwargs, + data = data, + linkopts = select({ + "@org_tensorflow//tensorflow:linux_x86_64": [make_rpaths(_QNN_LIB_RPATHS_X86_64)], + "//conditions:default": [], + }), + ) + + litert_rule(**litert_rule_kwargs) + +def litert_cc_lib_with_qnn( + backend = "htp", + include_system = False, + use_custom_libcc = False, + **litert_lib_kwargs): + """Creates a litert_lib target with QualComm backend dependencies. + + Args: + backend: The backend to use. Currently only "htp" is supported. + include_system: Whether to include libQnnSystem.so. + use_custom_libcc: Whether to use a custom libcc. Not yet supported. + **litert_lib_kwargs: Keyword arguments passed to litert_lib. + """ + _litert_with_qnn_base( + litert_lib, + backend, + include_system, + use_custom_libcc, + **litert_lib_kwargs + ) + +def litert_cc_bin_with_qnn( + backend = "htp", + include_system = False, + use_custom_libcc = False, + **litert_bin_kwargs): + """Creates a litert_bin target with QualComm backend dependencies. + + Args: + backend: The backend to use. Currently only "htp" is supported. + include_system: Whether to include libQnnSystem.so. + use_custom_libcc: Whether to use a custom libcc. Not yet supported. + **litert_bin_kwargs: Keyword arguments passed to litert_bin. + """ + _litert_with_qnn_base( + litert_bin, + backend, + include_system, + use_custom_libcc, + **litert_bin_kwargs + ) diff --git a/tflite/experimental/litert/vendors/qualcomm/tools/BUILD b/tflite/experimental/litert/vendors/qualcomm/tools/BUILD new file mode 100644 index 00000000..ec3bc743 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/tools/BUILD @@ -0,0 +1,31 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"], + default_visibility = ["//tflite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "dump", + srcs = ["dump.cc"], + hdrs = ["dump.h"], + tags = ["nobuilder"], + deps = [ + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + # copybara:uncomment "@org_tensorflow//third_party/qairt/latest:qnn_lib_headers", + "//tflite/experimental/litert/vendors/qualcomm:qnn_manager_hdr", + ], +) diff --git a/tflite/experimental/litert/vendors/qualcomm/tools/dump.cc b/tflite/experimental/litert/vendors/qualcomm/tools/dump.cc new file mode 100644 index 00000000..d51b91ab --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/tools/dump.cc @@ -0,0 +1,88 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/experimental/litert/vendors/qualcomm/tools/dump.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/System/QnnSystemInterface.h" +#include "tflite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +namespace litert::qnn::internal { +namespace { + +static constexpr absl::string_view kNullDumpTpl = "%s : nullptr\n"; + +void Dump(const QnnInterface_t* interface, std::ostream& out) { + static constexpr absl::string_view kQnnInterfaceHeader = "< QnnInterface_t >"; + // NOLINTBEGIN + static constexpr absl::string_view kQnnInterfaceDumpTpl = + "\ + %s\n\ + name: %s\n\ + backend_id: %u\n\ + core_api_version: %u.%u.%u\n\ + backend_api_version: %u.%u.%u\n"; + // NOLINTEND + + if (interface == nullptr) { + out << absl::StreamFormat(kNullDumpTpl, kQnnInterfaceHeader); + return; + } + + const auto core_version = interface->apiVersion.coreApiVersion; + const auto backend_version = interface->apiVersion.backendApiVersion; + + out << absl::StreamFormat(kQnnInterfaceDumpTpl, kQnnInterfaceHeader, + interface->providerName, interface->backendId, + core_version.major, core_version.minor, + core_version.patch, backend_version.major, + backend_version.minor, backend_version.patch); +} + +void Dump(const QnnSystemInterface_t* interface, std::ostream& out) { + static constexpr absl::string_view kQnnSystemInterfaceHeader = + "< QnnSystemInterface_t >"; + // NOLINTBEGIN + static constexpr absl::string_view kQnnSystemInterfaceDumpTpl = + "\ + %s\n\ + name: %s\n\ + backend_id: %u\n\ + system_api_version: %u.%u.%u\n"; + // NOLINTEND + + if (interface == nullptr) { + out << absl::StreamFormat(kNullDumpTpl, kQnnSystemInterfaceHeader); + return; + } + + const auto system_version = interface->systemApiVersion; + + out << absl::StreamFormat(kQnnSystemInterfaceDumpTpl, + kQnnSystemInterfaceHeader, interface->providerName, + interface->backendId, system_version.major, + system_version.minor, system_version.patch); +} + +} // namespace + +void Dump(const QnnManager& qnn, std::ostream& out) { + Dump(qnn.interface_, out); + Dump(qnn.system_interface_, out); +} +} // namespace litert::qnn::internal diff --git a/tflite/experimental/litert/vendors/qualcomm/tools/dump.h b/tflite/experimental/litert/vendors/qualcomm/tools/dump.h new file mode 100644 index 00000000..11fd4a33 --- /dev/null +++ b/tflite/experimental/litert/vendors/qualcomm/tools/dump.h @@ -0,0 +1,29 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_TOOLS_DUMP_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_TOOLS_DUMP_H_ + +#include +#include + +#include "tflite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +namespace litert::qnn::internal { + +void Dump(const QnnManager& qnn, std::ostream& out = std::cerr); + +} + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_TOOLS_DUMP_H_