diff --git a/projects/jdwp/codegen/dataclass_generator.py b/projects/jdwp/codegen/dataclass_generator.py index 7206db8..064805e 100644 --- a/projects/jdwp/codegen/dataclass_generator.py +++ b/projects/jdwp/codegen/dataclass_generator.py @@ -32,12 +32,7 @@ def __get_python_type_for(self, struct: Struct, field: Field) -> str: return self.__struct_to_name[type] case Array(): array_type = typing.cast(Array, type) - if isinstance(array_type.element_type, Struct): - return ( - f"typing.List[{self.__struct_to_name[array_type.element_type]}]" - ) - else: - return f"typing.List[{python_type_for(array_type.element_type)}]" + return f"typing.List[{self.__struct_to_name[array_type.element_type]}]" case TaggedUnion(): tagged_union_type = typing.cast(TaggedUnion, type) union_types = [ @@ -49,10 +44,6 @@ def __get_python_type_for(self, struct: Struct, field: Field) -> str: case _: return python_type_for(type) - def __get_element_struct_name(self, field: Field) -> str: - element_type = field.type.element_type - return self.__struct_to_name[element_type] - def __is_explicit_field(self, field: Field) -> bool: return not isinstance(field.type, (ArrayLength, UnionTag)) @@ -83,9 +74,10 @@ def __generate_serialize_method(self, struct: Struct) -> str: def __generate_serialize_array_code(self, field: Field) -> str: array_length_field_type = field.type.length + field_name = self.__get_field_name(field) array_code = ( - f"output.write_{array_length_field_type.lower()}(len(self.{field.name}))\n" - f"for element in self.{field.name}:\n" + f"output.write_{array_length_field_type.lower()}(len(self.{field_name}))\n" + f"for element in self.{field_name}:\n" f" element.serialize(output)\n" ) return dedent(array_code) @@ -93,75 +85,83 @@ def __generate_serialize_array_code(self, field: Field) -> str: def __generate_serialize_tagged_union_code( self, field: Field, union: TaggedUnion ) -> str: - union_code = "match self.{0}:\n".format(field.name) + field_name = self.__get_field_name(field) + union_code = "match self.{0}:\n".format(field_name) for enum_val, struct_type in union.cases: union_code += f" case {struct_type.__name__}():\n" union_code += f" output.write_{union.tag.type.name.lower()}({enum_val.value})\n" - union_code += f" self.{field.name}.serialize(output)\n" + union_code += f" self.{field_name}.serialize(output)\n" return dedent(union_code) def __generate_serialize_field_code(self, field: Field) -> str: + field_name = self.__get_field_name(field) match field.type: case Array(): return self.__generate_serialize_array_code(field) case TaggedUnion(): return self.__generate_serialize_tagged_union_code(field) case _: - return f"output.write_{field.type.name.lower()}(self.{field.name})" + return f"output.write_{field_name}(self.{field_name})" def __generate_parse_method(self, struct: Struct) -> str: struct_name = self.__struct_to_name[struct] parse_code = f"@staticmethod\nasync def parse(input: JDWPInputStreamBase) -> {struct_name}:\n" for field in struct.fields: - parse_code += ( - f" {field.name} = {self.__generate_parse_field_code(field)}\n" - ) + if self.__is_explicit_field(field): + field_name = self.__get_field_name(field) + parse_code += ( + f" {field_name} = {self.__generate_parse_field_code(field)}\n" + ) parse_code += f" return {struct_name}(\n" for field in struct.fields: - parse_code += f" {field.name}={field.name},\n" + if self.__is_explicit_field(field): + field_name = self.__get_field_name(field) + parse_code += f" {field_name}={field_name},\n" parse_code += " )" return dedent(parse_code) def __generate_parse_array_code(self, field: Field) -> str: - element_struct_name = self.__get_element_struct_name(field) - array_length_field = field.type.length + field_name = self.__get_field_name(field) + array_code = ( - f"{field.name}: typing.List[{element_struct_name}] = []\n" - f"for _ in range({array_length_field.name}):\n" - f" {field.name}.append(await {element_struct_name}.parse(input))\n" + f"{field_name}: typing.List[{field_name}] = []\n" + f"for _ in range({field_name}):\n" + f" {field_name}.append(await {field_name}.parse(input))\n" ) return dedent(array_code) def __generate_parse_tagged_union_code( self, field: Field, union: TaggedUnion ) -> str: - tag_field_name = self.__struct_to_name[union.tag.type] - union_code = ( - f"tag = await input.read_{tag_field_name}() # Read tag\n" f"match tag:\n" - ) + tag_type = union.tag.type + tag_read_method = f"read_{tag_type.tag.value.lower()}" + field_name = self.__get_field_name(field) + + # Generate the union code + union_code = f"tag = await input.{tag_read_method}() # Read tag\nmatch tag:\n" for enum_val, struct_type in union.cases: struct_name = self.__struct_to_name[struct_type] union_code += f" case {enum_val.value}:\n" - union_code += f" {field.name} = await {struct_name}.parse(input)\n" + union_code += f" {field_name} = await {struct_name}.parse(input)\n" union_code += " default:\n" - union_code += " raise ValueError('Unexpected tag value: {{tag}}')\n" + union_code += " raise ValueError('Unexpected tag value: {tag}')\n" return dedent(union_code) - def __generate_parse_field_code(self, field: Field) -> str: - match field.type: - case Array(): - return self.__generate_parse_array_code(field) - case TaggedUnion(): - return self.__generate_parse_tagged_union_code(field) - case _: - return f"await input.read_{field.type.name.lower()}()" + def __generate_parse_field_code(self, field: Field) -> str: + match field.type: + case Array(): + return self.__generate_parse_array_code(field) + case TaggedUnion(): + return self.__generate_parse_tagged_union_code(field) + case _: + return f"await input.read_{field.type.name.lower()}()" - def generate(self) -> typing.Generator[str, None, None]: - for _, _, nested in reversed(list(nested_structs(self.__root))): - yield self.__generate_dataclass(nested) - yield self.__generate_dataclass(self.__root) + def generate(self) -> typing.Generator[str, None, None]: + for _, _, nested in reversed(list(nested_structs(self.__root))): + yield self.__generate_dataclass(nested) + yield self.__generate_dataclass(self.__root) def format_enum_name(enum_value: enum.Enum) -> str: @@ -176,9 +176,8 @@ def nested_structs(root: Struct) -> typing.Generator[StructLink, None, None]: match field_type: case Array(): array_type = typing.cast(Array, field_type) - if isinstance(array_type.element_type, Struct): - yield root, field, array_type.element_type - yield from nested_structs(array_type.element_type) + yield root, field, array_type.element_type + yield from nested_structs(array_type.element_type) case TaggedUnion(): tagged_union_type = typing.cast(TaggedUnion, field_type) for _, struct in tagged_union_type.cases: diff --git a/projects/jdwp/codegen/new_type_generator.py b/projects/jdwp/codegen/new_type_generator.py index 844a93b..6846f50 100644 --- a/projects/jdwp/codegen/new_type_generator.py +++ b/projects/jdwp/codegen/new_type_generator.py @@ -11,6 +11,7 @@ def get_type_alias_definition(jdwp_type: IdType) -> str: def generate_new_types(): + print('import typing') for id_type in IdType: type_alias_definition = get_type_alias_definition(id_type) print(type_alias_definition)