Skip to content

Commit

Permalink
Codegen dataclassess (#79)
Browse files Browse the repository at this point in the history
* Wrote test to validate primitive type mapping

* Added test for type alias definition

* Built dataclasses code generator

* Built test for dataclassess code generator

* Changed to unittest

* Updated struct dataclass function to support all types

* feat: built dataclass generator

* Added copyright

* Added nested struct support

* Added test for nested struct

* Feat: update functions and put it in a class.

* chore: update tests

* Casted types to pass pyre check

* Updated types

* chore: Refactor typing.cast code

* Added copyright

* refactor: update tests to reflect new schema changes

* fix: Fix pyre typing errors
  • Loading branch information
Daquiver1 authored Dec 6, 2023
1 parent 26272f3 commit 2e88acd
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 0 deletions.
98 changes: 98 additions & 0 deletions projects/jdwp/codegen/dataclass_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

from textwrap import dedent
from projects.jdwp.codegen.types import python_type_for
import typing

from projects.jdwp.defs.schema import (
Array,
Field,
Struct,
TaggedUnion,
)


StructLink = typing.Tuple[Struct, Field, Struct]


class StructGenerator:
def __init__(self, root: Struct, name: str):
self.__root = root
self.__struct_to_name = compute_struct_names(root, name)

def __get_python_type_for(self, struct: Struct, field: Field) -> str:
type = field.type
match type:
case Struct():
return self.__struct_to_name[type]
case Array():
array_type = typing.cast(Array, type)
return f"typing.List[{self.__struct_to_name[array_type.element_type]}]"
case TaggedUnion():
tagged_union_type = typing.cast(TaggedUnion, type)
union_types = [
self.__struct_to_name[case_struct]
for case_struct in tagged_union_type.cases
]
union_types_str = ", ".join(union_types)
return f"typing.Union[{union_types_str}]"
case _:
return python_type_for(type)

def __generate_dataclass(self, struct: Struct) -> str:
name = self.__struct_to_name[struct]
fields_def = "\n".join(
f" {field.name}: {self.__get_python_type_for(struct, field)}"
for field in struct.fields
)
class_def = f"@dataclasses.dataclass(frozen=True)\nclass {name}:\n{fields_def}"
return dedent(class_def)

def generate(self):
return [
self.__generate_dataclass(nested)
for _, _, nested in reversed(list(nested_structs(self.__root)))
] + [self.__generate_dataclass(self.__root)]


def format_enum_name(enum_value):
words = enum_value.name.split("_")
formatted_name = "".join(word.capitalize() for word in words)
return f"{formatted_name}Type"


def nested_structs(root: Struct) -> typing.Generator[StructLink, None, None]:
for field in root.fields:
field_type = field.type
match field_type:
case Array():
array_type = typing.cast(Array, field_type)
yield root, field, array_type.element_type
yield from nested_structs(array_type.element_type)
case TaggedUnion():
tagged_union_type = typing.cast(TaggedUnion, field_type)
for _, struct in tagged_union_type.cases:
yield root, field, struct
yield from nested_structs(struct)
case Struct():
yield root, field, field_type
yield from nested_structs(field_type)


def compute_struct_names(root: Struct, name: str) -> typing.Mapping[Struct, str]:
names = {root: name}
for parent, field, nested in nested_structs(root):
type = field.type
match type:
case Struct():
names[nested] = f"{names[parent]}{field.name.capitalize()}"
case Array():
names[nested] = f"{names[parent]}{field.name.capitalize()}Element"
case TaggedUnion():
tagged_union_type = typing.cast(TaggedUnion, type)
for case_value, case_struct in tagged_union_type.cases:
case_name = format_enum_name(case_value)
names[
case_struct
] = f"{names[parent]}{field.name.capitalize()}Case{case_name}"
return names
101 changes: 101 additions & 0 deletions projects/jdwp/tests/test_dataclass_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import unittest
from projects.jdwp.defs.schema import ArrayLength, IntegralType, Struct, Field, Array
from projects.jdwp.codegen.dataclass_generator import StructGenerator


class TestStructGenerator(unittest.TestCase):
def test_simple_struct(self):
simple_struct = Struct(
fields=[
Field(name="id", type=IntegralType.INT, description="An integer ID")
]
)

generator = StructGenerator(simple_struct, "SimpleStruct")

result = generator.generate()

expected = [
"@dataclasses.dataclass(frozen=True)\n"
"class SimpleStruct:\n"
" id: int"
]

self.assertEqual(result, expected)

def test_nested_struct(self):
inner_struct = Struct(
fields=[
Field(
name="inner_field",
type=IntegralType.INT,
description="Inner integer field",
)
]
)
outer_struct = Struct(
fields=[
Field(name="nested", type=inner_struct, description="Nested structure")
]
)

generator = StructGenerator(outer_struct, "OuterStruct")

result = generator.generate()

expected = [
"@dataclasses.dataclass(frozen=True)\n"
"class OuterStructNested:\n"
" inner_field: int",
"@dataclasses.dataclass(frozen=True)\n"
"class OuterStruct:\n"
" nested: OuterStructNested",
]

self.assertEqual(result, expected)

def test_struct_in_array(self):
# Define a structure
element_struct = Struct(
fields=[
Field(
name="element_field",
type=IntegralType.INT,
description="Element field",
)
]
)

array_length = ArrayLength(type=IntegralType.INT)

array_struct = Struct(
fields=[
Field(
name="array",
type=Array(
element_type=element_struct,
length=Field(
name="length", type=array_length, description="Array length"
),
),
description="Array of structures",
)
]
)

generator = StructGenerator(array_struct, "ArrayStruct")

result = generator.generate()

expected = [
"@dataclasses.dataclass(frozen=True)\n"
"class ArrayStructArrayElement:\n"
" element_field: int",
"@dataclasses.dataclass(frozen=True)\n"
"class ArrayStruct:\n"
" array: typing.List[ArrayStructArrayElement]",
]

self.assertEqual(result, expected)

0 comments on commit 2e88acd

Please sign in to comment.