diff --git a/parsl/config.py b/parsl/config.py index c3725eccf8..2121baedb5 100644 --- a/parsl/config.py +++ b/parsl/config.py @@ -5,6 +5,7 @@ from typing_extensions import Literal from parsl.dataflow.dependency_resolvers import DependencyResolver +from parsl.dataflow.memoization import Memoizer from parsl.dataflow.taskrecord import TaskRecord from parsl.errors import ConfigurationError from parsl.executors.base import ParslExecutor @@ -98,6 +99,7 @@ class Config(RepresentationMixin, UsageInformation): def __init__(self, executors: Optional[Iterable[ParslExecutor]] = None, app_cache: bool = True, + memoizer: Optional[Memoizer] = None, checkpoint_files: Optional[Sequence[str]] = None, checkpoint_mode: Union[None, Literal['task_exit'], @@ -127,6 +129,7 @@ def __init__(self, self._executors: Sequence[ParslExecutor] = executors self._validate_executors() + self.memoizer = memoizer self.app_cache = app_cache self.checkpoint_files = checkpoint_files self.checkpoint_mode = checkpoint_mode diff --git a/parsl/dataflow/dflow.py b/parsl/dataflow/dflow.py index 3ef0ef589f..1b1d61ff9d 100644 --- a/parsl/dataflow/dflow.py +++ b/parsl/dataflow/dflow.py @@ -169,9 +169,16 @@ def __init__(self, config: Config) -> None: else: checkpoint_files = [] - self.memoizer: Memoizer = BasicMemoizer(self, memoize=config.app_cache, checkpoint_files=checkpoint_files) - self.memoizer.run_dir = self.run_dir + # self.memoizer: Memoizer = BasicMemoizer(self, memoize=config.app_cache, checkpoint_files=checkpoint_files) + # the memoize flag might turn into the user choosing different instances + # of the Memoizer interface + self.memoizer: Memoizer + if config.memoizer is not None: + self.memoizer = config.memoizer + else: + self.memoizer = BasicMemoizer() + self.memoizer.start(dfk=self, memoize=config.app_cache, checkpoint_files=checkpoint_files, run_dir=self.run_dir) self._checkpoint_timer = None self.checkpoint_mode = config.checkpoint_mode diff --git a/parsl/dataflow/memoization.py b/parsl/dataflow/memoization.py index 5324bf8164..865705440a 100644 --- a/parsl/dataflow/memoization.py +++ b/parsl/dataflow/memoization.py @@ -121,6 +121,9 @@ def id_for_memo_function(f: types.FunctionType, output_ref: bool = False) -> byt class Memoizer: + def start(self, *, dfk: DataFlowKernel, memoize: bool = True, checkpoint_files: Sequence[str], run_dir: str) -> None: + raise NotImplementedError + def update_memo(self, task: TaskRecord, r: Future[Any]) -> None: raise NotImplementedError @@ -164,7 +167,10 @@ class BasicMemoizer(Memoizer): run_dir: str - def __init__(self, dfk: DataFlowKernel, *, memoize: bool = True, checkpoint_files: Sequence[str]): + def __init__(self) -> None: + pass + + def start(self, *, dfk: DataFlowKernel, memoize: bool = True, checkpoint_files: Sequence[str], run_dir: str) -> None: """Initialize the memoizer. Args: @@ -176,6 +182,7 @@ def __init__(self, dfk: DataFlowKernel, *, memoize: bool = True, checkpoint_file """ self.dfk = dfk self.memoize = memoize + self.run_dir = run_dir self.checkpointed_tasks = 0 diff --git a/parsl/tests/test_python_apps/test_memoize_plugin.py b/parsl/tests/test_python_apps/test_memoize_plugin.py new file mode 100644 index 0000000000..724facf165 --- /dev/null +++ b/parsl/tests/test_python_apps/test_memoize_plugin.py @@ -0,0 +1,53 @@ +import argparse + +import pytest + +import parsl +from parsl.app.app import python_app +from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer +from parsl.dataflow.taskrecord import TaskRecord + + +class DontReuseSevenMemoizer(BasicMemoizer): + def check_memo(self, task_record: TaskRecord): + if task_record['args'][0] == 7: + return None # we didn't find a suitable memo record... + else: + return super().check_memo(task_record) + + +def local_config(): + return Config(memoizer=DontReuseSevenMemoizer()) + + +@python_app(cache=True) +def random_uuid(x, cache=True): + import uuid + return str(uuid.uuid4()) + + +@pytest.mark.local +def test_python_memoization(n=10): + """Testing python memoization disable + """ + + # TODO: this .result() needs to be here, not in the loop + # because otherwise we race to complete... and then + # we might sometimes get a memoization before the loop + # and sometimes not... + x = random_uuid(0).result() + + for i in range(0, n): + foo = random_uuid(0) + print(i) + print(foo.result()) + assert foo.result() == x, "Memoized results were incorrectly not used" + + y = random_uuid(7).result() + + for i in range(0, n): + foo = random_uuid(7) + print(i) + print(foo.result()) + assert foo.result() != y, "Memoized results were incorrectly used"