Skip to content

Commit

Permalink
some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Retribution98 committed Sep 22, 2023
1 parent 4cd63fe commit c799be3
Show file tree
Hide file tree
Showing 13 changed files with 68 additions and 27,049 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies:
- msgpack-python>=1.0.0
- cloudpickle
- packaging
- cython
- psutil
- pytest
# for downloading packages from PyPI
Expand Down
10 changes: 7 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
# Get the long description from the README file
long_description = (here / "README.md").read_text(encoding="utf-8")

_memory = Extension(
"unidist._memory",
["unidist/core/backends/mpi/core/memory/_memory.pyx"],
language="c++",
)

setup(
name="unidist",
version=versioneer.get_version(),
Expand All @@ -36,7 +42,5 @@
"all": all_deps,
},
python_requires=">=3.7.1",
ext_modules=cythonize(
[Extension("unidist._memory", ["unidist/ext_modules/memory/cmemory.pyx"])]
),
ext_modules=cythonize([_memory]),
)
2 changes: 1 addition & 1 deletion unidist/core/backends/mpi/core/controller/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def put(data):
shared_store = SharedObjectStore.get_instance()
data_id = local_store.generate_data_id(garbage_collector)
local_store.put(data_id, data)
if shared_store.should_be_shared(data):
if MpiSharedObjectStore.get():
shared_store.put(data_id, data)

logger.debug("PUT {} id".format(data_id._id))
Expand Down
19 changes: 10 additions & 9 deletions unidist/core/backends/mpi/core/controller/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,17 +352,18 @@ def push_data(dest_rank, value, is_blocking_op=False):
data_id = value
if shared_store.contains(data_id):
_push_shared_data(dest_rank, data_id, is_blocking_op)
elif local_store.is_already_serialized(data_id):
_push_local_data(dest_rank, data_id, is_blocking_op, is_serialized=True)
elif local_store.contains(data_id):
data = local_store.get(data_id)
if shared_store.should_be_shared(data):
shared_store.put(data_id, data)
_push_shared_data(dest_rank, data_id, is_blocking_op)
if local_store.is_already_serialized(data_id):
_push_local_data(dest_rank, data_id, is_blocking_op, is_serialized=True)
else:
_push_local_data(
dest_rank, data_id, is_blocking_op, is_serialized=False
)
data = local_store.get(data_id)
if shared_store.should_be_shared(data):
shared_store.put(data_id, data)
_push_shared_data(dest_rank, data_id, is_blocking_op)
else:
_push_local_data(
dest_rank, data_id, is_blocking_op, is_serialized=False
)
elif local_store.contains_data_owner(data_id):
_push_data_owner(dest_rank, data_id)
else:
Expand Down
26 changes: 26 additions & 0 deletions unidist/core/backends/mpi/core/memory/_memory.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (C) 2021-2023 Modin authors
#
# SPDX-License-Identifier: Apache-2.0

from libc.stdint cimport uint8_t
cimport memory

def parallel_memcopy(const uint8_t[:] src, uint8_t[:] dst, int memcopy_threads):
"""
Multithreaded data copying between buffers.
Parameters
----------
src : uint8_t[:]
Copied data.
dst : uint8_t[:]
Buffer for writing.
memcopy_threads : int
Number of threads to write.
"""
with nogil:
memory.parallel_memcopy(&dst[0],
&src[0],
len(src),
64,
memcopy_threads)
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
/*
* Copyright (C) 2021-2023 Modin authors
*
* SPDX-License-Identifier: Apache-2.0
*/

#include "memory.h"

#include <cstring>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
/*
* Copyright (C) 2021-2023 Modin authors
*
* SPDX-License-Identifier: Apache-2.0
*/

#ifndef MEMORY_H
#define MEMORY_H

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Copyright (C) 2021-2023 Modin authors
#
# SPDX-License-Identifier: Apache-2.0

from libc.stdint cimport uint8_t, uintptr_t, int64_t

cdef extern from "memory.cpp" nogil:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from array import array

from unidist.core.backends.mpi.utils import ImmutableDict

try:
import mpi4py
except ImportError:
Expand All @@ -17,6 +15,7 @@

from unidist.core.backends.mpi.core import communication, common
from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore
from unidist.core.backends.mpi.utils import ImmutableDict

# TODO: Find a way to move this after all imports
mpi4py.rc(recv_mprobe=False, initialize=False)
Expand Down Expand Up @@ -194,7 +193,7 @@ def put(self, data_id, memory_len):

def clear(self, data_id_list):
"""
Clear shared memory for the list of `DataID`.
Clear shared memory for the list of `DataID` if prossible.
Parameters
----------
Expand Down
10 changes: 5 additions & 5 deletions unidist/core/backends/mpi/core/shared_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

class WinLock:
"""
Class that helps synchronize writes to shared memory.
Class that helps to synchronize the write to shared memory.
Parameters
----------
Expand Down Expand Up @@ -229,7 +229,7 @@ def _increment_ref_number(self, data_id, service_index):
raise KeyError(
"it is not possible to increment the reference number for this data_id because it is not part of the shared data"
)
with SharedSignaler(self.service_win):
with WinLock(self.service_win):
prev_ref_number = self.service_shared_buffer[
service_index + self.REFERENCES_NUMBER
]
Expand Down Expand Up @@ -266,7 +266,7 @@ def _decrement_ref_number(self, data_id, service_index):
if MPI.Is_finalized():
return
if self._check_service_info(data_id, service_index):
with SharedSignaler(self.service_win):
with WinLock(self.service_win):
prev_ref_number = self.service_shared_buffer[
service_index + self.REFERENCES_NUMBER
]
Expand Down Expand Up @@ -296,7 +296,7 @@ def _put_service_info(self, service_index, data_id, first_index):
"""
worker_id, data_number = self._parse_data_id(data_id)

with SharedSignaler(self.service_win):
with WinLock(self.service_win):
self.service_shared_buffer[
service_index + self.FIRST_DATA_INDEX
] = first_index
Expand Down Expand Up @@ -626,7 +626,7 @@ def delete_service_info(self, data_id, service_index):
-----
This function should be called by the monitor during the cleanup of shared data.
"""
with SharedSignaler(self.service_win):
with WinLock(self.service_win):
# Read actual value
old_worker_id = self.service_shared_buffer[
service_index + self.WORKER_ID_INDEX
Expand Down
7 changes: 0 additions & 7 deletions unidist/ext_modules/memory/README.md

This file was deleted.

Loading

0 comments on commit c799be3

Please sign in to comment.