Skip to content

Commit

Permalink
Rename is_supported_empty_aggregation_type and `is_supported_aggreg…
Browse files Browse the repository at this point in the history
…ation_type` functions.

PiperOrigin-RevId: 673582513
  • Loading branch information
niketkumar authored and t5-copybara committed Sep 12, 2024
1 parent c2f0c9c commit 8d297e8
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion t5x/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ def detect_checkpoint_type(
return checkpoint_type


def _is_supported_empty_value(value: Any) -> bool:
if hasattr(ocp.type_handlers, 'is_supported_empty_aggregation_type'):
return ocp.type_handlers.is_supported_empty_aggregation_type(value)
return ocp.type_handlers.is_supported_empty_value(value)


def get_restore_parameters(
directory: epath.Path,
structure: PyTree,
Expand Down Expand Up @@ -280,7 +286,7 @@ def _get_param_info(
name: str,
meta_or_value: Union[Any, ocp.metadata.tree.ValueMetadataEntry],
) -> Union[ocp.type_handlers.ParamInfo, Any]:
if ocp.type_handlers.is_supported_empty_aggregation_type(meta_or_value):
if _is_supported_empty_value(meta_or_value):
# Empty node, ParamInfo should not be returned.
return meta_or_value
elif not isinstance(meta_or_value, ocp.metadata.tree.ValueMetadataEntry):
Expand Down

0 comments on commit 8d297e8

Please sign in to comment.