Skip to content

Commit

Permalink
feat: add extra_listeners (#269)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoppenheimer authored Nov 12, 2024
1 parent 07be36c commit 808c702
Show file tree
Hide file tree
Showing 14 changed files with 256 additions and 30 deletions.
3 changes: 3 additions & 0 deletions actions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ get-admin-credentials:
The returned client_properties can be used for Kafka bin commands using `--bootstrap-server` and `--command-config` for admin level administration
This action must be called on the leader unit.

get-listeners:
description: Get all active listeners and their port allocations

pre-upgrade-check:
description: Run necessary pre-upgrade checks before executing a charm upgrade.

Expand Down
4 changes: 4 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ options:
description: Config options to add extra-sans to the ones used when requesting server certificates. The extra-sans are specified by comma-separated names to be added when requesting signed certificates. Use "{unit}" as a placeholder to be filled with the unit number, e.g. "worker-{unit}" will be translated as "worker-0" for unit 0 and "worker-1" for unit 1 when requesting the certificate.
type: string
default: ""
extra_listeners:
description: "Config options to add extra SANs to the ones used when requesting server certificates, and to define custom `advertised.listeners` and ports for clients external to the Juju model. These items are comma-separated. Use '{unit}' as a placeholder to be filled with the unit number if necessary. For port allocations, providing the port for a given listener will offset the generated port number by that amount, with an accepted value range of 20001-50000. For example, a provided value of 'worker-{unit}.domain.com:30000' will generate listeners for unit 0 with name 'worker-0.domain.com', and be allocated ports 39092, 39093 etc for each authentication scheme."
type: string
default: ""
log_level:
description: "Level of logging for the different components operated by the charm. Possible values: ERROR, WARNING, INFO, DEBUG"
type: string
Expand Down
4 changes: 2 additions & 2 deletions src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import time

import ops
from charms.data_platform_libs.v0.data_models import TypedCharmBase
from charms.grafana_agent.v0.cos_agent import COSAgentProvider
from charms.operator_libs_linux.v0 import sysctl
Expand All @@ -17,7 +18,6 @@
EventBase,
StatusBase,
)
from ops.main import main

from core.cluster import ClusterState
from core.models import Substrates
Expand Down Expand Up @@ -188,4 +188,4 @@ def _on_collect_status(self, event: CollectStatusEvent):


if __name__ == "__main__":
main(KafkaCharm)
ops.main(KafkaCharm) # pyright: ignore[reportCallIssue]
43 changes: 43 additions & 0 deletions src/core/structured_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class CharmConfig(BaseConfigModel):
zookeeper_ssl_cipher_suites: str | None
profile: str
certificate_extra_sans: str | None
extra_listeners: list[str]
log_level: str
network_bandwidth: int = Field(default=50000, validate_default=False, gt=0)
cruisecontrol_balance_threshold: float = Field(default=1.1, validate_default=False, ge=1)
Expand Down Expand Up @@ -265,3 +266,45 @@ def roles_values(cls, value: str) -> str:
raise ValueError("Unknown role(s):", unknown_roles)

return ",".join(sorted(roles)) # this has to be a string as it goes in to properties

@validator("certificate_extra_sans")
@classmethod
def certificate_extra_sans_values(cls, value: str) -> list[str]:
"""Formats certificate_extra_sans values to a list."""
return value.split(",") if value else []

@validator("extra_listeners", pre=True)
@classmethod
def extra_listeners_port_values(cls, value: str) -> list[str]:
"""Check extra_listeners port values for each listener, and format values to a list."""
if not value:
return []

listeners = value.split(",")

ports = []
for listener in listeners:
if ":" not in listener or not listener.split(":")[1].isdigit():
raise ValueError("Value for listener does not contain a valid port.")

port = int(listener.split(":")[1])
if not 20000 < port < 50000:
raise ValueError(
"Value for port out of accepted values. Accepted values for port are greater than 20000 and less than 50000"
)

ports.append(port)

current_port = 0
for port in ports:
if not current_port - 100 < int(port) > current_port + 100:
raise ValueError(
"Value for port is too close to other value for port. Accepted values must be at least 100 apart from each other"
)

current_port = int(port)

if len(ports) != len(set(ports)):
raise ValueError("Value for port is not unique for each listener.")

return listeners
34 changes: 30 additions & 4 deletions src/events/password_actions.py → src/events/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# Copyright 2024 Canonical Ltd.
# See LICENSE file for licensing details.

"""Event handlers for password-related Juju Actions."""
"""Event handlers for Juju Actions."""

import logging
from typing import TYPE_CHECKING

Expand All @@ -18,11 +19,11 @@
logger = logging.getLogger(__name__)


class PasswordActionEvents(Object):
"""Event handlers for password-related Juju Actions."""
class ActionEvents(Object):
"""Event handlers for Juju Actions."""

def __init__(self, dependent: "BrokerOperator") -> None:
super().__init__(dependent, "password_events")
super().__init__(dependent, "action_events")
self.dependent = dependent
self.charm: "KafkaCharm" = dependent.charm

