Skip to content

Commit

Permalink
add typing to tests, cleanup fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
albertodonato committed Oct 26, 2023
1 parent ede901c commit ee9b3c7
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 125 deletions.
29 changes: 20 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,59 @@
import ssl
from typing import Iterator

import pytest
import trustme
from trustme import (
CA,
LeafCert,
)


@pytest.fixture
def ca():
yield trustme.CA()
def ca() -> Iterator[CA]:
"""A root CA."""
yield CA()


@pytest.fixture
def tls_ca_path(ca):
def tls_ca_path(ca: CA) -> Iterator[str]:
"""Path for the CA certificate."""
with ca.cert_pem.tempfile() as ca_cert_pem:
yield ca_cert_pem


@pytest.fixture
def tls_certificate(ca):
def tls_certificate(ca: CA) -> Iterator[LeafCert]:
"""A leaf certificate."""
yield ca.issue_cert("localhost", "127.0.0.1", "::1")


@pytest.fixture
def tls_public_key_path(tls_certificate):
def tls_public_key_path(tls_certificate: LeafCert) -> Iterator[str]:
"""Provide a certificate chain PEM file path via fixture."""
with tls_certificate.private_key_and_cert_chain_pem.tempfile() as cert_pem:
yield cert_pem


@pytest.fixture
def tls_private_key_path(tls_certificate):
def tls_private_key_path(tls_certificate: LeafCert) -> Iterator[str]:
"""Provide a certificate private key PEM file path via fixture."""
with tls_certificate.private_key_pem.tempfile() as cert_key_pem:
yield cert_key_pem


@pytest.fixture
def ssl_context(tls_certificate):
def ssl_context(tls_certificate: LeafCert) -> Iterator[ssl.SSLContext]:
"""SSL context with the test CA."""
ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
tls_certificate.configure_cert(ssl_ctx)
yield ssl_ctx


@pytest.fixture
def ssl_context_server(tls_public_key_path, ca):
def ssl_context_server(
ca: CA, tls_public_key_path: str
) -> Iterator[ssl.SSLContext]:
"""SSL context for server authentication."""
ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
ca.configure_trust(ssl_ctx)
yield ssl_ctx
28 changes: 12 additions & 16 deletions tests/metric_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import (
Any,
Callable,
cast,
)

from prometheus_client import Histogram
from prometheus_client.metrics import MetricWrapperBase
import pytest

Expand All @@ -14,8 +16,7 @@


