diff --git a/tensorflow/compiler/mlir/lite/schema/schema.fbs b/tensorflow/compiler/mlir/lite/schema/schema.fbs index 78cccb9347f..d74477be913 100644 --- a/tensorflow/compiler/mlir/lite/schema/schema.fbs +++ b/tensorflow/compiler/mlir/lite/schema/schema.fbs @@ -70,6 +70,19 @@ table CustomQuantization { // Represents a specific quantization technique's parameters. union QuantizationDetails { CustomQuantization, + BlockwiseQuantization, +} + +// Parameters for blockwise quantization. +table BlockwiseQuantization { + // index to the scale tensor, the tensor can be found in tensors array in + // subgraph. + scales: int; + // index to the zero point tensor. If zero_points is -1, the zero point is + // assumed to be 0, following the convention of optional tensors in tflite. + zero_points: int; + // The block size of the tensor. + block_size: int; } // Parameters for converting a quantized tensor back to float. @@ -474,6 +487,7 @@ enum BuiltinOperator : int32 { STABLEHLO_COMPOSITE = 206, // WARNING: No runtime support STABLEHLO_SHIFT_LEFT = 207, STABLEHLO_CBRT = 208, // WARNING: No runtime support + STABLEHLO_CASE = 209, } // LINT.ThenChange(nnapi_linter/linter.proto) @@ -633,6 +647,7 @@ union BuiltinOptions2{ ReduceWindowOptions (deprecated), StableHLOCompositeOptions, StablehloShiftLeftOptions, + StablehloCaseOptions, } table StablehloGatherOptions{ @@ -777,6 +792,10 @@ table StablehloScatterOptions { update_computation_subgraph_index: int; } +table StablehloCaseOptions{ + branch_subgraph_indices : [int]; +} + enum RngAlgorithm : byte { // An algorithm auto-selected by the system according to device type. DEFAULT = 0, diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index 4179bd5dc9a..21c59eb4045 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -236,6 +236,7 @@ typedef enum { kTfLiteBuiltinStablehloComposite = 206, kTfLiteBuiltinStablehloShiftLeft = 207, kTfLiteBuiltinStablehloCbrt = 208, + kTfLiteBuiltinStablehloCase = 209, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 193a5c8c496..91b3ba65a7e 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -27,6 +27,9 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/schema/schema_generated.h" +// TODO(sosagarcia): Rework all function implementations to wrap around the +// compiler flatbuffer_conversions. +// LINT.IfChange namespace tflite { namespace { @@ -928,6 +931,9 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, return ParseStablehloShiftLeft(op, error_reporter, allocator, builtin_data); } + case BuiltinOperator_STABLEHLO_CASE: { + return ParseStablehloCase(op, error_reporter, allocator, builtin_data); + } // TODO: skip param parsing for now since ops below don't have kernels case BuiltinOperator_STABLEHLO_SLICE: case BuiltinOperator_STABLEHLO_BROADCAST_IN_DIM: @@ -2421,6 +2427,46 @@ TfLiteStatus ParseStablehloShiftLeft(const Operator* op, return kTfLiteOk; } +TfLiteStatus ParseStablehloCase(const Operator* op, + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, error_reporter, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + auto params = safe_allocator.Allocate(); + + const StablehloCaseOptions* schema_params = + op->builtin_options_2_as_StablehloCaseOptions(); + if (schema_params) { + auto LoadAttr = + [&error_reporter]( + int32_t* params_array, const size_t params_array_size_bytes, + const flatbuffers::Vector* const flatbuffer_vector, + const char* const attr_name) -> TfLiteStatus { + TfLiteStatus status = FlatBufferIntVectorToArray( + params_array_size_bytes, flatbuffer_vector, params_array, + error_reporter, "stablehlo.case"); + if (status != kTfLiteOk) { + TF_LITE_REPORT_ERROR(error_reporter, "Check the '%s' attribute.", + attr_name); + } + return status; + }; + + TF_LITE_ENSURE_STATUS(LoadAttr(params->branch_subgraph_indices, + sizeof(params->branch_subgraph_indices), + schema_params->branch_subgraph_indices(), + "branch subgraph indices")); + params->num_branches = schema_params->branch_subgraph_indices()->size(); + *builtin_data = params.release(); + return kTfLiteOk; + } + TF_LITE_REPORT_ERROR(error_reporter, + "Could not get 'stablehlo.case' operation parameters."); + return kTfLiteError; +} + // We have this parse function instead of directly returning kTfLiteOk from the // switch-case in ParseOpData because this function is used as part of the // selective registration for the OpResolver implementation in micro. @@ -2943,3 +2989,4 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } } // namespace tflite +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.cc) diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h index f03a0a02938..376b0eeb630 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.h +++ b/tensorflow/lite/core/api/flatbuffer_conversions.h @@ -456,6 +456,11 @@ TfLiteStatus ParseStablehloShiftLeft(const Operator* op, BuiltinDataAllocator* allocator, void** builtin_data); +TfLiteStatus ParseStablehloCase(const Operator* op, + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, + void** builtin_data); + } // namespace tflite #endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ diff --git a/tensorflow/lite/core/c/builtin_op_data.h b/tensorflow/lite/core/c/builtin_op_data.h index e1428e72307..cfe3d825a7f 100644 --- a/tensorflow/lite/core/c/builtin_op_data.h +++ b/tensorflow/lite/core/c/builtin_op_data.h @@ -20,642 +20,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_CORE_C_BUILTIN_OP_DATA_H_ #define TENSORFLOW_LITE_CORE_C_BUILTIN_OP_DATA_H_ -#include -#include -#include - -#include "tensorflow/lite/core/c/common.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// TfLiteReshapeParams can't have dynamic data so we fix the maximum possible -// number of dimensions. -#define TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT 8 -#define TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT 8 -#define TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT 8 -#define TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT 8 -#define TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT 8 - -// TODO(aselle): Consider using "if this then that" for testing. - -// Useful placeholder to put in otherwise empty structs to avoid size warnings. -typedef struct { - char dummy; -} EmptyStructPlaceholder; - -// IMPORTANT: All new members of structs must be added at the end to ensure -// backwards compatibility. - -// Possible padding types (for convolutions) -typedef enum { - kTfLitePaddingUnknown = 0, - kTfLitePaddingSame, - kTfLitePaddingValid, -} TfLitePadding; - -typedef enum { - kTfLiteMirrorPaddingUnknown = 0, - kTfLiteMirrorPaddingReflect, - kTfLiteMirrorPaddingSymmetric, -} TfLiteMirrorPaddingMode; - -// TODO(b/130259536): We should move this out of builtin_op_data. -typedef struct { - int width; - int height; - int width_offset; - int height_offset; -} TfLitePaddingValues; - -typedef struct { - TfLiteMirrorPaddingMode mode; -} TfLiteMirrorPaddingParams; - -// Possible fused activation functions. -typedef enum { - kTfLiteActNone = 0, - kTfLiteActRelu, - kTfLiteActReluN1To1, // min(max(-1, x), 1) - kTfLiteActRelu6, // min(max(0, x), 6) - kTfLiteActTanh, - kTfLiteActSignBit, - kTfLiteActSigmoid, -} TfLiteFusedActivation; - -typedef struct { - // Parameters for CONV_2D version 1. - TfLitePadding padding; - int stride_width; - int stride_height; - TfLiteFusedActivation activation; - - // Parameters for CONV_2D version 2. - // Note: Version 2 supports dilation values not equal to 1. - int dilation_width_factor; - int dilation_height_factor; - - // Parameters for CONV_2D version 7 or above. - // Used to determine the default value for the quantized bias. - TfLiteType quantized_bias_type; -} TfLiteConvParams; - -typedef struct { - TfLitePadding padding; - int stride_width; - int stride_height; - int stride_depth; - int dilation_width_factor; - int dilation_height_factor; - int dilation_depth_factor; - TfLiteFusedActivation activation; -} TfLiteConv3DParams; - -typedef TfLiteConv3DParams TfLiteConv3DTransposeParams; - -typedef struct { - TfLitePadding padding; - int stride_width; - int stride_height; - int filter_width; - int filter_height; - TfLiteFusedActivation activation; - struct { - TfLitePaddingValues padding; - } computed; -} TfLitePoolParams; - -typedef struct { - // Parameters for DepthwiseConv version 1 or above. - TfLitePadding padding; - int stride_width; - int stride_height; - // `depth_multiplier` is redundant. It's used by CPU kernels in - // TensorFlow 2.0 or below, but ignored in versions above. - // - // The information can be deduced from the shape of input and the shape of - // weights. Since the TFLiteConverter toolchain doesn't support partially - // specified shapes, relying on `depth_multiplier` stops us from supporting - // graphs with dynamic shape tensors. - // - // Note: Some of the delegates (e.g. NNAPI, GPU) are still relying on this - // field. - int depth_multiplier; - TfLiteFusedActivation activation; - // Parameters for DepthwiseConv version 2 or above. - int dilation_width_factor; - int dilation_height_factor; -} TfLiteDepthwiseConvParams; - -typedef struct { - int rank; - TfLiteFusedActivation activation; - - // Parameter for SVDF version 4. - bool asymmetric_quantize_inputs; -} TfLiteSVDFParams; - -typedef struct { - TfLiteFusedActivation activation; - - // Parameter for RNN version 3. - bool asymmetric_quantize_inputs; -} TfLiteRNNParams; - -typedef struct { - bool time_major; - TfLiteFusedActivation activation; - - // Parameter for Sequence RNN version 3. - bool asymmetric_quantize_inputs; -} TfLiteSequenceRNNParams; - -typedef struct { - bool time_major; - TfLiteFusedActivation activation; - bool merge_outputs; - - // Parameter for Bidirectional RNN version 3. - bool asymmetric_quantize_inputs; -} TfLiteBidirectionalSequenceRNNParams; - -typedef enum { - kTfLiteFullyConnectedWeightsFormatDefault = 0, - kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1, -} TfLiteFullyConnectedWeightsFormat; - -typedef struct { - // Parameters for FullyConnected version 1 or above. - TfLiteFusedActivation activation; - - // Parameters for FullyConnected version 2 or above. - TfLiteFullyConnectedWeightsFormat weights_format; - - // Parameters for FullyConnected version 5 or above. - // If set to true, then the number of dimensions in the input and the output - // tensors are the same. Furthermore, all but the last dimension of the input - // and output shapes will be equal. - bool keep_num_dims; - - // Parameters for FullyConnected version 7 or above. - // If set to true and the weights are quantized, then non constant inputs - // are quantized at evaluation time with asymmetric quantization. - bool asymmetric_quantize_inputs; - - // Parameters for FullyConnected version 10 or above. - // Used to determine the default value for the quantized bias. - TfLiteType quantized_bias_type; -} TfLiteFullyConnectedParams; - -typedef enum { - kTfLiteLshProjectionUnknown = 0, - kTfLiteLshProjectionSparse = 1, - kTfLiteLshProjectionDense = 2, -} TfLiteLSHProjectionType; - -typedef struct { - TfLiteLSHProjectionType type; -} TfLiteLSHProjectionParams; - -typedef struct { - float beta; -} TfLiteSoftmaxParams; - -typedef struct { - int axis; - TfLiteFusedActivation activation; -} TfLiteConcatenationParams; - -typedef struct { - TfLiteFusedActivation activation; - // Parameter added for the version 4. - bool pot_scale_int16; -} TfLiteAddParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLiteSpaceToBatchNDParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLiteBatchToSpaceNDParams; - -typedef struct { - bool adj_x; - bool adj_y; - // Parameters for BatchMatMul version 4 or above. - // If set to true and the weights are quantized, then non constant inputs - // are quantized at evaluation time with asymmetric quantization. - bool asymmetric_quantize_inputs; -} TfLiteBatchMatMulParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteMulParams; - -typedef struct { - TfLiteFusedActivation activation; - // Parameter added for the version 5. - bool pot_scale_int16; -} TfLiteSubParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteDivParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteL2NormParams; - -typedef struct { - int radius; - float bias; - float alpha; - float beta; -} TfLiteLocalResponseNormParams; - -typedef enum { - kTfLiteLSTMFullKernel = 0, - kTfLiteLSTMBasicKernel -} TfLiteLSTMKernelType; - -typedef struct { - // Parameters for LSTM version 1. - TfLiteFusedActivation activation; - float cell_clip; - float proj_clip; - - // Parameters for LSTM version 2. - // kTfLiteLSTMBasicKernel is only supported in version 2 or above. - TfLiteLSTMKernelType kernel_type; - - // Parameters for LSTM version 4. - bool asymmetric_quantize_inputs; -} TfLiteLSTMParams; - -typedef struct { - // Parameters needed for the underlying LSTM. - TfLiteFusedActivation activation; - float cell_clip; - float proj_clip; - - // If set to true then the first dimension is time, otherwise batch. - bool time_major; - - // Parameter for unidirectional sequence RNN version 3. - bool asymmetric_quantize_inputs; - - // Parameter for unidirectional sequence RNN version 4. - bool diagonal_recurrent_tensors; -} TfLiteUnidirectionalSequenceLSTMParams; - -typedef struct { - // Parameters supported by version 1: - // Parameters inherited for the LSTM kernel. - TfLiteFusedActivation activation; - float cell_clip; - float proj_clip; - - // If true, store the outputs of both directions in the first output. - bool merge_outputs; - - // Parameters supported by version 2: - // If set to true then the first dimension is time, otherwise batch. - bool time_major; - - // Parameters supported by version 3: - // If set to true, then hybrid ops use asymmetric quantization for inputs. - bool asymmetric_quantize_inputs; -} TfLiteBidirectionalSequenceLSTMParams; - -typedef struct { - bool align_corners; - // half_pixel_centers assumes pixels are of half the actual dimensions, and - // yields more accurate resizes. Corresponds to the same argument for the - // original TensorFlow op in TF2.0. - bool half_pixel_centers; -} TfLiteResizeBilinearParams; - -typedef struct { - bool align_corners; - bool half_pixel_centers; -} TfLiteResizeNearestNeighborParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLitePadParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLitePadV2Params; - -typedef struct { - // These fields are only used in old models for backward compatibility. - // In the current implementation, we use the 2nd input of the op as the shape, - // and these fields are unused. - int32_t shape[TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT]; - int num_dimensions; -} TfLiteReshapeParams; - -typedef struct { - int ngram_size; - int max_skip_size; - bool include_all_ngrams; -} TfLiteSkipGramParams; - -typedef struct { - int block_size; -} TfLiteSpaceToDepthParams; - -typedef struct { - int block_size; -} TfLiteDepthToSpaceParams; - -typedef struct { - TfLiteType in_data_type; - TfLiteType out_data_type; -} TfLiteCastParams; - -typedef enum { - kTfLiteCombinerTypeSum = 0, - kTfLiteCombinerTypeMean = 1, - kTfLiteCombinerTypeSqrtn = 2, -} TfLiteCombinerType; - -typedef struct { - TfLiteCombinerType combiner; -} TfLiteEmbeddingLookupSparseParams; - -typedef struct { - int axis; - int batch_dims; -} TfLiteGatherParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLiteTransposeParams; - -typedef struct { - bool keep_dims; -} TfLiteReducerParams; - -typedef struct { - int num_splits; -} TfLiteSplitParams; - -typedef struct { - int num_splits; -} TfLiteSplitVParams; - -typedef struct { - // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. - // For now we will fix the maximum possible number of dimensions. - int32_t squeeze_dims[8]; - int num_squeeze_dims; -} TfLiteSqueezeParams; - -typedef struct { - int begin_mask; - int end_mask; - int ellipsis_mask; - int new_axis_mask; - int shrink_axis_mask; - - // Parameters supported by version 8: - // If true, then the end tensor is an offset of the begin tensor. - bool offset; -} TfLiteStridedSliceParams; - -typedef struct { - TfLiteType output_type; -} TfLiteArgMaxParams; - -typedef struct { - TfLiteType output_type; -} TfLiteArgMinParams; - -typedef struct { - // Parameters supported by version 1: - TfLitePadding padding; - int stride_width; - int stride_height; - - // Parameters supported by version 4: - TfLiteFusedActivation activation; - - // Parameters for TransposeConv version 5 or above. - // Used to determine the default value for the quantized bias. - TfLiteType quantized_bias_type; -} TfLiteTransposeConvParams; - -typedef struct { - bool validate_indices; -} TfLiteSparseToDenseParams; - -typedef struct { - TfLiteType out_type; -} TfLiteShapeParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLiteRankParams; - -typedef struct { - // Parameters supported by version 1: - float min; - float max; - int num_bits; - - // Parameters supported by version 2: - bool narrow_range; -} TfLiteFakeQuantParams; - -typedef struct { - int values_count; - int axis; -} TfLitePackParams; - -typedef struct { - int axis; -} TfLiteOneHotParams; - -typedef struct { - int num; - int axis; -} TfLiteUnpackParams; - -typedef struct { - float alpha; -} TfLiteLeakyReluParams; - -typedef struct { - TfLiteType index_out_type; -} TfLiteUniqueParams; - -typedef struct { - int seq_dim; - int batch_dim; -} TfLiteReverseSequenceParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLiteMatrixDiagParams; - -typedef struct { - EmptyStructPlaceholder placeholder; -} TfLiteMatrixSetDiagParams; - -typedef struct { - int then_subgraph_index; - int else_subgraph_index; -} TfLiteIfParams; - -typedef struct { - int cond_subgraph_index; - int body_subgraph_index; -} TfLiteWhileParams; - -typedef struct { - bool exclusive; - bool reverse; -} TfLiteCumsumParams; - -typedef struct { - int init_subgraph_index; -} TfLiteCallOnceParams; - -typedef struct { - int table_id; - TfLiteType key_dtype; - TfLiteType value_dtype; -} TfLiteHashtableParams; - -typedef struct { - const char* container; - const char* shared_name; -} TfLiteVarHandleParams; - -typedef struct { - int seed; - int seed2; -} TfLiteRandomParams; - -typedef struct { - int num_boundaries; - // This points to the memory stored in the model (flatbuffer), - // and is not owned. - const float* boundaries; -} TfLiteBucketizeParams; - -typedef struct { - bool approximate; -} TfLiteGeluParams; - -typedef struct { - int64_t dimension; -} TfLiteStablehloConcatenateParams; - -typedef struct { - // See the stablehlo spec for the explanation of the attributes: - // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter - bool indices_are_sorted; - int64_t - update_window_dims[TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT]; - int num_update_window_dims; - int64_t - inserted_window_dims[TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT]; - int num_inserted_window_dims; - int64_t scatter_dims_to_operand_dims - [TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT]; - int num_scatter_dims_to_operand_dims; - int64_t index_vector_dim; - bool unique_indices; - int update_computation_subgraph_index; -} TfLiteStablehloScatterParams; - -typedef enum { - kTfLiteRngAlgorithmUnknown = 0, - // An algorithm auto-selected by the system according to device type. - kTfLiteRngAlgorithmDefault, - // The Philox algorithm, as described in paper - // ['Parallel Random Numbers: As Easy as 1, 2, 3'] - // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) - kTfLiteRngAlgorithmPhilox, - // The ThreeFry algorithm, as described in paper - // ['Parallel Random Numbers: As Easy as 1, 2, 3'] - // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) - kTfLiteRngAlgorithmThreefry, -} TfLiteRngAlgorithm; - -typedef struct { - TfLiteRngAlgorithm algorithm; -} TfLiteStablehloRngBitGeneratorParams; - -typedef struct { - // See the stablehlo spec for the explanation of the attributes: - // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather - int64_t offset_dims[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; - int num_offset_dims; - int64_t - collapsed_slice_dims[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; - int num_collapsed_slice_dims; - int64_t start_index_map[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; - int num_start_index_map; - int64_t index_vector_dim; - int64_t slice_sizes[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; - int num_slice_sizes; - bool indices_are_sorted; -} TfLiteStablehloGatherParams; - -typedef struct { - // See the stablehlo spec for the explanation of the attributes: - // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_window - int64_t window_dimensions - [TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; - int64_t - window_strides[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; - int64_t - base_dilations[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; - int64_t window_dilations - [TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; - int64_t - padding[2 * TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; - int body_subgraph_index; -} TfLiteStablehloReduceWindowParams; - -enum TfLiteReduceWindowFunction { - TfLiteReduceWindowFunctionUnsupported, - TfLiteReduceWindowFunctionAdd, - TfLiteReduceWindowFunctionMul, - TfLiteReduceWindowFunctionMin, - TfLiteReduceWindowFunctionMax, - TfLiteReduceWindowFunctionAll, - TfLiteReduceWindowFunctionAny -}; - -typedef struct { - enum TfLiteReduceWindowFunction reduce_function; -} TfLiteReduceWindowParams; - -typedef struct { - // See the stablehlo spec for the explanation of the attributes: - // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#pad - int64_t edge_padding_low[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; - int64_t edge_padding_high[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; - int64_t interior_padding[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; -} TfLiteStablehloPadParams; - -typedef struct { - const char* name; - int32_t subgraph_index; - int32_t version; - const uint8_t* attributes; - size_t attributes_size; -} TfLiteStablehloCompositeParams; - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus +#include "tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h" // IWYU pragma: export +#include "tensorflow/lite/core/c/common.h" // IWYU pragma: export #endif // TENSORFLOW_LITE_CORE_C_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/lite/core/c/c_api_types.h b/tensorflow/lite/core/c/c_api_types.h index 79a00319709..b8f3b47f4dd 100644 --- a/tensorflow/lite/core/c/c_api_types.h +++ b/tensorflow/lite/core/c/c_api_types.h @@ -36,12 +36,12 @@ limitations under the License. #ifndef TENSORFLOW_LITE_CORE_C_C_API_TYPES_H_ #define TENSORFLOW_LITE_CORE_C_C_API_TYPES_H_ -#include - #ifdef __cplusplus extern "C" { #endif +#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" // IWYU pragma: export + // clang-format off // NOLINTBEGIN(whitespace/line_length) /** \defgroup c_api_types lite/c/c_api_types.h @@ -56,7 +56,7 @@ extern "C" { #define TFL_CAPI_EXPORT #elif defined(TFL_STATIC_LIBRARY_BUILD) #define TFL_CAPI_EXPORT -#else // not definded TFL_STATIC_LIBRARY_BUILD +#else // not defined TFL_STATIC_LIBRARY_BUILD #if defined(_WIN32) #ifdef TFL_COMPILE_LIBRARY #define TFL_CAPI_EXPORT __declspec(dllexport) @@ -119,42 +119,6 @@ typedef enum TfLiteStatus { kTfLiteOutputShapeNotKnown = 9, } TfLiteStatus; -/// Types supported by tensor -// LINT.IfChange -typedef enum { - kTfLiteNoType = 0, - kTfLiteFloat32 = 1, - kTfLiteInt32 = 2, - kTfLiteUInt8 = 3, - kTfLiteInt64 = 4, - kTfLiteString = 5, - kTfLiteBool = 6, - kTfLiteInt16 = 7, - kTfLiteComplex64 = 8, - kTfLiteInt8 = 9, - kTfLiteFloat16 = 10, - kTfLiteFloat64 = 11, - kTfLiteComplex128 = 12, - kTfLiteUInt64 = 13, - kTfLiteResource = 14, - kTfLiteVariant = 15, - kTfLiteUInt32 = 16, - kTfLiteUInt16 = 17, - kTfLiteInt4 = 18, - kTfLiteBFloat16 = 19, -} TfLiteType; -// LINT.ThenChange(//tensorflow/lite/profiling/proto/model_runtime_info.proto:EdgeDataType) - -/// Legacy. Will be deprecated in favor of `TfLiteAffineQuantization`. -/// If per-layer quantization is specified this field will still be populated in -/// addition to `TfLiteAffineQuantization`. -/// Parameters for asymmetric quantization. Quantized values can be converted -/// back to float using: `real_value = scale * (quantized_value - zero_point)` -typedef struct TfLiteQuantizationParams { - float scale; - int32_t zero_point; -} TfLiteQuantizationParams; - // -------------------------------------------------------------------------- // Opaque types used by c_api.h, c_api_opaque.h and common.h. diff --git a/tensorflow/lite/core/c/common.h b/tensorflow/lite/core/c/common.h index 5d310081649..1131adb669a 100644 --- a/tensorflow/lite/core/c/common.h +++ b/tensorflow/lite/core/c/common.h @@ -80,7 +80,9 @@ typedef enum TfLiteExternalContextType { kTfLiteGemmLowpContext = 1, /// include gemm_support.h to use. kTfLiteEdgeTpuContext = 2, /// Placeholder for Edge TPU support. kTfLiteCpuBackendContext = 3, /// include cpu_backend_context.h to use. - kTfLiteMaxExternalContexts = 4 + kTfLiteLiteRtBufferContext = + 4, /// include external_litert_buffer_context.h to use. + kTfLiteMaxExternalContexts = 5 } TfLiteExternalContextType; // Forward declare so dependent structs and methods can reference these types @@ -320,7 +322,7 @@ typedef struct TfLiteBFloat16 { const char* TfLiteTypeGetName(TfLiteType type); /// SupportedQuantizationTypes. -typedef enum TfLiteQuantizationType { +typedef enum TfLiteQuantizationType : int { /// No quantization. kTfLiteNoQuantization = 0, /// Affine quantization (with support for per-channel quantization). @@ -363,6 +365,7 @@ typedef union TfLitePtrUnion { uint64_t* u64; float* f; TfLiteFloat16* f16; + TfLiteBFloat16* bf16; double* f64; char* raw; const char* raw_const; @@ -442,12 +445,6 @@ enum { kTfLiteNullBufferHandle = -1, }; -/// Storage format of each dimension in a sparse tensor. -typedef enum TfLiteDimensionType { - kTfLiteDimDense = 0, - kTfLiteDimSparseCSR, -} TfLiteDimensionType; - /// Metadata to encode each dimension in a sparse tensor. typedef struct TfLiteDimensionMetadata { TfLiteDimensionType format; diff --git a/tensorflow/lite/portable_type_to_tflitetype.h b/tensorflow/lite/portable_type_to_tflitetype.h index 03357db0076..03c5b87c9a7 100644 --- a/tensorflow/lite/portable_type_to_tflitetype.h +++ b/tensorflow/lite/portable_type_to_tflitetype.h @@ -72,6 +72,7 @@ MATCH_TYPE_AND_TFLITE_TYPE(unsigned char, kTfLiteUInt8); MATCH_TYPE_AND_TFLITE_TYPE(int8_t, kTfLiteInt8); MATCH_TYPE_AND_TFLITE_TYPE(bool, kTfLiteBool); MATCH_TYPE_AND_TFLITE_TYPE(TfLiteFloat16, kTfLiteFloat16); +MATCH_TYPE_AND_TFLITE_TYPE(TfLiteBFloat16, kTfLiteBFloat16); MATCH_TYPE_AND_TFLITE_TYPE(double, kTfLiteFloat64); MATCH_TYPE_AND_TFLITE_TYPE(uint64_t, kTfLiteUInt64); diff --git a/tensorflow/lite/python/schema_py_generated.py b/tensorflow/lite/python/schema_py_generated.py index b18edcf0d8f..5fb12737d43 100755 --- a/tensorflow/lite/python/schema_py_generated.py +++ b/tensorflow/lite/python/schema_py_generated.py @@ -32,6 +32,7 @@ class TensorType(object): class QuantizationDetails(object): NONE = 0 CustomQuantization = 1 + BlockwiseQuantization = 2 def QuantizationDetailsCreator(unionType, table): from flatbuffers.table import Table @@ -39,6 +40,8 @@ def QuantizationDetailsCreator(unionType, table): return None if unionType == QuantizationDetails().CustomQuantization: return CustomQuantizationT.InitFromBuf(table.Bytes, table.Pos) + if unionType == QuantizationDetails().BlockwiseQuantization: + return BlockwiseQuantizationT.InitFromBuf(table.Bytes, table.Pos) return None @@ -276,6 +279,7 @@ class BuiltinOperator(object): STABLEHLO_COMPOSITE = 206 STABLEHLO_SHIFT_LEFT = 207 STABLEHLO_CBRT = 208 + STABLEHLO_CASE = 209 class BuiltinOptions(object): @@ -690,6 +694,7 @@ class BuiltinOptions2(object): ReduceWindowOptions = 20 StableHLOCompositeOptions = 21 StablehloShiftLeftOptions = 22 + StablehloCaseOptions = 23 def BuiltinOptions2Creator(unionType, table): from flatbuffers.table import Table @@ -739,6 +744,8 @@ def BuiltinOptions2Creator(unionType, table): return StableHLOCompositeOptionsT.InitFromBuf(table.Bytes, table.Pos) if unionType == BuiltinOptions2().StablehloShiftLeftOptions: return StablehloShiftLeftOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloCaseOptions: + return StablehloCaseOptionsT.InitFromBuf(table.Bytes, table.Pos) return None @@ -945,6 +952,109 @@ def Pack(self, builder): return customQuantization +class BlockwiseQuantization(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = BlockwiseQuantization() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsBlockwiseQuantization(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + @classmethod + def BlockwiseQuantizationBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) + + # BlockwiseQuantization + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # BlockwiseQuantization + def Scales(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # BlockwiseQuantization + def ZeroPoints(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # BlockwiseQuantization + def BlockSize(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + +def BlockwiseQuantizationStart(builder): + builder.StartObject(3) + +def BlockwiseQuantizationAddScales(builder, scales): + builder.PrependInt32Slot(0, scales, 0) + +def BlockwiseQuantizationAddZeroPoints(builder, zeroPoints): + builder.PrependInt32Slot(1, zeroPoints, 0) + +def BlockwiseQuantizationAddBlockSize(builder, blockSize): + builder.PrependInt32Slot(2, blockSize, 0) + +def BlockwiseQuantizationEnd(builder): + return builder.EndObject() + + + +class BlockwiseQuantizationT(object): + + # BlockwiseQuantizationT + def __init__(self): + self.scales = 0 # type: int + self.zeroPoints = 0 # type: int + self.blockSize = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + blockwiseQuantization = BlockwiseQuantization() + blockwiseQuantization.Init(buf, pos) + return cls.InitFromObj(blockwiseQuantization) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos+n) + + @classmethod + def InitFromObj(cls, blockwiseQuantization): + x = BlockwiseQuantizationT() + x._UnPack(blockwiseQuantization) + return x + + # BlockwiseQuantizationT + def _UnPack(self, blockwiseQuantization): + if blockwiseQuantization is None: + return + self.scales = blockwiseQuantization.Scales() + self.zeroPoints = blockwiseQuantization.ZeroPoints() + self.blockSize = blockwiseQuantization.BlockSize() + + # BlockwiseQuantizationT + def Pack(self, builder): + BlockwiseQuantizationStart(builder) + BlockwiseQuantizationAddScales(builder, self.scales) + BlockwiseQuantizationAddZeroPoints(builder, self.zeroPoints) + BlockwiseQuantizationAddBlockSize(builder, self.blockSize) + blockwiseQuantization = BlockwiseQuantizationEnd(builder) + return blockwiseQuantization + + class QuantizationParameters(object): __slots__ = ['_tab'] @@ -1153,7 +1263,7 @@ def __init__(self): self.scale = None # type: List[float] self.zeroPoint = None # type: List[int] self.detailsType = 0 # type: int - self.details = None # type: Union[None, CustomQuantizationT] + self.details = None # type: Union[None, CustomQuantizationT, BlockwiseQuantizationT] self.quantizedDimension = 0 # type: int @classmethod @@ -5925,6 +6035,125 @@ def Pack(self, builder): return stablehloScatterOptions +class StablehloCaseOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloCaseOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloCaseOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + @classmethod + def StablehloCaseOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) + + # StablehloCaseOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloCaseOptions + def BranchSubgraphIndices(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # StablehloCaseOptions + def BranchSubgraphIndicesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # StablehloCaseOptions + def BranchSubgraphIndicesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloCaseOptions + def BranchSubgraphIndicesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + +def StablehloCaseOptionsStart(builder): + builder.StartObject(1) + +def StablehloCaseOptionsAddBranchSubgraphIndices(builder, branchSubgraphIndices): + builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(branchSubgraphIndices), 0) + +def StablehloCaseOptionsStartBranchSubgraphIndicesVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def StablehloCaseOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + +class StablehloCaseOptionsT(object): + + # StablehloCaseOptionsT + def __init__(self): + self.branchSubgraphIndices = None # type: List[int] + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloCaseOptions = StablehloCaseOptions() + stablehloCaseOptions.Init(buf, pos) + return cls.InitFromObj(stablehloCaseOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos+n) + + @classmethod + def InitFromObj(cls, stablehloCaseOptions): + x = StablehloCaseOptionsT() + x._UnPack(stablehloCaseOptions) + return x + + # StablehloCaseOptionsT + def _UnPack(self, stablehloCaseOptions): + if stablehloCaseOptions is None: + return + if not stablehloCaseOptions.BranchSubgraphIndicesIsNone(): + if np is None: + self.branchSubgraphIndices = [] + for i in range(stablehloCaseOptions.BranchSubgraphIndicesLength()): + self.branchSubgraphIndices.append(stablehloCaseOptions.BranchSubgraphIndices(i)) + else: + self.branchSubgraphIndices = stablehloCaseOptions.BranchSubgraphIndicesAsNumpy() + + # StablehloCaseOptionsT + def Pack(self, builder): + if self.branchSubgraphIndices is not None: + if np is not None and type(self.branchSubgraphIndices) is np.ndarray: + branchSubgraphIndices = builder.CreateNumpyVector(self.branchSubgraphIndices) + else: + StablehloCaseOptionsStartBranchSubgraphIndicesVector(builder, len(self.branchSubgraphIndices)) + for i in reversed(range(len(self.branchSubgraphIndices))): + builder.PrependInt32(self.branchSubgraphIndices[i]) + branchSubgraphIndices = builder.EndVector() + StablehloCaseOptionsStart(builder) + if self.branchSubgraphIndices is not None: + StablehloCaseOptionsAddBranchSubgraphIndices(builder, branchSubgraphIndices) + stablehloCaseOptions = StablehloCaseOptionsEnd(builder) + return stablehloCaseOptions + + class StablehloRngBitGeneratorOptions(object): __slots__ = ['_tab'] @@ -17079,7 +17308,7 @@ def __init__(self): self.largeCustomOptionsOffset = 0 # type: int self.largeCustomOptionsSize = 0 # type: int self.builtinOptions2Type = 0 # type: int - self.builtinOptions2 = None # type: Union[None, StablehloConcatenateOptionsT, StablehloBroadcastInDimOptionsT, StablehloSliceOptionsT, StablehloConvolutionOptionsT, StablehloCustomCallOptionsT, StablehloReduceOptionsT, StablehloScatterOptionsT, StablehloCompareOptionsT, StablehloDynamicSliceOptionsT, StablehloPadOptionsT, StablehloIotaOptionsT, StablehloDotGeneralOptionsT, StablehloReduceWindowOptionsT, StablehloSortOptionsT, StablehloWhileOptionsT, StablehloGatherOptionsT, StablehloTransposeOptionsT, DilateOptionsT, StablehloRngBitGeneratorOptionsT, ReduceWindowOptionsT, StableHLOCompositeOptionsT, StablehloShiftLeftOptionsT] + self.builtinOptions2 = None # type: Union[None, StablehloConcatenateOptionsT, StablehloBroadcastInDimOptionsT, StablehloSliceOptionsT, StablehloConvolutionOptionsT, StablehloCustomCallOptionsT, StablehloReduceOptionsT, StablehloScatterOptionsT, StablehloCompareOptionsT, StablehloDynamicSliceOptionsT, StablehloPadOptionsT, StablehloIotaOptionsT, StablehloDotGeneralOptionsT, StablehloReduceWindowOptionsT, StablehloSortOptionsT, StablehloWhileOptionsT, StablehloGatherOptionsT, StablehloTransposeOptionsT, DilateOptionsT, StablehloRngBitGeneratorOptionsT, ReduceWindowOptionsT, StableHLOCompositeOptionsT, StablehloShiftLeftOptionsT, StablehloCaseOptionsT] self.debugMetadataIndex = -1 # type: int @classmethod diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index c4a37264fa0..1f055d2045f 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -19,6 +19,10 @@ struct CustomQuantization; struct CustomQuantizationBuilder; struct CustomQuantizationT; +struct BlockwiseQuantization; +struct BlockwiseQuantizationBuilder; +struct BlockwiseQuantizationT; + struct QuantizationParameters; struct QuantizationParametersBuilder; struct QuantizationParametersT; @@ -119,6 +123,10 @@ struct StablehloScatterOptions; struct StablehloScatterOptionsBuilder; struct StablehloScatterOptionsT; +struct StablehloCaseOptions; +struct StablehloCaseOptionsBuilder; +struct StablehloCaseOptionsT; + struct StablehloRngBitGeneratorOptions; struct StablehloRngBitGeneratorOptionsBuilder; struct StablehloRngBitGeneratorOptionsT; @@ -759,29 +767,32 @@ inline const char *EnumNameTensorType(TensorType e) { enum QuantizationDetails : uint8_t { QuantizationDetails_NONE = 0, QuantizationDetails_CustomQuantization = 1, + QuantizationDetails_BlockwiseQuantization = 2, QuantizationDetails_MIN = QuantizationDetails_NONE, - QuantizationDetails_MAX = QuantizationDetails_CustomQuantization + QuantizationDetails_MAX = QuantizationDetails_BlockwiseQuantization }; -inline const QuantizationDetails (&EnumValuesQuantizationDetails())[2] { +inline const QuantizationDetails (&EnumValuesQuantizationDetails())[3] { static const QuantizationDetails values[] = { QuantizationDetails_NONE, - QuantizationDetails_CustomQuantization + QuantizationDetails_CustomQuantization, + QuantizationDetails_BlockwiseQuantization }; return values; } inline const char * const *EnumNamesQuantizationDetails() { - static const char * const names[3] = { + static const char * const names[4] = { "NONE", "CustomQuantization", + "BlockwiseQuantization", nullptr }; return names; } inline const char *EnumNameQuantizationDetails(QuantizationDetails e) { - if (::flatbuffers::IsOutRange(e, QuantizationDetails_NONE, QuantizationDetails_CustomQuantization)) return ""; + if (::flatbuffers::IsOutRange(e, QuantizationDetails_NONE, QuantizationDetails_BlockwiseQuantization)) return ""; const size_t index = static_cast(e); return EnumNamesQuantizationDetails()[index]; } @@ -794,6 +805,10 @@ template<> struct QuantizationDetailsTraits { static const QuantizationDetails enum_value = QuantizationDetails_CustomQuantization; }; +template<> struct QuantizationDetailsTraits { + static const QuantizationDetails enum_value = QuantizationDetails_BlockwiseQuantization; +}; + template struct QuantizationDetailsUnionTraits { static const QuantizationDetails enum_value = QuantizationDetails_NONE; }; @@ -802,6 +817,10 @@ template<> struct QuantizationDetailsUnionTraits { static const QuantizationDetails enum_value = QuantizationDetails_CustomQuantization; }; +template<> struct QuantizationDetailsUnionTraits { + static const QuantizationDetails enum_value = QuantizationDetails_BlockwiseQuantization; +}; + struct QuantizationDetailsUnion { QuantizationDetails type; void *value; @@ -840,6 +859,14 @@ struct QuantizationDetailsUnion { return type == QuantizationDetails_CustomQuantization ? reinterpret_cast(value) : nullptr; } + tflite::BlockwiseQuantizationT *AsBlockwiseQuantization() { + return type == QuantizationDetails_BlockwiseQuantization ? + reinterpret_cast(value) : nullptr; + } + const tflite::BlockwiseQuantizationT *AsBlockwiseQuantization() const { + return type == QuantizationDetails_BlockwiseQuantization ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyQuantizationDetails(::flatbuffers::Verifier &verifier, const void *obj, QuantizationDetails type); @@ -1212,11 +1239,12 @@ enum BuiltinOperator : int32_t { BuiltinOperator_STABLEHLO_COMPOSITE = 206, BuiltinOperator_STABLEHLO_SHIFT_LEFT = 207, BuiltinOperator_STABLEHLO_CBRT = 208, + BuiltinOperator_STABLEHLO_CASE = 209, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_STABLEHLO_CBRT + BuiltinOperator_MAX = BuiltinOperator_STABLEHLO_CASE }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[209] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[210] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -1426,13 +1454,14 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[209] { BuiltinOperator_REDUCE_WINDOW, BuiltinOperator_STABLEHLO_COMPOSITE, BuiltinOperator_STABLEHLO_SHIFT_LEFT, - BuiltinOperator_STABLEHLO_CBRT + BuiltinOperator_STABLEHLO_CBRT, + BuiltinOperator_STABLEHLO_CASE }; return values; } inline const char * const *EnumNamesBuiltinOperator() { - static const char * const names[210] = { + static const char * const names[211] = { "ADD", "AVERAGE_POOL_2D", "CONCATENATION", @@ -1642,13 +1671,14 @@ inline const char * const *EnumNamesBuiltinOperator() { "STABLEHLO_COMPOSITE", "STABLEHLO_SHIFT_LEFT", "STABLEHLO_CBRT", + "STABLEHLO_CASE", nullptr }; return names; } inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { - if (::flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_STABLEHLO_CBRT)) return ""; + if (::flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_STABLEHLO_CASE)) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOperator()[index]; } @@ -4141,11 +4171,12 @@ enum BuiltinOptions2 : uint8_t { BuiltinOptions2_ReduceWindowOptions = 20, BuiltinOptions2_StableHLOCompositeOptions = 21, BuiltinOptions2_StablehloShiftLeftOptions = 22, + BuiltinOptions2_StablehloCaseOptions = 23, BuiltinOptions2_MIN = BuiltinOptions2_NONE, - BuiltinOptions2_MAX = BuiltinOptions2_StablehloShiftLeftOptions + BuiltinOptions2_MAX = BuiltinOptions2_StablehloCaseOptions }; -inline const BuiltinOptions2 (&EnumValuesBuiltinOptions2())[23] { +inline const BuiltinOptions2 (&EnumValuesBuiltinOptions2())[24] { static const BuiltinOptions2 values[] = { BuiltinOptions2_NONE, BuiltinOptions2_StablehloConcatenateOptions, @@ -4169,13 +4200,14 @@ inline const BuiltinOptions2 (&EnumValuesBuiltinOptions2())[23] { BuiltinOptions2_StablehloRngBitGeneratorOptions, BuiltinOptions2_ReduceWindowOptions, BuiltinOptions2_StableHLOCompositeOptions, - BuiltinOptions2_StablehloShiftLeftOptions + BuiltinOptions2_StablehloShiftLeftOptions, + BuiltinOptions2_StablehloCaseOptions }; return values; } inline const char * const *EnumNamesBuiltinOptions2() { - static const char * const names[24] = { + static const char * const names[25] = { "NONE", "StablehloConcatenateOptions", "StablehloBroadcastInDimOptions", @@ -4199,13 +4231,14 @@ inline const char * const *EnumNamesBuiltinOptions2() { "ReduceWindowOptions", "StableHLOCompositeOptions", "StablehloShiftLeftOptions", + "StablehloCaseOptions", nullptr }; return names; } inline const char *EnumNameBuiltinOptions2(BuiltinOptions2 e) { - if (::flatbuffers::IsOutRange(e, BuiltinOptions2_NONE, BuiltinOptions2_StablehloShiftLeftOptions)) return ""; + if (::flatbuffers::IsOutRange(e, BuiltinOptions2_NONE, BuiltinOptions2_StablehloCaseOptions)) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOptions2()[index]; } @@ -4302,6 +4335,10 @@ template<> struct BuiltinOptions2Traits { static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloShiftLeftOptions; }; +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloCaseOptions; +}; + template struct BuiltinOptions2UnionTraits { static const BuiltinOptions2 enum_value = BuiltinOptions2_NONE; }; @@ -4394,6 +4431,10 @@ template<> struct BuiltinOptions2UnionTraits static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloShiftLeftOptions; }; +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloCaseOptions; +}; + struct BuiltinOptions2Union { BuiltinOptions2 type; void *value; @@ -4600,6 +4641,14 @@ struct BuiltinOptions2Union { return type == BuiltinOptions2_StablehloShiftLeftOptions ? reinterpret_cast(value) : nullptr; } + tflite::StablehloCaseOptionsT *AsStablehloCaseOptions() { + return type == BuiltinOptions2_StablehloCaseOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloCaseOptionsT *AsStablehloCaseOptions() const { + return type == BuiltinOptions2_StablehloCaseOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions2(::flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions2 type); @@ -5115,6 +5164,80 @@ inline ::flatbuffers::Offset CreateCustomQuantizationDirect( ::flatbuffers::Offset CreateCustomQuantization(::flatbuffers::FlatBufferBuilder &_fbb, const CustomQuantizationT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct BlockwiseQuantizationT : public ::flatbuffers::NativeTable { + typedef BlockwiseQuantization TableType; + int32_t scales = 0; + int32_t zero_points = 0; + int32_t block_size = 0; +}; + +struct BlockwiseQuantization FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef BlockwiseQuantizationT NativeTableType; + typedef BlockwiseQuantizationBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_SCALES = 4, + VT_ZERO_POINTS = 6, + VT_BLOCK_SIZE = 8 + }; + int32_t scales() const { + return GetField(VT_SCALES, 0); + } + int32_t zero_points() const { + return GetField(VT_ZERO_POINTS, 0); + } + int32_t block_size() const { + return GetField(VT_BLOCK_SIZE, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_SCALES, 4) && + VerifyField(verifier, VT_ZERO_POINTS, 4) && + VerifyField(verifier, VT_BLOCK_SIZE, 4) && + verifier.EndTable(); + } + BlockwiseQuantizationT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BlockwiseQuantizationT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BlockwiseQuantizationT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BlockwiseQuantizationBuilder { + typedef BlockwiseQuantization Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_scales(int32_t scales) { + fbb_.AddElement(BlockwiseQuantization::VT_SCALES, scales, 0); + } + void add_zero_points(int32_t zero_points) { + fbb_.AddElement(BlockwiseQuantization::VT_ZERO_POINTS, zero_points, 0); + } + void add_block_size(int32_t block_size) { + fbb_.AddElement(BlockwiseQuantization::VT_BLOCK_SIZE, block_size, 0); + } + explicit BlockwiseQuantizationBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateBlockwiseQuantization( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t scales = 0, + int32_t zero_points = 0, + int32_t block_size = 0) { + BlockwiseQuantizationBuilder builder_(_fbb); + builder_.add_block_size(block_size); + builder_.add_zero_points(zero_points); + builder_.add_scales(scales); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateBlockwiseQuantization(::flatbuffers::FlatBufferBuilder &_fbb, const BlockwiseQuantizationT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct QuantizationParametersT : public ::flatbuffers::NativeTable { typedef QuantizationParameters TableType; std::vector min{}; @@ -5159,6 +5282,9 @@ struct QuantizationParameters FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::T const tflite::CustomQuantization *details_as_CustomQuantization() const { return details_type() == tflite::QuantizationDetails_CustomQuantization ? static_cast(details()) : nullptr; } + const tflite::BlockwiseQuantization *details_as_BlockwiseQuantization() const { + return details_type() == tflite::QuantizationDetails_BlockwiseQuantization ? static_cast(details()) : nullptr; + } int32_t quantized_dimension() const { return GetField(VT_QUANTIZED_DIMENSION, 0); } @@ -5187,6 +5313,10 @@ template<> inline const tflite::CustomQuantization *QuantizationParameters::deta return details_as_CustomQuantization(); } +template<> inline const tflite::BlockwiseQuantization *QuantizationParameters::details_as() const { + return details_as_BlockwiseQuantization(); +} + struct QuantizationParametersBuilder { typedef QuantizationParameters Table; ::flatbuffers::FlatBufferBuilder &fbb_; @@ -7687,6 +7817,68 @@ inline ::flatbuffers::Offset CreateStablehloScatterOpti ::flatbuffers::Offset CreateStablehloScatterOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloScatterOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct StablehloCaseOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloCaseOptions TableType; + std::vector branch_subgraph_indices{}; +}; + +struct StablehloCaseOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloCaseOptionsT NativeTableType; + typedef StablehloCaseOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_BRANCH_SUBGRAPH_INDICES = 4 + }; + const ::flatbuffers::Vector *branch_subgraph_indices() const { + return GetPointer *>(VT_BRANCH_SUBGRAPH_INDICES); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_BRANCH_SUBGRAPH_INDICES) && + verifier.VerifyVector(branch_subgraph_indices()) && + verifier.EndTable(); + } + StablehloCaseOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloCaseOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloCaseOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloCaseOptionsBuilder { + typedef StablehloCaseOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_branch_subgraph_indices(::flatbuffers::Offset<::flatbuffers::Vector> branch_subgraph_indices) { + fbb_.AddOffset(StablehloCaseOptions::VT_BRANCH_SUBGRAPH_INDICES, branch_subgraph_indices); + } + explicit StablehloCaseOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloCaseOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> branch_subgraph_indices = 0) { + StablehloCaseOptionsBuilder builder_(_fbb); + builder_.add_branch_subgraph_indices(branch_subgraph_indices); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateStablehloCaseOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *branch_subgraph_indices = nullptr) { + auto branch_subgraph_indices__ = branch_subgraph_indices ? _fbb.CreateVector(*branch_subgraph_indices) : 0; + return tflite::CreateStablehloCaseOptions( + _fbb, + branch_subgraph_indices__); +} + +::flatbuffers::Offset CreateStablehloCaseOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloCaseOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct StablehloRngBitGeneratorOptionsT : public ::flatbuffers::NativeTable { typedef StablehloRngBitGeneratorOptions TableType; tflite::RngAlgorithm algorithm = tflite::RngAlgorithm_DEFAULT; @@ -15328,6 +15520,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { const tflite::StablehloShiftLeftOptions *builtin_options_2_as_StablehloShiftLeftOptions() const { return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloShiftLeftOptions ? static_cast(builtin_options_2()) : nullptr; } + const tflite::StablehloCaseOptions *builtin_options_2_as_StablehloCaseOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloCaseOptions ? static_cast(builtin_options_2()) : nullptr; + } int32_t debug_metadata_index() const { return GetField(VT_DEBUG_METADATA_INDEX, -1); } @@ -15953,6 +16148,10 @@ template<> inline const tflite::StablehloShiftLeftOptions *Operator::builtin_opt return builtin_options_2_as_StablehloShiftLeftOptions(); } +template<> inline const tflite::StablehloCaseOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloCaseOptions(); +} + struct OperatorBuilder { typedef Operator Table; ::flatbuffers::FlatBufferBuilder &fbb_; @@ -16777,6 +16976,38 @@ inline ::flatbuffers::Offset CreateCustomQuantization(::flat _custom); } +inline BlockwiseQuantizationT *BlockwiseQuantization::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new BlockwiseQuantizationT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void BlockwiseQuantization::UnPackTo(BlockwiseQuantizationT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = scales(); _o->scales = _e; } + { auto _e = zero_points(); _o->zero_points = _e; } + { auto _e = block_size(); _o->block_size = _e; } +} + +inline ::flatbuffers::Offset BlockwiseQuantization::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BlockwiseQuantizationT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateBlockwiseQuantization(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateBlockwiseQuantization(::flatbuffers::FlatBufferBuilder &_fbb, const BlockwiseQuantizationT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const BlockwiseQuantizationT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _scales = _o->scales; + auto _zero_points = _o->zero_points; + auto _block_size = _o->block_size; + return tflite::CreateBlockwiseQuantization( + _fbb, + _scales, + _zero_points, + _block_size); +} + inline QuantizationParametersT *QuantizationParameters::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { auto _o = std::unique_ptr(new QuantizationParametersT()); UnPackTo(_o.get(), _resolver); @@ -17693,6 +17924,32 @@ inline ::flatbuffers::Offset CreateStablehloScatterOpti _update_computation_subgraph_index); } +inline StablehloCaseOptionsT *StablehloCaseOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloCaseOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloCaseOptions::UnPackTo(StablehloCaseOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = branch_subgraph_indices(); if (_e) { _o->branch_subgraph_indices.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->branch_subgraph_indices[_i] = _e->Get(_i); } } else { _o->branch_subgraph_indices.resize(0); } } +} + +inline ::flatbuffers::Offset StablehloCaseOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloCaseOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloCaseOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloCaseOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloCaseOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloCaseOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _branch_subgraph_indices = _o->branch_subgraph_indices.size() ? _fbb.CreateVector(_o->branch_subgraph_indices) : 0; + return tflite::CreateStablehloCaseOptions( + _fbb, + _branch_subgraph_indices); +} + inline StablehloRngBitGeneratorOptionsT *StablehloRngBitGeneratorOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { auto _o = std::unique_ptr(new StablehloRngBitGeneratorOptionsT()); UnPackTo(_o.get(), _resolver); @@ -21560,6 +21817,10 @@ inline bool VerifyQuantizationDetails(::flatbuffers::Verifier &verifier, const v auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case QuantizationDetails_BlockwiseQuantization: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return true; } } @@ -21583,6 +21844,10 @@ inline void *QuantizationDetailsUnion::UnPack(const void *obj, QuantizationDetai auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case QuantizationDetails_BlockwiseQuantization: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -21594,6 +21859,10 @@ inline ::flatbuffers::Offset QuantizationDetailsUnion::Pack(::flatbuffers: auto ptr = reinterpret_cast(value); return CreateCustomQuantization(_fbb, ptr, _rehasher).Union(); } + case QuantizationDetails_BlockwiseQuantization: { + auto ptr = reinterpret_cast(value); + return CreateBlockwiseQuantization(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -21604,6 +21873,10 @@ inline QuantizationDetailsUnion::QuantizationDetailsUnion(const QuantizationDeta value = new tflite::CustomQuantizationT(*reinterpret_cast(u.value)); break; } + case QuantizationDetails_BlockwiseQuantization: { + value = new tflite::BlockwiseQuantizationT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -21616,6 +21889,11 @@ inline void QuantizationDetailsUnion::Reset() { delete ptr; break; } + case QuantizationDetails_BlockwiseQuantization: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; @@ -24524,6 +24802,10 @@ inline bool VerifyBuiltinOptions2(::flatbuffers::Verifier &verifier, const void auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions2_StablehloCaseOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return true; } } @@ -24631,6 +24913,10 @@ inline void *BuiltinOptions2Union::UnPack(const void *obj, BuiltinOptions2 type, auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions2_StablehloCaseOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -24726,6 +25012,10 @@ inline ::flatbuffers::Offset BuiltinOptions2Union::Pack(::flatbuffers::Fla auto ptr = reinterpret_cast(value); return CreateStablehloShiftLeftOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions2_StablehloCaseOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloCaseOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -24820,6 +25110,10 @@ inline BuiltinOptions2Union::BuiltinOptions2Union(const BuiltinOptions2Union &u) value = new tflite::StablehloShiftLeftOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions2_StablehloCaseOptions: { + value = new tflite::StablehloCaseOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -24937,6 +25231,11 @@ inline void BuiltinOptions2Union::Reset() { delete ptr; break; } + case BuiltinOptions2_StablehloCaseOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr;