From 162f3d98e655183332e0e242fe0b7f3762ce86ab Mon Sep 17 00:00:00 2001 From: James Meakin <12661555+jmsmkn@users.noreply.github.com> Date: Tue, 6 Feb 2024 14:24:09 +0100 Subject: [PATCH] Pre-commit --- .pre-commit-config.yaml | 8 ++-- sagemaker_shim/cli.py | 11 ++++-- sagemaker_shim/extract.py | 7 ++-- sagemaker_shim/models.py | 43 +++++++++++++++------- sagemaker_shim/vendor/werkzeug/security.py | 1 + tests/test_io.py | 3 +- 6 files changed, 49 insertions(+), 24 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e1d7b0..6404691 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,16 +14,16 @@ repos: language: python args: [--py311-plus] - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort - repo: https://github.com/ambv/black - rev: 23.11.0 + rev: 24.1.1 hooks: - id: black language: python - repo: https://github.com/PyCQA/flake8 - rev: 6.1.0 + rev: 7.0.0 hooks: - id: flake8 language: python @@ -35,7 +35,7 @@ repos: - mccabe - yesqa - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.7.1' + rev: 'v1.8.0' hooks: - id: mypy additional_dependencies: diff --git a/sagemaker_shim/cli.py b/sagemaker_shim/cli.py index 995404a..652f680 100644 --- a/sagemaker_shim/cli.py +++ b/sagemaker_shim/cli.py @@ -15,7 +15,11 @@ from sagemaker_shim.app import app from sagemaker_shim.logging import LOGGING_CONFIG -from sagemaker_shim.models import InferenceTaskList, get_s3_file_content, DependentData +from sagemaker_shim.models import ( + DependentData, + InferenceTaskList, + get_s3_file_content, +) T = TypeVar("T") @@ -53,7 +57,9 @@ def cli() -> None: @cli.command(short_help="Start the model server") def serve() -> None: with DependentData(): - uvicorn.run(app=app, host="0.0.0.0", port=8080, log_config=None, workers=1) + uvicorn.run( + app=app, host="0.0.0.0", port=8080, log_config=None, workers=1 + ) @cli.command(short_help="Invoke the model") @@ -105,7 +111,6 @@ async def invoke(tasks: str, file: str) -> None: raise click.UsageError("Empty task list provided") - if __name__ == "__main__": # https://pyinstaller.org/en/stable/runtime-information.html#run-time-information we_are_bundled = getattr(sys, "frozen", False) and hasattr(sys, "_MEIPASS") diff --git a/sagemaker_shim/extract.py b/sagemaker_shim/extract.py index dc455c8..e647d43 100644 --- a/sagemaker_shim/extract.py +++ b/sagemaker_shim/extract.py @@ -68,8 +68,9 @@ def safe_extract(*, src: Path, dest: Path) -> None: f"Extracting {member['src']=} from {src} to {file_dest}" ) - with zf.open(member["src"], "r") as fs, open( - file_dest, "wb" - ) as fd: + with ( + zf.open(member["src"], "r") as fs, + open(file_dest, "wb") as fd, + ): while chunk := fs.read(8192): fd.write(chunk) diff --git a/sagemaker_shim/models.py b/sagemaker_shim/models.py index 30d5ed7..bd1e631 100644 --- a/sagemaker_shim/models.py +++ b/sagemaker_shim/models.py @@ -13,7 +13,8 @@ from functools import cached_property from importlib.metadata import version from pathlib import Path -from tempfile import TemporaryDirectory, SpooledTemporaryFile +from tempfile import SpooledTemporaryFile, TemporaryDirectory +from types import TracebackType from typing import TYPE_CHECKING, Any, NamedTuple from zipfile import BadZipFile @@ -44,11 +45,13 @@ def get_s3_client() -> S3Client: "s3", endpoint_url=os.environ.get("AWS_S3_ENDPOINT_URL") ) + class S3File(NamedTuple): bucket: str key: str -def parse_s3_uri(*, s3_uri) -> S3File: + +def parse_s3_uri(*, s3_uri: str) -> S3File: pattern = r"^(https|s3)://(?P[^/]+)/?(?P.*)$" match = re.fullmatch(pattern, s3_uri) @@ -57,6 +60,7 @@ def parse_s3_uri(*, s3_uri) -> S3File: return S3File(bucket=match.group("bucket"), key=match.group("key")) + def get_s3_file_content(*, s3_uri: str) -> bytes: s3_file = parse_s3_uri(s3_uri=s3_uri) @@ -72,6 +76,7 @@ def get_s3_file_content(*, s3_uri: str) -> bytes: return content.read() + def download_and_extract_tarball(*, s3_uri: str, dest: Path) -> None: s3_file = parse_s3_uri(s3_uri=s3_uri) s3_client = get_s3_client() @@ -128,30 +133,44 @@ def ground_truth_dest(self) -> Path: @property def post_clean_directories(self) -> list[Path]: - return [Path(p) for p in os.environ.get("GRAND_CHALLENGE_COMPONENT_POST_CLEAN_DIRECTORIES", "").split(":") if p] - - def __enter__(self): + return [ + Path(p) + for p in os.environ.get( + "GRAND_CHALLENGE_COMPONENT_POST_CLEAN_DIRECTORIES", "" + ).split(":") + if p + ] + + def __enter__(self) -> None: logger.info("Entering DependentData") self.download_model() self.download_ground_truth() - - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: logger.info("Exiting DependentData") for p in self.post_clean_directories: clean_path(p) - def download_model(self): + def download_model(self) -> None: if self.model_source is not None: self.model_dest.mkdir(parents=True, exist_ok=True) - download_and_extract_tarball(s3_uri=self.model_source, dest=self.model_dest) + download_and_extract_tarball( + s3_uri=self.model_source, dest=self.model_dest + ) - def download_ground_truth(self): + def download_ground_truth(self) -> None: if self.ground_truth_source is not None: self.ground_truth_dest.mkdir(parents=True, exist_ok=True) - download_and_extract_tarball(s3_uri=self.ground_truth_source, dest=self.ground_truth_dest) + download_and_extract_tarball( + s3_uri=self.ground_truth_source, dest=self.ground_truth_dest + ) class InferenceIO(BaseModel): @@ -515,8 +534,6 @@ def clean_io(self) -> None: clean_path(path=self.input_path) clean_path(path=self.output_path) - - def download_input(self) -> None: """Download all the inputs to the input path""" for input_file in self.inputs: diff --git a/sagemaker_shim/vendor/werkzeug/security.py b/sagemaker_shim/vendor/werkzeug/security.py index 77729f5..0695cb4 100644 --- a/sagemaker_shim/vendor/werkzeug/security.py +++ b/sagemaker_shim/vendor/werkzeug/security.py @@ -1,4 +1,5 @@ """From Werkzeug 2.1""" + import os import posixpath from pathlib import Path diff --git a/tests/test_io.py b/tests/test_io.py index 36fe5a3..4a6871b 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -13,7 +13,8 @@ InferenceIO, InferenceResult, InferenceTask, - get_s3_client, clean_path, + clean_path, + get_s3_client, ) from tests.utils import encode_b64j