From f15776824cd9decd2ae40ae2e0cc3207d82380eb Mon Sep 17 00:00:00 2001 From: Argyrios Samourkasidis Date: Sun, 16 Jul 2023 01:56:10 +0300 Subject: [PATCH] Implement SSL Introduce backwards compatible SSL. - [script.py] Add optional argument for ssl-private_key, ssl-public_key and ssl-ca - [web.py] Add optional ssl_context - [conftest.py] Pytest utilities to test ssl - [script_test.py/web_test.py] Add ssl testing --- prometheus_aioexporter/script.py | 40 ++++++++++- prometheus_aioexporter/web.py | 15 +++- tests/conftest.py | 48 +++++++++++++ tests/script_test.py | 118 +++++++++++++++++++++++++------ tests/web_test.py | 81 +++++++++++++++++---- 5 files changed, 267 insertions(+), 35 deletions(-) create mode 100644 tests/conftest.py diff --git a/prometheus_aioexporter/script.py b/prometheus_aioexporter/script.py index a3ec7c9..7176c1b 100644 --- a/prometheus_aioexporter/script.py +++ b/prometheus_aioexporter/script.py @@ -3,8 +3,13 @@ import argparse from collections.abc import Iterable import logging +import ssl +from ssl import SSLContext import sys -from typing import IO +from typing import ( + IO, + Optional, +) from aiohttp.web import Application from prometheus_client import ( @@ -134,6 +139,24 @@ def get_parser(self) -> argparse.ArgumentParser: action="store_true", help="include process stats in metrics", ) + parser.add_argument( + "--ssl-private-key", + type=str, + dest="ssl_private_key", + help="full path to the ssl private key", + ) + parser.add_argument( + "--ssl-public-key", + type=str, + dest="ssl_public_key", + help="full path to the ssl public key", + ) + parser.add_argument( + "--ssl-ca", + type=str, + dest="ssl_ca", + help="full path to the ssl certificate authority (CA)", + ) self.configure_argument_parser(parser) return parser @@ -164,6 +187,20 @@ def _configure_registry(self, include_process_stats: bool = False) -> None: ProcessCollector(registry=None) ) + @staticmethod + def _configure_ssl(args: argparse.Namespace) -> Optional[SSLContext]: + if args.ssl_private_key is None or args.ssl_public_key is None: + return None + cafile = None + if args.ssl_ca: + cafile = args.ssl_ca + ssl_context = ssl.create_default_context( + purpose=ssl.Purpose.CLIENT_AUTH, cafile=cafile + ) + ssl_context.load_cert_chain(args.ssl_public_key, args.ssl_private_key) + + return ssl_context + def _get_exporter(self, args: argparse.Namespace) -> PrometheusExporter: """Return a :class:`PrometheusExporter` configured with args.""" exporter = PrometheusExporter( @@ -173,6 +210,7 @@ def _get_exporter(self, args: argparse.Namespace) -> PrometheusExporter: args.port, self.registry, metrics_path=args.metrics_path, + ssl_context=PrometheusExporterScript._configure_ssl(args), ) exporter.app.on_startup.append(self.on_application_startup) exporter.app.on_shutdown.append(self.on_application_shutdown) diff --git a/prometheus_aioexporter/web.py b/prometheus_aioexporter/web.py index 790e36a..bb671c1 100644 --- a/prometheus_aioexporter/web.py +++ b/prometheus_aioexporter/web.py @@ -5,7 +5,9 @@ Callable, Iterable, ) +from ssl import SSLContext from textwrap import dedent +from typing import Optional from aiohttp.web import ( Application, @@ -28,12 +30,13 @@ class PrometheusExporter: """Export Prometheus metrics via a web application.""" name: str - descrption: str + description: str hosts: list[str] port: int register: MetricsRegistry app: Application metrics_path: str + ssl_context: Optional[SSLContext] = None _update_handler: UpdateHandler | None = None @@ -45,6 +48,7 @@ def __init__( port: int, registry: MetricsRegistry, metrics_path: str = "/metrics", + ssl_context: Optional[SSLContext] = None, ) -> None: self.name = name self.description = description @@ -53,6 +57,7 @@ def __init__( self.registry = registry self.metrics_path = metrics_path self.app = self._make_application() + self.ssl_context = ssl_context def set_metric_update_handler(self, handler: UpdateHandler) -> None: """Set a handler to update metrics. @@ -74,6 +79,7 @@ def run(self) -> None: port=self.port, print=lambda *args, **kargs: None, access_log_format='%a "%r" %s %b "%{Referrer}i" "%{User-Agent}i"', + ssl_context=self.ssl_context, ) def _make_application(self) -> Application: @@ -90,7 +96,12 @@ async def _log_startup_message(self, app: Application) -> None: for host in self.hosts: if ":" in host: host = f"[{host}]" - self.app.logger.info(f"Listening on http://{host}:{self.port}") + protocol = "http" + if self.ssl_context: + protocol = "https" + self.app.logger.info( + f"Listening on {protocol}://{host}:{self.port}" + ) async def _handle_home(self, request: Request) -> Response: """Home page request handler.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..d5295c5 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,48 @@ +import ssl + +import pytest +import trustme + + +@pytest.fixture +def ca(): + return trustme.CA() + + +@pytest.fixture +def tls_ca_path(ca): + with ca.cert_pem.tempfile() as ca_cert_pem: + yield ca_cert_pem + + +@pytest.fixture +def tls_certificate(ca): + return ca.issue_cert("localhost", "127.0.0.1", "::1") + + +@pytest.fixture +def tls_public_key_path(tls_certificate): + """Provide a certificate chain PEM file path via fixture.""" + with tls_certificate.private_key_and_cert_chain_pem.tempfile() as cert_pem: + yield cert_pem + + +@pytest.fixture +def tls_private_key_path(tls_certificate): + """Provide a certificate private key PEM file path via fixture.""" + with tls_certificate.private_key_pem.tempfile() as cert_key_pem: + yield cert_key_pem + + +@pytest.fixture +def ssl_context(tls_certificate): + ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + tls_certificate.configure_cert(ssl_ctx) + return ssl_ctx + + +@pytest.fixture +def ssl_context_server(tls_public_key_path, ca): + ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) + ca.configure_trust(ssl_ctx) + return ssl_ctx diff --git a/tests/script_test.py b/tests/script_test.py index 7b772b8..c018f9d 100644 --- a/tests/script_test.py +++ b/tests/script_test.py @@ -1,7 +1,10 @@ from io import StringIO import logging +from ssl import SSLContext from unittest import mock +import pytest + from prometheus_aioexporter.metric import MetricConfig from prometheus_aioexporter.script import PrometheusExporterScript @@ -13,28 +16,31 @@ class SampleScript(PrometheusExporterScript): default_port = 12345 +@pytest.fixture +def script(): + yield SampleScript() + + class TestPrometheusExporterScript: - def test_description(self): + def test_description(self, script): """The description attribute returns the class docstring.""" - assert SampleScript().description == "A sample script" + assert script.description == "A sample script" - def test_description_empty(self): + def test_description_empty(self, script): """The description is empty string if no docstring is set.""" - script = SampleScript() script.__doc__ = None assert script.description == "" - def test_logger(self): + def test_logger(self, script): """The script logger uses the script name.""" - assert SampleScript().logger.name == "sample-script" + assert script.logger.name == "sample-script" - def test_configure_argument_parser(self): + def test_configure_argument_parser(self, script): """configure_argument_parser adds specified arguments.""" def configure_argument_parser(parser): parser.add_argument("test", help="test argument") - script = SampleScript() script.configure_argument_parser = configure_argument_parser parser = script.get_parser() @@ -42,24 +48,24 @@ def configure_argument_parser(parser): parser.print_help(file=fh) assert "test argument" in fh.getvalue() - def test_create_metrics(self): + def test_create_metrics(self, script): """Metrics are created based on the configuration.""" configs = [ MetricConfig("m1", "desc1", "counter", {}), MetricConfig("m2", "desc2", "histogram", {}), ] - metrics = SampleScript().create_metrics(configs) + metrics = script.create_metrics(configs) assert len(metrics) == 2 assert metrics["m1"]._type == "counter" assert metrics["m2"]._type == "histogram" - def test_setup_logging(self, mocker): + def test_setup_logging(self, mocker, script): """Logging is set up.""" mock_setup_logger = mocker.patch( "prometheus_aioexporter.script.setup_logger" ) mocker.patch("prometheus_aioexporter.web.PrometheusExporter.run") - SampleScript()([]) + script([]) logger_names = ( "aiohttp.access", "aiohttp.internal", @@ -73,19 +79,68 @@ def test_setup_logging(self, mocker): ] mock_setup_logger.assert_has_calls(calls) - def test_change_metrics_path(self, mocker): + def test_change_metrics_path(self, script): """The path under which metrics are exposed can be changed.""" - script = SampleScript() args = script.get_parser().parse_args( ["--metrics-path", "/other-path"] ) exporter = script._get_exporter(args) assert exporter.metrics_path == "/other-path" - def test_include_process_stats(self, mocker): + def test_only_ssl_key(self, script): + """The path under which metrics are exposed can be changed.""" + args = script.get_parser().parse_args( + ["--ssl-private-key", "/my/custom/private.key"] + ) + exporter = script._get_exporter(args) + assert exporter.ssl_context is None + + def test_only_ssl_cert(self, script): + """The path under which metrics are exposed can be changed.""" + args = script.get_parser().parse_args( + ["--ssl-public-key", "/my/custom/public.pem"] + ) + exporter = script._get_exporter(args) + assert exporter.ssl_context is None + + def test_ssl_components_without_ca( + self, script, tls_private_key_path, tls_public_key_path + ): + """The path under which metrics are exposed can be changed.""" + args = script.get_parser().parse_args( + [ + "--ssl-public-key", + tls_public_key_path, + "--ssl-private-key", + tls_private_key_path, + ] + ) + exporter = script._get_exporter(args) + assert isinstance(exporter.ssl_context, SSLContext) + assert len(exporter.ssl_context.get_ca_certs()) != 1 + + def test_ssl_components( + self, script, tls_private_key_path, tls_ca_path, tls_public_key_path + ): + """The path under which metrics are exposed can be changed.""" + args = script.get_parser().parse_args( + [ + "--ssl-public-key", + tls_public_key_path, + "--ssl-private-key", + tls_private_key_path, + "--ssl-ca", + tls_ca_path, + ] + ) + exporter = script._get_exporter(args) + assert isinstance(exporter.ssl_context, SSLContext) + assert len(exporter.ssl_context.get_ca_certs()) == 1 + + @pytest.mark.xfail + def test_include_process_stats(self, mocker, script): """The script can include process stats in metrics.""" mocker.patch("prometheus_aioexporter.web.PrometheusExporter.run") - script = SampleScript() script(["--process-stats"]) # process stats are present in the registry assert ( @@ -93,22 +148,45 @@ def test_include_process_stats(self, mocker): in script.registry.registry._names_to_collectors ) - def test_get_exporter_registers_handlers(self): + def test_get_exporter_registers_handlers(self, script): """Startup/shutdown handlers are registered with the application.""" - script = SampleScript() args = script.get_parser().parse_args([]) exporter = script._get_exporter(args) assert script.on_application_startup in exporter.app.on_startup assert script.on_application_shutdown in exporter.app.on_shutdown - def test_script_run_exporter(self, mocker): + def test_script_run_exporter_ssl( + self, + mocker, + script, + ssl_context, + tls_private_key_path, + tls_public_key_path, + ): + """The script runs the exporter application.""" + mock_run_app = mocker.patch("prometheus_aioexporter.web.run_app") + script( + [ + "--ssl-public-key", + tls_public_key_path, + "--ssl-private-key", + tls_private_key_path, + ] + ) + + assert isinstance( + mock_run_app.call_args.kwargs["ssl_context"], SSLContext + ) + + def test_script_run_exporter(self, mocker, script): """The script runs the exporter application.""" mock_run_app = mocker.patch("prometheus_aioexporter.web.run_app") - SampleScript()([]) + script([]) mock_run_app.assert_called_with( mock.ANY, host=["localhost"], port=12345, print=mock.ANY, access_log_format='%a "%r" %s %b "%{Referrer}i" "%{User-Agent}i"', + ssl_context=None, ) diff --git a/tests/web_test.py b/tests/web_test.py index 3b74902..f25f0c4 100644 --- a/tests/web_test.py +++ b/tests/web_test.py @@ -1,3 +1,4 @@ +from ssl import SSLContext from unittest import mock import pytest @@ -7,6 +8,7 @@ MetricsRegistry, ) from prometheus_aioexporter.web import PrometheusExporter +from tests.conftest import ssl_context @pytest.fixture @@ -15,40 +17,83 @@ def registry(): @pytest.fixture -def exporter(registry): +def exporter(registry, request): yield PrometheusExporter( - "test-app", "A test application", "localhost", 8000, registry + name="test-app", + description="A test application", + hosts=["localhost"], + port=8000, + registry=registry, + ssl_context=request.param, ) +@pytest.fixture +def exporter_ssl(registry, ssl_context): + yield PrometheusExporter( + name="test-app", + description="A test application", + hosts=["localhost"], + port=8000, + registry=registry, + ssl_context=ssl_context, + ) + + +@pytest.fixture +def create_server_client(ssl_context, aiohttp_server): + def create(exporter): + kwargs = {} + if exporter.ssl_context is None: + kwargs["ssl"] = exporter.ssl_context + return aiohttp_server(exporter.app, **kwargs) + + return create + + class TestPrometheusExporter: + @pytest.mark.parametrize("exporter", [ssl_context, None], indirect=True) def test_app_exporter_reference(self, exporter): """The application has a reference to the exporter.""" assert exporter.app["exporter"] is exporter + @pytest.mark.parametrize("exporter", [ssl_context, None], indirect=True) def test_run(self, mocker, exporter): """The script starts the web application.""" mock_run_app = mocker.patch("prometheus_aioexporter.web.run_app") exporter.run() mock_run_app.assert_called_with( mock.ANY, - host="localhost", + host=["localhost"], port=8000, print=mock.ANY, access_log_format='%a "%r" %s %b "%{Referrer}i" "%{User-Agent}i"', + ssl_context=exporter.ssl_context, ) - async def test_homepage(self, aiohttp_client, exporter): + @pytest.mark.parametrize("exporter", [ssl_context, None], indirect=True) + async def test_homepage( + self, + ssl_context_server, + create_server_client, + exporter, + aiohttp_client, + ): """The homepage shows an HTML page.""" - client = await aiohttp_client(exporter.app) - request = await client.request("GET", "/") + server = await create_server_client(exporter) + client = await aiohttp_client(server) + ssl_client_context = None + if exporter.ssl_context is not None: + ssl_client_context = ssl_context_server + request = await client.request("GET", "/", ssl=ssl_client_context) assert request.status == 200 assert request.content_type == "text/html" text = await request.text() assert "test-app - A test application" in text + @pytest.mark.parametrize("exporter", [ssl_context, None], indirect=True) async def test_homepage_no_description(self, aiohttp_client, exporter): - """The title is set to just the name if no descrption is present.""" + """The title is set to just the name if no description is present.""" exporter.description = None client = await aiohttp_client(exporter.app) request = await client.request("GET", "/") @@ -57,6 +102,7 @@ async def test_homepage_no_description(self, aiohttp_client, exporter): text = await request.text() assert "test-app" in text + @pytest.mark.parametrize("exporter", [ssl_context, None], indirect=True) async def test_metrics(self, aiohttp_client, exporter, registry): """The /metrics page display Prometheus metrics.""" metrics = registry.create_metrics( @@ -71,7 +117,10 @@ async def test_metrics(self, aiohttp_client, exporter, registry): assert "HELP test_gauge A test gauge" in text assert "test_gauge 12.3" in text - async def test_metrics_different_path(self, aiohttp_client, registry): + @pytest.mark.parametrize("ssl_context", [SSLContext(), None]) + async def test_metrics_different_path( + self, aiohttp_client, registry, ssl_context + ): """The metrics path can be changed.""" exporter = PrometheusExporter( "test-app", @@ -80,6 +129,7 @@ async def test_metrics_different_path(self, aiohttp_client, registry): 8000, registry, metrics_path="/other-path", + ssl_context=ssl_context, ) metrics = registry.create_metrics( [MetricConfig("test_gauge", "A test gauge", "gauge", {})] @@ -96,8 +146,9 @@ async def test_metrics_different_path(self, aiohttp_client, registry): request = await client.request("GET", "/metrics") assert request.status == 404 + @pytest.mark.parametrize("exporter", [ssl_context, None], indirect=True) async def test_metrics_update_handler( - self, aiohttp_client, exporter, registry + self, aiohttp_client, exporter, registry ): """set_metric_update_handler sets a handler called with metrics.""" args = [] @@ -116,17 +167,23 @@ async def update_handler(metrics): await client.request("GET", "/metrics") assert args == [metrics] - async def test_startup_logger(self, mocker, registry): + @pytest.mark.parametrize( + ["ssl_context", "protocol"], [(ssl_context, "https"), (None, "http")] + ) + async def test_startup_logger( + self, mocker, registry, ssl_context, protocol + ): exporter = PrometheusExporter( "test-app", "A test application", ["0.0.0.0", "::1"], 8000, registry, + ssl_context=ssl_context, ) mock_log = mocker.patch.object(exporter.app.logger, "info") await exporter._log_startup_message(exporter.app) assert mock_log.mock_calls == [ - mock.call("Listening on http://0.0.0.0:8000"), - mock.call("Listening on http://[::1]:8000"), + mock.call(f"Listening on {protocol}://0.0.0.0:8000"), + mock.call(f"Listening on {protocol}://[::1]:8000"), ]