diff --git a/capella_ros_tools/.license_header.txt b/capella_ros_tools/.license_header.txt index 8c17559..f9f06ed 100644 --- a/capella_ros_tools/.license_header.txt +++ b/capella_ros_tools/.license_header.txt @@ -1,2 +1,2 @@ -# SPDX-FileCopyrightText: Copyright DB Netz AG +# SPDX-FileCopyrightText: Copyright DB InfraGO AG # SPDX-License-Identifier: Apache-2.0 diff --git a/capella_ros_tools/__main__.py b/capella_ros_tools/__main__.py index 38b2dd1..2e11894 100644 --- a/capella_ros_tools/__main__.py +++ b/capella_ros_tools/__main__.py @@ -11,9 +11,7 @@ from capellambse import cli_helpers, decl import capella_ros_tools -from capella_ros_tools import exporter, importer - -from . import logger +from capella_ros_tools import exporter, importer, logger @click.group() @@ -71,6 +69,11 @@ def cli(): type=click.Path(path_type=pathlib.Path, dir_okay=False), help="Produce a declarative YAML instead of modifying the source model.", ) +@click.option( + "--license-header", + type=click.Path(path_type=pathlib.Path, dir_okay=False), + help="Ignore the license header from the given file when importing msgs.", +) def import_msgs( input: str, model: capellambse.MelodyModel, @@ -79,6 +82,7 @@ def import_msgs( types: uuid.UUID, no_deps: bool, output: pathlib.Path, + license_header: pathlib.Path | None, ) -> None: """Import ROS messages into a Capella data package.""" if root: @@ -93,7 +97,7 @@ def import_msgs( else: params = {"types_parent_uuid": model.sa.data_package.uuid} - parsed = importer.Importer(input, no_deps) + parsed = importer.Importer(input, no_deps, license_header) logger.info("Loaded %d packages", len(parsed.messages.packages)) yml = parsed.to_yaml(root_uuid, **params) diff --git a/capella_ros_tools/data_model.py b/capella_ros_tools/data_model.py index becd789..d3e360b 100644 --- a/capella_ros_tools/data_model.py +++ b/capella_ros_tools/data_model.py @@ -217,12 +217,15 @@ def __eq__(self, other: object) -> bool: @classmethod def from_file( - cls, file: abc.AbstractFilePath | pathlib.Path + cls, + file: abc.AbstractFilePath | pathlib.Path, + license_header: str | None = None, ) -> MessageDef: """Create message definition from a .msg file.""" msg_name = file.stem msg_string = file.read_text() - msg_string = msg_string.removeprefix(LICENSE_HEADER) + license_header = license_header or LICENSE_HEADER + msg_string = msg_string.removeprefix(license_header) return cls.from_string(msg_name, msg_string) @classmethod @@ -391,7 +394,10 @@ def __eq__(self, other: object) -> bool: @classmethod def from_msg_folder( - cls, pkg_name: str, msg_path: abc.AbstractFilePath | pathlib.Path + cls, + pkg_name: str, + msg_path: abc.AbstractFilePath | pathlib.Path, + license_header: str | None = None, ) -> MessagePkgDef: """Create a message package definition from a folder.""" out = cls(pkg_name, [], []) @@ -400,6 +406,6 @@ def from_msg_folder( msg_path.rglob("*.msg"), ) for msg_file in sorted(files, key=os.fspath): - msg_def = MessageDef.from_file(msg_file) + msg_def = MessageDef.from_file(msg_file, license_header) out.messages.append(msg_def) return out diff --git a/capella_ros_tools/importer.py b/capella_ros_tools/importer.py index d546773..e2397be 100644 --- a/capella_ros_tools/importer.py +++ b/capella_ros_tools/importer.py @@ -3,6 +3,7 @@ """Tool for importing ROS messages to a Capella data package.""" import os +import pathlib import typing as t from capellambse import decl, filehandler, helpers @@ -27,10 +28,14 @@ def __init__( self, msg_path: str, no_deps: bool, + license_header_path: pathlib.Path | None = None, ): self.messages = data_model.MessagePkgDef("root", [], []) self._promise_ids: dict[str, None] = {} self._promise_id_refs: dict[str, None] = {} + self._license_header = None + if license_header_path is not None: + self._license_header = license_header_path.read_text("utf-8") self._add_packages("ros_msgs", msg_path) if no_deps: @@ -43,7 +48,9 @@ def _add_packages(self, name: str, path: str) -> None: root = filehandler.get_filehandler(path).rootdir for dir in sorted(root.rglob("msg"), key=os.fspath): pkg_name = dir.parent.name or name - pkg_def = data_model.MessagePkgDef.from_msg_folder(pkg_name, dir) + pkg_def = data_model.MessagePkgDef.from_msg_folder( + pkg_name, dir, self._license_header + ) self.messages.packages.append(pkg_def) logger.info("Loaded package %s from %s", pkg_name, dir) diff --git a/tests/data/data_model/custom_license_header.txt b/tests/data/data_model/custom_license_header.txt new file mode 100644 index 0000000..2e43018 --- /dev/null +++ b/tests/data/data_model/custom_license_header.txt @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright DB InfraGO AG +# SPDX-License-Identifier: Apache-2.0 +# Additional Stuff to be removed diff --git a/tests/data/data_model/custom_license_msgs/package/msg/SampleClassEnum.msg b/tests/data/data_model/custom_license_msgs/package/msg/SampleClassEnum.msg new file mode 100644 index 0000000..56cb8c8 --- /dev/null +++ b/tests/data/data_model/custom_license_msgs/package/msg/SampleClassEnum.msg @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright DB InfraGO AG +# SPDX-License-Identifier: Apache-2.0 +# Additional Stuff to be removed + +# SampleClassEnum.msg +# Properties in SampleClassEnum can reference +# enums in the same file. + +# This block comment is added to the +# enum description of SampleClassEnumStatus. +uint8 OK = 0 +uint8 WARN = 1 +uint8 ERROR = 2 +uint8 STALE = 3 + +# This block comment is added to the +# enum description of Color. +uint8 COLOR_RED = 0 +uint8 COLOR_BLUE = 1 +uint8 COLOR_YELLOW = 2 + +uint8 status # The property status is of type + # SampleClassEnumStatus. +uint8 color # The property color is of type Color. +uint8 field diff --git a/tests/data/data_model/example_msgs/package1/msg/SampleClass.msg b/tests/data/data_model/example_msgs/package1/msg/SampleClass.msg index 1a76f01..899b339 100644 --- a/tests/data/data_model/example_msgs/package1/msg/SampleClass.msg +++ b/tests/data/data_model/example_msgs/package1/msg/SampleClass.msg @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright DB Netz AG +# SPDX-FileCopyrightText: Copyright DB InfraGO AG # SPDX-License-Identifier: Apache-2.0 # SampleClass.msg diff --git a/tests/data/data_model/example_msgs/package1/msg/types/SampleEnum.msg b/tests/data/data_model/example_msgs/package1/msg/types/SampleEnum.msg index 5ea8f22..a189c39 100644 --- a/tests/data/data_model/example_msgs/package1/msg/types/SampleEnum.msg +++ b/tests/data/data_model/example_msgs/package1/msg/types/SampleEnum.msg @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright DB Netz AG +# SPDX-FileCopyrightText: Copyright DB InfraGO AG # SPDX-License-Identifier: Apache-2.0 # SampleEnum.msg diff --git a/tests/data/data_model/example_msgs/package2/msg/SampleClassEnum.msg b/tests/data/data_model/example_msgs/package2/msg/SampleClassEnum.msg index 60b0c13..10d40e5 100644 --- a/tests/data/data_model/example_msgs/package2/msg/SampleClassEnum.msg +++ b/tests/data/data_model/example_msgs/package2/msg/SampleClassEnum.msg @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright DB Netz AG +# SPDX-FileCopyrightText: Copyright DB InfraGO AG # SPDX-License-Identifier: Apache-2.0 # SampleClassEnum.msg diff --git a/tests/test_import_msgs.py b/tests/test_import_msgs.py index f1bc74f..6fab51d 100644 --- a/tests/test_import_msgs.py +++ b/tests/test_import_msgs.py @@ -21,8 +21,18 @@ PATH = pathlib.Path(__file__).parent SAMPLE_PACKAGE_PATH = PATH.joinpath("data/data_model/example_msgs") +CUSTOM_LICENSE_PACKAGE_PATH = PATH.joinpath( + "data/data_model/custom_license_msgs" +) SAMPLE_PACKAGE_YAML = PATH.joinpath("data/data_model/example_msgs.yaml") DUMMY_PATH = PATH.joinpath("data/empty_project_60") +CUSTOM_LICENSE_PATH = PATH.joinpath( + "data/data_model/custom_license_header.txt" +) +EXPECTED_DESCRIPTION_SAMPLE_CLASS_ENUM = ( + "SampleClassEnum.msg " + "Properties in SampleClassEnum can reference enums in the same file. " +) ROOT = helpers.UUIDString("00000000-0000-0000-0000-000000000000") SA_ROOT = helpers.UUIDString("00000000-0000-0000-0000-000000000001") @@ -220,3 +230,14 @@ def test_convert_package(): ) assert actual == expected + + +def test_custom_license_header(): + importer = Importer( + CUSTOM_LICENSE_PACKAGE_PATH.as_posix(), True, CUSTOM_LICENSE_PATH + ) + + assert ( + importer.messages.packages[0].messages[0].description + == EXPECTED_DESCRIPTION_SAMPLE_CLASS_ENUM + )