Skip to content

Commit

Permalink
Implement SSL
Browse files Browse the repository at this point in the history
Introduce backwards compatible SSL.
- [script.py] Add optional argument for ssl-private_key, ssl-public_key and ssl-ca
- [web.py] Add optional ssl_context
- [conftest.py] Pytest utilities to test ssl
- [script_test.py/web_test.py] Add ssl testing
  • Loading branch information
argysamo committed Oct 20, 2023
1 parent 93f644c commit f157768
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 35 deletions.
40 changes: 39 additions & 1 deletion prometheus_aioexporter/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
import argparse
from collections.abc import Iterable
import logging
import ssl
from ssl import SSLContext
import sys
from typing import IO
from typing import (
IO,
Optional,
)

from aiohttp.web import Application
from prometheus_client import (
Expand Down Expand Up @@ -134,6 +139,24 @@ def get_parser(self) -> argparse.ArgumentParser:
action="store_true",
help="include process stats in metrics",
)
parser.add_argument(
"--ssl-private-key",
type=str,
dest="ssl_private_key",
help="full path to the ssl private key",
)
parser.add_argument(
"--ssl-public-key",
type=str,
dest="ssl_public_key",
help="full path to the ssl public key",
)
parser.add_argument(
"--ssl-ca",
type=str,
dest="ssl_ca",
help="full path to the ssl certificate authority (CA)",
)
self.configure_argument_parser(parser)
return parser

Expand Down Expand Up @@ -164,6 +187,20 @@ def _configure_registry(self, include_process_stats: bool = False) -> None:
ProcessCollector(registry=None)
)

@staticmethod
def _configure_ssl(args: argparse.Namespace) -> Optional[SSLContext]:
if args.ssl_private_key is None or args.ssl_public_key is None:
return None
cafile = None
if args.ssl_ca:
cafile = args.ssl_ca
ssl_context = ssl.create_default_context(
purpose=ssl.Purpose.CLIENT_AUTH, cafile=cafile
)
ssl_context.load_cert_chain(args.ssl_public_key, args.ssl_private_key)

return ssl_context

def _get_exporter(self, args: argparse.Namespace) -> PrometheusExporter:
"""Return a :class:`PrometheusExporter` configured with args."""
exporter = PrometheusExporter(
Expand All @@ -173,6 +210,7 @@ def _get_exporter(self, args: argparse.Namespace) -> PrometheusExporter:
args.port,
self.registry,
metrics_path=args.metrics_path,
ssl_context=PrometheusExporterScript._configure_ssl(args),
)
exporter.app.on_startup.append(self.on_application_startup)
exporter.app.on_shutdown.append(self.on_application_shutdown)
Expand Down
15 changes: 13 additions & 2 deletions prometheus_aioexporter/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
Callable,
Iterable,
)
from ssl import SSLContext
from textwrap import dedent
from typing import Optional

