Skip to content

Commit

Permalink
feat: add strict typing
Browse files Browse the repository at this point in the history
  • Loading branch information
pkucmus committed Dec 20, 2024
1 parent 3a8ba24 commit 00f76a3
Show file tree
Hide file tree
Showing 16 changed files with 274 additions and 150 deletions.
29 changes: 20 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,6 @@ check = [
[tool.hatch.envs.types.scripts]
check = "mypy --install-types --non-interactive {args:src/smyth}"

[tool.mypy]
check_untyped_defs = true

[[tool.mypy.overrides]]
module = "setproctitle.*"
ignore_missing_imports = true


## Test environment

[tool.hatch.envs.hatch-test]
Expand Down Expand Up @@ -118,6 +110,25 @@ deploy = "mkdocs gh-deploy --force"

[tool.pytest.ini_options]

## Types configuration

[tool.mypy]
python_version = "3.10"
files = ["src/**/*.py"]
exclude = "tests/.*"
warn_redundant_casts = true
warn_unused_ignores = true
disallow_any_generics = true
check_untyped_defs = true
no_implicit_reexport = true
disallow_untyped_defs = true
strict = true
disable_error_code = ["import-untyped"]

[[tool.mypy.overrides]]
module = "setproctitle.*"
ignore_missing_imports = true

## Coverage configuration

[tool.coverage.run]
Expand Down Expand Up @@ -162,7 +173,7 @@ unfixable = ["UP007"] # typer does not handle PEP604 annotations
ban-relative-imports = "all"

[tool.ruff.lint.mccabe]
max-complexity = 10
max-complexity = 12

[tool.ruff.lint.isort]
known-first-party = ["smyth"]
Expand Down
2 changes: 1 addition & 1 deletion src/smyth/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def run(
quiet: Annotated[
bool, typer.Option(help="Effectively the same as --log-level=ERROR")
] = False,
):
) -> None:
if host:
config.host = host
if port:
Expand Down
7 changes: 4 additions & 3 deletions src/smyth/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any

import toml

Expand Down Expand Up @@ -29,7 +30,7 @@ class Config:
smyth_path_prefix: str = "/smyth"