class TestMetricConfig:
def test_invalid_metric_type(self):
"""An invalid metric type raises an error."""
def test_invalid_metric_type(self) -> None:
with pytest.raises(InvalidMetricType) as error:
MetricConfig("m1", "desc1", "unknown")
assert str(error.value) == (
Expand All @@ -29,8 +30,7 @@ def test_labels_sorted(self) -> None:


class TestMetricsRegistry:
def test_create_metrics(self):
"""Prometheus metrics are created from the specified config."""
def test_create_metrics(self) -> None:
configs = [
MetricConfig("m1", "desc1", "counter"),
MetricConfig("m2", "desc2", "histogram"),
Expand All @@ -40,27 +40,25 @@ def test_create_metrics(self):
assert metrics["m1"]._type == "counter"
assert metrics["m2"]._type == "histogram"

def test_create_metrics_with_config(self):
"""Metric configs are applied."""
def test_create_metrics_with_config(self) -> None:
configs = [
MetricConfig(
"m1", "desc1", "histogram", config={"buckets": [10, 20]}
)
]
metrics = MetricsRegistry().create_metrics(configs)
# The two specified bucket plus +Inf
assert len(metrics["m1"]._buckets) == 3
# Histogram has the two specified bucket plus +Inf
histogram = cast(Histogram, metrics["m1"])
assert len(histogram._buckets) == 3

def test_create_metrics_config_ignores_unknown(self):
"""Unknown metric configs are ignored and don't cause an error."""
def test_create_metrics_config_ignores_unknown(self) -> None:
configs = [
MetricConfig("m1", "desc1", "gauge", config={"unknown": "value"})
]
metrics = MetricsRegistry().create_metrics(configs)
assert len(metrics) == 1

def test_get_metrics(self):
"""get_metrics returns a dict with metrics."""
def test_get_metrics(self) -> None:
registry = MetricsRegistry()
metrics = registry.create_metrics(
[
Expand All @@ -70,8 +68,7 @@ def test_get_metrics(self):
)
assert registry.get_metrics() == metrics

def test_get_metric(self):
"""get_metric returns a metric."""
def test_get_metric(self) -> None:
configs = [
MetricConfig(
"m",
Expand All @@ -86,8 +83,7 @@ def test_get_metric(self):
assert metric._name == "m"
assert metric._labelvalues == ()

def test_get_metric_with_labels(self):
"""get_metric returns a metric configured with labels."""
def test_get_metric_with_labels(self) -> None:
configs = [
MetricConfig("m", "A test gauge", "gauge", labels=("l1", "l2"))
]
Expand Down
95 changes: 52 additions & 43 deletions tests/script_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from argparse import ArgumentParser
from io import StringIO
import logging
from ssl import SSLContext
from typing import Iterator
from unittest import mock

import pytest
from pytest_mock import MockerFixture

from prometheus_aioexporter._metric import MetricConfig
from prometheus_aioexporter._script import PrometheusExporterScript
Expand All @@ -15,41 +18,36 @@ class SampleScript(PrometheusExporterScript):
name = "sample-script"
default_port = 12345

def configure_argument_parser(self, parser: ArgumentParser) -> None:
parser.add_argument("--test", help="test argument")


@pytest.fixture
def script():
def script() -> Iterator[PrometheusExporterScript]:
yield SampleScript()


class TestPrometheusExporterScript:
def test_description(self, script):
"""The description attribute returns the class docstring."""
def test_description(self, script: PrometheusExporterScript) -> None:
assert script.description == "A sample script"

def test_description_empty(self, script):
"""The description is empty string if no docstring is set."""
def test_description_empty(self, script: PrometheusExporterScript) -> None:
script.__doc__ = None
assert script.description == ""

def test_logger(self, script):
"""The script logger uses the script name."""
def test_logger(self, script: PrometheusExporterScript) -> None:
assert script.logger.name == "sample-script"

def test_configure_argument_parser(self, script):
"""configure_argument_parser adds specified arguments."""

def configure_argument_parser(parser):
parser.add_argument("test", help="test argument")

script.configure_argument_parser = configure_argument_parser
def test_configure_argument_parser(
self, script: PrometheusExporterScript
) -> None:
parser = script.get_parser()

fh = StringIO()
parser.print_help(file=fh)
assert "test argument" in fh.getvalue()

def test_create_metrics(self, script):
"""Metrics are created based on the configuration."""
def test_create_metrics(self, script: PrometheusExporterScript) -> None:
configs = [
MetricConfig("m1", "desc1", "counter", {}),
MetricConfig("m2", "desc2", "histogram", {}),
Expand All @@ -59,8 +57,9 @@ def test_create_metrics(self, script):
assert metrics["m1"]._type == "counter"
assert metrics["m2"]._type == "histogram"

def test_setup_logging(self, mocker, script):
"""Logging is set up."""
def test_setup_logging(
self, mocker: MockerFixture, script: PrometheusExporterScript
) -> None:
mock_setup_logger = mocker.patch(
"prometheus_aioexporter._script.setup_logger"
)
Expand All @@ -79,34 +78,39 @@ def test_setup_logging(self, mocker, script):
]
mock_setup_logger.assert_has_calls(calls)

def test_change_metrics_path(self, script):
"""The path under which metrics are exposed can be changed."""
def test_change_metrics_path(
self, script: PrometheusExporterScript
) -> None:
args = script.get_parser().parse_args(
["--metrics-path", "/other-path"]
)
exporter = script._get_exporter(args)
assert exporter.metrics_path == "/other-path"

def test_only_ssl_key(self, script, tls_private_key_path):
"""The path under which metrics are exposed can be changed."""
def test_only_ssl_key(
self, script: PrometheusExporterScript, tls_private_key_path: str
) -> None:
args = script.get_parser().parse_args(
["--ssl-private-key", tls_private_key_path]
)
exporter = script._get_exporter(args)
assert exporter.ssl_context is None

def test_only_ssl_cert(self, script, tls_public_key_path):
"""The path under which metrics are exposed can be changed."""
def test_only_ssl_cert(
self, script: PrometheusExporterScript, tls_public_key_path: str
) -> None:
args = script.get_parser().parse_args(
["--ssl-public-key", tls_public_key_path]
)
exporter = script._get_exporter(args)
assert exporter.ssl_context is None

def test_ssl_components_without_ca(
self, script, tls_private_key_path, tls_public_key_path
):
"""The path under which metrics are exposed can be changed."""
self,
script: PrometheusExporterScript,
tls_private_key_path: str,
tls_public_key_path: str,
) -> None:
args = script.get_parser().parse_args(
[
"--ssl-public-key",
Expand All @@ -120,9 +124,12 @@ def test_ssl_components_without_ca(
assert len(exporter.ssl_context.get_ca_certs()) != 1

def test_ssl_components(
self, script, tls_private_key_path, tls_ca_path, tls_public_key_path
):
"""The path under which metrics are exposed can be changed."""
self,
script: PrometheusExporterScript,
tls_private_key_path: str,
tls_ca_path: str,
tls_public_key_path: str,
) -> None:
args = script.get_parser().parse_args(
[
"--ssl-public-key",
Expand All @@ -137,8 +144,9 @@ def test_ssl_components(
assert isinstance(exporter.ssl_context, SSLContext)
assert len(exporter.ssl_context.get_ca_certs()) == 1

def test_include_process_stats(self, mocker, script):
"""The script can include process stats in metrics."""
def test_include_process_stats(
self, mocker: MockerFixture, script: PrometheusExporterScript
) -> None:
mocker.patch("prometheus_aioexporter._web.PrometheusExporter.run")
script(["--process-stats"])
# process stats are present in the registry
Expand All @@ -147,22 +155,22 @@ def test_include_process_stats(self, mocker, script):
in script.registry.registry._names_to_collectors
)

def test_get_exporter_registers_handlers(self, script):
"""Startup/shutdown handlers are registered with the application."""
def test_get_exporter_registers_handlers(
self, script: PrometheusExporterScript
) -> None:
args = script.get_parser().parse_args([])
exporter = script._get_exporter(args)
assert script.on_application_startup in exporter.app.on_startup
assert script.on_application_shutdown in exporter.app.on_shutdown

def test_script_run_exporter_ssl(
self,
mocker,
script,
ssl_context,
tls_private_key_path,
tls_public_key_path,
):
"""The script runs the exporter application."""
mocker: MockerFixture,
script: PrometheusExporterScript,
ssl_context: SSLContext,
tls_private_key_path: str,
tls_public_key_path: str,
) -> None:
mock_run_app = mocker.patch("prometheus_aioexporter._web.run_app")
script(
[
Expand All @@ -177,8 +185,9 @@ def test_script_run_exporter_ssl(
mock_run_app.call_args.kwargs["ssl_context"], SSLContext
)

def test_script_run_exporter(self, mocker, script):
"""The script runs the exporter application."""
def test_script_run_exporter(
self, mocker: MockerFixture, script: PrometheusExporterScript
) -> None:
mock_run_app = mocker.patch("prometheus_aioexporter._web.run_app")
script([])
mock_run_app.assert_called_with(
Expand Down
Loading

0 comments on commit ee9b3c7

Please sign in to comment.