diff --git a/parsl/dataflow/memoization.py b/parsl/dataflow/memoization.py index 4c511cdd61..a50b411a9d 100644 --- a/parsl/dataflow/memoization.py +++ b/parsl/dataflow/memoization.py @@ -121,6 +121,42 @@ def id_for_memo_function(f: types.FunctionType, output_ref: bool = False) -> byt return pickle.dumps(["types.FunctionType", f.__name__, f.__module__]) +def make_hash(task: TaskRecord) -> str: + """Create a hash of the task inputs. + + Args: + - task (dict) : Task dictionary from dfk.tasks + + Returns: + - hash (str) : A unique hash string + """ + + t: List[bytes] = [] + + # if kwargs contains an outputs parameter, that parameter is removed + # and normalised differently - with output_ref set to True. + # kwargs listed in ignore_for_cache will also be removed + + filtered_kw = task['kwargs'].copy() + + ignore_list = task['ignore_for_cache'] + + logger.debug("Ignoring these kwargs for checkpointing: %s", ignore_list) + for k in ignore_list: + logger.debug("Ignoring kwarg %s", k) + del filtered_kw[k] + + if 'outputs' in task['kwargs']: + outputs = task['kwargs']['outputs'] + del filtered_kw['outputs'] + t.append(id_for_memo(outputs, output_ref=True)) + + t.extend(map(id_for_memo, (filtered_kw, task['func'], task['args']))) + + x = b''.join(t) + return hashlib.md5(x).hexdigest() + + class Memoizer: def start(self, *, dfk: DataFlowKernel, memoize: bool = True, checkpoint_files: Sequence[str], run_dir: str) -> None: raise NotImplementedError @@ -200,41 +236,6 @@ def start(self, *, dfk: DataFlowKernel, memoize: bool = True, checkpoint_files: logger.info("App caching disabled for all apps") self.memo_lookup_table = {} - def make_hash(self, task: TaskRecord) -> str: - """Create a hash of the task inputs. - - Args: - - task (dict) : Task dictionary from dfk.tasks - - Returns: - - hash (str) : A unique hash string - """ - - t: List[bytes] = [] - - # if kwargs contains an outputs parameter, that parameter is removed - # and normalised differently - with output_ref set to True. - # kwargs listed in ignore_for_cache will also be removed - - filtered_kw = task['kwargs'].copy() - - ignore_list = task['ignore_for_cache'] - - logger.debug("Ignoring these kwargs for checkpointing: %s", ignore_list) - for k in ignore_list: - logger.debug("Ignoring kwarg %s", k) - del filtered_kw[k] - - if 'outputs' in task['kwargs']: - outputs = task['kwargs']['outputs'] - del filtered_kw['outputs'] - t.append(id_for_memo(outputs, output_ref=True)) - - t.extend(map(id_for_memo, (filtered_kw, task['func'], task['args']))) - - x = b''.join(t) - return hashlib.md5(x).hexdigest() - def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]: """Create a hash of the task and its inputs and check the lookup table for this hash. @@ -256,7 +257,7 @@ def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]: logger.debug("Task {} will not be memoized".format(task_id)) return None - hashsum = self.make_hash(task) + hashsum = make_hash(task) logger.debug("Task {} has memoization hash {}".format(task_id, hashsum)) result = None if hashsum in self.memo_lookup_table: