Skip to content

Commit

Permalink
Update numpy_utils.cc to support ml_dtypes.bfloat16 (#2758)
Browse files Browse the repository at this point in the history
This commit adds NPY_USERDEF for ml_dtypes.bfloat16 only. Other types are not supported yet.

BUG=#2759
BUG=2759
  • Loading branch information
jaeyoo authored Nov 12, 2024
1 parent 79c3fde commit 26ada36
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions python/tflite_micro/numpy_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
case kTfLiteFloat16:
return NPY_FLOAT16;
case kTfLiteBFloat16:
// TODO(b/329491949): NPY_BFLOAT16 currently doesn't exist
return NPY_FLOAT16;
// TODO(b/329491949): Supports other ml_dtypes user-defined types.
return NPY_USERDEF;
case kTfLiteFloat64:
return NPY_FLOAT64;
case kTfLiteInt32:
Expand Down Expand Up @@ -114,6 +114,10 @@ TfLiteType TfLiteTypeFromPyType(int py_type) {
return kTfLiteComplex64;
case NPY_COMPLEX128:
return kTfLiteComplex128;
case NPY_USERDEF:
// User-defined types are defined in ml_dtypes. (bfloat16, float8, etc.)
// Fow now, we only support bfloat16.
return kTfLiteBFloat16;
// Avoid default so compiler errors created when new types are made.
}
return kTfLiteNoType;
Expand Down

0 comments on commit 26ada36

Please sign in to comment.