From 12b15e50c0b75c334d76024407ad720d5233f928 Mon Sep 17 00:00:00 2001 From: James Meakin <12661555+jmsmkn@users.noreply.github.com> Date: Sat, 16 Dec 2023 11:28:55 +0100 Subject: [PATCH] Update return type --- sagemaker_shim/models.py | 23 ++++++++++++++--------- tests/test_models.py | 6 ++++-- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/sagemaker_shim/models.py b/sagemaker_shim/models.py index 77373ce..7133c24 100644 --- a/sagemaker_shim/models.py +++ b/sagemaker_shim/models.py @@ -12,7 +12,7 @@ from pathlib import Path from shutil import rmtree from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NamedTuple from zipfile import BadZipFile import boto3 @@ -153,6 +153,11 @@ class InferenceResult(BaseModel): sagemaker_shim_version: str = version("sagemaker-shim") +class UserGroup(NamedTuple): + user: str | None + group: str | None + + class InferenceTask(BaseModel): model_config = ConfigDict(frozen=True) @@ -279,18 +284,18 @@ def proc_env(self) -> dict[str, str]: return env @property - def proc_user(self) -> dict[str, str | None]: + def proc_user(self) -> UserGroup: match = re.fullmatch( r"^(?P[0-9a-zA-Z]*):?(?P[0-9a-zA-Z]*)$", self.user ) if match: - return { - "user": match.group("user") or None, - "group": match.group("group") or None, - } + return UserGroup( + user=match.group("user") or None, + group=match.group("group") or None, + ) else: - return {"user": None, "group": None} + return UserGroup(user=None, group=None) async def invoke(self) -> InferenceResult: """Run the inference on a single case""" @@ -418,8 +423,8 @@ async def execute(self) -> int: process = await asyncio.create_subprocess_exec( *self.proc_args, - user=self.proc_user["user"], - group=self.proc_user["group"], + user=self.proc_user.user, + group=self.proc_user.group, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=self.proc_env, diff --git a/tests/test_models.py b/tests/test_models.py index 9fee19f..acab05a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -81,7 +81,8 @@ def test_proc_user(monkeypatch, user, expected_user, expected_group): ) assert t.user == user - assert t.proc_user == {"user": expected_user, "group": expected_group} + assert t.proc_user.user == expected_user + assert t.proc_user.group == expected_group def test_proc_user_unset(): @@ -90,4 +91,5 @@ def test_proc_user_unset(): ) assert t.user == "" - assert t.proc_user == {"user": None, "group": None} + assert t.proc_user.user is None + assert t.proc_user.group is None