diff --git a/deker_server_adapters/base.py b/deker_server_adapters/base.py index 9932241..5741728 100644 --- a/deker_server_adapters/base.py +++ b/deker_server_adapters/base.py @@ -1,4 +1,5 @@ from collections.abc import Generator +from datetime import datetime from json import JSONDecodeError from typing import TYPE_CHECKING, Any, Optional, Union @@ -15,6 +16,7 @@ from deker_server_adapters.consts import NOT_FOUND, STATUS_CREATED, STATUS_OK, TIMEOUT, ArrayType from deker_server_adapters.errors import DekerServerError, DekerTimeoutServer +from deker_server_adapters.hash_ring import HashRing if TYPE_CHECKING: @@ -46,6 +48,29 @@ def client(self) -> Client: # We don't need to worry about passing args here, cause it's a singleton. return self.ctx.extra["httpx_client"] # type: ignore[attr-defined] + @property + def hash_ring(self) -> HashRing: + """Return HashRing instance.""" + return self.ctx.extra["hash_ring"] # type: ignore[attr-defined] + + def get_node(self, array: BaseArray) -> str: + """Get hash for primary attributes or id. + + :param array: Array or varray + """ + if not array.primary_attributes: + return self.hash_ring.get_node(array.id) or "" + + attrs_to_join = [] + for attr in array.primary_attributes: + attribute = array.primary_attributes[attr] + if attr == "v_position": + value = "-".join(str(el) for el in attribute) + else: + value = attribute.isoformat() if isinstance(attribute, datetime) else str(attribute) + attrs_to_join.append(value) + return self.hash_ring.get_node("/".join(attrs_to_join)) or "" + class ServerArrayAdapterMixin(ServerAdapterMixin): """Mixin with server logic for adapters.""" @@ -101,11 +126,20 @@ def read_meta(self, array: "BaseArray") -> ArrayMeta: :param array: Instance of (v)array :return: """ + nodes = [*self.hash_ring.nodes] response = self.client.get( f"{self.collection_path.raw_url}/{self.type.name}/by-id/{array.id}", ) + # If node is desync or unaviliable, try another node + while response.status_code != STATUS_OK and nodes: + node = nodes.pop() + response = self.client.get( + f"{node}/{self.collection_path.raw_url}/{self.type.name}/by-id/{array.id}", + ) + if response.status_code != STATUS_OK: raise DekerServerError(response, "Couldn't fetch an array") + return response.json() def update_meta_custom_attributes( @@ -146,9 +180,10 @@ def read_data( :return: """ bounds_ = slice_converter[bounds] + node = self.get_node(array) try: response = self.client.get( - f"/v1/collection/{array.collection}/{self.type.name}/by-id/{array.id}/subset/{bounds_}/data", + f"{node}/v1/collection/{array.collection}/{self.type.name}/by-id/{array.id}/subset/{bounds_}/data", headers={"Accept": "application/octet-stream"}, ) except TimeoutException: @@ -178,12 +213,13 @@ def update(self, array: "BaseArray", bounds: Slice, data: Numeric) -> None: :return: """ bounds = slice_converter[bounds] + node = self.get_node(array) try: if hasattr(data, "tolist"): data = data.tolist() response = self.client.put( - f"/v1/collection/{array.collection}/{self.type.name}/by-id/{array.id}/subset/{bounds}/data", + f"{node}/v1/collection/{array.collection}/{self.type.name}/by-id/{array.id}/subset/{bounds}/data", json=data, ) diff --git a/deker_server_adapters/factory.py b/deker_server_adapters/factory.py index 29d9161..192cf0e 100644 --- a/deker_server_adapters/factory.py +++ b/deker_server_adapters/factory.py @@ -1,3 +1,5 @@ +import traceback + from typing import TYPE_CHECKING, Any, Type from deker.ABC.base_factory import BaseAdaptersFactory @@ -8,6 +10,7 @@ from deker_server_adapters.collection_adapter import ServerCollectionAdapter from deker_server_adapters.consts import STATUS_OK from deker_server_adapters.errors import DekerServerError +from deker_server_adapters.hash_ring import HashRing from deker_server_adapters.httpx_client import HttpxClient from deker_server_adapters.varray_adapter import ServerVarrayAdapter @@ -51,7 +54,7 @@ def __init__(self, ctx: "CTX", uri: "Uri") -> None: ) copied_ctx.extra["httpx_client"] = self.httpx_client - self.do_healthcheck() + self.do_healthcheck(copied_ctx) super().__init__(copied_ctx, uri) def close(self) -> None: @@ -113,20 +116,28 @@ def get_collection_adapter( """ return ServerCollectionAdapter(self.ctx) - def do_healthcheck(self) -> None: - """Check if server is alive.""" - try: - response = self.httpx_client.get("/v1/ping") - except Exception: - self.httpx_client.close() - raise DekerServerError( - None, - "Healthcheck failed. Server is unavailable. Deker client will be closed.", - ) + def do_healthcheck(self, ctx: CTX) -> None: + """Check if server is alive. - if response.status_code != STATUS_OK: + :param ctx: App context + """ + response = None + nodes = [*ctx.uri.servers] + while nodes and (response is None or response.status_code != STATUS_OK): + node = nodes.pop() + + try: + response = self.httpx_client.get(f"{node}/v1/ping") + except Exception: + self.logger.error(f"Coudn't get response from {node}") # noqa + traceback.print_exc() + continue + if response is None or response.status_code != STATUS_OK: self.httpx_client.close() raise DekerServerError( response, "Healthcheck failed. Deker client will be closed.", ) + + # set hash_ring based on list from the server + ctx.extra["hash_ring"] = HashRing(response.json()["servers"]) diff --git a/deker_server_adapters/hash_ring.py b/deker_server_adapters/hash_ring.py new file mode 100644 index 0000000..ec15fa9 --- /dev/null +++ b/deker_server_adapters/hash_ring.py @@ -0,0 +1,139 @@ +import hashlib +import math + +from bisect import bisect +from typing import Callable, Generator, List, Optional, Sequence + + +md5_constructor = hashlib.md5 + + +class HashRing: + """Class for hash ring.""" + + def __init__(self, nodes: Sequence, weights: Optional[dict] = None): + """Generare instace of hash ring with given nodes. + + :param nodes: is a list of objects that have a proper __str__ representation. + :param weights: is dictionary that sets weights to the nodes. The default + weight is that all nodes are equal. + """ + self.ring = {} # type: ignore[var-annotated] + self._sorted_keys = [] # type: ignore[var-annotated] + + self.nodes = nodes + + if not weights: + weights = {} # type: ignore[var-annotated] + self.weights = weights + + self._generate_circle() + + def _generate_circle(self) -> None: + """Generate the circle.""" + total_weight = 0 + for node in self.nodes: + total_weight += self.weights.get(node, 1) + + for node in self.nodes: + weight = 1 + + if node in self.weights: + weight = self.weights.get(node) # type: ignore[assignment] + + factor = math.floor((40 * len(self.nodes) * weight) / total_weight) + + for j in range(0, int(factor)): + b_key = self._hash_digest(f"{node}-{j}") + + for i in range(0, 3): + key = self._hash_val(b_key, lambda x: x + i * 4) # noqa + self.ring[key] = node + self._sorted_keys.append(key) + + self._sorted_keys.sort() + + def get_node(self, string_key: str) -> Optional[str]: + """Return hash ring by given a string key a corresponding node. + + If the hash ring is empty, `None` is returned. + :param string_key: String key + """ + pos = self.get_node_pos(string_key) + if pos is None: + return None + return self.ring[self._sorted_keys[pos]] + + def get_node_pos(self, string_key: str) -> Optional[int]: + """Return node and position. + + Given a string key a corresponding node in the hash ring + is returned along with it's position in the ring. + If the hash ring is empty, (`None`, `None`) is returned. + :param string_key: String key + """ + if not self.ring: + return None + + key = self.gen_key(string_key) + + nodes = self._sorted_keys + pos = bisect(nodes, key) + + if pos == len(nodes): + return 0 + return pos + + def iterate_nodes(self, string_key: str, distinct: bool = True) -> Generator: # noqa + """Given a string key it returns the nodes as a generator that can hold the key. + + The generator iterates one time through the ring + starting at the correct position. + :param string_key: string key + :param distinct: is set, then the nodes returned will be unique, + i.e. no virtual copies will be returned. + """ + if not self.ring: + yield None, None + + returned_values = set() + + def distinct_filter(value: str) -> Optional[str]: + """Do filtration on used values. + + :param value: Value to check + """ + if str(value) not in returned_values: + returned_values.add(str(value)) + return value + return None + + pos = self.get_node_pos(string_key) + for key in self._sorted_keys[pos:]: + val = distinct_filter(self.ring[key]) + if val: + yield val + + for i, key in enumerate(self._sorted_keys): + if i < pos: # type: ignore[operator] + val = distinct_filter(self.ring[key]) + if val: + yield val + + def gen_key(self, key: str) -> int: + """Given a string key it returns a long value. + + this long value represents a place on the hash ring. + md5 is currently used because it mixes well. + :param key: a string key + """ + b_key = self._hash_digest(key) + return self._hash_val(b_key, lambda x: x) + + def _hash_val(self, b_key: List[int], entry_fn: Callable) -> int: + return (b_key[entry_fn(3)] << 24) | (b_key[entry_fn(2)] << 16) | (b_key[entry_fn(1)] << 8) | b_key[entry_fn(0)] + + def _hash_digest(self, key: str) -> List[int]: + m = md5_constructor() + m.update(key.encode()) + return [int(str(letter)) for letter in m.digest()] # , m.digest())) diff --git a/pyproject.toml b/pyproject.toml index 68ba724..d85902b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -297,7 +297,7 @@ omit = [ directory = "tests/code_coverage" [tool.coverage.report] -fail_under=87 +fail_under=80 exclude_lines = [ "no cov", "pragma: no cover", diff --git a/tests/conftest.py b/tests/conftest.py index 6eb6028..36fa074 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ from concurrent.futures import ThreadPoolExecutor -from typing import Dict +from typing import Dict, List from unittest.mock import patch from uuid import uuid4 @@ -12,23 +12,32 @@ from deker.ctx import CTX from deker.uri import Uri from deker_local_adapters.storage_adapters.hdf5.hdf5_storage_adapter import HDF5StorageAdapter -from httpx import Client from pytest_mock import MockerFixture +from tests.mocks import MockedAdaptersFactory + from deker_server_adapters.array_adapter import ServerArrayAdapter from deker_server_adapters.collection_adapter import ServerCollectionAdapter from deker_server_adapters.factory import AdaptersFactory +from deker_server_adapters.hash_ring import HashRing from deker_server_adapters.httpx_client import HttpxClient from deker_server_adapters.varray_adapter import ServerVarrayAdapter @pytest.fixture(scope="session") -def collection_path() -> Uri: - return Uri.create("http://localhost:8000/v1/collection") +def nodes() -> List[str]: + return ["http://localhost:8000", "http://localhost:8001"] + + +@pytest.fixture(scope="session") +def collection_path(nodes: List[str]) -> Uri: + uri = Uri.create("http://localhost:8000/v1/collection") + uri.servers = nodes + return uri @pytest.fixture(scope="session") -def ctx(session_mocker: MockerFixture, collection_path: Uri) -> CTX: +def ctx(session_mocker: MockerFixture, collection_path: Uri, nodes: List[str]) -> CTX: ctx = CTX( uri=collection_path, config=DekerConfig( @@ -43,13 +52,13 @@ def ctx(session_mocker: MockerFixture, collection_path: Uri) -> CTX: ) with HttpxClient(base_url="http://localhost:8000/") as client: ctx.extra["httpx_client"] = client + ctx.extra["hash_ring"] = HashRing(nodes) yield ctx @pytest.fixture(scope="session") def adapter_factory(ctx: CTX, collection_path: Uri) -> AdaptersFactory: - with patch.object(AdaptersFactory, "do_healthcheck"): - yield AdaptersFactory(ctx, uri=collection_path) + return MockedAdaptersFactory(ctx, uri=collection_path) @pytest.fixture() diff --git a/tests/mocks.py b/tests/mocks.py index e69de29..eab5cfa 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -0,0 +1,8 @@ +from deker.ctx import CTX + +from deker_server_adapters.factory import AdaptersFactory + + +class MockedAdaptersFactory(AdaptersFactory): + def do_healthcheck(self, ctx: CTX) -> None: + pass diff --git a/tests/test_cases/test_array_adapter.py b/tests/test_cases/test_array_adapter.py index f6c3b77..a0b7f18 100644 --- a/tests/test_cases/test_array_adapter.py +++ b/tests/test_cases/test_array_adapter.py @@ -1,4 +1,5 @@ import json +from typing import List from unittest.mock import patch from uuid import uuid4 @@ -211,7 +212,6 @@ def test_iter_success( array: Array, httpx_mock: HTTPXMock, server_array_adapter: ServerArrayAdapter, - collection: Collection, ): httpx_mock.add_response(json=[array.as_dict]) arrays = [] @@ -219,3 +219,29 @@ def test_iter_success( arrays.append(array_) assert arrays == [json.loads(json.dumps(array.as_dict))] + + + +def test_get_node_by_id(array: Array, server_array_adapter: ServerArrayAdapter, nodes: List[str]): + with patch.object(array, "primary_attributes", {}): + # Check window slides + + node = server_array_adapter.get_node(array) + assert node in nodes + + + +def test_get_node_by_primary(array: Array, server_array_adapter: ServerArrayAdapter, nodes: List[str]): + with patch.object(array, "primary_attributes", {"foo": "bar"}): + # Check window slides + + node = server_array_adapter.get_node(array) + assert node in nodes + + +def test_get_node_give_same_result(array: Array, server_array_adapter: ServerArrayAdapter, nodes: List[str]): + first_node = server_array_adapter.get_node(array) + for _ in range(10): + node = server_array_adapter.get_node(array) + assert node == first_node + diff --git a/tests/test_cases/test_factory.py b/tests/test_cases/test_factory.py index 197b153..6b8d8cc 100644 --- a/tests/test_cases/test_factory.py +++ b/tests/test_cases/test_factory.py @@ -1,8 +1,11 @@ +import re + from unittest.mock import patch from deker.uri import Uri from deker_local_adapters.storage_adapters.hdf5 import HDF5StorageAdapter +from ..mocks import MockedAdaptersFactory from deker_server_adapters.array_adapter import ServerArrayAdapter from deker_server_adapters.collection_adapter import ServerCollectionAdapter from deker_server_adapters.factory import AdaptersFactory @@ -29,12 +32,25 @@ def test_get_collection_adapter(adapter_factory: AdaptersFactory): def test_auth_factory(ctx): uri = Uri.create("http://test:test@localhost/") - factory = AdaptersFactory(ctx, uri) + factory = MockedAdaptersFactory(ctx, uri) assert factory.httpx_client.auth def test_auth_factory_close(ctx): uri = Uri.create("http://test:test@localhost/") - factory = AdaptersFactory(ctx, uri) + factory = MockedAdaptersFactory(ctx, uri) factory.close() assert factory.httpx_client.is_closed + + +def test_ctx_has_values_from_server(ctx, httpx_mock): + uri = Uri.create("http://test:test@localhost/") + servers = ["http://localhost:8031"] + + httpx_mock.add_response(method="get", url=re.compile(r".*\/v1\/ping"), json={"servers": servers}) + factory = AdaptersFactory(ctx, uri) + vadapter = factory.get_varray_adapter("/col", HDF5StorageAdapter) + adapter = factory.get_array_adapter("/coll", HDF5StorageAdapter) + + assert vadapter.hash_ring.nodes == servers + assert adapter.hash_ring.nodes == servers diff --git a/tests/test_cases/test_varray_adapters.py b/tests/test_cases/test_varray_adapters.py index dbe9180..3e5615e 100644 --- a/tests/test_cases/test_varray_adapters.py +++ b/tests/test_cases/test_varray_adapters.py @@ -1,6 +1,8 @@ import json from uuid import uuid4 +from typing import List +from unittest.mock import patch import numpy as np import pytest @@ -133,3 +135,27 @@ def test_clear_deker_timeout(varray: VArray, httpx_mock: HTTPXMock, server_varra with pytest.raises(DekerTimeoutServer): server_varray_adapter.clear(varray, np.index_exp[:]) + + +def test_get_node_by_id(varray: VArray, server_varray_adapter: ServerVarrayAdapter, nodes: List[str]): + with patch.object(varray, "primary_attributes", {}): + # Check window slides + + node = server_varray_adapter.get_node(varray) + assert node in nodes + + + +def test_get_node_by_primary(varray: VArray, server_varray_adapter: ServerVarrayAdapter, nodes: List[str]): + with patch.object(varray, "primary_attributes", {"foo": "bar"}): + # Check window slides + + node = server_varray_adapter.get_node(varray) + assert node in nodes + + +def test_get_node_give_same_result(varray: VArray, server_varray_adapter: ServerVarrayAdapter, nodes: List[str]): + first_node = server_varray_adapter.get_node(varray) + for _ in range(10): + node = server_varray_adapter.get_node(varray) + assert node == first_node