diff --git a/projects/jdwp/codegen/dataclass_generator.py b/projects/jdwp/codegen/dataclass_generator.py new file mode 100644 index 0000000..b0ab223 --- /dev/null +++ b/projects/jdwp/codegen/dataclass_generator.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from textwrap import dedent +from projects.jdwp.codegen.types import python_type_for +import typing + +from projects.jdwp.defs.schema import ( + Array, + Field, + Struct, + TaggedUnion, +) + + +StructLink = typing.Tuple[Struct, Field, Struct] + + +class StructGenerator: + def __init__(self, root: Struct, name: str): + self.__root = root + self.__struct_to_name = compute_struct_names(root, name) + + def __get_python_type_for(self, struct: Struct, field: Field) -> str: + type = field.type + match type: + case Struct(): + return self.__struct_to_name[type] + case Array(): + array_type = typing.cast(Array, type) + return f"typing.List[{self.__struct_to_name[array_type.element_type]}]" + case TaggedUnion(): + tagged_union_type = typing.cast(TaggedUnion, type) + union_types = [ + self.__struct_to_name[case_struct] + for case_struct in tagged_union_type.cases + ] + union_types_str = ", ".join(union_types) + return f"typing.Union[{union_types_str}]" + case _: + return python_type_for(type) + + def __generate_dataclass(self, struct: Struct) -> str: + name = self.__struct_to_name[struct] + fields_def = "\n".join( + f" {field.name}: {self.__get_python_type_for(struct, field)}" + for field in struct.fields + ) + class_def = f"@dataclasses.dataclass(frozen=True)\nclass {name}:\n{fields_def}" + return dedent(class_def) + + def generate(self): + return [ + self.__generate_dataclass(nested) + for _, _, nested in reversed(list(nested_structs(self.__root))) + ] + [self.__generate_dataclass(self.__root)] + + +def format_enum_name(enum_value): + words = enum_value.name.split("_") + formatted_name = "".join(word.capitalize() for word in words) + return f"{formatted_name}Type" + + +def nested_structs(root: Struct) -> typing.Generator[StructLink, None, None]: + for field in root.fields: + field_type = field.type + match field_type: + case Array(): + array_type = typing.cast(Array, field_type) + yield root, field, array_type.element_type + yield from nested_structs(array_type.element_type) + case TaggedUnion(): + tagged_union_type = typing.cast(TaggedUnion, field_type) + for _, struct in tagged_union_type.cases: + yield root, field, struct + yield from nested_structs(struct) + case Struct(): + yield root, field, field_type + yield from nested_structs(field_type) + + +def compute_struct_names(root: Struct, name: str) -> typing.Mapping[Struct, str]: + names = {root: name} + for parent, field, nested in nested_structs(root): + type = field.type + match type: + case Struct(): + names[nested] = f"{names[parent]}{field.name.capitalize()}" + case Array(): + names[nested] = f"{names[parent]}{field.name.capitalize()}Element" + case TaggedUnion(): + tagged_union_type = typing.cast(TaggedUnion, type) + for case_value, case_struct in tagged_union_type.cases: + case_name = format_enum_name(case_value) + names[ + case_struct + ] = f"{names[parent]}{field.name.capitalize()}Case{case_name}" + return names diff --git a/projects/jdwp/tests/test_dataclass_generator.py b/projects/jdwp/tests/test_dataclass_generator.py new file mode 100644 index 0000000..46905fa --- /dev/null +++ b/projects/jdwp/tests/test_dataclass_generator.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import unittest +from projects.jdwp.defs.schema import ArrayLength, IntegralType, Struct, Field, Array +from projects.jdwp.codegen.dataclass_generator import StructGenerator + + +class TestStructGenerator(unittest.TestCase): + def test_simple_struct(self): + simple_struct = Struct( + fields=[ + Field(name="id", type=IntegralType.INT, description="An integer ID") + ] + ) + + generator = StructGenerator(simple_struct, "SimpleStruct") + + result = generator.generate() + + expected = [ + "@dataclasses.dataclass(frozen=True)\n" + "class SimpleStruct:\n" + " id: int" + ] + + self.assertEqual(result, expected) + + def test_nested_struct(self): + inner_struct = Struct( + fields=[ + Field( + name="inner_field", + type=IntegralType.INT, + description="Inner integer field", + ) + ] + ) + outer_struct = Struct( + fields=[ + Field(name="nested", type=inner_struct, description="Nested structure") + ] + ) + + generator = StructGenerator(outer_struct, "OuterStruct") + + result = generator.generate() + + expected = [ + "@dataclasses.dataclass(frozen=True)\n" + "class OuterStructNested:\n" + " inner_field: int", + "@dataclasses.dataclass(frozen=True)\n" + "class OuterStruct:\n" + " nested: OuterStructNested", + ] + + self.assertEqual(result, expected) + + def test_struct_in_array(self): + # Define a structure + element_struct = Struct( + fields=[ + Field( + name="element_field", + type=IntegralType.INT, + description="Element field", + ) + ] + ) + + array_length = ArrayLength(type=IntegralType.INT) + + array_struct = Struct( + fields=[ + Field( + name="array", + type=Array( + element_type=element_struct, + length=Field( + name="length", type=array_length, description="Array length" + ), + ), + description="Array of structures", + ) + ] + ) + + generator = StructGenerator(array_struct, "ArrayStruct") + + result = generator.generate() + + expected = [ + "@dataclasses.dataclass(frozen=True)\n" + "class ArrayStructArrayElement:\n" + " element_field: int", + "@dataclasses.dataclass(frozen=True)\n" + "class ArrayStruct:\n" + " array: typing.List[ArrayStructArrayElement]", + ] + + self.assertEqual(result, expected)