diff --git a/src/marqo/api/exceptions.py b/src/marqo/api/exceptions.py index 2bbe40d75..1ef4946f3 100644 --- a/src/marqo/api/exceptions.py +++ b/src/marqo/api/exceptions.py @@ -226,6 +226,12 @@ class InternalError(MarqoWebError): status_code = HTTPStatus.INTERNAL_SERVER_ERROR +class ServiceUnavailableError(MarqoWebError): + error_type = "service_unavailable" + code = "service_unavailable" + status_code = HTTPStatus.SERVICE_UNAVAILABLE + + class BackendCommunicationError(InternalError): """Error when connecting to Vespa""" code = "backend_communication_error" diff --git a/src/marqo/config.py b/src/marqo/config.py index d89926499..d2459244e 100644 --- a/src/marqo/config.py +++ b/src/marqo/config.py @@ -2,6 +2,7 @@ from kazoo.handlers.threading import KazooTimeoutError +from marqo.core.inference.device_manager import DeviceManager from marqo.vespa.zookeeper_client import ZookeeperClient from marqo.core.document.document import Document from marqo.core.embed.embed import Embed @@ -39,6 +40,7 @@ def __init__( self.timeout = timeout self.backend = backend if backend is not None else enums.SearchDb.vespa + # TODO [Refactoring device logic] deprecate default_device since it's not used self.default_device = default_device if default_device is not None else ( utils.read_env_vars_and_defaults(EnvVars.MARQO_BEST_AVAILABLE_DEVICE)) @@ -52,6 +54,7 @@ def __init__( self.document = Document(vespa_client, self.index_management) self.recommender = Recommender(vespa_client, self.index_management) self.embed = Embed(vespa_client, self.index_management, self.default_device) + self.device_manager = DeviceManager() def set_is_remote(self, vespa_client: VespaClient): local_host_markers = ["localhost", "0.0.0.0", "127.0.0.1"] diff --git a/src/marqo/core/embed/embed.py b/src/marqo/core/embed/embed.py index bead85221..0493192ef 100644 --- a/src/marqo/core/embed/embed.py +++ b/src/marqo/core/embed/embed.py @@ -29,6 +29,7 @@ def __init__(self, vespa_client: VespaClient, index_management: IndexManagement, self.default_device = default_device @pydantic.validator('default_device') + # TODO [Refactoring device logic] deprecate default_device since it's not used def validate_default_device(cls, value): if not value: raise ValueError("Default Device cannot be 'None'. Marqo default device must have been declared upon startup.") @@ -66,6 +67,7 @@ def embed_content( ) # Set default device if not provided + # TODO [Refactoring device logic] use device info gathered from device manager if device is None: device = utils.read_env_vars_and_defaults("MARQO_BEST_AVAILABLE_DEVICE") diff --git a/src/marqo/core/exceptions.py b/src/marqo/core/exceptions.py index 7467da5e3..5164b4b98 100644 --- a/src/marqo/core/exceptions.py +++ b/src/marqo/core/exceptions.py @@ -112,3 +112,15 @@ class DuplicateDocumentError(AddDocumentsError): class TooManyFieldsError(MarqoError): pass + + +class DeviceError(MarqoError): + pass + + +class CudaDeviceNotAvailableError(DeviceError): + pass + + +class CudaOutOfMemoryError(DeviceError): + pass diff --git a/src/marqo/core/inference/device_manager.py b/src/marqo/core/inference/device_manager.py new file mode 100644 index 000000000..7fd345e57 --- /dev/null +++ b/src/marqo/core/inference/device_manager.py @@ -0,0 +1,109 @@ +from enum import Enum +from functools import cached_property +from typing import List, Optional + +import torch + +from marqo.base_model import ImmutableBaseModel +from marqo.core.exceptions import CudaDeviceNotAvailableError, CudaOutOfMemoryError +from marqo.logging import get_logger + +logger = get_logger('device_manager') + + +class DeviceType(str, Enum): + CPU = 'cpu' + CUDA = 'cuda' + + +class Device(ImmutableBaseModel): + id: int + name: str + type: DeviceType + total_memory: Optional[int] = None + + @property + def full_name(self) -> str: + return f'{self.type.value}:{self.id}({self.name})' + + @classmethod + def cpu(cls) -> 'Device': + return Device(id=-1, name='cpu', type=DeviceType.CPU) + + @classmethod + def cuda(cls, device_id, name, total_memory) -> 'Device': + return Device(id=device_id, name=name, type=DeviceType.CUDA, total_memory=total_memory) + + +class DeviceManager: + """ + Device manager collects information and stats of CPU and GPU devices to facilitate the preprocessing and + vectorisation processes. Based on the information, we will choose the best device to load the embedding models, + process media files and vectorise the content to achieve optimal performance for search and document ingestion. + """ + def __init__(self): + self._is_cuda_available_at_startup: bool = torch.cuda.is_available() + self.devices: List[Device] = [Device.cpu()] + self.best_available_device_type = DeviceType.CPU + + if self._is_cuda_available_at_startup: + self.best_available_device_type = DeviceType.CUDA + device_count = torch.cuda.device_count() + for device_id in range(device_count): + self.devices.append(Device.cuda(device_id, + torch.cuda.get_device_name(device_id), + torch.cuda.get_device_properties(device_id).total_memory)) + + logger.debug(f'Found devices {self.devices}. Best available device set to: ' + f'{self.best_available_device_type.value}.') + + @cached_property + def cuda_devices(self): + return [device for device in self.devices if device.type == DeviceType.CUDA] + + def cuda_device_health_check(self) -> None: + """ + Checks the status of the CUDA devices, and raises exceptions if it becomes + not available or out of memory. + + raises + - CudaDeviceNotAvailableError if CUDA device is not available. + - CudaOutOfMemoryError if any CUDA device is out of memory. + """ + if not self._is_cuda_available_at_startup: + # If the instance is initialised without cuda devices, skip the check + return + + if not torch.cuda.is_available(): + # CUDA devices could become unavailable/unreachable if the docker container running Marqo loses access + # to the device symlinks. There is no way to recover from this, we will need to restart the container. + # See https://github.com/NVIDIA/nvidia-container-toolkit/issues/48 for more details. + raise CudaDeviceNotAvailableError('CUDA device(s) have become unavailable') + + oom_errors = [] + for device in self.cuda_devices: + try: + cuda_device = torch.device(f'cuda:{device.id}') + memory_stats = torch.cuda.memory_stats(cuda_device) + logger.debug(f'CUDA device {device.full_name} with total memory {device.total_memory}. ' + f'Memory stats: {str(memory_stats)}') + + torch.randn(3, device=cuda_device) + except RuntimeError as e: + if 'out of memory' in str(e).lower(): + logger.error(f'CUDA device {device.full_name} is out of memory. Total memory: {device.total_memory}. ' + f'Memory stats: {str(memory_stats)}') + allocated_mem = memory_stats.get("allocated.all.current", None) if memory_stats else None + oom_errors.append(f'CUDA device {device.full_name} is out of memory:' + f' ({allocated_mem}/{device.total_memory})') + else: + # Log out a warning message when encounter other transient errors. + logger.error(f'Encountered issue inspecting CUDA device {device.full_name}: {str(e)}') + except Exception as e: + # Log out a warning message when encounter other transient errors. + logger.error(f'Encountered issue inspecting CUDA device {device.full_name}: {str(e)}') + + if oom_errors: + # We error out if any CUDA device is out of memory. If this happens consistently, the memory might be held + # by a long-running thread, and Marqo will need to be restarted to get to a healthy status + raise CudaOutOfMemoryError(';'.join(oom_errors)) diff --git a/src/marqo/core/models/add_docs_params.py b/src/marqo/core/models/add_docs_params.py index 66cf12185..6b551c12d 100644 --- a/src/marqo/core/models/add_docs_params.py +++ b/src/marqo/core/models/add_docs_params.py @@ -62,6 +62,7 @@ class Config: batch_vectorisation_mode: BatchVectorisationMode = BatchVectorisationMode.PER_DOCUMENT def __init__(self, **data: Any): + # TODO [Refactoring device logic] use device info gathered from device manager # Ensure `None` and passing nothing are treated the same for device if "device" not in data or data["device"] is None: data["device"] = get_best_available_device() diff --git a/src/marqo/core/monitoring/monitoring.py b/src/marqo/core/monitoring/monitoring.py index 8a941603d..205686a9e 100644 --- a/src/marqo/core/monitoring/monitoring.py +++ b/src/marqo/core/monitoring/monitoring.py @@ -154,6 +154,7 @@ def _get_vespa_health(self, hostname_filter: Optional[str]) -> VespaHealthStatus ) def get_cuda_info(self) -> MarqoCudaInfoResponse: + # TODO [Refactoring device logic] move this logic to device manager """A function to get information about the CUDA devices on the machine Returns: diff --git a/src/marqo/tensor_search/api.py b/src/marqo/tensor_search/api.py index 06b62e6ac..d0333d1bc 100644 --- a/src/marqo/tensor_search/api.py +++ b/src/marqo/tensor_search/api.py @@ -112,6 +112,7 @@ def marqo_base_exception_handler(request: Request, exc: base_exceptions.MarqoErr (core_exceptions.InternalError, api_exceptions.InternalError, None, None), (core_exceptions.ApplicationRollbackError, api_exceptions.ApplicationRollbackError, None, None), (core_exceptions.TooManyFieldsError, api_exceptions.BadRequestError, None, None), + (core_exceptions.DeviceError, api_exceptions.ServiceUnavailableError, None, None), # Vespa client exceptions ( @@ -579,16 +580,31 @@ def schema_validation(index_name: str, settings_object: dict): ) +@app.get('/memory', include_in_schema=False) +@utils.enable_debug_apis() +def memory(): + return memory_profiler.get_memory_profile() + + @app.get("/health" , include_in_schema=False) def check_health(marqo_config: config.Config = Depends(get_config)): health_status = marqo_config.monitoring.get_health() return HealthResponse.from_marqo_health_status(health_status) -@app.get('/memory', include_in_schema=False) -@utils.enable_debug_apis() -def memory(): - return memory_profiler.get_memory_profile() +@app.get("/healthz", include_in_schema=False) +def liveness_check(marqo_config: config.Config = Depends(get_config)) -> JSONResponse: + """ + This liveness check endpoint does a quick status check, and error out if any component encounters unrecoverable + issues. This only does a check on the cuda devices right now. + Docker schedulers could leverage this endpoint to decide whether to restart the Marqo container. + + Returns: + 200 - if all checks pass + 500 - if any check fails + """ + marqo_config.device_manager.cuda_device_health_check() + return JSONResponse(content={"status": "ok"}, status_code=200) if __name__ == "__main__": diff --git a/src/marqo/tensor_search/on_start_script.py b/src/marqo/tensor_search/on_start_script.py index 73cc21ddb..39bcf5ac0 100644 --- a/src/marqo/tensor_search/on_start_script.py +++ b/src/marqo/tensor_search/on_start_script.py @@ -85,6 +85,7 @@ def run(self): class CUDAAvailable: + # TODO [Refactoring device logic] move this logic to device manager """checks the status of cuda """ logger = get_logger('CUDA device summary') @@ -109,6 +110,7 @@ def id_to_device(id): class SetBestAvailableDevice: + # TODO [Refactoring device logic] move this logic to device manager, get rid of MARQO_BEST_AVAILABLE_DEVICE envvar """sets the MARQO_BEST_AVAILABLE_DEVICE env var """ logger = get_logger('SetBestAvailableDevice') @@ -151,6 +153,7 @@ def __init__(self): self.models = warmed_models # TBD to include cross-encoder/ms-marco-TinyBERT-L-2-v2 + # TODO [Refactoring device logic] use device info gathered from device manager self.default_devices = ['cpu'] if not torch.cuda.is_available() else ['cuda', 'cpu'] self.logger.info(f"pre-loading {self.models} onto devices={self.default_devices}") @@ -230,6 +233,7 @@ def __init__(self): f"Invalid patch model: {model}. Please ensure that this is a valid patch model." ) + # TODO [Refactoring device logic] use device info gathered from device manager self.default_devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] def run(self): diff --git a/src/marqo/tensor_search/tensor_search.py b/src/marqo/tensor_search/tensor_search.py index 8d893e823..a3a762b57 100644 --- a/src/marqo/tensor_search/tensor_search.py +++ b/src/marqo/tensor_search/tensor_search.py @@ -1557,6 +1557,7 @@ def search(config: Config, index_name: str, text: Optional[Union[str, dict, Cust if verbose: print(f"determined_search_method: {search_method}, text query: {text}") + # TODO [Refactoring device logic] use device info gathered from device manager if device is None: selected_device = utils.read_env_vars_and_defaults("MARQO_BEST_AVAILABLE_DEVICE") if selected_device is None: @@ -2263,6 +2264,7 @@ def eject_model(model_name: str, device: str) -> dict: return result +# TODO [Refactoring device logic] move to device manager def get_cpu_info() -> dict: return { "cpu_usage_percent": f"{psutil.cpu_percent(1)} %", # The number 1 is a time interval for CPU usage calculation. diff --git a/src/marqo/tensor_search/utils.py b/src/marqo/tensor_search/utils.py index be10e9216..3ea6c8bfe 100644 --- a/src/marqo/tensor_search/utils.py +++ b/src/marqo/tensor_search/utils.py @@ -88,6 +88,7 @@ def construct_authorized_url(url_base: str, username: str, password: str) -> str def check_device_is_available(device: str) -> bool: + # TODO [Refactoring device logic] move this logic to device manager """Checks if a device is available on the machine Args: @@ -341,6 +342,7 @@ def generate_batches(seq: Sequence, batch_size: int): def get_best_available_device() -> str: + # TODO [Refactoring device logic] replace this with device manager """Get the best available device for Marqo to use and validate it.""" device = read_env_vars_and_defaults(EnvVars.MARQO_BEST_AVAILABLE_DEVICE) if device is None or not check_device_is_available(device): diff --git a/src/marqo/tensor_search/web/api_utils.py b/src/marqo/tensor_search/web/api_utils.py index 1db7cea68..6bc269781 100644 --- a/src/marqo/tensor_search/web/api_utils.py +++ b/src/marqo/tensor_search/web/api_utils.py @@ -10,6 +10,7 @@ def translate_api_device(device: Optional[str]) -> Optional[str]: + # TODO [Refactoring device logic] move this logic to device manager """Translates an API device as given through the API into an internal enum. Args: diff --git a/src/marqo/tensor_search/web/api_validation.py b/src/marqo/tensor_search/web/api_validation.py index 0f2d85f67..09e1e8a65 100644 --- a/src/marqo/tensor_search/web/api_validation.py +++ b/src/marqo/tensor_search/web/api_validation.py @@ -5,6 +5,7 @@ def validate_api_device_string(device: typing.Optional[str]) -> typing.Optional[str]: + # TODO [Refactoring device logic] move this logic to device manager """Validates a device which is an API parameter Args: @@ -47,6 +48,7 @@ def validate_api_device_string(device: typing.Optional[str]) -> typing.Optional[ async def validate_device(device: typing.Optional[str] = None) -> typing.Optional[str]: + # TODO [Refactoring device logic] move this logic to device manager """Translates and validates the device string. Checks if the requested device is available. diff --git a/tests/core/inference/test_device_manager.py b/tests/core/inference/test_device_manager.py new file mode 100644 index 000000000..2688e82a8 --- /dev/null +++ b/tests/core/inference/test_device_manager.py @@ -0,0 +1,126 @@ +import unittest +from collections import OrderedDict +from types import SimpleNamespace +from unittest import mock + +import torch + +from marqo.core.exceptions import CudaDeviceNotAvailableError, CudaOutOfMemoryError +from marqo.core.inference.device_manager import DeviceManager, Device + + +class TestDeviceManager(unittest.TestCase): + + def _device_manager_without_cuda(self): + with mock.patch("torch.cuda.is_available", return_value=False): + return DeviceManager() + + def _device_manager_with_cuda(self, total_memory: int = 1_000_000): + with mock.patch("torch.cuda.is_available", return_value=True), \ + mock.patch("torch.cuda.device_count", return_value=1), \ + mock.patch("torch.cuda.get_device_name", return_value='Tesla T4'), \ + mock.patch("torch.cuda.get_device_properties", return_value=SimpleNamespace(total_memory=total_memory)): + + return DeviceManager() + + def _device_manager_with_multiple_cuda_devices(self, total_memory: int = 1_000_000): + with mock.patch("torch.cuda.is_available", return_value=True), \ + mock.patch("torch.cuda.device_count", return_value=2), \ + mock.patch("torch.cuda.get_device_name", side_effect=['Tesla T4', 'Tesla H200']), \ + mock.patch("torch.cuda.get_device_properties", return_value=SimpleNamespace(total_memory=total_memory)): + + return DeviceManager() + + def test_init_with_cpu(self): + device_manager = self._device_manager_without_cuda() + + self.assertEqual(device_manager.best_available_device_type, 'cpu') + self.assertEqual(device_manager.devices, [Device.cpu()]) + self.assertFalse(device_manager._is_cuda_available_at_startup) + + def test_init_with_gpu(self): + device_manager = self._device_manager_with_cuda(total_memory=1_000_000) + + self.assertEqual(device_manager.best_available_device_type, 'cuda') + self.assertEqual(device_manager.devices, [Device.cpu(), Device.cuda(0, 'Tesla T4', 1_000_000)]) + self.assertTrue(device_manager._is_cuda_available_at_startup) + + def test_cuda_health_check_should_skip_without_cuda_devices(self): + device_manager = self._device_manager_without_cuda() + + with mock.patch("marqo.core.inference.device_manager.torch") as mock_cuda: + device_manager.cuda_device_health_check() + self.assertEqual(0, len(mock_cuda.mock_calls)) + + def test_cuda_health_check_should_pass_when_cuda_device_is_healthy(self): + device_manager = self._device_manager_with_cuda() + + with mock.patch("torch.cuda.is_available", return_value=True), \ + mock.patch("torch.randn", return_value=torch.tensor([1, 2, 3])), \ + mock.patch("marqo.core.inference.device_manager.logger") as mock_logger: + device_manager.cuda_device_health_check() + + # verify there's no warning or error level logging + for mock_logger_calls in mock_logger.mock_calls: + logger_call_method_name = mock_logger_calls[0] + self.assertNotIn(logger_call_method_name, ['warning', 'error']) + + def test_cuda_health_check_should_fail_when_cuda_device_becomes_unavailable(self): + device_manager = self._device_manager_with_cuda() + + with mock.patch("torch.cuda.is_available", return_value=False): + with self.assertRaises(CudaDeviceNotAvailableError) as err: + device_manager.cuda_device_health_check() + + self.assertEqual(str(err.exception), "CUDA device(s) have become unavailable") + + def test_cuda_health_check_should_fail_when_cuda_device_is_out_of_memory(self): + device_manager = self._device_manager_with_cuda(total_memory=1_000_000) + + with mock.patch("torch.cuda.is_available", return_value=True), \ + mock.patch("torch.randn", side_effect=RuntimeError("CUDA error: out of memory")), \ + mock.patch("torch.cuda.memory_stats", return_value=OrderedDict({"allocated.all.current": 900_000})): + with self.assertRaises(CudaOutOfMemoryError) as err: + device_manager.cuda_device_health_check() + + self.assertEqual(str(err.exception), "CUDA device cuda:0(Tesla T4) is out of memory: (900000/1000000)") + + def test_cuda_health_check_should_fail_when_any_cuda_device_is_out_of_memory(self): + device_manager = self._device_manager_with_multiple_cuda_devices(total_memory=1_000_000) + + with mock.patch("torch.cuda.is_available", return_value=True), \ + mock.patch("torch.randn", side_effect=[torch.tensor([1, 2, 3]), RuntimeError("CUDA error: out of memory")]), \ + mock.patch("torch.cuda.memory_stats", return_value=OrderedDict({"allocated.all.current": 900_000})): + with self.assertRaises(CudaOutOfMemoryError) as err: + device_manager.cuda_device_health_check() + + self.assertEqual(str(err.exception), "CUDA device cuda:1(Tesla H200) is out of memory: (900000/1000000)") + + def test_cuda_health_check_should_check_if_all_cuda_devices_are_out_of_memory(self): + device_manager = self._device_manager_with_multiple_cuda_devices(total_memory=1_000_000) + + with mock.patch("torch.cuda.is_available", return_value=True), \ + mock.patch("torch.randn", + side_effect=[RuntimeError("CUDA error: out of memory"), RuntimeError("CUDA error: out of memory")]), \ + mock.patch("torch.cuda.memory_stats", return_value=OrderedDict({"allocated.all.current": 900_000})): + with self.assertRaises(CudaOutOfMemoryError) as err: + device_manager.cuda_device_health_check() + + self.assertEqual(str(err.exception), "CUDA device cuda:0(Tesla T4) is out of memory: (900000/1000000);" + "CUDA device cuda:1(Tesla H200) is out of memory: (900000/1000000)") + + def test_cuda_health_check_should_pass_and_log_error_message_when_cuda_calls_encounter_issue_other_than_oom(self): + device_manager = self._device_manager_with_multiple_cuda_devices() + + with mock.patch("torch.cuda.is_available", return_value=True), \ + mock.patch("torch.cuda.memory_stats", side_effect=[RuntimeError("not a memory issue"), Exception("random exception")]), \ + mock.patch("marqo.core.inference.device_manager.logger") as mock_logger: + device_manager.cuda_device_health_check() + + self.assertEqual('error', mock_logger.mock_calls[0][0]) + self.assertEqual('Encountered issue inspecting CUDA device cuda:0(Tesla T4): not a memory issue', + mock_logger.mock_calls[0][1][0]) + + self.assertEqual('error', mock_logger.mock_calls[1][0]) + self.assertEqual('Encountered issue inspecting CUDA device cuda:1(Tesla H200): random exception', + mock_logger.mock_calls[1][1][0]) \ No newline at end of file diff --git a/tests/tensor_search/test_api.py b/tests/tensor_search/test_api.py index b767d5f08..a0da58153 100644 --- a/tests/tensor_search/test_api.py +++ b/tests/tensor_search/test_api.py @@ -7,6 +7,7 @@ import marqo.tensor_search.api as api from marqo import exceptions as base_exceptions from marqo.core import exceptions as core_exceptions +from marqo.core.exceptions import CudaDeviceNotAvailableError, CudaOutOfMemoryError from marqo.core.models.marqo_index import FieldType from marqo.core.models.marqo_index_request import FieldRequest from marqo.tensor_search.enums import EnvVars @@ -528,6 +529,23 @@ def test_invalid_structured_index_field_features(self): self.assertIn("allFields", response.text) self.assertIn("features", response.text) + def test_healthz_happy_pass(self): + response = self.client.get("/healthz") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {"status": "ok"}) + + def test_healthz_fails_if_exception_raised(self): + for cuda_exception in [ + CudaDeviceNotAvailableError('CUDA device(s) have become unavailable'), + CudaOutOfMemoryError('CUDA device cuda:0(Tesla T4) is out of memory') + ]: + with self.subTest(cuda_exception): + with patch("marqo.core.inference.device_manager.DeviceManager.cuda_device_health_check", + side_effect=cuda_exception): + response = self.client.get("/healthz") + self.assertEqual(response.status_code, 503) + self.assertIn(cuda_exception.message, response.json()['message']) + def test_log_stack_trace_for_core_exceptions(self): """Ensure stack trace is logged for core exceptions, e.g.,IndexExistsError""" raised_error = core_exceptions.IndexExistsError("index1")