From bacbafe9e7fe12be6c3726ad52c6a5944a7a2e40 Mon Sep 17 00:00:00 2001 From: Michael Harbarth Date: Thu, 28 Nov 2024 17:27:41 +0100 Subject: [PATCH] feat: Add dependency support for non-standard branches and enum fixes --- capella_ros_tools/__main__.py | 8 ++++- capella_ros_tools/data_model.py | 13 ++++--- capella_ros_tools/importer.py | 52 +++++++++++++++++++++------- tests/data/dependencies.json | 14 ++++++++ tests/data/dependencies.json.license | 2 ++ 5 files changed, 71 insertions(+), 18 deletions(-) create mode 100644 tests/data/dependencies.json create mode 100644 tests/data/dependencies.json.license diff --git a/capella_ros_tools/__main__.py b/capella_ros_tools/__main__.py index c156d9c..a962c70 100644 --- a/capella_ros_tools/__main__.py +++ b/capella_ros_tools/__main__.py @@ -79,6 +79,11 @@ def cli(): type=str, help="Regular expression to extract description from the file .", ) +@click.option( + "--dependency-json", + type=click.Path(path_type=pathlib.Path, dir_okay=False), + help="A path to a JSON containing dependencies which should be imported.", +) def import_msgs( input: str, model: capellambse.MelodyModel, @@ -89,6 +94,7 @@ def import_msgs( output: pathlib.Path, license_header: pathlib.Path | None, description_regex: str | None, + dependency_json: pathlib.Path | None, ) -> None: """Import ROS messages into a Capella data package.""" if root: @@ -104,7 +110,7 @@ def import_msgs( params = {"types_parent_uuid": model.sa.data_package.uuid} parsed = importer.Importer( - input, no_deps, license_header, description_regex + input, no_deps, license_header, description_regex, dependency_json ) logger.info("Loaded %d packages", len(parsed.messages.packages)) diff --git a/capella_ros_tools/data_model.py b/capella_ros_tools/data_model.py index 989a61e..2c38654 100644 --- a/capella_ros_tools/data_model.py +++ b/capella_ros_tools/data_model.py @@ -224,6 +224,7 @@ def __eq__(self, other: object) -> bool: @classmethod def from_file( cls, + pkg_name: str, file: abc.AbstractFilePath | pathlib.Path, license_header: str | None = None, msg_description_regex: re.Pattern[str] | None = None, @@ -233,11 +234,14 @@ def from_file( msg_string = file.read_text() license_header = license_header or LICENSE_HEADER msg_string = msg_string.removeprefix(license_header) - return cls.from_string(msg_name, msg_string, msg_description_regex) + return cls.from_string( + pkg_name, msg_name, msg_string, msg_description_regex + ) @classmethod def from_string( cls, + pkg_name: str, msg_name: str, msg_string: str, msg_description_regex: re.Pattern[str] | None = None, @@ -342,14 +346,15 @@ def from_string( if field.type.name == enum.literals[0].type.name: matched_field = matched_field or field if field.name.lower() == enum.name.lower(): + enum.name = msg_name + matched_field.name.capitalize() field.type.name = enum.name - field.type.package = msg_name + field.type.package = f"{pkg_name}.{msg_name}" break else: if matched_field: enum.name = msg_name + matched_field.name.capitalize() matched_field.type.name = enum.name - matched_field.type.package = msg_name + matched_field.type.package = f"{pkg_name}.{msg_name}" return msg @@ -422,7 +427,7 @@ def from_msg_folder( ) for msg_file in sorted(files, key=os.fspath): msg_def = MessageDef.from_file( - msg_file, license_header, msg_description_regex + pkg_name, msg_file, license_header, msg_description_regex ) out.messages.append(msg_def) return out diff --git a/capella_ros_tools/importer.py b/capella_ros_tools/importer.py index e3b0caf..d967b9d 100644 --- a/capella_ros_tools/importer.py +++ b/capella_ros_tools/importer.py @@ -1,7 +1,7 @@ # Copyright DB InfraGO AG and contributors # SPDX-License-Identifier: Apache-2.0 """Tool for importing ROS messages to a Capella data package.""" - +import json import os import pathlib import re @@ -31,7 +31,13 @@ def __init__( no_deps: bool, license_header_path: pathlib.Path | None = None, msg_description_regex: str | None = None, + dependency_json: pathlib.Path | None = None, ): + if dependency_json: + dependencies = json.loads(dependency_json.read_bytes()) + else: + dependencies = ROS2_INTERFACES + self.messages = data_model.MessagePkgDef("root", [], []) self._promise_ids: dict[str, None] = {} self._promise_id_refs: dict[str, None] = {} @@ -44,13 +50,23 @@ def __init__( if no_deps: return - for interface_name, interface_url in ROS2_INTERFACES.items(): - self._add_packages(interface_name, interface_url) + for interface_name, interface_spec in dependencies.items(): + kwargs = {} + if isinstance(interface_spec, dict): + interface_url = interface_spec.pop("path") + kwargs.update(interface_spec) + else: + interface_url = interface_spec + self._add_packages(interface_name, interface_url, **kwargs) def _add_packages( - self, name: str, path: str, msg_description_regex: str | None = None + self, + name: str, + path: str, + msg_description_regex: str | None = None, + **kwargs, ) -> None: - root = filehandler.get_filehandler(path).rootdir + root = filehandler.get_filehandler(path, **kwargs).rootdir msg_description_pattern = None if msg_description_regex is not None: msg_description_pattern = re.compile( @@ -95,7 +111,9 @@ def _convert_package( cls_yml = self._convert_class(pkg_def.name, msg_def) classes.append(cls_yml) for enum_def in msg_def.enums: - enums.append(self._convert_enum(msg_def.name, enum_def)) + enums.append( + self._convert_enum(pkg_def.name, msg_def.name, enum_def) + ) for new_pkg in pkg_def.packages: new_yml = { @@ -171,9 +189,9 @@ def _convert_class( return yml def _convert_enum( - self, pkg_name: str, enum_def: data_model.EnumDef + self, pkg_name: str, msg_name: str, enum_def: data_model.EnumDef ) -> dict[str, t.Any]: - promise_id = f"{pkg_name}.{enum_def.name}" + promise_id = f"{pkg_name}.{msg_name}.{enum_def.name}" self._promise_ids[promise_id] = None literals = [] for literal in enum_def.literals: @@ -190,16 +208,24 @@ def _convert_enum( if literal.description: literal_yml["set"]["description"] = literal.description literals.append(literal_yml) + + types = set(lit.type.name for lit in enum_def.literals) + assert len(types) == 1, "All values of an Enum must have the same type" + promise_ref = f"{pkg_name}.{types.pop()}" + self._promise_id_refs[promise_ref] = None + + set_values: dict[str, t.Any] = { + "domain_type": decl.Promise(promise_ref) + } + if enum_def.description: + set_values["description"] = enum_def.description + yml = { "promise_id": promise_id, "find": { "name": enum_def.name, }, - "set": ( - {"description": enum_def.description} - if enum_def.description - else {} - ), + "set": set_values, "sync": { "literals": literals, }, diff --git a/tests/data/dependencies.json b/tests/data/dependencies.json new file mode 100644 index 0000000..d40a7dc --- /dev/null +++ b/tests/data/dependencies.json @@ -0,0 +1,14 @@ +{ + "common_interfaces": { + "path": "git+https://github.com/ros2/common_interfaces", + "revision": "humble" + }, + "rcl_interfaces": { + "path": "git+https://github.com/ros2/rcl_interfaces", + "revision": "humble" + }, + "unique_identifier_msgs": { + "path": "git+https://github.com/ros2/unique_identifier_msgs", + "revision": "humble" + } +} diff --git a/tests/data/dependencies.json.license b/tests/data/dependencies.json.license new file mode 100644 index 0000000..62a1749 --- /dev/null +++ b/tests/data/dependencies.json.license @@ -0,0 +1,2 @@ +SPDX-FileCopyrightText: Copyright DB InfraGO AG +SPDX-License-Identifier: Apache-2.0