Skip to content

Commit

Permalink
Add GRAND_CHALLENGE_COMPONENT_MAX_MEMORY_MB (#30)
Browse files Browse the repository at this point in the history
Sets an explicit memory limit rather than trying to calculate it
internally. Removes psutil.

See DIAGNijmegen/rse-grand-challenge-admin#306
  • Loading branch information
jmsmkn authored Aug 12, 2024
1 parent 837d45f commit 494db4f
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 44 deletions.
44 changes: 14 additions & 30 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sagemaker-shim"
version = "0.3.4"
version = "0.3.5"
description = "Adapts algorithms that implement the Grand Challenge inference API for running in SageMaker"
authors = ["James Meakin <[email protected]>"]
license = "Apache-2.0"
Expand All @@ -21,7 +21,6 @@ fastapi = "!=0.89.0" # See https://github.com/DIAGNijmegen/rse-sagemaker-shim/i
uvicorn = "*"
click = "*"
boto3 = "*"
psutil = "*"

[tool.poetry.group.dev.dependencies]
pytest = "!=8.0.0" # pytest 8 is not yet supported by pytest-asyncio
Expand Down
22 changes: 10 additions & 12 deletions sagemaker_shim/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import Any, TypeVar

import click
import psutil
import uvicorn
from botocore.exceptions import ClientError, NoCredentialsError
from pydantic import ValidationError
Expand All @@ -22,6 +21,8 @@
get_s3_file_content,
)

logger = logging.getLogger(__name__)

T = TypeVar("T")


Expand All @@ -38,6 +39,7 @@ def wrapper(*args: Any, **kwargs: Any) -> T:
@click.group()
def cli() -> None:
logging.config.dictConfig(LOGGING_CONFIG)
set_memory_limits()


@cli.command(short_help="Start the model server")
Expand Down Expand Up @@ -105,23 +107,19 @@ async def invoke(tasks: str, file: str) -> None:


def set_memory_limits() -> None:
reserved_bytes = int(
os.environ.get(
"GRAND_CHALLENGE_COMPONENT_RESERVED_BYTES", 1_073_741_824
)
max_memory_mb = int(
os.environ.get("GRAND_CHALLENGE_COMPONENT_MAX_MEMORY_MB", "0")
)

if reserved_bytes:
total_memory_bytes = psutil.virtual_memory().total

limit = total_memory_bytes - reserved_bytes

if max_memory_mb:
logger.info(f"Setting memory limit to {max_memory_mb} MB")
limit = max_memory_mb * 1024 * 1024
resource.setrlimit(resource.RLIMIT_DATA, (limit, limit))
else:
logger.info("Not setting a memory limit")


if __name__ == "__main__":
set_memory_limits()

# https://pyinstaller.org/en/stable/runtime-information.html#run-time-information
we_are_bundled = getattr(sys, "frozen", False) and hasattr(sys, "_MEIPASS")

Expand Down
62 changes: 62 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import io
import json
import resource
from unittest.mock import patch
from uuid import uuid4

import pytest
Expand Down Expand Up @@ -228,3 +230,63 @@ def test_logging_stderr_setup(minio, monkeypatch):
'{"log": "hello", "level": "WARNING", '
f'"source": "stderr", "internal": false, "task": "{pk}"}}'
) in result.output


def test_memory_limit_undefined(minio, monkeypatch):
pk = str(uuid4())
tasks = [
{
"pk": pk,
"inputs": [],
"output_bucket_name": minio.output_bucket_name,
"output_prefix": f"tasks/{pk}",
}
]

monkeypatch.setenv(
"GRAND_CHALLENGE_COMPONENT_CMD_B64J",
encode_b64j(val=["echo", "hello"]),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")

runner = CliRunner()
result = runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)])

assert (
'{"log": "Not setting a memory limit", "level": "INFO", '
'"source": "stdout", "internal": true, "task": null}'
) in result.output


def test_memory_limit_defined(minio, monkeypatch):
pk = str(uuid4())
tasks = [
{
"pk": pk,
"inputs": [],
"output_bucket_name": minio.output_bucket_name,
"output_prefix": f"tasks/{pk}",
}
]

monkeypatch.setenv(
"GRAND_CHALLENGE_COMPONENT_CMD_B64J",
encode_b64j(val=["echo", "hello"]),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_MAX_MEMORY_MB", "1337")

expected_limit = 1337 * 1024 * 1024

with patch("resource.setrlimit") as mock_setrlimit:
runner = CliRunner()
result = runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)])

mock_setrlimit.assert_called_once_with(
resource.RLIMIT_DATA, (expected_limit, expected_limit)
)

assert (
'{"log": "Setting memory limit to 1337 MB", "level": "INFO", '
'"source": "stdout", "internal": true, "task": null}'
) in result.output

0 comments on commit 494db4f

Please sign in to comment.