diff --git a/sagemaker_shim/models.py b/sagemaker_shim/models.py index 39bb582..ff5859e 100644 --- a/sagemaker_shim/models.py +++ b/sagemaker_shim/models.py @@ -158,6 +158,7 @@ class UserInfo(NamedTuple): uid: int | None gid: int | None home: str | None + extra_groups: list[int] class InferenceTask(BaseModel): @@ -244,7 +245,7 @@ def extra_groups(self) -> list[int] | None: ): return None else: - return [] + return self.proc_user.extra_groups @cached_property def _s3_client(self) -> S3Client: @@ -303,7 +304,7 @@ def proc_env(self) -> dict[str, str]: @staticmethod def _get_user_info(id_or_name: str) -> UserInfo: if id_or_name == "": - return UserInfo(uid=None, gid=None, home=None) + return UserInfo(uid=None, gid=None, home=None, extra_groups=[]) try: user = pwd.getpwnam(id_or_name) @@ -316,9 +317,26 @@ def _get_user_info(id_or_name: str) -> UserInfo: try: user = pwd.getpwuid(uid) except (KeyError, AttributeError): - return UserInfo(uid=uid, gid=None, home=None) + return UserInfo(uid=uid, gid=None, home=None, extra_groups=[]) - return UserInfo(uid=user.pw_uid, gid=user.pw_gid, home=user.pw_dir) + users_groups = { + g.gr_gid for g in grp.getgrall() if user.pw_name in g.gr_mem + } + + # pw_gid as the first group + try: + users_groups.remove(user.pw_gid) + except KeyError: + pass + + extra_groups = [user.pw_gid, *sorted(users_groups)] + + return UserInfo( + uid=user.pw_uid, + gid=user.pw_gid, + home=user.pw_dir, + extra_groups=extra_groups, + ) @staticmethod def _get_group_id(id_or_name: str) -> int | None: @@ -339,7 +357,7 @@ def _get_group_id(id_or_name: str) -> int | None: @cached_property def proc_user(self) -> UserInfo: if self.user == "": - return UserInfo(uid=None, gid=None, home=None) + return UserInfo(uid=None, gid=None, home=None, extra_groups=[]) match = re.fullmatch( r"^(?P[0-9a-zA-Z]*):?(?P[0-9a-zA-Z]*)$", self.user @@ -353,6 +371,7 @@ def proc_user(self) -> UserInfo: uid=info.uid, gid=info.gid if gid is None else gid, home=info.home, + extra_groups=info.extra_groups, ) else: raise RuntimeError(f"Invalid user '{self.user}'") diff --git a/tests/test_models.py b/tests/test_models.py index 5497148..f5ef8dd 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -62,46 +62,62 @@ def test_removing_ld_library_path(monkeypatch): assert env["LD_LIBRARY_PATH"] == "present" +ROOT_HOME = pwd.getpwnam("root").pw_dir +ROOT_GROUPS = sorted({g.gr_gid for g in grp.getgrall() if "root" in g.gr_mem}) +USER_HOME = os.path.expanduser("~") +USER_GROUPS = { + g.gr_gid for g in grp.getgrall() if getpass.getuser() in g.gr_mem +} +USER_GROUPS = [pwd.getpwnam(getpass.getuser()).pw_gid, *sorted(USER_GROUPS)] + + @pytest.mark.parametrize( - "user,expected_user,expected_group,expected_home", + "user,expected_user,expected_group,expected_home,expected_extra_groups", ( - ("0", 0, 0, pwd.getpwnam("root").pw_dir), - ("0:0", 0, 0, pwd.getpwnam("root").pw_dir), - (":0", None, 0, None), - ("", None, None, None), - ("root", 0, 0, pwd.getpwnam("root").pw_dir), - (f"root:{grp.getgrgid(0).gr_name}", 0, 0, pwd.getpwnam("root").pw_dir), - (f":{grp.getgrgid(0).gr_name}", None, 0, None), - ("root:0", 0, 0, pwd.getpwnam("root").pw_dir), - (f"0:{grp.getgrgid(0).gr_name}", 0, 0, pwd.getpwnam("root").pw_dir), - (f":{os.getgid()}", None, os.getgid(), None), - (f"root:{os.getgid()}", 0, os.getgid(), pwd.getpwnam("root").pw_dir), + ("0", 0, 0, ROOT_HOME, ROOT_GROUPS), + ("0:0", 0, 0, ROOT_HOME, ROOT_GROUPS), + (":0", None, 0, None, []), + ("", None, None, None, []), + ("root", 0, 0, ROOT_HOME, ROOT_GROUPS), + (f"root:{grp.getgrgid(0).gr_name}", 0, 0, ROOT_HOME, ROOT_GROUPS), + (f":{grp.getgrgid(0).gr_name}", None, 0, None, []), + ("root:0", 0, 0, ROOT_HOME, ROOT_GROUPS), + (f"0:{grp.getgrgid(0).gr_name}", 0, 0, ROOT_HOME, ROOT_GROUPS), + (f":{os.getgid()}", None, os.getgid(), None, []), + (f"root:{os.getgid()}", 0, os.getgid(), ROOT_HOME, ROOT_GROUPS), # User exists - (f"{os.getuid()}", os.getuid(), os.getgid(), os.path.expanduser("~")), + (f"{os.getuid()}", os.getuid(), os.getgid(), USER_HOME, USER_GROUPS), ( f"{getpass.getuser()}", os.getuid(), os.getgid(), - os.path.expanduser("~"), + USER_HOME, + USER_GROUPS, ), # Group does not exist, but is an int - (f"{os.getuid()}:23746", os.getuid(), 23746, os.path.expanduser("~")), + (f"{os.getuid()}:23746", os.getuid(), 23746, USER_HOME, USER_GROUPS), ( f"{getpass.getuser()}:23746", os.getuid(), 23746, - os.path.expanduser("~"), + USER_HOME, + USER_GROUPS, ), # User does not exist, but is an int - ("23746", 23746, None, None), - (f"23746:{grp.getgrgid(0).gr_name}", 23746, 0, None), - (f"23746:{os.getgid()}", 23746, os.getgid(), None), + ("23746", 23746, None, None, []), + (f"23746:{grp.getgrgid(0).gr_name}", 23746, 0, None, []), + (f"23746:{os.getgid()}", 23746, os.getgid(), None, []), # User and group do not exist, but are ints - ("23746:23746", 23746, 23746, None), + ("23746:23746", 23746, 23746, None, []), ), ) def test_proc_user( - monkeypatch, user, expected_user, expected_group, expected_home + monkeypatch, + user, + expected_user, + expected_group, + expected_home, + expected_extra_groups, ): monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USER", user) @@ -113,6 +129,7 @@ def test_proc_user( assert t.proc_user.uid == expected_user assert t.proc_user.gid == expected_group assert t.proc_user.home == expected_home + assert t.extra_groups == expected_extra_groups # Should error