Skip to content

Commit

Permalink
Fix running tests as non-root
Browse files Browse the repository at this point in the history
  • Loading branch information
jmsmkn committed Jan 10, 2024
1 parent 0b37f5d commit 63e2e63
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 1 deletion.
14 changes: 13 additions & 1 deletion sagemaker_shim/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,18 @@ def output_path(self) -> Path:
logger.debug(f"{output_path=}")
return output_path

@property
def extra_groups(self) -> list[int] | None:
if (
os.environ.get(
"GRAND_CHALLENGE_COMPONENT_KEEP_EXTRA_GROUPS", "False"
).lower()
== "true"
):
return None
else:
return []

@cached_property
def _s3_client(self) -> S3Client:
return get_s3_client()
Expand Down Expand Up @@ -462,7 +474,7 @@ async def execute(self) -> int:
*self.proc_args,
user=self.proc_user.uid,
group=self.proc_user.gid,
extra_groups=[],
extra_groups=self.extra_groups,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=self.proc_env,
Expand Down
1 change: 1 addition & 0 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def test_invocations_endpoint(client, tmp_path, monkeypatch, capsys, minio):
"GRAND_CHALLENGE_COMPONENT_OUTPUT_PATH",
str(output_path),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_KEEP_EXTRA_GROUPS", "True")

debug_log = deepcopy(LOGGING_CONFIG)
debug_log["root"]["level"] = "DEBUG"
Expand Down
4 changes: 4 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def test_inference_from_task_list(
"GRAND_CHALLENGE_COMPONENT_CMD_B64J",
encode_b64j(val=cmd),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_KEEP_EXTRA_GROUPS", "True")

runner = CliRunner()
runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)])
Expand Down Expand Up @@ -142,6 +143,7 @@ def test_inference_from_s3_uri(minio, monkeypatch, cmd, expected_return_code):
"GRAND_CHALLENGE_COMPONENT_CMD_B64J",
encode_b64j(val=cmd),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_KEEP_EXTRA_GROUPS", "True")

definition_key = f"{uuid4()}/invocations.json"

Expand Down Expand Up @@ -183,6 +185,7 @@ def test_logging_setup(minio, monkeypatch):
"GRAND_CHALLENGE_COMPONENT_CMD_B64J",
encode_b64j(val=["echo", "hello"]),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_KEEP_EXTRA_GROUPS", "True")

runner = CliRunner()
result = runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)])
Expand Down Expand Up @@ -210,6 +213,7 @@ def test_logging_stderr_setup(minio, monkeypatch):
val=["bash", "-c", "echo 'hello' >> /dev/stderr && exit 1"]
),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_KEEP_EXTRA_GROUPS", "True")

runner = CliRunner()
result = runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)])
Expand Down
1 change: 1 addition & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ async def test_inference_result_upload(
"GRAND_CHALLENGE_COMPONENT_CMD_B64J",
encode_b64j(val=cmd),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_KEEP_EXTRA_GROUPS", "True")

direct_invocation = await task.invoke()

Expand Down

0 comments on commit 63e2e63

Please sign in to comment.