From cb000c8f04ea95241edee41fdb7f9957c52ad23d Mon Sep 17 00:00:00 2001 From: jamesbeedy Date: Wed, 17 Apr 2024 19:29:06 +0000 Subject: [PATCH 1/2] project cleanup These changes improve the slurmdbd charm in a number of ways. 1) remove slurm-ops-manager 2) remove unused code 3) emit relation data to the charm via events 4) remove dependency on slurmctld 5) general code cleanup 6) consolidate yaml files into charmcraft.yaml 7) add type tests 8) rename interface slurmdbd -> slurmctld 9) update readme Address feedback from PR review. * add package managers for each package, slurmdbd and munge * address docstring issues * move constants to constants.py * address test failure * pin requirements * install funtions handle failure cases * upgrade to checkout action v4 --- .github/workflows/ci.yaml | 20 +- .gitignore | 2 + README.md | 4 +- charmcraft.yaml | 56 +- config.yaml | 23 - .../data_platform_libs/v0/data_interfaces.py | 3001 +++++++++++++++-- lib/charms/operator_libs_linux/v0/apt.py | 1361 ++++++++ lib/charms/operator_libs_linux/v1/systemd.py | 251 +- metadata.yaml | 30 - pyproject.toml | 10 +- requirements.txt | 4 +- src/charm.py | 365 +- src/constants.py | 52 + src/interface_slurmctld.py | 118 + src/interface_slurmdbd.py | 149 - src/interface_slurmdbd_peer.py | 183 - src/slurmdbd_ops.py | 358 ++ src/utils/__init__.py | 16 - src/utils/confeditor.py | 1212 ------- src/utils/manager.py | 127 - tests/integration/test_charm.py | 6 +- tests/unit/test_charm.py | 264 +- tests/unit/test_confeditor.py | 128 - tests/unit/test_manager.py | 173 - tox.ini | 31 +- 25 files changed, 5041 insertions(+), 2903 deletions(-) delete mode 100644 config.yaml create mode 100644 lib/charms/operator_libs_linux/v0/apt.py delete mode 100644 metadata.yaml create mode 100644 src/constants.py create mode 100644 src/interface_slurmctld.py delete mode 100644 src/interface_slurmdbd.py delete mode 100644 src/interface_slurmdbd_peer.py create mode 100644 src/slurmdbd_ops.py delete mode 100644 src/utils/__init__.py delete mode 100644 src/utils/confeditor.py delete mode 100644 src/utils/manager.py delete mode 100644 tests/unit/test_confeditor.py delete mode 100644 tests/unit/test_manager.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 18c5fd3..fd144cb 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -23,7 +23,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Run tests uses: get-woke/woke-action@v0 with: @@ -34,7 +34,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install dependencies run: python3 -m pip install tox - name: Run linters @@ -45,12 +45,23 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install dependencies run: python3 -m pip install tox - name: Run tests run: tox -e unit + type-check: + name: Type check + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Install dependencies + run: python3 -m pip install tox + - name: Run tests + run: tox -e type + integration-test: strategy: fail-fast: true @@ -62,10 +73,11 @@ jobs: needs: - inclusive-naming-check - lint + - type-check - unit-test steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup operator environment uses: charmed-kubernetes/actions-operator@main with: diff --git a/.gitignore b/.gitignore index 9ceb2e5..e84bc51 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,5 @@ __pycache__/ .idea .vscode/ version +.ruff_cache/ +.mypy_cache/ diff --git a/README.md b/README.md index 2b7b968..5edc69e 100644 --- a/README.md +++ b/README.md @@ -27,10 +27,10 @@ $ juju deploy slurmd --channel edge $ juju deploy slurmdbd --channel edge $ juju deploy mysql --channel 8.0/edge $ juju deploy mysql-router slurmdbd-mysql-router --channel dpe/edge -$ juju integrate slurmctld:slurmd slurmd:slurmd +$ juju integrate slurmctld:slurmd slurmd:slurmctld +$ juju integrate slurmctld:slurmdbd slurmdbd:slurmctld $ juju integrate slurmdbd-mysql-router:backend-database mysql:database $ juju integrate slurmdbd:database slurmdbd-mysql-router:database -$ juju integrate slurmctld:slurmdbd slurmdbd:slurmdbd ``` ## Project & Community diff --git a/charmcraft.yaml b/charmcraft.yaml index 6a6f7d1..47e7e9f 100644 --- a/charmcraft.yaml +++ b/charmcraft.yaml @@ -1,7 +1,40 @@ -# Copyright 2020 Omnivector Solutions, LLC. +# Copyright 2020-2024 Omnivector, LLC. # See LICENSE file for licensing details. - +name: slurmdbd type: charm + +assumes: + - juju + +summary: | + Slurm DBD accounting daemon. + +description: | + This charm provides slurmdbd, munged, and the bindings to other utilities + that make lifecycle operations a breeze. + + slurmdbd provides a secure enterprise-wide interface to a database for + SLURM. This is particularly useful for archiving accounting records. + +links: + contact: https://matrix.to/#/#hpc:ubuntu.com + + source: + - https://github.com/omnivector-solutions/slurmdbd-operator + + issues: + - https://github.com/omnivector-solutions/slurmdbd-operator/issues + +requires: + database: + interface: mysql_client + fluentbit: + interface: fluentbit + +provides: + slurmctld: + interface: slurmdbd + bases: - build-on: - name: ubuntu @@ -10,22 +43,3 @@ bases: - name: ubuntu channel: "22.04" architectures: [amd64] - -parts: - charm: - build-packages: [git] - charm-python-packages: [setuptools] - - # Create a version file and pack it into the charm. This is dynamically generated - # as part of the build process for a charm to ensure that the git revision of the - # charm is always recorded in this version file. - version-file: - plugin: nil - build-packages: - - git - override-build: | - VERSION=$(git -C $CRAFT_PART_SRC/../../charm/src describe --dirty --always) - echo "Setting version to $VERSION" - echo $VERSION > $CRAFT_PART_INSTALL/version - stage: - - version diff --git a/config.yaml b/config.yaml deleted file mode 100644 index e7fc520..0000000 --- a/config.yaml +++ /dev/null @@ -1,23 +0,0 @@ -options: - custom-slurm-repo: - type: string - default: "" - description: > - Use a custom repository for Slurm installation. - - This can be set to the Organization's local mirror/cache of packages and - supersedes the Omnivector repositories. Alternatively, it can be used to - track a `testing` Slurm version, e.g. by setting to - `ppa:omnivector/osd-testing`. - - Note: The configuration `custom-slurm-repo` must be set *before* - deploying the units. Changing this value after deploying the units will - not reinstall Slurm. - slurmdbd-debug: - type: string - default: info - description: > - The level of detail to provide slurmdbd daemon's logs. The default value - is `info`. If the slurmdbd daemon is initiated with `-v` or `--verbose` - options, that debug level will be preserve or restored upon - reconfiguration. diff --git a/lib/charms/data_platform_libs/v0/data_interfaces.py b/lib/charms/data_platform_libs/v0/data_interfaces.py index 8d860a6..3ce69e1 100644 --- a/lib/charms/data_platform_libs/v0/data_interfaces.py +++ b/lib/charms/data_platform_libs/v0/data_interfaces.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Library to manage the relation for the data-platform products. +r"""Library to manage the relation for the data-platform products. This library contains the Requires and Provides classes for handling the relation between an application and multiple managed application supported by the data-team: -MySQL, Postgresql, MongoDB, Redis, and Kakfa. +MySQL, Postgresql, MongoDB, Redis, and Kafka. ### Database (MySQL, Postgresql, MongoDB, and Redis) @@ -144,6 +144,19 @@ def _on_cluster2_database_created(self, event: DatabaseCreatedEvent) -> None: ``` +When it's needed to check whether a plugin (extension) is enabled on the PostgreSQL +charm, you can use the is_postgresql_plugin_enabled method. To use that, you need to +add the following dependency to your charmcraft.yaml file: + +```yaml + +parts: + charm: + charm-binary-python-packages: + - psycopg[binary] + +``` + ### Provider Charm Following an example of using the DatabaseRequestedEvent, in the context of the @@ -277,23 +290,38 @@ def _on_topic_requested(self, event: TopicRequestedEvent): creating a new topic when other information other than a topic name is exchanged in the relation databag. """ -import abc + +import copy import json import logging from abc import ABC, abstractmethod -from collections import namedtuple +from collections import UserDict, namedtuple from datetime import datetime -from typing import List, Optional +from enum import Enum +from typing import ( + Callable, + Dict, + ItemsView, + KeysView, + List, + Optional, + Set, + Tuple, + Union, + ValuesView, +) +from ops import JujuVersion, Model, Secret, SecretInfo, SecretNotFoundError from ops.charm import ( CharmBase, CharmEvents, RelationChangedEvent, + RelationCreatedEvent, RelationEvent, - RelationJoinedEvent, + SecretChangedEvent, ) from ops.framework import EventSource, Object -from ops.model import Relation +from ops.model import Application, ModelError, Relation, Unit # The unique Charmhub library identifier, never change it LIBID = "6c3e6b6680d64e9c89e611d1a15f65be" @@ -303,7 +331,9 @@ def _on_topic_requested(self, event: TopicRequestedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 6 +LIBPATCH = 34 + +PYDEPS = ["ops>=2.0.0"] logger = logging.getLogger(__name__) @@ -316,7 +346,98 @@ def _on_topic_requested(self, event: TopicRequestedEvent): deleted - key that were deleted""" -def diff(event: RelationChangedEvent, bucket: str) -> Diff: +PROV_SECRET_PREFIX = "secret-" +REQ_SECRET_FIELDS = "requested-secrets" +GROUP_MAPPING_FIELD = "secret_group_mapping" +GROUP_SEPARATOR = "@" + + +class SecretGroup(str): + """Secret groups specific type.""" + + +class SecretGroupsAggregate(str): + """Secret groups with option to extend with additional constants.""" + + def __init__(self): + self.USER = SecretGroup("user") + self.TLS = SecretGroup("tls") + self.EXTRA = SecretGroup("extra") + + def __setattr__(self, name, value): + """Setting internal constants.""" + if name in self.__dict__: + raise RuntimeError("Can't set constant!") + else: + super().__setattr__(name, SecretGroup(value)) + + def groups(self) -> list: + """Return the list of stored SecretGroups.""" + return list(self.__dict__.values()) + + def get_group(self, group: str) -> Optional[SecretGroup]: + """If the input str translates to a group name, return that.""" + return SecretGroup(group) if group in self.groups() else None + + +SECRET_GROUPS = SecretGroupsAggregate() + + +class DataInterfacesError(Exception): + """Common ancestor for DataInterfaces related exceptions.""" + + +class SecretError(DataInterfacesError): + """Common ancestor for Secrets related exceptions.""" + + +class SecretAlreadyExistsError(SecretError): + """A secret that was to be added already exists.""" + + +class SecretsUnavailableError(SecretError): + """Secrets aren't yet available for Juju version used.""" + + +class SecretsIllegalUpdateError(SecretError): + """Secrets aren't yet available for Juju version used.""" + + +class IllegalOperationError(DataInterfacesError): + """To be used when an operation is not allowed to be performed.""" + + +def get_encoded_dict( + relation: Relation, member: Union[Unit, Application], field: str +) -> Optional[Dict[str, str]]: + """Retrieve and decode an encoded field from relation data.""" + data = json.loads(relation.data[member].get(field, "{}")) + if isinstance(data, dict): + return data + logger.error("Unexpected datatype for %s instead of dict.", str(data)) + + +def get_encoded_list( + relation: Relation, member: Union[Unit, Application], field: str +) -> Optional[List[str]]: + """Retrieve and decode an encoded field from relation data.""" + data = json.loads(relation.data[member].get(field, "[]")) + if isinstance(data, list): + return data + logger.error("Unexpected datatype for %s instead of list.", str(data)) + + +def set_encoded_field( + relation: Relation, + member: Union[Unit, Application], + field: str, + value: Union[str, list, Dict[str, str]], +) -> None: + """Set an encoded field from relation data.""" + relation.data[member].update({field: json.dumps(value)}) + + +def diff(event: RelationChangedEvent, bucket: Optional[Union[Unit, Application]]) -> Diff: """Retrieves the diff of the data in the relation changed databag. Args: @@ -328,305 +449,1940 @@ def diff(event: RelationChangedEvent, bucket: str) -> Diff: keys from the event relation databag. """ # Retrieve the old data from the data key in the application relation databag. - old_data = json.loads(event.relation.data[bucket].get("data", "{}")) + if not bucket: + return Diff([], [], []) + + old_data = get_encoded_dict(event.relation, bucket, "data") + + if not old_data: + old_data = {} + # Retrieve the new data from the event relation databag. - new_data = { - key: value for key, value in event.relation.data[event.app].items() if key != "data" - } + new_data = ( + {key: value for key, value in event.relation.data[event.app].items() if key != "data"} + if event.app + else {} + ) # These are the keys that were added to the databag and triggered this event. - added = new_data.keys() - old_data.keys() + added = new_data.keys() - old_data.keys() # pyright: ignore [reportAssignmentType] # These are the keys that were removed from the databag and triggered this event. - deleted = old_data.keys() - new_data.keys() + deleted = old_data.keys() - new_data.keys() # pyright: ignore [reportAssignmentType] # These are the keys that already existed in the databag, # but had their values changed. - changed = {key for key in old_data.keys() & new_data.keys() if old_data[key] != new_data[key]} + changed = { + key + for key in old_data.keys() & new_data.keys() # pyright: ignore [reportAssignmentType] + if old_data[key] != new_data[key] # pyright: ignore [reportAssignmentType] + } # Convert the new_data to a serializable format and save it for a next diff check. - event.relation.data[bucket].update({"data": json.dumps(new_data)}) + set_encoded_field(event.relation, bucket, "data", new_data) # Return the diff with all possible changes. return Diff(added, changed, deleted) -# Base DataProvides and DataRequires +def leader_only(f): + """Decorator to ensure that only leader can perform given operation.""" + def wrapper(self, *args, **kwargs): + if self.component == self.local_app and not self.local_unit.is_leader(): + logger.error( + "This operation (%s()) can only be performed by the leader unit", f.__name__ + ) + return + return f(self, *args, **kwargs) -class DataProvides(Object, ABC): - """Base provides-side of the data products relation.""" - - def __init__(self, charm: CharmBase, relation_name: str) -> None: - super().__init__(charm, relation_name) - self.charm = charm - self.local_app = self.charm.model.app - self.local_unit = self.charm.unit - self.relation_name = relation_name - self.framework.observe( - charm.on[relation_name].relation_changed, - self._on_relation_changed, - ) + wrapper.leader_only = True + return wrapper - def _diff(self, event: RelationChangedEvent) -> Diff: - """Retrieves the diff of the data in the relation changed databag. - Args: - event: relation changed event. +def juju_secrets_only(f): + """Decorator to ensure that certain operations would be only executed on Juju3.""" - Returns: - a Diff instance containing the added, deleted and changed - keys from the event relation databag. - """ - return diff(event, self.local_app) + def wrapper(self, *args, **kwargs): + if not self.secrets_enabled: + raise SecretsUnavailableError("Secrets unavailable on current Juju version") + return f(self, *args, **kwargs) - @abstractmethod - def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Event emitted when the relation data has changed.""" - raise NotImplementedError + return wrapper - def fetch_relation_data(self) -> dict: - """Retrieves data from relation. - This function can be used to retrieve data from a relation - in the charm code when outside an event callback. +def dynamic_secrets_only(f): + """Decorator to ensure that certain operations would be only executed when NO static secrets are defined.""" - Returns: - a dict of the values stored in the relation data bag - for all relation instances (indexed by the relation id). - """ - data = {} - for relation in self.relations: - data[relation.id] = { - key: value for key, value in relation.data[relation.app].items() if key != "data" - } - return data + def wrapper(self, *args, **kwargs): + if self.static_secret_fields: + raise IllegalOperationError( + "Unsafe usage of statically and dynamically defined secrets, aborting." + ) + return f(self, *args, **kwargs) - def _update_relation_data(self, relation_id: int, data: dict) -> None: - """Updates a set of key-value pairs in the relation. + return wrapper - This function writes in the application data bag, therefore, - only the leader unit can call it. - Args: - relation_id: the identifier for a particular relation. - data: dict containing the key-value pairs - that should be updated in the relation. - """ - if self.local_unit.is_leader(): - relation = self.charm.model.get_relation(self.relation_name, relation_id) - relation.data[self.local_app].update(data) +def either_static_or_dynamic_secrets(f): + """Decorator to ensure that static and dynamic secrets won't be used in parallel.""" - @property - def relations(self) -> List[Relation]: - """The list of Relation instances associated with this relation_name.""" - return list(self.charm.model.relations[self.relation_name]) + def wrapper(self, *args, **kwargs): + if self.static_secret_fields and set(self.current_secret_fields) - set( + self.static_secret_fields + ): + raise IllegalOperationError( + "Unsafe usage of statically and dynamically defined secrets, aborting." + ) + return f(self, *args, **kwargs) - def set_credentials(self, relation_id: int, username: str, password: str) -> None: - """Set credentials. + return wrapper - This function writes in the application data bag, therefore, - only the leader unit can call it. - Args: - relation_id: the identifier for a particular relation. - username: user that was created. - password: password of the created user. - """ - self._update_relation_data( - relation_id, - { - "username": username, - "password": password, - }, - ) +class Scope(Enum): + """Peer relations scope.""" - def set_tls(self, relation_id: int, tls: str) -> None: - """Set whether TLS is enabled. + APP = "app" + UNIT = "unit" - Args: - relation_id: the identifier for a particular relation. - tls: whether tls is enabled (True or False). - """ - self._update_relation_data(relation_id, {"tls": tls}) - def set_tls_ca(self, relation_id: int, tls_ca: str) -> None: - """Set the TLS CA in the application relation databag. +################################################################################ +# Secrets internal caching +################################################################################ - Args: - relation_id: the identifier for a particular relation. - tls_ca: TLS certification authority. - """ - self._update_relation_data(relation_id, {"tls_ca": tls_ca}) +class CachedSecret: + """Locally cache a secret. -class DataRequires(Object, ABC): - """Requires-side of the relation.""" + The data structure is precisely re-using/simulating as in the actual Secret Storage + """ def __init__( self, - charm, - relation_name: str, - extra_user_roles: str = None, + model: Model, + component: Union[Application, Unit], + label: str, + secret_uri: Optional[str] = None, + legacy_labels: List[str] = [], ): - """Manager of base client relations.""" - super().__init__(charm, relation_name) - self.charm = charm - self.extra_user_roles = extra_user_roles - self.local_app = self.charm.model.app - self.local_unit = self.charm.unit - self.relation_name = relation_name - self.framework.observe( - self.charm.on[relation_name].relation_joined, self._on_relation_joined_event - ) - self.framework.observe( - self.charm.on[relation_name].relation_changed, self._on_relation_changed_event - ) + self._secret_meta = None + self._secret_content = {} + self._secret_uri = secret_uri + self.label = label + self._model = model + self.component = component + self.legacy_labels = legacy_labels + self.current_label = None + + def add_secret( + self, + content: Dict[str, str], + relation: Optional[Relation] = None, + label: Optional[str] = None, + ) -> Secret: + """Create a new secret.""" + if self._secret_uri: + raise SecretAlreadyExistsError( + "Secret is already defined with uri %s", self._secret_uri + ) - @abstractmethod - def _on_relation_joined_event(self, event: RelationJoinedEvent) -> None: - """Event emitted when the application joins the relation.""" - raise NotImplementedError + label = self.label if not label else label - @abstractmethod - def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: - raise NotImplementedError + secret = self.component.add_secret(content, label=label) + if relation and relation.app != self._model.app: + # If it's not a peer relation, grant is to be applied + secret.grant(relation) + self._secret_uri = secret.id + self._secret_meta = secret + return self._secret_meta - def fetch_relation_data(self) -> dict: - """Retrieves data from relation. + @property + def meta(self) -> Optional[Secret]: + """Getting cached secret meta-information.""" + if not self._secret_meta: + if not (self._secret_uri or self.label): + return + + for label in [self.label] + self.legacy_labels: + try: + self._secret_meta = self._model.get_secret(label=label) + except SecretNotFoundError: + pass + else: + if label != self.label: + self.current_label = label + break + + # If still not found, to be checked by URI, to be labelled with the proposed label + if not self._secret_meta and self._secret_uri: + self._secret_meta = self._model.get_secret(id=self._secret_uri, label=self.label) + return self._secret_meta + + def get_content(self) -> Dict[str, str]: + """Getting cached secret content.""" + if not self._secret_content: + if self.meta: + try: + self._secret_content = self.meta.get_content(refresh=True) + except (ValueError, ModelError) as err: + # https://bugs.launchpad.net/juju/+bug/2042596 + # Only triggered when 'refresh' is set + known_model_errors = [ + "ERROR either URI or label should be used for getting an owned secret but not both", + "ERROR secret owner cannot use --refresh", + ] + if isinstance(err, ModelError) and not any( + msg in str(err) for msg in known_model_errors + ): + raise + # Due to: ValueError: Secret owner cannot use refresh=True + self._secret_content = self.meta.get_content() + return self._secret_content + + def _move_to_new_label_if_needed(self): + """Helper function to re-create the secret with a different label.""" + if not self.current_label or not (self.meta and self._secret_meta): + return - This function can be used to retrieve data from a relation - in the charm code when outside an event callback. - Function cannot be used in `*-relation-broken` events and will raise an exception. + # Create a new secret with the new label + old_meta = self._secret_meta + content = self._secret_meta.get_content() - Returns: - a dict of the values stored in the relation data bag - for all relation instances (indexed by the relation ID). - """ - data = {} - for relation in self.relations: - data[relation.id] = { - key: value for key, value in relation.data[relation.app].items() if key != "data" - } - return data + # I wish we could just check if we are the owners of the secret... + try: + self._secret_meta = self.add_secret(content, label=self.label) + except ModelError as err: + if "this unit is not the leader" not in str(err): + raise + old_meta.remove_all_revisions() + + def set_content(self, content: Dict[str, str]) -> None: + """Setting cached secret content.""" + if not self.meta: + return - def _update_relation_data(self, relation_id: int, data: dict) -> None: - """Updates a set of key-value pairs in the relation. + if content: + self._move_to_new_label_if_needed() + self.meta.set_content(content) + self._secret_content = content + else: + self.meta.remove_all_revisions() - This function writes in the application data bag, therefore, - only the leader unit can call it. + def get_info(self) -> Optional[SecretInfo]: + """Wrapper function to apply the corresponding call on the Secret object within CachedSecret if any.""" + if self.meta: + return self.meta.get_info() - Args: - relation_id: the identifier for a particular relation. - data: dict containing the key-value pairs - that should be updated in the relation. - """ - if self.local_unit.is_leader(): - relation = self.charm.model.get_relation(self.relation_name, relation_id) - relation.data[self.local_app].update(data) + def remove(self) -> None: + """Remove secret.""" + if not self.meta: + raise SecretsUnavailableError("Non-existent secret was attempted to be removed.") + try: + self.meta.remove_all_revisions() + except SecretNotFoundError: + pass + self._secret_content = {} + self._secret_meta = None + self._secret_uri = None + + +class SecretCache: + """A data structure storing CachedSecret objects.""" + + def __init__(self, model: Model, component: Union[Application, Unit]): + self._model = model + self.component = component + self._secrets: Dict[str, CachedSecret] = {} + + def get( + self, label: str, uri: Optional[str] = None, legacy_labels: List[str] = [] + ) -> Optional[CachedSecret]: + """Getting a secret from Juju Secret store or cache.""" + if not self._secrets.get(label): + secret = CachedSecret( + self._model, self.component, label, uri, legacy_labels=legacy_labels + ) + if secret.meta: + self._secrets[label] = secret + return self._secrets.get(label) + + def add(self, label: str, content: Dict[str, str], relation: Relation) -> CachedSecret: + """Adding a secret to Juju Secret.""" + if self._secrets.get(label): + raise SecretAlreadyExistsError(f"Secret {label} already exists") + + secret = CachedSecret(self._model, self.component, label) + secret.add_secret(content, relation) + self._secrets[label] = secret + return self._secrets[label] + + def remove(self, label: str) -> None: + """Remove a secret from the cache.""" + if secret := self.get(label): + try: + secret.remove() + self._secrets.pop(label) + except (SecretsUnavailableError, KeyError): + pass + else: + return + logging.debug("Non-existing Juju Secret was attempted to be removed %s", label) - def _diff(self, event: RelationChangedEvent) -> Diff: - """Retrieves the diff of the data in the relation changed databag. - Args: - event: relation changed event. +################################################################################ +# Relation Data base/abstract ancestors (i.e. parent classes) +################################################################################ - Returns: - a Diff instance containing the added, deleted and changed - keys from the event relation databag. - """ - return diff(event, self.local_unit) + +# Base Data + + +class DataDict(UserDict): + """Python Standard Library 'dict' - like representation of Relation Data.""" + + def __init__(self, relation_data: "Data", relation_id: int): + self.relation_data = relation_data + self.relation_id = relation_id + + @property + def data(self) -> Dict[str, str]: + """Return the full content of the Abstract Relation Data dictionary.""" + result = self.relation_data.fetch_my_relation_data([self.relation_id]) + try: + result_remote = self.relation_data.fetch_relation_data([self.relation_id]) + except NotImplementedError: + result_remote = {self.relation_id: {}} + if result: + result_remote[self.relation_id].update(result[self.relation_id]) + return result_remote.get(self.relation_id, {}) + + def __setitem__(self, key: str, item: str) -> None: + """Set an item of the Abstract Relation Data dictionary.""" + self.relation_data.update_relation_data(self.relation_id, {key: item}) + + def __getitem__(self, key: str) -> str: + """Get an item of the Abstract Relation Data dictionary.""" + result = None + + # Avoiding "leader_only" error when cross-charm non-leader unit, not to report useless error + if ( + not hasattr(self.relation_data.fetch_my_relation_field, "leader_only") + or self.relation_data.component != self.relation_data.local_app + or self.relation_data.local_unit.is_leader() + ): + result = self.relation_data.fetch_my_relation_field(self.relation_id, key) + + if not result: + try: + result = self.relation_data.fetch_relation_field(self.relation_id, key) + except NotImplementedError: + pass + + if not result: + raise KeyError + return result + + def __eq__(self, d: dict) -> bool: + """Equality.""" + return self.data == d + + def __repr__(self) -> str: + """String representation Abstract Relation Data dictionary.""" + return repr(self.data) + + def __len__(self) -> int: + """Length of the Abstract Relation Data dictionary.""" + return len(self.data) + + def __delitem__(self, key: str) -> None: + """Delete an item of the Abstract Relation Data dictionary.""" + self.relation_data.delete_relation_data(self.relation_id, [key]) + + def has_key(self, key: str) -> bool: + """Does the key exist in the Abstract Relation Data dictionary?""" + return key in self.data + + def update(self, items: Dict[str, str]): + """Update the Abstract Relation Data dictionary.""" + self.relation_data.update_relation_data(self.relation_id, items) + + def keys(self) -> KeysView[str]: + """Keys of the Abstract Relation Data dictionary.""" + return self.data.keys() + + def values(self) -> ValuesView[str]: + """Values of the Abstract Relation Data dictionary.""" + return self.data.values() + + def items(self) -> ItemsView[str, str]: + """Items of the Abstract Relation Data dictionary.""" + return self.data.items() + + def pop(self, item: str) -> str: + """Pop an item of the Abstract Relation Data dictionary.""" + result = self.relation_data.fetch_my_relation_field(self.relation_id, item) + if not result: + raise KeyError(f"Item {item} doesn't exist.") + self.relation_data.delete_relation_data(self.relation_id, [item]) + return result + + def __contains__(self, item: str) -> bool: + """Does the Abstract Relation Data dictionary contain item?""" + return item in self.data.values() + + def __iter__(self): + """Iterate through the Abstract Relation Data dictionary.""" + return iter(self.data) + + def get(self, key: str, default: Optional[str] = None) -> Optional[str]: + """Safely get an item of the Abstract Relation Data dictionary.""" + try: + if result := self[key]: + return result + except KeyError: + return default + + +class Data(ABC): + """Base relation data mainpulation (abstract) class.""" + + SCOPE = Scope.APP + + # Local map to associate mappings with secrets potentially as a group + SECRET_LABEL_MAP = { + "username": SECRET_GROUPS.USER, + "password": SECRET_GROUPS.USER, + "uris": SECRET_GROUPS.USER, + "tls": SECRET_GROUPS.TLS, + "tls-ca": SECRET_GROUPS.TLS, + } + + def __init__( + self, + model: Model, + relation_name: str, + ) -> None: + self._model = model + self.local_app = self._model.app + self.local_unit = self._model.unit + self.relation_name = relation_name + self._jujuversion = None + self.component = self.local_app if self.SCOPE == Scope.APP else self.local_unit + self.secrets = SecretCache(self._model, self.component) + self.data_component = None @property def relations(self) -> List[Relation]: """The list of Relation instances associated with this relation_name.""" return [ relation - for relation in self.charm.model.relations[self.relation_name] + for relation in self._model.relations[self.relation_name] if self._is_relation_active(relation) ] + @property + def secrets_enabled(self): + """Is this Juju version allowing for Secrets usage?""" + if not self._jujuversion: + self._jujuversion = JujuVersion.from_environ() + return self._jujuversion.has_secrets + + @property + def secret_label_map(self): + """Exposing secret-label map via a property -- could be overridden in descendants!""" + return self.SECRET_LABEL_MAP + + # Mandatory overrides for internal/helper methods + + @abstractmethod + def _get_relation_secret( + self, relation_id: int, group_mapping: SecretGroup, relation_name: Optional[str] = None + ) -> Optional[CachedSecret]: + """Retrieve a Juju Secret that's been stored in the relation databag.""" + raise NotImplementedError + + @abstractmethod + def _fetch_specific_relation_data( + self, relation: Relation, fields: Optional[List[str]] + ) -> Dict[str, str]: + """Fetch data available (directily or indirectly -- i.e. secrets) from the relation.""" + raise NotImplementedError + + @abstractmethod + def _fetch_my_specific_relation_data( + self, relation: Relation, fields: Optional[List[str]] + ) -> Dict[str, str]: + """Fetch data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + raise NotImplementedError + + @abstractmethod + def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> None: + """Update data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + raise NotImplementedError + + @abstractmethod + def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: + """Delete data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + raise NotImplementedError + + # Internal helper methods + @staticmethod def _is_relation_active(relation: Relation): + """Whether the relation is active based on contained data.""" try: _ = repr(relation.data) return True - except RuntimeError: + except (RuntimeError, ModelError): return False @staticmethod - def _is_resource_created_for_relation(relation: Relation): - return ( - "username" in relation.data[relation.app] and "password" in relation.data[relation.app] - ) + def _is_secret_field(field: str) -> bool: + """Is the field in question a secret reference (URI) field or not?""" + return field.startswith(PROV_SECRET_PREFIX) - def is_resource_created(self, relation_id: Optional[int] = None) -> bool: - """Check if the resource has been created. + @staticmethod + def _generate_secret_label( + relation_name: str, relation_id: int, group_mapping: SecretGroup + ) -> str: + """Generate unique group_mappings for secrets within a relation context.""" + return f"{relation_name}.{relation_id}.{group_mapping}.secret" - This function can be used to check if the Provider answered with data in the charm code - when outside an event callback. + def _generate_secret_field_name(self, group_mapping: SecretGroup) -> str: + """Generate unique group_mappings for secrets within a relation context.""" + return f"{PROV_SECRET_PREFIX}{group_mapping}" - Args: - relation_id (int, optional): When provided the check is done only for the relation id - provided, otherwise the check is done for all relations + def _relation_from_secret_label(self, secret_label: str) -> Optional[Relation]: + """Retrieve the relation that belongs to a secret label.""" + contents = secret_label.split(".") - Returns: - True or False + if not (contents and len(contents) >= 3): + return - Raises: - IndexError: If relation_id is provided but that relation does not exist - """ - if relation_id is not None: - try: - relation = [relation for relation in self.relations if relation.id == relation_id][ - 0 - ] - return self._is_resource_created_for_relation(relation) - except IndexError: - raise IndexError(f"relation id {relation_id} cannot be accessed") - else: - return ( - all( - [ - self._is_resource_created_for_relation(relation) - for relation in self.relations - ] - ) - if self.relations - else False - ) + contents.pop() # ".secret" at the end + contents.pop() # Group mapping + relation_id = contents.pop() + try: + relation_id = int(relation_id) + except ValueError: + return + # In case '.' character appeared in relation name + relation_name = ".".join(contents) -# General events + try: + return self.get_relation(relation_name, relation_id) + except ModelError: + return + def _group_secret_fields(self, secret_fields: List[str]) -> Dict[SecretGroup, List[str]]: + """Helper function to arrange secret mappings under their group. -class ExtraRoleEvent(RelationEvent): - """Base class for data events.""" + NOTE: All unrecognized items end up in the 'extra' secret bucket. + Make sure only secret fields are passed! + """ + secret_fieldnames_grouped = {} + for key in secret_fields: + if group := self.secret_label_map.get(key): + secret_fieldnames_grouped.setdefault(group, []).append(key) + else: + secret_fieldnames_grouped.setdefault(SECRET_GROUPS.EXTRA, []).append(key) + return secret_fieldnames_grouped + + def _get_group_secret_contents( + self, + relation: Relation, + group: SecretGroup, + secret_fields: Union[Set[str], List[str]] = [], + ) -> Dict[str, str]: + """Helper function to retrieve collective, requested contents of a secret.""" + if (secret := self._get_relation_secret(relation.id, group)) and ( + secret_data := secret.get_content() + ): + return { + k: v for k, v in secret_data.items() if not secret_fields or k in secret_fields + } + return {} + + def _content_for_secret_group( + self, content: Dict[str, str], secret_fields: Set[str], group_mapping: SecretGroup + ) -> Dict[str, str]: + """Select : pairs from input, that belong to this particular Secret group.""" + if group_mapping == SECRET_GROUPS.EXTRA: + return { + k: v + for k, v in content.items() + if k in secret_fields and k not in self.secret_label_map.keys() + } - @property - def extra_user_roles(self) -> Optional[str]: - """Returns the extra user roles that were requested.""" - return self.relation.data[self.relation.app].get("extra-user-roles") + return { + k: v + for k, v in content.items() + if k in secret_fields and self.secret_label_map.get(k) == group_mapping + } + + @juju_secrets_only + def _get_relation_secret_data( + self, relation_id: int, group_mapping: SecretGroup, relation_name: Optional[str] = None + ) -> Optional[Dict[str, str]]: + """Retrieve contents of a Juju Secret that's been stored in the relation databag.""" + secret = self._get_relation_secret(relation_id, group_mapping, relation_name) + if secret: + return secret.get_content() + + # Core operations on Relation Fields manipulations (regardless whether the field is in the databag or in a secret) + # Internal functions to be called directly from transparent public interface functions (+closely related helpers) + + def _process_secret_fields( + self, + relation: Relation, + req_secret_fields: Optional[List[str]], + impacted_rel_fields: List[str], + operation: Callable, + *args, + **kwargs, + ) -> Tuple[Dict[str, str], Set[str]]: + """Isolate target secret fields of manipulation, and execute requested operation by Secret Group.""" + result = {} + + # If the relation started on a databag, we just stay on the databag + # (Rolling upgrades may result in a relation starting on databag, getting secrets enabled on-the-fly) + # self.local_app is sufficient to check (ignored if Requires, never has secrets -- works if Provider) + fallback_to_databag = ( + req_secret_fields + and (self.local_unit == self._model.unit and self.local_unit.is_leader()) + and set(req_secret_fields) & set(relation.data[self.component]) + ) + + normal_fields = set(impacted_rel_fields) + if req_secret_fields and self.secrets_enabled and not fallback_to_databag: + normal_fields = normal_fields - set(req_secret_fields) + secret_fields = set(impacted_rel_fields) - set(normal_fields) + + secret_fieldnames_grouped = self._group_secret_fields(list(secret_fields)) + + for group in secret_fieldnames_grouped: + # operation() should return nothing when all goes well + if group_result := operation(relation, group, secret_fields, *args, **kwargs): + # If "meaningful" data was returned, we take it. (Some 'operation'-s only return success/failure.) + if isinstance(group_result, dict): + result.update(group_result) + else: + # If it wasn't found as a secret, let's give it a 2nd chance as "normal" field + # Needed when Juju3 Requires meets Juju2 Provider + normal_fields |= set(secret_fieldnames_grouped[group]) + return (result, normal_fields) + + def _fetch_relation_data_without_secrets( + self, component: Union[Application, Unit], relation: Relation, fields: Optional[List[str]] + ) -> Dict[str, str]: + """Fetching databag contents when no secrets are involved. + + Since the Provider's databag is the only one holding secrest, we can apply + a simplified workflow to read the Require's side's databag. + This is used typically when the Provider side wants to read the Requires side's data, + or when the Requires side may want to read its own data. + """ + if component not in relation.data or not relation.data[component]: + return {} + + if fields: + return { + k: relation.data[component][k] for k in fields if k in relation.data[component] + } + else: + return dict(relation.data[component]) + + def _fetch_relation_data_with_secrets( + self, + component: Union[Application, Unit], + req_secret_fields: Optional[List[str]], + relation: Relation, + fields: Optional[List[str]] = None, + ) -> Dict[str, str]: + """Fetching databag contents when secrets may be involved. + + This function has internal logic to resolve if a requested field may be "hidden" + within a Relation Secret, or directly available as a databag field. Typically + used to read the Provider side's databag (eigher by the Requires side, or by + Provider side itself). + """ + result = {} + normal_fields = [] + + if not fields: + if component not in relation.data: + return {} + + all_fields = list(relation.data[component].keys()) + normal_fields = [field for field in all_fields if not self._is_secret_field(field)] + fields = normal_fields + req_secret_fields if req_secret_fields else normal_fields + + if fields: + result, normal_fields = self._process_secret_fields( + relation, req_secret_fields, fields, self._get_group_secret_contents + ) + + # Processing "normal" fields. May include leftover from what we couldn't retrieve as a secret. + # (Typically when Juju3 Requires meets Juju2 Provider) + if normal_fields: + result.update( + self._fetch_relation_data_without_secrets(component, relation, list(normal_fields)) + ) + return result + + def _update_relation_data_without_secrets( + self, component: Union[Application, Unit], relation: Relation, data: Dict[str, str] + ) -> None: + """Updating databag contents when no secrets are involved.""" + if component not in relation.data or relation.data[component] is None: + return + + if relation: + relation.data[component].update(data) + + def _delete_relation_data_without_secrets( + self, component: Union[Application, Unit], relation: Relation, fields: List[str] + ) -> None: + """Remove databag fields 'fields' from Relation.""" + if component not in relation.data or relation.data[component] is None: + return + + for field in fields: + try: + relation.data[component].pop(field) + except KeyError: + logger.debug( + "Non-existing field '%s' was attempted to be removed from the databag (relation ID: %s)", + str(field), + str(relation.id), + ) + pass + + # Public interface methods + # Handling Relation Fields seamlessly, regardless if in databag or a Juju Secret + + def as_dict(self, relation_id: int) -> UserDict: + """Dict behavior representation of the Abstract Data.""" + return DataDict(self, relation_id) + + def get_relation(self, relation_name, relation_id) -> Relation: + """Safe way of retrieving a relation.""" + relation = self._model.get_relation(relation_name, relation_id) + + if not relation: + raise DataInterfacesError( + "Relation %s %s couldn't be retrieved", relation_name, relation_id + ) + + return relation + + def fetch_relation_data( + self, + relation_ids: Optional[List[int]] = None, + fields: Optional[List[str]] = None, + relation_name: Optional[str] = None, + ) -> Dict[int, Dict[str, str]]: + """Retrieves data from relation. + + This function can be used to retrieve data from a relation + in the charm code when outside an event callback. + Function cannot be used in `*-relation-broken` events and will raise an exception. + + Returns: + a dict of the values stored in the relation data bag + for all relation instances (indexed by the relation ID). + """ + if not relation_name: + relation_name = self.relation_name + + relations = [] + if relation_ids: + relations = [ + self.get_relation(relation_name, relation_id) for relation_id in relation_ids + ] + else: + relations = self.relations + + data = {} + for relation in relations: + if not relation_ids or (relation_ids and relation.id in relation_ids): + data[relation.id] = self._fetch_specific_relation_data(relation, fields) + return data + + def fetch_relation_field( + self, relation_id: int, field: str, relation_name: Optional[str] = None + ) -> Optional[str]: + """Get a single field from the relation data.""" + return ( + self.fetch_relation_data([relation_id], [field], relation_name) + .get(relation_id, {}) + .get(field) + ) + + def fetch_my_relation_data( + self, + relation_ids: Optional[List[int]] = None, + fields: Optional[List[str]] = None, + relation_name: Optional[str] = None, + ) -> Optional[Dict[int, Dict[str, str]]]: + """Fetch data of the 'owner' (or 'this app') side of the relation. + + NOTE: Since only the leader can read the relation's 'this_app'-side + Application databag, the functionality is limited to leaders + """ + if not relation_name: + relation_name = self.relation_name + + relations = [] + if relation_ids: + relations = [ + self.get_relation(relation_name, relation_id) for relation_id in relation_ids + ] + else: + relations = self.relations + + data = {} + for relation in relations: + if not relation_ids or relation.id in relation_ids: + data[relation.id] = self._fetch_my_specific_relation_data(relation, fields) + return data + + def fetch_my_relation_field( + self, relation_id: int, field: str, relation_name: Optional[str] = None + ) -> Optional[str]: + """Get a single field from the relation data -- owner side. + + NOTE: Since only the leader can read the relation's 'this_app'-side + Application databag, the functionality is limited to leaders + """ + if relation_data := self.fetch_my_relation_data([relation_id], [field], relation_name): + return relation_data.get(relation_id, {}).get(field) + + @leader_only + def update_relation_data(self, relation_id: int, data: dict) -> None: + """Update the data within the relation.""" + relation_name = self.relation_name + relation = self.get_relation(relation_name, relation_id) + return self._update_relation_data(relation, data) + + @leader_only + def delete_relation_data(self, relation_id: int, fields: List[str]) -> None: + """Remove field from the relation.""" + relation_name = self.relation_name + relation = self.get_relation(relation_name, relation_id) + return self._delete_relation_data(relation, fields) + + +class EventHandlers(Object): + """Requires-side of the relation.""" + + def __init__(self, charm: CharmBase, relation_data: Data, unique_key: str = ""): + """Manager of base client relations.""" + if not unique_key: + unique_key = relation_data.relation_name + super().__init__(charm, unique_key) + + self.charm = charm + self.relation_data = relation_data + + self.framework.observe( + charm.on[self.relation_data.relation_name].relation_changed, + self._on_relation_changed_event, + ) + + def _diff(self, event: RelationChangedEvent) -> Diff: + """Retrieves the diff of the data in the relation changed databag. + + Args: + event: relation changed event. + + Returns: + a Diff instance containing the added, deleted and changed + keys from the event relation databag. + """ + return diff(event, self.relation_data.data_component) + + @abstractmethod + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation data has changed.""" + raise NotImplementedError + + +# Base ProviderData and RequiresData + + +class ProviderData(Data): + """Base provides-side of the data products relation.""" + + def __init__( + self, + model: Model, + relation_name: str, + ) -> None: + super().__init__(model, relation_name) + self.data_component = self.local_app + + # Private methods handling secrets + + @juju_secrets_only + def _add_relation_secret( + self, + relation: Relation, + group_mapping: SecretGroup, + secret_fields: Set[str], + data: Dict[str, str], + uri_to_databag=True, + ) -> bool: + """Add a new Juju Secret that will be registered in the relation databag.""" + secret_field = self._generate_secret_field_name(group_mapping) + if uri_to_databag and relation.data[self.component].get(secret_field): + logging.error("Secret for relation %s already exists, not adding again", relation.id) + return False + + content = self._content_for_secret_group(data, secret_fields, group_mapping) + + label = self._generate_secret_label(self.relation_name, relation.id, group_mapping) + secret = self.secrets.add(label, content, relation) + + # According to lint we may not have a Secret ID + if uri_to_databag and secret.meta and secret.meta.id: + relation.data[self.component][secret_field] = secret.meta.id + + # Return the content that was added + return True + + @juju_secrets_only + def _update_relation_secret( + self, + relation: Relation, + group_mapping: SecretGroup, + secret_fields: Set[str], + data: Dict[str, str], + ) -> bool: + """Update the contents of an existing Juju Secret, referred in the relation databag.""" + secret = self._get_relation_secret(relation.id, group_mapping) + + if not secret: + logging.error("Can't update secret for relation %s", relation.id) + return False + + content = self._content_for_secret_group(data, secret_fields, group_mapping) + + old_content = secret.get_content() + full_content = copy.deepcopy(old_content) + full_content.update(content) + secret.set_content(full_content) + + # Return True on success + return True + + def _add_or_update_relation_secrets( + self, + relation: Relation, + group: SecretGroup, + secret_fields: Set[str], + data: Dict[str, str], + uri_to_databag=True, + ) -> bool: + """Update contents for Secret group. If the Secret doesn't exist, create it.""" + if self._get_relation_secret(relation.id, group): + return self._update_relation_secret(relation, group, secret_fields, data) + else: + return self._add_relation_secret(relation, group, secret_fields, data, uri_to_databag) + + @juju_secrets_only + def _delete_relation_secret( + self, relation: Relation, group: SecretGroup, secret_fields: List[str], fields: List[str] + ) -> bool: + """Update the contents of an existing Juju Secret, referred in the relation databag.""" + secret = self._get_relation_secret(relation.id, group) + + if not secret: + logging.error("Can't delete secret for relation %s", str(relation.id)) + return False + + old_content = secret.get_content() + new_content = copy.deepcopy(old_content) + for field in fields: + try: + new_content.pop(field) + except KeyError: + logging.debug( + "Non-existing secret was attempted to be removed %s, %s", + str(relation.id), + str(field), + ) + return False + + # Remove secret from the relation if it's fully gone + if not new_content: + field = self._generate_secret_field_name(group) + try: + relation.data[self.component].pop(field) + except KeyError: + pass + label = self._generate_secret_label(self.relation_name, relation.id, group) + self.secrets.remove(label) + else: + secret.set_content(new_content) + + # Return the content that was removed + return True + + # Mandatory internal overrides + + @juju_secrets_only + def _get_relation_secret( + self, relation_id: int, group_mapping: SecretGroup, relation_name: Optional[str] = None + ) -> Optional[CachedSecret]: + """Retrieve a Juju Secret that's been stored in the relation databag.""" + if not relation_name: + relation_name = self.relation_name + + label = self._generate_secret_label(relation_name, relation_id, group_mapping) + if secret := self.secrets.get(label): + return secret + + relation = self._model.get_relation(relation_name, relation_id) + if not relation: + return + + secret_field = self._generate_secret_field_name(group_mapping) + if secret_uri := relation.data[self.local_app].get(secret_field): + return self.secrets.get(label, secret_uri) + + def _fetch_specific_relation_data( + self, relation: Relation, fields: Optional[List[str]] + ) -> Dict[str, str]: + """Fetching relation data for Provider. + + NOTE: Since all secret fields are in the Provider side of the databag, we don't need to worry about that + """ + if not relation.app: + return {} + + return self._fetch_relation_data_without_secrets(relation.app, relation, fields) + + def _fetch_my_specific_relation_data( + self, relation: Relation, fields: Optional[List[str]] + ) -> dict: + """Fetching our own relation data.""" + secret_fields = None + if relation.app: + secret_fields = get_encoded_list(relation, relation.app, REQ_SECRET_FIELDS) + + return self._fetch_relation_data_with_secrets( + self.local_app, + secret_fields, + relation, + fields, + ) + + def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> None: + """Set values for fields not caring whether it's a secret or not.""" + req_secret_fields = [] + if relation.app: + req_secret_fields = get_encoded_list(relation, relation.app, REQ_SECRET_FIELDS) + + _, normal_fields = self._process_secret_fields( + relation, + req_secret_fields, + list(data), + self._add_or_update_relation_secrets, + data=data, + ) + + normal_content = {k: v for k, v in data.items() if k in normal_fields} + self._update_relation_data_without_secrets(self.local_app, relation, normal_content) + + def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: + """Delete fields from the Relation not caring whether it's a secret or not.""" + req_secret_fields = [] + if relation.app: + req_secret_fields = get_encoded_list(relation, relation.app, REQ_SECRET_FIELDS) + + _, normal_fields = self._process_secret_fields( + relation, req_secret_fields, fields, self._delete_relation_secret, fields=fields + ) + self._delete_relation_data_without_secrets(self.local_app, relation, list(normal_fields)) + + # Public methods - "native" + + def set_credentials(self, relation_id: int, username: str, password: str) -> None: + """Set credentials. + + This function writes in the application data bag, therefore, + only the leader unit can call it. + + Args: + relation_id: the identifier for a particular relation. + username: user that was created. + password: password of the created user. + """ + self.update_relation_data(relation_id, {"username": username, "password": password}) + + def set_tls(self, relation_id: int, tls: str) -> None: + """Set whether TLS is enabled. + + Args: + relation_id: the identifier for a particular relation. + tls: whether tls is enabled (True or False). + """ + self.update_relation_data(relation_id, {"tls": tls}) + + def set_tls_ca(self, relation_id: int, tls_ca: str) -> None: + """Set the TLS CA in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + tls_ca: TLS certification authority. + """ + self.update_relation_data(relation_id, {"tls-ca": tls_ca}) + + # Public functions -- inherited + + fetch_my_relation_data = leader_only(Data.fetch_my_relation_data) + fetch_my_relation_field = leader_only(Data.fetch_my_relation_field) + + +class RequirerData(Data): + """Requirer-side of the relation.""" + + SECRET_FIELDS = ["username", "password", "tls", "tls-ca", "uris"] + + def __init__( + self, + model, + relation_name: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + ): + """Manager of base client relations.""" + super().__init__(model, relation_name) + self.extra_user_roles = extra_user_roles + self._secret_fields = list(self.SECRET_FIELDS) + if additional_secret_fields: + self._secret_fields += additional_secret_fields + self.data_component = self.local_unit + + @property + def secret_fields(self) -> Optional[List[str]]: + """Local access to secrets field, in case they are being used.""" + if self.secrets_enabled: + return self._secret_fields + + # Internal helper functions + + def _register_secret_to_relation( + self, relation_name: str, relation_id: int, secret_id: str, group: SecretGroup + ): + """Fetch secrets and apply local label on them. + + [MAGIC HERE] + If we fetch a secret using get_secret(id=, label=), + then will be "stuck" on the Secret object, whenever it may + appear (i.e. as an event attribute, or fetched manually) on future occasions. + + This will allow us to uniquely identify the secret on Provider side (typically on + 'secret-changed' events), and map it to the corresponding relation. + """ + label = self._generate_secret_label(relation_name, relation_id, group) + + # Fetchin the Secret's meta information ensuring that it's locally getting registered with + CachedSecret(self._model, self.component, label, secret_id).meta + + def _register_secrets_to_relation(self, relation: Relation, params_name_list: List[str]): + """Make sure that secrets of the provided list are locally 'registered' from the databag. + + More on 'locally registered' magic is described in _register_secret_to_relation() method + """ + if not relation.app: + return + + for group in SECRET_GROUPS.groups(): + secret_field = self._generate_secret_field_name(group) + if secret_field in params_name_list: + if secret_uri := relation.data[relation.app].get(secret_field): + self._register_secret_to_relation( + relation.name, relation.id, secret_uri, group + ) + + def _is_resource_created_for_relation(self, relation: Relation) -> bool: + if not relation.app: + return False + + data = self.fetch_relation_data([relation.id], ["username", "password"]).get( + relation.id, {} + ) + return bool(data.get("username")) and bool(data.get("password")) + + def is_resource_created(self, relation_id: Optional[int] = None) -> bool: + """Check if the resource has been created. + + This function can be used to check if the Provider answered with data in the charm code + when outside an event callback. + + Args: + relation_id (int, optional): When provided the check is done only for the relation id + provided, otherwise the check is done for all relations + + Returns: + True or False + + Raises: + IndexError: If relation_id is provided but that relation does not exist + """ + if relation_id is not None: + try: + relation = [relation for relation in self.relations if relation.id == relation_id][ + 0 + ] + return self._is_resource_created_for_relation(relation) + except IndexError: + raise IndexError(f"relation id {relation_id} cannot be accessed") + else: + return ( + all( + self._is_resource_created_for_relation(relation) for relation in self.relations + ) + if self.relations + else False + ) + + # Mandatory internal overrides + + @juju_secrets_only + def _get_relation_secret( + self, relation_id: int, group: SecretGroup, relation_name: Optional[str] = None + ) -> Optional[CachedSecret]: + """Retrieve a Juju Secret that's been stored in the relation databag.""" + if not relation_name: + relation_name = self.relation_name + + label = self._generate_secret_label(relation_name, relation_id, group) + return self.secrets.get(label) + + def _fetch_specific_relation_data( + self, relation, fields: Optional[List[str]] = None + ) -> Dict[str, str]: + """Fetching Requirer data -- that may include secrets.""" + if not relation.app: + return {} + return self._fetch_relation_data_with_secrets( + relation.app, self.secret_fields, relation, fields + ) + + def _fetch_my_specific_relation_data(self, relation, fields: Optional[List[str]]) -> dict: + """Fetching our own relation data.""" + return self._fetch_relation_data_without_secrets(self.local_app, relation, fields) + + def _update_relation_data(self, relation: Relation, data: dict) -> None: + """Updates a set of key-value pairs in the relation. + + This function writes in the application data bag, therefore, + only the leader unit can call it. + + Args: + relation: the particular relation. + data: dict containing the key-value pairs + that should be updated in the relation. + """ + return self._update_relation_data_without_secrets(self.local_app, relation, data) + + def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: + """Deletes a set of fields from the relation. + + This function writes in the application data bag, therefore, + only the leader unit can call it. + + Args: + relation: the particular relation. + fields: list containing the field names that should be removed from the relation. + """ + return self._delete_relation_data_without_secrets(self.local_app, relation, fields) + + # Public functions -- inherited + + fetch_my_relation_data = leader_only(Data.fetch_my_relation_data) + fetch_my_relation_field = leader_only(Data.fetch_my_relation_field) + + +class RequirerEventHandlers(EventHandlers): + """Requires-side of the relation.""" + + def __init__(self, charm: CharmBase, relation_data: RequirerData, unique_key: str = ""): + """Manager of base client relations.""" + super().__init__(charm, relation_data, unique_key) + + self.framework.observe( + self.charm.on[relation_data.relation_name].relation_created, + self._on_relation_created_event, + ) + self.framework.observe( + charm.on.secret_changed, + self._on_secret_changed_event, + ) + + # Event handlers + + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: + """Event emitted when the relation is created.""" + if not self.relation_data.local_unit.is_leader(): + return + + if self.relation_data.secret_fields: # pyright: ignore [reportAttributeAccessIssue] + set_encoded_field( + event.relation, + self.relation_data.component, + REQ_SECRET_FIELDS, + self.relation_data.secret_fields, # pyright: ignore [reportAttributeAccessIssue] + ) + + @abstractmethod + def _on_secret_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation data has changed.""" + raise NotImplementedError + + +################################################################################ +# Peer Relation Data +################################################################################ + + +class DataPeerData(RequirerData, ProviderData): + """Represents peer relations data.""" + + SECRET_FIELDS = [] + SECRET_FIELD_NAME = "internal_secret" + SECRET_LABEL_MAP = {} + + def __init__( + self, + model, + relation_name: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + additional_secret_group_mapping: Dict[str, str] = {}, + secret_field_name: Optional[str] = None, + deleted_label: Optional[str] = None, + ): + """Manager of base client relations.""" + RequirerData.__init__( + self, + model, + relation_name, + extra_user_roles, + additional_secret_fields, + ) + self.secret_field_name = secret_field_name if secret_field_name else self.SECRET_FIELD_NAME + self.deleted_label = deleted_label + self._secret_label_map = {} + # Secrets that are being dynamically added within the scope of this event handler run + self._new_secrets = [] + self._additional_secret_group_mapping = additional_secret_group_mapping + + for group, fields in additional_secret_group_mapping.items(): + if group not in SECRET_GROUPS.groups(): + setattr(SECRET_GROUPS, group, group) + for field in fields: + secret_group = SECRET_GROUPS.get_group(group) + internal_field = self._field_to_internal_name(field, secret_group) + self._secret_label_map.setdefault(group, []).append(internal_field) + self._secret_fields.append(internal_field) + + @property + def scope(self) -> Optional[Scope]: + """Turn component information into Scope.""" + if isinstance(self.component, Application): + return Scope.APP + if isinstance(self.component, Unit): + return Scope.UNIT + + @property + def secret_label_map(self) -> Dict[str, str]: + """Property storing secret mappings.""" + return self._secret_label_map + + @property + def static_secret_fields(self) -> List[str]: + """Re-definition of the property in a way that dynamically extended list is retrieved.""" + return self._secret_fields + + @property + def secret_fields(self) -> List[str]: + """Re-definition of the property in a way that dynamically extended list is retrieved.""" + return ( + self.static_secret_fields if self.static_secret_fields else self.current_secret_fields + ) + + @property + def current_secret_fields(self) -> List[str]: + """Helper method to get all currently existing secret fields (added statically or dynamically).""" + if not self.secrets_enabled: + return [] + + if len(self._model.relations[self.relation_name]) > 1: + raise ValueError(f"More than one peer relation on {self.relation_name}") + + relation = self._model.relations[self.relation_name][0] + fields = [] + + ignores = [SECRET_GROUPS.get_group("user"), SECRET_GROUPS.get_group("tls")] + for group in SECRET_GROUPS.groups(): + if group in ignores: + continue + if content := self._get_group_secret_contents(relation, group): + fields += list(content.keys()) + return list(set(fields) | set(self._new_secrets)) + + @dynamic_secrets_only + def set_secret( + self, + relation_id: int, + field: str, + value: str, + group_mapping: Optional[SecretGroup] = None, + ) -> None: + """Public interface method to add a Relation Data field specifically as a Juju Secret. + + Args: + relation_id: ID of the relation + field: The secret field that is to be added + value: The string value of the secret + group_mapping: The name of the "secret group", in case the field is to be added to an existing secret + """ + full_field = self._field_to_internal_name(field, group_mapping) + if self.secrets_enabled and full_field not in self.current_secret_fields: + self._new_secrets.append(full_field) + if self._no_group_with_databag(field, full_field): + self.update_relation_data(relation_id, {full_field: value}) + + # Unlike for set_secret(), there's no harm using this operation with static secrets + # The restricion is only added to keep the concept clear + @dynamic_secrets_only + def get_secret( + self, + relation_id: int, + field: str, + group_mapping: Optional[SecretGroup] = None, + ) -> Optional[str]: + """Public interface method to fetch secrets only.""" + full_field = self._field_to_internal_name(field, group_mapping) + if ( + self.secrets_enabled + and full_field not in self.current_secret_fields + and field not in self.current_secret_fields + ): + return + if self._no_group_with_databag(field, full_field): + return self.fetch_my_relation_field(relation_id, full_field) + + @dynamic_secrets_only + def delete_secret( + self, + relation_id: int, + field: str, + group_mapping: Optional[SecretGroup] = None, + ) -> Optional[str]: + """Public interface method to delete secrets only.""" + full_field = self._field_to_internal_name(field, group_mapping) + if self.secrets_enabled and full_field not in self.current_secret_fields: + logger.warning(f"Secret {field} from group {group_mapping} was not found") + return + if self._no_group_with_databag(field, full_field): + self.delete_relation_data(relation_id, [full_field]) + + # Helpers + + @staticmethod + def _field_to_internal_name(field: str, group: Optional[SecretGroup]) -> str: + if not group or group == SECRET_GROUPS.EXTRA: + return field + return f"{field}{GROUP_SEPARATOR}{group}" + + @staticmethod + def _internal_name_to_field(name: str) -> Tuple[str, SecretGroup]: + parts = name.split(GROUP_SEPARATOR) + if not len(parts) > 1: + return (parts[0], SECRET_GROUPS.EXTRA) + secret_group = SECRET_GROUPS.get_group(parts[1]) + if not secret_group: + raise ValueError(f"Invalid secret field {name}") + return (parts[0], secret_group) + + def _group_secret_fields(self, secret_fields: List[str]) -> Dict[SecretGroup, List[str]]: + """Helper function to arrange secret mappings under their group. + + NOTE: All unrecognized items end up in the 'extra' secret bucket. + Make sure only secret fields are passed! + """ + secret_fieldnames_grouped = {} + for key in secret_fields: + field, group = self._internal_name_to_field(key) + secret_fieldnames_grouped.setdefault(group, []).append(field) + return secret_fieldnames_grouped + + def _content_for_secret_group( + self, content: Dict[str, str], secret_fields: Set[str], group_mapping: SecretGroup + ) -> Dict[str, str]: + """Select : pairs from input, that belong to this particular Secret group.""" + if group_mapping == SECRET_GROUPS.EXTRA: + return {k: v for k, v in content.items() if k in self.secret_fields} + return { + self._internal_name_to_field(k)[0]: v + for k, v in content.items() + if k in self.secret_fields + } + + # Backwards compatibility + + def _check_deleted_label(self, relation, fields) -> None: + """Helper function for legacy behavior.""" + current_data = self.fetch_my_relation_data([relation.id], fields) + if current_data is not None: + # Check if the secret we wanna delete actually exists + # Given the "deleted label", here we can't rely on the default mechanism (i.e. 'key not found') + if non_existent := (set(fields) & set(self.secret_fields)) - set( + current_data.get(relation.id, []) + ): + logger.debug( + "Non-existing secret %s was attempted to be removed.", + ", ".join(non_existent), + ) + + def _remove_secret_from_databag(self, relation, fields: List[str]) -> None: + """For Rolling Upgrades -- when moving from databag to secrets usage. + + Practically what happens here is to remove stuff from the databag that is + to be stored in secrets. + """ + if not self.secret_fields: + return + + secret_fields_passed = set(self.secret_fields) & set(fields) + for field in secret_fields_passed: + if self._fetch_relation_data_without_secrets(self.component, relation, [field]): + self._delete_relation_data_without_secrets(self.component, relation, [field]) + + def _remove_secret_field_name_from_databag(self, relation) -> None: + """Making sure that the old databag URI is gone. + + This action should not be executed more than once. + """ + # Nothing to do if 'internal-secret' is not in the databag + if not (relation.data[self.component].get(self._generate_secret_field_name())): + return + + # Making sure that the secret receives its label + # (This should have happened by the time we get here, rather an extra security measure.) + secret = self._get_relation_secret(relation.id) + + # Either app scope secret with leader executing, or unit scope secret + leader_or_unit_scope = self.component != self.local_app or self.local_unit.is_leader() + if secret and leader_or_unit_scope: + # Databag reference to the secret URI can be removed, now that it's labelled + relation.data[self.component].pop(self._generate_secret_field_name(), None) + + def _previous_labels(self) -> List[str]: + """Generator for legacy secret label names, for backwards compatibility.""" + result = [] + members = [self._model.app.name] + if self.scope: + members.append(self.scope.value) + result.append(f"{'.'.join(members)}") + return result + + def _no_group_with_databag(self, field: str, full_field: str) -> bool: + """Check that no secret group is attempted to be used together with databag.""" + if not self.secrets_enabled and full_field != field: + logger.error( + f"Can't access {full_field}: no secrets available (i.e. no secret groups either)." + ) + return False + return True + + # Event handlers + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + pass + + def _on_secret_changed_event(self, event: SecretChangedEvent) -> None: + """Event emitted when the secret has changed.""" + pass + + # Overrides of Relation Data handling functions + + def _generate_secret_label( + self, relation_name: str, relation_id: int, group_mapping: SecretGroup + ) -> str: + members = [relation_name, self._model.app.name] + if self.scope: + members.append(self.scope.value) + if group_mapping != SECRET_GROUPS.EXTRA: + members.append(group_mapping) + return f"{'.'.join(members)}" + + def _generate_secret_field_name(self, group_mapping: SecretGroup = SECRET_GROUPS.EXTRA) -> str: + """Generate unique group_mappings for secrets within a relation context.""" + return f"{self.secret_field_name}" + + @juju_secrets_only + def _get_relation_secret( + self, + relation_id: int, + group_mapping: SecretGroup = SECRET_GROUPS.EXTRA, + relation_name: Optional[str] = None, + ) -> Optional[CachedSecret]: + """Retrieve a Juju Secret specifically for peer relations. + + In case this code may be executed within a rolling upgrade, and we may need to + migrate secrets from the databag to labels, we make sure to stick the correct + label on the secret, and clean up the local databag. + """ + if not relation_name: + relation_name = self.relation_name + + relation = self._model.get_relation(relation_name, relation_id) + if not relation: + return + + label = self._generate_secret_label(relation_name, relation_id, group_mapping) + secret_uri = relation.data[self.component].get(self._generate_secret_field_name(), None) + + # URI or legacy label is only to applied when moving single legacy secret to a (new) label + if group_mapping == SECRET_GROUPS.EXTRA: + # Fetching the secret with fallback to URI (in case label is not yet known) + # Label would we "stuck" on the secret in case it is found + return self.secrets.get(label, secret_uri, legacy_labels=self._previous_labels()) + return self.secrets.get(label) + + def _get_group_secret_contents( + self, + relation: Relation, + group: SecretGroup, + secret_fields: Union[Set[str], List[str]] = [], + ) -> Dict[str, str]: + """Helper function to retrieve collective, requested contents of a secret.""" + secret_fields = [self._internal_name_to_field(k)[0] for k in secret_fields] + result = super()._get_group_secret_contents(relation, group, secret_fields) + if self.deleted_label: + result = {key: result[key] for key in result if result[key] != self.deleted_label} + if self._additional_secret_group_mapping: + return {self._field_to_internal_name(key, group): result[key] for key in result} + return result + + @either_static_or_dynamic_secrets + def _fetch_my_specific_relation_data( + self, relation: Relation, fields: Optional[List[str]] + ) -> Dict[str, str]: + """Fetch data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + return self._fetch_relation_data_with_secrets( + self.component, self.secret_fields, relation, fields + ) + + @either_static_or_dynamic_secrets + def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> None: + """Update data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + self._remove_secret_from_databag(relation, list(data.keys())) + _, normal_fields = self._process_secret_fields( + relation, + self.secret_fields, + list(data), + self._add_or_update_relation_secrets, + data=data, + uri_to_databag=False, + ) + self._remove_secret_field_name_from_databag(relation) + + normal_content = {k: v for k, v in data.items() if k in normal_fields} + self._update_relation_data_without_secrets(self.component, relation, normal_content) + + @either_static_or_dynamic_secrets + def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: + """Delete data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + if self.secret_fields and self.deleted_label: + # Legacy, backwards compatibility + self._check_deleted_label(relation, fields) + + _, normal_fields = self._process_secret_fields( + relation, + self.secret_fields, + fields, + self._update_relation_secret, + data={field: self.deleted_label for field in fields}, + ) + else: + _, normal_fields = self._process_secret_fields( + relation, self.secret_fields, fields, self._delete_relation_secret, fields=fields + ) + self._delete_relation_data_without_secrets(self.component, relation, list(normal_fields)) + + def fetch_relation_data( + self, + relation_ids: Optional[List[int]] = None, + fields: Optional[List[str]] = None, + relation_name: Optional[str] = None, + ) -> Dict[int, Dict[str, str]]: + """This method makes no sense for a Peer Relation.""" + raise NotImplementedError( + "Peer Relation only supports 'self-side' fetch methods: " + "fetch_my_relation_data() and fetch_my_relation_field()" + ) + + def fetch_relation_field( + self, relation_id: int, field: str, relation_name: Optional[str] = None + ) -> Optional[str]: + """This method makes no sense for a Peer Relation.""" + raise NotImplementedError( + "Peer Relation only supports 'self-side' fetch methods: " + "fetch_my_relation_data() and fetch_my_relation_field()" + ) + + # Public functions -- inherited + + fetch_my_relation_data = Data.fetch_my_relation_data + fetch_my_relation_field = Data.fetch_my_relation_field + + +class DataPeerEventHandlers(RequirerEventHandlers): + """Requires-side of the relation.""" + + def __init__(self, charm: CharmBase, relation_data: RequirerData, unique_key: str = ""): + """Manager of base client relations.""" + super().__init__(charm, relation_data, unique_key) + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + pass + + def _on_secret_changed_event(self, event: SecretChangedEvent) -> None: + """Event emitted when the secret has changed.""" + pass + + +class DataPeer(DataPeerData, DataPeerEventHandlers): + """Represents peer relations.""" + + def __init__( + self, + charm, + relation_name: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + additional_secret_group_mapping: Dict[str, str] = {}, + secret_field_name: Optional[str] = None, + deleted_label: Optional[str] = None, + unique_key: str = "", + ): + DataPeerData.__init__( + self, + charm.model, + relation_name, + extra_user_roles, + additional_secret_fields, + additional_secret_group_mapping, + secret_field_name, + deleted_label, + ) + DataPeerEventHandlers.__init__(self, charm, self, unique_key) + + +class DataPeerUnitData(DataPeerData): + """Unit data abstraction representation.""" + + SCOPE = Scope.UNIT + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class DataPeerUnit(DataPeerUnitData, DataPeerEventHandlers): + """Unit databag representation.""" + + def __init__( + self, + charm, + relation_name: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + additional_secret_group_mapping: Dict[str, str] = {}, + secret_field_name: Optional[str] = None, + deleted_label: Optional[str] = None, + unique_key: str = "", + ): + DataPeerData.__init__( + self, + charm.model, + relation_name, + extra_user_roles, + additional_secret_fields, + additional_secret_group_mapping, + secret_field_name, + deleted_label, + ) + DataPeerEventHandlers.__init__(self, charm, self, unique_key) + + +class DataPeerOtherUnitData(DataPeerUnitData): + """Unit data abstraction representation.""" + + def __init__(self, unit: Unit, *args, **kwargs): + super().__init__(*args, **kwargs) + self.local_unit = unit + self.component = unit + def update_relation_data(self, relation_id: int, data: dict) -> None: + """This method makes no sense for a Other Peer Relation.""" + raise NotImplementedError("It's not possible to update data of another unit.") -class AuthenticationEvent(RelationEvent): - """Base class for authentication fields for events.""" + def delete_relation_data(self, relation_id: int, fields: List[str]) -> None: + """This method makes no sense for a Other Peer Relation.""" + raise NotImplementedError("It's not possible to delete data of another unit.") + + +class DataPeerOtherUnitEventHandlers(DataPeerEventHandlers): + """Requires-side of the relation.""" + + def __init__(self, charm: CharmBase, relation_data: DataPeerUnitData): + """Manager of base client relations.""" + unique_key = f"{relation_data.relation_name}-{relation_data.local_unit.name}" + super().__init__(charm, relation_data, unique_key=unique_key) + + +class DataPeerOtherUnit(DataPeerOtherUnitData, DataPeerOtherUnitEventHandlers): + """Unit databag representation for another unit than the executor.""" + + def __init__( + self, + unit: Unit, + charm: CharmBase, + relation_name: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + additional_secret_group_mapping: Dict[str, str] = {}, + secret_field_name: Optional[str] = None, + deleted_label: Optional[str] = None, + ): + DataPeerOtherUnitData.__init__( + self, + unit, + charm.model, + relation_name, + extra_user_roles, + additional_secret_fields, + additional_secret_group_mapping, + secret_field_name, + deleted_label, + ) + DataPeerOtherUnitEventHandlers.__init__(self, charm, self) + + +################################################################################ +# Cross-charm Relatoins Data Handling and Evenets +################################################################################ + +# Generic events + + +class ExtraRoleEvent(RelationEvent): + """Base class for data events.""" + + @property + def extra_user_roles(self) -> Optional[str]: + """Returns the extra user roles that were requested.""" + if not self.relation.app: + return None + + return self.relation.data[self.relation.app].get("extra-user-roles") + + +class RelationEventWithSecret(RelationEvent): + """Base class for Relation Events that need to handle secrets.""" + + @property + def _secrets(self) -> dict: + """Caching secrets to avoid fetching them each time a field is referrd. + + DON'T USE the encapsulated helper variable outside of this function + """ + if not hasattr(self, "_cached_secrets"): + self._cached_secrets = {} + return self._cached_secrets + + def _get_secret(self, group) -> Optional[Dict[str, str]]: + """Retrieveing secrets.""" + if not self.app: + return + if not self._secrets.get(group): + self._secrets[group] = None + secret_field = f"{PROV_SECRET_PREFIX}{group}" + if secret_uri := self.relation.data[self.app].get(secret_field): + secret = self.framework.model.get_secret(id=secret_uri) + self._secrets[group] = secret.get_content() + return self._secrets[group] + + @property + def secrets_enabled(self): + """Is this Juju version allowing for Secrets usage?""" + return JujuVersion.from_environ().has_secrets + + +class AuthenticationEvent(RelationEventWithSecret): + """Base class for authentication fields for events. + + The amount of logic added here is not ideal -- but this was the only way to preserve + the interface when moving to Juju Secrets + """ @property def username(self) -> Optional[str]: """Returns the created username.""" + if not self.relation.app: + return None + + if self.secrets_enabled: + secret = self._get_secret("user") + if secret: + return secret.get("username") + return self.relation.data[self.relation.app].get("username") @property def password(self) -> Optional[str]: """Returns the password for the created user.""" + if not self.relation.app: + return None + + if self.secrets_enabled: + secret = self._get_secret("user") + if secret: + return secret.get("password") + return self.relation.data[self.relation.app].get("password") @property def tls(self) -> Optional[str]: """Returns whether TLS is configured.""" + if not self.relation.app: + return None + + if self.secrets_enabled: + secret = self._get_secret("tls") + if secret: + return secret.get("tls") + return self.relation.data[self.relation.app].get("tls") @property def tls_ca(self) -> Optional[str]: """Returns TLS CA.""" + if not self.relation.app: + return None + + if self.secrets_enabled: + secret = self._get_secret("tls") + if secret: + return secret.get("tls-ca") + return self.relation.data[self.relation.app].get("tls-ca") @@ -639,12 +2395,26 @@ class DatabaseProvidesEvent(RelationEvent): @property def database(self) -> Optional[str]: """Returns the database that was requested.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("database") class DatabaseRequestedEvent(DatabaseProvidesEvent, ExtraRoleEvent): """Event emitted when a new database is requested for use on this relation.""" + @property + def external_node_connectivity(self) -> bool: + """Returns the requested external_node_connectivity field.""" + if not self.relation.app: + return False + + return ( + self.relation.data[self.relation.app].get("external-node-connectivity", "false") + == "true" + ) + class DatabaseProvidesEvents(CharmEvents): """Database events. @@ -655,17 +2425,39 @@ class DatabaseProvidesEvents(CharmEvents): database_requested = EventSource(DatabaseRequestedEvent) -class DatabaseRequiresEvent(RelationEvent): +class DatabaseRequiresEvent(RelationEventWithSecret): """Base class for database events.""" + @property + def database(self) -> Optional[str]: + """Returns the database name.""" + if not self.relation.app: + return None + + return self.relation.data[self.relation.app].get("database") + @property def endpoints(self) -> Optional[str]: - """Returns a comma separated list of read/write endpoints.""" + """Returns a comma separated list of read/write endpoints. + + In VM charms, this is the primary's address. + In kubernetes charms, this is the service to the primary pod. + """ + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("endpoints") @property def read_only_endpoints(self) -> Optional[str]: - """Returns a comma separated list of read only endpoints.""" + """Returns a comma separated list of read only endpoints. + + In VM charms, this is the address of all the secondary instances. + In kubernetes charms, this is the service to all replica pod instances. + """ + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("read-only-endpoints") @property @@ -674,6 +2466,9 @@ def replset(self) -> Optional[str]: MongoDB only. """ + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("replset") @property @@ -682,6 +2477,14 @@ def uris(self) -> Optional[str]: MongoDB, Redis, OpenSearch. """ + if not self.relation.app: + return None + + if self.secrets_enabled: + secret = self._get_secret("user") + if secret: + return secret.get("uris") + return self.relation.data[self.relation.app].get("uris") @property @@ -690,6 +2493,9 @@ def version(self) -> Optional[str]: Version as informed by the database daemon. """ + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("version") @@ -719,27 +2525,23 @@ class DatabaseRequiresEvents(CharmEvents): # Database Provider and Requires -class DatabaseProvides(DataProvides): - """Provider-side of the database relations.""" - - on = DatabaseProvidesEvents() +class DatabaseProviderData(ProviderData): + """Provider-side data of the database relations.""" - def __init__(self, charm: CharmBase, relation_name: str) -> None: - super().__init__(charm, relation_name) + def __init__(self, model: Model, relation_name: str) -> None: + super().__init__(model, relation_name) - def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Event emitted when the relation has changed.""" - # Only the leader should handle this event. - if not self.local_unit.is_leader(): - return + def set_database(self, relation_id: int, database_name: str) -> None: + """Set database name. - # Check which data has changed to emit customs events. - diff = self._diff(event) + This function writes in the application data bag, therefore, + only the leader unit can call it. - # Emit a database requested event if the setup key (database name and optional - # extra user roles) was added to the relation databag by the application. - if "database" in diff.added: - self.on.database_requested.emit(event.relation, app=event.app, unit=event.unit) + Args: + relation_id: the identifier for a particular relation. + database_name: database name. + """ + self.update_relation_data(relation_id, {"database": database_name}) def set_endpoints(self, relation_id: int, connection_strings: str) -> None: """Set database primary connections. @@ -747,11 +2549,15 @@ def set_endpoints(self, relation_id: int, connection_strings: str) -> None: This function writes in the application data bag, therefore, only the leader unit can call it. + In VM charms, only the primary's address should be passed as an endpoint. + In kubernetes charms, the service endpoint to the primary pod should be + passed as an endpoint. + Args: relation_id: the identifier for a particular relation. connection_strings: database hosts and ports comma separated list. """ - self._update_relation_data(relation_id, {"endpoints": connection_strings}) + self.update_relation_data(relation_id, {"endpoints": connection_strings}) def set_read_only_endpoints(self, relation_id: int, connection_strings: str) -> None: """Set database replicas connection strings. @@ -763,7 +2569,7 @@ def set_read_only_endpoints(self, relation_id: int, connection_strings: str) -> relation_id: the identifier for a particular relation. connection_strings: database hosts and ports comma separated list. """ - self._update_relation_data(relation_id, {"read-only-endpoints": connection_strings}) + self.update_relation_data(relation_id, {"read-only-endpoints": connection_strings}) def set_replset(self, relation_id: int, replset: str) -> None: """Set replica set name in the application relation databag. @@ -774,7 +2580,7 @@ def set_replset(self, relation_id: int, replset: str) -> None: relation_id: the identifier for a particular relation. replset: replica set name. """ - self._update_relation_data(relation_id, {"replset": replset}) + self.update_relation_data(relation_id, {"replset": replset}) def set_uris(self, relation_id: int, uris: str) -> None: """Set the database connection URIs in the application relation databag. @@ -785,48 +2591,152 @@ def set_uris(self, relation_id: int, uris: str) -> None: relation_id: the identifier for a particular relation. uris: connection URIs. """ - self._update_relation_data(relation_id, {"uris": uris}) + self.update_relation_data(relation_id, {"uris": uris}) def set_version(self, relation_id: int, version: str) -> None: """Set the database version in the application relation databag. Args: - relation_id: the identifier for a particular relation. - version: database version. + relation_id: the identifier for a particular relation. + version: database version. + """ + self.update_relation_data(relation_id, {"version": version}) + + +class DatabaseProviderEventHandlers(EventHandlers): + """Provider-side of the database relation handlers.""" + + on = DatabaseProvidesEvents() # pyright: ignore [reportAssignmentType] + + def __init__( + self, charm: CharmBase, relation_data: DatabaseProviderData, unique_key: str = "" + ): + """Manager of base client relations.""" + super().__init__(charm, relation_data, unique_key) + # Just to calm down pyright, it can't parse that the same type is being used in the super() call above + self.relation_data = relation_data + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + # Leader only + if not self.relation_data.local_unit.is_leader(): + return + # Check which data has changed to emit customs events. + diff = self._diff(event) + + # Emit a database requested event if the setup key (database name and optional + # extra user roles) was added to the relation databag by the application. + if "database" in diff.added: + getattr(self.on, "database_requested").emit( + event.relation, app=event.app, unit=event.unit + ) + + +class DatabaseProvides(DatabaseProviderData, DatabaseProviderEventHandlers): + """Provider-side of the database relations.""" + + def __init__(self, charm: CharmBase, relation_name: str) -> None: + DatabaseProviderData.__init__(self, charm.model, relation_name) + DatabaseProviderEventHandlers.__init__(self, charm, self) + + +class DatabaseRequirerData(RequirerData): + """Requirer-side of the database relation.""" + + def __init__( + self, + model: Model, + relation_name: str, + database_name: str, + extra_user_roles: Optional[str] = None, + relations_aliases: Optional[List[str]] = None, + additional_secret_fields: Optional[List[str]] = [], + external_node_connectivity: bool = False, + ): + """Manager of database client relations.""" + super().__init__(model, relation_name, extra_user_roles, additional_secret_fields) + self.database = database_name + self.relations_aliases = relations_aliases + self.external_node_connectivity = external_node_connectivity + + def is_postgresql_plugin_enabled(self, plugin: str, relation_index: int = 0) -> bool: + """Returns whether a plugin is enabled in the database. + + Args: + plugin: name of the plugin to check. + relation_index: optional relation index to check the database + (default: 0 - first relation). + + PostgreSQL only. """ - self._update_relation_data(relation_id, {"version": version}) + # Psycopg 3 is imported locally to avoid the need of its package installation + # when relating to a database charm other than PostgreSQL. + import psycopg + + # Return False if no relation is established. + if len(self.relations) == 0: + return False + + relation_id = self.relations[relation_index].id + host = self.fetch_relation_field(relation_id, "endpoints") + + # Return False if there is no endpoint available. + if host is None: + return False + + host = host.split(":")[0] + + content = self.fetch_relation_data([relation_id], ["username", "password"]).get( + relation_id, {} + ) + user = content.get("username") + password = content.get("password") + + connection_string = ( + f"host='{host}' dbname='{self.database}' user='{user}' password='{password}'" + ) + try: + with psycopg.connect(connection_string) as connection: + with connection.cursor() as cursor: + cursor.execute( + "SELECT TRUE FROM pg_extension WHERE extname=%s::text;", (plugin,) + ) + return cursor.fetchone() is not None + except psycopg.Error as e: + logger.exception( + f"failed to check whether {plugin} plugin is enabled in the database: %s", str(e) + ) + return False -class DatabaseRequires(DataRequires): - """Requires-side of the database relation.""" +class DatabaseRequirerEventHandlers(RequirerEventHandlers): + """Requires-side of the relation.""" - on = DatabaseRequiresEvents() + on = DatabaseRequiresEvents() # pyright: ignore [reportAssignmentType] def __init__( - self, - charm, - relation_name: str, - database_name: str, - extra_user_roles: str = None, - relations_aliases: List[str] = None, + self, charm: CharmBase, relation_data: DatabaseRequirerData, unique_key: str = "" ): - """Manager of database client relations.""" - super().__init__(charm, relation_name, extra_user_roles) - self.database = database_name - self.relations_aliases = relations_aliases + """Manager of base client relations.""" + super().__init__(charm, relation_data, unique_key) + # Just to keep lint quiet, can't resolve inheritance. The same happened in super().__init__() above + self.relation_data = relation_data # Define custom event names for each alias. - if relations_aliases: + if self.relation_data.relations_aliases: # Ensure the number of aliases does not exceed the maximum # of connections allowed in the specific relation. - relation_connection_limit = self.charm.meta.requires[relation_name].limit - if len(relations_aliases) != relation_connection_limit: + relation_connection_limit = self.charm.meta.requires[ + self.relation_data.relation_name + ].limit + if len(self.relation_data.relations_aliases) != relation_connection_limit: raise ValueError( f"The number of aliases must match the maximum number of connections allowed in the relation. " - f"Expected {relation_connection_limit}, got {len(relations_aliases)}" + f"Expected {relation_connection_limit}, got {len(self.relation_data.relations_aliases)}" ) - for relation_alias in relations_aliases: + if self.relation_data.relations_aliases: + for relation_alias in self.relation_data.relations_aliases: self.on.define_event(f"{relation_alias}_database_created", DatabaseCreatedEvent) self.on.define_event( f"{relation_alias}_endpoints_changed", DatabaseEndpointsChangedEvent @@ -836,6 +2746,10 @@ def __init__( DatabaseReadOnlyEndpointsChangedEvent, ) + def _on_secret_changed_event(self, event: SecretChangedEvent): + """Event notifying about a new value of a secret.""" + pass + def _assign_relation_alias(self, relation_id: int) -> None: """Assigns an alias to a relation. @@ -845,29 +2759,32 @@ def _assign_relation_alias(self, relation_id: int) -> None: relation_id: the identifier for a particular relation. """ # If no aliases were provided, return immediately. - if not self.relations_aliases: + if not self.relation_data.relations_aliases: return # Return if an alias was already assigned to this relation # (like when there are more than one unit joining the relation). - if ( - self.charm.model.get_relation(self.relation_name, relation_id) - .data[self.local_unit] - .get("alias") - ): + relation = self.charm.model.get_relation(self.relation_data.relation_name, relation_id) + if relation and relation.data[self.relation_data.local_unit].get("alias"): return # Retrieve the available aliases (the ones that weren't assigned to any relation). - available_aliases = self.relations_aliases[:] - for relation in self.charm.model.relations[self.relation_name]: - alias = relation.data[self.local_unit].get("alias") + available_aliases = self.relation_data.relations_aliases[:] + for relation in self.charm.model.relations[self.relation_data.relation_name]: + alias = relation.data[self.relation_data.local_unit].get("alias") if alias: logger.debug("Alias %s was already assigned to relation %d", alias, relation.id) available_aliases.remove(alias) # Set the alias in the unit relation databag of the specific relation. - relation = self.charm.model.get_relation(self.relation_name, relation_id) - relation.data[self.local_unit].update({"alias": available_aliases[0]}) + relation = self.charm.model.get_relation(self.relation_data.relation_name, relation_id) + if relation: + relation.data[self.relation_data.local_unit].update({"alias": available_aliases[0]}) + + # We need to set relation alias also on the application level so, + # it will be accessible in show-unit juju command, executed for a consumer application unit + if self.relation_data.local_unit.is_leader(): + self.relation_data.update_relation_data(relation_id, {"alias": available_aliases[0]}) def _emit_aliased_event(self, event: RelationChangedEvent, event_name: str) -> None: """Emit an aliased event to a particular relation if it has an alias. @@ -891,40 +2808,54 @@ def _get_relation_alias(self, relation_id: int) -> Optional[str]: Returns: the relation alias or None if the relation was not found. """ - for relation in self.charm.model.relations[self.relation_name]: + for relation in self.charm.model.relations[self.relation_data.relation_name]: if relation.id == relation_id: - return relation.data[self.local_unit].get("alias") + return relation.data[self.relation_data.local_unit].get("alias") return None - def _on_relation_joined_event(self, event: RelationJoinedEvent) -> None: - """Event emitted when the application joins the database relation.""" + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: + """Event emitted when the database relation is created.""" + super()._on_relation_created_event(event) + # If relations aliases were provided, assign one to the relation. self._assign_relation_alias(event.relation.id) # Sets both database and extra user roles in the relation # if the roles are provided. Otherwise, sets only the database. - if self.extra_user_roles: - self._update_relation_data( - event.relation.id, - { - "database": self.database, - "extra-user-roles": self.extra_user_roles, - }, - ) - else: - self._update_relation_data(event.relation.id, {"database": self.database}) + if not self.relation_data.local_unit.is_leader(): + return + + event_data = {"database": self.relation_data.database} + + if self.relation_data.extra_user_roles: + event_data["extra-user-roles"] = self.relation_data.extra_user_roles + + # set external-node-connectivity field + if self.relation_data.external_node_connectivity: + event_data["external-node-connectivity"] = "true" + + self.relation_data.update_relation_data(event.relation.id, event_data) def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the database relation has changed.""" # Check which data has changed to emit customs events. diff = self._diff(event) + # Register all new secrets with their labels + if any(newval for newval in diff.added if self.relation_data._is_secret_field(newval)): + self.relation_data._register_secrets_to_relation(event.relation, diff.added) + # Check if the database is created # (the database charm shared the credentials). - if "username" in diff.added and "password" in diff.added: + secret_field_user = self.relation_data._generate_secret_field_name(SECRET_GROUPS.USER) + if ( + "username" in diff.added and "password" in diff.added + ) or secret_field_user in diff.added: # Emit the default event (the one without an alias). logger.info("database created at %s", datetime.now()) - self.on.database_created.emit(event.relation, app=event.app, unit=event.unit) + getattr(self.on, "database_created").emit( + event.relation, app=event.app, unit=event.unit + ) # Emit the aliased event (if any). self._emit_aliased_event(event, "database_created") @@ -938,7 +2869,9 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: if "endpoints" in diff.added or "endpoints" in diff.changed: # Emit the default event (the one without an alias). logger.info("endpoints changed on %s", datetime.now()) - self.on.endpoints_changed.emit(event.relation, app=event.app, unit=event.unit) + getattr(self.on, "endpoints_changed").emit( + event.relation, app=event.app, unit=event.unit + ) # Emit the aliased event (if any). self._emit_aliased_event(event, "endpoints_changed") @@ -952,7 +2885,7 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: if "read-only-endpoints" in diff.added or "read-only-endpoints" in diff.changed: # Emit the default event (the one without an alias). logger.info("read-only-endpoints changed on %s", datetime.now()) - self.on.read_only_endpoints_changed.emit( + getattr(self.on, "read_only_endpoints_changed").emit( event.relation, app=event.app, unit=event.unit ) @@ -960,7 +2893,37 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: self._emit_aliased_event(event, "read_only_endpoints_changed") -# Kafka related events +class DatabaseRequires(DatabaseRequirerData, DatabaseRequirerEventHandlers): + """Provider-side of the database relations.""" + + def __init__( + self, + charm: CharmBase, + relation_name: str, + database_name: str, + extra_user_roles: Optional[str] = None, + relations_aliases: Optional[List[str]] = None, + additional_secret_fields: Optional[List[str]] = [], + external_node_connectivity: bool = False, + ): + DatabaseRequirerData.__init__( + self, + charm.model, + relation_name, + database_name, + extra_user_roles, + relations_aliases, + additional_secret_fields, + external_node_connectivity, + ) + DatabaseRequirerEventHandlers.__init__(self, charm, self) + + +################################################################################ +# Charm-specific Relations Data and Events +################################################################################ + +# Kafka Events class KafkaProvidesEvent(RelationEvent): @@ -969,8 +2932,19 @@ class KafkaProvidesEvent(RelationEvent): @property def topic(self) -> Optional[str]: """Returns the topic that was requested.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("topic") + @property + def consumer_group_prefix(self) -> Optional[str]: + """Returns the consumer-group-prefix that was requested.""" + if not self.relation.app: + return None + + return self.relation.data[self.relation.app].get("consumer-group-prefix") + class TopicRequestedEvent(KafkaProvidesEvent, ExtraRoleEvent): """Event emitted when a new topic is requested for use on this relation.""" @@ -988,19 +2962,36 @@ class KafkaProvidesEvents(CharmEvents): class KafkaRequiresEvent(RelationEvent): """Base class for Kafka events.""" + @property + def topic(self) -> Optional[str]: + """Returns the topic.""" + if not self.relation.app: + return None + + return self.relation.data[self.relation.app].get("topic") + @property def bootstrap_server(self) -> Optional[str]: - """Returns a a comma-seperated list of broker uris.""" + """Returns a comma-separated list of broker uris.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("endpoints") @property def consumer_group_prefix(self) -> Optional[str]: """Returns the consumer-group-prefix.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("consumer-group-prefix") @property def zookeeper_uris(self) -> Optional[str]: """Returns a comma separated list of Zookeeper uris.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("zookeeper-uris") @@ -1025,27 +3016,20 @@ class KafkaRequiresEvents(CharmEvents): # Kafka Provides and Requires -class KafkaProvides(DataProvides): +class KafkaProvidesData(ProviderData): """Provider-side of the Kafka relation.""" - on = KafkaProvidesEvents() - - def __init__(self, charm: CharmBase, relation_name: str) -> None: - super().__init__(charm, relation_name) - - def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Event emitted when the relation has changed.""" - # Only the leader should handle this event. - if not self.local_unit.is_leader(): - return + def __init__(self, model: Model, relation_name: str) -> None: + super().__init__(model, relation_name) - # Check which data has changed to emit customs events. - diff = self._diff(event) + def set_topic(self, relation_id: int, topic: str) -> None: + """Set topic name in the application relation databag. - # Emit a topic requested event if the setup key (topic name and optional - # extra user roles) was added to the relation databag by the application. - if "topic" in diff.added: - self.on.topic_requested.emit(event.relation, app=event.app, unit=event.unit) + Args: + relation_id: the identifier for a particular relation. + topic: the topic name. + """ + self.update_relation_data(relation_id, {"topic": topic}) def set_bootstrap_server(self, relation_id: int, bootstrap_server: str) -> None: """Set the bootstrap server in the application relation databag. @@ -1054,7 +3038,7 @@ def set_bootstrap_server(self, relation_id: int, bootstrap_server: str) -> None: relation_id: the identifier for a particular relation. bootstrap_server: the bootstrap server address. """ - self._update_relation_data(relation_id, {"endpoints": bootstrap_server}) + self.update_relation_data(relation_id, {"endpoints": bootstrap_server}) def set_consumer_group_prefix(self, relation_id: int, consumer_group_prefix: str) -> None: """Set the consumer group prefix in the application relation databag. @@ -1063,43 +3047,111 @@ def set_consumer_group_prefix(self, relation_id: int, consumer_group_prefix: str relation_id: the identifier for a particular relation. consumer_group_prefix: the consumer group prefix string. """ - self._update_relation_data(relation_id, {"consumer-group-prefix": consumer_group_prefix}) + self.update_relation_data(relation_id, {"consumer-group-prefix": consumer_group_prefix}) def set_zookeeper_uris(self, relation_id: int, zookeeper_uris: str) -> None: """Set the zookeeper uris in the application relation databag. Args: relation_id: the identifier for a particular relation. - zookeeper_uris: comma-seperated list of ZooKeeper server uris. + zookeeper_uris: comma-separated list of ZooKeeper server uris. """ - self._update_relation_data(relation_id, {"zookeeper-uris": zookeeper_uris}) + self.update_relation_data(relation_id, {"zookeeper-uris": zookeeper_uris}) -class KafkaRequires(DataRequires): - """Requires-side of the Kafka relation.""" +class KafkaProvidesEventHandlers(EventHandlers): + """Provider-side of the Kafka relation.""" + + on = KafkaProvidesEvents() # pyright: ignore [reportAssignmentType] + + def __init__(self, charm: CharmBase, relation_data: KafkaProvidesData) -> None: + super().__init__(charm, relation_data) + # Just to keep lint quiet, can't resolve inheritance. The same happened in super().__init__() above + self.relation_data = relation_data + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + # Leader only + if not self.relation_data.local_unit.is_leader(): + return + + # Check which data has changed to emit customs events. + diff = self._diff(event) + + # Emit a topic requested event if the setup key (topic name and optional + # extra user roles) was added to the relation databag by the application. + if "topic" in diff.added: + getattr(self.on, "topic_requested").emit( + event.relation, app=event.app, unit=event.unit + ) + + +class KafkaProvides(KafkaProvidesData, KafkaProvidesEventHandlers): + """Provider-side of the Kafka relation.""" + + def __init__(self, charm: CharmBase, relation_name: str) -> None: + KafkaProvidesData.__init__(self, charm.model, relation_name) + KafkaProvidesEventHandlers.__init__(self, charm, self) - on = KafkaRequiresEvents() - def __init__(self, charm, relation_name: str, topic: str, extra_user_roles: str = None): +class KafkaRequiresData(RequirerData): + """Requirer-side of the Kafka relation.""" + + def __init__( + self, + model: Model, + relation_name: str, + topic: str, + extra_user_roles: Optional[str] = None, + consumer_group_prefix: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + ): """Manager of Kafka client relations.""" - # super().__init__(charm, relation_name) - super().__init__(charm, relation_name, extra_user_roles) - self.charm = charm + super().__init__(model, relation_name, extra_user_roles, additional_secret_fields) self.topic = topic + self.consumer_group_prefix = consumer_group_prefix or "" - def _on_relation_joined_event(self, event: RelationJoinedEvent) -> None: - """Event emitted when the application joins the Kafka relation.""" - # Sets both topic and extra user roles in the relation - # if the roles are provided. Otherwise, sets only the topic. - self._update_relation_data( - event.relation.id, - { - "topic": self.topic, - "extra-user-roles": self.extra_user_roles, - } - if self.extra_user_roles is not None - else {"topic": self.topic}, - ) + @property + def topic(self): + """Topic to use in Kafka.""" + return self._topic + + @topic.setter + def topic(self, value): + # Avoid wildcards + if value == "*": + raise ValueError(f"Error on topic '{value}', cannot be a wildcard.") + self._topic = value + + +class KafkaRequiresEventHandlers(RequirerEventHandlers): + """Requires-side of the Kafka relation.""" + + on = KafkaRequiresEvents() # pyright: ignore [reportAssignmentType] + + def __init__(self, charm: CharmBase, relation_data: KafkaRequiresData) -> None: + super().__init__(charm, relation_data) + # Just to keep lint quiet, can't resolve inheritance. The same happened in super().__init__() above + self.relation_data = relation_data + + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: + """Event emitted when the Kafka relation is created.""" + super()._on_relation_created_event(event) + + if not self.relation_data.local_unit.is_leader(): + return + + # Sets topic, extra user roles, and "consumer-group-prefix" in the relation + relation_data = { + f: getattr(self, f.replace("-", "_"), "") + for f in ["consumer-group-prefix", "extra-user-roles", "topic"] + } + + self.relation_data.update_relation_data(event.relation.id, relation_data) + + def _on_secret_changed_event(self, event: SecretChangedEvent): + """Event notifying about a new value of a secret.""" + pass def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the Kafka relation has changed.""" @@ -1108,21 +3160,306 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: # Check if the topic is created # (the Kafka charm shared the credentials). - if "username" in diff.added and "password" in diff.added: + + # Register all new secrets with their labels + if any(newval for newval in diff.added if self.relation_data._is_secret_field(newval)): + self.relation_data._register_secrets_to_relation(event.relation, diff.added) + + secret_field_user = self.relation_data._generate_secret_field_name(SECRET_GROUPS.USER) + if ( + "username" in diff.added and "password" in diff.added + ) or secret_field_user in diff.added: # Emit the default event (the one without an alias). logger.info("topic created at %s", datetime.now()) - self.on.topic_created.emit(event.relation, app=event.app, unit=event.unit) + getattr(self.on, "topic_created").emit(event.relation, app=event.app, unit=event.unit) # To avoid unnecessary application restarts do not trigger # “endpoints_changed“ event if “topic_created“ is triggered. return - # Emit an endpoints (bootstap-server) changed event if the Kakfa endpoints + # Emit an endpoints (bootstrap-server) changed event if the Kafka endpoints # added or changed this info in the relation databag. if "endpoints" in diff.added or "endpoints" in diff.changed: # Emit the default event (the one without an alias). logger.info("endpoints changed on %s", datetime.now()) - self.on.bootstrap_server_changed.emit( + getattr(self.on, "bootstrap_server_changed").emit( event.relation, app=event.app, unit=event.unit ) # here check if this is the right design return + + +class KafkaRequires(KafkaRequiresData, KafkaRequiresEventHandlers): + """Provider-side of the Kafka relation.""" + + def __init__( + self, + charm: CharmBase, + relation_name: str, + topic: str, + extra_user_roles: Optional[str] = None, + consumer_group_prefix: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + ) -> None: + KafkaRequiresData.__init__( + self, + charm.model, + relation_name, + topic, + extra_user_roles, + consumer_group_prefix, + additional_secret_fields, + ) + KafkaRequiresEventHandlers.__init__(self, charm, self) + + +# Opensearch related events + + +class OpenSearchProvidesEvent(RelationEvent): + """Base class for OpenSearch events.""" + + @property + def index(self) -> Optional[str]: + """Returns the index that was requested.""" + if not self.relation.app: + return None + + return self.relation.data[self.relation.app].get("index") + + +class IndexRequestedEvent(OpenSearchProvidesEvent, ExtraRoleEvent): + """Event emitted when a new index is requested for use on this relation.""" + + +class OpenSearchProvidesEvents(CharmEvents): + """OpenSearch events. + + This class defines the events that OpenSearch can emit. + """ + + index_requested = EventSource(IndexRequestedEvent) + + +class OpenSearchRequiresEvent(DatabaseRequiresEvent): + """Base class for OpenSearch requirer events.""" + + +class IndexCreatedEvent(AuthenticationEvent, OpenSearchRequiresEvent): + """Event emitted when a new index is created for use on this relation.""" + + +class OpenSearchRequiresEvents(CharmEvents): + """OpenSearch events. + + This class defines the events that the opensearch requirer can emit. + """ + + index_created = EventSource(IndexCreatedEvent) + endpoints_changed = EventSource(DatabaseEndpointsChangedEvent) + authentication_updated = EventSource(AuthenticationEvent) + + +# OpenSearch Provides and Requires Objects + + +class OpenSearchProvidesData(ProviderData): + """Provider-side of the OpenSearch relation.""" + + def __init__(self, model: Model, relation_name: str) -> None: + super().__init__(model, relation_name) + + def set_index(self, relation_id: int, index: str) -> None: + """Set the index in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + index: the index as it is _created_ on the provider charm. This needn't match the + requested index, and can be used to present a different index name if, for example, + the requested index is invalid. + """ + self.update_relation_data(relation_id, {"index": index}) + + def set_endpoints(self, relation_id: int, endpoints: str) -> None: + """Set the endpoints in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + endpoints: the endpoint addresses for opensearch nodes. + """ + self.update_relation_data(relation_id, {"endpoints": endpoints}) + + def set_version(self, relation_id: int, version: str) -> None: + """Set the opensearch version in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + version: database version. + """ + self.update_relation_data(relation_id, {"version": version}) + + +class OpenSearchProvidesEventHandlers(EventHandlers): + """Provider-side of the OpenSearch relation.""" + + on = OpenSearchProvidesEvents() # pyright: ignore[reportAssignmentType] + + def __init__(self, charm: CharmBase, relation_data: OpenSearchProvidesData) -> None: + super().__init__(charm, relation_data) + # Just to keep lint quiet, can't resolve inheritance. The same happened in super().__init__() above + self.relation_data = relation_data + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + # Leader only + if not self.relation_data.local_unit.is_leader(): + return + # Check which data has changed to emit customs events. + diff = self._diff(event) + + # Emit an index requested event if the setup key (index name and optional extra user roles) + # have been added to the relation databag by the application. + if "index" in diff.added: + getattr(self.on, "index_requested").emit( + event.relation, app=event.app, unit=event.unit + ) + + +class OpenSearchProvides(OpenSearchProvidesData, OpenSearchProvidesEventHandlers): + """Provider-side of the OpenSearch relation.""" + + def __init__(self, charm: CharmBase, relation_name: str) -> None: + OpenSearchProvidesData.__init__(self, charm.model, relation_name) + OpenSearchProvidesEventHandlers.__init__(self, charm, self) + + +class OpenSearchRequiresData(RequirerData): + """Requires data side of the OpenSearch relation.""" + + def __init__( + self, + model: Model, + relation_name: str, + index: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + ): + """Manager of OpenSearch client relations.""" + super().__init__(model, relation_name, extra_user_roles, additional_secret_fields) + self.index = index + + +class OpenSearchRequiresEventHandlers(RequirerEventHandlers): + """Requires events side of the OpenSearch relation.""" + + on = OpenSearchRequiresEvents() # pyright: ignore[reportAssignmentType] + + def __init__(self, charm: CharmBase, relation_data: OpenSearchRequiresData) -> None: + super().__init__(charm, relation_data) + # Just to keep lint quiet, can't resolve inheritance. The same happened in super().__init__() above + self.relation_data = relation_data + + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: + """Event emitted when the OpenSearch relation is created.""" + super()._on_relation_created_event(event) + + if not self.relation_data.local_unit.is_leader(): + return + + # Sets both index and extra user roles in the relation if the roles are provided. + # Otherwise, sets only the index. + data = {"index": self.relation_data.index} + if self.relation_data.extra_user_roles: + data["extra-user-roles"] = self.relation_data.extra_user_roles + + self.relation_data.update_relation_data(event.relation.id, data) + + def _on_secret_changed_event(self, event: SecretChangedEvent): + """Event notifying about a new value of a secret.""" + if not event.secret.label: + return + + relation = self.relation_data._relation_from_secret_label(event.secret.label) + if not relation: + logging.info( + f"Received secret {event.secret.label} but couldn't parse, seems irrelevant" + ) + return + + if relation.app == self.charm.app: + logging.info("Secret changed event ignored for Secret Owner") + + remote_unit = None + for unit in relation.units: + if unit.app != self.charm.app: + remote_unit = unit + + logger.info("authentication updated") + getattr(self.on, "authentication_updated").emit( + relation, app=relation.app, unit=remote_unit + ) + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the OpenSearch relation has changed. + + This event triggers individual custom events depending on the changing relation. + """ + # Check which data has changed to emit customs events. + diff = self._diff(event) + + # Register all new secrets with their labels + if any(newval for newval in diff.added if self.relation_data._is_secret_field(newval)): + self.relation_data._register_secrets_to_relation(event.relation, diff.added) + + secret_field_user = self.relation_data._generate_secret_field_name(SECRET_GROUPS.USER) + secret_field_tls = self.relation_data._generate_secret_field_name(SECRET_GROUPS.TLS) + updates = {"username", "password", "tls", "tls-ca", secret_field_user, secret_field_tls} + if len(set(diff._asdict().keys()) - updates) < len(diff): + logger.info("authentication updated at: %s", datetime.now()) + getattr(self.on, "authentication_updated").emit( + event.relation, app=event.app, unit=event.unit + ) + + # Check if the index is created + # (the OpenSearch charm shares the credentials). + if ( + "username" in diff.added and "password" in diff.added + ) or secret_field_user in diff.added: + # Emit the default event (the one without an alias). + logger.info("index created at: %s", datetime.now()) + getattr(self.on, "index_created").emit(event.relation, app=event.app, unit=event.unit) + + # To avoid unnecessary application restarts do not trigger + # “endpoints_changed“ event if “index_created“ is triggered. + return + + # Emit a endpoints changed event if the OpenSearch application added or changed this info + # in the relation databag. + if "endpoints" in diff.added or "endpoints" in diff.changed: + # Emit the default event (the one without an alias). + logger.info("endpoints changed on %s", datetime.now()) + getattr(self.on, "endpoints_changed").emit( + event.relation, app=event.app, unit=event.unit + ) # here check if this is the right design + return + + +class OpenSearchRequires(OpenSearchRequiresData, OpenSearchRequiresEventHandlers): + """Requires-side of the OpenSearch relation.""" + + def __init__( + self, + charm: CharmBase, + relation_name: str, + index: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + ) -> None: + OpenSearchRequiresData.__init__( + self, + charm.model, + relation_name, + index, + extra_user_roles, + additional_secret_fields, + ) + OpenSearchRequiresEventHandlers.__init__(self, charm, self) diff --git a/lib/charms/operator_libs_linux/v0/apt.py b/lib/charms/operator_libs_linux/v0/apt.py new file mode 100644 index 0000000..1400df7 --- /dev/null +++ b/lib/charms/operator_libs_linux/v0/apt.py @@ -0,0 +1,1361 @@ +# Copyright 2021 Canonical Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Abstractions for the system's Debian/Ubuntu package information and repositories. + +This module contains abstractions and wrappers around Debian/Ubuntu-style repositories and +packages, in order to easily provide an idiomatic and Pythonic mechanism for adding packages and/or +repositories to systems for use in machine charms. + +A sane default configuration is attainable through nothing more than instantiation of the +appropriate classes. `DebianPackage` objects provide information about the architecture, version, +name, and status of a package. + +`DebianPackage` will try to look up a package either from `dpkg -L` or from `apt-cache` when +provided with a string indicating the package name. If it cannot be located, `PackageNotFoundError` +will be returned, as `apt` and `dpkg` otherwise return `100` for all errors, and a meaningful error +message if the package is not known is desirable. + +To install packages with convenience methods: + +```python +try: + # Run `apt-get update` + apt.update() + apt.add_package("zsh") + apt.add_package(["vim", "htop", "wget"]) +except PackageNotFoundError: + logger.error("a specified package not found in package cache or on system") +except PackageError as e: + logger.error("could not install package. Reason: %s", e.message) +```` + +To find details of a specific package: + +```python +try: + vim = apt.DebianPackage.from_system("vim") + + # To find from the apt cache only + # apt.DebianPackage.from_apt_cache("vim") + + # To find from installed packages only + # apt.DebianPackage.from_installed_package("vim") + + vim.ensure(PackageState.Latest) + logger.info("updated vim to version: %s", vim.fullversion) +except PackageNotFoundError: + logger.error("a specified package not found in package cache or on system") +except PackageError as e: + logger.error("could not install package. Reason: %s", e.message) +``` + + +`RepositoryMapping` will return a dict-like object containing enabled system repositories +and their properties (available groups, baseuri. gpg key). This class can add, disable, or +manipulate repositories. Items can be retrieved as `DebianRepository` objects. + +In order add a new repository with explicit details for fields, a new `DebianRepository` can +be added to `RepositoryMapping` + +`RepositoryMapping` provides an abstraction around the existing repositories on the system, +and can be accessed and iterated over like any `Mapping` object, to retrieve values by key, +iterate, or perform other operations. + +Keys are constructed as `{repo_type}-{}-{release}` in order to uniquely identify a repository. + +Repositories can be added with explicit values through a Python constructor. + +Example: +```python +repositories = apt.RepositoryMapping() + +if "deb-example.com-focal" not in repositories: + repositories.add(DebianRepository(enabled=True, repotype="deb", + uri="https://example.com", release="focal", groups=["universe"])) +``` + +Alternatively, any valid `sources.list` line may be used to construct a new +`DebianRepository`. + +Example: +```python +repositories = apt.RepositoryMapping() + +if "deb-us.archive.ubuntu.com-xenial" not in repositories: + line = "deb http://us.archive.ubuntu.com/ubuntu xenial main restricted" + repo = DebianRepository.from_repo_line(line) + repositories.add(repo) +``` +""" + +import fileinput +import glob +import logging +import os +import re +import subprocess +from collections.abc import Mapping +from enum import Enum +from subprocess import PIPE, CalledProcessError, check_output +from typing import Iterable, List, Optional, Tuple, Union +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + +# The unique Charmhub library identifier, never change it +LIBID = "7c3dbc9c2ad44a47bd6fcb25caa270e5" + +# Increment this major API version when introducing breaking changes +LIBAPI = 0 + +# Increment this PATCH version before using `charmcraft publish-lib` or reset +# to 0 if you are raising the major API version +LIBPATCH = 13 + + +VALID_SOURCE_TYPES = ("deb", "deb-src") +OPTIONS_MATCHER = re.compile(r"\[.*?\]") + + +class Error(Exception): + """Base class of most errors raised by this library.""" + + def __repr__(self): + """Represent the Error.""" + return "<{}.{} {}>".format(type(self).__module__, type(self).__name__, self.args) + + @property + def name(self): + """Return a string representation of the model plus class.""" + return "<{}.{}>".format(type(self).__module__, type(self).__name__) + + @property + def message(self): + """Return the message passed as an argument.""" + return self.args[0] + + +class PackageError(Error): + """Raised when there's an error installing or removing a package.""" + + +class PackageNotFoundError(Error): + """Raised when a requested package is not known to the system.""" + + +class PackageState(Enum): + """A class to represent possible package states.""" + + Present = "present" + Absent = "absent" + Latest = "latest" + Available = "available" + + +class DebianPackage: + """Represents a traditional Debian package and its utility functions. + + `DebianPackage` wraps information and functionality around a known package, whether installed + or available. The version, epoch, name, and architecture can be easily queried and compared + against other `DebianPackage` objects to determine the latest version or to install a specific + version. + + The representation of this object as a string mimics the output from `dpkg` for familiarity. + + Installation and removal of packages is handled through the `state` property or `ensure` + method, with the following options: + + apt.PackageState.Absent + apt.PackageState.Available + apt.PackageState.Present + apt.PackageState.Latest + + When `DebianPackage` is initialized, the state of a given `DebianPackage` object will be set to + `Available`, `Present`, or `Latest`, with `Absent` implemented as a convenience for removal + (though it operates essentially the same as `Available`). + """ + + def __init__( + self, name: str, version: str, epoch: str, arch: str, state: PackageState + ) -> None: + self._name = name + self._arch = arch + self._state = state + self._version = Version(version, epoch) + + def __eq__(self, other) -> bool: + """Equality for comparison. + + Args: + other: a `DebianPackage` object for comparison + + Returns: + A boolean reflecting equality + """ + return isinstance(other, self.__class__) and ( + self._name, + self._version.number, + ) == (other._name, other._version.number) + + def __hash__(self): + """Return a hash of this package.""" + return hash((self._name, self._version.number)) + + def __repr__(self): + """Represent the package.""" + return "<{}.{}: {}>".format(self.__module__, self.__class__.__name__, self.__dict__) + + def __str__(self): + """Return a human-readable representation of the package.""" + return "<{}: {}-{}.{} -- {}>".format( + self.__class__.__name__, + self._name, + self._version, + self._arch, + str(self._state), + ) + + @staticmethod + def _apt( + command: str, + package_names: Union[str, List], + optargs: Optional[List[str]] = None, + ) -> None: + """Wrap package management commands for Debian/Ubuntu systems. + + Args: + command: the command given to `apt-get` + package_names: a package name or list of package names to operate on + optargs: an (Optional) list of additioanl arguments + + Raises: + PackageError if an error is encountered + """ + optargs = optargs if optargs is not None else [] + if isinstance(package_names, str): + package_names = [package_names] + _cmd = ["apt-get", "-y", *optargs, command, *package_names] + try: + env = os.environ.copy() + env["DEBIAN_FRONTEND"] = "noninteractive" + subprocess.run(_cmd, capture_output=True, check=True, text=True, env=env) + except CalledProcessError as e: + raise PackageError( + "Could not {} package(s) [{}]: {}".format(command, [*package_names], e.stderr) + ) from None + + def _add(self) -> None: + """Add a package to the system.""" + self._apt( + "install", + "{}={}".format(self.name, self.version), + optargs=["--option=Dpkg::Options::=--force-confold"], + ) + + def _remove(self) -> None: + """Remove a package from the system. Implementation-specific.""" + return self._apt("remove", "{}={}".format(self.name, self.version)) + + @property + def name(self) -> str: + """Returns the name of the package.""" + return self._name + + def ensure(self, state: PackageState): + """Ensure that a package is in a given state. + + Args: + state: a `PackageState` to reconcile the package to + + Raises: + PackageError from the underlying call to apt + """ + if self._state is not state: + if state not in (PackageState.Present, PackageState.Latest): + self._remove() + else: + self._add() + self._state = state + + @property + def present(self) -> bool: + """Returns whether or not a package is present.""" + return self._state in (PackageState.Present, PackageState.Latest) + + @property + def latest(self) -> bool: + """Returns whether the package is the most recent version.""" + return self._state is PackageState.Latest + + @property + def state(self) -> PackageState: + """Returns the current package state.""" + return self._state + + @state.setter + def state(self, state: PackageState) -> None: + """Set the package state to a given value. + + Args: + state: a `PackageState` to reconcile the package to + + Raises: + PackageError from the underlying call to apt + """ + if state in (PackageState.Latest, PackageState.Present): + self._add() + else: + self._remove() + self._state = state + + @property + def version(self) -> "Version": + """Returns the version for a package.""" + return self._version + + @property + def epoch(self) -> str: + """Returns the epoch for a package. May be unset.""" + return self._version.epoch + + @property + def arch(self) -> str: + """Returns the architecture for a package.""" + return self._arch + + @property + def fullversion(self) -> str: + """Returns the name+epoch for a package.""" + return "{}.{}".format(self._version, self._arch) + + @staticmethod + def _get_epoch_from_version(version: str) -> Tuple[str, str]: + """Pull the epoch, if any, out of a version string.""" + epoch_matcher = re.compile(r"^((?P\d+):)?(?P.*)") + matches = epoch_matcher.search(version).groupdict() + return matches.get("epoch", ""), matches.get("version") + + @classmethod + def from_system( + cls, package: str, version: Optional[str] = "", arch: Optional[str] = "" + ) -> "DebianPackage": + """Locates a package, either on the system or known to apt, and serializes the information. + + Args: + package: a string representing the package + version: an optional string if a specific version is requested + arch: an optional architecture, defaulting to `dpkg --print-architecture`. If an + architecture is not specified, this will be used for selection. + + """ + try: + return DebianPackage.from_installed_package(package, version, arch) + except PackageNotFoundError: + logger.debug( + "package '%s' is not currently installed or has the wrong architecture.", package + ) + + # Ok, try `apt-cache ...` + try: + return DebianPackage.from_apt_cache(package, version, arch) + except (PackageNotFoundError, PackageError): + # If we get here, it's not known to the systems. + # This seems unnecessary, but virtually all `apt` commands have a return code of `100`, + # and providing meaningful error messages without this is ugly. + raise PackageNotFoundError( + "Package '{}{}' could not be found on the system or in the apt cache!".format( + package, ".{}".format(arch) if arch else "" + ) + ) from None + + @classmethod + def from_installed_package( + cls, package: str, version: Optional[str] = "", arch: Optional[str] = "" + ) -> "DebianPackage": + """Check whether the package is already installed and return an instance. + + Args: + package: a string representing the package + version: an optional string if a specific version is requested + arch: an optional architecture, defaulting to `dpkg --print-architecture`. + If an architecture is not specified, this will be used for selection. + """ + system_arch = check_output( + ["dpkg", "--print-architecture"], universal_newlines=True + ).strip() + arch = arch if arch else system_arch + + # Regexps are a really terrible way to do this. Thanks dpkg + output = "" + try: + output = check_output(["dpkg", "-l", package], stderr=PIPE, universal_newlines=True) + except CalledProcessError: + raise PackageNotFoundError("Package is not installed: {}".format(package)) from None + + # Pop off the output from `dpkg -l' because there's no flag to + # omit it` + lines = str(output).splitlines()[5:] + + dpkg_matcher = re.compile( + r""" + ^(?P\w+?)\s+ + (?P.*?)(?P:\w+?)?\s+ + (?P.*?)\s+ + (?P\w+?)\s+ + (?P.*) + """, + re.VERBOSE, + ) + + for line in lines: + try: + matches = dpkg_matcher.search(line).groupdict() + package_status = matches["package_status"] + + if not package_status.endswith("i"): + logger.debug( + "package '%s' in dpkg output but not installed, status: '%s'", + package, + package_status, + ) + break + + epoch, split_version = DebianPackage._get_epoch_from_version(matches["version"]) + pkg = DebianPackage( + matches["package_name"], + split_version, + epoch, + matches["arch"], + PackageState.Present, + ) + if (pkg.arch == "all" or pkg.arch == arch) and ( + version == "" or str(pkg.version) == version + ): + return pkg + except AttributeError: + logger.warning("dpkg matcher could not parse line: %s", line) + + # If we didn't find it, fail through + raise PackageNotFoundError("Package {}.{} is not installed!".format(package, arch)) + + @classmethod + def from_apt_cache( + cls, package: str, version: Optional[str] = "", arch: Optional[str] = "" + ) -> "DebianPackage": + """Check whether the package is already installed and return an instance. + + Args: + package: a string representing the package + version: an optional string if a specific version is requested + arch: an optional architecture, defaulting to `dpkg --print-architecture`. + If an architecture is not specified, this will be used for selection. + """ + system_arch = check_output( + ["dpkg", "--print-architecture"], universal_newlines=True + ).strip() + arch = arch if arch else system_arch + + # Regexps are a really terrible way to do this. Thanks dpkg + keys = ("Package", "Architecture", "Version") + + try: + output = check_output( + ["apt-cache", "show", package], stderr=PIPE, universal_newlines=True + ) + except CalledProcessError as e: + raise PackageError( + "Could not list packages in apt-cache: {}".format(e.stderr) + ) from None + + pkg_groups = output.strip().split("\n\n") + keys = ("Package", "Architecture", "Version") + + for pkg_raw in pkg_groups: + lines = str(pkg_raw).splitlines() + vals = {} + for line in lines: + if line.startswith(keys): + items = line.split(":", 1) + vals[items[0]] = items[1].strip() + else: + continue + + epoch, split_version = DebianPackage._get_epoch_from_version(vals["Version"]) + pkg = DebianPackage( + vals["Package"], + split_version, + epoch, + vals["Architecture"], + PackageState.Available, + ) + + if (pkg.arch == "all" or pkg.arch == arch) and ( + version == "" or str(pkg.version) == version + ): + return pkg + + # If we didn't find it, fail through + raise PackageNotFoundError("Package {}.{} is not in the apt cache!".format(package, arch)) + + +class Version: + """An abstraction around package versions. + + This seems like it should be strictly unnecessary, except that `apt_pkg` is not usable inside a + venv, and wedging version comparisons into `DebianPackage` would overcomplicate it. + + This class implements the algorithm found here: + https://www.debian.org/doc/debian-policy/ch-controlfields.html#version + """ + + def __init__(self, version: str, epoch: str): + self._version = version + self._epoch = epoch or "" + + def __repr__(self): + """Represent the package.""" + return "<{}.{}: {}>".format(self.__module__, self.__class__.__name__, self.__dict__) + + def __str__(self): + """Return human-readable representation of the package.""" + return "{}{}".format("{}:".format(self._epoch) if self._epoch else "", self._version) + + @property + def epoch(self): + """Returns the epoch for a package. May be empty.""" + return self._epoch + + @property + def number(self) -> str: + """Returns the version number for a package.""" + return self._version + + def _get_parts(self, version: str) -> Tuple[str, str]: + """Separate the version into component upstream and Debian pieces.""" + try: + version.rindex("-") + except ValueError: + # No hyphens means no Debian version + return version, "0" + + upstream, debian = version.rsplit("-", 1) + return upstream, debian + + def _listify(self, revision: str) -> List[str]: + """Split a revision string into a listself. + + This list is comprised of alternating between strings and numbers, + padded on either end to always be "str, int, str, int..." and + always be of even length. This allows us to trivially implement the + comparison algorithm described. + """ + result = [] + while revision: + rev_1, remains = self._get_alphas(revision) + rev_2, remains = self._get_digits(remains) + result.extend([rev_1, rev_2]) + revision = remains + return result + + def _get_alphas(self, revision: str) -> Tuple[str, str]: + """Return a tuple of the first non-digit characters of a revision.""" + # get the index of the first digit + for i, char in enumerate(revision): + if char.isdigit(): + if i == 0: + return "", revision + return revision[0:i], revision[i:] + # string is entirely alphas + return revision, "" + + def _get_digits(self, revision: str) -> Tuple[int, str]: + """Return a tuple of the first integer characters of a revision.""" + # If the string is empty, return (0,'') + if not revision: + return 0, "" + # get the index of the first non-digit + for i, char in enumerate(revision): + if not char.isdigit(): + if i == 0: + return 0, revision + return int(revision[0:i]), revision[i:] + # string is entirely digits + return int(revision), "" + + def _dstringcmp(self, a, b): # noqa: C901 + """Debian package version string section lexical sort algorithm. + + The lexical comparison is a comparison of ASCII values modified so + that all the letters sort earlier than all the non-letters and so that + a tilde sorts before anything, even the end of a part. + """ + if a == b: + return 0 + try: + for i, char in enumerate(a): + if char == b[i]: + continue + # "a tilde sorts before anything, even the end of a part" + # (emptyness) + if char == "~": + return -1 + if b[i] == "~": + return 1 + # "all the letters sort earlier than all the non-letters" + if char.isalpha() and not b[i].isalpha(): + return -1 + if not char.isalpha() and b[i].isalpha(): + return 1 + # otherwise lexical sort + if ord(char) > ord(b[i]): + return 1 + if ord(char) < ord(b[i]): + return -1 + except IndexError: + # a is longer than b but otherwise equal, greater unless there are tildes + if char == "~": + return -1 + return 1 + # if we get here, a is shorter than b but otherwise equal, so check for tildes... + if b[len(a)] == "~": + return 1 + return -1 + + def _compare_revision_strings(self, first: str, second: str): # noqa: C901 + """Compare two debian revision strings.""" + if first == second: + return 0 + + # listify pads results so that we will always be comparing ints to ints + # and strings to strings (at least until we fall off the end of a list) + first_list = self._listify(first) + second_list = self._listify(second) + if first_list == second_list: + return 0 + try: + for i, item in enumerate(first_list): + # explicitly raise IndexError if we've fallen off the edge of list2 + if i >= len(second_list): + raise IndexError + # if the items are equal, next + if item == second_list[i]: + continue + # numeric comparison + if isinstance(item, int): + if item > second_list[i]: + return 1 + if item < second_list[i]: + return -1 + else: + # string comparison + return self._dstringcmp(item, second_list[i]) + except IndexError: + # rev1 is longer than rev2 but otherwise equal, hence greater + # ...except for goddamn tildes + if first_list[len(second_list)][0][0] == "~": + return 1 + return 1 + # rev1 is shorter than rev2 but otherwise equal, hence lesser + # ...except for goddamn tildes + if second_list[len(first_list)][0][0] == "~": + return -1 + return -1 + + def _compare_version(self, other) -> int: + if (self.number, self.epoch) == (other.number, other.epoch): + return 0 + + if self.epoch < other.epoch: + return -1 + if self.epoch > other.epoch: + return 1 + + # If none of these are true, follow the algorithm + upstream_version, debian_version = self._get_parts(self.number) + other_upstream_version, other_debian_version = self._get_parts(other.number) + + upstream_cmp = self._compare_revision_strings(upstream_version, other_upstream_version) + if upstream_cmp != 0: + return upstream_cmp + + debian_cmp = self._compare_revision_strings(debian_version, other_debian_version) + if debian_cmp != 0: + return debian_cmp + + return 0 + + def __lt__(self, other) -> bool: + """Less than magic method impl.""" + return self._compare_version(other) < 0 + + def __eq__(self, other) -> bool: + """Equality magic method impl.""" + return self._compare_version(other) == 0 + + def __gt__(self, other) -> bool: + """Greater than magic method impl.""" + return self._compare_version(other) > 0 + + def __le__(self, other) -> bool: + """Less than or equal to magic method impl.""" + return self.__eq__(other) or self.__lt__(other) + + def __ge__(self, other) -> bool: + """Greater than or equal to magic method impl.""" + return self.__gt__(other) or self.__eq__(other) + + def __ne__(self, other) -> bool: + """Not equal to magic method impl.""" + return not self.__eq__(other) + + +def add_package( + package_names: Union[str, List[str]], + version: Optional[str] = "", + arch: Optional[str] = "", + update_cache: Optional[bool] = False, +) -> Union[DebianPackage, List[DebianPackage]]: + """Add a package or list of packages to the system. + + Args: + package_names: single package name, or list of package names + name: the name(s) of the package(s) + version: an (Optional) version as a string. Defaults to the latest known + arch: an optional architecture for the package + update_cache: whether or not to run `apt-get update` prior to operating + + Raises: + TypeError if no package name is given, or explicit version is set for multiple packages + PackageNotFoundError if the package is not in the cache. + PackageError if packages fail to install + """ + cache_refreshed = False + if update_cache: + update() + cache_refreshed = True + + packages = {"success": [], "retry": [], "failed": []} + + package_names = [package_names] if isinstance(package_names, str) else package_names + if not package_names: + raise TypeError("Expected at least one package name to add, received zero!") + + if len(package_names) != 1 and version: + raise TypeError( + "Explicit version should not be set if more than one package is being added!" + ) + + for p in package_names: + pkg, success = _add(p, version, arch) + if success: + packages["success"].append(pkg) + else: + logger.warning("failed to locate and install/update '%s'", pkg) + packages["retry"].append(p) + + if packages["retry"] and not cache_refreshed: + logger.info("updating the apt-cache and retrying installation of failed packages.") + update() + + for p in packages["retry"]: + pkg, success = _add(p, version, arch) + if success: + packages["success"].append(pkg) + else: + packages["failed"].append(p) + + if packages["failed"]: + raise PackageError("Failed to install packages: {}".format(", ".join(packages["failed"]))) + + return packages["success"] if len(packages["success"]) > 1 else packages["success"][0] + + +def _add( + name: str, + version: Optional[str] = "", + arch: Optional[str] = "", +) -> Tuple[Union[DebianPackage, str], bool]: + """Add a package to the system. + + Args: + name: the name(s) of the package(s) + version: an (Optional) version as a string. Defaults to the latest known + arch: an optional architecture for the package + + Returns: a tuple of `DebianPackage` if found, or a :str: if it is not, and + a boolean indicating success + """ + try: + pkg = DebianPackage.from_system(name, version, arch) + pkg.ensure(state=PackageState.Present) + return pkg, True + except PackageNotFoundError: + return name, False + + +def remove_package( + package_names: Union[str, List[str]] +) -> Union[DebianPackage, List[DebianPackage]]: + """Remove package(s) from the system. + + Args: + package_names: the name of a package + + Raises: + PackageNotFoundError if the package is not found. + """ + packages = [] + + package_names = [package_names] if isinstance(package_names, str) else package_names + if not package_names: + raise TypeError("Expected at least one package name to add, received zero!") + + for p in package_names: + try: + pkg = DebianPackage.from_installed_package(p) + pkg.ensure(state=PackageState.Absent) + packages.append(pkg) + except PackageNotFoundError: + logger.info("package '%s' was requested for removal, but it was not installed.", p) + + # the list of packages will be empty when no package is removed + logger.debug("packages: '%s'", packages) + return packages[0] if len(packages) == 1 else packages + + +def update() -> None: + """Update the apt cache via `apt-get update`.""" + subprocess.run(["apt-get", "update"], capture_output=True, check=True) + + +def import_key(key: str) -> str: + """Import an ASCII Armor key. + + A Radix64 format keyid is also supported for backwards + compatibility. In this case Ubuntu keyserver will be + queried for a key via HTTPS by its keyid. This method + is less preferable because https proxy servers may + require traffic decryption which is equivalent to a + man-in-the-middle attack (a proxy server impersonates + keyserver TLS certificates and has to be explicitly + trusted by the system). + + Args: + key: A GPG key in ASCII armor format, including BEGIN + and END markers or a keyid. + + Returns: + The GPG key filename written. + + Raises: + GPGKeyError if the key could not be imported + """ + key = key.strip() + if "-" in key or "\n" in key: + # Send everything not obviously a keyid to GPG to import, as + # we trust its validation better than our own. eg. handling + # comments before the key. + logger.debug("PGP key found (looks like ASCII Armor format)") + if ( + "-----BEGIN PGP PUBLIC KEY BLOCK-----" in key + and "-----END PGP PUBLIC KEY BLOCK-----" in key + ): + logger.debug("Writing provided PGP key in the binary format") + key_bytes = key.encode("utf-8") + key_name = DebianRepository._get_keyid_by_gpg_key(key_bytes) + key_gpg = DebianRepository._dearmor_gpg_key(key_bytes) + gpg_key_filename = "/etc/apt/trusted.gpg.d/{}.gpg".format(key_name) + DebianRepository._write_apt_gpg_keyfile( + key_name=gpg_key_filename, key_material=key_gpg + ) + return gpg_key_filename + else: + raise GPGKeyError("ASCII armor markers missing from GPG key") + else: + logger.warning( + "PGP key found (looks like Radix64 format). " + "SECURELY importing PGP key from keyserver; " + "full key not provided." + ) + # as of bionic add-apt-repository uses curl with an HTTPS keyserver URL + # to retrieve GPG keys. `apt-key adv` command is deprecated as is + # apt-key in general as noted in its manpage. See lp:1433761 for more + # history. Instead, /etc/apt/trusted.gpg.d is used directly to drop + # gpg + key_asc = DebianRepository._get_key_by_keyid(key) + # write the key in GPG format so that apt-key list shows it + key_gpg = DebianRepository._dearmor_gpg_key(key_asc.encode("utf-8")) + gpg_key_filename = "/etc/apt/trusted.gpg.d/{}.gpg".format(key) + DebianRepository._write_apt_gpg_keyfile(key_name=gpg_key_filename, key_material=key_gpg) + return gpg_key_filename + + +class InvalidSourceError(Error): + """Exceptions for invalid source entries.""" + + +class GPGKeyError(Error): + """Exceptions for GPG keys.""" + + +class DebianRepository: + """An abstraction to represent a repository.""" + + def __init__( + self, + enabled: bool, + repotype: str, + uri: str, + release: str, + groups: List[str], + filename: Optional[str] = "", + gpg_key_filename: Optional[str] = "", + options: Optional[dict] = None, + ): + self._enabled = enabled + self._repotype = repotype + self._uri = uri + self._release = release + self._groups = groups + self._filename = filename + self._gpg_key_filename = gpg_key_filename + self._options = options + + @property + def enabled(self): + """Return whether or not the repository is enabled.""" + return self._enabled + + @property + def repotype(self): + """Return whether it is binary or source.""" + return self._repotype + + @property + def uri(self): + """Return the URI.""" + return self._uri + + @property + def release(self): + """Return which Debian/Ubuntu releases it is valid for.""" + return self._release + + @property + def groups(self): + """Return the enabled package groups.""" + return self._groups + + @property + def filename(self): + """Returns the filename for a repository.""" + return self._filename + + @filename.setter + def filename(self, fname: str) -> None: + """Set the filename used when a repo is written back to disk. + + Args: + fname: a filename to write the repository information to. + """ + if not fname.endswith(".list"): + raise InvalidSourceError("apt source filenames should end in .list!") + + self._filename = fname + + @property + def gpg_key(self): + """Returns the path to the GPG key for this repository.""" + return self._gpg_key_filename + + @property + def options(self): + """Returns any additional repo options which are set.""" + return self._options + + def make_options_string(self) -> str: + """Generate the complete options string for a a repository. + + Combining `gpg_key`, if set, and the rest of the options to find + a complex repo string. + """ + options = self._options if self._options else {} + if self._gpg_key_filename: + options["signed-by"] = self._gpg_key_filename + + return ( + "[{}] ".format(" ".join(["{}={}".format(k, v) for k, v in options.items()])) + if options + else "" + ) + + @staticmethod + def prefix_from_uri(uri: str) -> str: + """Get a repo list prefix from the uri, depending on whether a path is set.""" + uridetails = urlparse(uri) + path = ( + uridetails.path.lstrip("/").replace("/", "-") if uridetails.path else uridetails.netloc + ) + return "/etc/apt/sources.list.d/{}".format(path) + + @staticmethod + def from_repo_line(repo_line: str, write_file: Optional[bool] = True) -> "DebianRepository": + """Instantiate a new `DebianRepository` a `sources.list` entry line. + + Args: + repo_line: a string representing a repository entry + write_file: boolean to enable writing the new repo to disk + """ + repo = RepositoryMapping._parse(repo_line, "UserInput") + fname = "{}-{}.list".format( + DebianRepository.prefix_from_uri(repo.uri), repo.release.replace("/", "-") + ) + repo.filename = fname + + options = repo.options if repo.options else {} + if repo.gpg_key: + options["signed-by"] = repo.gpg_key + + # For Python 3.5 it's required to use sorted in the options dict in order to not have + # different results in the order of the options between executions. + options_str = ( + "[{}] ".format(" ".join(["{}={}".format(k, v) for k, v in sorted(options.items())])) + if options + else "" + ) + + if write_file: + with open(fname, "wb") as f: + f.write( + ( + "{}".format("#" if not repo.enabled else "") + + "{} {}{} ".format(repo.repotype, options_str, repo.uri) + + "{} {}\n".format(repo.release, " ".join(repo.groups)) + ).encode("utf-8") + ) + + return repo + + def disable(self) -> None: + """Remove this repository from consideration. + + Disable it instead of removing from the repository file. + """ + searcher = "{} {}{} {}".format( + self.repotype, self.make_options_string(), self.uri, self.release + ) + for line in fileinput.input(self._filename, inplace=True): + if re.match(r"^{}\s".format(re.escape(searcher)), line): + print("# {}".format(line), end="") + else: + print(line, end="") + + def import_key(self, key: str) -> None: + """Import an ASCII Armor key. + + A Radix64 format keyid is also supported for backwards + compatibility. In this case Ubuntu keyserver will be + queried for a key via HTTPS by its keyid. This method + is less preferable because https proxy servers may + require traffic decryption which is equivalent to a + man-in-the-middle attack (a proxy server impersonates + keyserver TLS certificates and has to be explicitly + trusted by the system). + + Args: + key: A GPG key in ASCII armor format, + including BEGIN and END markers or a keyid. + + Raises: + GPGKeyError if the key could not be imported + """ + self._gpg_key_filename = import_key(key) + + @staticmethod + def _get_keyid_by_gpg_key(key_material: bytes) -> str: + """Get a GPG key fingerprint by GPG key material. + + Gets a GPG key fingerprint (40-digit, 160-bit) by the ASCII armor-encoded + or binary GPG key material. Can be used, for example, to generate file + names for keys passed via charm options. + """ + # Use the same gpg command for both Xenial and Bionic + cmd = ["gpg", "--with-colons", "--with-fingerprint"] + ps = subprocess.run( + cmd, + stdout=PIPE, + stderr=PIPE, + input=key_material, + ) + out, err = ps.stdout.decode(), ps.stderr.decode() + if "gpg: no valid OpenPGP data found." in err: + raise GPGKeyError("Invalid GPG key material provided") + # from gnupg2 docs: fpr :: Fingerprint (fingerprint is in field 10) + return re.search(r"^fpr:{9}([0-9A-F]{40}):$", out, re.MULTILINE).group(1) + + @staticmethod + def _get_key_by_keyid(keyid: str) -> str: + """Get a key via HTTPS from the Ubuntu keyserver. + + Different key ID formats are supported by SKS keyservers (the longer ones + are more secure, see "dead beef attack" and https://evil32.com/). Since + HTTPS is used, if SSLBump-like HTTPS proxies are in place, they will + impersonate keyserver.ubuntu.com and generate a certificate with + keyserver.ubuntu.com in the CN field or in SubjAltName fields of a + certificate. If such proxy behavior is expected it is necessary to add the + CA certificate chain containing the intermediate CA of the SSLBump proxy to + every machine that this code runs on via ca-certs cloud-init directive (via + cloudinit-userdata model-config) or via other means (such as through a + custom charm option). Also note that DNS resolution for the hostname in a + URL is done at a proxy server - not at the client side. + 8-digit (32 bit) key ID + https://keyserver.ubuntu.com/pks/lookup?search=0x4652B4E6 + 16-digit (64 bit) key ID + https://keyserver.ubuntu.com/pks/lookup?search=0x6E85A86E4652B4E6 + 40-digit key ID: + https://keyserver.ubuntu.com/pks/lookup?search=0x35F77D63B5CEC106C577ED856E85A86E4652B4E6 + + Args: + keyid: An 8, 16 or 40 hex digit keyid to find a key for + + Returns: + A string contining key material for the specified GPG key id + + + Raises: + subprocess.CalledProcessError + """ + # options=mr - machine-readable output (disables html wrappers) + keyserver_url = ( + "https://keyserver.ubuntu.com" "/pks/lookup?op=get&options=mr&exact=on&search=0x{}" + ) + curl_cmd = ["curl", keyserver_url.format(keyid)] + # use proxy server settings in order to retrieve the key + return check_output(curl_cmd).decode() + + @staticmethod + def _dearmor_gpg_key(key_asc: bytes) -> bytes: + """Convert a GPG key in the ASCII armor format to the binary format. + + Args: + key_asc: A GPG key in ASCII armor format. + + Returns: + A GPG key in binary format as a string + + Raises: + GPGKeyError + """ + ps = subprocess.run(["gpg", "--dearmor"], stdout=PIPE, stderr=PIPE, input=key_asc) + out, err = ps.stdout, ps.stderr.decode() + if "gpg: no valid OpenPGP data found." in err: + raise GPGKeyError( + "Invalid GPG key material. Check your network setup" + " (MTU, routing, DNS) and/or proxy server settings" + " as well as destination keyserver status." + ) + else: + return out + + @staticmethod + def _write_apt_gpg_keyfile(key_name: str, key_material: bytes) -> None: + """Write GPG key material into a file at a provided path. + + Args: + key_name: A key name to use for a key file (could be a fingerprint) + key_material: A GPG key material (binary) + """ + with open(key_name, "wb") as keyf: + keyf.write(key_material) + + +class RepositoryMapping(Mapping): + """An representation of known repositories. + + Instantiation of `RepositoryMapping` will iterate through the + filesystem, parse out repository files in `/etc/apt/...`, and create + `DebianRepository` objects in this list. + + Typical usage: + + repositories = apt.RepositoryMapping() + repositories.add(DebianRepository( + enabled=True, repotype="deb", uri="https://example.com", release="focal", + groups=["universe"] + )) + """ + + def __init__(self): + self._repository_map = {} + # Repositories that we're adding -- used to implement mode param + self.default_file = "/etc/apt/sources.list" + + # read sources.list if it exists + if os.path.isfile(self.default_file): + self.load(self.default_file) + + # read sources.list.d + for file in glob.iglob("/etc/apt/sources.list.d/*.list"): + self.load(file) + + def __contains__(self, key: str) -> bool: + """Magic method for checking presence of repo in mapping.""" + return key in self._repository_map + + def __len__(self) -> int: + """Return number of repositories in map.""" + return len(self._repository_map) + + def __iter__(self) -> Iterable[DebianRepository]: + """Return iterator for RepositoryMapping.""" + return iter(self._repository_map.values()) + + def __getitem__(self, repository_uri: str) -> DebianRepository: + """Return a given `DebianRepository`.""" + return self._repository_map[repository_uri] + + def __setitem__(self, repository_uri: str, repository: DebianRepository) -> None: + """Add a `DebianRepository` to the cache.""" + self._repository_map[repository_uri] = repository + + def load(self, filename: str): + """Load a repository source file into the cache. + + Args: + filename: the path to the repository file + """ + parsed = [] + skipped = [] + with open(filename, "r") as f: + for n, line in enumerate(f): + try: + repo = self._parse(line, filename) + except InvalidSourceError: + skipped.append(n) + else: + repo_identifier = "{}-{}-{}".format(repo.repotype, repo.uri, repo.release) + self._repository_map[repo_identifier] = repo + parsed.append(n) + logger.debug("parsed repo: '%s'", repo_identifier) + + if skipped: + skip_list = ", ".join(str(s) for s in skipped) + logger.debug("skipped the following lines in file '%s': %s", filename, skip_list) + + if parsed: + logger.info("parsed %d apt package repositories", len(parsed)) + else: + raise InvalidSourceError("all repository lines in '{}' were invalid!".format(filename)) + + @staticmethod + def _parse(line: str, filename: str) -> DebianRepository: + """Parse a line in a sources.list file. + + Args: + line: a single line from `load` to parse + filename: the filename being read + + Raises: + InvalidSourceError if the source type is unknown + """ + enabled = True + repotype = uri = release = gpg_key = "" + options = {} + groups = [] + + line = line.strip() + if line.startswith("#"): + enabled = False + line = line[1:] + + # Check for "#" in the line and treat a part after it as a comment then strip it off. + i = line.find("#") + if i > 0: + line = line[:i] + + # Split a source into substrings to initialize a new repo. + source = line.strip() + if source: + # Match any repo options, and get a dict representation. + for v in re.findall(OPTIONS_MATCHER, source): + opts = dict(o.split("=") for o in v.strip("[]").split()) + # Extract the 'signed-by' option for the gpg_key + gpg_key = opts.pop("signed-by", "") + options = opts + + # Remove any options from the source string and split the string into chunks + source = re.sub(OPTIONS_MATCHER, "", source) + chunks = source.split() + + # Check we've got a valid list of chunks + if len(chunks) < 3 or chunks[0] not in VALID_SOURCE_TYPES: + raise InvalidSourceError("An invalid sources line was found in %s!", filename) + + repotype = chunks[0] + uri = chunks[1] + release = chunks[2] + groups = chunks[3:] + + return DebianRepository( + enabled, repotype, uri, release, groups, filename, gpg_key, options + ) + else: + raise InvalidSourceError("An invalid sources line was found in %s!", filename) + + def add(self, repo: DebianRepository, default_filename: Optional[bool] = False) -> None: + """Add a new repository to the system. + + Args: + repo: a `DebianRepository` object + default_filename: an (Optional) filename if the default is not desirable + """ + new_filename = "{}-{}.list".format( + DebianRepository.prefix_from_uri(repo.uri), repo.release.replace("/", "-") + ) + + fname = repo.filename or new_filename + + options = repo.options if repo.options else {} + if repo.gpg_key: + options["signed-by"] = repo.gpg_key + + with open(fname, "wb") as f: + f.write( + ( + "{}".format("#" if not repo.enabled else "") + + "{} {}{} ".format(repo.repotype, repo.make_options_string(), repo.uri) + + "{} {}\n".format(repo.release, " ".join(repo.groups)) + ).encode("utf-8") + ) + + self._repository_map["{}-{}-{}".format(repo.repotype, repo.uri, repo.release)] = repo + + def disable(self, repo: DebianRepository) -> None: + """Remove a repository. Disable by default. + + Args: + repo: a `DebianRepository` to disable + """ + searcher = "{} {}{} {}".format( + repo.repotype, repo.make_options_string(), repo.uri, repo.release + ) + + for line in fileinput.input(repo.filename, inplace=True): + if re.match(r"^{}\s".format(re.escape(searcher)), line): + print("# {}".format(line), end="") + else: + print(line, end="") + + self._repository_map["{}-{}-{}".format(repo.repotype, repo.uri, repo.release)] = repo diff --git a/lib/charms/operator_libs_linux/v1/systemd.py b/lib/charms/operator_libs_linux/v1/systemd.py index 5be34c1..cdcbad6 100644 --- a/lib/charms/operator_libs_linux/v1/systemd.py +++ b/lib/charms/operator_libs_linux/v1/systemd.py @@ -23,6 +23,7 @@ service_resume with run the mask/unmask and enable/disable invocations. Example usage: + ```python from charms.operator_libs_linux.v0.systemd import service_running, service_reload @@ -33,13 +34,14 @@ # Attempt to reload a service, restarting if necessary success = service_reload("nginx", restart_on_failure=True) ``` - """ -import logging -import subprocess - __all__ = [ # Don't export `_systemctl`. (It's not the intended way of using this lib.) + "SystemdError", + "daemon_reload", + "service_disable", + "service_enable", + "service_failed", "service_pause", "service_reload", "service_restart", @@ -47,9 +49,11 @@ "service_running", "service_start", "service_stop", - "daemon_reload", ] +import logging +import subprocess + logger = logging.getLogger(__name__) # The unique Charmhub library identifier, never change it @@ -60,122 +64,168 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 0 +LIBPATCH = 4 class SystemdError(Exception): - pass + """Custom exception for SystemD related errors.""" -def _popen_kwargs(): - return dict( - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - bufsize=1, - universal_newlines=True, - encoding="utf-8", - ) +def _systemctl(*args: str, check: bool = False) -> int: + """Control a system service using systemctl. + Args: + *args: Arguments to pass to systemctl. + check: Check the output of the systemctl command. Default: False. -def _systemctl( - sub_cmd: str, service_name: str = None, now: bool = None, quiet: bool = None -) -> bool: - """Control a system service. + Returns: + Returncode of systemctl command execution. - Args: - sub_cmd: the systemctl subcommand to issue - service_name: the name of the service to perform the action on - now: passes the --now flag to the shell invocation. - quiet: passes the --quiet flag to the shell invocation. + Raises: + SystemdError: Raised if calling systemctl returns a non-zero returncode and check is True. """ - cmd = ["systemctl", sub_cmd] - - if service_name is not None: - cmd.append(service_name) - if now is not None: - cmd.append("--now") - if quiet is not None: - cmd.append("--quiet") - if sub_cmd != "is-active": - logger.debug("Attempting to {} '{}' with command {}.".format(cmd, service_name, cmd)) - else: - logger.debug("Checking if '{}' is active".format(service_name)) - - proc = subprocess.Popen(cmd, **_popen_kwargs()) - last_line = "" - for line in iter(proc.stdout.readline, ""): - last_line = line - logger.debug(line) - - proc.wait() - - if sub_cmd == "is-active": - # If we are just checking whether a service is running, return True/False, rather - # than raising an error. - if proc.returncode < 1: - return True - if proc.returncode == 3: # Code returned when service is not active. - return False - - if proc.returncode < 1: - return True - - raise SystemdError( - "Could not {}{}: systemd output: {}".format( - sub_cmd, " {}".format(service_name) if service_name else "", last_line + cmd = ["systemctl", *args] + logger.debug(f"Executing command: {cmd}") + try: + proc = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + encoding="utf-8", + check=check, + ) + logger.debug( + f"Command {cmd} exit code: {proc.returncode}. systemctl output:\n{proc.stdout}" + ) + return proc.returncode + except subprocess.CalledProcessError as e: + raise SystemdError( + f"Command {cmd} failed with returncode {e.returncode}. systemctl output:\n{e.stdout}" ) - ) def service_running(service_name: str) -> bool: - """Determine whether a system service is running. + """Report whether a system service is running. + + Args: + service_name: The name of the service to check. + + Return: + True if service is running/active; False if not. + """ + # If returncode is 0, this means that is service is active. + return _systemctl("--quiet", "is-active", service_name) == 0 + + +def service_failed(service_name: str) -> bool: + """Report whether a system service has failed. Args: - service_name: the name of the service + service_name: The name of the service to check. + + Returns: + True if service is marked as failed; False if not. """ - return _systemctl("is-active", service_name, quiet=True) + # If returncode is 0, this means that the service has failed. + return _systemctl("--quiet", "is-failed", service_name) == 0 -def service_start(service_name: str) -> bool: +def service_start(*args: str) -> bool: """Start a system service. Args: - service_name: the name of the service to stop + *args: Arguments to pass to `systemctl start` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl start ...` returns a non-zero returncode. """ - return _systemctl("start", service_name) + return _systemctl("start", *args, check=True) == 0 -def service_stop(service_name: str) -> bool: +def service_stop(*args: str) -> bool: """Stop a system service. Args: - service_name: the name of the service to stop + *args: Arguments to pass to `systemctl stop` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl stop ...` returns a non-zero returncode. """ - return _systemctl("stop", service_name) + return _systemctl("stop", *args, check=True) == 0 -def service_restart(service_name: str) -> bool: +def service_restart(*args: str) -> bool: """Restart a system service. Args: - service_name: the name of the service to restart + *args: Arguments to pass to `systemctl restart` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl restart ...` returns a non-zero returncode. """ - return _systemctl("restart", service_name) + return _systemctl("restart", *args, check=True) == 0 + + +def service_enable(*args: str) -> bool: + """Enable a system service. + + Args: + *args: Arguments to pass to `systemctl enable` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl enable ...` returns a non-zero returncode. + """ + return _systemctl("enable", *args, check=True) == 0 + + +def service_disable(*args: str) -> bool: + """Disable a system service. + + Args: + *args: Arguments to pass to `systemctl disable` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl disable ...` returns a non-zero returncode. + """ + return _systemctl("disable", *args, check=True) == 0 def service_reload(service_name: str, restart_on_failure: bool = False) -> bool: """Reload a system service, optionally falling back to restart if reload fails. Args: - service_name: the name of the service to reload - restart_on_failure: boolean indicating whether to fallback to a restart if the - reload fails. + service_name: The name of the service to reload. + restart_on_failure: + Boolean indicating whether to fall back to a restart if the reload fails. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl reload|restart ...` returns a non-zero returncode. """ try: - return _systemctl("reload", service_name) + return _systemctl("reload", service_name, check=True) == 0 except SystemdError: if restart_on_failure: - return _systemctl("restart", service_name) + return service_restart(service_name) else: raise @@ -183,37 +233,56 @@ def service_reload(service_name: str, restart_on_failure: bool = False) -> bool: def service_pause(service_name: str) -> bool: """Pause a system service. - Stop it, and prevent it from starting again at boot. + Stops the service and prevents the service from starting again at boot. Args: - service_name: the name of the service to pause + service_name: The name of the service to pause. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if service is still running after being paused by systemctl. """ - _systemctl("disable", service_name, now=True) + _systemctl("disable", "--now", service_name) _systemctl("mask", service_name) - if not service_running(service_name): - return True + if service_running(service_name): + raise SystemdError(f"Attempted to pause {service_name!r}, but it is still running.") - raise SystemdError("Attempted to pause '{}', but it is still running.".format(service_name)) + return True def service_resume(service_name: str) -> bool: """Resume a system service. - Re-enable starting again at boot. Start the service. + Re-enable starting the service again at boot. Start the service. Args: - service_name: the name of the service to resume + service_name: The name of the service to resume. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if service is not running after being resumed by systemctl. """ _systemctl("unmask", service_name) - _systemctl("enable", service_name, now=True) + _systemctl("enable", "--now", service_name) - if service_running(service_name): - return True + if not service_running(service_name): + raise SystemdError(f"Attempted to resume {service_name!r}, but it is not running.") - raise SystemdError("Attempted to resume '{}', but it is not running.".format(service_name)) + return True def daemon_reload() -> bool: - """Reload systemd manager configuration.""" - return _systemctl("daemon-reload") + """Reload systemd manager configuration. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl daemon-reload` returns a non-zero returncode. + """ + return _systemctl("daemon-reload", check=True) == 0 diff --git a/metadata.yaml b/metadata.yaml deleted file mode 100644 index 01f2b0d..0000000 --- a/metadata.yaml +++ /dev/null @@ -1,30 +0,0 @@ -name: slurmdbd -summary: | - Slurm DBD accounting daemon -description: | - This charm provides slurmdbd, munged, and the bindings to other utilities - that make lifecycle operations a breeze. - - slurmdbd provides a secure enterprise-wide interface to a database for - SLURM. This is particularly useful for archiving accounting records. -source: https://github.com/omnivector-solutions/slurmdbd-operator -issues: https://github.com/omnivector-solutions/slurmdbd-operator/issues -maintainers: - - OmniVector Solutions - - Jason C. Nucciarone - - David Gomez - -peers: - slurmdbd-peer: - interface: slurmdbd-peer -requires: - database: - interface: mysql_client - fluentbit: - interface: fluentbit -provides: - slurmdbd: - interface: slurmdbd - -assumes: - - juju diff --git a/pyproject.toml b/pyproject.toml index 99734d9..65f01b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,8 +36,8 @@ target-version = ["py38"] # Linting tools configuration [tool.ruff] line-length = 99 -select = ["E", "W", "F", "C", "N", "D", "I001"] -extend-ignore = [ +lint.select = ["E", "W", "F", "C", "N", "D", "I001"] +lint.extend-ignore = [ "D203", "D204", "D213", @@ -50,9 +50,9 @@ extend-ignore = [ "D409", "D413", ] -ignore = ["E501", "D107"] +lint.ignore = ["E501", "D107"] extend-exclude = ["__pycache__", "*.egg_info"] -per-file-ignores = {"tests/*" = ["D100","D101","D102","D103","D104"]} +lint.per-file-ignores = {"tests/*" = ["D100","D101","D102","D103","D104"]} -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] max-complexity = 10 diff --git a/requirements.txt b/requirements.txt index dfa795a..5889f38 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -ops==2.* -git+https://github.com/omnivector-solutions/slurm-ops-manager.git@0.8.16 \ No newline at end of file +ops==2.14.0 +distro==1.9.0 diff --git a/src/charm.py b/src/charm.py index 0c0d775..c86544e 100755 --- a/src/charm.py +++ b/src/charm.py @@ -1,13 +1,12 @@ #!/usr/bin/env python3 -# Copyright 2020 Omnivector Solutions, LLC. +# Copyright 2020-2024 Omnivector, LLC. # See LICENSE file for licensing details. """Slurmdbd Operator Charm.""" import logging -from pathlib import Path from time import sleep -from typing import Any, Dict +from typing import Any, Dict, Union from urllib.parse import urlparse from charms.data_platform_libs.v0.data_interfaces import ( @@ -15,149 +14,105 @@ DatabaseRequires, ) from charms.fluentbit.v0.fluentbit import FluentbitClient -from interface_slurmdbd import Slurmdbd -from interface_slurmdbd_peer import SlurmdbdPeer -from ops.charm import CharmBase, CharmEvents -from ops.framework import EventBase, EventSource, StoredState -from ops.main import main -from ops.model import ActiveStatus, BlockedStatus, ErrorStatus, WaitingStatus -from slurm_ops_manager import SlurmManager -from utils.manager import SlurmdbdManager +from constants import CHARM_MAINTAINED_PARAMETERS, SLURM_ACCT_DB +from interface_slurmctld import Slurmctld, SlurmctldAvailableEvent, SlurmctldUnavailableEvent +from ops import ( + ActiveStatus, + BlockedStatus, + CharmBase, + ConfigChangedEvent, + InstallEvent, + RelationCreatedEvent, + StoredState, + UpdateStatusEvent, + WaitingStatus, + main, +) +from slurmdbd_ops import SlurmdbdOpsManager logger = logging.getLogger(__name__) -SLURM_ACCT_DB = "slurm_acct_db" - - -class JwtAvailable(EventBase): - """Emitted when JWT RSA is available.""" - - -class MungeAvailable(EventBase): - """Emitted when JWT RSA is available.""" - - -class WriteConfigAndRestartSlurmdbd(EventBase): - """Emitted when config needs to be written.""" - - -class SlurmdbdCharmEvents(CharmEvents): - """Slurmdbd emitted events.""" - - jwt_available = EventSource(JwtAvailable) - munge_available = EventSource(MungeAvailable) - write_config = EventSource(WriteConfigAndRestartSlurmdbd) - - class SlurmdbdCharm(CharmBase): """Slurmdbd Charm.""" _stored = StoredState() - on = SlurmdbdCharmEvents() def __init__(self, *args, **kwargs) -> None: """Set the default class attributes.""" super().__init__(*args, **kwargs) self._stored.set_default( - db_info={}, - jwt_available=False, - munge_available=False, slurm_installed=False, - cluster_name=str(), + db_info={}, ) - self._slurmdbd_manager = SlurmdbdManager() self._db = DatabaseRequires(self, relation_name="database", database_name=SLURM_ACCT_DB) - self._slurm_manager = SlurmManager(self, "slurmdbd") - self._slurmdbd = Slurmdbd(self, "slurmdbd") - self._slurmdbd_peer = SlurmdbdPeer(self, "slurmdbd-peer") self._fluentbit = FluentbitClient(self, "fluentbit") + self._slurmdbd_ops_manager = SlurmdbdOpsManager() + self._slurmctld = Slurmctld(self, "slurmctld") - for event, handler in { + event_handler_bindings = { + # Charm core events self.on.install: self._on_install, - self.on.upgrade_charm: self._on_upgrade, self.on.update_status: self._on_update_status, self.on.config_changed: self._write_config_and_restart_slurmdbd, - self.on.jwt_available: self._on_jwt_available, - self.on.munge_available: self._on_munge_available, - self.on.write_config: self._write_config_and_restart_slurmdbd, + # Database relation self._db.on.database_created: self._on_database_created, - self._slurmdbd_peer.on.slurmdbd_peer_available: self._write_config_and_restart_slurmdbd, - self._slurmdbd.on.slurmctld_available: self._on_slurmctld_available, - self._slurmdbd.on.slurmctld_unavailable: self._on_slurmctld_unavailable, + # Slurmctld + self._slurmctld.on.slurmctld_available: self._on_slurmctld_available, + self._slurmctld.on.slurmctld_unavailable: self._on_slurmctld_unavailable, # fluentbit self.on["fluentbit"].relation_created: self._on_fluentbit_relation_created, - }.items(): + } + for event, handler in event_handler_bindings.items(): self.framework.observe(event, handler) - def _on_install(self, event): + def _on_install(self, event: InstallEvent) -> None: """Perform installation operations for slurmdbd.""" - self.unit.set_workload_version(Path("version").read_text().strip()) + if not self.model.unit.is_leader(): + self.unit.status = BlockedStatus("Only singleton slurmdbd currently supported.") + event.defer() + return self.unit.status = WaitingStatus("Installing slurmdbd") - custom_repo = self.config.get("custom-slurm-repo") - successful_installation = self._slurm_manager.install(custom_repo) - - if successful_installation: + if self._slurmdbd_ops_manager.install() is not False: + self.unit.set_workload_version(self._slurmdbd_ops_manager.version) self._stored.slurm_installed = True - self.unit.status = ActiveStatus("slurmdbd successfully installed") - else: - self.unit.status = BlockedStatus("Error installing slurmdbd") - event.defer() - return + if self._slurmdbd_ops_manager.start_munge(): + logger.debug("## Munge started successfully") + else: + logger.error("## Unable to start munge") + self.unit.status = BlockedStatus("Error restarting munge") + event.defer() + return + + self.unit.status = ActiveStatus("slurmdbd successfully installed") self._check_status() - def _on_fluentbit_relation_created(self, event): + def _on_fluentbit_relation_created(self, event: RelationCreatedEvent) -> None: """Set up Fluentbit log forwarding.""" - self._configure_fluentbit() - - def _configure_fluentbit(self): logger.debug("## Configuring fluentbit") - cfg = [] - cfg.extend(self._slurm_manager.fluentbit_config_nhc) - cfg.extend(self._slurm_manager.fluentbit_config_slurm) - self._fluentbit.configure(cfg) + self._fluentbit.configure(self._slurmdbd_ops_manager.fluentbit_config_slurm) - def _on_upgrade(self, event): - """Perform upgrade operations.""" - self.unit.set_workload_version(Path("version").read_text().strip()) - - def _on_update_status(self, event): + def _on_update_status(self, event: UpdateStatusEvent) -> None: """Handle update status.""" self._check_status() - def _on_jwt_available(self, event): - """Retrieve and configure the jwt_rsa key.""" - # jwt rsa lives in slurm spool dir, it is created when slurm is installed - if not self._stored.slurm_installed: - event.defer() - return - - jwt_rsa = self._slurmdbd.get_jwt_rsa() - self._slurm_manager.configure_jwt_rsa(jwt_rsa) - self._stored.jwt_available = True - - def _on_munge_available(self, event): - """Retrieve munge key and start munged.""" - # munge is installed together with slurm - if not self._stored.slurm_installed: + def _on_slurmctld_available(self, event: SlurmctldAvailableEvent) -> None: + """Retrieve and configure the jwt_rsa and munge_key when slurmctld_available.""" + if self._stored.slurm_installed is not True: event.defer() return - munge_key = self._slurmdbd.get_munge_key() - self._slurm_manager.configure_munge_key(munge_key) + if (jwt := event.jwt_rsa) is not None: + self._slurmdbd_ops_manager.write_jwt_rsa(jwt) + if (munge_key := event.munge_key) is not None: + self._slurmdbd_ops_manager.write_munge_key(munge_key) - if self._slurm_manager.restart_munged(): - logger.debug("## Munge restarted successfully") - self._stored.munge_available = True - else: - logger.error("## Unable to restart munge") - self.unit.status = BlockedStatus("Error restarting munge") - event.defer() + self._write_config_and_restart_slurmdbd(event) def _on_database_created(self, event: DatabaseCreatedEvent) -> None: """Process the DatabaseCreatedEvent and updates the database parameters. @@ -192,7 +147,7 @@ def _on_database_created(self, event: DatabaseCreatedEvent) -> None: # a human to look at and resolve the proper next steps. Reprocessing the # deferred event will only result in continual errors. logger.error(f"No endpoints provided: {event.endpoints}") - self.unit.status = ErrorStatus("No database endpoints") + self.unit.status = BlockedStatus("No database endpoints provided") raise ValueError(f"Unexpected endpoint types: {event.endpoints}") for endpoint in [ep.strip() for ep in event.endpoints.split(",")]: @@ -205,9 +160,9 @@ def _on_database_created(self, event: DatabaseCreatedEvent) -> None: tcp_endpoints.append(endpoint) db_info = { - "db_username": event.username, - "db_password": event.password, - "db_name": SLURM_ACCT_DB, + "StorageUser": event.username, + "StoragePass": event.password, + "StorageLoc": SLURM_ACCT_DB, } if socket_endpoints: @@ -222,12 +177,12 @@ def _on_database_created(self, event: DatabaseCreatedEvent) -> None: # Make sure to strip the file:// off the front of the first endpoint # otherwise slurmdbd will not be able to connect to the database socket = urlparse(socket_endpoints[0]).path - self._slurmdbd_manager.set_environment_var(mysql_unix_port=f'"{socket}"') + self._slurmdbd_ops_manager.set_environment_var(mysql_unix_port=f'"{socket}"') elif tcp_endpoints: # This must be using TCP endpoint and the connection information will # be host_address:port. Only one remote mysql service will be configured # in this case. - logger.debug("Using tcp endpoints specified in the relation.") + logger.debug("Using tcp endpoints specified in the relation") if len(tcp_endpoints) > 1: logger.warning( f"{len(tcp_endpoints)} tcp endpoints are specified, " @@ -239,166 +194,136 @@ def _on_database_created(self, event: DatabaseCreatedEvent) -> None: addr = addr[1:-1] db_info.update( { - "db_hostname": addr, - "db_port": port, + "StorageHost": addr, + "StoragePort": port, } ) # Make sure that the MYSQL_UNIX_PORT is removed from the env file. - self._slurmdbd_manager.set_environment_var(mysql_unix_port=None) + self._slurmdbd_ops_manager.set_environment_var(mysql_unix_port=None) else: # This is 100% an error condition that the charm doesn't know how to handle # and is an unexpected condition. This happens when there are commas but no # usable data in the endpoints. logger.error(f"No endpoints provided: {event.endpoints}") - self.unit.status = ErrorStatus("No database endpoints") + self.unit.status = BlockedStatus("No database endpoints provided") raise ValueError(f"No endpoints provided: {event.endpoints}") - self.set_db_info(db_info) + self._stored.db_info = db_info self._write_config_and_restart_slurmdbd(event) - def _on_slurmctld_available(self, event): - self.on.jwt_available.emit() - self.on.munge_available.emit() - - self.on.write_config.emit() - if self._fluentbit._relation is not None: - self._configure_fluentbit() - - def _on_slurmctld_unavailable(self, event): + def _on_slurmctld_unavailable(self, event: SlurmctldUnavailableEvent) -> None: """Reset state and charm status when slurmctld broken.""" - self._stored.jwt_available = False - self._stored.munge_available = False + self._stored.slurmctld_available = False self._check_status() - def _is_leader(self): - return self.model.unit.is_leader() - - def _write_config_and_restart_slurmdbd(self, event): - """Check for prereqs before writing config/restart of slurmdbd.""" + def _write_config_and_restart_slurmdbd( + self, + event: Union[ + ConfigChangedEvent, + DatabaseCreatedEvent, + InstallEvent, + SlurmctldAvailableEvent, + ], + ) -> None: + """Check that we have what we need before we proceed.""" # Ensure all pre-conditions are met with _check_status(), if not # defer the event. if not self._check_status(): event.defer() return - slurmdbd_config = { - "slurmdbd_debug": self.config.get("slurmdbd-debug"), - **self._slurmdbd_peer.get_slurmdbd_info(), - **self._stored.db_info, - } + if ( + charm_config_slurmdbd_conf_params := self.config.get("slurmdbd-conf-parameters") + ) is not None: + if ( + charm_config_slurmdbd_conf_params + != self._stored.user_supplied_slurmdbd_conf_params + ): + logger.debug("## User supplied parameters changed.") + self._stored.user_supplied_slurmdbd_conf_params = charm_config_slurmdbd_conf_params + + if binding := self.model.get_binding("slurmctld"): + slurmdbd_full_config = { + **CHARM_MAINTAINED_PARAMETERS, + **self._stored.db_info, + **{"DbdHost": self._slurmdbd_ops_manager.hostname}, + **{"DbdAddr": f"{binding.network.ingress_address}"}, + **self._get_user_supplied_parameters(), + } + + if self._slurmctld.is_joined: + slurmdbd_full_config["AuthAltParameters"] = ( + '"jwt_key=/var/spool/slurmdbd/jwt_hs256.key"' + ) - self._slurmdbd_manager.stop() - self._slurm_manager.render_slurm_configs(slurmdbd_config) + self._slurmdbd_ops_manager.stop_slurmdbd() + self._slurmdbd_ops_manager.write_slurmdbd_conf(slurmdbd_full_config) - # At this point, we must guarantee that slurmdbd is correctly - # initialized. Its startup might take a while, so we have to wait - # for it. - self._check_slurmdbd() + # At this point, we must guarantee that slurmdbd is correctly + # initialized. Its startup might take a while, so we have to wait + # for it. + self._check_slurmdbd() - # Only the leader can set relation data on the application. - # Enforce that no one other than the leader tries to set - # application relation data. - if self.model.unit.is_leader(): - self._slurmdbd.set_slurmdbd_info_on_app_relation_data( - slurmdbd_config, - ) + # Only the leader can set relation data on the application. + # Enforce that no one other than the leader tries to set + # application relation data. + if self.model.unit.is_leader(): + self._slurmctld.set_slurmdbd_host_on_app_relation_data( + self._slurmdbd_ops_manager.hostname + ) + else: + logger.debug("Cannot get network binding. Please Debug.") + event.defer() + return self._check_status() - def _check_slurmdbd(self, max_attemps=5) -> None: + def _get_user_supplied_parameters(self) -> Dict[Any, Any]: + """Gather, parse, and return the user supplied parameters.""" + user_supplied_parameters = {} + if custom_config := self.config.get("slurmdbd-conf-parameters"): + try: + user_supplied_parameters = { + line.split("=")[0]: line.split("=")[1] + for line in str(custom_config).split("\n") + if not line.startswith("#") and line.strip() != "" + } + except IndexError as e: + logger.error(f"Could not parse user supplied parameters: {e}.") + return user_supplied_parameters + + def _check_slurmdbd(self, max_attemps: int = 5) -> None: """Ensure slurmdbd is up and running.""" logger.debug("## Checking if slurmdbd is active") for i in range(max_attemps): - if self._slurmdbd_manager.active: + if self._slurmdbd_ops_manager.is_slurmdbd_active(): logger.debug("## Slurmdbd running") break else: logger.warning("## Slurmdbd not running, trying to start it") - self.unit.status = WaitingStatus("Starting slurmdbd") - self._slurmdbd_manager.restart() + self.unit.status = WaitingStatus("Starting slurmdbd ...") + self._slurmdbd_ops_manager.restart_slurmdbd() sleep(3 + i) - if self._slurmdbd_manager.active: + if self._slurmdbd_ops_manager.is_slurmdbd_active(): self._check_status() else: - self.unit.status = BlockedStatus("Cannot start slurmdbd") + self.unit.status = BlockedStatus("Cannot start slurmdbd.") - def _check_status(self) -> bool: # noqa C901 + def _check_status(self) -> bool: """Check that we have the things we need.""" - slurm_installed = self._stored.slurm_installed - if not slurm_installed: - self.unit.status = BlockedStatus("Error installing slurm") + if self._stored.slurm_installed is not True: + self.unit.status = BlockedStatus("Error installing slurmdbd.") return False - # we must be sure to initialize the charms correctly. Slurmdbd must - # first connect to the db to be able to connect to slurmctld correctly - slurmctld_available = self._stored.jwt_available and self._stored.munge_available - statuses = { - "MySQL": { - "available": self._stored.db_info != {}, - "joined": self._stored.db_info != {}, - }, - "slurmctld": {"available": slurmctld_available, "joined": self._slurmdbd.is_joined}, - } - - relations_needed = [] - waiting_on = [] - for component in statuses.keys(): - if not statuses[component]["joined"]: - relations_needed.append(component) - if not statuses[component]["available"]: - waiting_on.append(component) - - if len(relations_needed): - msg = f"Need relations: {','.join(relations_needed)}" - self.unit.status = BlockedStatus(msg) - return False - - if len(waiting_on): - msg = f"Waiting on: {','.join(waiting_on)}" - self.unit.status = WaitingStatus(msg) - return False - - slurmdbd_info = self._slurmdbd_peer.get_slurmdbd_info() - if not slurmdbd_info: - self.unit.status = WaitingStatus("slurmdbd starting") - return False - - if not self._slurm_manager.check_munged(): - self.unit.status = WaitingStatus("munged starting") + if self._stored.db_info == {}: + self.unit.status = WaitingStatus("Waiting on: MySQL") return False - self.unit.status = ActiveStatus("slurmdbd available") + self.unit.status = ActiveStatus() return True - def get_port(self): - """Return the port from slurm-ops-manager.""" - return self._slurm_manager.port - - def get_hostname(self): - """Return the hostname from slurm-ops-manager.""" - return self._slurm_manager.hostname - - def set_db_info(self, new_db_info: Dict[str, Any]) -> None: - """Set the db_info in the stored state. - - Args: - new_db_info (Dict[str, Any]): - New backend database information to set. - """ - self._stored.db_info.update(new_db_info) - - @property - def cluster_name(self) -> str: - """Return the cluster-name.""" - return self._stored.cluster_name - - @cluster_name.setter - def cluster_name(self, name: str): - """Set the cluster-name.""" - self._stored.cluster_name = name - if __name__ == "__main__": - main(SlurmdbdCharm) + main.main(SlurmdbdCharm) diff --git a/src/constants.py b/src/constants.py new file mode 100644 index 0000000..23a1afe --- /dev/null +++ b/src/constants.py @@ -0,0 +1,52 @@ +# Copyright 2024 Omnivector, LLC. +# See LICENSE file for licensing details. +"""Constants.""" +from pathlib import Path + +CHARM_MAINTAINED_PARAMETERS = { + "DbdPort": "6819", + "AuthType": "auth/munge", + "AuthInfo": '"socket=/var/run/munge/munge.socket.2"', + "SlurmUser": "slurm", + "PluginDir": "/usr/lib/x86_64-linux-gnu/slurm-wlm", + "PidFile": "/var/run/slurmdbd.pid", + "LogFile": "/var/log/slurm/slurmdbd.log", + "StorageType": "accounting_storage/mysql", +} + +SLURM_ACCT_DB = "slurm_acct_db" + +SLURMDBD_DEFAULTS_FILE = Path("/etc/default/slurmdbd") + +UBUNTU_HPC_PPA_KEY: str = """ +-----BEGIN PGP PUBLIC KEY BLOCK----- +Comment: Hostname: +Version: Hockeypuck 2.1.1-10-gec3b0e7 + +xsFNBGTuZb8BEACtJ1CnZe6/hv84DceHv+a54y3Pqq0gqED0xhTKnbj/E2ByJpmT +NlDNkpeITwPAAN1e3824Me76Qn31RkogTMoPJ2o2XfG253RXd67MPxYhfKTJcnM3 +CEkmeI4u2Lynh3O6RQ08nAFS2AGTeFVFH2GPNWrfOsGZW03Jas85TZ0k7LXVHiBs +W6qonbsFJhshvwC3SryG4XYT+z/+35x5fus4rPtMrrEOD65hij7EtQNaE8owuAju +Kcd0m2b+crMXNcllWFWmYMV0VjksQvYD7jwGrWeKs+EeHgU8ZuqaIP4pYHvoQjag +umqnH9Qsaq5NAXiuAIAGDIIV4RdAfQIR4opGaVgIFJdvoSwYe3oh2JlrLPBlyxyY +dayDifd3X8jxq6/oAuyH1h5K/QLs46jLSR8fUbG98SCHlRmvozTuWGk+e07ALtGe +sGv78ToHKwoM2buXaTTHMwYwu7Rx8LZ4bZPHdersN1VW/m9yn1n5hMzwbFKy2s6/ +D4Q2ZBsqlN+5aW2q0IUmO+m0GhcdaDv8U7RVto1cWWPr50HhiCi7Yvei1qZiD9jq +57oYZVqTUNCTPxi6NeTOdEc+YqNynWNArx4PHh38LT0bqKtlZCGHNfoAJLPVYhbB +b2AHj9edYtHU9AAFSIy+HstET6P0UDxy02IeyE2yxoUBqdlXyv6FL44E+wARAQAB +zRxMYXVuY2hwYWQgUFBBIGZvciBVYnVudHUgSFBDwsGOBBMBCgA4FiEErocSHcPk +oLD4H/Aj9tDF1ca+s3sFAmTuZb8CGwMFCwkIBwIGFQoJCAsCBBYCAwECHgECF4AA +CgkQ9tDF1ca+s3sz3w//RNawsgydrutcbKf0yphDhzWS53wgfrs2KF1KgB0u/H+u +6Kn2C6jrVM0vuY4NKpbEPCduOj21pTCepL6PoCLv++tICOLVok5wY7Zn3WQFq0js +Iy1wO5t3kA1cTD/05v/qQVBGZ2j4DsJo33iMcQS5AjHvSr0nu7XSvDDEE3cQE55D +87vL7lgGjuTOikPh5FpCoS1gpemBfwm2Lbm4P8vGOA4/witRjGgfC1fv1idUnZLM +TbGrDlhVie8pX2kgB6yTYbJ3P3kpC1ZPpXSRWO/cQ8xoYpLBTXOOtqwZZUnxyzHh +gM+hv42vPTOnCo+apD97/VArsp59pDqEVoAtMTk72fdBqR+BB77g2hBkKESgQIEq +EiE1/TOISioMkE0AuUdaJ2ebyQXugSHHuBaqbEC47v8t5DVN5Qr9OriuzCuSDNFn +6SBHpahN9ZNi9w0A/Yh1+lFfpkVw2t04Q2LNuupqOpW+h3/62AeUqjUIAIrmfeML +IDRE2VdquYdIXKuhNvfpJYGdyvx/wAbiAeBWg0uPSepwTfTG59VPQmj0FtalkMnN +ya2212K5q68O5eXOfCnGeMvqIXxqzpdukxSZnLkgk40uFJnJVESd/CxHquqHPUDE +fy6i2AnB3kUI27D4HY2YSlXLSRbjiSxTfVwNCzDsIh7Czefsm6ITK2+cVWs0hNQ= +=cs1s +-----END PGP PUBLIC KEY BLOCK----- +""" diff --git a/src/interface_slurmctld.py b/src/interface_slurmctld.py new file mode 100644 index 0000000..ec692d8 --- /dev/null +++ b/src/interface_slurmctld.py @@ -0,0 +1,118 @@ +"""Slurmctld interface for slurmdbd-operator.""" + +import json +import logging +from typing import List, Union + +from ops import ( + EventBase, + EventSource, + Object, + ObjectEvents, + Relation, + RelationBrokenEvent, + RelationChangedEvent, +) + +logger = logging.getLogger() + + +class SlurmctldAvailableEvent(EventBase): + """Emitted when slurmctld is unavailable.""" + + def __init__(self, handle, munge_key, jwt_rsa): + super().__init__(handle) + + self.munge_key = munge_key + self.jwt_rsa = jwt_rsa + + def snapshot(self): + """Snapshot the event data.""" + return { + "munge_key": self.munge_key, + "jwt_rsa": self.jwt_rsa, + } + + def restore(self, snapshot): + """Restore the snapshot of the event data.""" + self.munge_key = snapshot.get("munge_key") + self.jwt_rsa = snapshot.get("jwt_rsa") + + +class SlurmctldUnavailableEvent(EventBase): + """Emitted when slurmctld joins the relation.""" + + +class Events(ObjectEvents): + """Slurmctld relation events.""" + + slurmctld_available = EventSource(SlurmctldAvailableEvent) + slurmctld_unavailable = EventSource(SlurmctldUnavailableEvent) + + +class Slurmctld(Object): + """Slurmctld interface for slurmdbd.""" + + on = Events() # pyright: ignore [reportIncompatibleMethodOverride, reportAssignmentType] + + def __init__(self, charm, relation_name): + """Observe relation lifecycle events.""" + super().__init__(charm, relation_name) + + self._charm = charm + self._relation_name = relation_name + + self.framework.observe( + self._charm.on[self._relation_name].relation_changed, + self._on_relation_changed, + ) + + self.framework.observe( + self._charm.on[self._relation_name].relation_broken, + self._on_relation_broken, + ) + + @property + def _relations(self) -> Union[List[Relation], None]: + return self.model.relations.get(self._relation_name) + + @property + def is_joined(self) -> bool: + """Return True if self._relation is not None.""" + return True if self._relations else False + + def _on_relation_changed(self, event: RelationChangedEvent) -> None: + """Handle the relation-changed event. + + Get the cluster_info (munge_key and jwt_rsa) from slurmctld and emit to the charm. + """ + if app := event.app: + event_app_data = event.relation.data[app] + if cluster_info_json := event_app_data.get("cluster_info"): + try: + cluster_info = json.loads(cluster_info_json) + except json.JSONDecodeError as e: + logger.debug(e) + raise (e) + + self.on.slurmctld_available.emit(**cluster_info) + logger.debug(f"## 'cluster_info': {cluster_info}.") + else: + logger.debug("'cluster_info' not in application relation data.") + else: + logger.debug("## No application in the event.") + + def _on_relation_broken(self, event: RelationBrokenEvent) -> None: + """Clear the application relation data and emit the event.""" + self.set_slurmdbd_host_on_app_relation_data("") + self.on.slurmctld_unavailable.emit() + + def set_slurmdbd_host_on_app_relation_data(self, slurmdbd_host: str) -> None: + """Send slurmdbd_info to slurmctld.""" + # Iterate over each of the relations setting the relation data. + if (relations := self._relations) is not None: + logger.debug(f"## Setting slurmdbd_host on app relation data: {slurmdbd_host}") + for relation in relations: + relation.data[self.model.app]["slurmdbd_host"] = slurmdbd_host + else: + logger.debug("## No relation, not setting data.") diff --git a/src/interface_slurmdbd.py b/src/interface_slurmdbd.py deleted file mode 100644 index e113f81..0000000 --- a/src/interface_slurmdbd.py +++ /dev/null @@ -1,149 +0,0 @@ -#! /usr/bin/env python3 -"""Slurmdbd.""" -import json -import logging - -from ops.framework import ( - EventBase, - EventSource, - Object, - ObjectEvents, - StoredState, -) - -logger = logging.getLogger() - - -class SlurmctldUnAvailableEvent(EventBase): - """Emitted when slurmctld is unavailable.""" - - -class SlurmctldAvailableEvent(EventBase): - """Emitted when slurmctld joins the relation.""" - - -class SlurmdbdEvents(ObjectEvents): - """Slurmdbd relation events.""" - - slurmctld_available = EventSource(SlurmctldAvailableEvent) - slurmctld_unavailable = EventSource(SlurmctldUnAvailableEvent) - - -class Slurmdbd(Object): - """Slurmdbd.""" - - _stored = StoredState() - on = SlurmdbdEvents() - - def __init__(self, charm, relation_name): - """Observe relation lifecycle events.""" - super().__init__(charm, relation_name) - - self._charm = charm - self._relation_name = relation_name - - self._stored.set_default( - munge_key=str(), - jwt_key=str(), - slurmctld_joined=False, - ) - - self.framework.observe( - self._charm.on[self._relation_name].relation_created, - self._on_relation_created, - ) - - self.framework.observe( - self._charm.on[self._relation_name].relation_joined, - self._on_relation_joined, - ) - - self.framework.observe( - self._charm.on[self._relation_name].relation_broken, - self._on_relation_broken, - ) - - @property - def is_joined(self): - """Return True if juju related slurmdbd <-> slurmctld.""" - return self._stored.slurmctld_joined - - @is_joined.setter - def is_joined(self, flag): - """Set the is_joined property.""" - self._stored.slurmctld_joined = flag - - def _on_relation_created(self, event): - """Handle the relation-created event.""" - self.is_joined = True - - def _on_relation_joined(self, event): - """Handle the relation-joined event. - - Get the munge_key and jwt_rsa from slurmctld and save it to the charm - stored state. - """ - # Since we are in relation-joined (with the app on the other side) - # we can almost guarantee that the app object will exist in - # the event, but check for it just in case. - event_app_data = event.relation.data.get(event.app) - if not event_app_data: - event.defer() - return - - # slurmctld sets the munge_key on the relation-created event - # which happens before relation-joined. We can almost guarantee that - # the munge key will exist at this point, but check for it just in case. - munge_key = event_app_data.get("munge_key") - if not munge_key: - event.defer() - return - - # slurmctld sets the jwt_rsa on the relation-created event - # which happens before relation-joined. We can almost guarantee that - # the jwt_rsa will exist at this point, but check for it just in case. - jwt_rsa = event_app_data.get("jwt_rsa") - if not jwt_rsa: - event.defer() - return - - # Store the munge_key and jwt_rsa in the interface's stored state - # object and emit the slurmctld_available event. - self._store_munge_key(munge_key) - self._store_jwt_rsa(jwt_rsa) - self.on.slurmctld_available.emit() - - self._charm.cluster_name = event_app_data.get("cluster_name") - - def _on_relation_broken(self, event): - """Clear the application relation data and emit the event.""" - self.set_slurmdbd_info_on_app_relation_data("") - self.is_joined = False - self.on.slurmctld_unavailable.emit() - - def set_slurmdbd_info_on_app_relation_data(self, slurmdbd_info): - """Send slurmdbd_info to slurmctld.""" - logger.debug(f"## Setting info in app relation data: {slurmdbd_info}") - relations = self.framework.model.relations["slurmdbd"] - # Iterate over each of the relations setting the relation data. - for relation in relations: - if slurmdbd_info != "": - relation.data[self.model.app]["slurmdbd_info"] = json.dumps(slurmdbd_info) - else: - relation.data[self.model.app]["slurmdbd_info"] = "" - - def _store_munge_key(self, munge_key): - """Set the munge key in the stored state.""" - self._stored.munge_key = munge_key - - def get_munge_key(self): - """Retrieve the munge key from the stored state.""" - return self._stored.munge_key - - def _store_jwt_rsa(self, jwt_rsa): - """Store the jwt_rsa in the interface stored state.""" - self._stored.jwt_rsa = jwt_rsa - - def get_jwt_rsa(self): - """Retrieve the jwt_rsa from the stored state.""" - return self._stored.jwt_rsa diff --git a/src/interface_slurmdbd_peer.py b/src/interface_slurmdbd_peer.py deleted file mode 100644 index d4abd60..0000000 --- a/src/interface_slurmdbd_peer.py +++ /dev/null @@ -1,183 +0,0 @@ -#!/usr/bin/env python3 -"""SlurmdbdPeer.""" -import copy -import json -import logging -import subprocess - -from ops.framework import EventBase, EventSource, Object, ObjectEvents - -logger = logging.getLogger() - - -class SlurmdbdPeerAvailableEvent(EventBase): - """Emitted in the relation_changed event when a peer comes online.""" - - -class SlurmdbdPeerRelationEvents(ObjectEvents): - """Slurmdbd peer relation events.""" - - slurmdbd_peer_available = EventSource(SlurmdbdPeerAvailableEvent) - - -class SlurmdbdPeer(Object): - """SlurmdbdPeer Interface.""" - - on = SlurmdbdPeerRelationEvents() - - def __init__(self, charm, relation_name): - """Initialize and observe.""" - super().__init__(charm, relation_name) - self._charm = charm - self._relation_name = relation_name - - self.framework.observe( - self._charm.on[self._relation_name].relation_created, - self._on_relation_created, - ) - - self.framework.observe( - self._charm.on[self._relation_name].relation_changed, - self._on_relation_changed, - ) - - self.framework.observe( - self._charm.on[self._relation_name].relation_departed, - self._on_relation_departed, - ) - - def _on_relation_created(self, event): - """Set hostname and port on the unit data.""" - relation = self.framework.model.get_relation(self._relation_name) - unit_relation_data = relation.data[self.model.unit] - - unit_relation_data["hostname"] = self._charm.get_hostname() - unit_relation_data["port"] = self._charm.get_port() - - # Call _on_relation_changed to assemble the slurmdbd_info and - # emit the slurmdbd_peer_available event. - self._on_relation_changed(event) - - def _on_relation_changed(self, event): - """Use the leader and app relation data to schedule active/backup.""" - # We only modify the slurmdbd queue if we are the leader. - # As such, we don't need to perform any operations here - # if we are not the leader. - if self.framework.model.unit.is_leader(): - relation = self.framework.model.get_relation(self._relation_name) - - app_relation_data = relation.data[self.model.app] - unit_relation_data = relation.data[self.model.unit] - - slurmdbd_peers = _get_active_peers() - slurmdbd_peers_tmp = copy.deepcopy(slurmdbd_peers) - - active_slurmdbd = app_relation_data.get("active_slurmdbd") - backup_slurmdbd = app_relation_data.get("backup_slurmdbd") - - # Account for the active slurmdbd - # In this case, tightly couple the active slurmdbd to the leader. - # - # If we are the leader but are not the active slurmdbd, - # then the previous slurmdbd leader must have died. - # Set our unit to the active_slurmdbd. - if active_slurmdbd != self.model.unit.name: - app_relation_data["active_slurmdbd"] = self.model.unit.name - - # Account for the backup and standby slurmdbd - # - # If the backup slurmdbd exists in the application relation data - # then check that it also exists in the slurmdbd_peers. If it does - # exist in the slurmdbd peers then remove it from the list of - # active peers and set the rest of the peers to be standbys. - if backup_slurmdbd: - # Just because the backup_slurmdbd exists in the application - # data doesn't mean that it really exists. Check that the - # backup_slurmdbd that we have in the application data still - # exists in the list of active units. If the backup_slurmdbd - # isn't in the list of active units then check for - # slurmdbd_peers > 0 and try to promote a standby to a backup. - if backup_slurmdbd in slurmdbd_peers: - slurmdbd_peers_tmp.remove(backup_slurmdbd) - app_relation_data["standby_slurmdbd"] = json.dumps(slurmdbd_peers_tmp) - else: - if len(slurmdbd_peers) > 0: - app_relation_data["backup_slurmdbd"] = slurmdbd_peers_tmp.pop() - app_relation_data["standby_slurmdbd"] = json.dumps(slurmdbd_peers_tmp) - else: - app_relation_data["backup_slurmdbd"] = "" - app_relation_data["standby_slurmdbd"] = json.dumps([]) - else: - if len(slurmdbd_peers) > 0: - app_relation_data["backup_slurmdbd"] = slurmdbd_peers_tmp.pop() - app_relation_data["standby_slurmdbd"] = json.dumps(slurmdbd_peers_tmp) - else: - app_relation_data["standby_slurmdbd"] = json.dumps([]) - - ctxt = {} - backup_slurmdbd = app_relation_data.get("backup_slurmdbd") - - # NOTE: We only care about the active and backup slurdbd. - # Set the active slurmdbd info and check for and set the - # backup slurmdbd information if one exists. - ctxt["active_slurmdbd_ingress_address"] = unit_relation_data["ingress-address"] - ctxt["active_slurmdbd_hostname"] = self._charm.get_hostname() - ctxt["active_slurmdbd_port"] = str(self._charm.get_port()) - - # If we have > 0 slurmdbd (also have a backup), iterate over - # them retrieving the info for the backup and set it along with - # the info for the active slurmdbd, then emit the - # 'slurmdbd_peer_available' event. - if backup_slurmdbd: - for unit in relation.units: - if unit.name == backup_slurmdbd: - unit_data = relation.data[unit] - ctxt["backup_slurmdbd_ingress_address"] = unit_data["ingress-address"] - ctxt["backup_slurmdbd_hostname"] = unit_data["hostname"] - ctxt["backup_slurmdbd_port"] = unit_data["port"] - else: - ctxt["backup_slurmdbd_ingress_address"] = "" - ctxt["backup_slurmdbd_hostname"] = "" - ctxt["backup_slurmdbd_port"] = "" - - app_relation_data["slurmdbd_info"] = json.dumps(ctxt) - self.on.slurmdbd_peer_available.emit() - - def _on_relation_departed(self, event): - self.on.slurmdbd_peer_available.emit() - - @property - def _relation(self): - return self.framework.model.get_relation(self._relation_name) - - def get_slurmdbd_info(self) -> dict: - """Return slurmdbd info.""" - relation = self._relation - if relation: - app = relation.app - if app: - slurmdbd_info = relation.data[app].get("slurmdbd_info") - if slurmdbd_info: - return json.loads(slurmdbd_info) - return {} - - -def _related_units(relid): - """List of related units.""" - units_cmd_line = ["relation-list", "--format=json", "-r", relid] - return json.loads(subprocess.check_output(units_cmd_line).decode("UTF-8")) or [] - - -def _relation_ids(reltype): - """List of relation_ids.""" - relid_cmd_line = ["relation-ids", "--format=json", reltype] - return json.loads(subprocess.check_output(relid_cmd_line).decode("UTF-8")) or [] - - -def _get_active_peers(): - """Return the active_units.""" - active_units = [] - for rel_id in _relation_ids("slurmdbd-peer"): - for unit in _related_units(rel_id): - active_units.append(unit) - return active_units diff --git a/src/slurmdbd_ops.py b/src/slurmdbd_ops.py new file mode 100644 index 0000000..b7d07bf --- /dev/null +++ b/src/slurmdbd_ops.py @@ -0,0 +1,358 @@ +# Copyright 2020-2024 Omnivector, LLC. +"""SlurmdbdOpsManager.""" +import logging +import os +import socket +import subprocess +import textwrap +from base64 import b64decode +from datetime import datetime +from grp import getgrnam +from pathlib import Path +from pwd import getpwnam +from typing import Optional + +import charms.operator_libs_linux.v0.apt as apt +import charms.operator_libs_linux.v1.systemd as systemd +import distro +from constants import SLURMDBD_DEFAULTS_FILE, UBUNTU_HPC_PPA_KEY + +logger = logging.getLogger() + + +class SlurmdbdManagerError(BaseException): + """Exception for use with SlurmdbdManager.""" + + def __init__(self, message): + super().__init__(message) + self.message = message + + +class CharmedHPCPackageLifecycleManager: + """Facilitate ubuntu-hpc slurm component package lifecycles.""" + + def __init__(self, package_name: str): + self._package_name = package_name + self._keyring_path = Path(f"/usr/share/keyrings/ubuntu-hpc-{self._package_name}.asc") + + def _repo(self) -> apt.DebianRepository: + """Return the ubuntu-hpc repo.""" + ppa_url: str = "https://ppa.launchpadcontent.net/ubuntu-hpc/slurm-wlm-23.02/ubuntu" + sources_list: str = ( + f"deb [signed-by={self._keyring_path}] {ppa_url} {distro.codename()} main" + ) + return apt.DebianRepository.from_repo_line(sources_list) + + def install(self) -> bool: + """Install package using lib apt.""" + package_installed = False + + if self._keyring_path.exists(): + self._keyring_path.unlink() + self._keyring_path.write_text(UBUNTU_HPC_PPA_KEY) + + repositories = apt.RepositoryMapping() + repositories.add(self._repo()) + + try: + apt.update() + apt.add_package([self._package_name]) + package_installed = True + except apt.PackageNotFoundError: + logger.error(f"'{self._package_name}' not found in package cache or on system.") + except apt.PackageError as e: + logger.error(f"Could not install '{self._package_name}'. Reason: {e.message}") + + return package_installed + + def uninstall(self) -> None: + """Uninstall the package using libapt.""" + if apt.remove_package(self._package_name): + logger.info(f"'{self._package_name}' removed from system.") + else: + logger.error(f"'{self._package_name}' not found on system.") + + repositories = apt.RepositoryMapping() + repositories.disable(self._repo()) + + if self._keyring_path.exists(): + self._keyring_path.unlink() + + def upgrade_to_latest(self) -> None: + """Upgrade package to latest.""" + try: + slurm_package = apt.DebianPackage.from_system(self._package_name) + slurm_package.ensure(apt.PackageState.Latest) + logger.info(f"Updated '{self._package_name}' to: {slurm_package.version.number}.") + except apt.PackageNotFoundError: + logger.error(f"'{self._package_name}' not found in package cache or on system.") + except apt.PackageError as e: + logger.error(f"Could not install '{self._package_name}'. Reason: {e.message}") + + def version(self) -> str: + """Return the package version.""" + slurmdbd_vers = "" + try: + slurmdbd_vers = apt.DebianPackage.from_installed_package( + self._package_name + ).version.number + except apt.PackageNotFoundError: + logger.error(f"'{self._package_name}' not found on system.") + return slurmdbd_vers + + +class SlurmdbdOpsManager: + """SlurmdbdOpsManager.""" + + def __init__(self): + """Set the initial attribute values.""" + self._slurm_component = "slurmdbd" + + self._slurm_state_dir = Path("/var/spool/slurmdbd") + self._slurmdbd_conf_dir = Path("/etc/slurm") + self._slurmdbd_log_dir = Path("/var/log/slurm") + self._slurm_user = "slurm" + self._slurm_group = "slurm" + + self._munge_package = CharmedHPCPackageLifecycleManager("munge") + self._slurmdbd_package = CharmedHPCPackageLifecycleManager("slurmdbd") + + def write_slurmdbd_conf(self, slurmdbd_parameters: dict) -> None: + """Render slurmdbd.conf.""" + slurmdbd_conf = self._slurmdbd_conf_dir / "slurmdbd.conf" + slurm_user_uid = getpwnam(self._slurm_user).pw_uid + slurm_group_gid = getgrnam(self._slurm_group).gr_gid + + header = textwrap.dedent( + f""" + # + # {slurmdbd_conf} generated at {datetime.now()} + # + + """ + ) + slurmdbd_conf.write_text( + header + "\n".join([f"{k}={v}" for k, v in slurmdbd_parameters.items() if v != ""]) + ) + + slurmdbd_conf.chmod(0o600) + os.chown(f"{slurmdbd_conf}", slurm_user_uid, slurm_group_gid) + + def write_munge_key(self, munge_key_data: str) -> None: + """Base64 decode and write the munge key.""" + munge_key_path = Path("/etc/munge/munge.key") + munge_key_path.write_bytes(b64decode(munge_key_data.encode())) + + def write_jwt_rsa(self, jwt_rsa: str) -> None: + """Write the jwt_rsa key and set permissions.""" + jwt_rsa_path = self._slurm_state_dir / "jwt_hs256.key" + slurm_user_uid = getpwnam(self._slurm_user).pw_uid + slurm_group_gid = getgrnam(self._slurm_group).gr_gid + + # Write the jwt_rsa key to the file and chmod 0600 + chown to slurm_user. + jwt_rsa_path.write_text(jwt_rsa) + jwt_rsa_path.chmod(0o600) + + os.chown(f"{jwt_rsa_path}", slurm_user_uid, slurm_group_gid) + + def install(self) -> bool: + """Install slurmdbd and munge to the system and setup paths.""" + slurm_user_uid = getpwnam(self._slurm_user).pw_uid + slurm_group_gid = getgrnam(self._slurm_group).gr_gid + + if self._slurmdbd_package.install() is not True: + return False + systemd.service_stop("slurmdbd") + + if self._munge_package.install() is not True: + return False + systemd.service_stop("munge") + + # Create needed paths with correct permissions. + for syspath in [self._slurmdbd_conf_dir, self._slurmdbd_log_dir, self._slurm_state_dir]: + if not syspath.exists(): + syspath.mkdir() + os.chown(f"{syspath}", slurm_user_uid, slurm_group_gid) + return True + + def stop_slurmdbd(self) -> None: + """Stop slurmdbd.""" + systemd.service_stop("slurmdbd") + + def is_slurmdbd_active(self) -> bool: + """Get if slurmdbd daemon is active or not.""" + return systemd.service_running("slurmdbd") + + def stop_munge(self) -> None: + """Stop munge.""" + systemd.service_stop("munge") + + def start_munge(self) -> bool: + """Start the munged process. + + Return True on success, and False otherwise. + """ + logger.debug("Starting munge.") + try: + systemd.service_start("munge") + # Ignore pyright error for is not a valid exception class, reportGeneralTypeIssues + except SlurmdbdManagerError( + "Cannot start munge." + ) as e: # pyright: ignore [reportGeneralTypeIssues] + logger.error(e) + return False + return self.check_munged() + + def check_munged(self) -> bool: + """Check if munge is working correctly.""" + if not systemd.service_running("munge"): + return False + + output = "" + # check if munge is working, i.e., can use the credentials correctly + try: + logger.debug("## Testing if munge is working correctly") + munge = subprocess.Popen( + ["munge", "-n"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + if munge is not None: + unmunge = subprocess.Popen( + ["unmunge"], stdin=munge.stdout, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + output = unmunge.communicate()[0].decode() + if "Success" in output: + logger.debug(f"## Munge working as expected: {output}") + return True + logger.error(f"## Munge not working: {output}") + except subprocess.CalledProcessError as e: + logger.error(f"## Error testing munge: {e}") + + return False + + def restart_slurmdbd(self) -> bool: + """Restart the slurmdbd process. + + Return True on success, and False otherwise. + """ + logger.debug("Attempting to restart slurmdbd.") + try: + systemd.service_restart("slurmdbd") + # Ignore pyright error for is not a valid exception class, reportGeneralTypeIssues + except SlurmdbdManagerError( + "Cannot restart slurmdbd." + ) as e: # pyright: ignore [reportGeneralTypeIssues] + logger.error(e) + return False + return True + + @property + def fluentbit_config_slurm(self) -> list: + """Return Fluentbit configuration parameters to forward Slurm logs.""" + log_file = self._slurmdbd_log_dir / "slurmdbd.log" + + cfg = [ + { + "input": [ + ("name", "tail"), + ("path", log_file.as_posix()), + ("path_key", "filename"), + ("tag", self._slurm_component), + ("parser", "slurm"), + ] + }, + { + "parser": [ + ("name", "slurm"), + ("format", "regex"), + ("regex", r"^\[(?