Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
Merge pull request #584 from OP2/wence/communicator-fixes
Browse files Browse the repository at this point in the history
wence/communicator fixes
  • Loading branch information
wence- authored Jun 3, 2020
2 parents 74cf453 + 729b0f2 commit f72fc39
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 36 deletions.
41 changes: 9 additions & 32 deletions pyop2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,6 @@ class Set(object):
Halo send/receive data is stored on sets in a :class:`Halo`.
"""

_globalcount = 0

_CORE_SIZE = 0
_OWNED_SIZE = 1
_GHOST_SIZE = 2
Expand All @@ -417,12 +415,11 @@ def __init__(self, size, name=None, halo=None, comm=None):
assert size[Set._CORE_SIZE] <= size[Set._OWNED_SIZE] <= \
size[Set._GHOST_SIZE], "Set received invalid sizes: %s" % size
self._sizes = size
self._name = name or "set_%d" % Set._globalcount
self._name = name or "set_#x%x" % id(self)
self._halo = halo
self._partition_size = 1024
# A cache of objects built on top of this set
self._cache = {}
Set._globalcount += 1

@cached_property
def core_size(self):
Expand Down Expand Up @@ -906,7 +903,6 @@ class DataSet(ObjectCached):
Set used in the op2.Dat structures to specify the dimension of the data.
"""
_globalcount = 0