Expand All @@ -33,6 +34,31 @@ def __init__(self, dependent: "BrokerOperator") -> None:
getattr(self.charm.on, "get_admin_credentials_action"),
self._get_admin_credentials_action,
)
self.framework.observe(
getattr(self.charm.on, "get_listeners_action"), self._get_listeners_action
)

def _get_listeners_action(self, event: ActionEvent) -> None:
"""Handler for get-listeners action."""
listeners = self.dependent.config_manager.all_listeners

result = {}
for listener in listeners:
key = listener.name.replace("_", "-").lower()
result.update(
{
key: {
"name": listener.name,
"scope": listener.scope,
"port": listener.port,
"protocol": listener.protocol,
"auth-mechanism": listener.mechanism,
"advertised-listener": listener.advertised_listener,
}
}
)

event.set_results(result)

def _set_password_action(self, event: ActionEvent) -> None:
"""Handler for set-password action.
Expand Down
5 changes: 3 additions & 2 deletions src/events/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
UpdateStatusEvent,
)

from events.actions import ActionEvents
from events.oauth import OAuthHandler
from events.password_actions import PasswordActionEvents
from events.provider import KafkaProvider
from events.upgrade import KafkaDependencyModel, KafkaUpgrade
from events.zookeeper import ZooKeeperHandler
Expand Down Expand Up @@ -88,7 +88,8 @@ def __init__(self, charm) -> None:
**DEPENDENCIES # pyright: ignore[reportArgumentType]
),
)
self.password_action_events = PasswordActionEvents(self)
self.action_events = ActionEvents(self)

if not self.charm.state.runs_controller:
self.zookeeper = ZooKeeperHandler(self)

Expand Down
8 changes: 8 additions & 0 deletions src/events/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import os
import re
import warnings
from typing import TYPE_CHECKING

from charms.tls_certificates_interface.v1.tls_certificates import (
Expand Down Expand Up @@ -296,6 +297,13 @@ def _request_certificate(self):

sans = self.charm.broker.tls_manager.build_sans()

# only warn during certificate creation, not every event if in structured_config
if self.charm.config.certificate_extra_sans:
warnings.warn(
"'certificate_extra_sans' config option is deprecated, use 'extra_listeners' instead",
DeprecationWarning,
)

csr = generate_csr(
private_key=self.charm.state.unit_broker.private_key.encode("utf-8"),
subject=self.charm.state.unit_broker.relation_data.get("private-address", ""),
Expand Down
3 changes: 2 additions & 1 deletion src/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ class Ports:
client: int
internal: int
external: int
extra: int = 0


AuthProtocol = Literal["SASL_PLAINTEXT", "SASL_SSL", "SSL"]
AuthMechanism = Literal["SCRAM-SHA-512", "OAUTHBEARER", "SSL"]
Scope = Literal["INTERNAL", "CLIENT", "EXTERNAL"]
Scope = Literal["INTERNAL", "CLIENT", "EXTERNAL", "EXTRA"]
AuthMap = NamedTuple("AuthMap", protocol=AuthProtocol, mechanism=AuthMechanism)

SECURITY_PROTOCOL_PORTS: dict[AuthMap, Ports] = {
Expand Down
63 changes: 55 additions & 8 deletions src/managers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
"profile",
"log_level",
"certificate_extra_sans",
"extra_listeners",
"roles",
"expose_external",
]
Expand All @@ -86,19 +87,28 @@ class Listener:
Args:
auth_map: AuthMap representing the auth.protocol and auth.mechanism for the listener
scope: scope of the listener, CLIENT, INTERNAL or EXTERNAL
scope: scope of the listener, CLIENT, INTERNAL, EXTERNAL or EXTRA
host: string with the host that will be announced
baseport (optional): integer port to offset CLIENT port numbers for EXTRA listeners
node_port (optional): the node-port for the listener if scope=EXTERNAL
"""

def __init__(
self, auth_map: AuthMap, scope: Scope, host: str = "", node_port: int | None = None
self,
auth_map: AuthMap,
scope: Scope,
host: str = "",
baseport: int = 30000,
extra_count: int = -1,
node_port: int | None = None,
):
self.auth_map = auth_map
self.protocol = auth_map.protocol
self.mechanism = auth_map.mechanism
self.host = host
self.scope = scope
self.baseport = baseport
self.extra_count = extra_count
self.node_port = node_port

@property
Expand All @@ -109,8 +119,8 @@ def scope(self) -> Scope:
@scope.setter
def scope(self, value):
"""Internal scope validator."""
if value not in ["CLIENT", "INTERNAL", "EXTERNAL"]:
raise ValueError("Only CLIENT, INTERNAL and EXTERNAL scopes are accepted")
if value not in ["CLIENT", "INTERNAL", "EXTERNAL", "EXTRA"]:
raise ValueError("Only CLIENT, INTERNAL, EXTERNAL and EXTRA scopes are accepted")

self._scope = value

