Skip to content

Commit

Permalink
Update return type
Browse files Browse the repository at this point in the history
  • Loading branch information
jmsmkn committed Dec 16, 2023
1 parent 6abf1c6 commit 12b15e5
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
23 changes: 14 additions & 9 deletions sagemaker_shim/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pathlib import Path
from shutil import rmtree
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, NamedTuple
from zipfile import BadZipFile

import boto3
Expand Down Expand Up @@ -153,6 +153,11 @@ class InferenceResult(BaseModel):
sagemaker_shim_version: str = version("sagemaker-shim")


class UserGroup(NamedTuple):
user: str | None
group: str | None


class InferenceTask(BaseModel):
model_config = ConfigDict(frozen=True)

Expand Down Expand Up @@ -279,18 +284,18 @@ def proc_env(self) -> dict[str, str]:
return env

@property
def proc_user(self) -> dict[str, str | None]:
def proc_user(self) -> UserGroup:
match = re.fullmatch(
r"^(?P<user>[0-9a-zA-Z]*):?(?P<group>[0-9a-zA-Z]*)$", self.user
)

if match:
return {
"user": match.group("user") or None,
"group": match.group("group") or None,
}
return UserGroup(
user=match.group("user") or None,
group=match.group("group") or None,
)
else:
return {"user": None, "group": None}
return UserGroup(user=None, group=None)

async def invoke(self) -> InferenceResult:
"""Run the inference on a single case"""
Expand Down Expand Up @@ -418,8 +423,8 @@ async def execute(self) -> int:

process = await asyncio.create_subprocess_exec(
*self.proc_args,
user=self.proc_user["user"],
group=self.proc_user["group"],
user=self.proc_user.user,
group=self.proc_user.group,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=self.proc_env,
Expand Down
6 changes: 4 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def test_proc_user(monkeypatch, user, expected_user, expected_group):
)

assert t.user == user
assert t.proc_user == {"user": expected_user, "group": expected_group}
assert t.proc_user.user == expected_user
assert t.proc_user.group == expected_group


def test_proc_user_unset():
Expand All @@ -90,4 +91,5 @@ def test_proc_user_unset():
)

assert t.user == ""
assert t.proc_user == {"user": None, "group": None}
assert t.proc_user.user is None
assert t.proc_user.group is None

0 comments on commit 12b15e5

Please sign in to comment.