From 43b1f897dc3f8c538baf35f123287b26f70368ef Mon Sep 17 00:00:00 2001 From: yihanzhao Date: Tue, 26 Nov 2024 09:37:26 +1100 Subject: [PATCH 1/7] introduce device_manager --- src/marqo/config.py | 3 + src/marqo/core/embed/embed.py | 2 + src/marqo/core/exceptions.py | 12 +++ src/marqo/core/inference/device_manager.py | 76 +++++++++++++++++++ src/marqo/core/models/add_docs_params.py | 1 + src/marqo/core/monitoring/monitoring.py | 1 + src/marqo/tensor_search/api.py | 1 + src/marqo/tensor_search/on_start_script.py | 4 + src/marqo/tensor_search/tensor_search.py | 2 + src/marqo/tensor_search/utils.py | 2 + src/marqo/tensor_search/web/api_utils.py | 1 + src/marqo/tensor_search/web/api_validation.py | 2 + 12 files changed, 107 insertions(+) create mode 100644 src/marqo/core/inference/device_manager.py diff --git a/src/marqo/config.py b/src/marqo/config.py index d89926499..d2459244e 100644 --- a/src/marqo/config.py +++ b/src/marqo/config.py @@ -2,6 +2,7 @@ from kazoo.handlers.threading import KazooTimeoutError +from marqo.core.inference.device_manager import DeviceManager from marqo.vespa.zookeeper_client import ZookeeperClient from marqo.core.document.document import Document from marqo.core.embed.embed import Embed @@ -39,6 +40,7 @@ def __init__( self.timeout = timeout self.backend = backend if backend is not None else enums.SearchDb.vespa + # TODO [Refactoring device logic] deprecate default_device since it's not used self.default_device = default_device if default_device is not None else ( utils.read_env_vars_and_defaults(EnvVars.MARQO_BEST_AVAILABLE_DEVICE)) @@ -52,6 +54,7 @@ def __init__( self.document = Document(vespa_client, self.index_management) self.recommender = Recommender(vespa_client, self.index_management) self.embed = Embed(vespa_client, self.index_management, self.default_device) + self.device_manager = DeviceManager() def set_is_remote(self, vespa_client: VespaClient): local_host_markers = ["localhost", "0.0.0.0", "127.0.0.1"] diff --git a/src/marqo/core/embed/embed.py b/src/marqo/core/embed/embed.py index bead85221..0493192ef 100644 --- a/src/marqo/core/embed/embed.py +++ b/src/marqo/core/embed/embed.py @@ -29,6 +29,7 @@ def __init__(self, vespa_client: VespaClient, index_management: IndexManagement, self.default_device = default_device @pydantic.validator('default_device') + # TODO [Refactoring device logic] deprecate default_device since it's not used def validate_default_device(cls, value): if not value: raise ValueError("Default Device cannot be 'None'. Marqo default device must have been declared upon startup.") @@ -66,6 +67,7 @@ def embed_content( ) # Set default device if not provided + # TODO [Refactoring device logic] use device info gathered from device manager if device is None: device = utils.read_env_vars_and_defaults("MARQO_BEST_AVAILABLE_DEVICE") diff --git a/src/marqo/core/exceptions.py b/src/marqo/core/exceptions.py index 7467da5e3..5164b4b98 100644 --- a/src/marqo/core/exceptions.py +++ b/src/marqo/core/exceptions.py @@ -112,3 +112,15 @@ class DuplicateDocumentError(AddDocumentsError): class TooManyFieldsError(MarqoError): pass + + +class DeviceError(MarqoError): + pass + + +class CudaDeviceNotAvailableError(DeviceError): + pass + + +class CudaOutOfMemoryError(DeviceError): + pass diff --git a/src/marqo/core/inference/device_manager.py b/src/marqo/core/inference/device_manager.py new file mode 100644 index 000000000..65a9c075f --- /dev/null +++ b/src/marqo/core/inference/device_manager.py @@ -0,0 +1,76 @@ +from enum import Enum +from typing import List + +import psutil +import torch +from pydantic import BaseModel + +from marqo.core.exceptions import CudaDeviceNotAvailableError, CudaOutOfMemoryError +from marqo.logging import get_logger + +logger = get_logger('device_manager') + + +class DeviceType(str, Enum): + cpu = 'cpu' + cuda = 'cuda' + + +class Device(BaseModel): + id: int + name: str + type: DeviceType + total_memory: int + + @classmethod + def cpu(cls) -> 'Device': + return Device(id=-1, name='cpu', type=DeviceType.cpu, total_memory=psutil.virtual_memory().total) + + @classmethod + def cuda(cls, device_id, name, total_memory) -> 'Device': + return Device(id=device_id, name=name, type=DeviceType.cuda, total_memory=total_memory) + + +class DeviceManager: + def __init__(self): + self._is_cuda_available_at_startup: bool = torch.cuda.is_available() + self.devices: List[Device] = [Device.cpu()] + self.best_available_device_type = DeviceType.cpu + + if self._is_cuda_available_at_startup: + self.best_available_device_type = DeviceType.cuda + device_count = torch.cuda.device_count() + for device_id in range(device_count): + self.devices.append(Device.cuda(device_id, + torch.cuda.get_device_name(device_id), + torch.cuda.get_device_properties(device_id).total_memory)) + + logger.debug(f'Found devices {self.devices}. Best available device set to: ' + f'{self.best_available_device_type.value}.') + + @property + def cuda_devices(self): + return [device for device in self.devices if device.type == DeviceType.cuda] + + def cuda_device_health_check(self) -> None: + if not self._is_cuda_available_at_startup: + return + + if not torch.cuda.is_available(): + logger.error('Cuda device becomes unavailable.') + raise CudaDeviceNotAvailableError('Cuda device becomes unavailable.') + + for device in self.cuda_devices: + cuda_device = torch.device(device.name) + memory_stats = torch.cuda.memory_stats(cuda_device) + logger.debug(f'Cuda device {device.name} with total memory {device.total_memory}. ' + f'Memory stats: {memory_stats}') + + try: + torch.randn(3).to(cuda_device) + except RuntimeError as e: + if 'out of memory' in str(e).lower(): + logger.error(f'Cuda device {device.name} is out of memory. Total memory: {device.total_memory}. ' + f'Memory stats: {memory_stats}') + raise CudaOutOfMemoryError(f'Cuda device {device.name} is out of memory. ' + f'({memory_stats["allocated.all.current"]}/{device.total_memory})') diff --git a/src/marqo/core/models/add_docs_params.py b/src/marqo/core/models/add_docs_params.py index 66cf12185..6b551c12d 100644 --- a/src/marqo/core/models/add_docs_params.py +++ b/src/marqo/core/models/add_docs_params.py @@ -62,6 +62,7 @@ class Config: batch_vectorisation_mode: BatchVectorisationMode = BatchVectorisationMode.PER_DOCUMENT def __init__(self, **data: Any): + # TODO [Refactoring device logic] use device info gathered from device manager # Ensure `None` and passing nothing are treated the same for device if "device" not in data or data["device"] is None: data["device"] = get_best_available_device() diff --git a/src/marqo/core/monitoring/monitoring.py b/src/marqo/core/monitoring/monitoring.py index 8a941603d..205686a9e 100644 --- a/src/marqo/core/monitoring/monitoring.py +++ b/src/marqo/core/monitoring/monitoring.py @@ -154,6 +154,7 @@ def _get_vespa_health(self, hostname_filter: Optional[str]) -> VespaHealthStatus ) def get_cuda_info(self) -> MarqoCudaInfoResponse: + # TODO [Refactoring device logic] move this logic to device manager """A function to get information about the CUDA devices on the machine Returns: diff --git a/src/marqo/tensor_search/api.py b/src/marqo/tensor_search/api.py index 06b62e6ac..273a6863a 100644 --- a/src/marqo/tensor_search/api.py +++ b/src/marqo/tensor_search/api.py @@ -112,6 +112,7 @@ def marqo_base_exception_handler(request: Request, exc: base_exceptions.MarqoErr (core_exceptions.InternalError, api_exceptions.InternalError, None, None), (core_exceptions.ApplicationRollbackError, api_exceptions.ApplicationRollbackError, None, None), (core_exceptions.TooManyFieldsError, api_exceptions.BadRequestError, None, None), + (core_exceptions.DeviceError, api_exceptions.InternalError, None, None), # Vespa client exceptions ( diff --git a/src/marqo/tensor_search/on_start_script.py b/src/marqo/tensor_search/on_start_script.py index 73cc21ddb..39bcf5ac0 100644 --- a/src/marqo/tensor_search/on_start_script.py +++ b/src/marqo/tensor_search/on_start_script.py @@ -85,6 +85,7 @@ def run(self): class CUDAAvailable: + # TODO [Refactoring device logic] move this logic to device manager """checks the status of cuda """ logger = get_logger('CUDA device summary') @@ -109,6 +110,7 @@ def id_to_device(id): class SetBestAvailableDevice: + # TODO [Refactoring device logic] move this logic to device manager, get rid of MARQO_BEST_AVAILABLE_DEVICE envvar """sets the MARQO_BEST_AVAILABLE_DEVICE env var """ logger = get_logger('SetBestAvailableDevice') @@ -151,6 +153,7 @@ def __init__(self): self.models = warmed_models # TBD to include cross-encoder/ms-marco-TinyBERT-L-2-v2 + # TODO [Refactoring device logic] use device info gathered from device manager self.default_devices = ['cpu'] if not torch.cuda.is_available() else ['cuda', 'cpu'] self.logger.info(f"pre-loading {self.models} onto devices={self.default_devices}") @@ -230,6 +233,7 @@ def __init__(self): f"Invalid patch model: {model}. Please ensure that this is a valid patch model." ) + # TODO [Refactoring device logic] use device info gathered from device manager self.default_devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] def run(self): diff --git a/src/marqo/tensor_search/tensor_search.py b/src/marqo/tensor_search/tensor_search.py index 8d893e823..a3a762b57 100644 --- a/src/marqo/tensor_search/tensor_search.py +++ b/src/marqo/tensor_search/tensor_search.py @@ -1557,6 +1557,7 @@ def search(config: Config, index_name: str, text: Optional[Union[str, dict, Cust if verbose: print(f"determined_search_method: {search_method}, text query: {text}") + # TODO [Refactoring device logic] use device info gathered from device manager if device is None: selected_device = utils.read_env_vars_and_defaults("MARQO_BEST_AVAILABLE_DEVICE") if selected_device is None: @@ -2263,6 +2264,7 @@ def eject_model(model_name: str, device: str) -> dict: return result +# TODO [Refactoring device logic] move to device manager def get_cpu_info() -> dict: return { "cpu_usage_percent": f"{psutil.cpu_percent(1)} %", # The number 1 is a time interval for CPU usage calculation. diff --git a/src/marqo/tensor_search/utils.py b/src/marqo/tensor_search/utils.py index be10e9216..3ea6c8bfe 100644 --- a/src/marqo/tensor_search/utils.py +++ b/src/marqo/tensor_search/utils.py @@ -88,6 +88,7 @@ def construct_authorized_url(url_base: str, username: str, password: str) -> str def check_device_is_available(device: str) -> bool: + # TODO [Refactoring device logic] move this logic to device manager """Checks if a device is available on the machine Args: @@ -341,6 +342,7 @@ def generate_batches(seq: Sequence, batch_size: int): def get_best_available_device() -> str: + # TODO [Refactoring device logic] replace this with device manager """Get the best available device for Marqo to use and validate it.""" device = read_env_vars_and_defaults(EnvVars.MARQO_BEST_AVAILABLE_DEVICE) if device is None or not check_device_is_available(device): diff --git a/src/marqo/tensor_search/web/api_utils.py b/src/marqo/tensor_search/web/api_utils.py index 1db7cea68..6bc269781 100644 --- a/src/marqo/tensor_search/web/api_utils.py +++ b/src/marqo/tensor_search/web/api_utils.py @@ -10,6 +10,7 @@ def translate_api_device(device: Optional[str]) -> Optional[str]: + # TODO [Refactoring device logic] move this logic to device manager """Translates an API device as given through the API into an internal enum. Args: diff --git a/src/marqo/tensor_search/web/api_validation.py b/src/marqo/tensor_search/web/api_validation.py index 0f2d85f67..09e1e8a65 100644 --- a/src/marqo/tensor_search/web/api_validation.py +++ b/src/marqo/tensor_search/web/api_validation.py @@ -5,6 +5,7 @@ def validate_api_device_string(device: typing.Optional[str]) -> typing.Optional[str]: + # TODO [Refactoring device logic] move this logic to device manager """Validates a device which is an API parameter Args: @@ -47,6 +48,7 @@ def validate_api_device_string(device: typing.Optional[str]) -> typing.Optional[ async def validate_device(device: typing.Optional[str] = None) -> typing.Optional[str]: + # TODO [Refactoring device logic] move this logic to device manager """Translates and validates the device string. Checks if the requested device is available. From 36eb39d0b587b29f5759c5460abc7f1c9e0328dd Mon Sep 17 00:00:00 2001 From: yihanzhao Date: Wed, 27 Nov 2024 16:30:11 +1100 Subject: [PATCH 2/7] expose /healthz for liveness check --- src/marqo/core/inference/device_manager.py | 15 +++++++++++++++ src/marqo/tensor_search/api.py | 14 ++++++++++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/marqo/core/inference/device_manager.py b/src/marqo/core/inference/device_manager.py index 65a9c075f..c7ec2317b 100644 --- a/src/marqo/core/inference/device_manager.py +++ b/src/marqo/core/inference/device_manager.py @@ -53,14 +53,27 @@ def cuda_devices(self): return [device for device in self.devices if device.type == DeviceType.cuda] def cuda_device_health_check(self) -> None: + """ + Checks the status of the CUDA devices, and raises exceptions if it becomes + not available or out of memory. + + raises + - CudaDeviceNotAvailableError if CUDA device is not available. + - CudaOutOfMemoryError if any CUDA device is out of memory. + """ if not self._is_cuda_available_at_startup: + # If the instance is initialised without cuda devices, skip the check return if not torch.cuda.is_available(): + # CUDA devices could become unavailable/unreachable if the docker container running Marqo loses access + # to the device symlinks. There is no way to recover from this, we will need to restart the container. + # See https://github.com/NVIDIA/nvidia-container-toolkit/issues/48 for more details. logger.error('Cuda device becomes unavailable.') raise CudaDeviceNotAvailableError('Cuda device becomes unavailable.') for device in self.cuda_devices: + # TODO confirm whether we should check all devices or just the default one cuda_device = torch.device(device.name) memory_stats = torch.cuda.memory_stats(cuda_device) logger.debug(f'Cuda device {device.name} with total memory {device.total_memory}. ' @@ -70,6 +83,8 @@ def cuda_device_health_check(self) -> None: torch.randn(3).to(cuda_device) except RuntimeError as e: if 'out of memory' in str(e).lower(): + # If we encounter 'CUDA error: out of memory' error consistently, it means some threads are + # holding the memory logger.error(f'Cuda device {device.name} is out of memory. Total memory: {device.total_memory}. ' f'Memory stats: {memory_stats}') raise CudaOutOfMemoryError(f'Cuda device {device.name} is out of memory. ' diff --git a/src/marqo/tensor_search/api.py b/src/marqo/tensor_search/api.py index 273a6863a..c2dc8fc19 100644 --- a/src/marqo/tensor_search/api.py +++ b/src/marqo/tensor_search/api.py @@ -580,16 +580,22 @@ def schema_validation(index_name: str, settings_object: dict): ) +@app.get('/memory', include_in_schema=False) +@utils.enable_debug_apis() +def memory(): + return memory_profiler.get_memory_profile() + + @app.get("/health" , include_in_schema=False) def check_health(marqo_config: config.Config = Depends(get_config)): health_status = marqo_config.monitoring.get_health() return HealthResponse.from_marqo_health_status(health_status) -@app.get('/memory', include_in_schema=False) -@utils.enable_debug_apis() -def memory(): - return memory_profiler.get_memory_profile() +@app.get("/healthz", include_in_schema=False) +def check_health(marqo_config: config.Config = Depends(get_config)): + marqo_config.device_manager.cuda_device_health_check() + return JSONResponse(content={"status": "ok"}, status_code=200) if __name__ == "__main__": From 967df8f6fbc8ee6b23ba6dbdbe3056db51308d45 Mon Sep 17 00:00:00 2001 From: yihanzhao Date: Thu, 28 Nov 2024 14:23:30 +1100 Subject: [PATCH 3/7] Add unit tests --- src/marqo/core/inference/device_manager.py | 35 ++++---- tests/core/inference/test_device_manager.py | 90 +++++++++++++++++++++ 2 files changed, 109 insertions(+), 16 deletions(-) create mode 100644 tests/core/inference/test_device_manager.py diff --git a/src/marqo/core/inference/device_manager.py b/src/marqo/core/inference/device_manager.py index c7ec2317b..47bdb9ce8 100644 --- a/src/marqo/core/inference/device_manager.py +++ b/src/marqo/core/inference/device_manager.py @@ -1,7 +1,6 @@ from enum import Enum -from typing import List +from typing import List, Optional -import psutil import torch from pydantic import BaseModel @@ -20,11 +19,11 @@ class Device(BaseModel): id: int name: str type: DeviceType - total_memory: int + total_memory: Optional[int] = 0 @classmethod def cpu(cls) -> 'Device': - return Device(id=-1, name='cpu', type=DeviceType.cpu, total_memory=psutil.virtual_memory().total) + return Device(id=-1, name='cpu', type=DeviceType.cpu) @classmethod def cuda(cls, device_id, name, total_memory) -> 'Device': @@ -69,23 +68,27 @@ def cuda_device_health_check(self) -> None: # CUDA devices could become unavailable/unreachable if the docker container running Marqo loses access # to the device symlinks. There is no way to recover from this, we will need to restart the container. # See https://github.com/NVIDIA/nvidia-container-toolkit/issues/48 for more details. - logger.error('Cuda device becomes unavailable.') - raise CudaDeviceNotAvailableError('Cuda device becomes unavailable.') + logger.error('Cuda device becomes unavailable') + raise CudaDeviceNotAvailableError('Cuda device becomes unavailable') + # TODO confirm whether we should check all devices or just the default one for device in self.cuda_devices: - # TODO confirm whether we should check all devices or just the default one - cuda_device = torch.device(device.name) - memory_stats = torch.cuda.memory_stats(cuda_device) - logger.debug(f'Cuda device {device.name} with total memory {device.total_memory}. ' - f'Memory stats: {memory_stats}') - try: - torch.randn(3).to(cuda_device) + cuda_device = torch.device(device.name) + memory_stats = torch.cuda.memory_stats(cuda_device) + logger.debug(f'Cuda device {device.name} with total memory {device.total_memory}. ' + f'Memory stats: {str(memory_stats)}') + + torch.randn(3, device=cuda_device) except RuntimeError as e: if 'out of memory' in str(e).lower(): # If we encounter 'CUDA error: out of memory' error consistently, it means some threads are # holding the memory logger.error(f'Cuda device {device.name} is out of memory. Total memory: {device.total_memory}. ' - f'Memory stats: {memory_stats}') - raise CudaOutOfMemoryError(f'Cuda device {device.name} is out of memory. ' - f'({memory_stats["allocated.all.current"]}/{device.total_memory})') + f'Memory stats: {str(memory_stats)}') + allocated_mem = memory_stats.get("allocated.all.current", None) if memory_stats else None + raise CudaOutOfMemoryError(f'Cuda device {device.name} is out of memory: ' + f'({allocated_mem}/{device.total_memory})') + except Exception as e: + # Log out a warning message when encounter other transient errors. + logger.warning(f'Encountered issue inspecting Cuda device {device.name}: {str(e)}') diff --git a/tests/core/inference/test_device_manager.py b/tests/core/inference/test_device_manager.py new file mode 100644 index 000000000..01836805c --- /dev/null +++ b/tests/core/inference/test_device_manager.py @@ -0,0 +1,90 @@ +import unittest +from collections import OrderedDict +from types import SimpleNamespace +from unittest import mock + +import torch + +from marqo.core.exceptions import CudaDeviceNotAvailableError, CudaOutOfMemoryError +from marqo.core.inference.device_manager import DeviceManager, Device + + +class TestDeviceManager(unittest.TestCase): + + def _device_manager_without_cuda(self): + with mock.patch("torch.cuda.is_available", return_value=False): + return DeviceManager() + + def _device_manager_with_cuda(self, total_memory: int = 1_000_000): + with mock.patch("torch.cuda.is_available", return_value=True), \ + mock.patch("torch.cuda.device_count", return_value=1), \ + mock.patch("torch.cuda.get_device_name", return_value='cuda:0'), \ + mock.patch("torch.cuda.get_device_properties", return_value=SimpleNamespace(total_memory=total_memory)): + + return DeviceManager() + + def test_init_with_cpu(self): + device_manager = self._device_manager_without_cuda() + + self.assertEqual(device_manager.best_available_device_type, 'cpu') + self.assertEqual(device_manager.devices, [Device.cpu()]) + self.assertFalse(device_manager._is_cuda_available_at_startup) + + def test_init_with_gpu(self): + device_manager = self._device_manager_with_cuda(total_memory=1_000_000) + + self.assertEqual(device_manager.best_available_device_type, 'cuda') + self.assertEqual(device_manager.devices, [Device.cpu(), Device.cuda(0, 'cuda:0', 1_000_000)]) + self.assertTrue(device_manager._is_cuda_available_at_startup) + + def test_cuda_health_check_should_skip_without_cuda_devices(self): + device_manager = self._device_manager_without_cuda() + + with mock.patch("marqo.core.inference.device_manager.torch") as mock_cuda: + device_manager.cuda_device_health_check() + self.assertEqual(0, len(mock_cuda.mock_calls)) + + def test_cuda_health_check_should_pass_when_cuda_device_is_healthy(self): + device_manager = self._device_manager_with_cuda() + + with mock.patch("torch.cuda.is_available", return_value=True), \ + mock.patch("torch.randn", return_value=torch.tensor([1, 2, 3])), \ + mock.patch("marqo.core.inference.device_manager.logger") as mock_logger: + device_manager.cuda_device_health_check() + + # verify there's no warning or error level logging + for mock_logger_calls in mock_logger.mock_calls: + logger_call_method_name = mock_logger_calls[0] + self.assertNotIn(logger_call_method_name, ['warning', 'error']) + + def test_cuda_health_check_should_fail_when_cuda_device_becomes_unavailable(self): + device_manager = self._device_manager_with_cuda() + + with mock.patch("torch.cuda.is_available", return_value=False): + with self.assertRaises(CudaDeviceNotAvailableError) as err: + device_manager.cuda_device_health_check() + + self.assertEqual(str(err.exception), "Cuda device becomes unavailable") + + def test_cuda_health_check_should_fail_when_cuda_device_is_out_of_memory(self): + device_manager = self._device_manager_with_cuda(total_memory=1_000_000) + + with mock.patch("torch.cuda.is_available", return_value=True), \ + mock.patch("torch.randn", side_effect=RuntimeError("CUDA error: out of memory")), \ + mock.patch("torch.cuda.memory_stats", return_value=OrderedDict({"allocated.all.current": 900_000})): + with self.assertRaises(CudaOutOfMemoryError) as err: + device_manager.cuda_device_health_check() + + self.assertEqual(str(err.exception), "Cuda device cuda:0 is out of memory: (900000/1000000)") + + def test_cuda_health_check_should_pass_and_log_warning_message_when_cuda_calls_encounter_issue(self): + device_manager = self._device_manager_with_cuda() + + with mock.patch("torch.cuda.is_available", return_value=True), \ + mock.patch("torch.cuda.memory_stats", side_effect=Exception("random exception")), \ + mock.patch("marqo.core.inference.device_manager.logger") as mock_logger: + device_manager.cuda_device_health_check() + + self.assertEqual('warning', mock_logger.mock_calls[0][0]) + self.assertEqual('Encountered issue inspecting Cuda device cuda:0: random exception', + mock_logger.mock_calls[0][1][0]) \ No newline at end of file From 389ae36c0048a939460518129bf0cae82db20c5b Mon Sep 17 00:00:00 2001 From: yihanzhao Date: Thu, 28 Nov 2024 14:50:12 +1100 Subject: [PATCH 4/7] Add test in api.py --- tests/tensor_search/test_api.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/tensor_search/test_api.py b/tests/tensor_search/test_api.py index b767d5f08..da5a4ffe9 100644 --- a/tests/tensor_search/test_api.py +++ b/tests/tensor_search/test_api.py @@ -7,6 +7,7 @@ import marqo.tensor_search.api as api from marqo import exceptions as base_exceptions from marqo.core import exceptions as core_exceptions +from marqo.core.exceptions import CudaDeviceNotAvailableError, CudaOutOfMemoryError from marqo.core.models.marqo_index import FieldType from marqo.core.models.marqo_index_request import FieldRequest from marqo.tensor_search.enums import EnvVars @@ -528,6 +529,23 @@ def test_invalid_structured_index_field_features(self): self.assertIn("allFields", response.text) self.assertIn("features", response.text) + def test_healthz_happy_pass(self): + response = self.client.get("/healthz") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {"status": "ok"}) + + def test_healthz_fails_if_exception_raised(self): + for cuda_exception in [ + CudaDeviceNotAvailableError('Cuda device becomes unavailable'), + CudaOutOfMemoryError('Cuda device cuda:0 is out of memory') + ]: + with self.subTest(cuda_exception): + with patch("marqo.core.inference.device_manager.DeviceManager.cuda_device_health_check", + side_effect=cuda_exception): + response = self.client.get("/healthz") + self.assertEqual(response.status_code, 500) + self.assertIn(cuda_exception.message, response.json()['message']) + def test_log_stack_trace_for_core_exceptions(self): """Ensure stack trace is logged for core exceptions, e.g.,IndexExistsError""" raised_error = core_exceptions.IndexExistsError("index1") From d5ef430fd964ba69672ac561ddbe5d132c4b6d34 Mon Sep 17 00:00:00 2001 From: yihanzhao Date: Wed, 4 Dec 2024 12:19:26 +1100 Subject: [PATCH 5/7] Address PR comments --- src/marqo/core/inference/device_manager.py | 26 +++++++---- src/marqo/tensor_search/api.py | 11 ++++- tests/core/inference/test_device_manager.py | 48 ++++++++++++++++++--- 3 files changed, 70 insertions(+), 15 deletions(-) diff --git a/src/marqo/core/inference/device_manager.py b/src/marqo/core/inference/device_manager.py index 47bdb9ce8..b03715294 100644 --- a/src/marqo/core/inference/device_manager.py +++ b/src/marqo/core/inference/device_manager.py @@ -19,7 +19,7 @@ class Device(BaseModel): id: int name: str type: DeviceType - total_memory: Optional[int] = 0 + total_memory: Optional[int] = None @classmethod def cpu(cls) -> 'Device': @@ -68,10 +68,10 @@ def cuda_device_health_check(self) -> None: # CUDA devices could become unavailable/unreachable if the docker container running Marqo loses access # to the device symlinks. There is no way to recover from this, we will need to restart the container. # See https://github.com/NVIDIA/nvidia-container-toolkit/issues/48 for more details. - logger.error('Cuda device becomes unavailable') - raise CudaDeviceNotAvailableError('Cuda device becomes unavailable') + logger.error('CUDA device/s have become unavailable') + raise CudaDeviceNotAvailableError('CUDA device/s have become unavailable') - # TODO confirm whether we should check all devices or just the default one + oom_errors = [] for device in self.cuda_devices: try: cuda_device = torch.device(device.name) @@ -82,13 +82,23 @@ def cuda_device_health_check(self) -> None: torch.randn(3, device=cuda_device) except RuntimeError as e: if 'out of memory' in str(e).lower(): - # If we encounter 'CUDA error: out of memory' error consistently, it means some threads are - # holding the memory + # `~torch.cuda.empty_cache` doesn't increase the amount of GPU memory available for PyTorch. + # However, it may help reduce fragmentation of GPU memory in certain cases. + torch.cuda.empty_cache() + logger.error(f'Cuda device {device.name} is out of memory. Total memory: {device.total_memory}. ' f'Memory stats: {str(memory_stats)}') allocated_mem = memory_stats.get("allocated.all.current", None) if memory_stats else None - raise CudaOutOfMemoryError(f'Cuda device {device.name} is out of memory: ' - f'({allocated_mem}/{device.total_memory})') + oom_errors.append(f'Cuda device {device.name} is out of memory:' + f' ({allocated_mem}/{device.total_memory})') + else: + # Log out a warning message when encounter other transient errors. + logger.warning(f'Encountered issue inspecting Cuda device {device.name}: {str(e)}') except Exception as e: # Log out a warning message when encounter other transient errors. logger.warning(f'Encountered issue inspecting Cuda device {device.name}: {str(e)}') + + if oom_errors: + # We error out if any cuda device is out of memory. If this happens consistently, the memory might be held + # by a long-running thread, and Marqo will need to be restarted to get to a healthy status + raise CudaOutOfMemoryError(';'.join(oom_errors)) diff --git a/src/marqo/tensor_search/api.py b/src/marqo/tensor_search/api.py index c2dc8fc19..d92b41a23 100644 --- a/src/marqo/tensor_search/api.py +++ b/src/marqo/tensor_search/api.py @@ -593,7 +593,16 @@ def check_health(marqo_config: config.Config = Depends(get_config)): @app.get("/healthz", include_in_schema=False) -def check_health(marqo_config: config.Config = Depends(get_config)): +def liveness_check(marqo_config: config.Config = Depends(get_config)) -> JSONResponse: + """ + This liveness check endpoint does a quick status check, and error out if any component encounters unrecoverable + issues. This only does a check on the cuda devices right now. + Docker schedulers could leverage this endpoint to decide whether to restart the Marqo container. + + Returns: + 200 - if all checks pass + 500 - if any check fails + """ marqo_config.device_manager.cuda_device_health_check() return JSONResponse(content={"status": "ok"}, status_code=200) diff --git a/tests/core/inference/test_device_manager.py b/tests/core/inference/test_device_manager.py index 01836805c..c08669c3a 100644 --- a/tests/core/inference/test_device_manager.py +++ b/tests/core/inference/test_device_manager.py @@ -23,6 +23,14 @@ def _device_manager_with_cuda(self, total_memory: int = 1_000_000): return DeviceManager() + def _device_manager_with_multiple_cuda_devices(self, total_memory: int = 1_000_000): + with mock.patch("torch.cuda.is_available", return_value=True), \ + mock.patch("torch.cuda.device_count", return_value=2), \ + mock.patch("torch.cuda.get_device_name", side_effect=['cuda:0', 'cuda:1']), \ + mock.patch("torch.cuda.get_device_properties", return_value=SimpleNamespace(total_memory=total_memory)): + + return DeviceManager() + def test_init_with_cpu(self): device_manager = self._device_manager_without_cuda() @@ -64,7 +72,7 @@ def test_cuda_health_check_should_fail_when_cuda_device_becomes_unavailable(self with self.assertRaises(CudaDeviceNotAvailableError) as err: device_manager.cuda_device_health_check() - self.assertEqual(str(err.exception), "Cuda device becomes unavailable") + self.assertEqual(str(err.exception), "CUDA device/s have become unavailable") def test_cuda_health_check_should_fail_when_cuda_device_is_out_of_memory(self): device_manager = self._device_manager_with_cuda(total_memory=1_000_000) @@ -77,14 +85,42 @@ def test_cuda_health_check_should_fail_when_cuda_device_is_out_of_memory(self): self.assertEqual(str(err.exception), "Cuda device cuda:0 is out of memory: (900000/1000000)") - def test_cuda_health_check_should_pass_and_log_warning_message_when_cuda_calls_encounter_issue(self): - device_manager = self._device_manager_with_cuda() + def test_cuda_health_check_should_fail_when_any_cuda_device_is_out_of_memory(self): + device_manager = self._device_manager_with_multiple_cuda_devices(total_memory=1_000_000) with mock.patch("torch.cuda.is_available", return_value=True), \ - mock.patch("torch.cuda.memory_stats", side_effect=Exception("random exception")), \ + mock.patch("torch.randn", side_effect=[torch.tensor([1, 2, 3]), RuntimeError("CUDA error: out of memory")]), \ + mock.patch("torch.cuda.memory_stats", return_value=OrderedDict({"allocated.all.current": 900_000})): + with self.assertRaises(CudaOutOfMemoryError) as err: + device_manager.cuda_device_health_check() + + self.assertEqual(str(err.exception), "Cuda device cuda:1 is out of memory: (900000/1000000)") + + def test_cuda_health_check_should_check_if_all_cuda_devices_are_out_of_memory(self): + device_manager = self._device_manager_with_multiple_cuda_devices(total_memory=1_000_000) + + with mock.patch("torch.cuda.is_available", return_value=True), \ + mock.patch("torch.randn", + side_effect=[RuntimeError("CUDA error: out of memory"), RuntimeError("CUDA error: out of memory")]), \ + mock.patch("torch.cuda.memory_stats", return_value=OrderedDict({"allocated.all.current": 900_000})): + with self.assertRaises(CudaOutOfMemoryError) as err: + device_manager.cuda_device_health_check() + + self.assertEqual(str(err.exception), "Cuda device cuda:0 is out of memory: (900000/1000000);" + "Cuda device cuda:1 is out of memory: (900000/1000000)") + + def test_cuda_health_check_should_pass_and_log_warning_message_when_cuda_calls_encounter_issue_other_than_oom(self): + device_manager = self._device_manager_with_multiple_cuda_devices() + + with mock.patch("torch.cuda.is_available", return_value=True), \ + mock.patch("torch.cuda.memory_stats", side_effect=[RuntimeError("not a memory issue"), Exception("random exception")]), \ mock.patch("marqo.core.inference.device_manager.logger") as mock_logger: device_manager.cuda_device_health_check() self.assertEqual('warning', mock_logger.mock_calls[0][0]) - self.assertEqual('Encountered issue inspecting Cuda device cuda:0: random exception', - mock_logger.mock_calls[0][1][0]) \ No newline at end of file + self.assertEqual('Encountered issue inspecting Cuda device cuda:0: not a memory issue', + mock_logger.mock_calls[0][1][0]) + + self.assertEqual('warning', mock_logger.mock_calls[1][0]) + self.assertEqual('Encountered issue inspecting Cuda device cuda:1: random exception', + mock_logger.mock_calls[1][1][0]) \ No newline at end of file From c807325f6467caf57eb2b9dbb96e5c6c79654302 Mon Sep 17 00:00:00 2001 From: yihanzhao Date: Wed, 4 Dec 2024 15:15:09 +1100 Subject: [PATCH 6/7] Address more PR comments --- src/marqo/api/exceptions.py | 6 +++ src/marqo/core/inference/device_manager.py | 51 +++++++++++---------- src/marqo/tensor_search/api.py | 2 +- tests/core/inference/test_device_manager.py | 20 ++++---- tests/tensor_search/test_api.py | 6 +-- 5 files changed, 48 insertions(+), 37 deletions(-) diff --git a/src/marqo/api/exceptions.py b/src/marqo/api/exceptions.py index 2bbe40d75..1ef4946f3 100644 --- a/src/marqo/api/exceptions.py +++ b/src/marqo/api/exceptions.py @@ -226,6 +226,12 @@ class InternalError(MarqoWebError): status_code = HTTPStatus.INTERNAL_SERVER_ERROR +class ServiceUnavailableError(MarqoWebError): + error_type = "service_unavailable" + code = "service_unavailable" + status_code = HTTPStatus.SERVICE_UNAVAILABLE + + class BackendCommunicationError(InternalError): """Error when connecting to Vespa""" code = "backend_communication_error" diff --git a/src/marqo/core/inference/device_manager.py b/src/marqo/core/inference/device_manager.py index b03715294..c4fca1580 100644 --- a/src/marqo/core/inference/device_manager.py +++ b/src/marqo/core/inference/device_manager.py @@ -1,9 +1,10 @@ from enum import Enum +from functools import cached_property from typing import List, Optional import torch -from pydantic import BaseModel +from marqo.base_model import ImmutableBaseModel from marqo.core.exceptions import CudaDeviceNotAvailableError, CudaOutOfMemoryError from marqo.logging import get_logger @@ -11,33 +12,42 @@ class DeviceType(str, Enum): - cpu = 'cpu' - cuda = 'cuda' + CPU = 'cpu' + CUDA = 'cuda' -class Device(BaseModel): +class Device(ImmutableBaseModel): id: int name: str type: DeviceType total_memory: Optional[int] = None + @property + def full_name(self) -> str: + return f'{self.type.value}:{self.id}({self.name})' + @classmethod def cpu(cls) -> 'Device': - return Device(id=-1, name='cpu', type=DeviceType.cpu) + return Device(id=-1, name='cpu', type=DeviceType.CPU) @classmethod def cuda(cls, device_id, name, total_memory) -> 'Device': - return Device(id=device_id, name=name, type=DeviceType.cuda, total_memory=total_memory) + return Device(id=device_id, name=name, type=DeviceType.CUDA, total_memory=total_memory) class DeviceManager: + """ + Device manager collects information and stats of CPU and GPU devices to facilitate the preprocessing and + vectorisation processes. Based on the information, we will choose the best device to load the embedding models, + process media files and vectorise the content to achieve optimal performance for search and document ingestion. + """ def __init__(self): self._is_cuda_available_at_startup: bool = torch.cuda.is_available() self.devices: List[Device] = [Device.cpu()] - self.best_available_device_type = DeviceType.cpu + self.best_available_device_type = DeviceType.CPU if self._is_cuda_available_at_startup: - self.best_available_device_type = DeviceType.cuda + self.best_available_device_type = DeviceType.CUDA device_count = torch.cuda.device_count() for device_id in range(device_count): self.devices.append(Device.cuda(device_id, @@ -47,9 +57,9 @@ def __init__(self): logger.debug(f'Found devices {self.devices}. Best available device set to: ' f'{self.best_available_device_type.value}.') - @property + @cached_property def cuda_devices(self): - return [device for device in self.devices if device.type == DeviceType.cuda] + return [device for device in self.devices if device.type == DeviceType.CUDA] def cuda_device_health_check(self) -> None: """ @@ -68,37 +78,32 @@ def cuda_device_health_check(self) -> None: # CUDA devices could become unavailable/unreachable if the docker container running Marqo loses access # to the device symlinks. There is no way to recover from this, we will need to restart the container. # See https://github.com/NVIDIA/nvidia-container-toolkit/issues/48 for more details. - logger.error('CUDA device/s have become unavailable') - raise CudaDeviceNotAvailableError('CUDA device/s have become unavailable') + raise CudaDeviceNotAvailableError('CUDA device(s) have become unavailable') oom_errors = [] for device in self.cuda_devices: try: - cuda_device = torch.device(device.name) + cuda_device = torch.device(f'cuda:{device.id}') memory_stats = torch.cuda.memory_stats(cuda_device) - logger.debug(f'Cuda device {device.name} with total memory {device.total_memory}. ' + logger.debug(f'CUDA device {device.full_name} with total memory {device.total_memory}. ' f'Memory stats: {str(memory_stats)}') torch.randn(3, device=cuda_device) except RuntimeError as e: if 'out of memory' in str(e).lower(): - # `~torch.cuda.empty_cache` doesn't increase the amount of GPU memory available for PyTorch. - # However, it may help reduce fragmentation of GPU memory in certain cases. - torch.cuda.empty_cache() - - logger.error(f'Cuda device {device.name} is out of memory. Total memory: {device.total_memory}. ' + logger.error(f'CUDA device {device.full_name} is out of memory. Total memory: {device.total_memory}. ' f'Memory stats: {str(memory_stats)}') allocated_mem = memory_stats.get("allocated.all.current", None) if memory_stats else None - oom_errors.append(f'Cuda device {device.name} is out of memory:' + oom_errors.append(f'CUDA device {device.full_name} is out of memory:' f' ({allocated_mem}/{device.total_memory})') else: # Log out a warning message when encounter other transient errors. - logger.warning(f'Encountered issue inspecting Cuda device {device.name}: {str(e)}') + logger.warning(f'Encountered issue inspecting CUDA device {device.full_name}: {str(e)}') except Exception as e: # Log out a warning message when encounter other transient errors. - logger.warning(f'Encountered issue inspecting Cuda device {device.name}: {str(e)}') + logger.warning(f'Encountered issue inspecting CUDA device {device.full_name}: {str(e)}') if oom_errors: - # We error out if any cuda device is out of memory. If this happens consistently, the memory might be held + # We error out if any CUDA device is out of memory. If this happens consistently, the memory might be held # by a long-running thread, and Marqo will need to be restarted to get to a healthy status raise CudaOutOfMemoryError(';'.join(oom_errors)) diff --git a/src/marqo/tensor_search/api.py b/src/marqo/tensor_search/api.py index d92b41a23..d0333d1bc 100644 --- a/src/marqo/tensor_search/api.py +++ b/src/marqo/tensor_search/api.py @@ -112,7 +112,7 @@ def marqo_base_exception_handler(request: Request, exc: base_exceptions.MarqoErr (core_exceptions.InternalError, api_exceptions.InternalError, None, None), (core_exceptions.ApplicationRollbackError, api_exceptions.ApplicationRollbackError, None, None), (core_exceptions.TooManyFieldsError, api_exceptions.BadRequestError, None, None), - (core_exceptions.DeviceError, api_exceptions.InternalError, None, None), + (core_exceptions.DeviceError, api_exceptions.ServiceUnavailableError, None, None), # Vespa client exceptions ( diff --git a/tests/core/inference/test_device_manager.py b/tests/core/inference/test_device_manager.py index c08669c3a..0280cfc8b 100644 --- a/tests/core/inference/test_device_manager.py +++ b/tests/core/inference/test_device_manager.py @@ -18,7 +18,7 @@ def _device_manager_without_cuda(self): def _device_manager_with_cuda(self, total_memory: int = 1_000_000): with mock.patch("torch.cuda.is_available", return_value=True), \ mock.patch("torch.cuda.device_count", return_value=1), \ - mock.patch("torch.cuda.get_device_name", return_value='cuda:0'), \ + mock.patch("torch.cuda.get_device_name", return_value='Tesla T4'), \ mock.patch("torch.cuda.get_device_properties", return_value=SimpleNamespace(total_memory=total_memory)): return DeviceManager() @@ -26,7 +26,7 @@ def _device_manager_with_cuda(self, total_memory: int = 1_000_000): def _device_manager_with_multiple_cuda_devices(self, total_memory: int = 1_000_000): with mock.patch("torch.cuda.is_available", return_value=True), \ mock.patch("torch.cuda.device_count", return_value=2), \ - mock.patch("torch.cuda.get_device_name", side_effect=['cuda:0', 'cuda:1']), \ + mock.patch("torch.cuda.get_device_name", side_effect=['Tesla T4', 'Tesla H200']), \ mock.patch("torch.cuda.get_device_properties", return_value=SimpleNamespace(total_memory=total_memory)): return DeviceManager() @@ -42,7 +42,7 @@ def test_init_with_gpu(self): device_manager = self._device_manager_with_cuda(total_memory=1_000_000) self.assertEqual(device_manager.best_available_device_type, 'cuda') - self.assertEqual(device_manager.devices, [Device.cpu(), Device.cuda(0, 'cuda:0', 1_000_000)]) + self.assertEqual(device_manager.devices, [Device.cpu(), Device.cuda(0, 'Tesla T4', 1_000_000)]) self.assertTrue(device_manager._is_cuda_available_at_startup) def test_cuda_health_check_should_skip_without_cuda_devices(self): @@ -72,7 +72,7 @@ def test_cuda_health_check_should_fail_when_cuda_device_becomes_unavailable(self with self.assertRaises(CudaDeviceNotAvailableError) as err: device_manager.cuda_device_health_check() - self.assertEqual(str(err.exception), "CUDA device/s have become unavailable") + self.assertEqual(str(err.exception), "CUDA device(s) have become unavailable") def test_cuda_health_check_should_fail_when_cuda_device_is_out_of_memory(self): device_manager = self._device_manager_with_cuda(total_memory=1_000_000) @@ -83,7 +83,7 @@ def test_cuda_health_check_should_fail_when_cuda_device_is_out_of_memory(self): with self.assertRaises(CudaOutOfMemoryError) as err: device_manager.cuda_device_health_check() - self.assertEqual(str(err.exception), "Cuda device cuda:0 is out of memory: (900000/1000000)") + self.assertEqual(str(err.exception), "CUDA device cuda:0(Tesla T4) is out of memory: (900000/1000000)") def test_cuda_health_check_should_fail_when_any_cuda_device_is_out_of_memory(self): device_manager = self._device_manager_with_multiple_cuda_devices(total_memory=1_000_000) @@ -94,7 +94,7 @@ def test_cuda_health_check_should_fail_when_any_cuda_device_is_out_of_memory(sel with self.assertRaises(CudaOutOfMemoryError) as err: device_manager.cuda_device_health_check() - self.assertEqual(str(err.exception), "Cuda device cuda:1 is out of memory: (900000/1000000)") + self.assertEqual(str(err.exception), "CUDA device cuda:1(Tesla H200) is out of memory: (900000/1000000)") def test_cuda_health_check_should_check_if_all_cuda_devices_are_out_of_memory(self): device_manager = self._device_manager_with_multiple_cuda_devices(total_memory=1_000_000) @@ -106,8 +106,8 @@ def test_cuda_health_check_should_check_if_all_cuda_devices_are_out_of_memory(se with self.assertRaises(CudaOutOfMemoryError) as err: device_manager.cuda_device_health_check() - self.assertEqual(str(err.exception), "Cuda device cuda:0 is out of memory: (900000/1000000);" - "Cuda device cuda:1 is out of memory: (900000/1000000)") + self.assertEqual(str(err.exception), "CUDA device cuda:0(Tesla T4) is out of memory: (900000/1000000);" + "CUDA device cuda:1(Tesla H200) is out of memory: (900000/1000000)") def test_cuda_health_check_should_pass_and_log_warning_message_when_cuda_calls_encounter_issue_other_than_oom(self): device_manager = self._device_manager_with_multiple_cuda_devices() @@ -118,9 +118,9 @@ def test_cuda_health_check_should_pass_and_log_warning_message_when_cuda_calls_e device_manager.cuda_device_health_check() self.assertEqual('warning', mock_logger.mock_calls[0][0]) - self.assertEqual('Encountered issue inspecting Cuda device cuda:0: not a memory issue', + self.assertEqual('Encountered issue inspecting CUDA device cuda:0(Tesla T4): not a memory issue', mock_logger.mock_calls[0][1][0]) self.assertEqual('warning', mock_logger.mock_calls[1][0]) - self.assertEqual('Encountered issue inspecting Cuda device cuda:1: random exception', + self.assertEqual('Encountered issue inspecting CUDA device cuda:1(Tesla H200): random exception', mock_logger.mock_calls[1][1][0]) \ No newline at end of file diff --git a/tests/tensor_search/test_api.py b/tests/tensor_search/test_api.py index da5a4ffe9..a0da58153 100644 --- a/tests/tensor_search/test_api.py +++ b/tests/tensor_search/test_api.py @@ -536,14 +536,14 @@ def test_healthz_happy_pass(self): def test_healthz_fails_if_exception_raised(self): for cuda_exception in [ - CudaDeviceNotAvailableError('Cuda device becomes unavailable'), - CudaOutOfMemoryError('Cuda device cuda:0 is out of memory') + CudaDeviceNotAvailableError('CUDA device(s) have become unavailable'), + CudaOutOfMemoryError('CUDA device cuda:0(Tesla T4) is out of memory') ]: with self.subTest(cuda_exception): with patch("marqo.core.inference.device_manager.DeviceManager.cuda_device_health_check", side_effect=cuda_exception): response = self.client.get("/healthz") - self.assertEqual(response.status_code, 500) + self.assertEqual(response.status_code, 503) self.assertIn(cuda_exception.message, response.json()['message']) def test_log_stack_trace_for_core_exceptions(self): From c2966661a04a724e2f80279109f4cd5cf6e554cf Mon Sep 17 00:00:00 2001 From: yihanzhao Date: Wed, 4 Dec 2024 15:20:31 +1100 Subject: [PATCH 7/7] Log error message --- src/marqo/core/inference/device_manager.py | 4 ++-- tests/core/inference/test_device_manager.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/marqo/core/inference/device_manager.py b/src/marqo/core/inference/device_manager.py index c4fca1580..7fd345e57 100644 --- a/src/marqo/core/inference/device_manager.py +++ b/src/marqo/core/inference/device_manager.py @@ -98,10 +98,10 @@ def cuda_device_health_check(self) -> None: f' ({allocated_mem}/{device.total_memory})') else: # Log out a warning message when encounter other transient errors. - logger.warning(f'Encountered issue inspecting CUDA device {device.full_name}: {str(e)}') + logger.error(f'Encountered issue inspecting CUDA device {device.full_name}: {str(e)}') except Exception as e: # Log out a warning message when encounter other transient errors. - logger.warning(f'Encountered issue inspecting CUDA device {device.full_name}: {str(e)}') + logger.error(f'Encountered issue inspecting CUDA device {device.full_name}: {str(e)}') if oom_errors: # We error out if any CUDA device is out of memory. If this happens consistently, the memory might be held diff --git a/tests/core/inference/test_device_manager.py b/tests/core/inference/test_device_manager.py index 0280cfc8b..2688e82a8 100644 --- a/tests/core/inference/test_device_manager.py +++ b/tests/core/inference/test_device_manager.py @@ -109,7 +109,7 @@ def test_cuda_health_check_should_check_if_all_cuda_devices_are_out_of_memory(se self.assertEqual(str(err.exception), "CUDA device cuda:0(Tesla T4) is out of memory: (900000/1000000);" "CUDA device cuda:1(Tesla H200) is out of memory: (900000/1000000)") - def test_cuda_health_check_should_pass_and_log_warning_message_when_cuda_calls_encounter_issue_other_than_oom(self): + def test_cuda_health_check_should_pass_and_log_error_message_when_cuda_calls_encounter_issue_other_than_oom(self): device_manager = self._device_manager_with_multiple_cuda_devices() with mock.patch("torch.cuda.is_available", return_value=True), \ @@ -117,10 +117,10 @@ def test_cuda_health_check_should_pass_and_log_warning_message_when_cuda_calls_e mock.patch("marqo.core.inference.device_manager.logger") as mock_logger: device_manager.cuda_device_health_check() - self.assertEqual('warning', mock_logger.mock_calls[0][0]) + self.assertEqual('error', mock_logger.mock_calls[0][0]) self.assertEqual('Encountered issue inspecting CUDA device cuda:0(Tesla T4): not a memory issue', mock_logger.mock_calls[0][1][0]) - self.assertEqual('warning', mock_logger.mock_calls[1][0]) + self.assertEqual('error', mock_logger.mock_calls[1][0]) self.assertEqual('Encountered issue inspecting CUDA device cuda:1(Tesla H200): random exception', mock_logger.mock_calls[1][1][0]) \ No newline at end of file