@validate_type(('iter_set', Set, SetTypeError),
('dim', (numbers.Integral, tuple, list), DimTypeError),
Expand All @@ -921,8 +917,7 @@ def __init__(self, iter_set, dim=1, name=None):
self._set = iter_set
self._dim = as_tuple(dim, numbers.Integral)
self._cdim = np.prod(self._dim).item()
self._name = name or "dset_%d" % DataSet._globalcount
DataSet._globalcount += 1
self._name = name or "dset_#x%x" % id(self)
self._initialized = True

@classmethod
Expand Down Expand Up @@ -1001,7 +996,6 @@ def __contains__(self, dat):
class GlobalDataSet(DataSet):
"""A proxy :class:`DataSet` for use in a :class:`Sparsity` where the
matrix has :class:`Global` rows or columns."""
_globalcount = 0

def __init__(self, global_):
"""
Expand Down Expand Up @@ -1348,13 +1342,12 @@ def pack(self):
from pyop2.codegen.builder import DatPack
return DatPack

_globalcount = 0
_modes = [READ, WRITE, RW, INC, MIN, MAX]

@validate_type(('dataset', (DataCarrier, DataSet, Set), DataSetTypeError),
('name', str, NameTypeError))
@validate_dtype(('dtype', None, DataTypeError))
def __init__(self, dataset, data=None, dtype=None, name=None, uid=None):
def __init__(self, dataset, data=None, dtype=None, name=None):

if isinstance(dataset, Dat):
self.__init__(dataset.dataset, None, dtype=dataset.dtype,
Expand All @@ -1371,14 +1364,7 @@ def __init__(self, dataset, data=None, dtype=None, name=None, uid=None):
self._dataset = dataset
self.comm = dataset.comm
self.halo_valid = True
# If the uid is not passed in from outside, assume that Dats
# have been declared in the same order everywhere.
if uid is None:
self._id = Dat._globalcount
Dat._globalcount += 1
else:
self._id = uid
self._name = name or "dat_%d" % self._id
self._name = name or "dat_#x%x" % id(self)

@cached_property
def _kernel_args_(self):
Expand Down Expand Up @@ -2283,7 +2269,6 @@ class Global(DataCarrier, _EmptyDataMixin):
initialised to be zero.
"""

_globalcount = 0
_modes = [READ, INC, MIN, MAX]

@validate_type(('name', str, NameTypeError))
Expand All @@ -2298,9 +2283,8 @@ def __init__(self, dim, data=None, dtype=None, name=None, comm=None):
self._cdim = np.prod(self._dim).item()
_EmptyDataMixin.__init__(self, data, dtype, self._dim)
self._buf = np.empty(self.shape, dtype=self.dtype)
self._name = name or "global_%d" % Global._globalcount
self._name = name or "global_#x%x" % id(self)
self.comm = comm
Global._globalcount += 1

@cached_property
def _kernel_args_(self):
Expand Down Expand Up @@ -2524,8 +2508,6 @@ class Map(object):
map result will be passed to the kernel.
"""

_globalcount = 0

dtype = IntType

@validate_type(('iterset', Set, SetTypeError), ('toset', Set, SetTypeError),
Expand All @@ -2539,14 +2521,13 @@ def __init__(self, iterset, toset, arity, values=None, name=None, offset=None):
(iterset.total_size, arity),
allow_none=True)
self.shape = (iterset.total_size, arity)
self._name = name or "map_%d" % Map._globalcount
self._name = name or "map_#x%x" % id(self)
if offset is None or len(offset) == 0:
self._offset = None
else:
self._offset = verify_reshape(offset, IntType, (arity, ))
# A cache for objects built on top of this map
self._cache = {}
Map._globalcount += 1

@cached_property
def _kernel_args_(self):
Expand Down Expand Up @@ -2847,8 +2828,7 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None,
raise ValueError("Haven't thought hard enough about different left and right communicators")
self.comm = self.lcomm

self._name = name or "sparsity_%d" % Sparsity._globalcount
Sparsity._globalcount += 1
self._name = name or "sparsity_#x%x" % id(self)

self.iteration_regions = iteration_regions
# If the Sparsity is defined on MixedDataSets, we need to build each
Expand Down Expand Up @@ -2885,7 +2865,6 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None,
self._initialized = True

_cache = {}
_globalcount = 0

@classmethod
@validate_type(('dsets', (Set, DataSet, tuple, list), DataSetTypeError),
Expand Down Expand Up @@ -3135,7 +3114,6 @@ def pack(self):
INSERT_VALUES = "INSERT_VALUES"
ADD_VALUES = "ADD_VALUES"

_globalcount = 0
_modes = [WRITE, INC]

@validate_type(('sparsity', Sparsity, SparsityTypeError),
Expand All @@ -3147,9 +3125,8 @@ def __init__(self, sparsity, dtype=None, name=None):
self.comm = sparsity.comm
dtype = dtype or ScalarType
self._datatype = np.dtype(dtype)
self._name = name or "mat_%d" % Mat._globalcount
self._name = name or "mat_#x%x" % id(self)
self.assembly_state = Mat.ASSEMBLED
Mat._globalcount += 1

@validate_in(('access', _modes, ModeValueError))
def __call__(self, access, path, lgmaps=None, unroll_map=False):
Expand Down Expand Up @@ -3443,7 +3420,7 @@ class JITModule(Cached):
def _cache_key(cls, kernel, iterset, *args, **kwargs):
counter = itertools.count()
seen = defaultdict(lambda: next(counter))
key = (kernel._wrapper_cache_key_ + iterset._wrapper_cache_key_
key = ((id(dup_comm(iterset.comm)), ) + kernel._wrapper_cache_key_ + iterset._wrapper_cache_key_
+ (iterset._extruded, (iterset._extruded and iterset.constant_layers), isinstance(iterset, Subset)))

for arg in args:
Expand Down
11 changes: 7 additions & 4 deletions pyop2/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,16 @@ def __init__(self, cc, ld=None, cppargs=[], ldargs=[],

@property
def compiler_version(self):
key = (id(self.comm), self._cc)
try:
return Compiler.compiler_versions[self._cc]
return Compiler.compiler_versions[key]
except KeyError:
if self.comm.rank == 0:
ver = sniff_compiler_version(self._cc)
else:
ver = None
ver = self.comm.bcast(ver, root=0)
return Compiler.compiler_versions.setdefault(self._cc, ver)
return Compiler.compiler_versions.setdefault(key, ver)

@property
def workaround_cflags(self):
Expand Down Expand Up @@ -233,7 +234,7 @@ def get_so(self, jitmodule, extension):
library."""

# Determine cache key
hsh = md5(str(jitmodule.cache_key).encode())
hsh = md5(str(jitmodule.cache_key[1:]).encode())
hsh.update(self._cc.encode())
if self._ld:
hsh.update(self._ld.encode())
Expand Down Expand Up @@ -457,7 +458,9 @@ def load(jitmodule, extension, fn_name, cppargs=[], ldargs=[],
class StrCode(object):
def __init__(self, code, argtypes):
self.code_to_compile = code
self.cache_key = code
self.cache_key = (None, code) # We peel off the first
# entry, since for a jitmodule, it's a process-local
# cache key
self.argtypes = argtypes
code = StrCode(jitmodule, argtypes)
elif isinstance(jitmodule, JITModule):
Expand Down

0 comments on commit f72fc39

Please sign in to comment.