Skip to content

Commit

Permalink
♻️Mypy: webserver (#6193)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanderegg authored Aug 19, 2024
1 parent 58114c8 commit 95e54ff
Show file tree
Hide file tree
Showing 84 changed files with 537 additions and 317 deletions.
5 changes: 5 additions & 0 deletions packages/models-library/src/models_library/basic_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)

from .basic_regex import (
PROPERTY_KEY_RE,
SEMANTIC_VERSION_RE_W_CAPTURE_GROUPS,
SIMPLE_VERSION_RE,
UUID_RE,
Expand Down Expand Up @@ -155,3 +156,7 @@ class BuildTargetEnum(str, Enum):
CACHE = "cache"
PRODUCTION = "production"
DEVELOPMENT = "development"


class KeyIDStr(ConstrainedStr):
regex = re.compile(PROPERTY_KEY_RE)
9 changes: 2 additions & 7 deletions packages/models-library/src/models_library/projects_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Models Node as a central element in a project's pipeline
"""

import re
from copy import deepcopy
from typing import Any, ClassVar, TypeAlias, Union

Expand All @@ -18,8 +17,7 @@
validator,
)

from .basic_regex import PROPERTY_KEY_RE
from .basic_types import EnvVarKey, HttpUrlWithCustomMinLength
from .basic_types import EnvVarKey, HttpUrlWithCustomMinLength, KeyIDStr
from .projects_access import AccessEnum
from .projects_nodes_io import (
DatCoreFileLink,
Expand Down Expand Up @@ -57,10 +55,6 @@
]


class KeyIDStr(ConstrainedStr):
regex = re.compile(PROPERTY_KEY_RE)


InputID: TypeAlias = KeyIDStr
OutputID: TypeAlias = KeyIDStr

Expand Down Expand Up @@ -238,6 +232,7 @@ def convert_from_enum(cls, v):

class Config:
extra = Extra.forbid

# NOTE: exporting without this trick does not make runHash as nullable.
# It is a Pydantic issue see https://github.com/samuelcolvin/pydantic/issues/1270
@staticmethod
Expand Down
13 changes: 4 additions & 9 deletions packages/models-library/src/models_library/projects_nodes_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

import re
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias
from typing import Any, ClassVar, TypeAlias
from uuid import UUID

from models_library.basic_types import KeyIDStr
from pydantic import (
AnyUrl,
BaseModel,
Expand All @@ -23,15 +24,11 @@

from .basic_regex import (
DATCORE_FILE_ID_RE,
PROPERTY_KEY_RE,
SIMCORE_S3_DIRECTORY_ID_RE,
SIMCORE_S3_FILE_ID_RE,
UUID_RE,
)

if TYPE_CHECKING:
pass

NodeID = UUID


Expand Down Expand Up @@ -107,10 +104,9 @@ class PortLink(BaseModel):
description="The node to get the port output from",
alias="nodeUuid",
)
output: str = Field(
output: KeyIDStr = Field(
...,
description="The port key in the node given by nodeUuid",
regex=PROPERTY_KEY_RE,
)

class Config:
Expand Down Expand Up @@ -183,8 +179,7 @@ class SimCoreFileLink(BaseFileLink):

dataset: str | None = Field(
default=None,
deprecated=True
# TODO: Remove with storage refactoring
deprecated=True,
)

@validator("store", always=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from aiohttp import web
from models_library.utils.json_serialization import json_dumps
from pydantic import BaseModel, Extra, ValidationError, parse_obj_as
from servicelib.aiohttp import status

from ..mimetype_constants import MIMETYPE_APPLICATION_JSON
from . import status

ModelClass = TypeVar("ModelClass", bound=BaseModel)
ModelOrListOrDictType = TypeVar("ModelOrListOrDictType", bound=BaseModel | list | dict)
Expand Down
130 changes: 66 additions & 64 deletions packages/service-library/src/servicelib/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from datetime import datetime
from inspect import getframeinfo, stack
from pathlib import Path
from typing import Any, TypeAlias, TypedDict
from typing import Any, Iterator, TypeAlias, TypedDict, TypeVar

from .utils_secrets import mask_sensitive_data

Expand Down Expand Up @@ -56,11 +56,11 @@ class CustomFormatter(logging.Formatter):
2. Overrides 'filename' with the value of 'file_name_override', if it exists.
"""

def __init__(self, fmt: str, log_format_local_dev_enabled: bool):
def __init__(self, fmt: str, *, log_format_local_dev_enabled: bool) -> None:
super().__init__(fmt)
self.log_format_local_dev_enabled = log_format_local_dev_enabled

def format(self, record):
def format(self, record) -> str:
if hasattr(record, "func_name_override"):
record.funcName = record.func_name_override
if hasattr(record, "file_name_override"):
Expand All @@ -86,7 +86,7 @@ def format(self, record):
# log_level=%{WORD:log_level} \| log_timestamp=%{TIMESTAMP_ISO8601:log_timestamp} \| log_source=%{DATA:log_source} \| log_msg=%{GREEDYDATA:log_msg}


def config_all_loggers(log_format_local_dev_enabled: bool):
def config_all_loggers(*, log_format_local_dev_enabled: bool) -> None:
"""
Applies common configuration to ALL registered loggers
"""
Expand All @@ -102,19 +102,26 @@ def config_all_loggers(log_format_local_dev_enabled: bool):
fmt = LOCAL_FORMATTING

for logger in loggers:
set_logging_handler(logger, fmt, log_format_local_dev_enabled)
set_logging_handler(
logger, fmt=fmt, log_format_local_dev_enabled=log_format_local_dev_enabled
)


def set_logging_handler(
logger: logging.Logger,
*,
fmt: str,
log_format_local_dev_enabled: bool,
) -> None:
for handler in logger.handlers:
handler.setFormatter(CustomFormatter(fmt, log_format_local_dev_enabled))
handler.setFormatter(
CustomFormatter(
fmt, log_format_local_dev_enabled=log_format_local_dev_enabled
)
)


def test_logger_propagation(logger: logging.Logger):
def test_logger_propagation(logger: logging.Logger) -> None:
"""log propagation and levels can sometimes be daunting to get it right.
This function uses the `logger`` passed as argument to log the same message at different levels
Expand Down Expand Up @@ -161,75 +168,69 @@ def _log_arguments(
# Before to the function execution, log function details.
logger_obj.log(
level,
"Arguments: %s - Begin function",
"%s:%s(%s) - Begin function",
func.__module__.split(".")[-1],
func.__name__,
formatted_arguments,
extra=extra_args,
)

return extra_args


def _log_return_value(
logger_obj: logging.Logger,
level: int,
func: Callable,
result: Any,
extra_args: dict[str, str],
) -> None:
logger_obj.log(
level,
"%s:%s returned %r - End function",
func.__module__.split(".")[-1],
func.__name__,
result,
extra=extra_args,
)


F = TypeVar("F", bound=Callable[..., Any])


def log_decorator(
logger=None, level: int = logging.DEBUG, *, log_traceback: bool = False
):
# Build logger object
logger_obj = logger or _logger
logger: logging.Logger | None, level: int = logging.DEBUG
) -> Callable[[F], F]:
the_logger = logger or _logger

def log_decorator_info(func):
def decorator(func: F) -> F:
if iscoroutinefunction(func):

@functools.wraps(func)
async def log_decorator_wrapper(*args, **kwargs):
extra_args = _log_arguments(logger_obj, level, func, *args, **kwargs)
try:
# log return value from the function
value = await func(*args, **kwargs)
logger_obj.log(
level, "Returned: - End function %r", value, extra=extra_args
)
except:
# log exception if occurs in function
logger_obj.log(
level,
"Exception: %s",
extra=extra_args,
exc_info=log_traceback,
)
raise
# Return function value
return value

else:
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
extra_args = _log_arguments(the_logger, level, func, *args, **kwargs)
with log_catch(the_logger, reraise=True):
result = await func(*args, **kwargs)
_log_return_value(the_logger, level, func, result, extra_args)
return result

@functools.wraps(func)
def log_decorator_wrapper(*args, **kwargs):
extra_args = _log_arguments(logger_obj, level, func, *args, **kwargs)
try:
# log return value from the function
value = func(*args, **kwargs)
logger_obj.log(
level, "Returned: - End function %r", value, extra=extra_args
)
except:
# log exception if occurs in function
logger_obj.log(
level,
"Exception: %s",
extra=extra_args,
exc_info=log_traceback,
)
raise
# Return function value
return value

# Return the pointer to the function
return log_decorator_wrapper

return log_decorator_info
return async_wrapper # type: ignore[return-value] # decorators typing is hard

@functools.wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
extra_args = _log_arguments(the_logger, level, func, *args, **kwargs)
with log_catch(the_logger, reraise=True):
result = func(*args, **kwargs)
_log_return_value(the_logger, level, func, result, extra_args)
return result

return sync_wrapper # type: ignore[return-value] # decorators typing is hard

return decorator


@contextmanager
def log_catch(logger: logging.Logger, reraise: bool = True):
def log_catch(logger: logging.Logger, *, reraise: bool = True) -> Iterator[None]:
try:
yield
except asyncio.CancelledError:
Expand Down Expand Up @@ -257,7 +258,7 @@ def get_log_record_extra(*, user_id: int | str | None = None) -> LogExtra | None
return extra or None


def _un_capitalize(s):
def _un_capitalize(s: str) -> str:
return s[:1].lower() + s[1:] if s else ""


Expand All @@ -277,15 +278,16 @@ def log_context(
kwargs: dict[str, Any] = {}
if extra:
kwargs["extra"] = extra

logger.log(level, "Starting " + msg + " ...", *args, **kwargs)
log_msg = f"Starting {msg} ..."
logger.log(level, log_msg, *args, **kwargs)
yield
duration = (
f" in {(datetime.now() - start ).total_seconds()}s" # noqa: DTZ005
if log_duration
else ""
)
logger.log(level, "Finished " + msg + duration, *args, **kwargs)
log_msg = f"Finished {msg}{duration}"
logger.log(level, log_msg, *args, **kwargs)


def guess_message_log_level(message: str) -> LogLevelInt:
Expand Down
Loading

0 comments on commit 95e54ff

Please sign in to comment.