Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved ModelConfiguration serialization #222

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 212 additions & 2 deletions common/setups/returnn_pytorch/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import string
import textwrap
from collections import OrderedDict
from dataclasses import fields
from dataclasses import fields, is_dataclass
from enum import Enum
from inspect import isfunction
from typing import Any, Dict, List, Optional, Set, Tuple, Union, TYPE_CHECKING

import torch
from i6_core.util import instanciate_delayed
from sisyphus import gs, tk
from sisyphus.delayed_ops import DelayedBase
from sisyphus.delayed_ops import DelayedBase, DelayedFormat
from sisyphus.hash import sis_hash_helper

if TYPE_CHECKING:
Expand Down Expand Up @@ -251,3 +252,212 @@ def build_config_constructor_serializers(
imports = list(OrderedDict.fromkeys(imports)) # remove duplications

return Call(callable_name=type(cfg).__name__, kwargs=call_kwargs, return_assign_variables=variable_name), imports


def build_config_constructor_serializers_v2(
cfg: Any,
variable_name: Optional[str] = None,
unhashed_package_root: Optional[str] = None,
) -> Tuple[Call, List[Import]]:
"""
Creates a Call object that will re-construct the given ModelConfiguration when serialized and
optionally assigns the resulting config object to a variable. Automatically generates a list of all
necessary imports in order to perform the constructor call.

Compared to the previous version, this function can also serialize enum members and values of type
list, tuple or dict. It also fixes import deduplication.

:param cfg: ModelConfiguration or dataclass object that will be re-constructed by the Call serializer
:param variable_name: Name of the variable which the constructed ModelConfiguration
will be assigned to. If None, the result will not be assigned
to a variable.
:param unhashed_package_root: Will be passed to all generated Import objects.
:return: Call object and list of necessary imports.
"""
from i6_models.config import ModuleFactoryV1

# Helper function which can call itself recursively for nested types
def serialize_value(value: Any) -> Tuple[Union[str, DelayedBase], List[Import]]:
# Switch over serialization logic for different subtypes

if isinstance(value, ModuleFactoryV1):
# Example:
# ConformerEncoderConfig(
# frontend=ModuleFactoryV1(module_class=VGGFrontend, cfg=VGGFrontendConfig(...)))
# -> Import classes ModuleFactoryV1, VGGFrontend and VGGFrontendConfig
# -> Sub-Constructor-Call for VGGFrontendConfig
subcall, subimports = build_config_constructor_serializers_v2(
value.cfg, unhashed_package_root=unhashed_package_root
)
subimports.append(
Import(
code_object_path=f"{value.module_class.__module__}.{value.module_class.__name__}",
unhashed_package_root=(
unhashed_package_root
if unhashed_package_root is not None
and value.module_class.__module__.startswith(unhashed_package_root)
else None
),
)
)
subimports.append(
Import(
code_object_path=f"{ModuleFactoryV1.__module__}.{ModuleFactoryV1.__name__}",
unhashed_package_root=(
unhashed_package_root
if unhashed_package_root is not None
and ModuleFactoryV1.__module__.startswith(unhashed_package_root)
else None
),
)
)
return (
Call(
callable_name=ModuleFactoryV1.__name__,
kwargs=[
("module_class", value.module_class.__name__),
("cfg", subcall),
],
),
subimports,
)
elif is_dataclass(value):
# Example:
# ConformerBlockConfig(mhsa_config=ConformerMHSAConfig(...))
# -> Sub-Constructor-Call and imports for ConformerMHSAConfig
return build_config_constructor_serializers_v2(value, unhashed_package_root=unhashed_package_root)
elif isinstance(value, torch.nn.Module):
# Example:
# ConformerConvolutionConfig(norm=BatchNorm1d(...))
# -> Import class BatchNorm1d
# -> Sub-serialization of BatchNorm1d object.
# The __str__ function of torch.nn.Module already does this in the way we want.
return str(value), [
Import(
code_object_path=f"{value.__module__}.{type(value).__name__}",
unhashed_package_root=(
unhashed_package_root
if unhashed_package_root is not None and value.__module__.startswith(unhashed_package_root)
else None
),
)
]
elif isfunction(value):
# Example:
# ConformerConvolutionConfig(activation=torch.nn.functional.silu)
# -> Import function silu
# Builtins (e.g. 'sum') do not need to be imported
if value.__module__ != "builtins":
subimports = [
Import(
code_object_path=f"{value.__module__}.{value.__name__}",
unhashed_package_root=(
unhashed_package_root
if unhashed_package_root is not None and value.__module__.startswith(unhashed_package_root)
else None
),
)
]
else:
subimports = []
return value.__name__, subimports
elif isinstance(value, Enum):
# Example:
# FrontendLayerType.Conv2d
# -> Import enum class FrontendLayerType
subimports = [
Import(
code_object_path=f"{value.__class__.__module__}.{value.__class__.__name__}",
unhashed_package_root=(
unhashed_package_root
if unhashed_package_root is not None
and value.__class__.__module__.startswith(unhashed_package_root)
else None
),
)
]
return f"{value.__class__.__name__}.{value.name}", subimports
elif isinstance(value, list):
# -> Serialize list values individually, collect subimports
list_items = []
list_imports = []
for item in value:
item_serialized, item_imports = serialize_value(item)
list_items.append(item_serialized)
list_imports += item_imports
return DelayedFormat(f"[{', '.join(['{}'] * len(list_items))}]", *list_items), list_imports
elif isinstance(value, tuple):
# -> Serialize tuple values individually, collect subimports
tuple_items = []
tuple_imports = []
for item in value:
item_serialized, item_imports = serialize_value(item)
tuple_items.append(item_serialized)
tuple_imports += item_imports
return DelayedFormat(f"({', '.join(['{}'] * len(tuple_items))})", *tuple_items), tuple_imports
elif isinstance(value, dict):
# -> Serialize dict values individually, collect subimports
dict_items = [] # Will alternatingly contain key and value of all dict items
dict_imports = []
for key, val in value.items():
val_serialized, item_imports = serialize_value(val)
dict_items += [key, val_serialized]
dict_imports += item_imports
return (
DelayedFormat("{{" + ", ".join(["{}: {}"] * (len(dict_items) // 2)) + "}}", *dict_items),
dict_imports,
)
elif isinstance(value, tk.Path):
return DelayedFormat('tk.Path("{}")', value), [Import("sisyphus.tk")]
elif isinstance(value, DelayedBase):
# sisyphus variables are just given as-is and will be instanciated only when calling "get".
return value, []
elif isinstance(value, str):
return f'"{value}"', []
else:
# No special case (usually python primitives)
# -> Just get string representation
return str(value), []

# Import the class of `cfg`
imports = [
Import(
code_object_path=f"{type(cfg).__module__}.{type(cfg).__name__}",
unhashed_package_root=(
unhashed_package_root
if unhashed_package_root is not None and type(cfg).__module__.startswith(unhashed_package_root)
else None
),
)
]

call_kwargs = []

# Iterate over all dataclass fields and apply helper function to all values
for key in fields(type(cfg)):
# Value corresponding to dataclass field name
value = getattr(cfg, key.name)

serialized_value, value_imports = serialize_value(value)
call_kwargs.append((key.name, serialized_value))
imports += value_imports

# Deduplicate imports
seen_hashes = set()
unique_imports = []
for imp in imports:
imp_hash = hash(imp)
if imp_hash not in seen_hashes:
seen_hashes.add(imp_hash)
unique_imports.append(imp)

unique_imports.sort(key=lambda imp: str(imp))

return (
Call(
callable_name=type(cfg).__name__,
kwargs=call_kwargs,
return_assign_variables=variable_name,
),
unique_imports,
)
4 changes: 2 additions & 2 deletions common/setups/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,8 @@ class Call(SerializerObject):
def __init__(
self,
callable_name: str,
kwargs: Optional[List[Tuple[str, Union[str, DelayedBase]]]] = None,
unhashed_kwargs: Optional[List[Tuple[str, Union[str, DelayedBase]]]] = None,
kwargs: Optional[List[Tuple[str, Any]]] = None,
unhashed_kwargs: Optional[List[Tuple[str, Any]]] = None,
return_assign_variables: Optional[Union[str, List[str]]] = None,
) -> None:
"""
Expand Down