diff --git a/sagemaker_shim/models.py b/sagemaker_shim/models.py index aaf8e81..39bb582 100644 --- a/sagemaker_shim/models.py +++ b/sagemaker_shim/models.py @@ -234,6 +234,18 @@ def output_path(self) -> Path: logger.debug(f"{output_path=}") return output_path + @property + def extra_groups(self) -> list[int] | None: + if ( + os.environ.get( + "GRAND_CHALLENGE_COMPONENT_KEEP_EXTRA_GROUPS", "False" + ).lower() + == "true" + ): + return None + else: + return [] + @cached_property def _s3_client(self) -> S3Client: return get_s3_client() @@ -462,7 +474,7 @@ async def execute(self) -> int: *self.proc_args, user=self.proc_user.uid, group=self.proc_user.gid, - extra_groups=[], + extra_groups=self.extra_groups, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=self.proc_env, diff --git a/tests/test_app.py b/tests/test_app.py index 7dfa2a1..f61f950 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -73,6 +73,7 @@ def test_invocations_endpoint(client, tmp_path, monkeypatch, capsys, minio): "GRAND_CHALLENGE_COMPONENT_OUTPUT_PATH", str(output_path), ) + monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_KEEP_EXTRA_GROUPS", "True") debug_log = deepcopy(LOGGING_CONFIG) debug_log["root"]["level"] = "DEBUG" diff --git a/tests/test_cli.py b/tests/test_cli.py index 7fbc849..516a434 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -100,6 +100,7 @@ def test_inference_from_task_list( "GRAND_CHALLENGE_COMPONENT_CMD_B64J", encode_b64j(val=cmd), ) + monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_KEEP_EXTRA_GROUPS", "True") runner = CliRunner() runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)]) @@ -142,6 +143,7 @@ def test_inference_from_s3_uri(minio, monkeypatch, cmd, expected_return_code): "GRAND_CHALLENGE_COMPONENT_CMD_B64J", encode_b64j(val=cmd), ) + monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_KEEP_EXTRA_GROUPS", "True") definition_key = f"{uuid4()}/invocations.json" @@ -183,6 +185,7 @@ def test_logging_setup(minio, monkeypatch): "GRAND_CHALLENGE_COMPONENT_CMD_B64J", encode_b64j(val=["echo", "hello"]), ) + monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_KEEP_EXTRA_GROUPS", "True") runner = CliRunner() result = runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)]) @@ -210,6 +213,7 @@ def test_logging_stderr_setup(minio, monkeypatch): val=["bash", "-c", "echo 'hello' >> /dev/stderr && exit 1"] ), ) + monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_KEEP_EXTRA_GROUPS", "True") runner = CliRunner() result = runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)]) diff --git a/tests/test_io.py b/tests/test_io.py index 05a0cd0..7416364 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -312,6 +312,7 @@ async def test_inference_result_upload( "GRAND_CHALLENGE_COMPONENT_CMD_B64J", encode_b64j(val=cmd), ) + monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_KEEP_EXTRA_GROUPS", "True") direct_invocation = await task.invoke()