Skip to content

Commit

Permalink
Fix user and group determination
Browse files Browse the repository at this point in the history
Closes #18
  • Loading branch information
jmsmkn committed Jan 10, 2024
1 parent a9852c5 commit 527befe
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 42 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sagemaker-shim"
version = "0.2.3"
version = "0.2.4"
description = "Adapts algorithms that implement the Grand Challenge inference API for running in SageMaker"
authors = ["James Meakin <[email protected]>"]
license = "Apache-2.0"
Expand Down
52 changes: 26 additions & 26 deletions sagemaker_shim/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,30 +289,42 @@ def proc_env(self) -> dict[str, str]:
return env

@staticmethod
def _get_user_info(id_or_name: str) -> pwd.struct_passwd | None:
def _get_user_info(id_or_name: str) -> UserGroup:
if id_or_name == "":
return None
return UserGroup(uid=None, gid=None, home=None)

try:
return pwd.getpwnam(id_or_name)
user = pwd.getpwnam(id_or_name)
except (KeyError, AttributeError):
try:
return pwd.getpwuid(int(id_or_name))
except (KeyError, ValueError, AttributeError) as error:
raise RuntimeError(f"User {id_or_name} not found") from error
uid = int(id_or_name)
except ValueError:
raise RuntimeError(
f"User '{id_or_name}' not found"
) from ValueError

try:
user = pwd.getpwuid(uid)
except (KeyError, AttributeError):
return UserGroup(uid=uid, gid=None, home=None)

return UserGroup(uid=user.pw_uid, gid=user.pw_gid, home=user.pw_dir)

@staticmethod
def _get_group_info(id_or_name: str) -> grp.struct_group | None:
def _get_group_info(id_or_name: str) -> int | None:
if id_or_name == "":
return None

try:
return grp.getgrnam(id_or_name)
group = grp.getgrnam(id_or_name)
return group.gr_gid
except (KeyError, AttributeError):
try:
return grp.getgrgid(int(id_or_name))
except (KeyError, ValueError, AttributeError) as error:
raise RuntimeError(f"Group {id_or_name} not found") from error
return int(id_or_name)
except ValueError:
raise RuntimeError(
f"Group '{id_or_name}' not found"
) from ValueError

@cached_property
def proc_user(self) -> UserGroup:
Expand All @@ -324,21 +336,9 @@ def proc_user(self) -> UserGroup:
user = self._get_user_info(id_or_name=match.group("user"))
group = self._get_group_info(id_or_name=match.group("group"))

if user is None:
uid = None
home = None
else:
uid = user.pw_uid
home = user.pw_dir

if group is None:
if user is None:
gid = None
else:
# Switch to the users primary group
gid = user.pw_gid
else:
gid = group.gr_gid
uid = user.uid
gid = user.gid if group is None else group
home = user.home

return UserGroup(uid=uid, gid=gid, home=home)
else:
Expand Down
95 changes: 80 additions & 15 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,47 @@ def test_removing_ld_library_path(monkeypatch):


@pytest.mark.parametrize(
"user,expected_user,expected_group",
"user,expected_user,expected_group,expected_home",
(
("0", 0, 0),
("0:0", 0, 0),
(":0", None, 0),
("", None, None),
("root", 0, 0),
(f"root:{grp.getgrgid(0).gr_name}", 0, 0),
(f":{grp.getgrgid(0).gr_name}", None, 0),
("", None, None),
("🙈:🙉", None, None),
("root:0", 0, 0),
(f"0:{grp.getgrgid(0).gr_name}", 0, 0),
(f":{os.getgid()}", None, os.getgid()),
(f"root:{os.getgid()}", 0, os.getgid()),
("0", 0, 0, pwd.getpwnam("root").pw_dir),
("0:0", 0, 0, pwd.getpwnam("root").pw_dir),
(":0", None, 0, None),
("", None, None, None),
("root", 0, 0, "/var/root"),
(f"root:{grp.getgrgid(0).gr_name}", 0, 0, pwd.getpwnam("root").pw_dir),
(f":{grp.getgrgid(0).gr_name}", None, 0, None),
("🙈:🙉", None, None, None),
("root:0", 0, 0, pwd.getpwnam("root").pw_dir),
(f"0:{grp.getgrgid(0).gr_name}", 0, 0, pwd.getpwnam("root").pw_dir),
(f":{os.getgid()}", None, os.getgid(), None),
(f"root:{os.getgid()}", 0, os.getgid(), pwd.getpwnam("root").pw_dir),
# User exists
(f"{os.getuid()}", os.getuid(), os.getgid(), os.path.expanduser("~")),
(
f"{os.getlogin()}",
os.getuid(),
os.getgid(),
os.path.expanduser("~"),
),
# Group does not exist, but is an int
(f"{os.getuid()}:23746", os.getuid(), 23746, os.path.expanduser("~")),
(
f"{os.getlogin()}:23746",
os.getuid(),
23746,
os.path.expanduser("~"),
),
# User does not exist, but is an int
("23746", 23746, None, None),
(f"23746:{grp.getgrgid(0).gr_name}", 23746, 0, None),
(f"23746:{os.getgid()}", 23746, os.getgid(), None),
# User and group do not exist, but are ints
("23746:23746", 23746, 23746, None),
),
)
def test_proc_user(monkeypatch, user, expected_user, expected_group):
def test_proc_user(
monkeypatch, user, expected_user, expected_group, expected_home
):
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USER", user)

t = InferenceTask(
Expand All @@ -89,6 +112,47 @@ def test_proc_user(monkeypatch, user, expected_user, expected_group):
assert t.user == user
assert t.proc_user.uid == expected_user
assert t.proc_user.gid == expected_group
assert t.proc_user.home == expected_home


# Should error
@pytest.mark.parametrize(
"user,expected_error",
(
(
f"{os.getuid()}:nonExistantGroup",
"Group 'nonExistantGroup' not found",
),
(
f"{os.getlogin()}:nonExistantGroup",
"Group 'nonExistantGroup' not found",
),
("nonExistantUser", "User 'nonExistantUser' not found"),
(
"nonExistantUser:nonExistantGroup",
"User 'nonExistantUser' not found",
),
(":nonExistantGroup", "Group 'nonExistantGroup' not found"),
(
f"nonExistantUser:{grp.getgrgid(0).gr_name}",
"User 'nonExistantUser' not found",
),
(f"nonExistantUser:{os.getgid()}", "User 'nonExistantUser' not found"),
),
)
def test_proc_user_errors(monkeypatch, user, expected_error):
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USER", user)

t = InferenceTask(
pk="test", inputs=[], output_bucket_name="test", output_prefix="test"
)

assert t.user == user

with pytest.raises(RuntimeError) as error:
_ = t.proc_user

assert str(error.value) == expected_error


def test_proc_user_unset():
Expand All @@ -99,6 +163,7 @@ def test_proc_user_unset():
assert t.user == ""
assert t.proc_user.uid is None
assert t.proc_user.gid is None
assert t.proc_user.home is None


def test_home_is_set(monkeypatch):
Expand Down

0 comments on commit 527befe

Please sign in to comment.