Skip to content

Commit

Permalink
fix: used self get field name
Browse files Browse the repository at this point in the history
  • Loading branch information
Daquiver1 committed Dec 6, 2023
1 parent 307a76b commit 60a3b8c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 45 deletions.
89 changes: 44 additions & 45 deletions projects/jdwp/codegen/dataclass_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,7 @@ def __get_python_type_for(self, struct: Struct, field: Field) -> str:
return self.__struct_to_name[type]
case Array():
array_type = typing.cast(Array, type)
if isinstance(array_type.element_type, Struct):
return (
f"typing.List[{self.__struct_to_name[array_type.element_type]}]"
)
else:
return f"typing.List[{python_type_for(array_type.element_type)}]"
return f"typing.List[{self.__struct_to_name[array_type.element_type]}]"
case TaggedUnion():
tagged_union_type = typing.cast(TaggedUnion, type)
union_types = [
Expand All @@ -49,10 +44,6 @@ def __get_python_type_for(self, struct: Struct, field: Field) -> str:
case _:
return python_type_for(type)

def __get_element_struct_name(self, field: Field) -> str:
element_type = field.type.element_type
return self.__struct_to_name[element_type]

def __is_explicit_field(self, field: Field) -> bool:
return not isinstance(field.type, (ArrayLength, UnionTag))

Expand Down Expand Up @@ -83,85 +74,94 @@ def __generate_serialize_method(self, struct: Struct) -> str:

def __generate_serialize_array_code(self, field: Field) -> str:
array_length_field_type = field.type.length
field_name = self.__get_field_name(field)
array_code = (
f"output.write_{array_length_field_type.lower()}(len(self.{field.name}))\n"
f"for element in self.{field.name}:\n"
f"output.write_{array_length_field_type.lower()}(len(self.{field_name}))\n"
f"for element in self.{field_name}:\n"
f" element.serialize(output)\n"
)
return dedent(array_code)

def __generate_serialize_tagged_union_code(
self, field: Field, union: TaggedUnion
) -> str:
union_code = "match self.{0}:\n".format(field.name)
field_name = self.__get_field_name(field)
union_code = "match self.{0}:\n".format(field_name)
for enum_val, struct_type in union.cases:
union_code += f" case {struct_type.__name__}():\n"
union_code += f" output.write_{union.tag.type.name.lower()}({enum_val.value})\n"
union_code += f" self.{field.name}.serialize(output)\n"
union_code += f" self.{field_name}.serialize(output)\n"
return dedent(union_code)

def __generate_serialize_field_code(self, field: Field) -> str:
field_name = self.__get_field_name(field)
match field.type:
case Array():
return self.__generate_serialize_array_code(field)
case TaggedUnion():
return self.__generate_serialize_tagged_union_code(field)
case _:
return f"output.write_{field.type.name.lower()}(self.{field.name})"
return f"output.write_{field_name}(self.{field_name})"

def __generate_parse_method(self, struct: Struct) -> str:
struct_name = self.__struct_to_name[struct]
parse_code = f"@staticmethod\nasync def parse(input: JDWPInputStreamBase) -> {struct_name}:\n"

for field in struct.fields:
parse_code += (
f" {field.name} = {self.__generate_parse_field_code(field)}\n"
)
if self.__is_explicit_field(field):
field_name = self.__get_field_name(field)
parse_code += (
f" {field_name} = {self.__generate_parse_field_code(field)}\n"
)

parse_code += f" return {struct_name}(\n"
for field in struct.fields:
parse_code += f" {field.name}={field.name},\n"
if self.__is_explicit_field(field):
field_name = self.__get_field_name(field)
parse_code += f" {field_name}={field_name},\n"
parse_code += " )"
return dedent(parse_code)

def __generate_parse_array_code(self, field: Field) -> str:
element_struct_name = self.__get_element_struct_name(field)
array_length_field = field.type.length
field_name = self.__get_field_name(field)

array_code = (
f"{field.name}: typing.List[{element_struct_name}] = []\n"
f"for _ in range({array_length_field.name}):\n"
f" {field.name}.append(await {element_struct_name}.parse(input))\n"
f"{field_name}: typing.List[{field_name}] = []\n"
f"for _ in range({field_name}):\n"
f" {field_name}.append(await {field_name}.parse(input))\n"
)
return dedent(array_code)

def __generate_parse_tagged_union_code(
self, field: Field, union: TaggedUnion
) -> str:
tag_field_name = self.__struct_to_name[union.tag.type]
union_code = (
f"tag = await input.read_{tag_field_name}() # Read tag\n" f"match tag:\n"
)
tag_type = union.tag.type
tag_read_method = f"read_{tag_type.tag.value.lower()}"
field_name = self.__get_field_name(field)

# Generate the union code
union_code = f"tag = await input.{tag_read_method}() # Read tag\nmatch tag:\n"
for enum_val, struct_type in union.cases:
struct_name = self.__struct_to_name[struct_type]
union_code += f" case {enum_val.value}:\n"
union_code += f" {field.name} = await {struct_name}.parse(input)\n"
union_code += f" {field_name} = await {struct_name}.parse(input)\n"
union_code += " default:\n"
union_code += " raise ValueError('Unexpected tag value: {{tag}}')\n"
union_code += " raise ValueError('Unexpected tag value: {tag}')\n"
return dedent(union_code)

def __generate_parse_field_code(self, field: Field) -> str:
match field.type:
case Array():
return self.__generate_parse_array_code(field)
case TaggedUnion():
return self.__generate_parse_tagged_union_code(field)
case _:
return f"await input.read_{field.type.name.lower()}()"
def __generate_parse_field_code(self, field: Field) -> str:
match field.type:
case Array():
return self.__generate_parse_array_code(field)
case TaggedUnion():
return self.__generate_parse_tagged_union_code(field)
case _:
return f"await input.read_{field.type.name.lower()}()"

def generate(self) -> typing.Generator[str, None, None]:
for _, _, nested in reversed(list(nested_structs(self.__root))):
yield self.__generate_dataclass(nested)
yield self.__generate_dataclass(self.__root)
def generate(self) -> typing.Generator[str, None, None]:
for _, _, nested in reversed(list(nested_structs(self.__root))):
yield self.__generate_dataclass(nested)
yield self.__generate_dataclass(self.__root)


def format_enum_name(enum_value: enum.Enum) -> str:
Expand All @@ -176,9 +176,8 @@ def nested_structs(root: Struct) -> typing.Generator[StructLink, None, None]:
match field_type:
case Array():
array_type = typing.cast(Array, field_type)
if isinstance(array_type.element_type, Struct):
yield root, field, array_type.element_type
yield from nested_structs(array_type.element_type)
yield root, field, array_type.element_type
yield from nested_structs(array_type.element_type)
case TaggedUnion():
tagged_union_type = typing.cast(TaggedUnion, field_type)
for _, struct in tagged_union_type.cases:
Expand Down
1 change: 1 addition & 0 deletions projects/jdwp/codegen/new_type_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def get_type_alias_definition(jdwp_type: IdType) -> str:


def generate_new_types():
print('import typing')
for id_type in IdType:
type_alias_definition = get_type_alias_definition(id_type)
print(type_alias_definition)
Expand Down

0 comments on commit 60a3b8c

Please sign in to comment.