Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jmsmkn committed Feb 6, 2024
1 parent 6e1a586 commit 629c632
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 9 deletions.
45 changes: 36 additions & 9 deletions sagemaker_shim/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def download_and_extract_tarball(*, s3_uri: str, dest: Path) -> None:
Fileobj=f,
)

f.seek(0)

with tarfile.open(fileobj=f, mode="r:gz") as tar:
tar.extractall(path=dest, filter="data")

Expand Down Expand Up @@ -116,34 +118,52 @@ class DependentData(BaseModel):
@property
def model_source(self) -> str | None:
"""s3 URI to a .tar.gz file that is extracted to model_dest"""
return os.environ.get("GRAND_CHALLENGE_COMPONENT_MODEL")
model = os.environ.get("GRAND_CHALLENGE_COMPONENT_MODEL")
logger.debug(f"{model=}")
return model

@property
def model_dest(self) -> Path:
return Path("/opt/ml/model/")
model_dest = Path(
os.environ.get(
"GRAND_CHALLENGE_COMPONENT_MODEL_DEST", "/opt/ml/model/"
)
)
logger.debug(f"{model_dest=}")
return model_dest

@property
def ground_truth_source(self) -> str | None:
"""s3 URI to a .tar.gz file that is extracted to ground_truth_dest"""
return os.environ.get("GRAND_CHALLENGE_COMPONENT_GROUND_TRUTH")
ground_truth = os.environ.get("GRAND_CHALLENGE_COMPONENT_GROUND_TRUTH")
logger.debug(f"{ground_truth=}")
return ground_truth

@property
def ground_truth_dest(self) -> Path:
return Path("/opt/ml/input/data/ground_truth/")
ground_truth_dest = Path(
os.environ.get(
"GRAND_CHALLENGE_COMPONENT_GROUND_TRUTH_DEST",
"/opt/ml/input/data/ground_truth/",
)
)
logger.debug(f"{ground_truth_dest=}")
return ground_truth_dest

@property
def post_clean_directories(self) -> list[Path]:
return [
post_clean_directories = [
Path(p)
for p in os.environ.get(
"GRAND_CHALLENGE_COMPONENT_POST_CLEAN_DIRECTORIES", ""
).split(":")
if p
]
logger.debug(f"{post_clean_directories=}")
return post_clean_directories

def __enter__(self) -> None:
logger.info("Entering DependentData")

logger.info("Setting up Dependent Data")
self.download_model()
self.download_ground_truth()

Expand All @@ -153,20 +173,27 @@ def __exit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
logger.info("Exiting DependentData")

logger.info("Cleaning up Dependent Data")
for p in self.post_clean_directories:
logger.info(f"Cleaning {p=}")
clean_path(p)

def download_model(self) -> None:
if self.model_source is not None:
logger.info(
f"Downloading model from {self.model_source=} to {self.model_dest=}"
)
self.model_dest.mkdir(parents=True, exist_ok=True)
download_and_extract_tarball(
s3_uri=self.model_source, dest=self.model_dest
)

def download_ground_truth(self) -> None:
if self.ground_truth_source is not None:
logger.info(
f"Downloading ground truth from {self.ground_truth_source=} "
f"to {self.ground_truth_dest=}"
)
self.ground_truth_dest.mkdir(parents=True, exist_ok=True)
download_and_extract_tarball(
s3_uri=self.ground_truth_source, dest=self.ground_truth_dest
Expand Down
2 changes: 2 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def test_logging_setup(minio, monkeypatch):
'{"log": "hello", "level": "INFO", '
f'"source": "stdout", "internal": false, "task": "{pk}"}}'
) in result.output
assert "Setting up Dependent Data" in result.output
assert "Cleaning up Dependent Data" in result.output


def test_logging_stderr_setup(minio, monkeypatch):
Expand Down
92 changes: 92 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import getpass
import grp
import io
import os
import pwd
import tarfile
from uuid import uuid4

import pytest

from sagemaker_shim.models import (
DependentData,
InferenceTask,
_get_users_groups,
_put_gid_first,
get_s3_client,
validate_bucket_name,
)

Expand Down Expand Up @@ -218,3 +223,90 @@ def test_home_is_set(monkeypatch):
)

assert t.proc_env["HOME"] == pwd.getpwnam("root").pw_dir


def test_model_and_ground_truth_extraction(minio, monkeypatch, tmp_path):
s3_client = get_s3_client()

model_pk = str(uuid4())

model_f = io.BytesIO()
with tarfile.open(fileobj=model_f, mode="w:gz") as tar:
content = b"Hello, World!"
file_info = tarfile.TarInfo("model-file1.txt")
file_info.size = len(content)
tar.addfile(file_info, io.BytesIO(content))

file_info = tarfile.TarInfo("model-sub/model-file2.txt")
file_info.size = len(content)
tar.addfile(file_info, io.BytesIO(content))

model_f.seek(0)

s3_client.upload_fileobj(
model_f, minio.input_bucket_name, f"{model_pk}/model.tar.gz"
)

ground_truth_pk = str(uuid4())

ground_truth_f = io.BytesIO()
with tarfile.open(fileobj=ground_truth_f, mode="w:gz") as tar:
content = b"Hello, World!"
file_info = tarfile.TarInfo("gt-file1.txt")
file_info.size = len(content)
tar.addfile(file_info, io.BytesIO(content))

file_info = tarfile.TarInfo("gt-sub/gt-file2.txt")
file_info.size = len(content)
tar.addfile(file_info, io.BytesIO(content))

ground_truth_f.seek(0)

s3_client.upload_fileobj(
ground_truth_f,
minio.input_bucket_name,
f"{ground_truth_pk}/ground_truth.tar.gz",
)

model_destination = tmp_path / "model"
ground_truth_destination = tmp_path / "ground_truth"

monkeypatch.setenv(
"GRAND_CHALLENGE_COMPONENT_MODEL",
f"s3://{minio.input_bucket_name}/{model_pk}/model.tar.gz",
)
monkeypatch.setenv(
"GRAND_CHALLENGE_COMPONENT_MODEL_DEST", str(model_destination)
)
monkeypatch.setenv(
"GRAND_CHALLENGE_COMPONENT_GROUND_TRUTH",
f"s3://{minio.input_bucket_name}/{ground_truth_pk}/ground_truth.tar.gz",
)
monkeypatch.setenv(
"GRAND_CHALLENGE_COMPONENT_GROUND_TRUTH_DEST",
str(ground_truth_destination),
)
monkeypatch.setenv(
"GRAND_CHALLENGE_COMPONENT_POST_CLEAN_DIRECTORIES",
f"{model_destination}:{ground_truth_destination}",
)

with DependentData():
downloaded_files = {
str(f.relative_to(tmp_path))
for f in tmp_path.rglob("**/*")
if f.is_file()
}

assert downloaded_files == {
"model/model-file1.txt",
"model/model-sub/model-file2.txt",
"ground_truth/gt-file1.txt",
"ground_truth/gt-sub/gt-file2.txt",
}

# Files should be cleaned up
assert {str(f.relative_to(tmp_path)) for f in tmp_path.rglob("**/*")} == {
"model",
"ground_truth",
}

0 comments on commit 629c632

Please sign in to comment.