Skip to content

Commit

Permalink
Add the users groups to extra groups
Browse files Browse the repository at this point in the history
  • Loading branch information
jmsmkn committed Jan 10, 2024
1 parent 63e2e63 commit c68a220
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 26 deletions.
29 changes: 24 additions & 5 deletions sagemaker_shim/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class UserInfo(NamedTuple):
uid: int | None
gid: int | None
home: str | None
extra_groups: list[int]


class InferenceTask(BaseModel):
Expand Down Expand Up @@ -244,7 +245,7 @@ def extra_groups(self) -> list[int] | None:
):
return None
else:
return []
return self.proc_user.extra_groups

@cached_property
def _s3_client(self) -> S3Client:
Expand Down Expand Up @@ -303,7 +304,7 @@ def proc_env(self) -> dict[str, str]:
@staticmethod
def _get_user_info(id_or_name: str) -> UserInfo:
if id_or_name == "":
return UserInfo(uid=None, gid=None, home=None)
return UserInfo(uid=None, gid=None, home=None, extra_groups=[])

try:
user = pwd.getpwnam(id_or_name)
Expand All @@ -316,9 +317,26 @@ def _get_user_info(id_or_name: str) -> UserInfo:
try:
user = pwd.getpwuid(uid)
except (KeyError, AttributeError):
return UserInfo(uid=uid, gid=None, home=None)
return UserInfo(uid=uid, gid=None, home=None, extra_groups=[])

return UserInfo(uid=user.pw_uid, gid=user.pw_gid, home=user.pw_dir)
users_groups = {
g.gr_gid for g in grp.getgrall() if user.pw_name in g.gr_mem
}

# pw_gid as the first group
try:
users_groups.remove(user.pw_gid)
except KeyError:
pass

extra_groups = [user.pw_gid, *sorted(users_groups)]

return UserInfo(
uid=user.pw_uid,
gid=user.pw_gid,
home=user.pw_dir,
extra_groups=extra_groups,
)

@staticmethod
def _get_group_id(id_or_name: str) -> int | None:
Expand All @@ -339,7 +357,7 @@ def _get_group_id(id_or_name: str) -> int | None:
@cached_property
def proc_user(self) -> UserInfo:
if self.user == "":
return UserInfo(uid=None, gid=None, home=None)
return UserInfo(uid=None, gid=None, home=None, extra_groups=[])

match = re.fullmatch(
r"^(?P<user>[0-9a-zA-Z]*):?(?P<group>[0-9a-zA-Z]*)$", self.user
Expand All @@ -353,6 +371,7 @@ def proc_user(self) -> UserInfo:
uid=info.uid,
gid=info.gid if gid is None else gid,
home=info.home,
extra_groups=info.extra_groups,
)
else:
raise RuntimeError(f"Invalid user '{self.user}'")
Expand Down
59 changes: 38 additions & 21 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,46 +62,62 @@ def test_removing_ld_library_path(monkeypatch):
assert env["LD_LIBRARY_PATH"] == "present"


ROOT_HOME = pwd.getpwnam("root").pw_dir
ROOT_GROUPS = sorted({g.gr_gid for g in grp.getgrall() if "root" in g.gr_mem})
USER_HOME = os.path.expanduser("~")
USER_GROUPS = {
g.gr_gid for g in grp.getgrall() if getpass.getuser() in g.gr_mem
}
USER_GROUPS = [pwd.getpwnam(getpass.getuser()).pw_gid, *sorted(USER_GROUPS)]


@pytest.mark.parametrize(
"user,expected_user,expected_group,expected_home",
"user,expected_user,expected_group,expected_home,expected_extra_groups",
(
("0", 0, 0, pwd.getpwnam("root").pw_dir),
("0:0", 0, 0, pwd.getpwnam("root").pw_dir),
(":0", None, 0, None),
("", None, None, None),
("root", 0, 0, pwd.getpwnam("root").pw_dir),
(f"root:{grp.getgrgid(0).gr_name}", 0, 0, pwd.getpwnam("root").pw_dir),
(f":{grp.getgrgid(0).gr_name}", None, 0, None),
("root:0", 0, 0, pwd.getpwnam("root").pw_dir),
(f"0:{grp.getgrgid(0).gr_name}", 0, 0, pwd.getpwnam("root").pw_dir),
(f":{os.getgid()}", None, os.getgid(), None),
(f"root:{os.getgid()}", 0, os.getgid(), pwd.getpwnam("root").pw_dir),
("0", 0, 0, ROOT_HOME, ROOT_GROUPS),
("0:0", 0, 0, ROOT_HOME, ROOT_GROUPS),
(":0", None, 0, None, []),
("", None, None, None, []),
("root", 0, 0, ROOT_HOME, ROOT_GROUPS),
(f"root:{grp.getgrgid(0).gr_name}", 0, 0, ROOT_HOME, ROOT_GROUPS),
(f":{grp.getgrgid(0).gr_name}", None, 0, None, []),
("root:0", 0, 0, ROOT_HOME, ROOT_GROUPS),
(f"0:{grp.getgrgid(0).gr_name}", 0, 0, ROOT_HOME, ROOT_GROUPS),
(f":{os.getgid()}", None, os.getgid(), None, []),
(f"root:{os.getgid()}", 0, os.getgid(), ROOT_HOME, ROOT_GROUPS),
# User exists
(f"{os.getuid()}", os.getuid(), os.getgid(), os.path.expanduser("~")),
(f"{os.getuid()}", os.getuid(), os.getgid(), USER_HOME, USER_GROUPS),
(
f"{getpass.getuser()}",
os.getuid(),
os.getgid(),
os.path.expanduser("~"),
USER_HOME,
USER_GROUPS,
),
# Group does not exist, but is an int
(f"{os.getuid()}:23746", os.getuid(), 23746, os.path.expanduser("~")),
(f"{os.getuid()}:23746", os.getuid(), 23746, USER_HOME, USER_GROUPS),
(
f"{getpass.getuser()}:23746",
os.getuid(),
23746,
os.path.expanduser("~"),
USER_HOME,
USER_GROUPS,
),
# User does not exist, but is an int
("23746", 23746, None, None),
(f"23746:{grp.getgrgid(0).gr_name}", 23746, 0, None),
(f"23746:{os.getgid()}", 23746, os.getgid(), None),
("23746", 23746, None, None, []),
(f"23746:{grp.getgrgid(0).gr_name}", 23746, 0, None, []),
(f"23746:{os.getgid()}", 23746, os.getgid(), None, []),
# User and group do not exist, but are ints
("23746:23746", 23746, 23746, None),
("23746:23746", 23746, 23746, None, []),
),
)
def test_proc_user(
monkeypatch, user, expected_user, expected_group, expected_home
monkeypatch,
user,
expected_user,
expected_group,
expected_home,
expected_extra_groups,
):
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USER", user)

Expand All @@ -113,6 +129,7 @@ def test_proc_user(
assert t.proc_user.uid == expected_user
assert t.proc_user.gid == expected_group
assert t.proc_user.home == expected_home
assert t.extra_groups == expected_extra_groups


# Should error
Expand Down

0 comments on commit c68a220

Please sign in to comment.