From f953da894d3fe97b7ef6a544653168ed91ac0421 Mon Sep 17 00:00:00 2001 From: Sassan Haradji Date: Sat, 28 Sep 2024 18:34:37 +0400 Subject: [PATCH] refactor(rpc): preserve the order of fields of `oneof` declarations generated for `Union` types --- CHANGELOG.md | 1 + ubo_app/rpc/generate_proto.py | 64 ++++++----------------------------- 2 files changed, 12 insertions(+), 53 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2cd641d..eb42fd9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - feat(display): add `DisplayCompressedRenderEvent` as a compressed version of `DisplayRenderEvent` - feat(rpc): add reflection to rpc server, limited to root service, but good enough for health checking purposes +- refactor(rpc): preserve the order of fields of `oneof` declarations generated for `Union` types ## Version 0.17.0 diff --git a/ubo_app/rpc/generate_proto.py b/ubo_app/rpc/generate_proto.py index 0dd2dce..ea195f4 100644 --- a/ubo_app/rpc/generate_proto.py +++ b/ubo_app/rpc/generate_proto.py @@ -6,10 +6,8 @@ import ast import functools import importlib -import operator import re import sys -from collections import defaultdict from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, Self, get_args @@ -27,7 +25,6 @@ actions = [] events = [] states = {} -dataclass_fields = defaultdict(set) global_messages: dict[str, tuple[str, list[tuple[str, _Type]]]] = {} global_enums: dict[str, str] = {} @@ -65,10 +62,6 @@ def get_embedded_definitions( ) -> str: return self.get_definitions(name, current_package=current_package) - @property - def dependencies(self: Self) -> set[str]: - return set() - @property def package(self: Self) -> str | None: return None @@ -120,18 +113,6 @@ def get_definitions(self: Self, name: str, *, current_package: str | None) -> st _ = name, current_package return '' - @property - def dependencies(self: Self) -> set[str]: - try: - if self.package is None: - return set() - except TypeError as exception: - if 'Unknown type' in str(exception): - return set() - raise - else: - return {self.package} - @property def package(self: Self) -> str | None: if self.type not in get_args(FieldType): @@ -187,10 +168,6 @@ def get_embedded_definitions( }} """ - @property - def dependencies(self: Self) -> set[str]: - return self.type.dependencies - class _ListType(_Type): type: _Type @@ -231,10 +208,6 @@ def get_embedded_definitions( }} """ - @property - def dependencies(self: Self) -> set[str]: - return self.type.dependencies - class _UnionType(_Type): types: tuple[_Type, ...] @@ -298,14 +271,6 @@ def get_embedded_definitions( return f'{sub_definitions}\n{definitions}' - @property - def dependencies(self: Self) -> set[str]: - return functools.reduce( - operator.or_, - (item.dependencies for item in self.types), - set(), - ) - class _DictType(_Type): key_type: _Type @@ -356,10 +321,6 @@ def get_embedded_definitions( {self.get_proto(name, current_package=current_package)} items = 1; }}""" - @property - def dependencies(self: Self) -> set[str]: - return self.key_type.dependencies | self.value_type.dependencies - class _ProtoGenerator(ast.NodeVisitor): def __init__(self: _ProtoGenerator, module: ModuleType) -> None: @@ -469,7 +430,6 @@ def process_class(self: _ProtoGenerator, node: ast.ClassDef) -> None: # noqa: C ) self.messages[message_name] = fields global_messages[message_name] = (self.package_name, fields) - dataclass_fields[self.package_name].add(message_name) if message_name.endswith('Action'): actions.append((message_name, self.package_name)) if message_name.endswith('Event'): @@ -557,14 +517,18 @@ def get_field_type( # noqa: C901, PLR0912 raise TypeError(msg) if isinstance(value, ast.BinOp) and isinstance(value.op, ast.BitOr): - types: set[_Type] = set() + types: list[_Type] = [] try: - types.add(self.get_field_type(value=value.left)) + type = self.get_field_type(value=value.left) + if type not in types: + types.append(type) except TypeError as e: if 'Callable types are not supported' not in str(e): raise try: - types.add(self.get_field_type(value=value.right)) + type = self.get_field_type(value=value.right) + if type not in types: + types.append(type) except TypeError as e: if 'Callable types are not supported' not in str(e): raise @@ -575,7 +539,7 @@ def get_field_type( # noqa: C901, PLR0912 msg = f'Unsupported field type: {value}' raise TypeError(msg) - def generate_proto(self: _ProtoGenerator) -> tuple[str, set[str]]: # noqa: C901 + def generate_proto(self: _ProtoGenerator) -> str: # noqa: C901 try: proto = '' for enum_name, values in self.enums.items(): @@ -588,7 +552,6 @@ def generate_proto(self: _ProtoGenerator) -> tuple[str, set[str]]: # noqa: C901 betterproto.casing.snake_case(enum_name).upper()}_{ value_name} = {i + 1};\n""" proto += '}\n\n' - dependencies: set[str] = set() for message_name, fields in self.messages.items(): proto += f'message {message_name} {{\n' proto += f""" option (package_info.v1.package_name) = "{ @@ -597,7 +560,6 @@ def generate_proto(self: _ProtoGenerator) -> tuple[str, set[str]]: # noqa: C901 proto += f"""{self.package_name.replace(".", "_dot_")} = { META_FIELD_PREFIX_PACKAGE_NAME_INDEX};\n""" for field_name, field_type in fields: - dependencies |= field_type.dependencies proto += re.sub( r'\n(?=.)', '\n', @@ -625,12 +587,11 @@ def generate_proto(self: _ProtoGenerator) -> tuple[str, set[str]]: # noqa: C901 name, current_package=self.package_name, ) - dependencies |= field_type.dependencies except TypeError as e: msg = f'Error in {self.package_name}' raise TypeError(msg) from e else: - return proto, dependencies - {self.package_name} + return proto def _generate_operations_proto(output_directory: Path) -> None: @@ -690,7 +651,7 @@ def parse(input_module: ModuleType) -> _ProtoGenerator: ) generators.extend( parse(importlib.import_module(f'ubo_app.store.services.{file.stem}')) - for file in Path('ubo_app/store/services/').glob('*.py') + for file in sorted(Path('ubo_app/store/services/').glob('*.py')) ) (output_directory / 'ubo' / 'v1').mkdir( @@ -703,11 +664,8 @@ def parse(input_module: ModuleType) -> _ProtoGenerator: file.write('import "package_info/v1/package_info.proto";\n\n') for generator in generators: sys.stdout.write(f'⚡ Generating proto for {generator.package_name} .') - proto_definitions, _ = generator.generate_proto() - sys.stdout.write('.') - - file.write(proto_definitions) + file.write(generator.generate_proto()) print(' Done')