diff --git a/lib/charms/hydra/v0/oauth.py b/lib/charms/hydra/v0/oauth.py new file mode 100644 index 00000000..375c690d --- /dev/null +++ b/lib/charms/hydra/v0/oauth.py @@ -0,0 +1,803 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""# Oauth Library. + +This library is designed to enable applications to register OAuth2/OIDC +clients with an OIDC Provider through the `oauth` interface. + +## Getting started + +To get started using this library you just need to fetch the library using `charmcraft`. **Note +that you also need to add `jsonschema` to your charm's `requirements.txt`.** + +```shell +cd some-charm +charmcraft fetch-lib charms.hydra.v0.oauth +EOF +``` + +Then, to initialize the library: +```python +# ... +from charms.hydra.v0.oauth import ClientConfig, OAuthRequirer + +OAUTH = "oauth" +OAUTH_SCOPES = "openid email" +OAUTH_GRANT_TYPES = ["authorization_code"] + +class SomeCharm(CharmBase): + def __init__(self, *args): + # ... + self.oauth = OAuthRequirer(self, client_config, relation_name=OAUTH) + + self.framework.observe(self.oauth.on.oauth_info_changed, self._configure_application) + # ... + + def _on_ingress_ready(self, event): + self.external_url = "https://example.com" + self._set_client_config() + + def _set_client_config(self): + client_config = ClientConfig( + urljoin(self.external_url, "/oauth/callback"), + OAUTH_SCOPES, + OAUTH_GRANT_TYPES, + ) + self.oauth.update_client_config(client_config) +``` +""" + +import json +import logging +import re +from dataclasses import asdict, dataclass, field, fields +from typing import Dict, List, Mapping, Optional + +import jsonschema +from ops.charm import CharmBase, RelationBrokenEvent, RelationChangedEvent, RelationCreatedEvent +from ops.framework import EventBase, EventSource, Handle, Object, ObjectEvents +from ops.model import Relation, Secret, TooManyRelatedAppsError + +# The unique Charmhub library identifier, never change it +LIBID = "a3a301e325e34aac80a2d633ef61fe97" + +# Increment this major API version when introducing breaking changes +LIBAPI = 0 + +# Increment this PATCH version before using `charmcraft publish-lib` or reset +# to 0 if you are raising the major API version +LIBPATCH = 9 + +PYDEPS = ["jsonschema"] + + +logger = logging.getLogger(__name__) + +DEFAULT_RELATION_NAME = "oauth" +ALLOWED_GRANT_TYPES = [ + "authorization_code", + "refresh_token", + "client_credentials", + "urn:ietf:params:oauth:grant-type:device_code", +] +ALLOWED_CLIENT_AUTHN_METHODS = ["client_secret_basic", "client_secret_post"] +CLIENT_SECRET_FIELD = "secret" + +url_regex = re.compile( + r"(^http://)|(^https://)" # http:// or https:// + r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|" + r"[A-Z0-9-]{2,}\.?)|" # domain... + r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip + r"(?::\d+)?" # optional port + r"(?:/?|[/?]\S+)$", + re.IGNORECASE, +) + +OAUTH_PROVIDER_JSON_SCHEMA = { + "$schema": "http://json-schema.org/draft-07/schema", + "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/oauth/schemas/provider.json", + "type": "object", + "properties": { + "issuer_url": { + "type": "string", + }, + "authorization_endpoint": { + "type": "string", + }, + "token_endpoint": { + "type": "string", + }, + "introspection_endpoint": { + "type": "string", + }, + "userinfo_endpoint": { + "type": "string", + }, + "jwks_endpoint": { + "type": "string", + }, + "scope": { + "type": "string", + }, + "client_id": { + "type": "string", + }, + "client_secret_id": { + "type": "string", + }, + "groups": {"type": "string", "default": None}, + "ca_chain": {"type": "array", "items": {"type": "string"}, "default": []}, + "jwt_access_token": {"type": "string", "default": "False"}, + }, + "required": [ + "issuer_url", + "authorization_endpoint", + "token_endpoint", + "introspection_endpoint", + "userinfo_endpoint", + "jwks_endpoint", + "scope", + ], +} +OAUTH_REQUIRER_JSON_SCHEMA = { + "$schema": "http://json-schema.org/draft-07/schema", + "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/oauth/schemas/requirer.json", + "type": "object", + "properties": { + "redirect_uri": { + "type": "string", + "default": None, + }, + "audience": {"type": "array", "default": [], "items": {"type": "string"}}, + "scope": {"type": "string", "default": None}, + "grant_types": { + "type": "array", + "default": None, + "items": { + "enum": ALLOWED_GRANT_TYPES, + "type": "string", + }, + }, + "token_endpoint_auth_method": { + "type": "string", + "enum": ALLOWED_CLIENT_AUTHN_METHODS, + "default": "client_secret_basic", + }, + }, + "required": ["redirect_uri", "audience", "scope", "grant_types", "token_endpoint_auth_method"], +} + + +class ClientConfigError(Exception): + """Emitted when invalid client config is provided.""" + + +class DataValidationError(RuntimeError): + """Raised when data validation fails on relation data.""" + + +def _load_data(data: Mapping, schema: Optional[Dict] = None) -> Dict: + """Parses nested fields and checks whether `data` matches `schema`.""" + ret = {} + for k, v in data.items(): + try: + ret[k] = json.loads(v) + except json.JSONDecodeError: + ret[k] = v + + if schema: + _validate_data(ret, schema) + return ret + + +def _dump_data(data: Dict, schema: Optional[Dict] = None) -> Dict: + if schema: + _validate_data(data, schema) + + ret = {} + for k, v in data.items(): + if isinstance(v, (list, dict)): + try: + ret[k] = json.dumps(v) + except json.JSONDecodeError as e: + raise DataValidationError(f"Failed to encode relation json: {e}") + elif isinstance(v, bool): + ret[k] = str(v) + else: + ret[k] = v + return ret + + +def strtobool(val: str) -> bool: + """Convert a string representation of truth to true (1) or false (0). + + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values + are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if + 'val' is anything else. + """ + if not isinstance(val, str): + raise ValueError(f"invalid value type {type(val)}") + + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return True + elif val in ("n", "no", "f", "false", "off", "0"): + return False + else: + raise ValueError(f"invalid truth value {val}") + + +class OAuthRelation(Object): + """A class containing helper methods for oauth relation.""" + + def _pop_relation_data(self, relation_id: Relation) -> None: + if not self.model.unit.is_leader(): + return + + if len(self.model.relations) == 0: + return + + relation = self.model.get_relation(self._relation_name, relation_id=relation_id) + if not relation or not relation.app: + return + + try: + for data in list(relation.data[self.model.app]): + relation.data[self.model.app].pop(data, "") + except Exception as e: + logger.info(f"Failed to pop the relation data: {e}") + + +def _validate_data(data: Dict, schema: Dict) -> None: + """Checks whether `data` matches `schema`. + + Will raise DataValidationError if the data is not valid, else return None. + """ + try: + jsonschema.validate(instance=data, schema=schema) + except jsonschema.ValidationError as e: + raise DataValidationError(data, schema) from e + + +@dataclass +class ClientConfig: + """Helper class containing a client's configuration.""" + + redirect_uri: str + scope: str + grant_types: List[str] + audience: List[str] = field(default_factory=lambda: []) + token_endpoint_auth_method: str = "client_secret_basic" + client_id: Optional[str] = None + + def validate(self) -> None: + """Validate the client configuration.""" + # Validate redirect_uri + if not re.match(url_regex, self.redirect_uri): + raise ClientConfigError(f"Invalid URL {self.redirect_uri}") + + if self.redirect_uri.startswith("http://"): + logger.warning("Provided Redirect URL uses http scheme. Don't do this in production") + + # Validate grant_types + for grant_type in self.grant_types: + if grant_type not in ALLOWED_GRANT_TYPES: + raise ClientConfigError( + f"Invalid grant_type {grant_type}, must be one " f"of {ALLOWED_GRANT_TYPES}" + ) + + # Validate client authentication methods + if self.token_endpoint_auth_method not in ALLOWED_CLIENT_AUTHN_METHODS: + raise ClientConfigError( + f"Invalid client auth method {self.token_endpoint_auth_method}, " + f"must be one of {ALLOWED_CLIENT_AUTHN_METHODS}" + ) + + def to_dict(self) -> Dict: + """Convert object to dict.""" + return {k: v for k, v in asdict(self).items() if v is not None} + + +@dataclass +class OauthProviderConfig: + """Helper class containing provider's configuration.""" + + issuer_url: str + authorization_endpoint: str + token_endpoint: str + introspection_endpoint: str + userinfo_endpoint: str + jwks_endpoint: str + scope: str + client_id: Optional[str] = None + client_secret: Optional[str] = None + groups: Optional[str] = None + ca_chain: Optional[str] = None + jwt_access_token: Optional[bool] = False + + @classmethod + def from_dict(cls, dic: Dict) -> "OauthProviderConfig": + """Generate OauthProviderConfig instance from dict.""" + jwt_access_token = False + if "jwt_access_token" in dic: + jwt_access_token = strtobool(dic["jwt_access_token"]) + return cls( + jwt_access_token=jwt_access_token, + **{ + k: v + for k, v in dic.items() + if k in [f.name for f in fields(cls)] and k != "jwt_access_token" + }, + ) + + +class OAuthInfoChangedEvent(EventBase): + """Event to notify the charm that the information in the databag changed.""" + + def __init__(self, handle: Handle, client_id: str, client_secret_id: str): + super().__init__(handle) + self.client_id = client_id + self.client_secret_id = client_secret_id + + def snapshot(self) -> Dict: + """Save event.""" + return { + "client_id": self.client_id, + "client_secret_id": self.client_secret_id, + } + + def restore(self, snapshot: Dict) -> None: + """Restore event.""" + super().restore(snapshot) + self.client_id = snapshot["client_id"] + self.client_secret_id = snapshot["client_secret_id"] + + +class InvalidClientConfigEvent(EventBase): + """Event to notify the charm that the client configuration is invalid.""" + + def __init__(self, handle: Handle, error: str): + super().__init__(handle) + self.error = error + + def snapshot(self) -> Dict: + """Save event.""" + return { + "error": self.error, + } + + def restore(self, snapshot: Dict) -> None: + """Restore event.""" + self.error = snapshot["error"] + + +class OAuthInfoRemovedEvent(EventBase): + """Event to notify the charm that the provider data was removed.""" + + def snapshot(self) -> Dict: + """Save event.""" + return {} + + def restore(self, snapshot: Dict) -> None: + """Restore event.""" + pass + + +class OAuthRequirerEvents(ObjectEvents): + """Event descriptor for events raised by `OAuthRequirerEvents`.""" + + oauth_info_changed = EventSource(OAuthInfoChangedEvent) + oauth_info_removed = EventSource(OAuthInfoRemovedEvent) + invalid_client_config = EventSource(InvalidClientConfigEvent) + + +class OAuthRequirer(OAuthRelation): + """Register an oauth client.""" + + on = OAuthRequirerEvents() + + def __init__( + self, + charm: CharmBase, + client_config: Optional[ClientConfig] = None, + relation_name: str = DEFAULT_RELATION_NAME, + ) -> None: + super().__init__(charm, relation_name) + self._charm = charm + self._relation_name = relation_name + self._client_config = client_config + events = self._charm.on[relation_name] + self.framework.observe(events.relation_created, self._on_relation_created_event) + self.framework.observe(events.relation_changed, self._on_relation_changed_event) + self.framework.observe(events.relation_broken, self._on_relation_broken_event) + + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: + try: + self._update_relation_data(self._client_config, event.relation.id) + except ClientConfigError as e: + self.on.invalid_client_config.emit(e.args[0]) + + def _on_relation_broken_event(self, event: RelationBrokenEvent) -> None: + # Workaround for https://github.com/canonical/operator/issues/888 + self._pop_relation_data(event.relation.id) + if self.is_client_created(): + event.defer() + logger.info("Relation data still available. Deferring the event") + return + + # Notify the requirer that the relation data was removed + self.on.oauth_info_removed.emit() + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + data = event.relation.data[event.app] + if not data: + logger.info("No relation data available.") + return + + data = _load_data(data, OAUTH_PROVIDER_JSON_SCHEMA) + + client_id = data.get("client_id") + client_secret_id = data.get("client_secret_id") + if not client_id or not client_secret_id: + logger.info("OAuth Provider info is available, waiting for client to be registered.") + # The client credentials are not ready yet, so we do nothing + # This could mean that the client credentials were removed from the databag, + # but we don't allow that (for now), so we don't have to check for it. + return + + self.on.oauth_info_changed.emit(client_id, client_secret_id) + + def _update_relation_data( + self, client_config: Optional[ClientConfig], relation_id: Optional[int] = None + ) -> None: + if not self.model.unit.is_leader() or not client_config: + return + + if not isinstance(client_config, ClientConfig): + raise ValueError(f"Unexpected client_config type: {type(client_config)}") + + client_config.validate() + + try: + relation = self.model.get_relation( + relation_name=self._relation_name, relation_id=relation_id + ) + except TooManyRelatedAppsError: + raise RuntimeError("More than one relations are defined. Please provide a relation_id") + + if not relation or not relation.app: + return + + data = _dump_data(client_config.to_dict(), OAUTH_REQUIRER_JSON_SCHEMA) + relation.data[self.model.app].update(data) + + def is_client_created(self, relation_id: Optional[int] = None) -> bool: + """Check if the client has been created.""" + if len(self.model.relations) == 0: + return None + try: + relation = self.model.get_relation(self._relation_name, relation_id=relation_id) + except TooManyRelatedAppsError: + raise RuntimeError("More than one relations are defined. Please provide a relation_id") + + if not relation or not relation.app: + return None + + return ( + "client_id" in relation.data[relation.app] + and "client_secret_id" in relation.data[relation.app] + ) + + def get_provider_info( + self, relation_id: Optional[int] = None + ) -> Optional[OauthProviderConfig]: + """Get the provider information from the databag.""" + if len(self.model.relations) == 0: + return None + try: + relation = self.model.get_relation(self._relation_name, relation_id=relation_id) + except TooManyRelatedAppsError: + raise RuntimeError("More than one relations are defined. Please provide a relation_id") + if not relation or not relation.app: + return None + + data = relation.data[relation.app] + if not data: + logger.info("No relation data available.") + return + + data = _load_data(data, OAUTH_PROVIDER_JSON_SCHEMA) + + client_secret_id = data.get("client_secret_id") + if client_secret_id: + _client_secret = self.get_client_secret(client_secret_id) + client_secret = _client_secret.get_content()[CLIENT_SECRET_FIELD] + data["client_secret"] = client_secret + + oauth_provider = OauthProviderConfig.from_dict(data) + return oauth_provider + + def get_client_secret(self, client_secret_id: str) -> Secret: + """Get the client_secret.""" + client_secret = self.model.get_secret(id=client_secret_id) + return client_secret + + def update_client_config( + self, client_config: ClientConfig, relation_id: Optional[int] = None + ) -> None: + """Update the client config stored in the object.""" + self._client_config = client_config + self._update_relation_data(client_config, relation_id=relation_id) + + +class ClientCreatedEvent(EventBase): + """Event to notify the Provider charm to create a new client.""" + + def __init__( + self, + handle: Handle, + redirect_uri: str, + scope: str, + grant_types: List[str], + audience: List, + token_endpoint_auth_method: str, + relation_id: int, + ) -> None: + super().__init__(handle) + self.redirect_uri = redirect_uri + self.scope = scope + self.grant_types = grant_types + self.audience = audience + self.token_endpoint_auth_method = token_endpoint_auth_method + self.relation_id = relation_id + + def snapshot(self) -> Dict: + """Save event.""" + return { + "redirect_uri": self.redirect_uri, + "scope": self.scope, + "grant_types": self.grant_types, + "audience": self.audience, + "token_endpoint_auth_method": self.token_endpoint_auth_method, + "relation_id": self.relation_id, + } + + def restore(self, snapshot: Dict) -> None: + """Restore event.""" + self.redirect_uri = snapshot["redirect_uri"] + self.scope = snapshot["scope"] + self.grant_types = snapshot["grant_types"] + self.audience = snapshot["audience"] + self.token_endpoint_auth_method = snapshot["token_endpoint_auth_method"] + self.relation_id = snapshot["relation_id"] + + def to_client_config(self) -> ClientConfig: + """Convert the event information to a ClientConfig object.""" + return ClientConfig( + self.redirect_uri, + self.scope, + self.grant_types, + self.audience, + self.token_endpoint_auth_method, + ) + + +class ClientChangedEvent(EventBase): + """Event to notify the Provider charm that the client config changed.""" + + def __init__( + self, + handle: Handle, + redirect_uri: str, + scope: str, + grant_types: List, + audience: List, + token_endpoint_auth_method: str, + relation_id: int, + client_id: str, + ) -> None: + super().__init__(handle) + self.redirect_uri = redirect_uri + self.scope = scope + self.grant_types = grant_types + self.audience = audience + self.token_endpoint_auth_method = token_endpoint_auth_method + self.relation_id = relation_id + self.client_id = client_id + + def snapshot(self) -> Dict: + """Save event.""" + return { + "redirect_uri": self.redirect_uri, + "scope": self.scope, + "grant_types": self.grant_types, + "audience": self.audience, + "token_endpoint_auth_method": self.token_endpoint_auth_method, + "relation_id": self.relation_id, + "client_id": self.client_id, + } + + def restore(self, snapshot: Dict) -> None: + """Restore event.""" + self.redirect_uri = snapshot["redirect_uri"] + self.scope = snapshot["scope"] + self.grant_types = snapshot["grant_types"] + self.audience = snapshot["audience"] + self.token_endpoint_auth_method = snapshot["token_endpoint_auth_method"] + self.relation_id = snapshot["relation_id"] + self.client_id = snapshot["client_id"] + + def to_client_config(self) -> ClientConfig: + """Convert the event information to a ClientConfig object.""" + return ClientConfig( + self.redirect_uri, + self.scope, + self.grant_types, + self.audience, + self.token_endpoint_auth_method, + self.client_id, + ) + + +class ClientDeletedEvent(EventBase): + """Event to notify the Provider charm that the client was deleted.""" + + def __init__( + self, + handle: Handle, + relation_id: int, + ) -> None: + super().__init__(handle) + self.relation_id = relation_id + + def snapshot(self) -> Dict: + """Save event.""" + return {"relation_id": self.relation_id} + + def restore(self, snapshot: Dict) -> None: + """Restore event.""" + self.relation_id = snapshot["relation_id"] + + +class OAuthProviderEvents(ObjectEvents): + """Event descriptor for events raised by `OAuthProviderEvents`.""" + + client_created = EventSource(ClientCreatedEvent) + client_changed = EventSource(ClientChangedEvent) + client_deleted = EventSource(ClientDeletedEvent) + + +class OAuthProvider(OAuthRelation): + """A provider object for OIDC Providers.""" + + on = OAuthProviderEvents() + + def __init__(self, charm: CharmBase, relation_name: str = DEFAULT_RELATION_NAME) -> None: + super().__init__(charm, relation_name) + self._charm = charm + self._relation_name = relation_name + + events = self._charm.on[relation_name] + self.framework.observe( + events.relation_changed, + self._get_client_config_from_relation_data, + ) + self.framework.observe( + events.relation_broken, + self._on_relation_broken, + ) + + def _get_client_config_from_relation_data(self, event: RelationChangedEvent) -> None: + if not self.model.unit.is_leader(): + return + + data = event.relation.data[event.app] + if not data: + logger.info("No requirer relation data available.") + return + + client_data = _load_data(data, OAUTH_REQUIRER_JSON_SCHEMA) + redirect_uri = client_data.get("redirect_uri") + scope = client_data.get("scope") + grant_types = client_data.get("grant_types") + audience = client_data.get("audience") + token_endpoint_auth_method = client_data.get("token_endpoint_auth_method") + + data = event.relation.data[self._charm.app] + if not data: + logger.info("No provider relation data available.") + return + provider_data = _load_data(data, OAUTH_PROVIDER_JSON_SCHEMA) + client_id = provider_data.get("client_id") + + relation_id = event.relation.id + + if client_id: + # Modify an existing client + self.on.client_changed.emit( + redirect_uri, + scope, + grant_types, + audience, + token_endpoint_auth_method, + relation_id, + client_id, + ) + else: + # Create a new client + self.on.client_created.emit( + redirect_uri, scope, grant_types, audience, token_endpoint_auth_method, relation_id + ) + + def _get_secret_label(self, relation: Relation) -> str: + return f"client_secret_{relation.id}" + + def _on_relation_broken(self, event: RelationBrokenEvent) -> None: + # Workaround for https://github.com/canonical/operator/issues/888 + self._pop_relation_data(event.relation.id) + + self._delete_juju_secret(event.relation) + self.on.client_deleted.emit(event.relation.id) + + def _create_juju_secret(self, client_secret: str, relation: Relation) -> Secret: + """Create a juju secret and grant it to a relation.""" + secret = {CLIENT_SECRET_FIELD: client_secret} + juju_secret = self.model.app.add_secret(secret, label=self._get_secret_label(relation)) + juju_secret.grant(relation) + return juju_secret + + def _delete_juju_secret(self, relation: Relation) -> None: + secret = self.model.get_secret(label=self._get_secret_label(relation)) + secret.remove_all_revisions() + + def set_provider_info_in_relation_data( + self, + issuer_url: str, + authorization_endpoint: str, + token_endpoint: str, + introspection_endpoint: str, + userinfo_endpoint: str, + jwks_endpoint: str, + scope: str, + groups: Optional[str] = None, + ca_chain: Optional[str] = None, + jwt_access_token: Optional[bool] = False, + ) -> None: + """Put the provider information in the databag.""" + if not self.model.unit.is_leader(): + return + + data = { + "issuer_url": issuer_url, + "authorization_endpoint": authorization_endpoint, + "token_endpoint": token_endpoint, + "introspection_endpoint": introspection_endpoint, + "userinfo_endpoint": userinfo_endpoint, + "jwks_endpoint": jwks_endpoint, + "scope": scope, + "jwt_access_token": jwt_access_token, + } + if groups: + data["groups"] = groups + if ca_chain: + data["ca_chain"] = ca_chain + + for relation in self.model.relations[self._relation_name]: + relation.data[self.model.app].update(_dump_data(data)) + + def set_client_credentials_in_relation_data( + self, relation_id: int, client_id: str, client_secret: str + ) -> None: + """Put the client credentials in the databag.""" + if not self.model.unit.is_leader(): + return + + relation = self.model.get_relation(self._relation_name, relation_id) + if not relation or not relation.app: + return + # TODO: What if we are refreshing the client_secret? We need to add a + # new revision for that + secret = self._create_juju_secret(client_secret, relation) + data = dict(client_id=client_id, client_secret_id=secret.id) + relation.data[self.model.app].update(_dump_data(data)) diff --git a/metadata.yaml b/metadata.yaml index 92720ee9..08f564e0 100644 --- a/metadata.yaml +++ b/metadata.yaml @@ -42,6 +42,10 @@ requires: trusted-certificate: interface: tls-certificates optional: true + oauth: + interface: oauth + limit: 1 + optional: true provides: kafka-client: diff --git a/src/charm.py b/src/charm.py index 68415440..54794193 100755 --- a/src/charm.py +++ b/src/charm.py @@ -28,6 +28,7 @@ from core.cluster import ClusterState from core.models import Substrates from core.structured_config import CharmConfig +from events.oauth import OAuthHandler from events.password_actions import PasswordActionEvents from events.provider import KafkaProvider from events.tls import TLSHandler @@ -76,6 +77,7 @@ def __init__(self, *args): self.password_action_events = PasswordActionEvents(self) self.zookeeper = ZooKeeperHandler(self) self.tls = TLSHandler(self) + self.oauth = OAuthHandler(self) self.provider = KafkaProvider(self) self.upgrade = KafkaUpgrade( self, diff --git a/src/core/cluster.py b/src/core/cluster.py index a25e2134..54f3c921 100644 --- a/src/core/cluster.py +++ b/src/core/cluster.py @@ -17,14 +17,16 @@ from ops import Framework, Object, Relation from ops.model import Unit -from core.models import KafkaBroker, KafkaClient, KafkaCluster, ZooKeeper +from core.models import KafkaBroker, KafkaClient, KafkaCluster, OAuth, ZooKeeper from literals import ( INTERNAL_USERS, + OAUTH_REL_NAME, PEER, REL_NAME, SECRETS_UNIT, SECURITY_PROTOCOL_PORTS, ZK, + AuthMechanism, Status, Substrates, ) @@ -63,6 +65,11 @@ def client_relations(self) -> set[Relation]: """The relations of all client applications.""" return set(self.model.relations[REL_NAME]) + @property + def oauth_relation(self) -> Relation | None: + """The OAuth relation.""" + return self.model.get_relation(OAUTH_REL_NAME) + # --- CORE COMPONENTS --- @property @@ -127,6 +134,13 @@ def zookeeper(self) -> ZooKeeper: local_app=self.cluster.app, ) + @property + def oauth(self) -> OAuth: + """The oauth relation state.""" + return OAuth( + relation=self.oauth_relation, + ) + @property def clients(self) -> set[KafkaClient]: """The state for all related client Applications.""" @@ -180,10 +194,11 @@ def super_users(self) -> str: @property def port(self) -> int: """Return the port to be used internally.""" + mechanism: AuthMechanism = "SCRAM-SHA-512" return ( - SECURITY_PROTOCOL_PORTS["SASL_SSL"].client + SECURITY_PROTOCOL_PORTS["SASL_SSL", mechanism].client if (self.cluster.tls_enabled and self.unit_broker.certificate) - else SECURITY_PROTOCOL_PORTS["SASL_PLAINTEXT"].client + else SECURITY_PROTOCOL_PORTS["SASL_PLAINTEXT", mechanism].client ) @property diff --git a/src/core/models.py b/src/core/models.py index c5b93ea8..51016966 100644 --- a/src/core/models.py +++ b/src/core/models.py @@ -5,7 +5,9 @@ """Collection of state objects for the Kafka relations, apps and units.""" import logging +from typing import MutableMapping +import requests from charms.data_platform_libs.v0.data_interfaces import Data, DataPeerData, DataPeerUnitData from charms.zookeeper.v0.client import QuorumLeaderNotFoundError, ZooKeeperManager from kazoo.client import AuthFailedError, NoNodeError @@ -451,3 +453,49 @@ def extra_user_roles(self) -> str: When `admin` is set, the Kafka charm interprets this as a new super.user. """ return self.relation_data.get("extra-user-roles", "") + + +class OAuth: + """State collection metadata for the oauth relation.""" + + def __init__(self, relation: Relation | None): + self.relation = relation + + @property + def relation_data(self) -> MutableMapping[str, str]: + """Oauth relation data object.""" + if not self.relation or not self.relation.app: + return {} + + return self.relation.data[self.relation.app] + + @property + def issuer_url(self) -> str: + """The issuer URL to identify the IDP.""" + return self.relation_data.get("issuer_url", "") + + @property + def jwks_endpoint(self) -> str: + """The JWKS endpoint needed to validate JWT tokens.""" + return self.relation_data.get("jwks_endpoint", "") + + @property + def introspection_endpoint(self) -> str: + """The introspection endpoint needed to validate non-JWT tokens.""" + return self.relation_data.get("introspection_endpoint", "") + + @property + def jwt_access_token(self) -> bool: + """A flag indicating if the access token is JWT or not.""" + return self.relation_data.get("jwt_access_token", "false").lower() == "true" + + @property + def uses_trusted_ca(self) -> bool: + """A flag indicating if the IDP uses certificates signed by a trusted CA.""" + try: + requests.get(self.issuer_url, timeout=10) + return True + except requests.exceptions.SSLError: + return False + except requests.exceptions.RequestException: + return True diff --git a/src/events/oauth.py b/src/events/oauth.py new file mode 100644 index 00000000..575e037a --- /dev/null +++ b/src/events/oauth.py @@ -0,0 +1,40 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Manager for handling Kafka OAuth configuration.""" + +import logging +from typing import TYPE_CHECKING + +from charms.hydra.v0.oauth import ClientConfig, OAuthRequirer +from ops.framework import EventBase, Object + +from literals import OAUTH_REL_NAME + +if TYPE_CHECKING: + from charm import KafkaCharm + +logger = logging.getLogger(__name__) + + +class OAuthHandler(Object): + """Handler for managing oauth relations.""" + + def __init__(self, charm): + super().__init__(charm, "oauth") + self.charm: "KafkaCharm" = charm + + client_config = ClientConfig("https://kafka.local", "openid email", ["client_credentials"]) + self.oauth = OAuthRequirer(charm, client_config, relation_name=OAUTH_REL_NAME) + self.framework.observe( + self.charm.on[OAUTH_REL_NAME].relation_changed, self._on_oauth_relation_changed + ) + self.framework.observe( + self.charm.on[OAUTH_REL_NAME].relation_broken, self._on_oauth_relation_changed + ) + + def _on_oauth_relation_changed(self, event: EventBase) -> None: + """Handler for `_on_oauth_relation_changed` event.""" + if not self.charm.unit.is_leader() or not self.charm.state.brokers: + return + self.charm._on_config_changed(event) diff --git a/src/literals.py b/src/literals.py index f4f7ba8e..1e9e154f 100644 --- a/src/literals.py +++ b/src/literals.py @@ -23,6 +23,7 @@ PEER = "cluster" ZK = "zookeeper" REL_NAME = "kafka-client" +OAUTH_REL_NAME = "oauth" TLS_RELATION = "certificates" TRUSTED_CERTIFICATE_RELATION = "trusted-certificate" TRUSTED_CA_RELATION = "trusted-ca" @@ -53,7 +54,8 @@ USER = 584788 GROUP = "root" -AuthMechanism = Literal["SASL_PLAINTEXT", "SASL_SSL", "SSL"] +AuthProtocol = Literal["SASL_PLAINTEXT", "SASL_SSL", "SSL"] +AuthMechanism = Literal["SCRAM-SHA-512", "OAUTHBEARER", "SSL"] Scope = Literal["INTERNAL", "CLIENT"] DebugLevel = Literal["DEBUG", "INFO", "WARNING", "ERROR"] DatabagScope = Literal["unit", "app"] @@ -84,10 +86,12 @@ class Ports: internal: int -SECURITY_PROTOCOL_PORTS: dict[AuthMechanism, Ports] = { - "SASL_PLAINTEXT": Ports(9092, 19092), - "SASL_SSL": Ports(9093, 19093), - "SSL": Ports(9094, 19094), +SECURITY_PROTOCOL_PORTS: dict[tuple[AuthProtocol, AuthMechanism], Ports] = { + ("SASL_PLAINTEXT", "SCRAM-SHA-512"): Ports(9092, 19092), + ("SASL_PLAINTEXT", "OAUTHBEARER"): Ports(9095, 19095), + ("SASL_SSL", "SCRAM-SHA-512"): Ports(9093, 19093), + ("SASL_SSL", "OAUTHBEARER"): Ports(9096, 19096), + ("SSL", "SSL"): Ports(9094, 19094), } diff --git a/src/managers/config.py b/src/managers/config.py index c31580b9..3ff5950c 100644 --- a/src/managers/config.py +++ b/src/managers/config.py @@ -6,7 +6,9 @@ import inspect import logging -from typing import cast +import os +import re +import textwrap from core.cluster import ClusterState from core.structured_config import CharmConfig, LogLevel @@ -19,13 +21,13 @@ JVM_MEM_MIN_GB, SECURITY_PROTOCOL_PORTS, AuthMechanism, + AuthProtocol, Scope, ) logger = logging.getLogger(__name__) DEFAULT_CONFIG_OPTIONS = """ -sasl.enabled.mechanisms=SCRAM-SHA-512 sasl.mechanism.inter.broker.protocol=SCRAM-SHA-512 authorizer.class.name=kafka.security.authorizer.AclAuthorizer allow.everyone.if.no.acl.found=false @@ -44,8 +46,9 @@ class Listener: scope: scope of the listener, CLIENT or INTERNAL """ - def __init__(self, host: str, protocol: AuthMechanism, scope: Scope): - self.protocol: AuthMechanism = protocol + def __init__(self, host: str, protocol: AuthProtocol, mechanism: AuthMechanism, scope: Scope): + self.protocol: AuthProtocol = protocol + self.mechanism: AuthMechanism = mechanism self.host = host self.scope = scope @@ -71,15 +74,15 @@ def port(self) -> int: Returns: Integer of port number """ + port = SECURITY_PROTOCOL_PORTS[self.protocol, self.mechanism] if self.scope == "CLIENT": - return SECURITY_PROTOCOL_PORTS[self.protocol].client - - return SECURITY_PROTOCOL_PORTS[self.protocol].internal + return port.client + return port.internal @property def name(self) -> str: """Name of the listener.""" - return f"{self.scope}_{self.protocol}" + return f"{self.scope}_{self.protocol}_{self.mechanism.replace('-', '_')}" @property def protocol_map(self) -> str: @@ -197,6 +200,18 @@ def kafka_opts(self) -> str: f"-Djava.security.auth.login.config={self.workload.paths.zk_jaas}", ] + http_proxy = os.environ.get("JUJU_CHARM_HTTP_PROXY") + https_proxy = os.environ.get("JUJU_CHARM_HTTPS_PROXY") + no_proxy = os.environ.get("JUJU_CHARM_NO_PROXY") + + for prot, proxy in {"http": http_proxy, "https": https_proxy}.items(): + if proxy: + proxy = re.sub(r"^https?://", "", proxy) + [host, port] = proxy.split(":") if ":" in proxy else [proxy, "8080"] + opts.append(f"-D{prot}.proxyHost={host} -D{prot}.proxyPort={port}") + if no_proxy: + opts.append(f"-Dhttp.nonProxyHosts={no_proxy}") + return f"KAFKA_OPTS='{' '.join(opts)}'" @property @@ -270,56 +285,119 @@ def scram_properties(self) -> list[str]: username = INTER_BROKER_USER password = self.state.cluster.internal_user_credentials.get(INTER_BROKER_USER, "") + listener_name = self.internal_listener.name.lower() + listener_mechanism = self.internal_listener.mechanism.lower() + scram_properties = [ - f'listener.name.{self.internal_listener.name.lower()}.scram-sha-512.sasl.jaas.config=org.apache.kafka.common.security.scram.ScramLoginModule required username="{username}" password="{password}";' - ] - client_scram = [ - auth.name for auth in self.client_listeners if auth.protocol.startswith("SASL_") + f'listener.name.{listener_name}.{listener_mechanism}.sasl.jaas.config=org.apache.kafka.common.security.scram.ScramLoginModule required username="{username}" password="{password}";', + f"listener.name.{listener_name}.sasl.enabled.mechanisms={self.internal_listener.mechanism}", ] - for name in client_scram: + + for auth in self.client_listeners: + if not auth.mechanism.startswith("SCRAM"): + continue + + scram_properties.append( + f'listener.name.{auth.name.lower()}.{auth.mechanism.lower()}.sasl.jaas.config=org.apache.kafka.common.security.scram.ScramLoginModule required username="{username}" password="{password}";' + ) scram_properties.append( - f'listener.name.{name.lower()}.scram-sha-512.sasl.jaas.config=org.apache.kafka.common.security.scram.ScramLoginModule required username="{username}" password="{password}";' + f"listener.name.{auth.name.lower()}.sasl.enabled.mechanisms={auth.mechanism}" ) return scram_properties @property - def security_protocol(self) -> AuthMechanism: + def oauth_properties(self) -> list[str]: + """Builds the properties for the oauth listener. + + Returns: + list of oauth properties to be set. + """ + if not self.state.oauth_relation: + return [] + + listener = [ + listener + for listener in self.client_listeners + if listener.mechanism.startswith("OAUTH") + ][0] + + username_claim = "email" + username_fallback_claim = "client_id" + + # use jwks validation if jwt token, otherwise use introspection validation + validation_cfg = ( + f'oauth.jwks.endpoint.uri="{self.state.oauth.jwks_endpoint}"' + if self.state.oauth.jwt_access_token + else f'oauth.introspection.endpoint.uri="{self.state.oauth.introspection_endpoint}"' + ) + + truststore_cfg = "" + if not self.state.oauth.uses_trusted_ca: + truststore_cfg = f'oauth.ssl.truststore.location="{self.workload.paths.truststore}" oauth.ssl.truststore.password="{self.state.unit_broker.truststore_password}" oauth.ssl.truststore.type="JKS"' + + scram_properties = [ + textwrap.dedent( + f"""\ + listener.name.{listener.name.lower()}.{listener.mechanism.lower()}.sasl.jaas.config=org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required \\ + oauth.client.id="kafka" \\ + oauth.valid.issuer.uri="{self.state.oauth.issuer_url}" \\ + {validation_cfg} \\ + oauth.username.claim="{username_claim}" \\ + oauth.fallback.username.claim="{username_fallback_claim}" \\ + oauth.check.audience="true" \\ + oauth.check.access.token.type="false" \\ + oauth.config.id="{listener.name}" \\ + unsecuredLoginStringClaim_sub="unused" \\ + {truststore_cfg};""" + ), + f"listener.name.{listener.name.lower()}.{listener.mechanism.lower()}.sasl.server.callback.handler.class=io.strimzi.kafka.oauth.server.JaasServerOauthValidatorCallbackHandler", + f"listener.name.{listener.name.lower()}.sasl.enabled.mechanisms={listener.mechanism}", + "principal.builder.class=io.strimzi.kafka.oauth.server.OAuthKafkaPrincipalBuilder", + ] + + return scram_properties + + @property + def security_protocol(self) -> AuthProtocol: """Infers current charm security.protocol based on current relations.""" - # FIXME: When we have multiple auth_mechanims/listeners, remove this method return ( "SASL_SSL" if (self.state.cluster.tls_enabled and self.state.unit_broker.certificate) else "SASL_PLAINTEXT" ) - @property - def auth_mechanisms(self) -> list[AuthMechanism]: - """Return a list of enabled auth mechanisms.""" - # TODO: At the moment only one mechanism for extra listeners. Will need to be - # extended with more depending on configuration settings. - protocol = [self.security_protocol] - if self.state.cluster.mtls_enabled: - protocol += ["SSL"] - - return cast(list[AuthMechanism], protocol) - @property def internal_listener(self) -> Listener: """Return the internal listener.""" protocol = self.security_protocol - return Listener(host=self.state.unit_broker.host, protocol=protocol, scope="INTERNAL") + mechanism: AuthMechanism = "SCRAM-SHA-512" + return Listener( + host=self.state.unit_broker.host, + protocol=protocol, + mechanism=mechanism, + scope="INTERNAL", + ) @property def client_listeners(self) -> list[Listener]: """Return a list of extra listeners.""" - # if there is a relation with kafka then add extra listener - if not self.state.client_relations: - return [] + protocol_mechanism_dict: list[tuple[AuthProtocol, AuthMechanism]] = [] + if self.state.client_relations: + protocol_mechanism_dict.append((self.security_protocol, "SCRAM-SHA-512")) + if self.state.oauth_relation: + protocol_mechanism_dict.append((self.security_protocol, "OAUTHBEARER")) + if self.state.cluster.mtls_enabled: + protocol_mechanism_dict.append(("SSL", "SSL")) return [ - Listener(host=self.state.unit_broker.host, protocol=auth, scope="CLIENT") - for auth in self.auth_mechanisms + Listener( + host=self.state.unit_broker.host, + protocol=protocol, + mechanism=mechanism, + scope="CLIENT", + ) + for protocol, mechanism in protocol_mechanism_dict ] @property @@ -334,7 +412,7 @@ def inter_broker_protocol_version(self) -> str: Returns: String with the `major.minor` version """ - # Remove patch number from full vervion. + # Remove patch number from full version. major_minor = self.current_version.split(".", maxsplit=2) return ".".join(major_minor[:2]) @@ -400,8 +478,9 @@ def server_properties(self) -> list[str]: f"inter.broker.listener.name={self.internal_listener.name}", f"inter.broker.protocol.version={self.inter_broker_protocol_version}", ] - + self.config_properties + self.scram_properties + + self.oauth_properties + + self.config_properties + self.default_replication_properties + self.auth_properties + self.rack_properties diff --git a/tests/integration/ha/ha_helpers.py b/tests/integration/ha/ha_helpers.py index 1c190ac9..8224369c 100644 --- a/tests/integration/ha/ha_helpers.py +++ b/tests/integration/ha/ha_helpers.py @@ -54,7 +54,7 @@ async def get_topic_description( for unit in ops_test.model.applications[APP_NAME].units: bootstrap_servers.append( await get_address(ops_test=ops_test, unit_num=unit.name.split("/")[-1]) - + f":{SECURITY_PROTOCOL_PORTS['SASL_PLAINTEXT'].client}" + + f":{SECURITY_PROTOCOL_PORTS['SASL_PLAINTEXT', 'SCRAM-SHA-512'].client}" ) unit_name = unit_name or ops_test.model.applications[APP_NAME].units[0].name @@ -85,7 +85,7 @@ async def get_topic_offsets( for unit in ops_test.model.applications[APP_NAME].units: bootstrap_servers.append( await get_address(ops_test=ops_test, unit_num=unit.name.split("/")[-1]) - + f":{SECURITY_PROTOCOL_PORTS['SASL_PLAINTEXT'].client}" + + f":{SECURITY_PROTOCOL_PORTS['SASL_PLAINTEXT', 'SCRAM-SHA-512'].client}" ) unit_name = unit_name or ops_test.model.applications[APP_NAME].units[0].name diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py index c82e0a8d..1f3bc39f 100644 --- a/tests/integration/helpers.py +++ b/tests/integration/helpers.py @@ -263,7 +263,7 @@ async def run_client_properties(ops_test: OpsTest) -> str: """Runs command requiring admin permissions, authenticated with bootstrap-server.""" bootstrap_server = ( await get_address(ops_test=ops_test) - + f":{SECURITY_PROTOCOL_PORTS['SASL_PLAINTEXT'].client}" + + f":{SECURITY_PROTOCOL_PORTS['SASL_PLAINTEXT', 'SCRAM-SHA-512'].client}" ) result = check_output( f"JUJU_MODEL={ops_test.model_full_name} juju ssh kafka/0 sudo -i 'charmed-kafka.configs --bootstrap-server {bootstrap_server} --describe --all --command-config {PATHS['CONF']}/client.properties --entity-type users'", diff --git a/tests/integration/test_charm.py b/tests/integration/test_charm.py index 2cde6c32..f070ace7 100644 --- a/tests/integration/test_charm.py +++ b/tests/integration/test_charm.py @@ -134,11 +134,13 @@ async def test_remove_zk_relation_relate(ops_test: OpsTest): async def test_listeners(ops_test: OpsTest, app_charm): address = await get_address(ops_test=ops_test) assert check_socket( - address, SECURITY_PROTOCOL_PORTS["SASL_PLAINTEXT"].internal + address, SECURITY_PROTOCOL_PORTS["SASL_PLAINTEXT", "SCRAM-SHA-512"].internal ) # Internal listener # Client listener should not be enabled if there is no relations - assert not check_socket(address, SECURITY_PROTOCOL_PORTS["SASL_PLAINTEXT"].client) + assert not check_socket( + address, SECURITY_PROTOCOL_PORTS["SASL_PLAINTEXT", "SCRAM-SHA-512"].client + ) # Add relation with dummy app await asyncio.gather( @@ -153,7 +155,7 @@ async def test_listeners(ops_test: OpsTest, app_charm): await ops_test.model.wait_for_idle(apps=[APP_NAME, ZK_NAME, DUMMY_NAME]) # check that client listener is active - assert check_socket(address, SECURITY_PROTOCOL_PORTS["SASL_PLAINTEXT"].client) + assert check_socket(address, SECURITY_PROTOCOL_PORTS["SASL_PLAINTEXT", "SCRAM-SHA-512"].client) # remove relation and check that client listener is not active await ops_test.model.applications[APP_NAME].remove_relation( @@ -161,7 +163,9 @@ async def test_listeners(ops_test: OpsTest, app_charm): ) await ops_test.model.wait_for_idle(apps=[APP_NAME]) - assert not check_socket(address, SECURITY_PROTOCOL_PORTS["SASL_PLAINTEXT"].client) + assert not check_socket( + address, SECURITY_PROTOCOL_PORTS["SASL_PLAINTEXT", "SCRAM-SHA-512"].client + ) @pytest.mark.abort_on_fail diff --git a/tests/integration/test_tls.py b/tests/integration/test_tls.py index bbdd77f8..8e14aba2 100644 --- a/tests/integration/test_tls.py +++ b/tests/integration/test_tls.py @@ -111,7 +111,9 @@ async def test_kafka_tls(ops_test: OpsTest, app_charm): kafka_address = await get_address(ops_test=ops_test, app_name=CHARM_KEY) - assert not check_tls(ip=kafka_address, port=SECURITY_PROTOCOL_PORTS["SASL_SSL"].client) + assert not check_tls( + ip=kafka_address, port=SECURITY_PROTOCOL_PORTS["SASL_SSL", "SCRAM-SHA-512"].client + ) await asyncio.gather( ops_test.model.deploy(app_charm, application_name=DUMMY_NAME, num_units=1, series="jammy"), @@ -127,7 +129,9 @@ async def test_kafka_tls(ops_test: OpsTest, app_charm): apps=[CHARM_KEY, DUMMY_NAME], idle_period=30, status="active" ) - assert check_tls(ip=kafka_address, port=SECURITY_PROTOCOL_PORTS["SASL_SSL"].client) + assert check_tls( + ip=kafka_address, port=SECURITY_PROTOCOL_PORTS["SASL_SSL", "SCRAM-SHA-512"].client + ) # Rotate credentials new_private_key = generate_private_key().decode("utf-8") @@ -182,8 +186,8 @@ async def test_mtls(ops_test: OpsTest): broker_ca = extract_ca(ops_test=ops_test, unit_name=f"{CHARM_KEY}/0") address = await get_address(ops_test, app_name=CHARM_KEY) - ssl_port = SECURITY_PROTOCOL_PORTS["SSL"].client - sasl_port = SECURITY_PROTOCOL_PORTS["SASL_SSL"].client + ssl_port = SECURITY_PROTOCOL_PORTS["SSL", "SSL"].client + sasl_port = SECURITY_PROTOCOL_PORTS["SASL_SSL", "SCRAM-SHA-512"].client ssl_bootstrap_server = f"{address}:{ssl_port}" sasl_bootstrap_server = f"{address}:{sasl_port}" @@ -260,14 +264,18 @@ async def test_kafka_tls_scaling(ops_test: OpsTest): assert f"{chroot}/brokers/ids/2" in active_brokers kafka_address = await get_address(ops_test=ops_test, app_name=CHARM_KEY, unit_num=2) - assert check_tls(ip=kafka_address, port=SECURITY_PROTOCOL_PORTS["SASL_SSL"].client) + assert check_tls( + ip=kafka_address, port=SECURITY_PROTOCOL_PORTS["SASL_SSL", "SCRAM-SHA-512"].client + ) # remove relation and check connection again await ops_test.model.applications[CHARM_KEY].remove_relation( f"{CHARM_KEY}:{REL_NAME}", f"{DUMMY_NAME}:{REL_NAME_ADMIN}" ) await ops_test.model.wait_for_idle(apps=[CHARM_KEY]) - assert not check_tls(ip=kafka_address, port=SECURITY_PROTOCOL_PORTS["SASL_SSL"].client) + assert not check_tls( + ip=kafka_address, port=SECURITY_PROTOCOL_PORTS["SASL_SSL", "SCRAM-SHA-512"].client + ) async def test_tls_removed(ops_test: OpsTest): @@ -277,4 +285,6 @@ async def test_tls_removed(ops_test: OpsTest): ) kafka_address = await get_address(ops_test=ops_test, app_name=CHARM_KEY) - assert not check_tls(ip=kafka_address, port=SECURITY_PROTOCOL_PORTS["SASL_SSL"].client) + assert not check_tls( + ip=kafka_address, port=SECURITY_PROTOCOL_PORTS["SASL_SSL", "SCRAM-SHA-512"].client + ) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index a4a642b9..354f2403 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -2,6 +2,7 @@ # Copyright 2024 Canonical Ltd. # See LICENSE file for licensing details. +import os from pathlib import Path from unittest.mock import PropertyMock, mock_open, patch @@ -20,15 +21,17 @@ JMX_EXPORTER_PORT, JVM_MEM_MAX_GB, JVM_MEM_MIN_GB, + OAUTH_REL_NAME, PEER, SUBSTRATE, ZK, ) from managers.config import ConfigManager -CONFIG = str(yaml.safe_load(Path("./config.yaml").read_text())) -ACTIONS = str(yaml.safe_load(Path("./actions.yaml").read_text())) -METADATA = str(yaml.safe_load(Path("./metadata.yaml").read_text())) +BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..")) +CONFIG = str(yaml.safe_load(Path(BASE_DIR + "/config.yaml").read_text())) +ACTIONS = str(yaml.safe_load(Path(BASE_DIR + "/actions.yaml").read_text())) +METADATA = str(yaml.safe_load(Path(BASE_DIR + "/metadata.yaml").read_text())) # override conftest fixtures @pytest.fixture(autouse=False) @@ -145,8 +148,10 @@ def test_listeners_in_server_properties(harness: Harness): peer_relation_id, f"{CHARM_KEY}/0", {"private-address": "treebeard"} ) - expected_listeners = "listeners=INTERNAL_SASL_PLAINTEXT://:19092" - expected_advertised_listeners = f"advertised.listeners=INTERNAL_SASL_PLAINTEXT://{'treebeard' if SUBSTRATE == 'vm' else 'kafka-k8s-0.kafka-k8s-endpoints'}:19092" + host = "treebeard" if SUBSTRATE == "vm" else "kafka-k8s-0.kafka-k8s-endpoints" + sasl_pm = "SASL_PLAINTEXT_SCRAM_SHA_512" + expected_listeners = f"listeners=INTERNAL_{sasl_pm}://:19092" + expected_advertised_listeners = f"advertised.listeners=INTERNAL_{sasl_pm}://{host}:19092" with ( patch( @@ -159,6 +164,54 @@ def test_listeners_in_server_properties(harness: Harness): assert expected_advertised_listeners in harness.charm.config_manager.server_properties +def test_oauth_client_listeners_in_server_properties(harness): + """Checks that oauth client listeners are properly set when a relating through oauth.""" + harness.add_relation(ZK, CHARM_KEY) + peer_relation_id = harness.add_relation(PEER, CHARM_KEY) + harness.add_relation_unit(peer_relation_id, f"{CHARM_KEY}/1") + harness.update_relation_data( + peer_relation_id, f"{CHARM_KEY}/0", {"private-address": "treebeard"} + ) + + oauth_relation_id = harness.add_relation(OAUTH_REL_NAME, "hydra") + harness.update_relation_data( + oauth_relation_id, + "hydra", + { + "issuer_url": "issuer", + "jwks_endpoint": "jwks", + "authorization_endpoint": "authz", + "token_endpoint": "token", + "introspection_endpoint": "introspection", + "userinfo_endpoint": "userinfo", + "scope": "scope", + "jwt_access_token": "False", + }, + ) + + # let's add a scram client just for fun + client_relation_id = harness.add_relation("kafka-client", "app") + harness.update_relation_data(client_relation_id, "app", {"extra-user-roles": "admin,producer"}) + + host = "treebeard" if SUBSTRATE == "vm" else "kafka-k8s-0.kafka-k8s-endpoints" + internal_protocol, internal_port = "INTERNAL_SASL_PLAINTEXT_SCRAM_SHA_512", "19092" + scram_client_protocol, scram_client_port = "CLIENT_SASL_PLAINTEXT_SCRAM_SHA_512", "9092" + oauth_client_protocol, oauth_client_port = "CLIENT_SASL_PLAINTEXT_OAUTHBEARER", "9095" + + expected_listeners = ( + f"listeners={internal_protocol}://:{internal_port}," + f"{scram_client_protocol}://:{scram_client_port}," + f"{oauth_client_protocol}://:{oauth_client_port}" + ) + expected_advertised_listeners = ( + f"advertised.listeners={internal_protocol}://{host}:{internal_port}," + f"{scram_client_protocol}://{host}:{scram_client_port}," + f"{oauth_client_protocol}://{host}:{oauth_client_port}" + ) + assert expected_listeners in harness.charm.config_manager.server_properties + assert expected_advertised_listeners in harness.charm.config_manager.server_properties + + def test_ssl_listeners_in_server_properties(harness: Harness): """Checks that listeners are added after TLS relation are created.""" zk_relation_id = harness.add_relation(ZK, CHARM_KEY) @@ -196,11 +249,12 @@ def test_ssl_listeners_in_server_properties(harness: Harness): ) host = "treebeard" if SUBSTRATE == "vm" else "kafka-k8s-0.kafka-k8s-endpoints" + sasl_pm = "SASL_SSL_SCRAM_SHA_512" + ssl_pm = "SSL_SSL" expected_listeners = ( - "listeners=INTERNAL_SASL_SSL://:19093,CLIENT_SASL_SSL://:9093,CLIENT_SSL://:9094" + f"listeners=INTERNAL_{sasl_pm}://:19093,CLIENT_{sasl_pm}://:9093,CLIENT_{ssl_pm}://:9094" ) - expected_advertised_listeners = f"advertised.listeners=INTERNAL_SASL_SSL://{host}:19093,CLIENT_SASL_SSL://{host}:9093,CLIENT_SSL://{host}:9094" - + expected_advertised_listeners = f"advertised.listeners=INTERNAL_{sasl_pm}://{host}:19093,CLIENT_{sasl_pm}://{host}:9093,CLIENT_{ssl_pm}://{host}:9094" with ( patch( "core.models.KafkaCluster.internal_user_credentials",