Skip to content

Commit

Permalink
fix: ensure certs are refreshed on SANs DNS changes
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoppenheimer committed Dec 3, 2024
1 parent a7a2de4 commit 5212e45
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 67 deletions.
42 changes: 40 additions & 2 deletions src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging
import time
from datetime import datetime

import ops
from charms.data_platform_libs.v0.data_models import TypedCharmBase
Expand Down Expand Up @@ -95,7 +96,7 @@ def __init__(self, *args):
current_version=self.upgrade.current_version,
)
self.tls_manager = TLSManager(
state=self.state, workload=self.workload, substrate=self.substrate
state=self.state, workload=self.workload, substrate=self.substrate, config=self.config
)
self.auth_manager = AuthManager(
state=self.state,
Expand Down Expand Up @@ -183,11 +184,48 @@ def _on_config_changed(self, event: EventBase) -> None:
zk_jaas = self.workload.read(self.workload.paths.zk_jaas)
zk_jaas_changed = set(zk_jaas) ^ set(self.config_manager.zk_jaas_config.splitlines())

if not properties or not zk_jaas:
current_sans = self.tls_manager.get_current_sans()
logger.info(f"{current_sans=}")
logger.info(f"{self.tls_manager.build_sans()=}")

if not (properties and zk_jaas):
# Event fired before charm has properly started
event.defer()
return

current_sans_ip = set(current_sans["sans_ip"]) if current_sans else set()
expected_sans_ip = set(self.tls_manager.build_sans()["sans_ip"]) if current_sans else set()
sans_ip_changed = current_sans_ip ^ expected_sans_ip

current_sans_dns = set(current_sans["sans_dns"]) if current_sans else set()
expected_sans_dns = (
set(self.tls_manager.build_sans()["sans_dns"]) if current_sans else set()
)
sans_dns_changed = current_sans_dns ^ expected_sans_dns

# update environment
self.config_manager.set_environment()

if sans_ip_changed or sans_dns_changed:
logger.info(
(
f'Broker {self.unit.name.split("/")[1]} updating certificate SANs - '
f"OLD SANs IP = {current_sans_ip - expected_sans_ip}, "
f"NEW SANs IP = {expected_sans_ip - current_sans_ip}, "
f"OLD SANs DNS = {current_sans_dns - expected_sans_dns}, "
f"NEW SANs DNS = {expected_sans_dns - current_sans_dns}"
)
)
self.tls.certificates.on.certificate_expiring.emit(
certificate=self.state.unit_broker.certificate,
expiry=datetime.now().isoformat(),
) # new cert will eventually be dynamically loaded by the broker
self.state.unit_broker.update(
{"certificate": ""}
) # ensures only single requested new certs, will be replaced on new certificate-available event

return # early return here to ensure new node cert arrives before updating advertised.listeners

# update environment
self.config_manager.set_environment()
self.unit.set_workload_version(self.workload.get_version())
Expand Down
11 changes: 11 additions & 0 deletions src/core/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import os
from functools import cached_property
from ipaddress import IPv4Address, IPv6Address

from charms.data_platform_libs.v0.data_interfaces import (
DatabaseRequirerData,
Expand Down Expand Up @@ -167,6 +168,16 @@ def clients(self) -> set[KafkaClient]:

# ---- GENERAL VALUES ----

@property
def bind_address(self) -> IPv4Address | IPv6Address | str:
"""The network binding address from the peer relation."""
bind_address = None
if self.peer_relation:
if binding := self.model.get_binding(self.peer_relation):
bind_address = binding.network.bind_address

return bind_address or ""

@property
def super_users(self) -> str:
"""Generates all users with super/admin permissions for the cluster from relations.
Expand Down
53 changes: 10 additions & 43 deletions src/events/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import logging
import os
import re
import socket
import warnings
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -156,13 +155,14 @@ def _trusted_relation_joined(self, event: RelationJoinedEvent) -> None:
subject = (
os.uname()[1] if self.charm.substrate == "k8s" else self.charm.state.unit_broker.host
)
sans = self.charm.tls_manager.build_sans()
csr = (
generate_csr(
add_unique_id_to_subject_name=bool(alias),
private_key=self.charm.state.unit_broker.private_key.encode("utf-8"),
subject=subject,
sans_ip=self._sans["sans_ip"],
sans_dns=self._sans["sans_dns"],
sans_ip=sans["sans_ip"],
sans_dns=sans["sans_dns"],
)
.decode()
.strip()
Expand Down Expand Up @@ -281,11 +281,12 @@ def _on_certificate_expiring(self, _) -> None:
logger.error("Missing unit private key and/or old csr")
return

sans = self.charm.tls_manager.build_sans()
new_csr = generate_csr(
private_key=self.charm.state.unit_broker.private_key.encode("utf-8"),
subject=self.charm.state.unit_broker.relation_data.get("private-address", ""),
sans_ip=self._sans["sans_ip"],
sans_dns=self._sans["sans_dns"],
sans_ip=sans["sans_ip"],
sans_dns=sans["sans_dns"],
)

self.certificates.request_certificate_renewal(
Expand Down Expand Up @@ -313,6 +314,8 @@ def _request_certificate(self):
logger.error("Can't request certificate, missing private key")
return

sans = self.charm.tls_manager.build_sans()

# only warn during certificate creation, not every event if in structured_config
if self.charm.config.certificate_extra_sans:
warnings.warn(
Expand All @@ -323,45 +326,9 @@ def _request_certificate(self):
csr = generate_csr(
private_key=self.charm.state.unit_broker.private_key.encode("utf-8"),
subject=self.charm.state.unit_broker.relation_data.get("private-address", ""),
sans_ip=self._sans["sans_ip"],
sans_dns=self._sans["sans_dns"],
sans_ip=sans["sans_ip"],
sans_dns=sans["sans_dns"],
)
self.charm.state.unit_broker.update({"csr": csr.decode("utf-8").strip()})

self.certificates.request_certificate_creation(certificate_signing_request=csr)

@property
def _sans(self) -> dict[str, list[str] | None]:
"""Builds a SAN dict of DNS names and IPs for the unit."""
if self.charm.substrate == "vm":
return {
"sans_ip": [self.charm.state.unit_broker.host],
"sans_dns": [self.model.unit.name, socket.getfqdn()] + self._extra_sans,
}
else:
bind_address = ""
if self.charm.state.peer_relation:
if binding := self.charm.model.get_binding(self.charm.state.peer_relation):
bind_address = binding.network.bind_address
return {
"sans_ip": [str(bind_address)],
"sans_dns": [
self.charm.state.unit_broker.host.split(".")[0],
self.charm.state.unit_broker.host,
socket.getfqdn(),
]
+ self._extra_sans,
}

@property
def _extra_sans(self) -> list[str]:
"""Parse the certificate_extra_sans config option."""
extra_sans = (
self.charm.config.extra_listeners or self.charm.config.certificate_extra_sans or []
)
clean_sans = [san.split(":")[0] for san in extra_sans]
parsed_sans = [
san.replace("{unit}", str(self.charm.state.unit_broker.unit_id)) for san in clean_sans
]

return parsed_sans
86 changes: 84 additions & 2 deletions src/managers/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,36 @@
"""Manager for handling Kafka TLS configuration."""

import logging
import subprocess # nosec B404
import socket
import subprocess
from typing import TypedDict # nosec B404

from ops.pebble import ExecError

from core.cluster import ClusterState
from core.structured_config import CharmConfig
from core.workload import WorkloadBase
from literals import GROUP, USER, Substrates

logger = logging.getLogger(__name__)

Sans = TypedDict("Sans", {"sans_ip": list[str], "sans_dns": list[str]})


class TLSManager:
"""Manager for building necessary files for Java TLS auth."""

def __init__(self, state: ClusterState, workload: WorkloadBase, substrate: Substrates):
def __init__(
self,
state: ClusterState,
workload: WorkloadBase,
substrate: Substrates,
config: CharmConfig,
):
self.state = state
self.workload = workload
self.substrate = substrate
self.config = config

self.keytool = "charmed-kafka.keytool" if self.substrate == "vm" else "keytool"

Expand Down Expand Up @@ -113,6 +125,76 @@ def remove_cert(self, alias: str) -> None:
logger.error(e.stdout)
raise e

def _build_extra_sans(self) -> list[str]:
"""Parse the certificate_extra_sans config option."""
extra_sans = self.config.extra_listeners or self.config.certificate_extra_sans or []
clean_sans = [san.split(":")[0] for san in extra_sans]
parsed_sans = [
san.replace("{unit}", str(self.state.unit_broker.unit_id)) for san in clean_sans
]

return parsed_sans

def build_sans(self) -> Sans:
"""Builds a SAN dict of DNS names and IPs for the unit."""
if self.substrate == "vm":
return {
"sans_ip": [
self.state.unit_broker.host,
],
"sans_dns": [self.state.unit_broker.unit.name, socket.getfqdn()]
+ self._build_extra_sans(),
}
else:
return {
"sans_ip": sorted(
[
str(self.state.bind_address),
]
),
"sans_dns": sorted(
[
self.state.unit_broker.host.split(".")[0],
self.state.unit_broker.host,
socket.getfqdn(),
]
+ self._build_extra_sans()
),
}

def get_current_sans(self) -> Sans | None:
"""Gets the current SANs for the unit cert."""
if not self.state.unit_broker.certificate:
return

command = ["openssl", "x509", "-noout", "-ext", "subjectAltName", "-in", "server.pem"]

try:
sans_lines = self.workload.exec(
command=" ".join(command), working_dir=self.workload.paths.conf_path
).splitlines()
logger.info(f"{sans_lines=}")
except (subprocess.CalledProcessError, ExecError) as e:
logger.error(e.stdout)
raise e

for line in sans_lines:
if "DNS" in line and "IP" in line:
break

sans_ip = []
sans_dns = []
for item in line.split(","):
logger.info(f"{item=}")
san_type, san_value = item.split(":")

if san_type.strip() == "DNS":
sans_dns.append(san_value)
if san_type.strip() == "IP Address":
sans_ip.append(san_value)

return {"sans_ip": sorted(sans_ip), "sans_dns": sorted(sans_dns)}

def remove_stores(self) -> None:
"""Cleans up all keys/certs/stores on a unit."""
try:
Expand Down
53 changes: 33 additions & 20 deletions tests/unit/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,28 @@ def test_mtls_flag_added(harness: Harness):
assert isinstance(harness.charm.app.status, ActiveStatus)


def test_extra_sans_config(harness: Harness):
def test_extra_sans_config(harness: Harness[KafkaCharm]):
# Create peer relation
peer_relation_id = harness.add_relation(PEER, CHARM_KEY)
harness.add_relation_unit(peer_relation_id, f"{CHARM_KEY}/0")
harness.update_relation_data(
peer_relation_id, f"{CHARM_KEY}/0", {"private-address": "treebeard"}
)

harness.update_config({"certificate_extra_sans": ""})
assert harness.charm.tls._extra_sans == []
manager = harness.charm.tls_manager

harness.update_config({"certificate_extra_sans": "worker{unit}.com"})
assert harness.charm.tls._extra_sans == ["worker0.com"]
harness._update_config({"certificate_extra_sans": ""})
manager.config = harness.charm.config
assert manager._build_extra_sans() == []

harness._update_config({"certificate_extra_sans": "worker{unit}.com"})
manager.config = harness.charm.config
assert "worker0.com" in "".join(manager._build_extra_sans())

harness.update_config({"certificate_extra_sans": "worker{unit}.com,{unit}.example"})
assert harness.charm.tls._extra_sans == ["worker0.com", "0.example"]
harness._update_config({"certificate_extra_sans": "worker{unit}.com,{unit}.example"})
manager.config = harness.charm.config
assert "worker0.com" in "".join(manager._build_extra_sans())
assert "0.example" in "".join(manager._build_extra_sans())

# verifying that sans can be built with both certificate_extra_sans and extra_listeners
harness._update_config(
Expand All @@ -117,33 +123,40 @@ def test_extra_sans_config(harness: Harness):
"extra_listeners": "worker{unit}.com:30000,{unit}.example:40000,nonunit.domain.com:45000",
}
)
assert harness.charm.tls._extra_sans
assert "worker0.com" in "".join(harness.charm.tls._extra_sans)
assert "0.example" in "".join(harness.charm.tls._extra_sans)
assert "nonunit.domain.com" in "".join(harness.charm.tls._extra_sans)
manager.config = harness.charm.config
assert manager._build_extra_sans
assert "worker0.com" in "".join(manager._build_extra_sans())
assert "0.example" in "".join(manager._build_extra_sans())
assert "nonunit.domain.com" in "".join(manager._build_extra_sans())


def test_sans(harness: Harness):
def test_sans(harness: Harness[KafkaCharm]):
# Create peer relation
peer_relation_id = harness.add_relation(PEER, CHARM_KEY)
harness.add_relation_unit(peer_relation_id, f"{CHARM_KEY}/0")
harness.update_relation_data(
peer_relation_id, f"{CHARM_KEY}/0", {"private-address": "treebeard"}
)

manager = harness.charm.tls_manager
harness.update_config({"certificate_extra_sans": "worker{unit}.com"})
manager.config = harness.charm.config

sock_dns = socket.getfqdn()
if SUBSTRATE == "vm":
assert harness.charm.tls._sans == {
assert manager.build_sans() == {
"sans_ip": ["treebeard"],
"sans_dns": [f"{CHARM_KEY}/0", sock_dns, "worker0.com"],
}
elif SUBSTRATE == "k8s":
# NOTE previous k8s sans_ip like kafka-k8s-0.kafka-k8s-endpoints or binding pod address
with patch("ops.model.Model.get_binding"):
assert harness.charm.tls._sans["sans_dns"] == [
"kafka-k8s-0",
"kafka-k8s-0.kafka-k8s-endpoints",
sock_dns,
"worker0.com",
]
with (patch("ops.model.Model.get_binding")):
assert sorted(manager.build_sans()["sans_dns"]) == sorted(
[
"kafka-k8s-0",
"kafka-k8s-0.kafka-k8s-endpoints",
sock_dns,
"worker0.com",
]
)
assert "palantir" in "".join(manager.build_sans()["sans_ip"])

0 comments on commit 5212e45

Please sign in to comment.