diff --git a/sagemaker_shim/cli.py b/sagemaker_shim/cli.py index 3bf4b89..e3081a4 100644 --- a/sagemaker_shim/cli.py +++ b/sagemaker_shim/cli.py @@ -1,9 +1,11 @@ import asyncio import logging.config +import os import sys from collections.abc import Callable, Coroutine from functools import wraps from json import JSONDecodeError +from pathlib import Path from typing import Any, TypeVar import click @@ -28,9 +30,19 @@ def wrapper(*args: Any, **kwargs: Any) -> T: return wrapper +def _ensure_directories_are_writable() -> None: + for directory in os.environ.get( + "GRAND_CHALLENGE_COMPONENT_WRITABLE_DIRECTORIES", "" + ).split(":"): + path = Path(directory) + path.mkdir(exist_ok=True, parents=True) + path.chmod(mode=0o777) + + @click.group() def cli() -> None: logging.config.dictConfig(LOGGING_CONFIG) + _ensure_directories_are_writable() @cli.command(short_help="Start the model server") diff --git a/tests/test_cli.py b/tests/test_cli.py index efcabb3..01406d4 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -218,3 +218,30 @@ def test_logging_stderr_setup(minio, monkeypatch): '{"log": "hello", "level": "WARNING", ' f'"source": "stderr", "internal": false, "task": "{pk}"}}' ) in result.output + + +def test_ensure_directories_are_writable(tmp_path, monkeypatch): + data = tmp_path / "opt" / "ml" / "output" / "data" + data.mkdir(mode=0o755, parents=True) + + model = tmp_path / "opt" / "ml" / "model" + model.mkdir(mode=0o755, parents=True) + + # Do not create the checkpoints dir in the test + checkpoints = tmp_path / "opt" / "ml" / "checkpoints" + + tmp = tmp_path / "tmp" + tmp.mkdir(mode=0o755) + + monkeypatch.setenv( + "GRAND_CHALLENGE_COMPONENT_WRITABLE_DIRECTORIES", + f"{data.absolute()}:{model.absolute()}:{checkpoints.absolute()}:{tmp.absolute()}", + ) + + runner = CliRunner() + runner.invoke(cli, ["invoke"]) + + assert data.stat().st_mode == 0o40777 + assert model.stat().st_mode == 0o40777 + assert checkpoints.stat().st_mode == 0o40777 + assert tmp.stat().st_mode == 0o40777