From 48336b18cc6d76c6f3a57543a44cbc03bd7606c4 Mon Sep 17 00:00:00 2001 From: James Meakin <12661555+jmsmkn@users.noreply.github.com> Date: Thu, 15 Feb 2024 12:32:16 +0100 Subject: [PATCH] Change the ownership of extracted files to those of the process user (#22) See DIAGNijmegen/rse-roadmap#297 --- poetry.lock | 94 ++++++------- pyproject.toml | 2 +- sagemaker_shim/cli.py | 6 +- sagemaker_shim/models.py | 281 ++++++++++++++++++++++----------------- tests/test_cli.py | 4 +- tests/test_models.py | 105 ++++++++++----- 6 files changed, 283 insertions(+), 209 deletions(-) diff --git a/poetry.lock b/poetry.lock index 4323ae3..2b5dcf8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -44,17 +44,17 @@ trio = ["trio (>=0.23)"] [[package]] name = "boto3" -version = "1.34.35" +version = "1.34.42" description = "The AWS SDK for Python" optional = false python-versions = ">= 3.8" files = [ - {file = "boto3-1.34.35-py3-none-any.whl", hash = "sha256:53897701ab4f307fbcfdade673eae809dfc5eabb6102053c84907aa27de66e53"}, - {file = "boto3-1.34.35.tar.gz", hash = "sha256:4052031f9ac18924e94be7c30a5f0af5843a4752b8c8cb9034cd978be252b61b"}, + {file = "boto3-1.34.42-py3-none-any.whl", hash = "sha256:5069b2c647c73c8428378e88b32bd23f568001f897a6f01179fae25de72a7ca6"}, + {file = "boto3-1.34.42.tar.gz", hash = "sha256:2ed136f9cf79e783e12424db23e970d1c50e65a8d7a9077efa71cbf8496fb7a3"}, ] [package.dependencies] -botocore = ">=1.34.35,<1.35.0" +botocore = ">=1.34.42,<1.35.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -63,13 +63,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "boto3-stubs" -version = "1.34.35" -description = "Type annotations for boto3 1.34.35 generated with mypy-boto3-builder 7.23.1" +version = "1.34.42" +description = "Type annotations for boto3 1.34.42 generated with mypy-boto3-builder 7.23.1" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-stubs-1.34.35.tar.gz", hash = "sha256:91725aec6109f9b4d1d37a50e7c9478ce9caf41d5bd7debe80a31ec3e1f84a50"}, - {file = "boto3_stubs-1.34.35-py3-none-any.whl", hash = "sha256:b88e0a93e777bdbb9f2f965b7800f9f1fc9624f46d6af1a0032fba9a36690c43"}, + {file = "boto3-stubs-1.34.42.tar.gz", hash = "sha256:c20fcb0b1680be89c185951a2c88406e7dcd48a91814231b59f37dad6249108a"}, + {file = "boto3_stubs-1.34.42-py3-none-any.whl", hash = "sha256:751041b26076de192657f9d8a916c61f02bff6aece122a8260ce65b676e5b96e"}, ] [package.dependencies] @@ -120,7 +120,7 @@ bedrock-agent = ["mypy-boto3-bedrock-agent (>=1.34.0,<1.35.0)"] bedrock-agent-runtime = ["mypy-boto3-bedrock-agent-runtime (>=1.34.0,<1.35.0)"] bedrock-runtime = ["mypy-boto3-bedrock-runtime (>=1.34.0,<1.35.0)"] billingconductor = ["mypy-boto3-billingconductor (>=1.34.0,<1.35.0)"] -boto3 = ["boto3 (==1.34.35)", "botocore (==1.34.35)"] +boto3 = ["boto3 (==1.34.42)", "botocore (==1.34.42)"] braket = ["mypy-boto3-braket (>=1.34.0,<1.35.0)"] budgets = ["mypy-boto3-budgets (>=1.34.0,<1.35.0)"] ce = ["mypy-boto3-ce (>=1.34.0,<1.35.0)"] @@ -462,13 +462,13 @@ xray = ["mypy-boto3-xray (>=1.34.0,<1.35.0)"] [[package]] name = "botocore" -version = "1.34.35" +version = "1.34.42" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">= 3.8" files = [ - {file = "botocore-1.34.35-py3-none-any.whl", hash = "sha256:b67b8c865973202dc655a493317ae14b33d115e49ed6960874eb05d950167b37"}, - {file = "botocore-1.34.35.tar.gz", hash = "sha256:8a2b53ab772584a5f7e2fe1e4a59028b0602cfef8e39d622db7c6b670e4b1ee6"}, + {file = "botocore-1.34.42-py3-none-any.whl", hash = "sha256:93755c3ede4bd9f6180b3118f6b607acca8b633fd2668226794a543ce79a2434"}, + {file = "botocore-1.34.42.tar.gz", hash = "sha256:cf4fad50d09686f03e44418fcae9dd24369658daa556357cedc0790cfcd6fdac"}, ] [package.dependencies] @@ -481,13 +481,13 @@ crt = ["awscrt (==0.19.19)"] [[package]] name = "botocore-stubs" -version = "1.34.35" +version = "1.34.42" description = "Type annotations and code completion for botocore" optional = false python-versions = ">=3.8,<4.0" files = [ - {file = "botocore_stubs-1.34.35-py3-none-any.whl", hash = "sha256:1c386a1d4ec54ea5f2a116ee952ca0e93df8a61c2e7ab74fdf9230396cf67ff5"}, - {file = "botocore_stubs-1.34.35.tar.gz", hash = "sha256:df5d8f369820a64cb84f124a2841b62030f3e86b6bd07949376efab31c3845c5"}, + {file = "botocore_stubs-1.34.42-py3-none-any.whl", hash = "sha256:7f0d783e5c84a378721e5cae0922bb8ef6aad11d748eda3fd48eeb8d22ee6e42"}, + {file = "botocore_stubs-1.34.42.tar.gz", hash = "sha256:d1de9b7f4fd1c3491846862fc635fea8484a2d4c6d92291d736bafed4ab0344b"}, ] [package.dependencies] @@ -762,13 +762,13 @@ files = [ [[package]] name = "httpcore" -version = "1.0.2" +version = "1.0.3" description = "A minimal low-level HTTP client." optional = false python-versions = ">=3.8" files = [ - {file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"}, - {file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"}, + {file = "httpcore-1.0.3-py3-none-any.whl", hash = "sha256:9a6a501c3099307d9fd76ac244e08503427679b1e81ceb1d922485e2f2462ad2"}, + {file = "httpcore-1.0.3.tar.gz", hash = "sha256:5c0f9546ad17dac4d0772b0808856eb616eb8b48ce94f49ed819fd6982a8a544"}, ] [package.dependencies] @@ -779,7 +779,7 @@ h11 = ">=0.13,<0.15" asyncio = ["anyio (>=4.0,<5.0)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] -trio = ["trio (>=0.22.0,<0.23.0)"] +trio = ["trio (>=0.22.0,<0.24.0)"] [[package]] name = "httpx" @@ -1026,23 +1026,23 @@ files = [ [[package]] name = "pyinstaller" -version = "6.3.0" +version = "6.4.0" description = "PyInstaller bundles a Python application and all its dependencies into a single package." optional = false python-versions = "<3.13,>=3.8" files = [ - {file = "pyinstaller-6.3.0-py3-none-macosx_10_13_universal2.whl", hash = "sha256:75a6f2a6f835a2e6e0899d10e60c10caf5defd25aced38b1dd48fbbabc89de07"}, - {file = "pyinstaller-6.3.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:de25beb176f73a944758553caacec46cc665bf3910ad8a174706d79cf6e95340"}, - {file = "pyinstaller-6.3.0-py3-none-manylinux2014_i686.whl", hash = "sha256:e436fcc0ea87c3f132baac916d508c24c84a8f6d8a06c3154fbc753f169b76c7"}, - {file = "pyinstaller-6.3.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:b721d793a33b6d9946c7dd95d3ea7589c0424b51cf1b9fe580f03c544f1336b2"}, - {file = "pyinstaller-6.3.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:96c37a1ee5b2fd5bb25c098ef510661d6d17b6515d0b86d8fc93727dd2475ba3"}, - {file = "pyinstaller-6.3.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:abe91106a3bbccc3f3a27af4325676ecdb6f46cb842ac663625002a870fc503b"}, - {file = "pyinstaller-6.3.0-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:41c937fe8f07ae02009b3b5a96ac3eb0800a4f8a97af142d4100060fe2135bb9"}, - {file = "pyinstaller-6.3.0-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:886b3b995b674905a20ad5b720b47cc395897d7b391117831027a4c8c5d67a58"}, - {file = "pyinstaller-6.3.0-py3-none-win32.whl", hash = "sha256:0597fb04337695e5cc5250253e0655530bf14f264b7a5b7d219cc65f6889c4bd"}, - {file = "pyinstaller-6.3.0-py3-none-win_amd64.whl", hash = "sha256:156b32ba943e0090bcc68e40ae1cb68fd92b7f1ab6fe0bdf8faf3d3cfc4e12dd"}, - {file = "pyinstaller-6.3.0-py3-none-win_arm64.whl", hash = "sha256:1eadbd1fae84e2e6c678d8b4ed6a232ec5c8fe3a839aea5a3071c4c0282f98cc"}, - {file = "pyinstaller-6.3.0.tar.gz", hash = "sha256:914d4c96cc99472e37ac552fdd82fbbe09e67bb592d0717fcffaa99ea74273df"}, + {file = "pyinstaller-6.4.0-py3-none-macosx_10_13_universal2.whl", hash = "sha256:a2e63fa71784f290bbf79b31b60a27c45b17a18b8c7f910757f9474e0c12c95d"}, + {file = "pyinstaller-6.4.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:3127724d1841f785a9916d7b4cfd9595f359925e9ce7d137a16db8c29ca8453b"}, + {file = "pyinstaller-6.4.0-py3-none-manylinux2014_i686.whl", hash = "sha256:a37f83850cb150ad1e00fe92acecc4d39b8e10162a1850a5836a05fcb2daa870"}, + {file = "pyinstaller-6.4.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:28b98fa3c74602bdc4c5a7698e907f31e714cc40a13f6358082bcbc74ddab35c"}, + {file = "pyinstaller-6.4.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:3ae62cf8858ec4dc54df6fa03d29bc78297e3c87caf532887eae8c3893be0789"}, + {file = "pyinstaller-6.4.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e3e1e6922a4260dcacf6f5655b0ca857451e05ac502d01642935d0f2873ad3c7"}, + {file = "pyinstaller-6.4.0-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:78fb66ca753ef8becdf059eaa1e764d384cacb8c2ec76800126f8c9ef6d19a50"}, + {file = "pyinstaller-6.4.0-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:e07cff584600647af7dc279dd04c60cd1b4b1b41947b0753f8fcf1969300a583"}, + {file = "pyinstaller-6.4.0-py3-none-win32.whl", hash = "sha256:c7bc0fbea8a9010484cfa7d3856416003af73271f03ca3da4bc0eaf14680ad17"}, + {file = "pyinstaller-6.4.0-py3-none-win_amd64.whl", hash = "sha256:ec8a08c983e3febb0247893cd9bd59f55b6767a1f649cb41a0a129b8f04ff2cb"}, + {file = "pyinstaller-6.4.0-py3-none-win_arm64.whl", hash = "sha256:11e6da6a6e441379352ee460a8880f2633dac91dac0f5a9eeff5d449d459b046"}, + {file = "pyinstaller-6.4.0.tar.gz", hash = "sha256:1bf608ed947b58614711275a7ff169289b32560dc97ec748ebd5fa8bdec80649"}, ] [package.dependencies] @@ -1050,7 +1050,7 @@ altgraph = "*" macholib = {version = ">=1.8", markers = "sys_platform == \"darwin\""} packaging = ">=22.0" pefile = {version = ">=2022.5.30", markers = "sys_platform == \"win32\""} -pyinstaller-hooks-contrib = ">=2021.4" +pyinstaller-hooks-contrib = ">=2024.0" pywin32-ctypes = {version = ">=0.2.1", markers = "sys_platform == \"win32\""} setuptools = ">=42.0.0" @@ -1060,13 +1060,13 @@ hook-testing = ["execnet (>=1.5.0)", "psutil", "pytest (>=2.7.3)"] [[package]] name = "pyinstaller-hooks-contrib" -version = "2024.0" +version = "2024.1" description = "Community maintained hooks for PyInstaller" optional = false python-versions = ">=3.7" files = [ - {file = "pyinstaller-hooks-contrib-2024.0.tar.gz", hash = "sha256:a7118c1a5c9788595e5c43ad058a7a5b7b6d59e1eceb42362f6ec1f0b61986b0"}, - {file = "pyinstaller_hooks_contrib-2024.0-py2.py3-none-any.whl", hash = "sha256:469b5690df53223e2e8abffb2e44d6ee596e7d79d4b1eed9465123b67439875a"}, + {file = "pyinstaller-hooks-contrib-2024.1.tar.gz", hash = "sha256:51a51ea9e1ae6bd5ffa7ec45eba7579624bf4f2472ff56dba0edc186f6ed46a6"}, + {file = "pyinstaller_hooks_contrib-2024.1-py2.py3-none-any.whl", hash = "sha256:131494f9cfce190aaa66ed82e82c78b2723d1720ce64d012fbaf938f4ab01d35"}, ] [package.dependencies] @@ -1095,17 +1095,17 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no [[package]] name = "pytest-asyncio" -version = "0.23.3" +version = "0.23.5" description = "Pytest support for asyncio" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-asyncio-0.23.3.tar.gz", hash = "sha256:af313ce900a62fbe2b1aed18e37ad757f1ef9940c6b6a88e2954de38d6b1fb9f"}, - {file = "pytest_asyncio-0.23.3-py3-none-any.whl", hash = "sha256:37a9d912e8338ee7b4a3e917381d1c95bfc8682048cb0fbc35baba316ec1faba"}, + {file = "pytest-asyncio-0.23.5.tar.gz", hash = "sha256:3a048872a9c4ba14c3e90cc1aa20cbc2def7d01c7c8db3777ec281ba9c057675"}, + {file = "pytest_asyncio-0.23.5-py3-none-any.whl", hash = "sha256:4e7093259ba018d58ede7d5315131d21923a60f8a6e9ee266ce1589685c89eac"}, ] [package.dependencies] -pytest = ">=7.0.0" +pytest = ">=7.0.0,<9" [package.extras] docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] @@ -1268,18 +1268,18 @@ crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] [[package]] name = "setuptools" -version = "69.0.3" +version = "69.1.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-69.0.3-py3-none-any.whl", hash = "sha256:385eb4edd9c9d5c17540511303e39a147ce2fc04bc55289c322b9e5904fe2c05"}, - {file = "setuptools-69.0.3.tar.gz", hash = "sha256:be1af57fc409f93647f2e8e4573a142ed38724b8cdd389706a867bb4efcf1e78"}, + {file = "setuptools-69.1.0-py3-none-any.whl", hash = "sha256:c054629b81b946d63a9c6e732bc8b2513a7c3ea645f11d0139a2191d735c60c6"}, + {file = "setuptools-69.1.0.tar.gz", hash = "sha256:850894c4195f09c4ed30dba56213bf7c3f21d86ed6bdaafb5df5972593bfc401"}, ] [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] [[package]] @@ -1387,13 +1387,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "uvicorn" -version = "0.27.0.post1" +version = "0.27.1" description = "The lightning-fast ASGI server." optional = false python-versions = ">=3.8" files = [ - {file = "uvicorn-0.27.0.post1-py3-none-any.whl", hash = "sha256:4b85ba02b8a20429b9b205d015cbeb788a12da527f731811b643fd739ef90d5f"}, - {file = "uvicorn-0.27.0.post1.tar.gz", hash = "sha256:54898fcd80c13ff1cd28bf77b04ec9dbd8ff60c5259b499b4b12bb0917f22907"}, + {file = "uvicorn-0.27.1-py3-none-any.whl", hash = "sha256:5c89da2f3895767472a35556e539fd59f7edbe9b1e9c0e1c99eebeadc61838e4"}, + {file = "uvicorn-0.27.1.tar.gz", hash = "sha256:3d9a267296243532db80c83a959a3400502165ade2c1338dea4e67915fd4745a"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index 4931fdf..f2cc3bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sagemaker-shim" -version = "0.3.0a1" +version = "0.3.0a2" description = "Adapts algorithms that implement the Grand Challenge inference API for running in SageMaker" authors = ["James Meakin <12661555+jmsmkn@users.noreply.github.com>"] license = "Apache-2.0" diff --git a/sagemaker_shim/cli.py b/sagemaker_shim/cli.py index 38c59c8..a50ee35 100644 --- a/sagemaker_shim/cli.py +++ b/sagemaker_shim/cli.py @@ -14,7 +14,7 @@ from sagemaker_shim.app import app from sagemaker_shim.logging import LOGGING_CONFIG from sagemaker_shim.models import ( - DependentData, + AuxiliaryData, InferenceTaskList, get_s3_file_content, ) @@ -39,7 +39,7 @@ def cli() -> None: @cli.command(short_help="Start the model server") def serve() -> None: - with DependentData(): + with AuxiliaryData(): uvicorn.run( app=app, host="0.0.0.0", port=8080, log_config=None, workers=1 ) @@ -86,7 +86,7 @@ async def invoke(tasks: str, file: str) -> None: ) from error if parsed_tasks.root: - with DependentData(): + with AuxiliaryData(): for task in parsed_tasks.root: # Only run one task at a time await task.invoke() diff --git a/sagemaker_shim/models.py b/sagemaker_shim/models.py index cc9d5a1..1d85ceb 100644 --- a/sagemaker_shim/models.py +++ b/sagemaker_shim/models.py @@ -26,9 +26,11 @@ from sagemaker_shim.logging import STDOUT_LEVEL if TYPE_CHECKING: + from _typeshed import StrOrBytesPath # pragma: no cover from mypy_boto3_s3 import S3Client # pragma: no cover else: S3Client = object + StrOrBytesPath = object logger = logging.getLogger(__name__) @@ -40,10 +42,115 @@ ) -def get_s3_client() -> S3Client: - return boto3.client( - "s3", endpoint_url=os.environ.get("AWS_S3_ENDPOINT_URL") - ) +class UserInfo(NamedTuple): + uid: int | None + gid: int | None + home: str | None + groups: list[int] + + +class ProcUserMixin: + @property + def _user(self) -> str: + user = os.environ.get("GRAND_CHALLENGE_COMPONENT_USER", "") + logger.debug(f"{user=}") + return user + + @cached_property + def proc_user(self) -> UserInfo: + if self._user == "": + return UserInfo(uid=None, gid=None, home=None, groups=[]) + + match = re.fullmatch( + r"^(?P[0-9a-zA-Z]*):?(?P[0-9a-zA-Z]*)$", self._user + ) + + if match: + info = self._get_user_info(id_or_name=match.group("user")) + group_id = self._get_group_id(id_or_name=match.group("group")) + + gid = info.gid if group_id is None else group_id + + return UserInfo( + uid=info.uid, + gid=gid, + home=info.home, + groups=self._put_gid_first(gid=gid, groups=info.groups), + ) + else: + raise RuntimeError(f"Invalid user '{self._user}'") + + @classmethod + def _get_user_info(cls, id_or_name: str) -> UserInfo: + if id_or_name == "": + return UserInfo(uid=None, gid=None, home=None, groups=[]) + + try: + user = pwd.getpwnam(id_or_name) + except (KeyError, AttributeError): + try: + uid = int(id_or_name) + except ValueError as error: + raise RuntimeError(f"User '{id_or_name}' not found") from error + + try: + user = pwd.getpwuid(uid) + except (KeyError, AttributeError): + return UserInfo(uid=uid, gid=None, home=None, groups=[]) + + return UserInfo( + uid=user.pw_uid, + gid=user.pw_gid, + home=user.pw_dir, + groups=cls._get_users_groups(user=user), + ) + + @classmethod + def _get_users_groups(cls, *, user: pwd.struct_passwd) -> list[int]: + users_groups = [ + g.gr_gid for g in grp.getgrall() if user.pw_name in g.gr_mem + ] + return cls._put_gid_first(gid=user.pw_gid, groups=users_groups) + + @staticmethod + def _put_gid_first(*, gid: int | None, groups: list[int]) -> list[int]: + if gid is None: + return groups + else: + user_groups = set(groups) + + try: + user_groups.remove(gid) + except KeyError: + pass + + return [gid, *sorted(user_groups)] + + @staticmethod + def _get_group_id(id_or_name: str) -> int | None: + if id_or_name == "": + return None + + try: + return grp.getgrnam(id_or_name).gr_gid + except (KeyError, AttributeError): + try: + return int(id_or_name) + except ValueError as error: + raise RuntimeError( + f"Group '{id_or_name}' not found" + ) from error + + +def clean_path(path: Path) -> None: + for f in path.glob("*"): + if f.is_file(): + f.chmod(0o700) + f.unlink() + elif f.is_dir(): + f.chmod(0o700) + clean_path(f) + f.rmdir() class S3File(NamedTuple): @@ -51,6 +158,12 @@ class S3File(NamedTuple): key: str +def get_s3_client() -> S3Client: + return boto3.client( + "s3", endpoint_url=os.environ.get("AWS_S3_ENDPOINT_URL") + ) + + def parse_s3_uri(*, s3_uri: str) -> S3File: pattern = r"^(https|s3)://(?P[^/]+)/?(?P.*)$" match = re.fullmatch(pattern, s3_uri) @@ -77,6 +190,42 @@ def get_s3_file_content(*, s3_uri: str) -> bytes: return content.read() +def validate_bucket_name(v: str) -> str: + if BUCKET_NAME_REGEX.match(v) or BUCKET_ARN_REGEX.match(v): + return v + else: + raise ValueError("Invalid bucket name") + + +class ProcUserTarfile(ProcUserMixin, tarfile.TarFile): + """ + A tarfile that sets the owner of the extracted files to the user and group + specified in the GRAND_CHALLENGE_COMPONENT_USER environment variable. + """ + + def chown( + self, + tarinfo: tarfile.TarInfo, + targetpath: StrOrBytesPath, + numeric_owner: bool, + ) -> None: + if self.proc_user.uid is None and self.proc_user.gid is None: + # No user or group specified, use the default + return super().chown( + tarinfo=tarinfo, + targetpath=targetpath, + numeric_owner=numeric_owner, + ) + else: + # Do not change owner if the user or group is not set + uid = -1 if self.proc_user.uid is None else self.proc_user.uid + gid = -1 if self.proc_user.gid is None else self.proc_user.gid + + logger.debug(f"Changing owner of {targetpath=} to {uid=}, {gid=}") + + os.chown(path=targetpath, uid=uid, gid=gid) + + def download_and_extract_tarball(*, s3_uri: str, dest: Path) -> None: s3_file = parse_s3_uri(s3_uri=s3_uri) s3_client = get_s3_client() @@ -90,29 +239,11 @@ def download_and_extract_tarball(*, s3_uri: str, dest: Path) -> None: f.seek(0) - with tarfile.open(fileobj=f, mode="r") as tar: + with ProcUserTarfile.open(fileobj=f, mode="r") as tar: tar.extractall(path=dest, filter="data") -def validate_bucket_name(v: str) -> str: - if BUCKET_NAME_REGEX.match(v) or BUCKET_ARN_REGEX.match(v): - return v - else: - raise ValueError("Invalid bucket name") - - -def clean_path(path: Path) -> None: - for f in path.glob("*"): - if f.is_file(): - f.chmod(0o700) - f.unlink() - elif f.is_dir(): - f.chmod(0o700) - clean_path(f) - f.rmdir() - - -class DependentData: +class AuxiliaryData: @property def model_source(self) -> str | None: """s3 URI to a .tar.gz file that is extracted to model_dest""" @@ -172,8 +303,8 @@ def post_clean_directories(self) -> list[Path]: logger.debug(f"{post_clean_directories=}") return post_clean_directories - def __enter__(self) -> "DependentData": - logger.info("Setting up Dependent Data") + def __enter__(self) -> "AuxiliaryData": + logger.info("Setting up Auxiliary Data") self.ensure_directories_are_writable() self.download_model() self.download_ground_truth() @@ -185,7 +316,7 @@ def __exit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - logger.info("Cleaning up Dependent Data") + logger.info("Cleaning up Auxiliary Data") for p in self.post_clean_directories: logger.info(f"Cleaning {p=}") clean_path(p) @@ -301,73 +432,7 @@ class InferenceResult(BaseModel): sagemaker_shim_version: str = version("sagemaker-shim") -class UserInfo(NamedTuple): - uid: int | None - gid: int | None - home: str | None - groups: list[int] - - -def _get_user_info(id_or_name: str) -> UserInfo: - if id_or_name == "": - return UserInfo(uid=None, gid=None, home=None, groups=[]) - - try: - user = pwd.getpwnam(id_or_name) - except (KeyError, AttributeError): - try: - uid = int(id_or_name) - except ValueError as error: - raise RuntimeError(f"User '{id_or_name}' not found") from error - - try: - user = pwd.getpwuid(uid) - except (KeyError, AttributeError): - return UserInfo(uid=uid, gid=None, home=None, groups=[]) - - return UserInfo( - uid=user.pw_uid, - gid=user.pw_gid, - home=user.pw_dir, - groups=_get_users_groups(user=user), - ) - - -def _get_users_groups(*, user: pwd.struct_passwd) -> list[int]: - users_groups = [ - g.gr_gid for g in grp.getgrall() if user.pw_name in g.gr_mem - ] - return _put_gid_first(gid=user.pw_gid, groups=users_groups) - - -def _put_gid_first(*, gid: int | None, groups: list[int]) -> list[int]: - if gid is None: - return groups - else: - user_groups = set(groups) - - try: - user_groups.remove(gid) - except KeyError: - pass - - return [gid, *sorted(user_groups)] - - -def _get_group_id(id_or_name: str) -> int | None: - if id_or_name == "": - return None - - try: - return grp.getgrnam(id_or_name).gr_gid - except (KeyError, AttributeError): - try: - return int(id_or_name) - except ValueError as error: - raise RuntimeError(f"Group '{id_or_name}' not found") from error - - -class InferenceTask(BaseModel): +class InferenceTask(ProcUserMixin, BaseModel): model_config = ConfigDict(frozen=True) pk: str @@ -419,10 +484,6 @@ def entrypoint(self) -> Any: logger.debug(f"{entrypoint=}") return entrypoint - @property - def user(self) -> str: - return os.environ.get("GRAND_CHALLENGE_COMPONENT_USER", "") - @property def input_path(self) -> Path: """Local path where the subprocess is expected to read its input files""" @@ -507,30 +568,6 @@ def proc_env(self) -> dict[str, str]: return env - @cached_property - def proc_user(self) -> UserInfo: - if self.user == "": - return UserInfo(uid=None, gid=None, home=None, groups=[]) - - match = re.fullmatch( - r"^(?P[0-9a-zA-Z]*):?(?P[0-9a-zA-Z]*)$", self.user - ) - - if match: - info = _get_user_info(id_or_name=match.group("user")) - group_id = _get_group_id(id_or_name=match.group("group")) - - gid = info.gid if group_id is None else group_id - - return UserInfo( - uid=info.uid, - gid=gid, - home=info.home, - groups=_put_gid_first(gid=gid, groups=info.groups), - ) - else: - raise RuntimeError(f"Invalid user '{self.user}'") - async def invoke(self) -> InferenceResult: """Run the inference on a single case""" diff --git a/tests/test_cli.py b/tests/test_cli.py index 620685b..09a069f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -194,8 +194,8 @@ def test_logging_setup(minio, monkeypatch): '{"log": "hello", "level": "INFO", ' f'"source": "stdout", "internal": false, "task": "{pk}"}}' ) in result.output - assert "Setting up Dependent Data" in result.output - assert "Cleaning up Dependent Data" in result.output + assert "Setting up Auxiliary Data" in result.output + assert "Cleaning up Auxiliary Data" in result.output def test_logging_stderr_setup(minio, monkeypatch): diff --git a/tests/test_models.py b/tests/test_models.py index 18eb142..cdafbc4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -10,15 +10,33 @@ import pytest from sagemaker_shim.models import ( - DependentData, + AuxiliaryData, InferenceTask, - _get_users_groups, - _put_gid_first, + ProcUserMixin, + ProcUserTarfile, get_s3_client, validate_bucket_name, ) +@pytest.fixture +def algorithm_model(): + model_f = io.BytesIO() + with tarfile.open(fileobj=model_f, mode="w:gz") as tar: + content = b"Hello, World!" + file_info = tarfile.TarInfo("model-file1.txt") + file_info.size = len(content) + tar.addfile(file_info, io.BytesIO(content)) + + file_info = tarfile.TarInfo("model-sub/model-file2.txt") + file_info.size = len(content) + tar.addfile(file_info, io.BytesIO(content)) + + model_f.seek(0) + + return model_f + + def test_invalid_bucket_name(): with pytest.raises(ValueError): validate_bucket_name("$!#") @@ -74,9 +92,11 @@ def test_removing_ld_library_path(monkeypatch): ROOT_HOME = pwd.getpwnam("root").pw_dir -ROOT_GROUPS = _get_users_groups(user=pwd.getpwnam("root")) +ROOT_GROUPS = ProcUserMixin._get_users_groups(user=pwd.getpwnam("root")) USER_HOME = os.path.expanduser("~") -USER_GROUPS = _get_users_groups(user=pwd.getpwnam(getpass.getuser())) +USER_GROUPS = ProcUserMixin._get_users_groups( + user=pwd.getpwnam(getpass.getuser()) +) @pytest.mark.parametrize( @@ -97,7 +117,7 @@ def test_removing_ld_library_path(monkeypatch): 0, os.getgid(), ROOT_HOME, - _put_gid_first(gid=os.getgid(), groups=ROOT_GROUPS), + ProcUserMixin._put_gid_first(gid=os.getgid(), groups=ROOT_GROUPS), ), # User exists (f"{os.getuid()}", os.getuid(), os.getgid(), USER_HOME, USER_GROUPS), @@ -114,14 +134,14 @@ def test_removing_ld_library_path(monkeypatch): os.getuid(), 23746, USER_HOME, - _put_gid_first(gid=23746, groups=USER_GROUPS), + ProcUserMixin._put_gid_first(gid=23746, groups=USER_GROUPS), ), ( f"{getpass.getuser()}:23746", os.getuid(), 23746, USER_HOME, - _put_gid_first(gid=23746, groups=USER_GROUPS), + ProcUserMixin._put_gid_first(gid=23746, groups=USER_GROUPS), ), # User does not exist, but is an int ("23746", 23746, None, None, []), @@ -138,6 +158,7 @@ def test_proc_user( expected_group, expected_home, expected_extra_groups, + algorithm_model, ): monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USER", user) @@ -145,21 +166,40 @@ def test_proc_user( pk="test", inputs=[], output_bucket_name="test", output_prefix="test" ) - assert t.user == user + assert t._user == user assert t.proc_user.uid == expected_user assert t.proc_user.gid == expected_group assert t.proc_user.home == expected_home assert t.extra_groups == expected_extra_groups assert None not in t.extra_groups + with ProcUserTarfile.open(fileobj=algorithm_model, mode="r") as tar: + assert tar.proc_user == t.proc_user + def test_put_gid_first(): # Setting None leaves the groups unmodified - assert _put_gid_first(gid=None, groups=[2, 1, 3, 4]) == [2, 1, 3, 4] + assert ProcUserMixin._put_gid_first(gid=None, groups=[2, 1, 3, 4]) == [ + 2, + 1, + 3, + 4, + ] # Setting an existing group puts it first and orders the rest - assert _put_gid_first(gid=3, groups=[2, 1, 3, 4]) == [3, 1, 2, 4] + assert ProcUserMixin._put_gid_first(gid=3, groups=[2, 1, 3, 4]) == [ + 3, + 1, + 2, + 4, + ] # Adding a group puts it first and orders the rest - assert _put_gid_first(gid=5, groups=[2, 1, 3, 4]) == [5, 1, 2, 3, 4] + assert ProcUserMixin._put_gid_first(gid=5, groups=[2, 1, 3, 4]) == [ + 5, + 1, + 2, + 3, + 4, + ] # Should error @@ -197,7 +237,7 @@ def test_proc_user_errors(monkeypatch, user, expected_error): pk="test", inputs=[], output_bucket_name="test", output_prefix="test" ) - assert t.user == user + assert t._user == user with pytest.raises(RuntimeError) as error: _ = t.proc_user @@ -205,16 +245,19 @@ def test_proc_user_errors(monkeypatch, user, expected_error): assert str(error.value) == expected_error -def test_proc_user_unset(): +def test_proc_user_unset(algorithm_model): t = InferenceTask( pk="test", inputs=[], output_bucket_name="test", output_prefix="test" ) - assert t.user == "" + assert t._user == "" assert t.proc_user.uid is None assert t.proc_user.gid is None assert t.proc_user.home is None + with ProcUserTarfile.open(fileobj=algorithm_model, mode="r") as tar: + assert tar.proc_user == t.proc_user + def test_home_is_set(monkeypatch): monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USER", "root") @@ -226,26 +269,15 @@ def test_home_is_set(monkeypatch): assert t.proc_env["HOME"] == pwd.getpwnam("root").pw_dir -def test_model_and_ground_truth_extraction(minio, monkeypatch, tmp_path): +def test_model_and_ground_truth_extraction( + minio, monkeypatch, tmp_path, mocker, algorithm_model +): s3_client = get_s3_client() model_pk = str(uuid4()) - model_f = io.BytesIO() - with tarfile.open(fileobj=model_f, mode="w:gz") as tar: - content = b"Hello, World!" - file_info = tarfile.TarInfo("model-file1.txt") - file_info.size = len(content) - tar.addfile(file_info, io.BytesIO(content)) - - file_info = tarfile.TarInfo("model-sub/model-file2.txt") - file_info.size = len(content) - tar.addfile(file_info, io.BytesIO(content)) - - model_f.seek(0) - s3_client.upload_fileobj( - model_f, minio.input_bucket_name, f"{model_pk}/model.tar.gz" + algorithm_model, minio.input_bucket_name, f"{model_pk}/model.tar.gz" ) ground_truth_pk = str(uuid4()) @@ -292,7 +324,9 @@ def test_model_and_ground_truth_extraction(minio, monkeypatch, tmp_path): f"{model_destination}:{ground_truth_destination}", ) - with DependentData(): + spy = mocker.spy(ProcUserTarfile, "chown") + + with AuxiliaryData(): downloaded_files = { str(f.relative_to(tmp_path)) for f in tmp_path.rglob("**/*") @@ -312,9 +346,12 @@ def test_model_and_ground_truth_extraction(minio, monkeypatch, tmp_path): "ground_truth", } + # We cannot test chown as you need to be root, but we can test that it was called + assert spy.call_count == 4 + def test_ensure_directories_are_writable_unset(): - with DependentData() as d: + with AuxiliaryData() as d: assert d.writable_directories == [] assert d.post_clean_directories == [] assert d.model_source is None @@ -341,7 +378,7 @@ def test_ensure_directories_are_writable_set( directories, ) - d = DependentData() + d = AuxiliaryData() assert d.writable_directories == expected @@ -363,7 +400,7 @@ def test_ensure_directories_are_writable(tmp_path, monkeypatch): f"{data.absolute()}:{model.absolute()}:{checkpoints.absolute()}:{tmp.absolute()}", ) - with DependentData(): + with AuxiliaryData(): pass assert data.stat().st_mode == 0o40777