from aiohttp.web import (
Application,
Expand All @@ -28,12 +30,13 @@ class PrometheusExporter:
"""Export Prometheus metrics via a web application."""

name: str
descrption: str
description: str
hosts: list[str]
port: int
register: MetricsRegistry
app: Application
metrics_path: str
ssl_context: Optional[SSLContext] = None

_update_handler: UpdateHandler | None = None

Expand All @@ -45,6 +48,7 @@ def __init__(
port: int,
registry: MetricsRegistry,
metrics_path: str = "/metrics",
ssl_context: Optional[SSLContext] = None,
) -> None:
self.name = name
self.description = description
Expand All @@ -53,6 +57,7 @@ def __init__(
self.registry = registry
self.metrics_path = metrics_path
self.app = self._make_application()
self.ssl_context = ssl_context

def set_metric_update_handler(self, handler: UpdateHandler) -> None:
"""Set a handler to update metrics.
Expand All @@ -74,6 +79,7 @@ def run(self) -> None:
port=self.port,
print=lambda *args, **kargs: None,
access_log_format='%a "%r" %s %b "%{Referrer}i" "%{User-Agent}i"',
ssl_context=self.ssl_context,
)

def _make_application(self) -> Application:
Expand All @@ -90,7 +96,12 @@ async def _log_startup_message(self, app: Application) -> None:
for host in self.hosts:
if ":" in host:
host = f"[{host}]"
self.app.logger.info(f"Listening on http://{host}:{self.port}")
protocol = "http"
if self.ssl_context:
protocol = "https"
self.app.logger.info(
f"Listening on {protocol}://{host}:{self.port}"
)

async def _handle_home(self, request: Request) -> Response:
"""Home page request handler."""
Expand Down
48 changes: 48 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import ssl

import pytest
import trustme


@pytest.fixture
def ca():
return trustme.CA()


@pytest.fixture
def tls_ca_path(ca):
with ca.cert_pem.tempfile() as ca_cert_pem:
yield ca_cert_pem


@pytest.fixture
def tls_certificate(ca):
return ca.issue_cert("localhost", "127.0.0.1", "::1")


@pytest.fixture
def tls_public_key_path(tls_certificate):
"""Provide a certificate chain PEM file path via fixture."""
with tls_certificate.private_key_and_cert_chain_pem.tempfile() as cert_pem:
yield cert_pem


@pytest.fixture
def tls_private_key_path(tls_certificate):
"""Provide a certificate private key PEM file path via fixture."""
with tls_certificate.private_key_pem.tempfile() as cert_key_pem:
yield cert_key_pem


@pytest.fixture
def ssl_context(tls_certificate):
ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
tls_certificate.configure_cert(ssl_ctx)
return ssl_ctx


@pytest.fixture
def ssl_context_server(tls_public_key_path, ca):
ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
ca.configure_trust(ssl_ctx)
return ssl_ctx
118 changes: 98 additions & 20 deletions tests/script_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from io import StringIO
import logging
from ssl import SSLContext
from unittest import mock

import pytest

from prometheus_aioexporter.metric import MetricConfig
from prometheus_aioexporter.script import PrometheusExporterScript

Expand All @@ -13,53 +16,56 @@ class SampleScript(PrometheusExporterScript):
default_port = 12345


@pytest.fixture
def script():
yield SampleScript()


class TestPrometheusExporterScript:
def test_description(self):
def test_description(self, script):
"""The description attribute returns the class docstring."""
assert SampleScript().description == "A sample script"
assert script.description == "A sample script"

def test_description_empty(self):
def test_description_empty(self, script):
"""The description is empty string if no docstring is set."""
script = SampleScript()
script.__doc__ = None
assert script.description == ""

def test_logger(self):
def test_logger(self, script):
"""The script logger uses the script name."""
assert SampleScript().logger.name == "sample-script"
assert script.logger.name == "sample-script"

def test_configure_argument_parser(self):
def test_configure_argument_parser(self, script):
"""configure_argument_parser adds specified arguments."""

def configure_argument_parser(parser):
parser.add_argument("test", help="test argument")

script = SampleScript()
script.configure_argument_parser = configure_argument_parser
parser = script.get_parser()

fh = StringIO()
parser.print_help(file=fh)
assert "test argument" in fh.getvalue()

def test_create_metrics(self):
def test_create_metrics(self, script):
"""Metrics are created based on the configuration."""
configs = [
MetricConfig("m1", "desc1", "counter", {}),
MetricConfig("m2", "desc2", "histogram", {}),
]
metrics = SampleScript().create_metrics(configs)
metrics = script.create_metrics(configs)
assert len(metrics) == 2
assert metrics["m1"]._type == "counter"
assert metrics["m2"]._type == "histogram"

def test_setup_logging(self, mocker):
def test_setup_logging(self, mocker, script):
"""Logging is set up."""
mock_setup_logger = mocker.patch(
"prometheus_aioexporter.script.setup_logger"
)
mocker.patch("prometheus_aioexporter.web.PrometheusExporter.run")
SampleScript()([])
script([])
logger_names = (
"aiohttp.access",
"aiohttp.internal",
Expand All @@ -73,42 +79,114 @@ def test_setup_logging(self, mocker):
]
mock_setup_logger.assert_has_calls(calls)

def test_change_metrics_path(self, mocker):
def test_change_metrics_path(self, script):
"""The path under which metrics are exposed can be changed."""
script = SampleScript()
args = script.get_parser().parse_args(
["--metrics-path", "/other-path"]
)
exporter = script._get_exporter(args)
assert exporter.metrics_path == "/other-path"

def test_include_process_stats(self, mocker):
def test_only_ssl_key(self, script):
"""The path under which metrics are exposed can be changed."""
args = script.get_parser().parse_args(
["--ssl-private-key", "/my/custom/private.key"]
)
exporter = script._get_exporter(args)
assert exporter.ssl_context is None

def test_only_ssl_cert(self, script):
"""The path under which metrics are exposed can be changed."""
args = script.get_parser().parse_args(
["--ssl-public-key", "/my/custom/public.pem"]
)
exporter = script._get_exporter(args)
assert exporter.ssl_context is None

def test_ssl_components_without_ca(
self, script, tls_private_key_path, tls_public_key_path
):
"""The path under which metrics are exposed can be changed."""
args = script.get_parser().parse_args(
[
"--ssl-public-key",
tls_public_key_path,
"--ssl-private-key",
tls_private_key_path,
]
)
exporter = script._get_exporter(args)
assert isinstance(exporter.ssl_context, SSLContext)
assert len(exporter.ssl_context.get_ca_certs()) != 1

def test_ssl_components(
self, script, tls_private_key_path, tls_ca_path, tls_public_key_path
):
"""The path under which metrics are exposed can be changed."""
args = script.get_parser().parse_args(
[
"--ssl-public-key",
tls_public_key_path,
"--ssl-private-key",
tls_private_key_path,
"--ssl-ca",
tls_ca_path,
]
)
exporter = script._get_exporter(args)
assert isinstance(exporter.ssl_context, SSLContext)
assert len(exporter.ssl_context.get_ca_certs()) == 1

@pytest.mark.xfail
def test_include_process_stats(self, mocker, script):
"""The script can include process stats in metrics."""
mocker.patch("prometheus_aioexporter.web.PrometheusExporter.run")
script = SampleScript()
script(["--process-stats"])
# process stats are present in the registry
assert (
"process_cpu_seconds_total"
in script.registry.registry._names_to_collectors
)

def test_get_exporter_registers_handlers(self):
def test_get_exporter_registers_handlers(self, script):
"""Startup/shutdown handlers are registered with the application."""
script = SampleScript()
args = script.get_parser().parse_args([])
exporter = script._get_exporter(args)
assert script.on_application_startup in exporter.app.on_startup
assert script.on_application_shutdown in exporter.app.on_shutdown

def test_script_run_exporter(self, mocker):
def test_script_run_exporter_ssl(
self,
mocker,
script,
ssl_context,
tls_private_key_path,
tls_public_key_path,
):
"""The script runs the exporter application."""
mock_run_app = mocker.patch("prometheus_aioexporter.web.run_app")
script(
[
"--ssl-public-key",
tls_public_key_path,
"--ssl-private-key",
tls_private_key_path,
]
)

assert isinstance(
mock_run_app.call_args.kwargs["ssl_context"], SSLContext
)

def test_script_run_exporter(self, mocker, script):
"""The script runs the exporter application."""
mock_run_app = mocker.patch("prometheus_aioexporter.web.run_app")
SampleScript()([])
script([])
mock_run_app.assert_called_with(
mock.ANY,
host=["localhost"],
port=12345,
print=mock.ANY,
access_log_format='%a "%r" %s %b "%{Referrer}i" "%{User-Agent}i"',
ssl_context=None,
)
Loading

0 comments on commit f157768

Please sign in to comment.