From 4a655d8316d40c6422671bc3f1f5076554e32ef3 Mon Sep 17 00:00:00 2001 From: John Lambert Date: Thu, 12 Oct 2023 10:55:39 -0400 Subject: [PATCH] Minor fixes --- machine/jobs/nmt_engine_build_job.py | 3 ++- tests/jobs/test_nmt_engine_build_job.py | 3 ++- typings/networkx/classes/digraph.pyi | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/machine/jobs/nmt_engine_build_job.py b/machine/jobs/nmt_engine_build_job.py index 8d30e95..29f2082 100644 --- a/machine/jobs/nmt_engine_build_job.py +++ b/machine/jobs/nmt_engine_build_job.py @@ -20,9 +20,10 @@ def __init__(self, config: Any, nmt_model_factory: NmtModelFactory, shared_file_ self._shared_file_service = shared_file_service self.clearml_task: Optional[Task] = None - def run(self, task: Optional[Task]) -> None: + def run(self, task: Optional[Task] = None) -> None: self.clearml_task = task self._send_clearml_config() + self._check_canceled() self._nmt_model_factory.init() diff --git a/tests/jobs/test_nmt_engine_build_job.py b/tests/jobs/test_nmt_engine_build_job.py index cc1617f..0e92443 100644 --- a/tests/jobs/test_nmt_engine_build_job.py +++ b/tests/jobs/test_nmt_engine_build_job.py @@ -27,8 +27,9 @@ def test_run() -> None: def test_cancel() -> None: env = _TestEnvironment() checker = _CancellationChecker(3) + setattr(env.job, "_check_canceled", checker.check_canceled) with pytest.raises(CanceledError): - env.job.run(checker.check_canceled) + env.job.run() assert env.target_pretranslations == "" diff --git a/typings/networkx/classes/digraph.pyi b/typings/networkx/classes/digraph.pyi index dfa5547..83a5930 100644 --- a/typings/networkx/classes/digraph.pyi +++ b/typings/networkx/classes/digraph.pyi @@ -14,11 +14,11 @@ class DiGraph(Graph[T]): edge_attr_dict_factory: Any = ... graph: Any = ... @overload - def __init__(self, incoming_graph_data: DiGraph[T] = ..., **attr: Any) -> None: ... + def __init__(self, incoming_graph_data: DiGraph[T] = ..., **attr: Any) -> None: ... # type: ignore # @overload # def __init__(self, incoming_graph_data: Optional[Any] = ..., **attr: Any) -> None: ... @overload - def __init__(self, incoming_graph_data: List[Tuple[T, T]] = ..., **attr: Any) -> None: ... + def __init__(self, incoming_graph_data: List[Tuple[T, T]] = ..., **attr: Any) -> None: ... # type: ignore @property def adj(self): ... @property