Skip to content

Commit

Permalink
Add GRAND_CHALLENGE_COMPONENT_WRITABLE_DIRECTORIES
Browse files Browse the repository at this point in the history
  • Loading branch information
jmsmkn committed Dec 16, 2023
1 parent 86c8005 commit 2eb8c16
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
12 changes: 12 additions & 0 deletions sagemaker_shim/cli.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import logging.config
import os
import sys
from collections.abc import Callable, Coroutine
from functools import wraps
from json import JSONDecodeError
from pathlib import Path
from typing import Any, TypeVar

import click
Expand All @@ -28,9 +30,19 @@ def wrapper(*args: Any, **kwargs: Any) -> T:
return wrapper


def _ensure_directories_are_writable() -> None:
for directory in os.environ.get(
"GRAND_CHALLENGE_COMPONENT_WRITABLE_DIRECTORIES", ""
).split(":"):
path = Path(directory)
path.mkdir(exist_ok=True, parents=True)
path.chmod(mode=0o777)


@click.group()
def cli() -> None:
logging.config.dictConfig(LOGGING_CONFIG)
_ensure_directories_are_writable()


@cli.command(short_help="Start the model server")
Expand Down
27 changes: 27 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,30 @@ def test_logging_stderr_setup(minio, monkeypatch):
'{"log": "hello", "level": "WARNING", '
f'"source": "stderr", "internal": false, "task": "{pk}"}}'
) in result.output


def test_ensure_directories_are_writable(tmp_path, monkeypatch):
data = tmp_path / "opt" / "ml" / "output" / "data"
data.mkdir(mode=0o755, parents=True)

model = tmp_path / "opt" / "ml" / "model"
model.mkdir(mode=0o755, parents=True)

# Do not create the checkpoints dir in the test
checkpoints = tmp_path / "opt" / "ml" / "checkpoints"

tmp = tmp_path / "tmp"
tmp.mkdir(mode=0o755)

monkeypatch.setenv(
"GRAND_CHALLENGE_COMPONENT_WRITABLE_DIRECTORIES",
f"{data.absolute()}:{model.absolute()}:{checkpoints.absolute()}:{tmp.absolute()}",
)

runner = CliRunner()
runner.invoke(cli, ["invoke"])

assert data.stat().st_mode == 0o40777
assert model.stat().st_mode == 0o40777
assert checkpoints.stat().st_mode == 0o40777
assert tmp.stat().st_mode == 0o40777

0 comments on commit 2eb8c16

Please sign in to comment.