-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Wrote test to validate primitive type mapping * Added test for type alias definition * Built dataclasses code generator * Built test for dataclassess code generator * Changed to unittest * Updated struct dataclass function to support all types * feat: built dataclass generator * Added copyright * Added nested struct support * Added test for nested struct * Feat: update functions and put it in a class. * chore: update tests * Casted types to pass pyre check * Updated types * chore: Refactor typing.cast code * Added copyright * refactor: update tests to reflect new schema changes * fix: Fix pyre typing errors
- Loading branch information
Showing
2 changed files
with
199 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
|
||
from textwrap import dedent | ||
from projects.jdwp.codegen.types import python_type_for | ||
import typing | ||
|
||
from projects.jdwp.defs.schema import ( | ||
Array, | ||
Field, | ||
Struct, | ||
TaggedUnion, | ||
) | ||
|
||
|
||
StructLink = typing.Tuple[Struct, Field, Struct] | ||
|
||
|
||
class StructGenerator: | ||
def __init__(self, root: Struct, name: str): | ||
self.__root = root | ||
self.__struct_to_name = compute_struct_names(root, name) | ||
|
||
def __get_python_type_for(self, struct: Struct, field: Field) -> str: | ||
type = field.type | ||
match type: | ||
case Struct(): | ||
return self.__struct_to_name[type] | ||
case Array(): | ||
array_type = typing.cast(Array, type) | ||
return f"typing.List[{self.__struct_to_name[array_type.element_type]}]" | ||
case TaggedUnion(): | ||
tagged_union_type = typing.cast(TaggedUnion, type) | ||
union_types = [ | ||
self.__struct_to_name[case_struct] | ||
for case_struct in tagged_union_type.cases | ||
] | ||
union_types_str = ", ".join(union_types) | ||
return f"typing.Union[{union_types_str}]" | ||
case _: | ||
return python_type_for(type) | ||
|
||
def __generate_dataclass(self, struct: Struct) -> str: | ||
name = self.__struct_to_name[struct] | ||
fields_def = "\n".join( | ||
f" {field.name}: {self.__get_python_type_for(struct, field)}" | ||
for field in struct.fields | ||
) | ||
class_def = f"@dataclasses.dataclass(frozen=True)\nclass {name}:\n{fields_def}" | ||
return dedent(class_def) | ||
|
||
def generate(self): | ||
return [ | ||
self.__generate_dataclass(nested) | ||
for _, _, nested in reversed(list(nested_structs(self.__root))) | ||
] + [self.__generate_dataclass(self.__root)] | ||
|
||
|
||
def format_enum_name(enum_value): | ||
words = enum_value.name.split("_") | ||
formatted_name = "".join(word.capitalize() for word in words) | ||
return f"{formatted_name}Type" | ||
|
||
|
||
def nested_structs(root: Struct) -> typing.Generator[StructLink, None, None]: | ||
for field in root.fields: | ||
field_type = field.type | ||
match field_type: | ||
case Array(): | ||
array_type = typing.cast(Array, field_type) | ||
yield root, field, array_type.element_type | ||
yield from nested_structs(array_type.element_type) | ||
case TaggedUnion(): | ||
tagged_union_type = typing.cast(TaggedUnion, field_type) | ||
for _, struct in tagged_union_type.cases: | ||
yield root, field, struct | ||
yield from nested_structs(struct) | ||
case Struct(): | ||
yield root, field, field_type | ||
yield from nested_structs(field_type) | ||
|
||
|
||
def compute_struct_names(root: Struct, name: str) -> typing.Mapping[Struct, str]: | ||
names = {root: name} | ||
for parent, field, nested in nested_structs(root): | ||
type = field.type | ||
match type: | ||
case Struct(): | ||
names[nested] = f"{names[parent]}{field.name.capitalize()}" | ||
case Array(): | ||
names[nested] = f"{names[parent]}{field.name.capitalize()}Element" | ||
case TaggedUnion(): | ||
tagged_union_type = typing.cast(TaggedUnion, type) | ||
for case_value, case_struct in tagged_union_type.cases: | ||
case_name = format_enum_name(case_value) | ||
names[ | ||
case_struct | ||
] = f"{names[parent]}{field.name.capitalize()}Case{case_name}" | ||
return names |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
|
||
import unittest | ||
from projects.jdwp.defs.schema import ArrayLength, IntegralType, Struct, Field, Array | ||
from projects.jdwp.codegen.dataclass_generator import StructGenerator | ||
|
||
|
||
class TestStructGenerator(unittest.TestCase): | ||
def test_simple_struct(self): | ||
simple_struct = Struct( | ||
fields=[ | ||
Field(name="id", type=IntegralType.INT, description="An integer ID") | ||
] | ||
) | ||
|
||
generator = StructGenerator(simple_struct, "SimpleStruct") | ||
|
||
result = generator.generate() | ||
|
||
expected = [ | ||
"@dataclasses.dataclass(frozen=True)\n" | ||
"class SimpleStruct:\n" | ||
" id: int" | ||
] | ||
|
||
self.assertEqual(result, expected) | ||
|
||
def test_nested_struct(self): | ||
inner_struct = Struct( | ||
fields=[ | ||
Field( | ||
name="inner_field", | ||
type=IntegralType.INT, | ||
description="Inner integer field", | ||
) | ||
] | ||
) | ||
outer_struct = Struct( | ||
fields=[ | ||
Field(name="nested", type=inner_struct, description="Nested structure") | ||
] | ||
) | ||
|
||
generator = StructGenerator(outer_struct, "OuterStruct") | ||
|
||
result = generator.generate() | ||
|
||
expected = [ | ||
"@dataclasses.dataclass(frozen=True)\n" | ||
"class OuterStructNested:\n" | ||
" inner_field: int", | ||
"@dataclasses.dataclass(frozen=True)\n" | ||
"class OuterStruct:\n" | ||
" nested: OuterStructNested", | ||
] | ||
|
||
self.assertEqual(result, expected) | ||
|
||
def test_struct_in_array(self): | ||
# Define a structure | ||
element_struct = Struct( | ||
fields=[ | ||
Field( | ||
name="element_field", | ||
type=IntegralType.INT, | ||
description="Element field", | ||
) | ||
] | ||
) | ||
|
||
array_length = ArrayLength(type=IntegralType.INT) | ||
|
||
array_struct = Struct( | ||
fields=[ | ||
Field( | ||
name="array", | ||
type=Array( | ||
element_type=element_struct, | ||
length=Field( | ||
name="length", type=array_length, description="Array length" | ||
), | ||
), | ||
description="Array of structures", | ||
) | ||
] | ||
) | ||
|
||
generator = StructGenerator(array_struct, "ArrayStruct") | ||
|
||
result = generator.generate() | ||
|
||
expected = [ | ||
"@dataclasses.dataclass(frozen=True)\n" | ||
"class ArrayStructArrayElement:\n" | ||
" element_field: int", | ||
"@dataclasses.dataclass(frozen=True)\n" | ||
"class ArrayStruct:\n" | ||
" array: typing.List[ArrayStructArrayElement]", | ||
] | ||
|
||
self.assertEqual(result, expected) |