Skip to content

Commit

Permalink
Add user switching (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmsmkn authored Dec 16, 2023
1 parent 468d0d1 commit efb3564
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 336 deletions.
334 changes: 2 additions & 332 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sagemaker-shim"
version = "0.2.0"
version = "0.2.1a0"
description = "Adapts algorithms that implement the Grand Challenge inference API for running in SageMaker"
authors = ["James Meakin <[email protected]>"]
license = "Apache-2.0"
Expand All @@ -18,7 +18,7 @@ sagemaker-shim = "sagemaker_shim.cli:cli"
# Only support one version of python at a time
python = "^3.11,<3.12"
fastapi = "!=0.89.0"
uvicorn = {extras = ["standard"], version = "*"}
uvicorn = "*"
click = "*"
boto3 = "*"

Expand Down
62 changes: 61 additions & 1 deletion sagemaker_shim/models.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import asyncio
import errno
import grp
import io
import json
import logging
import os
import pwd
import re
import subprocess
from base64 import b64decode
from collections.abc import Callable
from functools import cached_property
from importlib.metadata import version
from pathlib import Path
from shutil import rmtree
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, NamedTuple
from zipfile import BadZipFile

import boto3
Expand Down Expand Up @@ -153,6 +156,11 @@ class InferenceResult(BaseModel):
sagemaker_shim_version: str = version("sagemaker-shim")


class UserGroup(NamedTuple):
user: int | None
group: int | None


class InferenceTask(BaseModel):
model_config = ConfigDict(frozen=True)

Expand Down Expand Up @@ -205,6 +213,10 @@ def entrypoint(self) -> Any:
logger.debug(f"{entrypoint=}")
return entrypoint

@property
def user(self) -> str:
return os.environ.get("GRAND_CHALLENGE_COMPONENT_USER", "")

@property
def input_path(self) -> Path:
"""Local path where the subprocess is expected to read its input files"""
Expand Down Expand Up @@ -274,6 +286,52 @@ def proc_env(self) -> dict[str, str]:

return env

@staticmethod
def _get_user_or_group_id(*, match: re.Match[str], key: str) -> int | None:
value = match.group(key)

if value == "":
return None

if key == "user":
name_lookup: Callable[
[str], pwd.struct_passwd | grp.struct_group
] = pwd.getpwnam
id_lookup: Callable[
[int], pwd.struct_passwd | grp.struct_group
] = pwd.getpwuid
attr = "pw_uid"
elif key == "group":
name_lookup = grp.getgrnam
id_lookup = grp.getgrgid
attr = "gr_gid"
else:
raise RuntimeError("Unknown key")

try:
out: int = getattr(name_lookup(value), attr)
except (KeyError, AttributeError):
try:
out = getattr(id_lookup(int(value)), attr)
except (KeyError, ValueError, AttributeError) as error:
raise RuntimeError(f"{key} {value} not found") from error

return out

@property
def proc_user(self) -> UserGroup:
match = re.fullmatch(
r"^(?P<user>[0-9a-zA-Z]*):?(?P<group>[0-9a-zA-Z]*)$", self.user
)

if match:
return UserGroup(
user=self._get_user_or_group_id(match=match, key="user"),
group=self._get_user_or_group_id(match=match, key="group"),
)
else:
return UserGroup(user=None, group=None)

async def invoke(self) -> InferenceResult:
"""Run the inference on a single case"""

Expand Down Expand Up @@ -400,6 +458,8 @@ async def execute(self) -> int:

process = await asyncio.create_subprocess_exec(
*self.proc_args,
user=self.proc_user.user,
group=self.proc_user.group,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=self.proc_env,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_container_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _container_helper(request) -> None:


@contextmanager
def _container(*, base_image="hello-world:latest", host_port=8080, cmd=None):
def _container(*, base_image="ubuntu:latest", host_port=8080, cmd=None):
client = docker.from_env()
registry = client.containers.run(
image="registry:2.7",
Expand Down Expand Up @@ -93,6 +93,7 @@ def _container(*, base_image="hello-world:latest", host_port=8080, cmd=None):
init=False,
environment=container_env,
links={minio.container.name: "minio"},
user=0,
)

# Wait for startup
Expand Down
39 changes: 39 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import grp
import os

import pytest
Expand Down Expand Up @@ -57,3 +58,41 @@ def test_removing_ld_library_path(monkeypatch):

assert "LD_LIBRARY_PATH" not in t.proc_env
assert env["LD_LIBRARY_PATH"] == "present"


@pytest.mark.parametrize(
"user,expected_user,expected_group",
(
("0", 0, None),
("0:0", 0, 0),
(":0", None, 0),
("", None, None),
("root", 0, None),
(f"root:{grp.getgrgid(0).gr_name}", 0, 0),
(f":{grp.getgrgid(0).gr_name}", None, 0),
("", None, None),
("🙈:🙉", None, None),
("root:0", 0, 0),
(f"0:{grp.getgrgid(0).gr_name}", 0, 0),
),
)
def test_proc_user(monkeypatch, user, expected_user, expected_group):
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USER", user)

t = InferenceTask(
pk="test", inputs=[], output_bucket_name="test", output_prefix="test"
)

assert t.user == user
assert t.proc_user.user == expected_user
assert t.proc_user.group == expected_group


def test_proc_user_unset():
t = InferenceTask(
pk="test", inputs=[], output_bucket_name="test", output_prefix="test"
)

assert t.user == ""
assert t.proc_user.user is None
assert t.proc_user.group is None
2 changes: 2 additions & 0 deletions tests/test_patch_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,13 @@ def test_patch_image(registry):
assert env_vars == {
"GRAND_CHALLENGE_COMPONENT_CMD_B64J": "WyJzaCJd",
"GRAND_CHALLENGE_COMPONENT_ENTRYPOINT_B64J": "bnVsbA==",
"GRAND_CHALLENGE_COMPONENT_USER": "0:0",
}
assert new_config["config"]["Entrypoint"] == ["/sagemaker-shim"]
assert "Cmd" not in new_config["config"]
assert set(new_config["config"]["Env"]) == {
"PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
"GRAND_CHALLENGE_COMPONENT_CMD_B64J=WyJzaCJd",
"GRAND_CHALLENGE_COMPONENT_ENTRYPOINT_B64J=bnVsbA==",
"GRAND_CHALLENGE_COMPONENT_USER=0:0",
}
1 change: 1 addition & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def get_new_env_vars(*, existing_config: dict[str, Any]) -> dict[str, str]:
"GRAND_CHALLENGE_COMPONENT_ENTRYPOINT_B64J": encode_b64j(
val=entrypoint
),
"GRAND_CHALLENGE_COMPONENT_USER": "0:0",
}


Expand Down

0 comments on commit efb3564

Please sign in to comment.