Expand All @@ -121,12 +131,18 @@ def port(self) -> int:
Returns:
Integer of port number
"""
# generates ports 39092, 39192, 39292 etc for listener auth if baseport=30000
if self.scope == "EXTRA":
return getattr(SECURITY_PROTOCOL_PORTS[self.auth_map], "client") + self.baseport

return getattr(SECURITY_PROTOCOL_PORTS[self.auth_map], self.scope.lower())

@property
def name(self) -> str:
"""Name of the listener."""
return f"{self.scope}_{self.protocol}_{self.mechanism.replace('-', '_')}"
return f"{self.scope}_{self.protocol}_{self.mechanism.replace('-', '_')}" + (
f"_{self.extra_count}" if self.extra_count >= 0 else ""
)

@property
def protocol_map(self) -> str:
Expand Down Expand Up @@ -383,7 +399,7 @@ def scram_properties(self) -> list[str]:
f'listener.name.{listener_name}.{listener_mechanism}.sasl.jaas.config=org.apache.kafka.common.security.scram.ScramLoginModule required username="{username}" password="{password}";',
f"listener.name.{listener_name}.sasl.enabled.mechanisms={self.internal_listener.mechanism}",
]
for auth in self.client_listeners + self.external_listeners:
for auth in self.client_listeners + self.external_listeners + self.extra_listeners:
if not auth.mechanism.startswith("SCRAM"):
continue

Expand Down Expand Up @@ -463,8 +479,34 @@ def controller_listener(self) -> None:
pass # TODO: No good abstraction in place for the controller use case

@property
def client_listeners(self) -> list[Listener]:
def extra_listeners(self) -> list[Listener]:
"""Return a list of extra listeners."""
extra_host_baseports = [
tuple(listener.split(":")) for listener in self.config.extra_listeners
]

extra_listeners = []
extra_count = 0
for host, baseport in extra_host_baseports:
for auth_map in self.state.enabled_auth:
host = host.replace("{unit}", str(self.state.unit_broker.unit_id))
extra_listeners.append(
Listener(
host=host,
auth_map=auth_map,
scope="EXTRA",
baseport=int(baseport),
extra_count=extra_count,
)
)

extra_count += 1

return extra_listeners

@property
def client_listeners(self) -> list[Listener]:
"""Return a list of client listeners."""
return [
Listener(
host=self.state.unit_broker.internal_address, auth_map=auth_map, scope="CLIENT"
Expand Down Expand Up @@ -508,7 +550,12 @@ def external_listeners(self) -> list[Listener]:
@property
def all_listeners(self) -> list[Listener]:
"""Return a list with all expected listeners."""
return [self.internal_listener] + self.client_listeners + self.external_listeners
return (
[self.internal_listener]
+ self.client_listeners
+ self.external_listeners
+ self.extra_listeners
)

@property
def inter_broker_protocol_version(self) -> str:
Expand Down
13 changes: 5 additions & 8 deletions src/managers/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,11 @@ def remove_cert(self, alias: str) -> None:

def _build_extra_sans(self) -> list[str]:
"""Parse the certificate_extra_sans config option."""
extra_sans = self.config.certificate_extra_sans or ""
parsed_sans = []

if extra_sans == "":
return parsed_sans

for sans in extra_sans.split(","):
parsed_sans.append(sans.replace("{unit}", str(self.state.unit_broker.unit_id)))
extra_sans = self.config.extra_listeners or self.config.certificate_extra_sans or []
clean_sans = [san.split(":")[0] for san in extra_sans]
parsed_sans = [
san.replace("{unit}", str(self.state.unit_broker.unit_id)) for san in clean_sans
]

return parsed_sans

Expand Down
8 changes: 4 additions & 4 deletions tests/integration/test_charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,10 @@ async def test_observability_integration(ops_test: OpsTest):

@pytest.mark.abort_on_fail
async def test_deploy_with_existing_storage(ops_test: OpsTest):
unit_to_remove, *_ = await ops_test.model.applications[APP_NAME].add_units(count=3)
await ops_test.model.block_until(lambda: len(ops_test.model.applications[APP_NAME].units) == 4)
unit_to_remove, *_ = await ops_test.model.applications[APP_NAME].add_units(count=1)
await ops_test.model.block_until(lambda: len(ops_test.model.applications[APP_NAME].units) == 2)
await ops_test.model.wait_for_idle(
apps=[APP_NAME], status="active", timeout=1000, idle_period=30
apps=[APP_NAME], status="active", timeout=2000, idle_period=30
)

_, stdout, _ = await ops_test.juju("storage", "--format", "json")
Expand All @@ -347,7 +347,7 @@ async def test_deploy_with_existing_storage(ops_test: OpsTest):
break

await unit_to_remove.remove(destroy_storage=False)
await ops_test.model.block_until(lambda: len(ops_test.model.applications[APP_NAME].units) == 3)
await ops_test.model.block_until(lambda: len(ops_test.model.applications[APP_NAME].units) == 1)

add_unit_cmd = f"add-unit {APP_NAME} --model={ops_test.model.info.name} --attach-storage={data_storage_id}".split()
await ops_test.juju(*add_unit_cmd)
Expand Down
Loading

0 comments on commit 808c702

Please sign in to comment.