@classmethod
def from_dict(cls, config_dict: dict):
def from_dict(cls, config_dict: dict[str, Any]) -> "Config":
handler_data = config_dict.pop("handlers")
handlers = {
handler_name: HandlerConfig(**handler_config)
Expand All @@ -48,7 +49,7 @@ def get_config_file_path(file_name: str = "pyproject.toml") -> Path:
return directory.joinpath(file_name).resolve()


def get_config_dict(config_file_name: str | None = None) -> dict:
def get_config_dict(config_file_name: str | None = None) -> dict[str, Any]:
"""Get config dict."""
if config_file_name:
config_file_path = get_config_file_path(config_file_name)
Expand All @@ -58,7 +59,7 @@ def get_config_dict(config_file_name: str | None = None) -> dict:
return toml.load(config_file_path)


def get_config(config_dict: dict) -> Config:
def get_config(config_dict: dict[str, Any]) -> Config:
"""Get config."""
if environ_config := os.environ.get("__SMYTH_CONFIG"):
config_data = json.loads(environ_config)
Expand Down
2 changes: 1 addition & 1 deletion src/smyth/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

async def generate_context_data(
request: Request | None, smyth_handler: SmythHandler, process: RunnerProcessProtocol
):
) -> dict[str, Any]:
"""
The data returned by this function is passed to the
`smyth.runner.FaneContext` as kwargs.
Expand Down
8 changes: 6 additions & 2 deletions src/smyth/event.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Any

from starlette.requests import Request

from smyth.types import EventData


async def generate_api_gw_v2_event_data(request: Request):
async def generate_api_gw_v2_event_data(request: Request) -> EventData:
source_ip = None
if request.client:
source_ip = request.client.host
Expand All @@ -28,5 +32,5 @@ async def generate_api_gw_v2_event_data(request: Request):
}


async def generate_lambda_invocation_event_data(request: Request):
async def generate_lambda_invocation_event_data(request: Request) -> Any:
return await request.json()
21 changes: 12 additions & 9 deletions src/smyth/runner/fake_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import sys
from collections.abc import Callable
from time import strftime, time
from typing import Any

from aws_lambda_powertools.utilities.typing import LambdaContext

Expand All @@ -10,7 +12,7 @@ def __init__(
name: str | None = None,
version: str | None = "LATEST",
timeout: int | None = None,
**kwargs,
**kwargs: Any,
):
if name is None:
name = "Fake"
Expand Down Expand Up @@ -39,31 +41,32 @@ def get_remaining_time_in_millis(self) -> int: # type: ignore[override]
)

@property
def function_name(self):
def function_name(self) -> str:
return self.name

@property
def function_version(self):
def function_version(self) -> str:
return self.version

@property
def invoked_function_arn(self):
def invoked_function_arn(self) -> str:
return "arn:aws:lambda:serverless:" + self.name

@property
def memory_limit_in_mb(self):
# This indeed is a string in the real context hence the ignore[override]
def memory_limit_in_mb(self) -> str: # type: ignore[override]
return "1024"

@property
def aws_request_id(self):
def aws_request_id(self) -> str:
return "1234567890"

@property
def log_group_name(self):
def log_group_name(self) -> str:
return "/aws/lambda/" + self.name

@property
def log_stream_name(self):
def log_stream_name(self) -> str:
return (
strftime("%Y/%m/%d")
+ "/[$"
Expand All @@ -72,5 +75,5 @@ def log_stream_name(self):
)

@property
def log(self):
def log(self) -> Callable[[str], int] | Any:
return sys.stdout.write
106 changes: 64 additions & 42 deletions src/smyth/runner/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import signal
import sys
import traceback
from collections.abc import Generator
from multiprocessing import Process, Queue, set_start_method
from queue import Empty
from time import time
from types import FrameType

from asgiref.sync import sync_to_async
from setproctitle import setproctitle
Expand All @@ -18,7 +20,18 @@
SubprocessError,
)
from smyth.runner.fake_context import FakeLambdaContext
from smyth.types import LambdaHandler, RunnerMessage, SmythHandlerState
from smyth.types import (
EventData,
LambdaErrorResponse,
LambdaHandler,
LambdaResponse,
RunnerErrorMessage,
RunnerInputMessage,
RunnerOutputMessage,
RunnerResponseMessage,
RunnerStatusMessage,
SmythHandlerState,
)
from smyth.utils import get_logging_config, import_attribute

set_start_method("spawn", force=True)
Expand All @@ -37,24 +50,24 @@ def __init__(self, name: str, lambda_handler_path: str, log_level: str = "INFO")
self.last_used_timestamp = 0
self.state = SmythHandlerState.COLD

self.input_queue: Queue[RunnerMessage] = Queue(maxsize=1)
self.output_queue: Queue[RunnerMessage] = Queue(maxsize=1)
self.input_queue: Queue[RunnerInputMessage] = Queue(maxsize=1)
self.output_queue: Queue[RunnerOutputMessage] = Queue(maxsize=1)

self.lambda_handler_path = lambda_handler_path
self.log_level = log_level
super().__init__(
name=name,
)

def stop(self):
self.input_queue.put({"type": "smyth.stop"})
def stop(self) -> None:
self.input_queue.put(RunnerInputMessage(type="smyth.stop"))
self.join()
self.input_queue.close()
self.output_queue.close()
self.input_queue.join_thread()
self.output_queue.join_thread()

def send(self, data) -> RunnerMessage | None:
def send(self, data: RunnerInputMessage) -> LambdaResponse | None:
LOGGER.debug("Sending data to process %s: %s", self.name, data)
self.task_counter += 1
self.last_used_timestamp = time()
Expand All @@ -78,30 +91,31 @@ def send(self, data) -> RunnerMessage | None:
return None

LOGGER.debug("Received message from process %s: %s", self.name, message)
if message["type"] == "smyth.lambda.status":
self.state = SmythHandlerState(message["status"])
elif message["type"] == "smyth.lambda.response":

if message.type == "smyth.lambda.status":
self.state = message.status
elif message.type == "smyth.lambda.response":
self.state = SmythHandlerState.WARM
return message["response"]
elif message["type"] == "smyth.lambda.error":
return message.response
elif message.type == "smyth.lambda.error":
self.state = SmythHandlerState.WARM
if message["response"]["type"] == "LambdaTimeoutError":
raise LambdaTimeoutError(message["response"]["message"])
if message.error.type == "LambdaTimeoutError":
raise LambdaTimeoutError(message.error.message)
else:
raise LambdaInvocationError(message["response"]["message"])
raise LambdaInvocationError(message.error.message)

@sync_to_async(thread_sensitive=False)
def asend(self, data) -> RunnerMessage | None:
def asend(self, data: RunnerInputMessage) -> LambdaResponse | None:
return self.send(data)

# Backend

def run(self):
def run(self) -> None:
setproctitle(f"smyth:{self.name}")
logging.config.dictConfig(get_logging_config(self.log_level))
self.lambda_invoker__()

def get_message__(self):
def get_message__(self) -> Generator[RunnerInputMessage, None, None]:
while True:
try:
message = self.input_queue.get(block=True, timeout=1)
Expand All @@ -112,21 +126,27 @@ def get_message__(self):
continue
else:
LOGGER.debug("Received message: %s", message)
if message["type"] == "smyth.stop":
if message.type == "smyth.stop":
LOGGER.debug("Stopping process")
return
yield message

def get_event__(self, message):
return message["event"]
def get_event__(self, message: RunnerInputMessage) -> EventData:
if message.event is None:
raise LambdaInvocationError("No event data provided")
return message.event

def get_context__(self, message):
return FakeLambdaContext(**message["context"])
def get_context__(self, message: RunnerInputMessage) -> FakeLambdaContext:
if message.context is None:
raise LambdaInvocationError("No context data provided")
return FakeLambdaContext(**message.context)

def import_handler__(self, lambda_handler_path, event, context):
def import_handler__(
self, lambda_handler_path: str, event: EventData, context: FakeLambdaContext
) -> LambdaHandler:
LOGGER.info("Starting cold, importing '%s'", lambda_handler_path)
try:
handler = import_attribute(lambda_handler_path)
handler: LambdaHandler = import_attribute(lambda_handler_path)
except ImportError as error:
raise LambdaHandlerLoadError(
f"Error importing handler: {error}, module not found"
Expand All @@ -146,21 +166,23 @@ def import_handler__(self, lambda_handler_path, event, context):
)
return handler

def set_status__(self, status: SmythHandlerState):
self.output_queue.put({"type": "smyth.lambda.status", "status": status})
def set_status__(self, status: SmythHandlerState) -> None:
self.output_queue.put(
RunnerStatusMessage(type="smyth.lambda.status", status=status)
)

@staticmethod
def timeout_handler__(signum, frame):
def timeout_handler__(signum: int, frame: FrameType | None) -> None:
raise LambdaTimeoutError("Lambda timeout")

def lambda_invoker__(self):
def lambda_invoker__(self) -> None:
sys.stdin = open("/dev/stdin")
lambda_handler: LambdaHandler | None = None
self.set_status__(SmythHandlerState.COLD)

for message in self.get_message__():
if message.get("type") != "smyth.lambda.invoke":
LOGGER.error("Invalid message type: %s", message.get("type"))
if message.type != "smyth.lambda.invoke":
LOGGER.error("Invalid message type: %s", message.type)
continue

event = self.get_event__(message)
Expand All @@ -186,21 +208,21 @@ def lambda_invoker__(self):
extra={"log_setting": "console_full_width"},
)
self.output_queue.put(
{
"type": "smyth.lambda.error",
"response": {
"type": type(error).__name__,
"message": str(error),
"stacktrace": traceback.format_exc(),
},
}
RunnerErrorMessage(
type="smyth.lambda.error",
error=LambdaErrorResponse(
type=type(error).__name__,
message=str(error),
stacktrace=traceback.format_exc(),
),
)
)
else:
self.output_queue.put(
{
"type": "smyth.lambda.response",
"response": response,
}
RunnerResponseMessage(
type="smyth.lambda.response",
response=response,
)
)
finally:
signal.alarm(0)
Loading

0 comments on commit 00f76a3

Please sign in to comment.