Skip to content

Commit

Permalink
refactor(rpc): preserve the order of fields of oneof declarations g…
Browse files Browse the repository at this point in the history
…enerated for `Union` types
  • Loading branch information
sassanh committed Sep 28, 2024
1 parent c97426d commit f953da8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 53 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

- feat(display): add `DisplayCompressedRenderEvent` as a compressed version of `DisplayRenderEvent`
- feat(rpc): add reflection to rpc server, limited to root service, but good enough for health checking purposes
- refactor(rpc): preserve the order of fields of `oneof` declarations generated for `Union` types

## Version 0.17.0

Expand Down
64 changes: 11 additions & 53 deletions ubo_app/rpc/generate_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
import ast
import functools
import importlib
import operator
import re
import sys
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Self, get_args

Expand All @@ -27,7 +25,6 @@
actions = []
events = []
states = {}
dataclass_fields = defaultdict(set)

global_messages: dict[str, tuple[str, list[tuple[str, _Type]]]] = {}
global_enums: dict[str, str] = {}
Expand Down Expand Up @@ -65,10 +62,6 @@ def get_embedded_definitions(
) -> str:
return self.get_definitions(name, current_package=current_package)

@property
def dependencies(self: Self) -> set[str]:
return set()

@property
def package(self: Self) -> str | None:
return None
Expand Down Expand Up @@ -120,18 +113,6 @@ def get_definitions(self: Self, name: str, *, current_package: str | None) -> st
_ = name, current_package
return ''

@property
def dependencies(self: Self) -> set[str]:
try:
if self.package is None:
return set()
except TypeError as exception:
if 'Unknown type' in str(exception):
return set()
raise
else:
return {self.package}

@property
def package(self: Self) -> str | None:
if self.type not in get_args(FieldType):
Expand Down Expand Up @@ -187,10 +168,6 @@ def get_embedded_definitions(
}}
"""

@property
def dependencies(self: Self) -> set[str]:
return self.type.dependencies


class _ListType(_Type):
type: _Type
Expand Down Expand Up @@ -231,10 +208,6 @@ def get_embedded_definitions(
}}
"""

@property
def dependencies(self: Self) -> set[str]:
return self.type.dependencies


class _UnionType(_Type):
types: tuple[_Type, ...]
Expand Down Expand Up @@ -298,14 +271,6 @@ def get_embedded_definitions(

return f'{sub_definitions}\n{definitions}'

@property
def dependencies(self: Self) -> set[str]:
return functools.reduce(
operator.or_,
(item.dependencies for item in self.types),
set(),
)


class _DictType(_Type):
key_type: _Type
Expand Down Expand Up @@ -356,10 +321,6 @@ def get_embedded_definitions(
{self.get_proto(name, current_package=current_package)} items = 1;
}}"""

@property
def dependencies(self: Self) -> set[str]:
return self.key_type.dependencies | self.value_type.dependencies


class _ProtoGenerator(ast.NodeVisitor):
def __init__(self: _ProtoGenerator, module: ModuleType) -> None:
Expand Down Expand Up @@ -469,7 +430,6 @@ def process_class(self: _ProtoGenerator, node: ast.ClassDef) -> None: # noqa: C
)
self.messages[message_name] = fields
global_messages[message_name] = (self.package_name, fields)
dataclass_fields[self.package_name].add(message_name)
if message_name.endswith('Action'):
actions.append((message_name, self.package_name))
if message_name.endswith('Event'):
Expand Down Expand Up @@ -557,14 +517,18 @@ def get_field_type( # noqa: C901, PLR0912
raise TypeError(msg)

if isinstance(value, ast.BinOp) and isinstance(value.op, ast.BitOr):
types: set[_Type] = set()
types: list[_Type] = []
try:
types.add(self.get_field_type(value=value.left))
type = self.get_field_type(value=value.left)
if type not in types:
types.append(type)
except TypeError as e:
if 'Callable types are not supported' not in str(e):
raise
try:
types.add(self.get_field_type(value=value.right))
type = self.get_field_type(value=value.right)
if type not in types:
types.append(type)
except TypeError as e:
if 'Callable types are not supported' not in str(e):
raise
Expand All @@ -575,7 +539,7 @@ def get_field_type( # noqa: C901, PLR0912
msg = f'Unsupported field type: {value}'
raise TypeError(msg)

def generate_proto(self: _ProtoGenerator) -> tuple[str, set[str]]: # noqa: C901
def generate_proto(self: _ProtoGenerator) -> str: # noqa: C901
try:
proto = ''
for enum_name, values in self.enums.items():
Expand All @@ -588,7 +552,6 @@ def generate_proto(self: _ProtoGenerator) -> tuple[str, set[str]]: # noqa: C901
betterproto.casing.snake_case(enum_name).upper()}_{
value_name} = {i + 1};\n"""
proto += '}\n\n'
dependencies: set[str] = set()
for message_name, fields in self.messages.items():
proto += f'message {message_name} {{\n'
proto += f""" option (package_info.v1.package_name) = "{
Expand All @@ -597,7 +560,6 @@ def generate_proto(self: _ProtoGenerator) -> tuple[str, set[str]]: # noqa: C901
proto += f"""{self.package_name.replace(".", "_dot_")} = {
META_FIELD_PREFIX_PACKAGE_NAME_INDEX};\n"""
for field_name, field_type in fields:
dependencies |= field_type.dependencies
proto += re.sub(
r'\n(?=.)',
'\n',
Expand Down Expand Up @@ -625,12 +587,11 @@ def generate_proto(self: _ProtoGenerator) -> tuple[str, set[str]]: # noqa: C901
name,
current_package=self.package_name,
)
dependencies |= field_type.dependencies
except TypeError as e:
msg = f'Error in {self.package_name}'
raise TypeError(msg) from e
else:
return proto, dependencies - {self.package_name}
return proto


def _generate_operations_proto(output_directory: Path) -> None:
Expand Down Expand Up @@ -690,7 +651,7 @@ def parse(input_module: ModuleType) -> _ProtoGenerator:
)
generators.extend(
parse(importlib.import_module(f'ubo_app.store.services.{file.stem}'))
for file in Path('ubo_app/store/services/').glob('*.py')
for file in sorted(Path('ubo_app/store/services/').glob('*.py'))
)

(output_directory / 'ubo' / 'v1').mkdir(
Expand All @@ -703,11 +664,8 @@ def parse(input_module: ModuleType) -> _ProtoGenerator:
file.write('import "package_info/v1/package_info.proto";\n\n')
for generator in generators:
sys.stdout.write(f'⚡ Generating proto for {generator.package_name} .')
proto_definitions, _ = generator.generate_proto()

sys.stdout.write('.')

file.write(proto_definitions)
file.write(generator.generate_proto())

print(' Done')

Expand Down

0 comments on commit f953da8

Please sign in to comment.