Skip to content

Commit

Permalink
Don't decode default bytes values in TFT Schema
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 576126369
  • Loading branch information
tf-transform-team authored and tfx-copybara committed Oct 24, 2023
1 parent 807025e commit a6cca79
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
11 changes: 10 additions & 1 deletion tensorflow_transform/tf_metadata/schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,16 @@ def _standardize_default_value(
assert isinstance(default_value, list), spec.default_value
# Convert bytes to string
if spec.dtype == tf.string:
default_value = [value.decode('utf-8') for value in default_value]

# Handle bytes string by trying to decode them (for legacy backwards
# compatibility) and if failed, keep the default value as bytes.
def try_decode(value: bytes) -> Union[str, bytes]:
try:
return value.decode('utf-8')
except UnicodeError:
return value

default_value = [try_decode(value) for value in default_value]
# Unwrap a list with a single element.
if len(default_value) == 1:
default_value = default_value[0]
Expand Down
29 changes: 29 additions & 0 deletions tensorflow_transform/tf_metadata/schema_utils_test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,35 @@
'seq_string_feature': schema_pb2.StringDomain(value=['a', 'b'])
}
},
{
'testcase_name': 'fixed_len_bytes_encoding',
'ascii_proto': """
feature {
name: "x"
type: BYTES
value_count {
min: 1
max: 1
}
}
tensor_representation_group {
key: ""
value {
tensor_representation {
key: "x"
value {
dense_tensor {
column_name: "x"
shape { dim { size: 1 } }
default_value { bytes_value: "\\xd0" }
}
}
}
}
}
""",
'feature_spec': {'x': tf.io.FixedLenFeature([1], tf.string, b'\xd0')},
},
]

INVALID_SCHEMA_PROTOS = [
Expand Down

0 comments on commit a6cca79

Please sign in to comment.