Skip to content

Commit

Permalink
Pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jmsmkn committed Feb 6, 2024
1 parent 46d8084 commit 162f3d9
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 24 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ repos:
language: python
args: [--py311-plus]
- repo: https://github.com/pycqa/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/ambv/black
rev: 23.11.0
rev: 24.1.1
hooks:
- id: black
language: python
- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
rev: 7.0.0
hooks:
- id: flake8
language: python
Expand All @@ -35,7 +35,7 @@ repos:
- mccabe
- yesqa
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.7.1'
rev: 'v1.8.0'
hooks:
- id: mypy
additional_dependencies:
Expand Down
11 changes: 8 additions & 3 deletions sagemaker_shim/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@

from sagemaker_shim.app import app
from sagemaker_shim.logging import LOGGING_CONFIG
from sagemaker_shim.models import InferenceTaskList, get_s3_file_content, DependentData
from sagemaker_shim.models import (
DependentData,
InferenceTaskList,
get_s3_file_content,
)

T = TypeVar("T")

Expand Down Expand Up @@ -53,7 +57,9 @@ def cli() -> None:
@cli.command(short_help="Start the model server")
def serve() -> None:
with DependentData():
uvicorn.run(app=app, host="0.0.0.0", port=8080, log_config=None, workers=1)
uvicorn.run(
app=app, host="0.0.0.0", port=8080, log_config=None, workers=1
)


@cli.command(short_help="Invoke the model")
Expand Down Expand Up @@ -105,7 +111,6 @@ async def invoke(tasks: str, file: str) -> None:
raise click.UsageError("Empty task list provided")



if __name__ == "__main__":
# https://pyinstaller.org/en/stable/runtime-information.html#run-time-information
we_are_bundled = getattr(sys, "frozen", False) and hasattr(sys, "_MEIPASS")
Expand Down
7 changes: 4 additions & 3 deletions sagemaker_shim/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def safe_extract(*, src: Path, dest: Path) -> None:
f"Extracting {member['src']=} from {src} to {file_dest}"
)

with zf.open(member["src"], "r") as fs, open(
file_dest, "wb"
) as fd:
with (
zf.open(member["src"], "r") as fs,
open(file_dest, "wb") as fd,
):
while chunk := fs.read(8192):
fd.write(chunk)
43 changes: 30 additions & 13 deletions sagemaker_shim/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from functools import cached_property
from importlib.metadata import version
from pathlib import Path
from tempfile import TemporaryDirectory, SpooledTemporaryFile
from tempfile import SpooledTemporaryFile, TemporaryDirectory
from types import TracebackType
from typing import TYPE_CHECKING, Any, NamedTuple
from zipfile import BadZipFile

Expand Down Expand Up @@ -44,11 +45,13 @@ def get_s3_client() -> S3Client:
"s3", endpoint_url=os.environ.get("AWS_S3_ENDPOINT_URL")
)


class S3File(NamedTuple):
bucket: str
key: str

def parse_s3_uri(*, s3_uri) -> S3File:

def parse_s3_uri(*, s3_uri: str) -> S3File:
pattern = r"^(https|s3)://(?P<bucket>[^/]+)/?(?P<key>.*)$"
match = re.fullmatch(pattern, s3_uri)

Expand All @@ -57,6 +60,7 @@ def parse_s3_uri(*, s3_uri) -> S3File:

return S3File(bucket=match.group("bucket"), key=match.group("key"))


def get_s3_file_content(*, s3_uri: str) -> bytes:
s3_file = parse_s3_uri(s3_uri=s3_uri)

Expand All @@ -72,6 +76,7 @@ def get_s3_file_content(*, s3_uri: str) -> bytes:

return content.read()


def download_and_extract_tarball(*, s3_uri: str, dest: Path) -> None:
s3_file = parse_s3_uri(s3_uri=s3_uri)
s3_client = get_s3_client()
Expand Down Expand Up @@ -128,30 +133,44 @@ def ground_truth_dest(self) -> Path:

@property
def post_clean_directories(self) -> list[Path]:
return [Path(p) for p in os.environ.get("GRAND_CHALLENGE_COMPONENT_POST_CLEAN_DIRECTORIES", "").split(":") if p]

def __enter__(self):
return [
Path(p)
for p in os.environ.get(
"GRAND_CHALLENGE_COMPONENT_POST_CLEAN_DIRECTORIES", ""
).split(":")
if p
]

def __enter__(self) -> None:
logger.info("Entering DependentData")

self.download_model()
self.download_ground_truth()


def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
logger.info("Exiting DependentData")

for p in self.post_clean_directories:
clean_path(p)

def download_model(self):
def download_model(self) -> None:
if self.model_source is not None:
self.model_dest.mkdir(parents=True, exist_ok=True)
download_and_extract_tarball(s3_uri=self.model_source, dest=self.model_dest)
download_and_extract_tarball(
s3_uri=self.model_source, dest=self.model_dest
)

def download_ground_truth(self):
def download_ground_truth(self) -> None:
if self.ground_truth_source is not None:
self.ground_truth_dest.mkdir(parents=True, exist_ok=True)
download_and_extract_tarball(s3_uri=self.ground_truth_source, dest=self.ground_truth_dest)
download_and_extract_tarball(
s3_uri=self.ground_truth_source, dest=self.ground_truth_dest
)


class InferenceIO(BaseModel):
Expand Down Expand Up @@ -515,8 +534,6 @@ def clean_io(self) -> None:
clean_path(path=self.input_path)
clean_path(path=self.output_path)



def download_input(self) -> None:
"""Download all the inputs to the input path"""
for input_file in self.inputs:
Expand Down
1 change: 1 addition & 0 deletions sagemaker_shim/vendor/werkzeug/security.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""From Werkzeug 2.1"""

import os
import posixpath
from pathlib import Path
Expand Down
3 changes: 2 additions & 1 deletion tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
InferenceIO,
InferenceResult,
InferenceTask,
get_s3_client, clean_path,
clean_path,
get_s3_client,
)
from tests.utils import encode_b64j

Expand Down

0 comments on commit 162f3d9

Please sign in to comment.