Skip to content

Commit

Permalink
Separate local storage from shared storage
Browse files Browse the repository at this point in the history
  • Loading branch information
Retribution98 committed Dec 13, 2023
1 parent 2d66acd commit 1c47ae7
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 54 deletions.
14 changes: 5 additions & 9 deletions unidist/core/backends/mpi/core/controller/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
request_worker_data,
push_data,
RoundRobin,
get_data,
contains_data,
)
import unidist.core.backends.mpi.core.common as common
import unidist.core.backends.mpi.core.communication as communication
Expand Down Expand Up @@ -411,21 +413,17 @@ def get(data_ids):
object
A Python object.
"""
local_store = LocalObjectStore.get_instance()

is_list = isinstance(data_ids, list)
if not is_list:
data_ids = [data_ids]
remote_data_ids = [
data_id for data_id in data_ids if not local_store.contains(data_id)
]
remote_data_ids = [data_id for data_id in data_ids if not contains_data(data_id)]
# Remote data gets available in the local store inside `request_worker_data`
if remote_data_ids:
request_worker_data(remote_data_ids)

logger.debug("GET {} ids".format(common.unwrapped_data_ids_list(data_ids)))

values = [local_store.get(data_id) for data_id in data_ids]
values = [get_data(data_id) for data_id in data_ids]

# Initiate reference count based cleaup
# if all the tasks were completed
Expand Down Expand Up @@ -463,11 +461,9 @@ def wait(data_ids, num_returns=1):
not_ready = data_ids
pending_returns = num_returns
ready = []
local_store = LocalObjectStore.get_instance()

logger.debug("WAIT {} ids".format(common.unwrapped_data_ids_list(data_ids)))
for data_id in not_ready.copy():
if local_store.contains(data_id):
if contains_data(data_id):
ready.append(data_id)
not_ready.remove(data_id)
pending_returns -= 1
Expand Down
67 changes: 64 additions & 3 deletions unidist/core/backends/mpi/core/controller/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from unidist.core.backends.mpi.core.async_operations import AsyncOperations
from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore
from unidist.core.backends.mpi.core.serialization import serialize_complex_data
from unidist.core.backends.mpi.core.serialization import (
deserialize_complex_data,
serialize_complex_data,
)


logger = common.get_logger("common", "common.log")
Expand Down Expand Up @@ -138,10 +141,10 @@ def pull_data(comm, owner_rank=None):
shared_store = SharedObjectStore.get_instance()
data_id = info_package["id"]

if local_store.contains(data_id):
if contains_data(data_id):
return {
"id": data_id,
"data": local_store.get(data_id),
"data": get_data(data_id),
}

data = shared_store.get(data_id, owner_rank, info_package)
Expand Down Expand Up @@ -411,3 +414,61 @@ def push_data(dest_rank, value, is_blocking_op=False):
_push_data_owner(dest_rank, data_id)
else:
raise ValueError("Unknown DataID!")


def contains_data(data_id):
"""
Check if the data associated with `data_id` exists in the current process.
Parameters
----------
data_id : unidist.core.backends.mpi.core.common.MpiDataID
An ID to data.
Returns
-------
bool
Return the status if an object exist in the current process.
"""
local_store = LocalObjectStore.get_instance()
shared_store = SharedObjectStore.get_instance()
return (
local_store.contains(data_id)
or local_store.is_already_serialized(data_id)
or shared_store.contains(data_id)
)


def get_data(data_id):
"""
Get data from any location in the current process.
Parameters
----------
data_id : unidist.core.backends.mpi.core.common.MpiDataID
An ID to data.
Returns
-------
object
Return data associated with `data_id`.
"""
local_store = LocalObjectStore.get_instance()
shared_store = SharedObjectStore.get_instance()

if local_store.contains(data_id):
return local_store.get(data_id)

if local_store.is_already_serialized(data_id):
serialized_data = local_store.get_serialized_data(data_id)
value = deserialize_complex_data(
serialized_data["s_data"],
serialized_data["raw_buffers"],
serialized_data["buffer_count"],
)
elif shared_store.contains(data_id):
value = shared_store.get(data_id)
else:
raise ValueError("The current data ID is not contained in the procces.")
local_store.put(data_id, value)
return value
34 changes: 3 additions & 31 deletions unidist/core/backends/mpi/core/local_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

import unidist.core.backends.mpi.core.common as common
import unidist.core.backends.mpi.core.communication as communication
from unidist.core.backends.mpi.core.serialization import deserialize_complex_data
from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore


class LocalObjectStore:
Expand Down Expand Up @@ -114,26 +112,7 @@ def get(self, data_id):
object
Return local data associated with `data_id`.
"""
if data_id in self._data_map:
return self._data_map[data_id]
# data can be prepared for sending to another process, but is not saved to local storage
else:
shared_object_store = SharedObjectStore.get_instance()
if self.is_already_serialized(data_id):
serialized_data = self.get_serialized_data(data_id)
value = deserialize_complex_data(
serialized_data["s_data"],
serialized_data["raw_buffers"],
serialized_data["buffer_count"],
)
elif shared_object_store.contains(data_id):
value = shared_object_store.get(data_id)
else:
raise ValueError(
"The current data ID is not contained in the LocalObjectStore."
)
self.put(data_id, value)
return value
return self._data_map[data_id]

