From 086cd7f16abc9d3367de3b741ce878c1c6ae89bb Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Fri, 1 Sep 2023 11:39:12 -0700 Subject: [PATCH 1/6] Remove trailing whitespace Remove some trailing whitespace caught by the new pre-commit hooks. --- docs/user-guide/github-apps/webhook-models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user-guide/github-apps/webhook-models.rst b/docs/user-guide/github-apps/webhook-models.rst index 60df6a53..794ace54 100644 --- a/docs/user-guide/github-apps/webhook-models.rst +++ b/docs/user-guide/github-apps/webhook-models.rst @@ -8,7 +8,7 @@ This page provides a quick reference for the Pydantic models provided in `safir. Safir's coverage of GitHub webhooks is not exhaustive. You can contribute additional models as needed. - + Additionally, the models are not necessarily complete. GitHub may provide additional fields that are not parsed by these models because they were not deemed relevant. To use additional fields documented by GitHub, you can either subclass these models and add additional fields, or contribute updates to the models in Safir. From 5f4228cd8f28f6e801c2f9b27fa497f241f85bb1 Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Fri, 1 Sep 2023 11:45:11 -0700 Subject: [PATCH 2/6] Fix the majority of Ruff-detected issues Let Ruff do its automatic fixes and fix nearly all of the remaining issues. The most controversial change is probably that Ruff by default wants the first line of a docstring to fit on a single line, so some docstrings have been restructured to satisfy this check. --- src/safir/__init__.py | 4 +- src/safir/arq.py | 64 ++++++++------- src/safir/database.py | 22 ++--- src/safir/datetime.py | 12 +-- src/safir/dependencies/arq.py | 19 ++--- src/safir/dependencies/db_session.py | 12 +-- src/safir/dependencies/http_client.py | 4 +- src/safir/dependencies/logger.py | 10 +-- src/safir/fastapi.py | 8 +- src/safir/github/_client.py | 24 +++--- src/safir/github/models.py | 26 +++--- src/safir/github/webhooks.py | 43 ++++++---- src/safir/logging.py | 4 +- src/safir/metadata.py | 19 ++--- src/safir/middleware/x_forwarded.py | 5 +- src/safir/models.py | 3 +- src/safir/pydantic.py | 10 +-- src/safir/redis.py | 14 ++-- src/safir/slack/blockkit.py | 29 +++---- src/safir/slack/webhook.py | 8 +- src/safir/testing/gcs.py | 31 ++++--- src/safir/testing/kubernetes.py | 112 +++++++++++++++----------- src/safir/testing/slack.py | 2 +- src/safir/testing/uvicorn.py | 18 ++--- tests/conftest.py | 2 +- tests/database_test.py | 10 +-- tests/datetime_test.py | 22 ++--- tests/dependencies/arq_test.py | 23 +++--- tests/github/webhooks_test.py | 2 +- tests/logging_test.py | 14 ++-- tests/metadata_test.py | 9 +-- tests/middleware/x_forwarded_test.py | 3 +- tests/pydantic_test.py | 29 ++++--- tests/redis_test.py | 8 +- tests/testing/gcs_test.py | 4 +- tests/testing/kubernetes_test.py | 6 +- 36 files changed, 321 insertions(+), 314 deletions(-) diff --git a/src/safir/__init__.py b/src/safir/__init__.py index 486c9a01..682c7dc9 100644 --- a/src/safir/__init__.py +++ b/src/safir/__init__.py @@ -1,4 +1,6 @@ -"""Safir is the Rubin Observatory's library for building FastAPI services +"""Support library for the Rubin Science Platform. + +Safir is the Rubin Observatory's library for building FastAPI services for the Rubin Science Platform. """ diff --git a/src/safir/arq.py b/src/safir/arq.py index 6ad86b46..0abac796 100644 --- a/src/safir/arq.py +++ b/src/safir/arq.py @@ -1,6 +1,4 @@ -"""An `arq `__ client with a mock for -testing. -""" +"""An arq_ client with a mock for testing.""" from __future__ import annotations @@ -9,13 +7,15 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Optional, Self +from typing import Any, Self from arq import create_pool from arq.connections import ArqRedis, RedisSettings from arq.constants import default_queue_name as arq_default_queue_name from arq.jobs import Job, JobStatus +from .datetime import current_datetime + __all__ = [ "ArqJobError", "JobNotQueued", @@ -173,7 +173,7 @@ async def from_job(cls, job: Job) -> Self: status=job_status, # private attribute of Job; not available in JobDef # queue_name is available in JobResult - queue_name=job._queue_name, + queue_name=job._queue_name, # noqa: SLF001 ) @@ -269,11 +269,13 @@ async def from_job(cls, job: Job) -> Self: class ArqQueue(metaclass=abc.ABCMeta): - """An common interface for working with an arq queue that can be + """arq queue interface supporting either Redis or an in-memory repository. + + Provides a common interface for working with an arq queue that can be implemented either with a real Redis backend, or an in-memory repository for testing. - See also + See Also -------- RedisArqQueue Production implementation with a Redis store. @@ -288,8 +290,9 @@ def __init__( @property def default_queue_name(self) -> str: - """Name of the default queue, if the ``_queue_name`` parameter is - no set in method calls. + """Name of the default queue. + + Used if the ``_queue_name`` parameter is not set in method calls. """ return self._default_queue_name @@ -298,7 +301,7 @@ async def enqueue( self, task_name: str, *task_args: Any, - _queue_name: Optional[str] = None, + _queue_name: str | None = None, **task_kwargs: Any, ) -> JobMetadata: """Add a job to the queue. @@ -328,7 +331,7 @@ async def enqueue( @abc.abstractmethod async def get_job_metadata( - self, job_id: str, queue_name: Optional[str] = None + self, job_id: str, queue_name: str | None = None ) -> JobMetadata: """Get metadata about a `~arq.jobs.Job`. @@ -354,9 +357,9 @@ async def get_job_metadata( @abc.abstractmethod async def get_job_result( - self, job_id: str, queue_name: Optional[str] = None + self, job_id: str, queue_name: str | None = None ) -> JobResult: - """The job result, if available. + """Retrieve the job result, if available. Parameters ---------- @@ -410,7 +413,7 @@ async def enqueue( self, task_name: str, *task_args: Any, - _queue_name: Optional[str] = None, + _queue_name: str | None = None, **task_kwargs: Any, ) -> JobMetadata: job = await self._pool.enqueue_job( @@ -422,10 +425,11 @@ async def enqueue( if job: return await JobMetadata.from_job(job) else: - # TODO if implementing hard-coded job IDs, set as an argument + # TODO(jonathansick): if implementing hard-coded job IDs, set as + # an argument raise JobNotQueued(None) - def _get_job(self, job_id: str, queue_name: Optional[str] = None) -> Job: + def _get_job(self, job_id: str, queue_name: str | None = None) -> Job: return Job( job_id, self._pool, @@ -433,13 +437,13 @@ def _get_job(self, job_id: str, queue_name: Optional[str] = None) -> Job: ) async def get_job_metadata( - self, job_id: str, queue_name: Optional[str] = None + self, job_id: str, queue_name: str | None = None ) -> JobMetadata: job = self._get_job(job_id, queue_name=queue_name) return await JobMetadata.from_job(job) async def get_job_result( - self, job_id: str, queue_name: Optional[str] = None + self, job_id: str, queue_name: str | None = None ) -> JobResult: job = self._get_job(job_id, queue_name=queue_name) return await JobResult.from_job(job) @@ -466,7 +470,7 @@ async def enqueue( self, task_name: str, *task_args: Any, - _queue_name: Optional[str] = None, + _queue_name: str | None = None, **task_kwargs: Any, ) -> JobMetadata: queue_name = self._resolve_queue_name(_queue_name) @@ -475,7 +479,7 @@ async def enqueue( name=task_name, args=task_args, kwargs=task_kwargs, - enqueue_time=datetime.now(), + enqueue_time=current_datetime(microseconds=True), status=JobStatus.queued, queue_name=queue_name, ) @@ -483,25 +487,25 @@ async def enqueue( return new_job async def get_job_metadata( - self, job_id: str, queue_name: Optional[str] = None + self, job_id: str, queue_name: str | None = None ) -> JobMetadata: queue_name = self._resolve_queue_name(queue_name) try: return self._job_metadata[queue_name][job_id] - except KeyError: - raise JobNotFound(job_id) + except KeyError as e: + raise JobNotFound(job_id) from e async def get_job_result( - self, job_id: str, queue_name: Optional[str] = None + self, job_id: str, queue_name: str | None = None ) -> JobResult: queue_name = self._resolve_queue_name(queue_name) try: return self._job_results[queue_name][job_id] - except KeyError: - raise JobResultUnavailable(job_id) + except KeyError as e: + raise JobResultUnavailable(job_id) from e async def set_in_progress( - self, job_id: str, queue_name: Optional[str] = None + self, job_id: str, queue_name: str | None = None ) -> None: """Set a job's status to in progress, for mocking a queue in tests.""" job = await self.get_job_metadata(job_id, queue_name=queue_name) @@ -517,7 +521,7 @@ async def set_complete( *, result: Any, success: bool = True, - queue_name: Optional[str] = None, + queue_name: str | None = None, ) -> None: """Set a job's result, for mocking a queue in tests.""" queue_name = self._resolve_queue_name(queue_name) @@ -534,8 +538,8 @@ async def set_complete( kwargs=job_metadata.kwargs, status=job_metadata.status, enqueue_time=job_metadata.enqueue_time, - start_time=datetime.now(), - finish_time=datetime.now(), + start_time=current_datetime(microseconds=True), + finish_time=current_datetime(microseconds=True), result=result, success=success, queue_name=queue_name, diff --git a/src/safir/database.py b/src/safir/database.py index 380ddc96..d8b920b8 100644 --- a/src/safir/database.py +++ b/src/safir/database.py @@ -4,8 +4,8 @@ import asyncio import time -from datetime import datetime, timezone -from typing import Optional, overload +from datetime import UTC, datetime +from typing import overload from urllib.parse import quote, urlparse from sqlalchemy import create_engine @@ -107,9 +107,9 @@ def datetime_from_db(time: datetime | None) -> datetime | None: """ if not time: return None - if time.tzinfo not in (None, timezone.utc): + if time.tzinfo not in (None, UTC): raise ValueError(f"datetime {time} not in UTC") - return time.replace(tzinfo=timezone.utc) + return time.replace(tzinfo=UTC) @overload @@ -140,7 +140,7 @@ def datetime_to_db(time: datetime | None) -> datetime | None: """ if not time: return None - if time.tzinfo != timezone.utc: + if time.tzinfo != UTC: raise ValueError(f"datetime {time} not in UTC") return time.replace(tzinfo=None) @@ -149,7 +149,7 @@ def create_database_engine( url: str, password: str | None, *, - isolation_level: Optional[str] = None, + isolation_level: str | None = None, ) -> AsyncEngine: """Create a new async database engine. @@ -185,9 +185,9 @@ def create_database_engine( async def create_async_session( engine: AsyncEngine, - logger: Optional[BoundLogger] = None, + logger: BoundLogger | None = None, *, - statement: Optional[Select] = None, + statement: Select | None = None, ) -> async_scoped_session: """Create a new async database session. @@ -247,10 +247,10 @@ async def create_async_session( def create_sync_session( url: str, password: str | None, - logger: Optional[BoundLogger] = None, + logger: BoundLogger | None = None, *, - isolation_level: Optional[str] = None, - statement: Optional[Select] = None, + isolation_level: str | None = None, + statement: Select | None = None, ) -> scoped_session: """Create a new sync database session. diff --git a/src/safir/datetime.py b/src/safir/datetime.py index c6243c1a..beafd0f8 100644 --- a/src/safir/datetime.py +++ b/src/safir/datetime.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import overload __all__ = [ @@ -36,7 +36,7 @@ def current_datetime(*, microseconds: bool = False) -> datetime: The current time forced to UTC and optionally with the microseconds field zeroed. """ - result = datetime.now(tz=timezone.utc) + result = datetime.now(tz=UTC) if microseconds: return result else: @@ -75,8 +75,8 @@ def format_datetime_for_logging(timestamp: datetime | None) -> str | None: Raised if the argument is in a time zone other than UTC. """ if timestamp: - if timestamp.tzinfo not in (None, timezone.utc): - raise ValueError("Datetime {timestamp} not in UTC") + if timestamp.tzinfo not in (None, UTC): + raise ValueError(f"datetime {timestamp} not in UTC") if timestamp.microsecond: result = timestamp.isoformat(sep=" ", timespec="milliseconds") else: @@ -106,8 +106,8 @@ def isodatetime(timestamp: datetime) -> str: ValueError The provided timestamp was not in UTC. """ - if timestamp.tzinfo not in (None, timezone.utc): - raise ValueError("Datetime {timestamp} not in UTC") + if timestamp.tzinfo not in (None, UTC): + raise ValueError(f"datetime {timestamp} not in UTC") return timestamp.strftime("%Y-%m-%dT%H:%M:%SZ") diff --git a/src/safir/dependencies/arq.py b/src/safir/dependencies/arq.py index 3356a383..af7ab6c0 100644 --- a/src/safir/dependencies/arq.py +++ b/src/safir/dependencies/arq.py @@ -1,11 +1,7 @@ -"""A FastAPI dependency that supplies a Redis connection for `arq -`__. -""" +"""A FastAPI dependency that supplies a Redis connection for arq_.""" from __future__ import annotations -from typing import Optional - from arq.connections import RedisSettings from ..arq import ArqMode, ArqQueue, MockArqQueue, RedisArqQueue @@ -14,12 +10,15 @@ class ArqDependency: - """A FastAPI dependency that maintains a Redis client for enqueing - tasks to the worker pool. + """FastAPI dependency providing a client for enqueuing tasks. + + This class maintains a singleton Redis client for enqueuing tasks to an + arq_ worker pool and provides it to handler methods via the FastAPI + dependency interface. """ def __init__(self) -> None: - self._arq_queue: Optional[ArqQueue] = None + self._arq_queue: ArqQueue | None = None async def initialize( self, *, mode: ArqMode, redis_settings: RedisSettings | None @@ -87,6 +86,4 @@ async def __call__(self) -> ArqQueue: arq_dependency = ArqDependency() -"""Singleton instance of `ArqDependency` that serves as a FastAPI -dependency. -""" +"""Singleton instance of `ArqDependency` as a FastAPI dependency.""" diff --git a/src/safir/dependencies/db_session.py b/src/safir/dependencies/db_session.py index c67ca4f1..f9ab8be2 100644 --- a/src/safir/dependencies/db_session.py +++ b/src/safir/dependencies/db_session.py @@ -1,7 +1,6 @@ """Manage an async database session.""" from collections.abc import AsyncIterator -from typing import Optional from sqlalchemy.ext.asyncio import AsyncEngine, async_scoped_session @@ -37,9 +36,9 @@ class DatabaseSessionDependency: """ def __init__(self) -> None: - self._engine: Optional[AsyncEngine] = None - self._override_engine: Optional[AsyncEngine] = None - self._session: Optional[async_scoped_session] = None + self._engine: AsyncEngine | None = None + self._override_engine: AsyncEngine | None = None + self._session: async_scoped_session | None = None async def __call__(self) -> AsyncIterator[async_scoped_session]: """Return the database session manager. @@ -49,7 +48,8 @@ async def __call__(self) -> AsyncIterator[async_scoped_session]: sqlalchemy.ext.asyncio.AsyncSession The newly-created session. """ - assert self._session, "db_session_dependency not initialized" + if not self._session: + raise RuntimeError("db_session_dependency not initialized") yield self._session # Following the recommendations in the SQLAlchemy documentation, each @@ -69,7 +69,7 @@ async def initialize( url: str, password: str | None, *, - isolation_level: Optional[str] = None, + isolation_level: str | None = None, ) -> None: """Initialize the session dependency. diff --git a/src/safir/dependencies/http_client.py b/src/safir/dependencies/http_client.py index 80a1b73d..28ecd794 100644 --- a/src/safir/dependencies/http_client.py +++ b/src/safir/dependencies/http_client.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import Optional - import httpx __all__ = [ @@ -40,7 +38,7 @@ async def shutdown_event() -> None: """ def __init__(self) -> None: - self.http_client: Optional[httpx.AsyncClient] = None + self.http_client: httpx.AsyncClient | None = None async def __call__(self) -> httpx.AsyncClient: """Return the cached ``httpx.AsyncClient``.""" diff --git a/src/safir/dependencies/logger.py b/src/safir/dependencies/logger.py index 3efc6ef3..6a4e8b53 100644 --- a/src/safir/dependencies/logger.py +++ b/src/safir/dependencies/logger.py @@ -5,7 +5,6 @@ """ import uuid -from typing import Optional import structlog from fastapi import Request @@ -33,7 +32,7 @@ class LoggerDependency: """ def __init__(self) -> None: - self.logger: Optional[BoundLogger] = None + self.logger: BoundLogger | None = None async def __call__(self, request: Request) -> BoundLogger: """Return a logger bound with request information. @@ -45,7 +44,9 @@ async def __call__(self, request: Request) -> BoundLogger: """ if not self.logger: self.logger = structlog.get_logger(logging.logger_name) - assert self.logger + if not self.logger: + msg = f"Unable to get logger for {logging.logger_name}" + raise RuntimeError(msg) # Construct the httpRequest logging data (compatible with the format # expected by Google Log Explorer). @@ -59,11 +60,10 @@ async def __call__(self, request: Request) -> BoundLogger: if user_agent: request_data["userAgent"] = user_agent - logger = self.logger.new( + return self.logger.new( httpRequest=request_data, request_id=str(uuid.uuid4()), ) - return logger logger_dependency = LoggerDependency() diff --git a/src/safir/fastapi.py b/src/safir/fastapi.py index 9785cb11..ddad1bb8 100644 --- a/src/safir/fastapi.py +++ b/src/safir/fastapi.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import ClassVar, Optional +from typing import ClassVar from fastapi import Request, status from fastapi.responses import JSONResponse @@ -139,8 +139,8 @@ async def get_info(username: str) -> UserInfo: def __init__( self, message: str, - location: Optional[ErrorLocation] = None, - field_path: Optional[list[str]] = None, + location: ErrorLocation | None = None, + field_path: list[str] | None = None, ) -> None: super().__init__(message) self.location = location @@ -169,7 +169,7 @@ def to_dict(self) -> dict[str, list[str] | str]: } if self.location: if self.field_path: - result["loc"] = [self.location.value] + self.field_path + result["loc"] = [self.location.value, *self.field_path] else: result["loc"] = [self.location.value] return result diff --git a/src/safir/github/_client.py b/src/safir/github/_client.py index 354427e0..8a7a532a 100644 --- a/src/safir/github/_client.py +++ b/src/safir/github/_client.py @@ -1,15 +1,15 @@ from __future__ import annotations -from typing import Optional - import gidgethub.apps import httpx from gidgethub.httpx import GitHubAPI class GitHubAppClientFactory: - """Factory for creating GitHub App clients authenticated either as an app - or as an installation of that app. + """Create GitHub App clients. + + Provides a factory for creating GitHub App clients authenticated either as + an app or as an installation of that app. Parameters ---------- @@ -48,9 +48,7 @@ def get_app_jwt(self) -> str: app_id=self.app_id, private_key=self.app_key ) - def _create_client( - self, *, oauth_token: Optional[str] = None - ) -> GitHubAPI: + def _create_client(self, *, oauth_token: str | None = None) -> GitHubAPI: return GitHubAPI( self._http_client, self.app_name, oauth_token=oauth_token ) @@ -78,8 +76,10 @@ def create_app_client(self) -> GitHubAPI: async def create_installation_client( self, installation_id: str ) -> GitHubAPI: - """Create a client authenticated as an installation of the GitHub App - for a specific repository or organization. + """Create a client for an installation of the GitHub App. + + The resulting client is authenticated as an installation of the GitHub + App for a specific repository or organization. Parameters ---------- @@ -106,8 +106,10 @@ async def create_installation_client( async def create_installation_client_for_repo( self, owner: str, repo: str ) -> GitHubAPI: - """Create a client authenticated as an installation of the GitHub App - for a specific repository or organization. + """Create a client for a repository installation of the GitHub App. + + The resulting client is authenticated as an installation of the GitHub + App for a specific repository. Parameters ---------- diff --git a/src/safir/github/models.py b/src/safir/github/models.py index 4c214571..fa268075 100644 --- a/src/safir/github/models.py +++ b/src/safir/github/models.py @@ -67,8 +67,9 @@ class GitHubUserModel(BaseModel): class GitHubRepositoryModel(BaseModel): - """A Pydantic model for the ``repository`` field, often found in webhook - payloads. + """A Pydantic model for the ``repository`` field. + + This field is often found in webhook payloads. https://docs.github.com/en/rest/repos/repos#get-a-repository """ @@ -270,8 +271,10 @@ class GitHubCheckSuiteConclusion(str, Enum): class GitHubCheckSuiteModel(BaseModel): - """A Pydantic model for the ``check_suite`` field in a ``check_suite`` - webhook (`~safir.github.webhooks.GitHubCheckSuiteEventModel`). + """A Pydantic model for the ``check_suite`` field. + + This field is found in a ``check_suite`` webhook + (`~safir.github.webhooks.GitHubCheckSuiteEventModel`). """ id: str = Field(description="Identifier for this check run.") @@ -320,9 +323,7 @@ class GitHubCheckRunConclusion(str, Enum): """The check run has failed.""" neutral = "neutral" - """The check run has a neutral outcome, perhaps because the check was - skipped. - """ + """The check run has a neutral outcome, perhaps because it was skipped.""" cancelled = "cancelled" """The check run was cancelled.""" @@ -369,8 +370,9 @@ class GitHubCheckRunOutput(BaseModel): class GitHubCheckRunPrInfoModel(BaseModel): - """A Pydantic model of the ``pull_requests[]`` items in a check run - GitHub API model (`GitHubCheckRunModel`). + """A Pydantic model of the ``pull_requests[]`` items. + + These are found in a check run GitHub API model (`GitHubCheckRunModel`). https://docs.github.com/en/rest/checks/runs#get-a-check-run """ @@ -379,8 +381,10 @@ class GitHubCheckRunPrInfoModel(BaseModel): class GitHubCheckRunModel(BaseModel): - """A Pydantic model for the "check_run" field in a check_run webhook - payload (`~safir.github.webhooks.GitHubCheckRunEventModel`). + """A Pydantic model for the ``check_run`` field. + + This is found in a check_run webhook payload + (`~safir.github.webhooks.GitHubCheckRunEventModel`). """ id: str = Field(description="Identifier for this check run.") diff --git a/src/safir/github/webhooks.py b/src/safir/github/webhooks.py index 5815679f..bfae5949 100644 --- a/src/safir/github/webhooks.py +++ b/src/safir/github/webhooks.py @@ -31,16 +31,18 @@ class GitHubAppInstallationModel(BaseModel): - """A Pydantic model for the ``installation`` field found in webhook - payloads for GitHub Apps. + """A Pydantic model for the ``installation`` field found. + + This field is found in webhook payloads for GitHub Apps. """ id: str = Field(description="The installation ID.") class GitHubPushEventModel(BaseModel): - """A Pydantic model for the ``push`` event webhook when a commit or - tag is pushed. + """A Pydantic model for the ``push`` event webhook. + + This webhook is triggered when a commit or tag is pushed. https://docs.github.com/en/webhooks/webhook-events-and-payloads#push """ @@ -71,8 +73,10 @@ class GitHubPushEventModel(BaseModel): class GitHubAppInstallationEventRepoModel(BaseModel): - """A pydantic model for repository objects used by - `GitHubAppInstallationRepositoriesEventModel`. + """A Pydantic model for repository objects used by installation events. + + See `GitHubAppInstallationRepositoriesEventModel` for where this model is + used. https://docs.github.com/en/webhooks/webhook-events-and-payloads#installation """ @@ -94,8 +98,10 @@ def owner_name(self) -> str: class GitHubAppInstallationEventAction(str, Enum): - """The action performed on an GitHub App ``installation`` webhook - (`GitHubAppInstallationEventModel`). + """The action performed on an GitHub App ``installation`` webhook. + + See `GitHubAppInstallationEventModel` for the model of the event where + this model is used. """ created = "created" @@ -136,8 +142,10 @@ class GitHubAppInstallationEventModel(BaseModel): class GitHubAppInstallationRepositoriesEventAction(str, Enum): - """The action performed on a GitHub App ``installation_repositories`` - webhook (`GitHubAppInstallationRepositoriesEventModel`). + """A Pydantic model for a ``installation_repositories`` action. + + This model is for the action performed on a ``installation_repositories`` + GitHub App webhook (`GitHubAppInstallationRepositoriesEventModel`). """ #: Someone added a repository to an installation. @@ -171,8 +179,9 @@ class GitHubAppInstallationRepositoriesEventModel(BaseModel): class GitHubPullRequestEventAction(str, Enum): - """The action performed on a GitHub ``pull_request`` webhook - (`GitHubPullRequestEventModel`). + """The action performed on a GitHub ``pull_request`` webhook. + + See `GitHubPullRequestEventModel` for where this model is used. """ assigned = "assigned" @@ -267,8 +276,9 @@ class GitHubPullRequestEventModel(BaseModel): class GitHubCheckSuiteEventAction(str, Enum): - """The action performed in a GitHub ``check_suite`` webhook - (`GitHubCheckSuiteEventModel`). + """The action performed in a GitHub ``check_suite`` webhook. + + See `GitHubCheckSuiteEventModel` for where this model is used. """ completed = "completed" @@ -307,8 +317,9 @@ class GitHubCheckSuiteEventModel(BaseModel): class GitHubCheckRunEventAction(str, Enum): - """The action performed in a GitHub ``check_run`` webhook - (`GitHubCheckRunEventModel`). + """The action performed in a GitHub ``check_run`` webhook. + + See `GitHubCheckRunEventModel` for where this model is used. """ completed = "completed" diff --git a/src/safir/logging.py b/src/safir/logging.py index 998e31bc..272357f1 100644 --- a/src/safir/logging.py +++ b/src/safir/logging.py @@ -7,7 +7,7 @@ import re import sys from enum import Enum -from typing import Any, Optional +from typing import Any import structlog from structlog.stdlib import add_log_level @@ -22,7 +22,7 @@ "logger_name", ] -logger_name: Optional[str] = None +logger_name: str | None = None """Name of the configured global logger. When `configure_logging` is called, the name of the configured logger is diff --git a/src/safir/metadata.py b/src/safir/metadata.py index 250603ca..a36d3c30 100644 --- a/src/safir/metadata.py +++ b/src/safir/metadata.py @@ -1,12 +1,10 @@ -"""Standardized metadata for Roundtable HTTP services. -""" +"""Standardized metadata for Roundtable HTTP services.""" from __future__ import annotations -import sys from email.message import Message from importlib.metadata import metadata -from typing import Optional, cast +from typing import cast from pydantic import BaseModel, Field @@ -20,15 +18,15 @@ class Metadata(BaseModel): version: str = Field(..., title="Version", example="1.0.0") - description: Optional[str] = Field( + description: str | None = Field( None, title="Description", example="string" ) - repository_url: Optional[str] = Field( + repository_url: str | None = Field( None, title="Repository URL", example="https://example.com/" ) - documentation_url: Optional[str] = Field( + documentation_url: str | None = Field( None, title="Documentation URL", example="https://example.com/" ) @@ -78,10 +76,7 @@ def get_metadata(*, package_name: str, application_name: str) -> Metadata: project_urls, Source code Used as the ``respository_url``. """ - if sys.version_info >= (3, 10): - pkg_metadata = cast(Message, metadata(package_name)) - else: - pkg_metadata = metadata(package_name) + pkg_metadata = cast(Message, metadata(package_name)) # Newer packages that use pyproject.toml only do not use the Home-page # field (setuptools in pyproject.toml mode does not support it) and use @@ -103,7 +98,7 @@ def get_metadata(*, package_name: str, application_name: str) -> Metadata: ) -def get_project_url(meta: Message, label: str) -> Optional[str]: +def get_project_url(meta: Message, label: str) -> str | None: """Get a specific URL from a package's ``project_urls`` metadata. Parameters diff --git a/src/safir/middleware/x_forwarded.py b/src/safir/middleware/x_forwarded.py index 8483c4ae..fb6f26d8 100644 --- a/src/safir/middleware/x_forwarded.py +++ b/src/safir/middleware/x_forwarded.py @@ -4,7 +4,6 @@ from collections.abc import Awaitable, Callable from ipaddress import _BaseAddress, _BaseNetwork, ip_address -from typing import Optional from fastapi import FastAPI, Request, Response from starlette.middleware.base import BaseHTTPMiddleware @@ -39,7 +38,7 @@ class XForwardedMiddleware(BaseHTTPMiddleware): """ def __init__( - self, app: FastAPI, *, proxies: Optional[list[_BaseNetwork]] = None + self, app: FastAPI, *, proxies: list[_BaseNetwork] | None = None ) -> None: super().__init__(app) if proxies: @@ -74,7 +73,7 @@ async def dispatch( client = None for n, ip in enumerate(forwarded_for): - if any((ip in network for network in self.proxies)): + if any(ip in network for network in self.proxies): continue client = str(ip) index = n diff --git a/src/safir/models.py b/src/safir/models.py index d1fa1b69..ee68241c 100644 --- a/src/safir/models.py +++ b/src/safir/models.py @@ -8,7 +8,6 @@ """ from enum import Enum -from typing import Optional from pydantic import BaseModel, Field @@ -35,7 +34,7 @@ class ErrorLocation(str, Enum): class ErrorDetail(BaseModel): """The detail of the error message.""" - loc: Optional[list[str]] = Field( + loc: list[str] | None = Field( None, title="Location", example=["area", "field"] ) diff --git a/src/safir/pydantic.py b/src/safir/pydantic.py index 24c28ea3..a0f41829 100644 --- a/src/safir/pydantic.py +++ b/src/safir/pydantic.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Callable -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any, ParamSpec, TypeVar from pydantic import BaseModel @@ -60,11 +60,11 @@ class Info(BaseModel): if v is None: return v elif isinstance(v, int): - return datetime.fromtimestamp(v, tz=timezone.utc) + return datetime.fromtimestamp(v, tz=UTC) elif v.tzinfo and v.tzinfo.utcoffset(v) is not None: - return v.astimezone(timezone.utc) + return v.astimezone(UTC) else: - return v.replace(tzinfo=timezone.utc) + return v.replace(tzinfo=UTC) def normalize_isodatetime(v: str | None) -> datetime | None: @@ -113,7 +113,7 @@ class Info(BaseModel): try: return datetime.fromisoformat(v[:-1] + "+00:00") except Exception as e: - raise ValueError(f"Invalid date {v}: {str(e)}") from e + raise ValueError(f"Invalid date {v}: {e!s}") from e def to_camel_case(string: str) -> str: diff --git a/src/safir/redis.py b/src/safir/redis.py index 13703418..1234bff2 100644 --- a/src/safir/redis.py +++ b/src/safir/redis.py @@ -3,15 +3,15 @@ from __future__ import annotations from collections.abc import AsyncIterator -from typing import Generic, Optional, TypeVar +from typing import Generic, TypeVar try: import redis.asyncio as redis -except ImportError: +except ImportError as e: raise ImportError( "The safir.redis module requires the redis extra. " "Install it with `pip install safir[redis]`." - ) + ) from e from cryptography.fernet import Fernet from pydantic import BaseModel @@ -31,7 +31,9 @@ class DeserializeError(SlackException): - """Raised when a stored Pydantic object in Redis cannot be decoded (and + """Error decoding or deserializing a Pydantic object from Redis. + + Raised when a stored Pydantic object in Redis cannot be decoded (and possibly decrypted) or deserialized. Parameters @@ -140,7 +142,7 @@ async def get(self, key: str) -> S | None: try: return self._deserialize(data) except Exception as e: - error = f"{type(e).__name__}: {str(e)}" + error = f"{type(e).__name__}: {e!s}" msg = f"Cannot deserialize data for key {full_key}" raise DeserializeError(msg, key=full_key, error=error) from e @@ -163,7 +165,7 @@ async def scan(self, pattern: str) -> AsyncIterator[str]: yield key.decode().removeprefix(self._key_prefix) async def store( - self, key: str, obj: S, lifetime: Optional[int] = None + self, key: str, obj: S, lifetime: int | None = None ) -> None: """Store an object. diff --git a/src/safir/slack/blockkit.py b/src/safir/slack/blockkit.py index 836e5b15..8216df18 100644 --- a/src/safir/slack/blockkit.py +++ b/src/safir/slack/blockkit.py @@ -4,7 +4,7 @@ from abc import ABCMeta, abstractmethod from datetime import datetime -from typing import Any, ClassVar, Optional, Self +from typing import Any, ClassVar, Self from httpx import HTTPError, HTTPStatusError from pydantic import BaseModel, validator @@ -249,9 +249,9 @@ class SlackException(Exception): def __init__( self, message: str, - user: Optional[str] = None, + user: str | None = None, *, - failed_at: Optional[datetime] = None, + failed_at: datetime | None = None, ) -> None: self.user = user if failed_at: @@ -307,9 +307,7 @@ class SlackWebException(SlackException): """ @classmethod - def from_exception( - cls, exc: HTTPError, user: Optional[str] = None - ) -> Self: + def from_exception(cls, exc: HTTPError, user: str | None = None) -> Self: """Create an exception from an HTTPX_ exception. Parameters @@ -337,7 +335,7 @@ def from_exception( body=exc.response.text, ) else: - message = f"{type(exc).__name__}: {str(exc)}" + message = f"{type(exc).__name__}: {exc!s}" # All httpx.HTTPError exceptions have a slot for the request, # initialized to None and then sometimes added by child @@ -360,12 +358,12 @@ def __init__( self, message: str, *, - failed_at: Optional[datetime] = None, - method: Optional[str] = None, - url: Optional[str] = None, - user: Optional[str] = None, - status: Optional[int] = None, - body: Optional[str] = None, + failed_at: datetime | None = None, + method: str | None = None, + url: str | None = None, + user: str | None = None, + status: int | None = None, + body: str | None = None, ) -> None: self.message = message self.method = method @@ -391,10 +389,7 @@ def to_slack(self) -> SlackMessage: message = super().to_slack() message.message = self.message if self.url: - if self.method: - text = f"{self.method} {self.url}" - else: - text = self.url + text = f"{self.method} {self.url}" if self.method else self.url message.blocks.append(SlackTextBlock(heading="URL", text=text)) if self.body: block = SlackCodeBlock(heading="Response", code=self.body) diff --git a/src/safir/slack/webhook.py b/src/safir/slack/webhook.py index dfa5b67a..906a1894 100644 --- a/src/safir/slack/webhook.py +++ b/src/safir/slack/webhook.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Callable, Coroutine -from typing import Any, ClassVar, Optional +from typing import Any, ClassVar from fastapi import HTTPException, Request, Response from fastapi.exceptions import RequestValidationError @@ -111,12 +111,12 @@ async def post_uncaught_exception(self, exc: Exception) -> None: """ if isinstance(exc, SlackException): message = exc.to_slack() - msg = f"Uncaught exception in {self._application}: {str(exc)}" + msg = f"Uncaught exception in {self._application}: {exc!s}" message.message = msg else: date = format_datetime_for_logging(current_datetime()) name = type(exc).__name__ - error = f"{name}: {str(exc)}" + error = f"{name}: {exc!s}" message = SlackMessage( message=f"Uncaught exception in {self._application}", fields=[ @@ -164,7 +164,7 @@ class SlackRouteErrorHandler(APIRoute): ) """Uncaught exceptions that should not be sent to Slack.""" - _alert_client: ClassVar[Optional[SlackWebhookClient]] = None + _alert_client: ClassVar[SlackWebhookClient | None] = None """Global Slack alert client used by `SlackRouteErrorHandler`. Initialize with `initialize`. This object caches the alert confguration diff --git a/src/safir/testing/gcs.py b/src/safir/testing/gcs.py index 75438cd9..d93464ab 100644 --- a/src/safir/testing/gcs.py +++ b/src/safir/testing/gcs.py @@ -3,10 +3,10 @@ from __future__ import annotations from collections.abc import Iterator -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from io import BufferedReader from pathlib import Path -from typing import Any, Optional +from typing import Any from unittest.mock import Mock, patch from google.cloud import storage @@ -33,7 +33,7 @@ class MockBlob(Mock): """ def __init__( - self, name: str, expected_expiration: Optional[timedelta] = None + self, name: str, expected_expiration: timedelta | None = None ) -> None: super().__init__(spec=storage.blob.Blob) self.name = name @@ -45,8 +45,8 @@ def generate_signed_url( version: str, expiration: timedelta, method: str, - response_type: Optional[str] = None, - credentials: Optional[Any] = None, + response_type: str | None = None, + credentials: Any | None = None, ) -> str: """Generate a mock signed URL for testing. @@ -105,7 +105,7 @@ def __init__( self, name: str, path: Path, - expected_expiration: Optional[timedelta] = None, + expected_expiration: timedelta | None = None, ) -> None: super().__init__(name, expected_expiration) self._path = path @@ -113,7 +113,7 @@ def __init__( if self._exists: self.size = self._path.stat().st_size mtime = self._path.stat().st_mtime - self.updated = datetime.fromtimestamp(mtime, tz=timezone.utc) + self.updated = datetime.fromtimestamp(mtime, tz=UTC) self.etag = str(self._path.stat().st_ino) def download_as_bytes(self) -> bytes: @@ -163,7 +163,6 @@ def reload(self) -> None: This does nothing in the mock. """ - pass class MockBucket(Mock): @@ -183,8 +182,8 @@ class MockBucket(Mock): def __init__( self, bucket_name: str, - expected_expiration: Optional[timedelta] = None, - path: Optional[Path] = None, + expected_expiration: timedelta | None = None, + path: Path | None = None, ) -> None: super().__init__(spec=storage.bucket.Bucket) self._expected_expiration = expected_expiration @@ -234,9 +233,9 @@ class MockStorageClient(Mock): def __init__( self, - expected_expiration: Optional[timedelta] = None, - path: Optional[Path] = None, - bucket_name: Optional[str] = None, + expected_expiration: timedelta | None = None, + path: Path | None = None, + bucket_name: str | None = None, ) -> None: super().__init__(spec=storage.Client) self._bucket_name = bucket_name @@ -265,9 +264,9 @@ def bucket(self, bucket_name: str) -> MockBucket: def patch_google_storage( *, - expected_expiration: Optional[timedelta] = None, - path: Optional[Path] = None, - bucket_name: Optional[str] = None, + expected_expiration: timedelta | None = None, + path: Path | None = None, + bucket_name: str | None = None, ) -> Iterator[MockStorageClient]: """Replace the Google Cloud Storage API with a mock class. diff --git a/src/safir/testing/kubernetes.py b/src/safir/testing/kubernetes.py index 0d25b091..10c345d9 100644 --- a/src/safir/testing/kubernetes.py +++ b/src/safir/testing/kubernetes.py @@ -10,7 +10,7 @@ from collections import defaultdict from collections.abc import AsyncIterator, Callable, Iterator from datetime import timedelta -from typing import Any, Optional +from typing import Any from unittest.mock import AsyncMock, Mock, patch from kubernetes_asyncio import client, config @@ -77,13 +77,15 @@ def _parse_label_selector(label_selector: str) -> dict[str, str]: result = {} for requirement in label_selector.split(","): match = re.match(r"([^!=]+)==?(.*)", requirement) - assert match and match.group(1) and match.group(2) + assert match + assert match.group(1) + assert match.group(2) result[match.group(1)] = match.group(2) return result def _check_labels( - obj_labels: Optional[dict[str, str]], label_selector: Optional[str] + obj_labels: dict[str, str] | None, label_selector: str | None ) -> bool: """Check whether an object's labels match the label selector supplied. @@ -144,16 +146,18 @@ def strip_none(model: dict[str, Any]) -> dict[str, Any]: for key, value in model.items(): if value is None: continue + new_value = value if isinstance(value, dict): - value = strip_none(value) + new_value: Any = strip_none(value) elif isinstance(value, list): list_result = [] for item in value: if isinstance(item, dict): - item = strip_none(item) - list_result.append(item) - value = list_result - result[key] = value + list_result.append(strip_none(item)) + else: + list_result.append(item) + new_value = list_result + result[key] = new_value return result @@ -197,11 +201,11 @@ def add_event(self, event: dict[str, Any]) -> None: def build_watch_response( self, - resource_version: Optional[str] = None, - timeout_seconds: Optional[int] = None, + resource_version: str | None = None, + timeout_seconds: int | None = None, *, - field_selector: Optional[str] = None, - label_selector: Optional[str] = None, + field_selector: str | None = None, + label_selector: str | None = None, ) -> Mock: """Construct a response to a watch request. @@ -245,7 +249,7 @@ async def readline() -> bytes: response.content.readline.side_effect = readline return response - def _build_watcher( + def _build_watcher( # noqa: C901 self, resource_version: str | None, timeout_seconds: int | None, @@ -291,7 +295,8 @@ def _build_watcher( name = None if field_selector: match = re.match(r"metadata\.name=(.*)$", field_selector) - assert match and match.group(1) + assert match + assert match.group(1) name = match.group(1) # Create and register a new trigger. @@ -387,7 +392,7 @@ class MockKubernetesApi: """ def __init__(self) -> None: - self.error_callback: Optional[Callable[..., None]] = None + self.error_callback: Callable[..., None] | None = None self.initial_pod_phase = "Running" self._custom_kinds: dict[str, str] = {} @@ -417,12 +422,12 @@ def get_all_objects_for_test(self, kind: str) -> list[Any]: for namespace in sorted(self._objects.keys()): if key not in self._objects[namespace]: continue - for name, obj in sorted(self._objects[namespace][key].items()): + for _name, obj in sorted(self._objects[namespace][key].items()): results.append(obj) return results def get_namespace_objects_for_test(self, namespace: str) -> list[Any]: - """Returns all objects in the given namespace. + """Return all objects in the given namespace. Parameters ---------- @@ -613,8 +618,8 @@ async def list_cluster_custom_object( self._maybe_error("list_cluster_custom_object", group, version, plural) key = f"{group}/{version}/{plural}" results = [] - for namespace in self._objects.keys(): - for name, obj in self._objects[namespace].get(key, {}).items(): + for namespace in self._objects: + for obj in self._objects[namespace].get(key, {}).values(): results.append(obj) return {"items": results} @@ -829,12 +834,12 @@ async def list_namespaced_event( self, namespace: str, *, - field_selector: Optional[str] = None, - resource_version: Optional[str] = None, - timeout_seconds: Optional[int] = None, + field_selector: str | None = None, + resource_version: str | None = None, + timeout_seconds: int | None = None, watch: bool = False, _preload_content: bool = True, - _request_timeout: Optional[int] = None, + _request_timeout: int | None = None, ) -> CoreV1EventList | Mock: """List namespaced events. @@ -984,13 +989,13 @@ async def list_namespaced_ingress( self, namespace: str, *, - field_selector: Optional[str] = None, - label_selector: Optional[str] = None, - resource_version: Optional[str] = None, - timeout_seconds: Optional[int] = None, + field_selector: str | None = None, + label_selector: str | None = None, + resource_version: str | None = None, + timeout_seconds: int | None = None, watch: bool = False, _preload_content: bool = True, - _request_timeout: Optional[int] = None, + _request_timeout: int | None = None, ) -> V1IngressList | Mock: """List ingress objects in a namespace. @@ -1040,7 +1045,8 @@ async def list_namespaced_ingress( if not watch: if field_selector: match = re.match(r"metadata\.name=(.*)$", field_selector) - assert match and match.group(1) + assert match + assert match.group(1) try: ingress = self._get_object( namespace, "Ingress", match.group(1) @@ -1222,13 +1228,13 @@ async def list_namespaced_job( self, namespace: str, *, - field_selector: Optional[str] = None, - label_selector: Optional[str] = None, - resource_version: Optional[str] = None, - timeout_seconds: Optional[int] = None, + field_selector: str | None = None, + label_selector: str | None = None, + resource_version: str | None = None, + timeout_seconds: int | None = None, watch: bool = False, _preload_content: bool = True, - _request_timeout: Optional[int] = None, + _request_timeout: int | None = None, ) -> V1JobList | Mock: """List job objects in a namespace. @@ -1278,7 +1284,8 @@ async def list_namespaced_job( if not watch: if field_selector: match = re.match(r"metadata\.name=(.*)$", field_selector) - assert match and match.group(1) + assert match + assert match.group(1) try: job = self._get_object(namespace, "Job", match.group(1)) return V1JobList(kind="Job", items=[job]) @@ -1583,13 +1590,13 @@ async def list_namespaced_pod( self, namespace: str, *, - field_selector: Optional[str] = None, - label_selector: Optional[str] = None, - resource_version: Optional[str] = None, - timeout_seconds: Optional[int] = None, + field_selector: str | None = None, + label_selector: str | None = None, + resource_version: str | None = None, + timeout_seconds: int | None = None, watch: bool = False, _preload_content: bool = True, - _request_timeout: Optional[int] = None, + _request_timeout: int | None = None, ) -> V1PodList | Mock: """List pod objects in a namespace. @@ -1638,7 +1645,8 @@ async def list_namespaced_pod( if not watch: if field_selector: match = re.match(r"metadata\.name=(.*)$", field_selector) - assert match and match.group(1) + assert match + assert match.group(1) try: pod = self._get_object(namespace, "Pod", match.group(1)) if _check_labels(pod.metadata.labels, label_selector): @@ -1866,7 +1874,7 @@ async def patch_namespaced_secret( elif change["path"] == "/metadata/labels": obj.metadata.labels = change["value"] else: - assert False, f'unsupported path {change["path"]}' + raise AssertionError(f"unsupported path {change['path']}") self._store_object(namespace, "Secret", name, obj, replace=True) async def read_namespaced_secret( @@ -1977,13 +1985,13 @@ async def list_namespaced_service( self, namespace: str, *, - field_selector: Optional[str] = None, - label_selector: Optional[str] = None, - resource_version: Optional[str] = None, - timeout_seconds: Optional[int] = None, + field_selector: str | None = None, + label_selector: str | None = None, + resource_version: str | None = None, + timeout_seconds: int | None = None, watch: bool = False, _preload_content: bool = True, - _request_timeout: Optional[int] = None, + _request_timeout: int | None = None, ) -> V1ServiceList | Mock: """List service objects in a namespace. @@ -2033,7 +2041,8 @@ async def list_namespaced_service( if not watch: if field_selector: match = re.match(r"metadata\.name=(.*)$", field_selector) - assert match and match.group(1) + assert match + assert match.group(1) try: service = self._get_object( namespace, "Service", match.group(1) @@ -2131,7 +2140,11 @@ def _get_object(self, namespace: str, key: str, name: str) -> Any: return self._objects[namespace][key][name] def _maybe_error(self, method: str, *args: Any) -> None: - """Helper function to avoid using class method call syntax.""" + """Call the error callback if one is registered. + + This is a separate helper function to avoid using class method call + syntax. + """ if self.error_callback: callback = self.error_callback callback(method, *args) @@ -2142,6 +2155,7 @@ def _store_object( key: str, name: str, obj: Any, + *, replace: bool = False, ) -> None: """Store an object in internal data structures. diff --git a/src/safir/testing/slack.py b/src/safir/testing/slack.py index 1c5442bf..7874dffd 100644 --- a/src/safir/testing/slack.py +++ b/src/safir/testing/slack.py @@ -27,7 +27,7 @@ def __init__(self, url: str) -> None: self.messages: list[dict[str, Any]] = [] def post_webhook(self, request: Request) -> Response: - """Callback called whenever a Slack message is posted. + """Post a Slack message. The provided message is stored in the messages attribute. diff --git a/src/safir/testing/uvicorn.py b/src/safir/testing/uvicorn.py index 4875bc27..59d965bd 100644 --- a/src/safir/testing/uvicorn.py +++ b/src/safir/testing/uvicorn.py @@ -17,7 +17,6 @@ import time from dataclasses import dataclass from pathlib import Path -from typing import Optional __all__ = [ "ServerNotListeningError", @@ -68,7 +67,7 @@ def _wait_for_server(port: int, timeout: float = 5.0) -> None: sock.connect(("localhost", port)) except socket.timeout: pass - except socket.error as e: + except OSError as e: if e.errno not in (errno.ETIMEDOUT, errno.ECONNREFUSED): raise else: @@ -80,11 +79,11 @@ def _wait_for_server(port: int, timeout: float = 5.0) -> None: def spawn_uvicorn( *, working_directory: str | Path, - app: Optional[str] = None, - factory: Optional[str] = None, + app: str | None = None, + factory: str | None = None, capture: bool = False, timeout: float = 5.0, - env: Optional[dict[str, str]] = None, + env: dict[str, str] | None = None, ) -> UvicornProcess: """Spawn an ASGI app as a separate Uvicorn process. @@ -129,10 +128,7 @@ def spawn_uvicorn( raise ValueError("Only one of app or factory may be given") if not app and not factory: raise ValueError("Neither of app nor factory was given") - if env: - env = {**os.environ, **env} - else: - env = {**os.environ} + env = {**os.environ, **env} if env else {**os.environ} # Get a random port for the app to listen on. sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -146,9 +142,9 @@ def spawn_uvicorn( elif factory: cmd.extend(("--factory", factory)) if "PYTHONPATH" in env: - env["PYTHONPATH"] += f":{os.getcwd()}" + env["PYTHONPATH"] += f":{Path.cwd()}" else: - env["PYTHONPATH"] = os.getcwd() + env["PYTHONPATH"] = str(Path.cwd()) logging.info("Starting server with command %s", " ".join(cmd)) if capture: process = subprocess.Popen( diff --git a/tests/conftest.py b/tests/conftest.py index 21f11885..e645f894 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,7 +34,7 @@ def mock_slack(respx_mock: respx.Router) -> MockSlackWebhook: @pytest_asyncio.fixture async def redis_client() -> AsyncIterator[redis.Redis]: - """A Redis client for testing. + """Redis client for testing. This fixture connects to the Redis server that runs via tox-docker. """ diff --git a/tests/database_test.py b/tests/database_test.py index 78a15e50..a093e2e8 100644 --- a/tests/database_test.py +++ b/tests/database_test.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta, timezone from urllib.parse import unquote, urlparse import pytest @@ -170,7 +170,7 @@ async def test_create_sync_session() -> None: def test_datetime() -> None: - tz_aware = datetime.now(tz=timezone.utc) + tz_aware = datetime.now(tz=UTC) tz_naive = tz_aware.replace(tzinfo=None) assert datetime_to_db(tz_aware) == tz_naive @@ -180,11 +180,11 @@ def test_datetime() -> None: assert datetime_to_db(None) is None assert datetime_from_db(None) is None - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=r"datetime .* not in UTC"): datetime_to_db(tz_naive) tz_local = datetime.now(tz=timezone(timedelta(hours=1))) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=r"datetime .* not in UTC"): datetime_to_db(tz_local) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=r"datetime .* not in UTC"): datetime_from_db(tz_local) diff --git a/tests/datetime_test.py b/tests/datetime_test.py index d554ec81..5b95a9fc 100644 --- a/tests/datetime_test.py +++ b/tests/datetime_test.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta, timezone import pytest @@ -17,16 +17,16 @@ def test_current_datetime() -> None: time = current_datetime() assert time.microsecond == 0 - assert time.tzinfo == timezone.utc - now = datetime.now(tz=timezone.utc) + assert time.tzinfo == UTC + now = datetime.now(tz=UTC) assert now - timedelta(seconds=2) <= time <= now time = current_datetime(microseconds=True) if not time.microsecond: time = current_datetime(microseconds=True) assert time.microsecond != 0 - assert time.tzinfo == timezone.utc - now = datetime.now(tz=timezone.utc) + assert time.tzinfo == UTC + now = datetime.now(tz=UTC) assert now - timedelta(seconds=2) <= time <= now @@ -34,17 +34,17 @@ def test_isodatetime() -> None: time = datetime.fromisoformat("2022-09-16T12:03:45+00:00") assert isodatetime(time) == "2022-09-16T12:03:45Z" - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=r"datetime .* not in UTC"): isodatetime(datetime.fromisoformat("2022-09-16T12:03:45+02:00")) def test_parse_isodatetime() -> None: time = parse_isodatetime("2022-09-16T12:03:45Z") - assert time == datetime(2022, 9, 16, 12, 3, 45, tzinfo=timezone.utc) + assert time == datetime(2022, 9, 16, 12, 3, 45, tzinfo=UTC) now = current_datetime() assert parse_isodatetime(isodatetime(now)) == now - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=r".* does not end with Z"): parse_isodatetime("2022-09-16T12:03:45+00:00") @@ -55,13 +55,13 @@ def test_format_datetime_for_logging() -> None: # Test with milliseconds, allowing for getting extremely unlucky and # having no microseconds. Getting unlucky twice seems impossible, so we'll # fail in that case rather than loop. - now = datetime.now(tz=timezone.utc) + now = datetime.now(tz=UTC) if not now.microsecond: - now = datetime.now(tz=timezone.utc) + now = datetime.now(tz=UTC) milliseconds = int(now.microsecond / 1000) expected = now.strftime("%Y-%m-%d %H:%M:%S") + f".{milliseconds:03n}" assert format_datetime_for_logging(now) == expected time = datetime.now(tz=timezone(timedelta(hours=1))) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=r"datetime .* not in UTC"): format_datetime_for_logging(time) diff --git a/tests/dependencies/arq_test.py b/tests/dependencies/arq_test.py index b6683c2d..f5dd9146 100644 --- a/tests/dependencies/arq_test.py +++ b/tests/dependencies/arq_test.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any import pytest from arq.constants import default_queue_name @@ -37,7 +37,7 @@ async def post_job( @app.get("/jobs/{job_id}") async def get_metadata( job_id: str, - queue_name: Optional[str] = None, + queue_name: str | None = None, arq_queue: MockArqQueue = Depends(arq_dependency), ) -> dict[str, Any]: """Get metadata about a job.""" @@ -45,8 +45,8 @@ async def get_metadata( job = await arq_queue.get_job_metadata( job_id, queue_name=queue_name ) - except JobNotFound: - raise HTTPException(status_code=404) + except JobNotFound as e: + raise HTTPException(status_code=404) from e return { "job_id": job.id, "job_status": job.status, @@ -59,7 +59,7 @@ async def get_metadata( @app.get("/results/{job_id}") async def get_result( job_id: str, - queue_name: Optional[str] = None, + queue_name: str | None = None, arq_queue: MockArqQueue = Depends(arq_dependency), ) -> dict[str, Any]: """Get the results for a job.""" @@ -68,7 +68,7 @@ async def get_result( job_id, queue_name=queue_name ) except (JobNotFound, JobResultUnavailable) as e: - raise HTTPException(status_code=404, detail=str(e)) + raise HTTPException(status_code=404, detail=str(e)) from e return { "job_id": job_result.id, "job_status": job_result.status, @@ -83,20 +83,21 @@ async def get_result( @app.post("/jobs/{job_id}/inprogress") async def post_job_inprogress( job_id: str, - queue_name: Optional[str] = None, + queue_name: str | None = None, arq_queue: MockArqQueue = Depends(arq_dependency), ) -> None: """Toggle a job to in-progress, for testing.""" try: await arq_queue.set_in_progress(job_id, queue_name=queue_name) except JobNotFound as e: - raise HTTPException(status_code=404, detail=str(e)) + raise HTTPException(status_code=404, detail=str(e)) from e @app.post("/jobs/{job_id}/complete") async def post_job_complete( job_id: str, - queue_name: Optional[str] = None, - result: Optional[str] = None, + *, + queue_name: str | None = None, + result: str | None = None, success: bool = True, arq_queue: MockArqQueue = Depends(arq_dependency), ) -> None: @@ -106,7 +107,7 @@ async def post_job_complete( job_id, result=result, success=success, queue_name=queue_name ) except JobNotFound as e: - raise HTTPException(status_code=404, detail=str(e)) + raise HTTPException(status_code=404, detail=str(e)) from e @app.on_event("startup") async def startup() -> None: diff --git a/tests/github/webhooks_test.py b/tests/github/webhooks_test.py index f9b0fec9..68efed17 100644 --- a/tests/github/webhooks_test.py +++ b/tests/github/webhooks_test.py @@ -4,7 +4,7 @@ from pathlib import Path -import safir.github.webhooks as webhooks +from safir.github import webhooks from safir.github.models import ( GitHubCheckRunStatus, GitHubCheckSuiteConclusion, diff --git a/tests/logging_test.py b/tests/logging_test.py index 7ea9ae90..eaef02c8 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -5,7 +5,7 @@ import json import logging import re -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from pathlib import Path from unittest.mock import ANY @@ -96,8 +96,8 @@ def test_configure_logging_dev_timestamp(caplog: LogCaptureFixture) -> None: isotimestamp = match.group(1) assert isotimestamp.endswith("Z") timestamp = datetime.fromisoformat(isotimestamp[:-1]) - timestamp = timestamp.replace(tzinfo=timezone.utc) - now = datetime.now(tz=timezone.utc) + timestamp = timestamp.replace(tzinfo=UTC) + now = datetime.now(tz=UTC) assert now - timedelta(seconds=5) < timestamp < now @@ -148,8 +148,8 @@ def test_configure_logging_prod_timestamp(caplog: LogCaptureFixture) -> None: } assert data["timestamp"].endswith("Z") timestamp = datetime.fromisoformat(data["timestamp"][:-1]) - timestamp = timestamp.replace(tzinfo=timezone.utc) - now = datetime.now(tz=timezone.utc) + timestamp = timestamp.replace(tzinfo=UTC) + now = datetime.now(tz=UTC) assert now - timedelta(seconds=5) < timestamp < now @@ -188,7 +188,7 @@ def test_dev_exception_logging(caplog: LogCaptureFixture) -> None: logger = structlog.get_logger("myapp") try: - raise ValueError("this is some exception") + raise ValueError("this is some exception") # noqa: TRY301 except Exception: logger.exception("exception happened", foo="bar") @@ -206,7 +206,7 @@ def test_production_exception_logging(caplog: LogCaptureFixture) -> None: logger = structlog.get_logger("myapp") try: - raise ValueError("this is some exception") + raise ValueError("this is some exception") # noqa: TRY301 except Exception: logger.exception("exception happened", foo="bar") diff --git a/tests/metadata_test.py b/tests/metadata_test.py index 0ad3321f..37313dbb 100644 --- a/tests/metadata_test.py +++ b/tests/metadata_test.py @@ -1,9 +1,7 @@ -"""Tests for the safir.metadata module. -""" +"""Tests for the safir.metadata module.""" from __future__ import annotations -import sys from email.message import Message from importlib.metadata import metadata from typing import cast @@ -15,10 +13,7 @@ @pytest.fixture(scope="session") def safir_metadata() -> Message: - if sys.version_info >= (3, 10): - return cast(Message, metadata("safir")) - else: - return metadata("safir") + return cast(Message, metadata("safir")) def test_get_project_url(safir_metadata: Message) -> None: diff --git a/tests/middleware/x_forwarded_test.py b/tests/middleware/x_forwarded_test.py index 7abc2faa..69432c3f 100644 --- a/tests/middleware/x_forwarded_test.py +++ b/tests/middleware/x_forwarded_test.py @@ -3,7 +3,6 @@ from __future__ import annotations from ipaddress import _BaseNetwork, ip_network -from typing import Optional import pytest from fastapi import FastAPI, Request @@ -12,7 +11,7 @@ from safir.middleware.x_forwarded import XForwardedMiddleware -def build_app(proxies: Optional[list[_BaseNetwork]] = None) -> FastAPI: +def build_app(proxies: list[_BaseNetwork] | None = None) -> FastAPI: """Construct a test FastAPI app with the middleware registered.""" app = FastAPI() app.add_middleware(XForwardedMiddleware, proxies=proxies) diff --git a/tests/pydantic_test.py b/tests/pydantic_test.py index 18035332..161c06c5 100644 --- a/tests/pydantic_test.py +++ b/tests/pydantic_test.py @@ -3,8 +3,7 @@ from __future__ import annotations import json -from datetime import datetime, timedelta, timezone -from typing import Optional +from datetime import UTC, datetime, timedelta, timezone import pytest from pydantic import BaseModel, ValidationError, root_validator @@ -21,18 +20,18 @@ def test_normalize_datetime() -> None: assert normalize_datetime(None) is None - date = datetime.fromtimestamp(1668814932, tz=timezone.utc) + date = datetime.fromtimestamp(1668814932, tz=UTC) assert normalize_datetime(1668814932) == date mst_zone = timezone(-timedelta(hours=7)) mst_date = datetime.now(tz=mst_zone) - utc_date = mst_date.astimezone(timezone.utc) + utc_date = mst_date.astimezone(UTC) assert normalize_datetime(mst_date) == utc_date - naive_date = datetime.utcnow() + naive_date = datetime.utcnow() # noqa: DTZ003 aware_date = normalize_datetime(naive_date) - assert aware_date == naive_date.replace(tzinfo=timezone.utc) - assert aware_date.tzinfo == timezone.utc + assert aware_date == naive_date.replace(tzinfo=UTC) + assert aware_date.tzinfo == UTC def test_normalize_isodatetime() -> None: @@ -44,13 +43,13 @@ def test_normalize_isodatetime() -> None: date = datetime.fromisoformat("2023-01-25T15:44:00+00:00") assert date == normalize_isodatetime("2023-01-25T15:44Z") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=r"Must be a string in .* format"): normalize_isodatetime("2023-01-25T15:44:00+00:00") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=r"Must be a string in .* format"): normalize_isodatetime(1668814932) # type: ignore[arg-type] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=r"Must be a string in .* format"): normalize_isodatetime("next thursday") @@ -96,9 +95,9 @@ class TestModel(CamelCaseModel): def test_validate_exactly_one_of() -> None: class Model(BaseModel): - foo: Optional[int] = None - bar: Optional[int] = None - baz: Optional[int] = None + foo: int | None = None + bar: int | None = None + baz: int | None = None _validate_type = root_validator(allow_reuse=True)( validate_exactly_one_of("foo", "bar", "baz") @@ -118,8 +117,8 @@ class Model(BaseModel): assert "one of foo, bar, and baz must be given" in str(excinfo.value) class TwoModel(BaseModel): - foo: Optional[int] = None - bar: Optional[int] = None + foo: int | None = None + bar: int | None = None _validate_type = root_validator(allow_reuse=True)( validate_exactly_one_of("foo", "bar") diff --git a/tests/redis_test.py b/tests/redis_test.py index b973a32f..7bed2e2b 100644 --- a/tests/redis_test.py +++ b/tests/redis_test.py @@ -23,9 +23,7 @@ class DemoModel(BaseModel): async def basic_testing(storage: PydanticRedisStorage[DemoModel]) -> None: - """Test basic storage operations for either encrypted or unencrypted - storage. - """ + """Test basic storage operations for encrypted and unencrypted storage.""" await storage.store("mark42", DemoModel(name="Mark", value=42)) await storage.store("mark13", DemoModel(name="Mark", value=13)) await storage.store("jon7", DemoModel(name="Jon", value=7)) @@ -143,9 +141,7 @@ async def test_deserialization_error(redis_client: redis.Redis) -> None: async def test_deserialization_error_with_key_prefix( redis_client: redis.Redis, ) -> None: - """Test that deserialization errors are presented correctly when a key - prefix is used. - """ + """Test deserialization error formatting when a key prefix is used.""" storage = PydanticRedisStorage( datatype=DemoModel, redis=redis_client, key_prefix="test:" ) diff --git a/tests/testing/gcs_test.py b/tests/testing/gcs_test.py index 4ab4c083..0a7d639c 100644 --- a/tests/testing/gcs_test.py +++ b/tests/testing/gcs_test.py @@ -6,7 +6,7 @@ from __future__ import annotations -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from pathlib import Path import google.auth @@ -79,7 +79,7 @@ def test_mock_files(mock_gcs_file: MockStorageClient) -> None: assert blob.exists() assert blob.size == this_file.stat().st_size assert blob.updated == datetime.fromtimestamp( - this_file.stat().st_mtime, tz=timezone.utc + this_file.stat().st_mtime, tz=UTC ) assert blob.etag == str(this_file.stat().st_ino) assert blob.download_as_bytes() == this_file.read_bytes() diff --git a/tests/testing/kubernetes_test.py b/tests/testing/kubernetes_test.py index c6e70e31..ce7425ec 100644 --- a/tests/testing/kubernetes_test.py +++ b/tests/testing/kubernetes_test.py @@ -7,7 +7,7 @@ from __future__ import annotations import asyncio -from typing import Any, Optional +from typing import Any import pytest from kubernetes_asyncio.client import ( @@ -95,7 +95,7 @@ def error(method: str, *args: Any) -> None: raise ValueError("some exception") mock_kubernetes.error_callback = error - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match="some exception") as excinfo: await mock_kubernetes.replace_namespaced_custom_object( "gafaelfawr.lsst.io", "v1alpha1", @@ -111,7 +111,7 @@ async def watch_events( mock_kubernetes: MockKubernetesApi, namespace: str, *, - resource_version: Optional[str] = None, + resource_version: str | None = None, ) -> list[CoreV1Event]: """Watch events, returning when an event with message ``Done`` is seen.""" method = mock_kubernetes.list_namespaced_event From 6271deb2b2cfa5a2c572d46d76c7a1648ac3faf9 Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Fri, 1 Sep 2023 12:25:08 -0700 Subject: [PATCH 3/6] Redo Kubernetes object listing in mock Several list_* mock Kubernetes API functions were using variations of the same pattern to support field and label selectors. Move that code into a helper method so that it can be shared, and clean up the coding style of the function that checks label selectors. --- src/safir/testing/kubernetes.py | 204 +++++++++++++++++--------------- 1 file changed, 109 insertions(+), 95 deletions(-) diff --git a/src/safir/testing/kubernetes.py b/src/safir/testing/kubernetes.py index 10c345d9..5346e66d 100644 --- a/src/safir/testing/kubernetes.py +++ b/src/safir/testing/kubernetes.py @@ -66,7 +66,7 @@ def _parse_label_selector(label_selector: str) -> dict[str, str]: Returns ------- - dict of str to str + dict of str Dictionary of required labels to their required values. Raises @@ -87,35 +87,33 @@ def _parse_label_selector(label_selector: str) -> dict[str, str]: def _check_labels( obj_labels: dict[str, str] | None, label_selector: str | None ) -> bool: - """Check whether an object's labels match the label selector supplied. + """Check whether an object's labels match a label selector. Parameters ---------- obj_labels - Kubernetes object labels - + Kubernetes object labels. label_selector - label selector in string form + Label selector in string form. Returns ------- bool - Did all of the supplied label_selector labels match - the object labels? + Whether this object matches the label selector. """ - if label_selector is None or label_selector == "": - # Everything matches the absence of a selector. + # Everything matches the absence of a selector. + if not label_selector: return True - if obj_labels is None or not obj_labels: - # If there are no labels but a non-empty selector, it doesn't - # match. + + # If there are no labels but a non-empty selector, it doesn't match. + if not obj_labels: return False + + # Check that all labels match the labels of the object. labels = _parse_label_selector(label_selector) - for lbl in labels: - if lbl not in obj_labels or labels[lbl] != obj_labels[lbl]: - # Label isn't present or its value isn't right + for label in labels: + if label not in obj_labels or labels[label] != obj_labels[label]: return False - # The whole selector is correct. return True @@ -148,7 +146,7 @@ def strip_none(model: dict[str, Any]) -> dict[str, Any]: continue new_value = value if isinstance(value, dict): - new_value: Any = strip_none(value) + new_value = strip_none(value) elif isinstance(value, list): list_result = [] for item in value: @@ -617,10 +615,9 @@ async def list_cluster_custom_object( """ self._maybe_error("list_cluster_custom_object", group, version, plural) key = f"{group}/{version}/{plural}" - results = [] + results: list[dict[str, Any]] = [] for namespace in self._objects: - for obj in self._objects[namespace].get(key, {}).values(): - results.append(obj) + results.extend(self._objects[namespace].get(key, {}).values()) return {"items": results} async def patch_namespaced_custom_object_status( @@ -1043,24 +1040,10 @@ async def list_namespaced_ingress( msg = f"Namespace {namespace} not found" raise ApiException(status=404, reason=msg) if not watch: - if field_selector: - match = re.match(r"metadata\.name=(.*)$", field_selector) - assert match - assert match.group(1) - try: - ingress = self._get_object( - namespace, "Ingress", match.group(1) - ) - return V1IngressList(kind="Ingress", items=[ingress]) - except ApiException: - return V1IngressList(kind="Ingress", items=[]) - else: - ingresss = [] - if "Ingress" in self._objects[namespace]: - for obj in self._objects[namespace]["Ingress"].values(): - if _check_labels(obj.metadata.labels, label_selector): - ingresss.append(obj) - return V1IngressList(kind="Ingress", items=ingresss) + ingresses = self._list_objects( + namespace, "Ingress", field_selector, label_selector + ) + return V1IngressList(kind="Ingress", items=ingresses) # All watches must not preload content since we're returning raw JSON. # This is done by the Kubernetes API Watch object. @@ -1282,22 +1265,10 @@ async def list_namespaced_job( msg = f"Namespace {namespace} not found" raise ApiException(status=404, reason=msg) if not watch: - if field_selector: - match = re.match(r"metadata\.name=(.*)$", field_selector) - assert match - assert match.group(1) - try: - job = self._get_object(namespace, "Job", match.group(1)) - return V1JobList(kind="Job", items=[job]) - except ApiException: - return V1JobList(kind="Job", items=[]) - else: - jobs = [] - if "Job" in self._objects[namespace]: - for obj in self._objects[namespace]["Job"].values(): - if _check_labels(obj.metadata.labels, label_selector): - jobs.append(obj) - return V1JobList(kind="Job", items=jobs) + jobs = self._list_objects( + namespace, "Job", field_selector, label_selector + ) + return V1PodList(kind="Job", items=jobs) # All watches must not preload content since we're returning raw JSON. # This is done by the Kubernetes API Watch object. @@ -1434,9 +1405,7 @@ async def list_namespace(self) -> V1NamespaceList: synthesized namespace objects. """ self._maybe_error("list_namespace") - namespaces = [] - for namespace in self._objects: - namespaces.append(await self.read_namespace(namespace)) + namespaces = [await self.read_namespace(n) for n in self._objects] return V1NamespaceList(items=namespaces) # NETWORKPOLICY API @@ -1643,24 +1612,10 @@ async def list_namespaced_pod( msg = f"Namespace {namespace} not found" raise ApiException(status=404, reason=msg) if not watch: - if field_selector: - match = re.match(r"metadata\.name=(.*)$", field_selector) - assert match - assert match.group(1) - try: - pod = self._get_object(namespace, "Pod", match.group(1)) - if _check_labels(pod.metadata.labels, label_selector): - return V1PodList(kind="Pod", items=[pod]) - return V1PodList(kind="Pod", items=[]) - except ApiException: - return V1PodList(kind="Pod", items=[]) - else: - pods = [] - if "Pod" in self._objects[namespace]: - for obj in self._objects[namespace]["Pod"].values(): - if _check_labels(obj.metadata.labels, label_selector): - pods.append(obj) - return V1PodList(kind="Pod", items=pods) + pods = self._list_objects( + namespace, "Pod", field_selector, label_selector + ) + return V1PodList(kind="Pod", items=pods) # All watches must not preload content since we're returning raw JSON. # This is done by the Kubernetes API Watch object. @@ -2005,8 +1960,8 @@ async def list_namespaced_service( Only ``metadata.name=...`` is supported. It is parsed to find the service name and only services matching that name will be returned. label_selector - Which events to retrieve when performing a watch. All - labels must match. + Which matching objects to retrieve by label. All labels must + match. resource_version Where to start in the event stream when performing a watch. If `None`, starts with the next change. @@ -2039,24 +1994,10 @@ async def list_namespaced_service( msg = f"Namespace {namespace} not found" raise ApiException(status=404, reason=msg) if not watch: - if field_selector: - match = re.match(r"metadata\.name=(.*)$", field_selector) - assert match - assert match.group(1) - try: - service = self._get_object( - namespace, "Service", match.group(1) - ) - return V1ServiceList(kind="Service", items=[service]) - except ApiException: - return V1ServiceList(kind="Service", items=[]) - else: - services = [] - if "Service" in self._objects[namespace]: - for obj in self._objects[namespace]["Service"].values(): - if _check_labels(obj.metadata.labels, label_selector): - services.append(obj) - return V1ServiceList(kind="Service", items=services) + services = self._list_objects( + namespace, "Service", field_selector, label_selector + ) + return V1ServiceList(kind="Service", items=services) # All watches must not preload content since we're returning raw JSON. # This is done by the Kubernetes API Watch object. @@ -2101,6 +2042,15 @@ async def read_namespaced_service( def _delete_object(self, namespace: str, key: str, name: str) -> V1Status: """Delete an object from internal data structures. + Parameters + ---------- + namespace + Namespace from which to delete an object. + key + Key under which the object is stored (usually the kind). + name + Name of the object. + Returns ------- kubernetes_asyncio.client.V1Status @@ -2121,6 +2071,15 @@ def _delete_object(self, namespace: str, key: str, name: str) -> V1Status: def _get_object(self, namespace: str, key: str, name: str) -> Any: """Retrieve an object from internal data structures. + Parameters + ---------- + namespace + Namespace from which to delete an object. + key + Key under which the object is stored (usually the kind). + name + Name of the object. + Returns ------- Any @@ -2139,6 +2098,61 @@ def _get_object(self, namespace: str, key: str, name: str) -> Any: raise ApiException(status=404, reason=reason) return self._objects[namespace][key][name] + def _list_objects( + self, + namespace: str, + key: str, + field_selector: str | None, + label_selector: str | None, + ) -> list[Any]: + """List objects, possibly with selector restrictions. + + Parameters + ---------- + namespace + Namespace in which to list objects. + key + Key under which the object is stored (usually the kind). + field_selector + If present, only ``metadata.name=...`` is supported. It is parsed + to find the object name and only an object matching that name will + be returned. + label_selector + Which matching objects to retrieve by label. All labels must + match. + + Returns + ------- + list + List of matching objects. + """ + if key not in self._objects[namespace]: + return [] + + # If there is a field selector, only name selectors are supported and + # we should retrieve the object by name. + if field_selector: + match = re.match(r"metadata\.name=(.*)$", field_selector) + if not match or not match.group(1): + msg = f"Field selector {field_selector} not supported" + raise ValueError(msg) + try: + obj = self._get_object(namespace, key, match.group(1)) + if _check_labels(obj.metadata.labels, label_selector): + return [obj] + else: + return [] + except ApiException: + return [] + + # Otherwise, construct the list of all objects matching the label + # selector. + return [ + o + for o in self._objects[namespace][key].values() + if _check_labels(o.metadata.labels, label_selector) + ] + def _maybe_error(self, method: str, *args: Any) -> None: """Call the error callback if one is registered. From 60abcc7bdadcd188d0685b0b5f67e9e99b9c8692 Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Fri, 1 Sep 2023 11:37:51 -0700 Subject: [PATCH 4/6] Convert to Ruff for linting Add a Ruff pre-commit hook and configure it based on the current SQuaRE PyPI library template. Also enable conflict and trailing whitespace checking in the standard pre-commit hooks. --- .flake8 | 7 -- .pre-commit-config.yaml | 18 ++-- changelog.d/20230905_084310_rra_DM_40628.md | 3 + pyproject.toml | 114 ++++++++++++++++++-- 4 files changed, 118 insertions(+), 24 deletions(-) delete mode 100644 .flake8 create mode 100644 changelog.d/20230905_084310_rra_DM_40628.md diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 01ee5229..00000000 --- a/.flake8 +++ /dev/null @@ -1,7 +0,0 @@ -[flake8] -max-line-length = 79 -# E203: whitespace before :, flake8 disagrees with PEP-8 -# W503: line break after binary operator, flake8 disagrees with PEP-8 -ignore = E203, W503 -exclude = - docs/conf.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 90020ff3..ce788b99 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,15 +2,16 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - - id: check-yaml + - id: check-merge-conflict - id: check-toml + - id: check-yaml + - id: trailing-whitespace - - repo: https://github.com/PyCQA/isort - rev: 5.12.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.0.286 hooks: - - id: isort - additional_dependencies: - - toml + - id: ruff + args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/ambv/black rev: 23.3.0 @@ -23,8 +24,3 @@ repos: - id: blacken-docs additional_dependencies: [black==23.1.0] args: [-l, '76', -t, py311] - - - repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 - hooks: - - id: flake8 diff --git a/changelog.d/20230905_084310_rra_DM_40628.md b/changelog.d/20230905_084310_rra_DM_40628.md new file mode 100644 index 00000000..806c2a10 --- /dev/null +++ b/changelog.d/20230905_084310_rra_DM_40628.md @@ -0,0 +1,3 @@ +### Other changes + +- Safir now uses the [Ruff](https://beta.ruff.rs/docs/) linter instead of flake8 and isort. diff --git a/pyproject.toml b/pyproject.toml index d6c5a1af..db4c2ce8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,12 +130,6 @@ exclude = ''' # Use single-quoted strings so TOML treats the string like a Python r-string # Multi-line strings are implicitly treated by black as regular expressions -[tool.isort] -profile = "black" -line_length = 79 -known_first_party = ["safir", "tests"] -skip = ["docs/conf.py"] - [tool.pytest.ini_options] asyncio_mode = "strict" filterwarnings = [ @@ -171,6 +165,114 @@ init_typed = true warn_required_dynamic_aliases = true warn_untyped_fields = true +# The rule used with Ruff configuration is to disable every lint that has +# legitimate exceptions that are not dodgy code, rather than cluttering code +# with noqa markers. This is therefore a reiatively relaxed configuration that +# errs on the side of disabling legitimate lints. +# +# Reference for settings: https://beta.ruff.rs/docs/settings/ +# Reference for rules: https://beta.ruff.rs/docs/rules/ +[tool.ruff] +exclude = [ + "docs/**", +] +line-length = 79 +ignore = [ + "ANN101", # self should not have a type annotation + "ANN102", # cls should not have a type annotation + "ANN401", # sometimes Any is the right type + "ARG001", # unused function arguments are often legitimate + "ARG002", # unused method arguments are often legitimate + "ARG005", # unused lambda arguments are often legitimate + "BLE001", # we want to catch and report Exception in background tasks + "C414", # nested sorted is how you sort by multiple keys with reverse + "COM812", # omitting trailing commas allows black autoreformatting + "D102", # sometimes we use docstring inheritence + "D104", # don't see the point of documenting every package + "D105", # our style doesn't require docstrings for magic methods + "D106", # Pydantic uses a nested Config class that doesn't warrant docs + "EM101", # justification (duplicate string in traceback) is silly + "EM102", # justification (duplicate string in traceback) is silly + "FBT003", # positional booleans are normal for Pydantic field defaults + "FIX002", # point of a TODO comment is that we're not ready to fix it + "G004", # forbidding logging f-strings is appealing, but not our style + "RET505", # disagree that omitting else always makes code more readable + "PLR0913", # factory pattern uses constructors with many arguments + "PLR2004", # too aggressive about magic values + "S105", # good idea but too many false positives on non-passwords + "S106", # good idea but too many false positives on non-passwords + "S603", # not going to manually mark every subprocess call as reviewed + "S607", # using PATH is not a security vulnerability + "SIM102", # sometimes the formatting of nested if statements is clearer + "SIM117", # sometimes nested with contexts are clearer + "TCH001", # we decided to not maintain separate TYPE_CHECKING blocks + "TCH002", # we decided to not maintain separate TYPE_CHECKING blocks + "TCH003", # we decided to not maintain separate TYPE_CHECKING blocks + "TD003", # we don't require issues be created for TODOs + "TID252", # if we're going to use relative imports, use them always + "TRY003", # good general advice but lint is way too aggressive + + # Safir-specific rules. + "N818", # Exception is correct in some cases, others are part of API + "PLW0603", # necessary trick for safir.logging +] +select = ["ALL"] +target-version = "py311" + +[tool.ruff.per-file-ignores] +"src/safir/testing/**" = [ + "S101", # test support functions are allowed to use assert +] +"tests/**" = [ + "C901", # tests are allowed to be complex, sometimes that's convenient + "D101", # tests don't need docstrings + "D103", # tests don't need docstrings + "PLR0915", # tests are allowed to be long, sometimes that's convenient + "PT012", # way too aggressive about limiting pytest.raises blocks + "S101", # tests should use assert + "SLF001", # tests are allowed to access private members +] + +[tool.ruff.isort] +known-first-party = ["safir", "tests"] +split-on-trailing-comma = false + +[tool.ruff.flake8-bugbear] +extend-immutable-calls = [ + "fastapi.Form", + "fastapi.Header", + "fastapi.Depends", + "fastapi.Path", + "fastapi.Query", +] + +# These are too useful as attributes or methods to allow the conflict with the +# built-in to rule out their use. +[tool.ruff.flake8-builtins] +builtins-ignorelist = [ + "all", + "any", + "dict", + "help", + "id", + "list", + "open", + "type", +] + +[tool.ruff.flake8-pytest-style] +fixture-parentheses = false +mark-parentheses = false + +[tool.ruff.pep8-naming] +classmethod-decorators = [ + "pydantic.root_validator", + "pydantic.validator", +] + +[tool.ruff.pydocstyle] +convention = "numpy" + [tool.scriv] categories = [ "Backwards-incompatible changes", From da9f92f519adb345de1a8ccfb7633205969b208d Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Tue, 5 Sep 2023 10:59:49 -0700 Subject: [PATCH 5/6] Fix Ruff warning in new Click test --- tests/click_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/click_test.py b/tests/click_test.py index f3e33bea..5894e3bd 100644 --- a/tests/click_test.py +++ b/tests/click_test.py @@ -11,7 +11,7 @@ def test_display_help() -> None: @click.group() def main() -> None: - """Some command.""" + """Run some command.""" @main.command() @click.argument("topic", default=None, required=False, nargs=1) @@ -38,7 +38,7 @@ def something() -> None: runner = CliRunner() result = runner.invoke(main, ["help"], catch_exceptions=False) - assert "Some command" in result.output + assert "Run some command" in result.output result = runner.invoke(main, ["help", "foo"], catch_exceptions=False) assert "main foo [OPTIONS]" in result.output assert "Run foo" in result.output From 7cb30e2f91dc1824a63e53bdc2291b09513ca7c0 Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Tue, 5 Sep 2023 11:32:27 -0700 Subject: [PATCH 6/6] Revert docstring changes for D205 Our documentation style allows wrapped summary lines longer than a single line. Disable the corresponding Ruff diagnostic and revert changes to satisfy that diagnostic. --- pyproject.toml | 1 + src/safir/__init__.py | 4 +--- src/safir/arq.py | 9 +++----- src/safir/dependencies/arq.py | 11 ++++----- src/safir/github/_client.py | 18 +++++---------- src/safir/github/models.py | 26 +++++++++------------ src/safir/github/webhooks.py | 43 +++++++++++++---------------------- src/safir/redis.py | 4 +--- tests/redis_test.py | 8 +++++-- 9 files changed, 50 insertions(+), 74 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index db4c2ce8..f1480963 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -191,6 +191,7 @@ ignore = [ "D104", # don't see the point of documenting every package "D105", # our style doesn't require docstrings for magic methods "D106", # Pydantic uses a nested Config class that doesn't warrant docs + "D205", # our documentation style allows a folded first line "EM101", # justification (duplicate string in traceback) is silly "EM102", # justification (duplicate string in traceback) is silly "FBT003", # positional booleans are normal for Pydantic field defaults diff --git a/src/safir/__init__.py b/src/safir/__init__.py index 682c7dc9..486c9a01 100644 --- a/src/safir/__init__.py +++ b/src/safir/__init__.py @@ -1,6 +1,4 @@ -"""Support library for the Rubin Science Platform. - -Safir is the Rubin Observatory's library for building FastAPI services +"""Safir is the Rubin Observatory's library for building FastAPI services for the Rubin Science Platform. """ diff --git a/src/safir/arq.py b/src/safir/arq.py index 0abac796..1134aa24 100644 --- a/src/safir/arq.py +++ b/src/safir/arq.py @@ -269,9 +269,7 @@ async def from_job(cls, job: Job) -> Self: class ArqQueue(metaclass=abc.ABCMeta): - """arq queue interface supporting either Redis or an in-memory repository. - - Provides a common interface for working with an arq queue that can be + """A common interface for working with an arq queue that can be implemented either with a real Redis backend, or an in-memory repository for testing. @@ -290,9 +288,8 @@ def __init__( @property def default_queue_name(self) -> str: - """Name of the default queue. - - Used if the ``_queue_name`` parameter is not set in method calls. + """Name of the default queue, if the ``_queue_name`` parameter is not + set in method calls. """ return self._default_queue_name diff --git a/src/safir/dependencies/arq.py b/src/safir/dependencies/arq.py index af7ab6c0..35daab7e 100644 --- a/src/safir/dependencies/arq.py +++ b/src/safir/dependencies/arq.py @@ -10,11 +10,8 @@ class ArqDependency: - """FastAPI dependency providing a client for enqueuing tasks. - - This class maintains a singleton Redis client for enqueuing tasks to an - arq_ worker pool and provides it to handler methods via the FastAPI - dependency interface. + """A FastAPI dependency that maintains a Redis client for enqueuing tasks + to the worker pool. """ def __init__(self) -> None: @@ -86,4 +83,6 @@ async def __call__(self) -> ArqQueue: arq_dependency = ArqDependency() -"""Singleton instance of `ArqDependency` as a FastAPI dependency.""" +"""Singleton instance of `ArqDependency` that serves as a FastAPI +dependency. +""" diff --git a/src/safir/github/_client.py b/src/safir/github/_client.py index 8a7a532a..cf64e286 100644 --- a/src/safir/github/_client.py +++ b/src/safir/github/_client.py @@ -6,10 +6,8 @@ class GitHubAppClientFactory: - """Create GitHub App clients. - - Provides a factory for creating GitHub App clients authenticated either as - an app or as an installation of that app. + """Factory for creating GitHub App clients authenticated either as an app + or as an installation of that app. Parameters ---------- @@ -76,10 +74,8 @@ def create_app_client(self) -> GitHubAPI: async def create_installation_client( self, installation_id: str ) -> GitHubAPI: - """Create a client for an installation of the GitHub App. - - The resulting client is authenticated as an installation of the GitHub - App for a specific repository or organization. + """Create a client authenticated as an installation of the GitHub App + for a specific repository or organization. Parameters ---------- @@ -106,10 +102,8 @@ async def create_installation_client( async def create_installation_client_for_repo( self, owner: str, repo: str ) -> GitHubAPI: - """Create a client for a repository installation of the GitHub App. - - The resulting client is authenticated as an installation of the GitHub - App for a specific repository. + """Create a client authenticated as an installation of the GitHub App + for a specific repository or organization. Parameters ---------- diff --git a/src/safir/github/models.py b/src/safir/github/models.py index fa268075..4c214571 100644 --- a/src/safir/github/models.py +++ b/src/safir/github/models.py @@ -67,9 +67,8 @@ class GitHubUserModel(BaseModel): class GitHubRepositoryModel(BaseModel): - """A Pydantic model for the ``repository`` field. - - This field is often found in webhook payloads. + """A Pydantic model for the ``repository`` field, often found in webhook + payloads. https://docs.github.com/en/rest/repos/repos#get-a-repository """ @@ -271,10 +270,8 @@ class GitHubCheckSuiteConclusion(str, Enum): class GitHubCheckSuiteModel(BaseModel): - """A Pydantic model for the ``check_suite`` field. - - This field is found in a ``check_suite`` webhook - (`~safir.github.webhooks.GitHubCheckSuiteEventModel`). + """A Pydantic model for the ``check_suite`` field in a ``check_suite`` + webhook (`~safir.github.webhooks.GitHubCheckSuiteEventModel`). """ id: str = Field(description="Identifier for this check run.") @@ -323,7 +320,9 @@ class GitHubCheckRunConclusion(str, Enum): """The check run has failed.""" neutral = "neutral" - """The check run has a neutral outcome, perhaps because it was skipped.""" + """The check run has a neutral outcome, perhaps because the check was + skipped. + """ cancelled = "cancelled" """The check run was cancelled.""" @@ -370,9 +369,8 @@ class GitHubCheckRunOutput(BaseModel): class GitHubCheckRunPrInfoModel(BaseModel): - """A Pydantic model of the ``pull_requests[]`` items. - - These are found in a check run GitHub API model (`GitHubCheckRunModel`). + """A Pydantic model of the ``pull_requests[]`` items in a check run + GitHub API model (`GitHubCheckRunModel`). https://docs.github.com/en/rest/checks/runs#get-a-check-run """ @@ -381,10 +379,8 @@ class GitHubCheckRunPrInfoModel(BaseModel): class GitHubCheckRunModel(BaseModel): - """A Pydantic model for the ``check_run`` field. - - This is found in a check_run webhook payload - (`~safir.github.webhooks.GitHubCheckRunEventModel`). + """A Pydantic model for the "check_run" field in a check_run webhook + payload (`~safir.github.webhooks.GitHubCheckRunEventModel`). """ id: str = Field(description="Identifier for this check run.") diff --git a/src/safir/github/webhooks.py b/src/safir/github/webhooks.py index bfae5949..5815679f 100644 --- a/src/safir/github/webhooks.py +++ b/src/safir/github/webhooks.py @@ -31,18 +31,16 @@ class GitHubAppInstallationModel(BaseModel): - """A Pydantic model for the ``installation`` field found. - - This field is found in webhook payloads for GitHub Apps. + """A Pydantic model for the ``installation`` field found in webhook + payloads for GitHub Apps. """ id: str = Field(description="The installation ID.") class GitHubPushEventModel(BaseModel): - """A Pydantic model for the ``push`` event webhook. - - This webhook is triggered when a commit or tag is pushed. + """A Pydantic model for the ``push`` event webhook when a commit or + tag is pushed. https://docs.github.com/en/webhooks/webhook-events-and-payloads#push """ @@ -73,10 +71,8 @@ class GitHubPushEventModel(BaseModel): class GitHubAppInstallationEventRepoModel(BaseModel): - """A Pydantic model for repository objects used by installation events. - - See `GitHubAppInstallationRepositoriesEventModel` for where this model is - used. + """A pydantic model for repository objects used by + `GitHubAppInstallationRepositoriesEventModel`. https://docs.github.com/en/webhooks/webhook-events-and-payloads#installation """ @@ -98,10 +94,8 @@ def owner_name(self) -> str: class GitHubAppInstallationEventAction(str, Enum): - """The action performed on an GitHub App ``installation`` webhook. - - See `GitHubAppInstallationEventModel` for the model of the event where - this model is used. + """The action performed on an GitHub App ``installation`` webhook + (`GitHubAppInstallationEventModel`). """ created = "created" @@ -142,10 +136,8 @@ class GitHubAppInstallationEventModel(BaseModel): class GitHubAppInstallationRepositoriesEventAction(str, Enum): - """A Pydantic model for a ``installation_repositories`` action. - - This model is for the action performed on a ``installation_repositories`` - GitHub App webhook (`GitHubAppInstallationRepositoriesEventModel`). + """The action performed on a GitHub App ``installation_repositories`` + webhook (`GitHubAppInstallationRepositoriesEventModel`). """ #: Someone added a repository to an installation. @@ -179,9 +171,8 @@ class GitHubAppInstallationRepositoriesEventModel(BaseModel): class GitHubPullRequestEventAction(str, Enum): - """The action performed on a GitHub ``pull_request`` webhook. - - See `GitHubPullRequestEventModel` for where this model is used. + """The action performed on a GitHub ``pull_request`` webhook + (`GitHubPullRequestEventModel`). """ assigned = "assigned" @@ -276,9 +267,8 @@ class GitHubPullRequestEventModel(BaseModel): class GitHubCheckSuiteEventAction(str, Enum): - """The action performed in a GitHub ``check_suite`` webhook. - - See `GitHubCheckSuiteEventModel` for where this model is used. + """The action performed in a GitHub ``check_suite`` webhook + (`GitHubCheckSuiteEventModel`). """ completed = "completed" @@ -317,9 +307,8 @@ class GitHubCheckSuiteEventModel(BaseModel): class GitHubCheckRunEventAction(str, Enum): - """The action performed in a GitHub ``check_run`` webhook. - - See `GitHubCheckRunEventModel` for where this model is used. + """The action performed in a GitHub ``check_run`` webhook + (`GitHubCheckRunEventModel`). """ completed = "completed" diff --git a/src/safir/redis.py b/src/safir/redis.py index 1234bff2..854568a5 100644 --- a/src/safir/redis.py +++ b/src/safir/redis.py @@ -31,9 +31,7 @@ class DeserializeError(SlackException): - """Error decoding or deserializing a Pydantic object from Redis. - - Raised when a stored Pydantic object in Redis cannot be decoded (and + """Raised when a stored Pydantic object in Redis cannot be decoded (and possibly decrypted) or deserialized. Parameters diff --git a/tests/redis_test.py b/tests/redis_test.py index 7bed2e2b..b973a32f 100644 --- a/tests/redis_test.py +++ b/tests/redis_test.py @@ -23,7 +23,9 @@ class DemoModel(BaseModel): async def basic_testing(storage: PydanticRedisStorage[DemoModel]) -> None: - """Test basic storage operations for encrypted and unencrypted storage.""" + """Test basic storage operations for either encrypted or unencrypted + storage. + """ await storage.store("mark42", DemoModel(name="Mark", value=42)) await storage.store("mark13", DemoModel(name="Mark", value=13)) await storage.store("jon7", DemoModel(name="Jon", value=7)) @@ -141,7 +143,9 @@ async def test_deserialization_error(redis_client: redis.Redis) -> None: async def test_deserialization_error_with_key_prefix( redis_client: redis.Redis, ) -> None: - """Test deserialization error formatting when a key prefix is used.""" + """Test that deserialization errors are presented correctly when a key + prefix is used. + """ storage = PydanticRedisStorage( datatype=DemoModel, redis=redis_client, key_prefix="test:" )