Skip to content

Commit

Permalink
Fix some tests, run both modes
Browse files Browse the repository at this point in the history
Signed-off-by: Igoshev, Iaroslav <[email protected]>
  • Loading branch information
YarShev committed Sep 23, 2023
1 parent 899d072 commit 1f6f567
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 11 deletions.
16 changes: 12 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,19 @@ jobs:
if: matrix.backend != 'mpi'
# when using a directory to run with mpiexec MPI gets hung after executing tests
# so we run the test files one by one
- run: mpiexec -n 1 python -m pytest unidist/test/test_actor.py
- run: |
UNIDIST_MPI_SHARED_OBJECT_STORE=True mpiexec -n 1 python -m pytest unidist/test/test_actor.py
UNIDIST_MPI_SHARED_OBJECT_STORE=False mpiexec -n 1 python -m pytest unidist/test/test_actor.py
if: matrix.backend == 'mpi'
- run: mpiexec -n 1 python -m pytest unidist/test/test_async_actor.py
- run: |
UNIDIST_MPI_SHARED_OBJECT_STORE=True mpiexec -n 1 python -m pytest unidist/test/test_async_actor.py
UNIDIST_MPI_SHARED_OBJECT_STORE=False mpiexec -n 1 python -m pytest unidist/test/test_async_actor.py
if: matrix.backend == 'mpi'
- run: mpiexec -n 1 python -m pytest unidist/test/test_task.py
- run: |
UNIDIST_MPI_SHARED_OBJECT_STORE=True mpiexec -n 1 python -m pytest unidist/test/test_task.py
UNIDIST_MPI_SHARED_OBJECT_STORE=False mpiexec -n 1 python -m pytest unidist/test/test_task.py
if: matrix.backend == 'mpi'
- run: mpiexec -n 1 python -m pytest unidist/test/test_general.py
- run: |
UNIDIST_MPI_SHARED_OBJECT_STORE=True mpiexec -n 1 python -m pytest unidist/test/test_general.py
UNIDIST_MPI_SHARED_OBJECT_STORE=False mpiexec -n 1 python -m pytest unidist/test/test_general.py
if: matrix.backend == 'mpi'
13 changes: 12 additions & 1 deletion unidist/core/backends/mpi/core/controller/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
garbage_collector,
)
from unidist.core.backends.mpi.core.controller.common import push_data, RoundRobin
from unidist.core.backends.mpi.core.controller.api import put


class ActorMethod:
Expand Down Expand Up @@ -41,6 +42,9 @@ def __call__(self, *args, num_returns=1, **kwargs):
unwrapped_args = [common.unwrap_data_ids(arg) for arg in args]
unwrapped_kwargs = {k: common.unwrap_data_ids(v) for k, v in kwargs.items()}

push_data(
self._actor._owner_rank, common.master_data_ids_to_base(self._method_name)
)
push_data(self._actor._owner_rank, unwrapped_args)
push_data(self._actor._owner_rank, unwrapped_kwargs)

Expand Down Expand Up @@ -181,8 +185,15 @@ def __reduce__(self):
state = self._serialization_helper()
return self._deserialization_helper, (state,)

# Cache for serialized actor methods {"method_name": DataID}
actor_methods = {}

def __getattr__(self, name):
return ActorMethod(self, name)
data_id_to_method = self.actor_methods.get(name, None)
if data_id_to_method is None:
data_id_to_method = put(name)
self.actor_methods[name] = data_id_to_method
return ActorMethod(self, data_id_to_method)

def __del__(self):
"""
Expand Down
3 changes: 2 additions & 1 deletion unidist/core/backends/mpi/core/worker/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ async def worker_loop():
request = pull_data(mpi_state.comm, source_rank)
if not ready_to_shutdown_posted:
# Prepare the data
method_name = request["task"]
# Actor method here is a data id so we have to retrieve it from the storage
method_name = local_store.get(request["task"])
handler = request["handler"]
actor_method = getattr(actor_map[handler], method_name)
request["task"] = actor_method
Expand Down
9 changes: 5 additions & 4 deletions unidist/core/backends/mpi/core/worker/task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,10 +442,11 @@ def process_task_request(self, request):
# Parse request
local_store = LocalObjectStore.get_instance()
shared_store = SharedObjectStore.get_instance()
if is_data_id(request["task"]):
task = local_store.get(request["task"])
else:
task = request["task"]
task = request["task"]
# Remote function here is a data id so we have to retrieve it from the storage,
# whereas actor method is already materialized in the worker loop.
if is_data_id(task):
task = local_store.get(task)
args = request["args"]
kwargs = request["kwargs"]
output_ids = request["output"]
Expand Down
5 changes: 4 additions & 1 deletion unidist/core/backends/mpi/remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""An implementation of ``RemoteFunction`` interface using MPI."""

from unidist.core.backends.common.data_id import is_data_id
import unidist.core.backends.mpi.core as mpi
from unidist.core.backends.common.utils import unwrap_object_refs
from unidist.core.base.object_ref import ObjectRef
Expand All @@ -27,7 +28,7 @@ class MPIRemoteFunction(RemoteFunction):
"""

def __init__(self, function, num_cpus, num_returns, resources):
self._remote_function = mpi.put(function)
self._remote_function = function
self._num_cpus = num_cpus
self._num_returns = 1 if num_returns is None else num_returns
self._resources = resources
Expand Down Expand Up @@ -71,6 +72,8 @@ def _remote(self, *args, num_cpus=None, num_returns=None, resources=None, **kwar
unwrapped_args = [unwrap_object_refs(arg) for arg in args]
unwrapped_kwargs = {k: unwrap_object_refs(v) for k, v in kwargs.items()}

if not is_data_id(self._remote_function):
self._remote_function = mpi.put(self._remote_function)
data_ids = mpi.submit(
self._remote_function,
*unwrapped_args,
Expand Down

0 comments on commit 1f6f567

Please sign in to comment.