def get_data_owner(self, data_id):
"""
Expand Down Expand Up @@ -165,12 +144,7 @@ def contains(self, data_id):
bool
Return the status if an object exist in local dictionary.
"""
shared_object_store = SharedObjectStore.get_instance()
return (
data_id in self._data_map
or self.is_already_serialized(data_id)
or shared_object_store.contains(data_id)
)
return data_id in self._data_map

def contains_data_owner(self, data_id):
"""
Expand Down Expand Up @@ -305,9 +279,7 @@ def cache_serialized_data(self, data_id, data):
"""
# Copying is necessary to avoid corruption of data obtained through out-of-band serialization,
# and buffers are marked read-only to prevent them from being modified.
data["raw_buffers"] = [
memoryview(buf.tobytes()).toreadonly() for buf in data["raw_buffers"]
]
data["raw_buffers"] = [buf.tobytes() for buf in data["raw_buffers"]]
self._serialization_cache[data_id] = data
self.maybe_update_data_id_map(data_id)

Expand Down
4 changes: 2 additions & 2 deletions unidist/core/backends/mpi/core/worker/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from unidist.core.backends.mpi.core.worker.request_store import RequestStore
from unidist.core.backends.mpi.core.worker.task_store import TaskStore
from unidist.core.backends.mpi.core.async_operations import AsyncOperations
from unidist.core.backends.mpi.core.controller.common import pull_data
from unidist.core.backends.mpi.core.controller.common import pull_data, get_data
from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore

# TODO: Find a way to move this after all imports
Expand Down Expand Up @@ -185,7 +185,7 @@ async def worker_loop():
if not ready_to_shutdown_posted:
# Prepare the data
# Actor method here is a data id so we have to retrieve it from the storage
method_name = local_store.get(request["task"])
method_name = get_data(request["task"])
handler = request["handler"]
actor_method = getattr(actor_map[handler], method_name)
request["task"] = actor_method
Expand Down
8 changes: 3 additions & 5 deletions unidist/core/backends/mpi/core/worker/request_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

import unidist.core.backends.mpi.core.common as common
import unidist.core.backends.mpi.core.communication as communication
from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
from unidist.core.backends.mpi.core.controller.common import push_data
from unidist.core.backends.mpi.core.controller.common import push_data, contains_data


mpi_state = communication.MPIState.get_instance()
Expand Down Expand Up @@ -213,7 +212,7 @@ def process_wait_request(self, data_id):
-----
Only ROOT rank is supported for now, therefore no rank argument needed.
"""
if LocalObjectStore.get_instance().contains(data_id):
if contains_data(data_id):
# Executor wait just for signal
# We use a blocking send here because the receiver is waiting for the result.
communication.mpi_send_object(
Expand Down Expand Up @@ -247,8 +246,7 @@ def process_get_request(self, source_rank, data_id, is_blocking_op=False):
-----
Request is asynchronous, no wait for the data sending.
"""
local_store = LocalObjectStore.get_instance()
if local_store.contains(data_id):
if contains_data(data_id):
push_data(
source_rank,
data_id,
Expand Down
8 changes: 4 additions & 4 deletions unidist/core/backends/mpi/core/worker/task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from unidist.core.backends.common.data_id import is_data_id
import unidist.core.backends.mpi.core.common as common
import unidist.core.backends.mpi.core.communication as communication
from unidist.core.backends.mpi.core.controller.common import get_data, contains_data
from unidist.core.backends.mpi.core.async_operations import AsyncOperations
from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore
Expand Down Expand Up @@ -187,8 +188,8 @@ def unwrap_local_data_id(self, arg):
"""
if is_data_id(arg):
local_store = LocalObjectStore.get_instance()
if local_store.contains(arg):
value = local_store.get(arg)
if contains_data(arg):
value = get_data(arg)
# Data is already local or was pushed from master
return value, False
elif local_store.contains_data_owner(arg):
Expand Down Expand Up @@ -418,12 +419,11 @@ def process_task_request(self, request):
Same request if the task couldn`t be executed, otherwise ``None``.
"""
# Parse request
local_store = LocalObjectStore.get_instance()
task = request["task"]
# Remote function here is a data id so we have to retrieve it from the storage,
# whereas actor method is already materialized in the worker loop.
if is_data_id(task):
task = local_store.get(task)
task = get_data(task)
args = request["args"]
kwargs = request["kwargs"]
output_ids = request["output"]
Expand Down

0 comments on commit 1c47ae7

Please sign in to comment.