From 829f447ae931905f33baca8db320bd195288179b Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Thu, 22 Nov 2018 21:19:09 +0100 Subject: [PATCH 01/16] Common base class for runners and trainers (GraphExecutor) GenericTrainer now inherits from GraphExecutor closes #415 --- neuralmonkey/runners/base_runner.py | 43 +++++++++++++++++++----- neuralmonkey/trainers/generic_trainer.py | 17 +++------- 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/neuralmonkey/runners/base_runner.py b/neuralmonkey/runners/base_runner.py index 5d21fc5ef..ae3a79dd8 100644 --- a/neuralmonkey/runners/base_runner.py +++ b/neuralmonkey/runners/base_runner.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from typing import (Any, Dict, Tuple, List, NamedTuple, Union, Set, TypeVar, Generic, Optional) import numpy as np @@ -6,6 +7,8 @@ from neuralmonkey.logging import notice from neuralmonkey.model.model_part import GenericModelPart from neuralmonkey.model.feedable import Feedable +from neuralmonkey.model.parameterized import Parameterized + # pylint: disable=invalid-name FeedDict = Dict[tf.Tensor, Union[int, float, np.ndarray]] NextExecute = Tuple[Set[Feedable], Union[Dict, List], List[FeedDict]] @@ -40,32 +43,54 @@ class Executable: def result(self) -> Optional[ExecutionResult]: return getattr(self, "_result") + @abstractmethod def next_to_execute(self) -> NextExecute: raise NotImplementedError() + @abstractmethod def collect_results(self, results: List[Dict]) -> None: raise NotImplementedError() -class BaseRunner(Generic[MP]): +class GraphExecutor(GenericModelPart): + def __init__(self, - output_series: str, - decoder: MP) -> None: - self.output_series = output_series - self._decoder = decoder + dependencies: Set[GenericModelPart]) -> None: + self._dependencies = dependencies + self._feedables, self._parameterizeds = self.get_dependencies() - self.feedables, self.parameterizeds = decoder.get_dependencies() + @property + def dependencies(self) -> List[str]: + return ["_dependencies"] - if not hasattr(decoder, "data_id"): - notice("Top-level decoder {} does not have the 'data_id' attribute" - .format(decoder)) + @property + def feedables(self) -> Set[Feedable]: + return self._feedables + + @property + def parameterizeds(self) -> Set[Parameterized]: + return self._parameterizeds + @abstractmethod def get_executable(self, compute_losses: bool, summaries: bool, num_sessions: int) -> Executable: raise NotImplementedError() + +class BaseRunner(GraphExecutor, Generic[MP]): + def __init__(self, + output_series: str, + decoder: MP) -> None: + GraphExecutor.__init__(self, {decoder}) + self.output_series = output_series + self._decoder = decoder + + if not hasattr(decoder, "data_id"): + notice("Top-level decoder {} does not have the 'data_id' attribute" + .format(decoder)) + @property def decoder_data_id(self) -> Optional[str]: return getattr(self._decoder, "data_id", None) diff --git a/neuralmonkey/trainers/generic_trainer.py b/neuralmonkey/trainers/generic_trainer.py index 6e600b09b..563775556 100644 --- a/neuralmonkey/trainers/generic_trainer.py +++ b/neuralmonkey/trainers/generic_trainer.py @@ -1,4 +1,4 @@ -from typing import Dict, List, NamedTuple, Optional, Tuple, Union, Set +from typing import Dict, List, NamedTuple, Optional, Tuple, Union import re import tensorflow as tf @@ -7,10 +7,8 @@ from neuralmonkey.decorators import tensor from neuralmonkey.logging import log from neuralmonkey.model.model_part import GenericModelPart -from neuralmonkey.model.feedable import Feedable -from neuralmonkey.model.parameterized import Parameterized from neuralmonkey.runners.base_runner import ( - Executable, ExecutionResult, NextExecute) + GraphExecutor, Executable, ExecutionResult, NextExecute) # pylint: disable=invalid-name Gradients = List[Tuple[tf.Tensor, tf.Variable]] @@ -42,7 +40,7 @@ class Objective(NamedTuple( # pylint: disable=too-few-public-methods,too-many-locals,too-many-arguments -class GenericTrainer: +class GenericTrainer(GraphExecutor): @staticmethod def default_optimizer(): @@ -57,6 +55,7 @@ def __init__(self, var_scopes: List[str] = None, var_collection: str = None) -> None: check_argument_types() + GraphExecutor.__init__(self, {obj.decoder for obj in objectives}) self.objectives = objectives self.l1_weight = l1_weight @@ -70,14 +69,6 @@ def __init__(self, self.optimizer = ( optimizer if optimizer is not None else self.default_optimizer()) - self.feedables = set() # type: Set[Feedable] - self.parameterizeds = set() # type: Set[Parameterized] - - for obj in objectives: - feeds, params = obj.decoder.get_dependencies() - self.feedables |= feeds - self.parameterizeds |= params - log("Train op: {}".format(str(self.train_op))) # pylint: disable=no-self-use From 5c535d25f476281e8b12d6f5003cd2215db58512 Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Thu, 22 Nov 2018 21:41:22 +0100 Subject: [PATCH 02/16] TF Manager retrieves feedables before it calls next_to_execute --- neuralmonkey/tf_manager.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/neuralmonkey/tf_manager.py b/neuralmonkey/tf_manager.py index 894c8ea50..7adef58d7 100644 --- a/neuralmonkey/tf_manager.py +++ b/neuralmonkey/tf_manager.py @@ -179,9 +179,9 @@ def validation_hook(self, score: float, epoch: int, batch: int) -> None: # pylint: disable=too-many-locals def _run_executables(self, batch: Dataset, + feedables: Set[Feedable], executables: List[Executable], train: bool) -> None: - all_feedables = set() # type: Set[Any] all_tensors_to_execute = {} # We might want to feed different values to each session @@ -193,10 +193,9 @@ def _run_executables(self, for executable in executables: if executable.result is None: - (feedables, + (_, tensors_to_execute, add_feed_dicts) = executable.next_to_execute() - all_feedables = all_feedables.union(feedables) all_tensors_to_execute[executable] = tensors_to_execute if add_feed_dicts: for fdict, add_fd in zip(feed_dicts, add_feed_dicts): @@ -205,7 +204,7 @@ def _run_executables(self, else: tensor_list_lengths.append(0) - feed_dict = _feed_dicts(batch, all_feedables, train=train) + feed_dict = _feed_dicts(batch, feedables, train=train) for fdict in feed_dicts: fdict.update(feed_dict) @@ -250,9 +249,11 @@ def execute(self, num_sessions=len(self.sessions)) for runner in runners] + feedables = set.union(*[runner.feedables for runner in runners]) + # TODO refactor runner results to properties while not all(getattr(ex, "result") is not None for ex in executables): - self._run_executables(batch, executables, train) + self._run_executables(batch, feedables, executables, train) return [getattr(ex, "result") for ex in executables] From 43d75e84d7d26cbb1f2dbbb9b30bf4afe4eeffdf Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Thu, 22 Nov 2018 22:17:01 +0100 Subject: [PATCH 03/16] next_to_execute does not return feedables --- neuralmonkey/runners/base_runner.py | 3 ++- neuralmonkey/runners/beamsearch_runner.py | 12 ++++------ neuralmonkey/runners/ctc_debug_runner.py | 13 +++------- neuralmonkey/runners/label_runner.py | 11 +++------ neuralmonkey/runners/logits_runner.py | 13 ++++------ neuralmonkey/runners/perplexity_runner.py | 13 ++++------ neuralmonkey/runners/plain_runner.py | 11 +++------ neuralmonkey/runners/regression_runner.py | 11 +++------ neuralmonkey/runners/runner.py | 13 +++------- neuralmonkey/runners/tensor_runner.py | 12 ++++------ neuralmonkey/runners/word_alignment_runner.py | 14 ++++------- neuralmonkey/tf_manager.py | 24 +++++++------------ .../trainers/delayed_update_trainer.py | 6 +---- neuralmonkey/trainers/generic_trainer.py | 2 +- .../trainers/test_multitask_trainer.py | 14 ++++++----- 15 files changed, 55 insertions(+), 117 deletions(-) diff --git a/neuralmonkey/runners/base_runner.py b/neuralmonkey/runners/base_runner.py index ae3a79dd8..31bd42964 100644 --- a/neuralmonkey/runners/base_runner.py +++ b/neuralmonkey/runners/base_runner.py @@ -11,7 +11,7 @@ # pylint: disable=invalid-name FeedDict = Dict[tf.Tensor, Union[int, float, np.ndarray]] -NextExecute = Tuple[Set[Feedable], Union[Dict, List], List[FeedDict]] +NextExecute = Tuple[Union[Dict, List], List[FeedDict]] MP = TypeVar("MP", bound=GenericModelPart) # pylint: enable=invalid-name @@ -45,6 +45,7 @@ def result(self) -> Optional[ExecutionResult]: @abstractmethod def next_to_execute(self) -> NextExecute: + """Get the tensors and additional feed dicts for execution.""" raise NotImplementedError() @abstractmethod diff --git a/neuralmonkey/runners/beamsearch_runner.py b/neuralmonkey/runners/beamsearch_runner.py index 7bc125e99..8caf27b8c 100644 --- a/neuralmonkey/runners/beamsearch_runner.py +++ b/neuralmonkey/runners/beamsearch_runner.py @@ -1,10 +1,9 @@ -from typing import Callable, List, Dict, Optional, Set +from typing import Callable, List, Dict, Optional import scipy import numpy as np from typeguard import check_argument_types -from neuralmonkey.model.feedable import Feedable from neuralmonkey.decoders.beam_search_decoder import BeamSearchDecoder from neuralmonkey.runners.base_runner import ( BaseRunner, Executable, ExecutionResult, NextExecute) @@ -17,13 +16,11 @@ class BeamSearchExecutable(Executable): def __init__(self, rank: int, - feedables: Set[Feedable], num_sessions: int, decoder: BeamSearchDecoder, postprocess: Optional[Callable]) -> None: self._rank = rank self._num_sessions = num_sessions - self._feedables = feedables self._decoder = decoder self._postprocess = postprocess @@ -39,8 +36,7 @@ def __init__(self, self._result = None # type: Optional[ExecutionResult] def next_to_execute(self) -> NextExecute: - return (self._feedables, {"bs_outputs": self._decoder.outputs}, - self._next_feed) + return {"bs_outputs": self._decoder.outputs}, self._next_feed def collect_results(self, results: List[Dict]) -> None: # Recompute logits @@ -154,8 +150,8 @@ def get_executable(self, compute_losses: bool = False, summaries: bool = True, num_sessions: int = 1) -> BeamSearchExecutable: - return BeamSearchExecutable(self._rank, self.feedables, num_sessions, - self._decoder, self._postprocess) + return BeamSearchExecutable( + self._rank, num_sessions, self._decoder, self._postprocess) # pylint: enable=unused-argument @property diff --git a/neuralmonkey/runners/ctc_debug_runner.py b/neuralmonkey/runners/ctc_debug_runner.py index 55dbebd44..84148c3d6 100644 --- a/neuralmonkey/runners/ctc_debug_runner.py +++ b/neuralmonkey/runners/ctc_debug_runner.py @@ -1,11 +1,10 @@ -from typing import Dict, List, Set, Optional +from typing import Dict, List, Optional import numpy as np from typeguard import check_argument_types from neuralmonkey.runners.base_runner import ( BaseRunner, Executable, FeedDict, ExecutionResult, NextExecute) -from neuralmonkey.model.feedable import Feedable from neuralmonkey.vocabulary import Vocabulary from neuralmonkey.decoders.ctc_decoder import CTCDecoder @@ -13,18 +12,15 @@ class CTCDebugExecutable(Executable): def __init__(self, - feedables: Set[Feedable], fetches: FeedDict, vocabulary: Vocabulary) -> None: - self._feedables = feedables self._fetches = fetches self._vocabulary = vocabulary self._result = None # type: Optional[ExecutionResult] def next_to_execute(self) -> NextExecute: - """Get the feedables and tensors to run.""" - return self._feedables, self._fetches, [] + return self._fetches, [] def collect_results(self, results: List[Dict]) -> None: if len(results) != 1: @@ -68,10 +64,7 @@ def get_executable(self, num_sessions: int) -> CTCDebugExecutable: fetches = {"logits": self._decoder.logits} - return CTCDebugExecutable( - self.feedables, - fetches, - self._decoder.vocabulary) + return CTCDebugExecutable(fetches, self._decoder.vocabulary) # pylint: enable=unused-argument @property diff --git a/neuralmonkey/runners/label_runner.py b/neuralmonkey/runners/label_runner.py index e14bdca56..d6edb31c8 100644 --- a/neuralmonkey/runners/label_runner.py +++ b/neuralmonkey/runners/label_runner.py @@ -1,9 +1,8 @@ -from typing import List, Dict, Set, Optional, Callable +from typing import List, Dict, Optional, Callable import numpy as np from typeguard import check_argument_types from neuralmonkey.logging import log -from neuralmonkey.model.feedable import Feedable from neuralmonkey.vocabulary import Vocabulary, END_TOKEN_INDEX from neuralmonkey.runners.base_runner import ( BaseRunner, Executable, FeedDict, ExecutionResult, NextExecute) @@ -17,11 +16,9 @@ class LabelRunExecutable(Executable): def __init__(self, - feedables: Set[Feedable], fetches: FeedDict, vocabulary: Vocabulary, postprocess: Optional[Postprocessor]) -> None: - self._feedables = feedables self._fetches = fetches self._vocabulary = vocabulary self._postprocess = postprocess @@ -29,8 +26,7 @@ def __init__(self, self._result = None # type: Optional[ExecutionResult] def next_to_execute(self) -> NextExecute: - """Get the feedables and tensors to run.""" - return self._feedables, self._fetches, [] + return self._fetches, [] def collect_results(self, results: List[Dict]) -> None: loss = results[0].get("loss", 0.) @@ -93,8 +89,7 @@ def get_executable(self, fetches["loss"] = self._decoder.cost return LabelRunExecutable( - self.feedables, fetches, self._decoder.vocabulary, - self._postprocess) + fetches, self._decoder.vocabulary, self._postprocess) # pylint: enable: unused-argument @property diff --git a/neuralmonkey/runners/logits_runner.py b/neuralmonkey/runners/logits_runner.py index 0ed73e91b..8609f4462 100644 --- a/neuralmonkey/runners/logits_runner.py +++ b/neuralmonkey/runners/logits_runner.py @@ -1,6 +1,6 @@ """A runner outputing logits or normalized distriution from a decoder.""" -from typing import Dict, List, Set, Optional +from typing import Dict, List, Optional from typeguard import check_argument_types import numpy as np @@ -9,19 +9,16 @@ from neuralmonkey.decoders.classifier import Classifier from neuralmonkey.runners.base_runner import ( BaseRunner, Executable, FeedDict, ExecutionResult, NextExecute) -from neuralmonkey.model.feedable import Feedable from neuralmonkey.vocabulary import Vocabulary class LogitsExecutable(Executable): def __init__(self, - feedables: Set[Feedable], fetches: FeedDict, vocabulary: Vocabulary, normalize: bool, pick_index: Optional[int]) -> None: - self._feedables = feedables self._fetches = fetches self._vocabulary = vocabulary self._normalize = normalize @@ -30,8 +27,7 @@ def __init__(self, self._result = None # type: Optional[ExecutionResult] def next_to_execute(self) -> NextExecute: - """Get the feedables and tensors to run.""" - return self._feedables, self._fetches, [] + return self._fetches, [] def collect_results(self, results: List[Dict]) -> None: if len(results) != 1: @@ -131,9 +127,8 @@ def get_executable(self, fetches["train_loss"] = self._decoder.train_loss fetches["runtime_loss"] = self._decoder.runtime_loss - return LogitsExecutable( - self.feedables, fetches, self._decoder.vocabulary, - self._normalize, self._pick_index) + return LogitsExecutable(fetches, self._decoder.vocabulary, + self._normalize, self._pick_index) # pylint: enable: unused-argument @property diff --git a/neuralmonkey/runners/perplexity_runner.py b/neuralmonkey/runners/perplexity_runner.py index 2d6b2562f..08c277689 100644 --- a/neuralmonkey/runners/perplexity_runner.py +++ b/neuralmonkey/runners/perplexity_runner.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Set +from typing import Dict, List # pylint: disable=unused-import from typing import Optional # pylint: enable=unused-import @@ -7,24 +7,19 @@ import tensorflow as tf import numpy as np -from neuralmonkey.model.feedable import Feedable from neuralmonkey.decoders.autoregressive import AutoregressiveDecoder from neuralmonkey.runners.base_runner import ( BaseRunner, Executable, ExecutionResult, NextExecute) class PerplexityExecutable(Executable): - def __init__(self, - feedables: Set[Feedable], - xent_op: tf.Tensor) -> None: - self._feedables = feedables + def __init__(self, xent_op: tf.Tensor) -> None: self._xent_op = xent_op self._result = None # type: Optional[ExecutionResult] def next_to_execute(self) -> NextExecute: - """Get the feedables and tensors to run.""" - return self._feedables, {"xents": self._xent_op}, [] + return {"xents": self._xent_op}, [] def collect_results(self, results: List[Dict]) -> None: perplexities = np.mean([2 ** res["xents"] for res in results], axis=0) @@ -52,7 +47,7 @@ def get_executable(self, compute_losses: bool, summaries: bool, num_sessions: int) -> PerplexityExecutable: - return PerplexityExecutable(self.feedables, self._decoder_xent) + return PerplexityExecutable(self._decoder_xent) # pylint: enable=unused-argument @property diff --git a/neuralmonkey/runners/plain_runner.py b/neuralmonkey/runners/plain_runner.py index 0c82b0caa..2e2c4a857 100644 --- a/neuralmonkey/runners/plain_runner.py +++ b/neuralmonkey/runners/plain_runner.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Set, Union, Callable, Optional +from typing import Dict, List, Union, Callable, Optional import tensorflow as tf from typeguard import check_argument_types @@ -7,7 +7,6 @@ from neuralmonkey.decoders.ctc_decoder import CTCDecoder from neuralmonkey.decoders.classifier import Classifier from neuralmonkey.decoders.sequence_labeler import SequenceLabeler -from neuralmonkey.model.feedable import Feedable from neuralmonkey.runners.base_runner import ( BaseRunner, Executable, FeedDict, ExecutionResult, NextExecute) from neuralmonkey.vocabulary import Vocabulary @@ -22,12 +21,10 @@ class PlainExecutable(Executable): def __init__(self, - feedables: Set[Feedable], fetches: FeedDict, num_sessions: int, vocabulary: Vocabulary, postprocess: Optional[Postprocessor]) -> None: - self._feedables = feedables self._fetches = fetches self._num_sessions = num_sessions self._vocabulary = vocabulary @@ -36,8 +33,7 @@ def __init__(self, self._result = None # type: Optional[ExecutionResult] def next_to_execute(self) -> NextExecute: - """Get the feedables and tensors to run.""" - return self._feedables, self._fetches, [] + return self._fetches, [] def collect_results(self, results: List[Dict]) -> None: if len(results) != 1: @@ -87,8 +83,7 @@ def get_executable(self, fetches["runtime_loss"] = self._decoder.runtime_loss return PlainExecutable( - self.feedables, fetches, num_sessions, self._decoder.vocabulary, - self._postprocess) + fetches, num_sessions, self._decoder.vocabulary, self._postprocess) # pylint: enable=unused-argument @property diff --git a/neuralmonkey/runners/regression_runner.py b/neuralmonkey/runners/regression_runner.py index ada4c0203..351bb14b8 100644 --- a/neuralmonkey/runners/regression_runner.py +++ b/neuralmonkey/runners/regression_runner.py @@ -1,11 +1,10 @@ -from typing import Dict, List, Set, Callable, Optional +from typing import Dict, List, Callable, Optional import numpy as np import tensorflow as tf from typeguard import check_argument_types from neuralmonkey.decoders.sequence_regressor import SequenceRegressor -from neuralmonkey.model.feedable import Feedable from neuralmonkey.runners.base_runner import ( BaseRunner, Executable, ExecutionResult, NextExecute) @@ -17,18 +16,15 @@ class RegressionRunExecutable(Executable): def __init__(self, - feedables: Set[Feedable], fetches: Dict[str, tf.Tensor], postprocess: Optional[Postprocessor]) -> None: - self._feedables = feedables self._fetches = fetches self._postprocess = postprocess self._result = None # type: Optional[ExecutionResult] def next_to_execute(self) -> NextExecute: - """Get the feedables and tensors to run.""" - return self._feedables, self._fetches, [] + return self._fetches, [] def collect_results(self, results: List[Dict]) -> None: predictions_sum = np.zeros_like(results[0]["prediction"]) @@ -74,8 +70,7 @@ def get_executable(self, if compute_losses: fetches["mse"] = self._decoder.cost - return RegressionRunExecutable( - self.feedables, fetches, self._postprocess) + return RegressionRunExecutable(fetches, self._postprocess) # pylint: enable=unused-argument @property diff --git a/neuralmonkey/runners/runner.py b/neuralmonkey/runners/runner.py index b4e970c59..517dde5e7 100644 --- a/neuralmonkey/runners/runner.py +++ b/neuralmonkey/runners/runner.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Set, Optional, Callable, Union +from typing import Dict, List, Optional, Callable, Union import numpy as np import tensorflow as tf @@ -6,7 +6,6 @@ from neuralmonkey.runners.base_runner import ( BaseRunner, Executable, FeedDict, ExecutionResult, NextExecute) -from neuralmonkey.model.feedable import Feedable from neuralmonkey.vocabulary import Vocabulary from neuralmonkey.decoders.autoregressive import AutoregressiveDecoder from neuralmonkey.decoders.classifier import Classifier @@ -20,11 +19,9 @@ class GreedyRunExecutable(Executable): def __init__(self, - feedables: Set[Feedable], fetches: FeedDict, vocabulary: Vocabulary, postprocess: Optional[Postprocessor]) -> None: - self._feedables = feedables self._fetches = fetches self._vocabulary = vocabulary self._postprocess = postprocess @@ -32,8 +29,7 @@ def __init__(self, self._result = None # type: Optional[ExecutionResult] def next_to_execute(self) -> NextExecute: - """Get the feedables and tensors to run.""" - return self._feedables, self._fetches, [] + return self._fetches, [] def collect_results(self, results: List[Dict]) -> None: train_loss = 0. @@ -99,10 +95,7 @@ def get_executable(self, fetches["image_summaries"] = self.image_summaries return GreedyRunExecutable( - self.feedables, - fetches, - self._decoder.vocabulary, - self._postprocess) + fetches, self._decoder.vocabulary, self._postprocess) # pylint: enable=unused-argument @property diff --git a/neuralmonkey/runners/tensor_runner.py b/neuralmonkey/runners/tensor_runner.py index 4a6aedbd5..edec66dbb 100644 --- a/neuralmonkey/runners/tensor_runner.py +++ b/neuralmonkey/runners/tensor_runner.py @@ -1,11 +1,10 @@ -from typing import Dict, List, Set, Optional +from typing import Dict, List, Optional import numpy as np import tensorflow as tf from typeguard import check_argument_types from neuralmonkey.logging import log, warn -from neuralmonkey.model.feedable import Feedable from neuralmonkey.model.model_part import GenericModelPart from neuralmonkey.runners.base_runner import ( BaseRunner, Executable, ExecutionResult, NextExecute, FeedDict) @@ -15,12 +14,10 @@ class TensorExecutable(Executable): def __init__(self, - feedables: Set[Feedable], fetches: FeedDict, batch_dims: Dict[str, int], select_session: Optional[int], single_tensor: bool) -> None: - self._feedables = feedables self._fetches = fetches self._batch_dims = batch_dims self._select_session = select_session @@ -29,7 +26,7 @@ def __init__(self, self._result = None # type: Optional[ExecutionResult] def next_to_execute(self) -> NextExecute: - return self._feedables, self._fetches, [] + return self._fetches, [] def collect_results(self, results: List[Dict]) -> None: if len(results) > 1 and self._select_session is None: @@ -176,9 +173,8 @@ def get_executable(self, self._fetches[tensor.name] = tensor self._batch_ids[tensor.name] = bid - return TensorExecutable( - self.feedables, self._fetches, self._batch_ids, - self._select_session, self._single_tensor) + return TensorExecutable(self._fetches, self._batch_ids, + self._select_session, self._single_tensor) # pylint: enable=unused-argument @property diff --git a/neuralmonkey/runners/word_alignment_runner.py b/neuralmonkey/runners/word_alignment_runner.py index c8e28485a..53791f92b 100644 --- a/neuralmonkey/runners/word_alignment_runner.py +++ b/neuralmonkey/runners/word_alignment_runner.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Set +from typing import Dict, List # pylint: disable=unused-import from typing import Optional # pylint: enable=unused-import @@ -8,24 +8,18 @@ from neuralmonkey.attention.base_attention import BaseAttention from neuralmonkey.decoders.decoder import Decoder -from neuralmonkey.model.feedable import Feedable from neuralmonkey.runners.base_runner import ( BaseRunner, Executable, FeedDict, ExecutionResult, NextExecute) class WordAlignmentRunnerExecutable(Executable): - def __init__(self, - feedables: Set[Feedable], - fetches: FeedDict) -> None: - self._feedables = feedables + def __init__(self, fetches: FeedDict) -> None: self._fetches = fetches - self._result = None # type: Optional[ExecutionResult] def next_to_execute(self) -> NextExecute: - """Get the feedables and tensors to run.""" - return self._feedables, self._fetches, [] + return self._fetches, [] def collect_results(self, results: List[Dict]) -> None: self._result = ExecutionResult( @@ -60,7 +54,7 @@ def get_executable(self, alignment = tf.transpose(att_histories, perm=[1, 2, 0]) fetches = {"alignment": alignment} - return WordAlignmentRunnerExecutable(self.feedables, fetches) + return WordAlignmentRunnerExecutable(fetches) # pylint: enable=unused-argument @property diff --git a/neuralmonkey/tf_manager.py b/neuralmonkey/tf_manager.py index 7adef58d7..9f6420293 100644 --- a/neuralmonkey/tf_manager.py +++ b/neuralmonkey/tf_manager.py @@ -182,35 +182,27 @@ def _run_executables(self, feedables: Set[Feedable], executables: List[Executable], train: bool) -> None: - all_tensors_to_execute = {} + all_fetches = {} # We might want to feed different values to each session # E.g. when executing only step at a time during ensembling feed_dicts = [{} for _ in range(len(self.sessions))] \ # type: List[FeedDict] - tensor_list_lengths = [] # type: List[int] + for executable in (ex for ex in executables if ex.result is None): + fetches, add_feed_dicts = executable.next_to_execute() + all_fetches[executable] = fetches - for executable in executables: - if executable.result is None: - (_, - tensors_to_execute, - add_feed_dicts) = executable.next_to_execute() - all_tensors_to_execute[executable] = tensors_to_execute - if add_feed_dicts: - for fdict, add_fd in zip(feed_dicts, add_feed_dicts): - fdict.update(add_fd) - tensor_list_lengths.append(len(tensors_to_execute)) - else: - tensor_list_lengths.append(0) + if add_feed_dicts: + for fdict, add_fd in zip(feed_dicts, add_feed_dicts): + fdict.update(add_fd) feed_dict = _feed_dicts(batch, feedables, train=train) for fdict in feed_dicts: fdict.update(feed_dict) - session_results = [sess.run(all_tensors_to_execute, - feed_dict=fd) + session_results = [sess.run(all_fetches, feed_dict=fd) for sess, fd in zip(self.sessions, feed_dicts)] for executable in executables: diff --git a/neuralmonkey/trainers/delayed_update_trainer.py b/neuralmonkey/trainers/delayed_update_trainer.py index 461d1230e..3b0939966 100644 --- a/neuralmonkey/trainers/delayed_update_trainer.py +++ b/neuralmonkey/trainers/delayed_update_trainer.py @@ -197,7 +197,6 @@ def next_to_execute(self) -> NextExecute: fetches = {"accumulators": self.trainer.accumulate_ops, "counter": self.trainer.cumulator_counter, "losses": self.trainer.objective_values} - coders = self.trainer.feedables elif self.state == 1: # UPDATING fetches = { @@ -207,13 +206,10 @@ def next_to_execute(self) -> NextExecute: if self.summaries: fetches.update(self.trainer.summaries) - coders = self.trainer.feedables - else: # RESETTING fetches = {"resets": self.trainer.reset_ops} - coders = set() - return coders, fetches, [{}] + return fetches, [{}] def collect_results(self, results: List[Dict]) -> None: assert len(results) == 1 diff --git a/neuralmonkey/trainers/generic_trainer.py b/neuralmonkey/trainers/generic_trainer.py index 563775556..a461de482 100644 --- a/neuralmonkey/trainers/generic_trainer.py +++ b/neuralmonkey/trainers/generic_trainer.py @@ -253,7 +253,7 @@ def next_to_execute(self) -> NextExecute: fetches["losses"] = self.trainer.objective_values fetches["_update_ops"] = tf.get_collection(tf.GraphKeys.UPDATE_OPS) - return self.trainer.feedables, fetches, [{}] + return fetches, [{}] def collect_results(self, results: List[Dict]) -> None: assert len(results) == 1 diff --git a/neuralmonkey/trainers/test_multitask_trainer.py b/neuralmonkey/trainers/test_multitask_trainer.py index 79c4be379..92af59e5e 100644 --- a/neuralmonkey/trainers/test_multitask_trainer.py +++ b/neuralmonkey/trainers/test_multitask_trainer.py @@ -47,6 +47,7 @@ def setUp(self): self.trainer2 = GenericTrainer([objective_2], clip_norm=1.0) def test_mt_trainer(self): + # TODO multitask trainer is likely broken by changes in tf-data branch trainer = MultitaskTrainer( [self.trainer1, self.trainer2, self.trainer1]) @@ -60,24 +61,25 @@ def test_mt_trainer(self): self.assertTrue(trainer.trainer_idx == 0) executable = trainer.get_executable() - mparts, fetches, feeds = executable.next_to_execute() - self.assertSetEqual(mparts, {self.mpart}) + # mparts = trainer.feedables + fetches, feeds = executable.next_to_execute() + # self.assertSetEqual(mparts, {self.mpart}) self.assertFalse(feeds[0]) self.assertTrue(trainer.trainer_idx == 1) self.assertTrue(fetches["losses"][0] == self.mpart.loss) executable = trainer.get_executable() - mparts, fetches, feeds = executable.next_to_execute() - self.assertSetEqual(mparts, {self.mpart_2}) + fetches, feeds = executable.next_to_execute() + # self.assertSetEqual(mparts, {self.mpart_2}) self.assertFalse(feeds[0]) self.assertTrue(trainer.trainer_idx == 2) self.assertTrue(fetches["losses"][0] == self.mpart_2.loss) executable = trainer.get_executable() - mparts, fetches, feeds = executable.next_to_execute() - self.assertSetEqual(mparts, {self.mpart}) + fetches, feeds = executable.next_to_execute() + # self.assertSetEqual(mparts, {self.mpart}) self.assertFalse(feeds[0]) self.assertTrue(trainer.trainer_idx == 0) From 1277d0acf133b110839b66eac18b3059869bd886 Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Thu, 22 Nov 2018 22:43:39 +0100 Subject: [PATCH 04/16] tf_manager.execute begins with calling default feed dicts --- neuralmonkey/tf_manager.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/neuralmonkey/tf_manager.py b/neuralmonkey/tf_manager.py index 9f6420293..7f004fadb 100644 --- a/neuralmonkey/tf_manager.py +++ b/neuralmonkey/tf_manager.py @@ -21,11 +21,8 @@ from neuralmonkey.logging import log from neuralmonkey.dataset import Dataset from neuralmonkey.model.feedable import Feedable -# pylint: disable=unused-import -from neuralmonkey.runners.base_runner import FeedDict -# pylint: enable=unused-import from neuralmonkey.runners.base_runner import ( - BaseRunner, ExecutionResult, Executable) + FeedDict, BaseRunner, ExecutionResult, Executable) from neuralmonkey.trainers.generic_trainer import GenericTrainer from neuralmonkey.trainers.multitask_trainer import MultitaskTrainer @@ -178,10 +175,8 @@ def validation_hook(self, score: float, epoch: int, batch: int) -> None: # pylint: disable=too-many-locals def _run_executables(self, - batch: Dataset, - feedables: Set[Feedable], - executables: List[Executable], - train: bool) -> None: + feed_dict: FeedDict, + executables: List[Executable]) -> None: all_fetches = {} # We might want to feed different values to each session @@ -197,8 +192,6 @@ def _run_executables(self, for fdict, add_fd in zip(feed_dicts, add_feed_dicts): fdict.update(add_fd) - feed_dict = _feed_dicts(batch, feedables, train=train) - for fdict in feed_dicts: fdict.update(feed_dict) @@ -236,16 +229,17 @@ def execute(self, A list of `ExecutionResult` tuples, one for each executable (runner). """ + feedables = set.union(*[runner.feedables for runner in runners]) + default_feed_dict = _feed_dicts(batch, feedables, train=train) + executables = [runner.get_executable(compute_losses=compute_losses, summaries=summaries, num_sessions=len(self.sessions)) for runner in runners] - feedables = set.union(*[runner.feedables for runner in runners]) - # TODO refactor runner results to properties while not all(getattr(ex, "result") is not None for ex in executables): - self._run_executables(batch, feedables, executables, train) + self._run_executables(default_feed_dict, executables) return [getattr(ex, "result") for ex in executables] From e0061f6f9a9fab1e0f184ce1037066235012ae83 Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Thu, 22 Nov 2018 23:16:46 +0100 Subject: [PATCH 05/16] moving feedable extraction to learning_utils this alone should already make for modest speed improvements as the model part tree is traversed just once for training. For runtime, the tree is traversed each time run_on_dataset is called, which is pretty much all the time.. --- neuralmonkey/learning_utils.py | 14 ++++++++++---- neuralmonkey/tf_manager.py | 16 +++++----------- neuralmonkey/trainers/multitask_trainer.py | 16 ++++------------ 3 files changed, 19 insertions(+), 27 deletions(-) diff --git a/neuralmonkey/learning_utils.py b/neuralmonkey/learning_utils.py index 6166218e7..d27aca04c 100644 --- a/neuralmonkey/learning_utils.py +++ b/neuralmonkey/learning_utils.py @@ -200,6 +200,10 @@ def _is_logging_time(period_batch: Optional[int], except tf.errors.NotFoundError: warn("Some variables were not found in checkpoint.)") + # Ignoring type. Mypy complains about summing runner and trainer lists. + feedables = set.union( + *[ex.feedables for ex in runners + trainers]) # type: ignore + if log_directory: log("Initializing TensorBoard summary writer.") tb_writer = tf.summary.FileWriter( @@ -231,7 +235,7 @@ def _is_logging_time(period_batch: Optional[int], if _is_logging_time(log_period_batch, log_period_time, last_log_time): trainer_result = tf_manager.execute( - batch, trainers, train=True, summaries=True) + batch, feedables, trainers, train=True, summaries=True) train_results, train_outputs = run_on_dataset( tf_manager, runners, batch, postprocess, write_out=False, @@ -249,8 +253,8 @@ def _is_logging_time(period_batch: Optional[int], train=True) last_log_time = time.process_time() else: - tf_manager.execute( - batch, trainers, train=True, summaries=False) + tf_manager.execute(batch, feedables, trainers, train=True, + summaries=False) if _is_logging_time(val_period_batch, val_period_time, last_val_time): @@ -453,6 +457,8 @@ def run_on_dataset(tf_manager: TensorFlowManager, last_log_time = time.process_time() batch_results = [[] for _ in runners] # type: List[List[ExecutionResult]] + feedables = set.union(*[runner.feedables for runner in runners]) + processed_examples = 0 for batch in dataset.batches(batching_scheme): if 0 < log_progress < time.process_time() - last_log_time: @@ -460,7 +466,7 @@ def run_on_dataset(tf_manager: TensorFlowManager, last_log_time = time.process_time() execution_results = tf_manager.execute( - batch, runners, compute_losses=contains_targets) + batch, feedables, runners, compute_losses=contains_targets) processed_examples += len(batch) for script_list, ex_result in zip(batch_results, execution_results): diff --git a/neuralmonkey/tf_manager.py b/neuralmonkey/tf_manager.py index 7f004fadb..dbf1a7ce8 100644 --- a/neuralmonkey/tf_manager.py +++ b/neuralmonkey/tf_manager.py @@ -22,13 +22,7 @@ from neuralmonkey.dataset import Dataset from neuralmonkey.model.feedable import Feedable from neuralmonkey.runners.base_runner import ( - FeedDict, BaseRunner, ExecutionResult, Executable) -from neuralmonkey.trainers.generic_trainer import GenericTrainer -from neuralmonkey.trainers.multitask_trainer import MultitaskTrainer - -# pylint: disable=invalid-name -Trainer = Union[GenericTrainer, MultitaskTrainer] -# pylint: enable=invalid-name + FeedDict, ExecutionResult, Executable, GraphExecutor) class TensorFlowManager: @@ -206,7 +200,8 @@ def _run_executables(self, # pylint: disable=too-many-locals def execute(self, batch: Dataset, - runners: Sequence[Union[BaseRunner, Trainer]], + feedables: Set[Feedable], + runners: Sequence[GraphExecutor], train: bool = False, compute_losses: bool = True, summaries: bool = True) -> List[ExecutionResult]: @@ -229,7 +224,6 @@ def execute(self, A list of `ExecutionResult` tuples, one for each executable (runner). """ - feedables = set.union(*[runner.feedables for runner in runners]) default_feed_dict = _feed_dicts(batch, feedables, train=train) executables = [runner.get_executable(compute_losses=compute_losses, @@ -277,8 +271,8 @@ def restore_best_vars(self) -> None: # TODO warn when link does not exist self.restore(self.variables_files[self.best_score_index]) - def initialize_model_parts( - self, runners: List[Any], save: bool = False) -> None: + def initialize_model_parts(self, runners: Sequence[GraphExecutor], + save: bool = False) -> None: """Initialize model parts variables from their checkpoints.""" if any(not hasattr(r, "parameterizeds") for r in runners): diff --git a/neuralmonkey/trainers/multitask_trainer.py b/neuralmonkey/trainers/multitask_trainer.py index 26bfd810c..e1e4209a5 100644 --- a/neuralmonkey/trainers/multitask_trainer.py +++ b/neuralmonkey/trainers/multitask_trainer.py @@ -1,15 +1,13 @@ -from typing import List, Set +from typing import List from typeguard import check_argument_types -from neuralmonkey.model.feedable import Feedable -from neuralmonkey.model.parameterized import Parameterized -from neuralmonkey.runners.base_runner import Executable +from neuralmonkey.runners.base_runner import Executable, GraphExecutor from neuralmonkey.trainers.generic_trainer import GenericTrainer # pylint: disable=too-few-public-methods -class MultitaskTrainer: +class MultitaskTrainer(GraphExecutor): """Wrapper for scheduling multitask training. The wrapper contains a list of trainer objects. They are being @@ -20,19 +18,13 @@ class MultitaskTrainer: def __init__(self, trainers: List[GenericTrainer]) -> None: check_argument_types() + GraphExecutor.__init__(self, set(trainers)) self.trainers = trainers self.trainer_idx = 0 self.var_list = list(set.union(*[set(t.var_list) for t in trainers])) - self.feedables = set() # type: Set[Feedable] - self.parameterizeds = set() # type: Set[Parameterized] - - for trainer in self.trainers: - self.feedables |= trainer.feedables - self.parameterizeds |= trainer.parameterizeds - def get_executable( self, compute_losses: bool = True, summaries: bool = True, num_sessions: int = 1) -> Executable: From f2c4487635d7e1e9abb9c871180c65cd30371de9 Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Sat, 24 Nov 2018 00:16:32 +0100 Subject: [PATCH 06/16] Refactoring tensor_runner not to work with tensors directly --- neuralmonkey/runners/tensor_runner.py | 82 +++++++++++++++------------ tests/bahdanau.ini | 7 +-- 2 files changed, 48 insertions(+), 41 deletions(-) diff --git a/neuralmonkey/runners/tensor_runner.py b/neuralmonkey/runners/tensor_runner.py index edec66dbb..0b24b538e 100644 --- a/neuralmonkey/runners/tensor_runner.py +++ b/neuralmonkey/runners/tensor_runner.py @@ -1,10 +1,14 @@ from typing import Dict, List, Optional -import numpy as np +# pylint: disable=unused-import +# Type annotation used in comment import tensorflow as tf +# pylint: enable=unused-import + +import numpy as np from typeguard import check_argument_types -from neuralmonkey.logging import log, warn +from neuralmonkey.logging import warn from neuralmonkey.model.model_part import GenericModelPart from neuralmonkey.runners.base_runner import ( BaseRunner, Executable, ExecutionResult, NextExecute, FeedDict) @@ -88,12 +92,11 @@ class TensorRunner(BaseRunner[GenericModelPart]): # pylint: disable=too-many-arguments def __init__(self, output_series: str, - toplevel_modelpart: GenericModelPart, - toplevel_tensors: List[tf.Tensor], + modelparts: List[GenericModelPart], + tensors: List[str], + batch_dims: List[int], tensors_by_name: List[str], - tensors_by_ref: List[tf.Tensor], batch_dims_by_name: List[int], - batch_dims_by_ref: List[int], select_session: int = None, single_tensor: bool = False) -> None: """Construct a new ``TensorRunner`` object. @@ -106,21 +109,16 @@ def __init__(self, Args: output_series: The name of the generated output data series. - toplevel_modelpart: A ``GenericModelPart`` object that is used as - the top-level component of the model. This object should depend - on values of all the wanted tensors. - toplevel_tensors: A list of tensors that should be constructed. Use - this when the toplevel model part does not depend on this - tensor. The tensors are constructed during running this - constructor method which prints them out. + modelparts: A list of ``GenericModelPart`` objects that hold the + tensors that will be retrieved. + tensors: A list of names of tensors that should be retrieved. + batch_dims_by_ref: A list of integers that correspond to the + batch dimension in each wanted tensor. tensors_by_name: A list of tensor names to fetch. If a tensor is not in the graph, a warning is generated and the tensor is ignored. - tensors_by_ref: A list of tensor objects to fetch. batch_dims_by_name: A list of integers that correspond to the batch dimension in each wanted tensor specified by name. - batch_dims_by_ref: A list of integers that correspond to the - batch dimension in each wanted tensor specified by reference. select_session: An optional integer specifying the session to use in case of ensembling. When not used, tensors from all sessions are stored. In case of a single session, this option has no @@ -131,28 +129,41 @@ def __init__(self, tensor names to NumPy arrays. """ check_argument_types() + + if not modelparts: + raise ValueError("At least one model part is expected") + BaseRunner[GenericModelPart].__init__( - self, output_series, toplevel_modelpart) + self, output_series, modelparts[0]) - total_tensors = len(tensors_by_name) + len(tensors_by_ref) + if len(modelparts) != len(tensors): + raise ValueError("TensorRunner: 'modelparts' and 'tensors' lists " + "must have the same length") + + total_tensors = len(tensors_by_name) + len(tensors) if single_tensor and total_tensors > 1: raise ValueError("single_tensor is True, but {} tensors were given" .format(total_tensors)) self._names = tensors_by_name - self._tensors = tensors_by_ref + self._modelparts = modelparts + self._tensors = tensors self._batch_dims_name = batch_dims_by_name - self._batch_dims_ref = batch_dims_by_ref + self._batch_dims = batch_dims self._select_session = select_session self._single_tensor = single_tensor - log("Blessing toplevel tensors for tensor runner:") - for tensor in toplevel_tensors: - log("Toplevel tensor: {}".format(tensor)) - self._fetches = {} # type: Dict[str, tf.Tensor] self._batch_ids = {} # type: Dict[str, int] + # pylint: enable=too-many-arguments + + # pylint: disable=unused-argument + def get_executable(self, + compute_losses: bool, + summaries: bool, + num_sessions: int) -> TensorExecutable: + for name, bid in zip(self._names, self._batch_dims_name): try: self._fetches[name] = ( @@ -161,15 +172,15 @@ def __init__(self, except KeyError: warn(("The tensor of name '{}' is not present in the " "graph.").format(name)) - # pylint: enable=too-many-arguments - # pylint: disable=unused-argument - def get_executable(self, - compute_losses: bool, - summaries: bool, - num_sessions: int) -> TensorExecutable: + for mpart, tname, bid in zip(self._modelparts, self._tensors, + self._batch_dims): + if not hasattr(mpart, tname): + raise ValueError("Model part {} does not have a tensor called " + "{}.".format(mpart, tname)) + + tensor = getattr(mpart, tname) - for tensor, bid in zip(self._tensors, self._batch_dims_ref): self._fetches[tensor.name] = tensor self._batch_ids[tensor.name] = bid @@ -211,16 +222,13 @@ def __init__(self, raise TypeError("The encoder '{}' does not have the specified " "attribute '{}'".format(encoder, attribute)) - tensor_to_get = getattr(encoder, attribute) - TensorRunner.__init__( self, output_series, - toplevel_modelpart=encoder, - toplevel_tensors=[], + modelparts=[encoder], + tensors=[attribute], + batch_dims=[0], tensors_by_name=[], - tensors_by_ref=[tensor_to_get], batch_dims_by_name=[], - batch_dims_by_ref=[0], select_session=select_session, single_tensor=True) diff --git a/tests/bahdanau.ini b/tests/bahdanau.ini index ded79f63d..3e9826a33 100644 --- a/tests/bahdanau.ini +++ b/tests/bahdanau.ini @@ -121,11 +121,10 @@ output_series="encoded" [debug_runner] class=runners.tensor_runner.TensorRunner -toplevel_modelpart= -toplevel_tensors=[] +modelparts=[, , ] +tensors=["output", "temporal_states", "runtime_logits"] +batch_dims=[0, 0, 1] ; tensors_by_name=["sentence_encoder/input_to_final_state/Tensordot/add:0"] tensors_by_name=[] batch_dims_by_name=[] -tensors_by_ref=[,,] -batch_dims_by_ref=[0, 0, 1] output_series="debugtensors" From e9b3f16b98db811e0dc6184dc41287c2f4123629 Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Sat, 24 Nov 2018 00:19:17 +0100 Subject: [PATCH 07/16] Removing blessing from most of the codebase. The only place where the blessing remained is in the trainer and in the attention. This may break the running of the model though! --- neuralmonkey/attention/feed_forward.py | 8 ++++---- neuralmonkey/decoders/autoregressive.py | 23 +++++++++++++---------- neuralmonkey/decoders/ctc_decoder.py | 2 -- neuralmonkey/decoders/decoder.py | 4 ---- neuralmonkey/decoders/transformer.py | 6 +----- neuralmonkey/encoders/transformer.py | 3 --- neuralmonkey/runners/label_runner.py | 4 ---- neuralmonkey/tests/test_decoder.py | 9 --------- neuralmonkey/trainers/generic_trainer.py | 1 + 9 files changed, 19 insertions(+), 41 deletions(-) diff --git a/neuralmonkey/attention/feed_forward.py b/neuralmonkey/attention/feed_forward.py index 891fb49b4..2eb4ca966 100644 --- a/neuralmonkey/attention/feed_forward.py +++ b/neuralmonkey/attention/feed_forward.py @@ -42,10 +42,6 @@ def __init__(self, self._variable_scope.set_initializer( tf.random_normal_initializer(stddev=0.001)) - - # TODO blessing - log("Hidden features: {}".format(self.hidden_features)) - log("Attention mask: {}".format(self.attention_mask)) # pylint: enable=too-many-arguments @tensor @@ -170,6 +166,10 @@ def attention(self, return context, next_loop_state def initial_loop_state(self) -> AttentionLoopState: + # TODO blessing + log("Pre-computing attention tensors") + log("Hidden features: {}".format(self.hidden_features)) + return empty_attention_loop_state( self.batch_size, tf.shape(self.attention_states)[1], diff --git a/neuralmonkey/decoders/autoregressive.py b/neuralmonkey/decoders/autoregressive.py index e3f2765f6..86fe2a4b9 100644 --- a/neuralmonkey/decoders/autoregressive.py +++ b/neuralmonkey/decoders/autoregressive.py @@ -150,19 +150,22 @@ def __init__(self, self.encoder_masks = [] # type: List[tf.Tensor] # Check the values of the parameters (max_output_len, ...) - if max_output_len <= 0: - raise ValueError("Maximum sequence length must be " - "a positive integer.") + if self.max_output_len <= 0: + raise ValueError( + "Maximum sequence length must be a positive integer.") + + if self.embedding_size is not None and self.embedding_size <= 0: + raise ValueError("Embedding size must be a positive integer.") - if dropout_keep_prob < 0.0 or dropout_keep_prob > 1.0: - raise ValueError("Dropout keep probability must be" - "a real number in the interval [0,1].") + if self.dropout_keep_prob < 0.0 or self.dropout_keep_prob > 1.0: + raise ValueError("Dropout keep probability must be a real number " + "in the interval [0,1].") if self.embedding_size is None and self.embeddings_source is None: - raise ValueError("You must specify either embedding size or the " - "embedded sequence from which to reuse the " - "embeddings (e.g. set either 'embedding_size' or " - " 'embeddings_source' parameter)") + raise ValueError( + "You must specify either embedding size or the embedded " + "sequence from which to reuse the embeddings (e.g. set either " + "'embedding_size' or 'embeddings_source' parameter)") if self.embeddings_source is not None: if self.embedding_size is not None: diff --git a/neuralmonkey/decoders/ctc_decoder.py b/neuralmonkey/decoders/ctc_decoder.py index 2d3cef783..b75bd1a0a 100644 --- a/neuralmonkey/decoders/ctc_decoder.py +++ b/neuralmonkey/decoders/ctc_decoder.py @@ -6,7 +6,6 @@ from neuralmonkey.dataset import Dataset from neuralmonkey.decorators import tensor -from neuralmonkey.logging import log from neuralmonkey.model.feedable import FeedDict from neuralmonkey.model.parameterized import InitializerSpecs from neuralmonkey.model.model_part import ModelPart @@ -47,7 +46,6 @@ def __init__(self, self.merge_repeated_targets = merge_repeated_targets self.merge_repeated_outputs = merge_repeated_outputs self.beam_width = beam_width - log("CTC output tensor {}.".format(self.decoded)) # pylint: enable=too-many-arguments # pylint: disable=no-self-use diff --git a/neuralmonkey/decoders/decoder.py b/neuralmonkey/decoders/decoder.py index 0dc5d0c87..6957cccd4 100644 --- a/neuralmonkey/decoders/decoder.py +++ b/neuralmonkey/decoders/decoder.py @@ -214,10 +214,6 @@ def __init__(self, self._variable_scope.set_initializer( tf.random_normal_initializer(stddev=0.001)) - - # TODO when it is possible, remove the printing of the cost var - log("Decoder initalized. Cost var: {}".format(str(self.cost))) - log("Runtime logits tensor: {}".format(str(self.runtime_logits))) # pylint: enable=too-many-arguments,too-many-branches,too-many-statements @tensor diff --git a/neuralmonkey/decoders/transformer.py b/neuralmonkey/decoders/transformer.py index 9e2a1b964..5c12b8d2e 100644 --- a/neuralmonkey/decoders/transformer.py +++ b/neuralmonkey/decoders/transformer.py @@ -20,7 +20,7 @@ AutoregressiveDecoder, LoopState, DecoderFeedables) from neuralmonkey.encoders.transformer import ( TransformerLayer, position_signal) -from neuralmonkey.logging import log, warn +from neuralmonkey.logging import warn from neuralmonkey.model.sequence import EmbeddedSequence from neuralmonkey.model.parameterized import InitializerSpecs from neuralmonkey.model.model_part import ModelPart @@ -214,10 +214,6 @@ def __init__(self, self._variable_scope.set_initializer(tf.variance_scaling_initializer( mode="fan_avg", distribution="uniform")) - - log("Decoder cost op: {}".format(self.cost)) - self._variable_scope.reuse_variables() - log("Runtime logits: {}".format(self.runtime_logits)) # pylint: enable=too-many-arguments,too-many-locals,too-many-branches @property diff --git a/neuralmonkey/encoders/transformer.py b/neuralmonkey/encoders/transformer.py index c03401e47..50757d19e 100644 --- a/neuralmonkey/encoders/transformer.py +++ b/neuralmonkey/encoders/transformer.py @@ -12,7 +12,6 @@ Attendable, get_attention_states, get_attention_mask) from neuralmonkey.decorators import tensor from neuralmonkey.attention.scaled_dot_product import attention -from neuralmonkey.logging import log from neuralmonkey.model.parameterized import InitializerSpecs from neuralmonkey.model.model_part import ModelPart from neuralmonkey.model.stateful import (TemporalStateful, @@ -161,8 +160,6 @@ def __init__(self, self._variable_scope.set_initializer(tf.variance_scaling_initializer( mode="fan_avg", distribution="uniform")) - - log("Output op: {}".format(self.output)) # pylint: enable=too-many-arguments,too-many-locals @tensor diff --git a/neuralmonkey/runners/label_runner.py b/neuralmonkey/runners/label_runner.py index d6edb31c8..e2296c166 100644 --- a/neuralmonkey/runners/label_runner.py +++ b/neuralmonkey/runners/label_runner.py @@ -2,7 +2,6 @@ import numpy as np from typeguard import check_argument_types -from neuralmonkey.logging import log from neuralmonkey.vocabulary import Vocabulary, END_TOKEN_INDEX from neuralmonkey.runners.base_runner import ( BaseRunner, Executable, FeedDict, ExecutionResult, NextExecute) @@ -71,9 +70,6 @@ def __init__(self, self._postprocess = postprocess - # Make sure the lazy decoder creates its output tensor - log("Decoder output tensor: {}".format(decoder.decoded)) - # pylint: disable=unused-argument # Don't know why it works in Attention.attention and not here. # Parameters are unused beacause they are inherited. diff --git a/neuralmonkey/tests/test_decoder.py b/neuralmonkey/tests/test_decoder.py index 48b0152ef..4b1e7c968 100644 --- a/neuralmonkey/tests/test_decoder.py +++ b/neuralmonkey/tests/test_decoder.py @@ -54,15 +54,6 @@ def test_embedding_size(self): with self.assertRaises(ValueError): Decoder(**dparams) - def test_tie_embeddings(self): - dparams = copy.deepcopy(DECODER_PARAMS) - - dparams["tie_embeddings"] = True - dparams["rnn_size"] = 20 - dparams["embedding_size"] = 10 - with self.assertRaises(ValueError): - Decoder(**dparams) - def test_cell_type(self): dparams = copy.deepcopy(DECODER_PARAMS) diff --git a/neuralmonkey/trainers/generic_trainer.py b/neuralmonkey/trainers/generic_trainer.py index a461de482..ce8044c99 100644 --- a/neuralmonkey/trainers/generic_trainer.py +++ b/neuralmonkey/trainers/generic_trainer.py @@ -69,6 +69,7 @@ def __init__(self, self.optimizer = ( optimizer if optimizer is not None else self.default_optimizer()) + log("Building model") log("Train op: {}".format(str(self.train_op))) # pylint: disable=no-self-use From bd3ed4854d9d5ec4ddf7a5038983c80313561187 Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Sun, 25 Nov 2018 20:14:51 +0100 Subject: [PATCH 08/16] Refactoring runners and trainers to use lazy execution. Graph construction is now started by "blessing" the fetches of all runners and trainers, right after building the cfg model in experiment.py --- neuralmonkey/attention/combination.py | 12 +- neuralmonkey/attention/feed_forward.py | 1 + .../decoders/word_alignment_decoder.py | 38 +-- neuralmonkey/encoders/recurrent.py | 2 + neuralmonkey/experiment.py | 19 +- neuralmonkey/model/parameterized.py | 2 + neuralmonkey/runners/base_runner.py | 88 +++++-- neuralmonkey/runners/beamsearch_runner.py | 224 +++++++++--------- neuralmonkey/runners/ctc_debug_runner.py | 80 +++---- neuralmonkey/runners/label_runner.py | 105 ++++---- neuralmonkey/runners/logits_runner.py | 126 ++++------ neuralmonkey/runners/perplexity_runner.py | 54 ++--- neuralmonkey/runners/plain_runner.py | 87 +++---- neuralmonkey/runners/regression_runner.py | 76 +++--- neuralmonkey/runners/runner.py | 108 ++++----- neuralmonkey/runners/tensor_runner.py | 161 ++++++------- neuralmonkey/runners/word_alignment_runner.py | 48 ++-- neuralmonkey/tf_manager.py | 4 +- .../trainers/delayed_update_trainer.py | 146 ++++++------ neuralmonkey/trainers/generic_trainer.py | 85 +++---- neuralmonkey/trainers/multitask_trainer.py | 15 +- tests/pydocstyle_run.sh | 2 +- 22 files changed, 683 insertions(+), 800 deletions(-) diff --git a/neuralmonkey/attention/combination.py b/neuralmonkey/attention/combination.py index aef6f9331..4f7f5de26 100644 --- a/neuralmonkey/attention/combination.py +++ b/neuralmonkey/attention/combination.py @@ -23,6 +23,7 @@ get_attention_states, get_attention_mask, Attendable) from neuralmonkey.attention.namedtuples import HierarchicalLoopState from neuralmonkey.checking import assert_shape +from neuralmonkey.decorators import tensor from neuralmonkey.model.model_part import ModelPart from neuralmonkey.model.parameterized import InitializerSpecs from neuralmonkey.tf_utils import get_variable @@ -49,11 +50,6 @@ def __init__(self, self._use_sentinels = use_sentinels self.att_scope_name = "attention_{}".format(name) - - with self.use_scope(): - self.attn_v = get_variable( - "attn_v", [1, 1, self.attention_state_size], - initializer=tf.random_normal_initializer(stddev=0.001)) # pylint: enable=unused-argument,too-many-arguments def attention(self, @@ -64,6 +60,12 @@ def attention(self, """Get context vector for given decoder state.""" raise NotImplementedError("Abstract method") + @tensor + def attn_v(self) -> tf.Tensor: + return get_variable( + "attn_v", [1, 1, self.attention_state_size], + initializer=tf.random_normal_initializer(stddev=0.001)) + @property def attn_size(self): return self.attention_state_size diff --git a/neuralmonkey/attention/feed_forward.py b/neuralmonkey/attention/feed_forward.py index 2eb4ca966..edef3c6f4 100644 --- a/neuralmonkey/attention/feed_forward.py +++ b/neuralmonkey/attention/feed_forward.py @@ -169,6 +169,7 @@ def initial_loop_state(self) -> AttentionLoopState: # TODO blessing log("Pre-computing attention tensors") log("Hidden features: {}".format(self.hidden_features)) + log("Hidden mask: {}".format(self.attention_mask)) return empty_attention_loop_state( self.batch_size, diff --git a/neuralmonkey/decoders/word_alignment_decoder.py b/neuralmonkey/decoders/word_alignment_decoder.py index 0932a9998..cb957d7f9 100644 --- a/neuralmonkey/decoders/word_alignment_decoder.py +++ b/neuralmonkey/decoders/word_alignment_decoder.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import cast, Tuple import numpy as np import tensorflow as tf @@ -40,19 +40,6 @@ def __init__(self, self.enc_input = cast(Sequence, self.encoder.input_sequence) - # TODO this is here to call the lazy properties which create - # the list of attention distribbutions - # pylint: disable=pointless-statement - self.decoder.runtime_logits - self.decoder.train_logits - # pylint: enable=pointless-statement - - _, self.train_loss = self._make_decoder(runtime_mode=False) - self.decoded, self.runtime_loss = self._make_decoder(runtime_mode=True) - - tf.summary.scalar("alignment_train_xent", self.train_loss, - collections=["summary_train"]) - @tensor def ref_alignment(self) -> tf.Tensor: # TODO dynamic shape? @@ -67,6 +54,29 @@ def alignment_target(self) -> tf.Tensor: # shape will be [max_output_len, batch_size, max_input_len] return tf.transpose(self.ref_alignment, perm=[1, 0, 2]) + @tensor + def train_loss(self) -> tf.Tensor: + loss = self._make_decoder(runtime_mode=False) + tf.summary.scalar( + "alignment_train_xent", loss, collections=["summary_train"]) + + return loss + + # pylint: disable=unsubscriptable-object + # Bug in pylint + @tensor + def decoded(self) -> tf.Tensor: + return self.runtime_outputs[0] + + @tensor + def runtime_loss(self) -> tf.Tensor: + return self.runtime_outputs[1] + # pylint: enable=unsubscriptable-object + + @tensor + def runtime_outputs(self) -> Tuple[tf.Tensor, tf.Tensor]: + return self._make_decoder(runtime_mode=True) + def _make_decoder(self, runtime_mode=False): attn_obj = self.decoder.get_attention_object(self.encoder, not runtime_mode) diff --git a/neuralmonkey/encoders/recurrent.py b/neuralmonkey/encoders/recurrent.py index 2a9c60cc4..687910169 100644 --- a/neuralmonkey/encoders/recurrent.py +++ b/neuralmonkey/encoders/recurrent.py @@ -114,7 +114,9 @@ def rnn_layer(rnn_input: tf.Tensor, "must match when applying residual connection. Reshaping " "the rnn output using linear projection.".format( outputs.get_shape(), rnn_input.get_shape())) + # pylint: disable=redefined-variable-type outputs = tf.layers.dense(outputs, rnn_input.shape.as_list()[-1]) + # pylint: enable=redefined-variable-type outputs += rnn_input return outputs, final_state diff --git a/neuralmonkey/experiment.py b/neuralmonkey/experiment.py index 4f1aa9ab8..dc4ec9386 100644 --- a/neuralmonkey/experiment.py +++ b/neuralmonkey/experiment.py @@ -105,6 +105,19 @@ def model(self) -> Namespace: return self._model + def _bless_graph_executors(self) -> None: + if hasattr(self.model, "trainer"): + if isinstance(self.model.trainer, List): + trainers = self.model.trainer + else: + trainers = [self.model.trainer] + + for trainer in trainers: + log("Trainer fetches: {}".format(trainer.fetches)) + + for runner in self.model.runners: + log("Runner fetches: {}".format(runner.fetches)) + def build_model(self) -> None: if self._model_built: raise RuntimeError("build_model() called twice") @@ -117,11 +130,13 @@ def build_model(self) -> None: # Enable the created model parts to find this experiment. type(self)._current_experiment = self # type: ignore - self.config.build_model(warn_unused=self.train_mode) - type(self)._current_experiment = None + self.config.build_model(warn_unused=self.train_mode) self._model = self.config.model self._model_built = True + self._bless_graph_executors() + + type(self)._current_experiment = None if self.model.runners_batch_size is None: self.model.runners_batch_size = self.model.batch_size diff --git a/neuralmonkey/model/parameterized.py b/neuralmonkey/model/parameterized.py index 60c10e137..1ab4e3975 100644 --- a/neuralmonkey/model/parameterized.py +++ b/neuralmonkey/model/parameterized.py @@ -89,6 +89,8 @@ def use_scope(self) -> Iterator[None]: """ # If we are already reusing, reuse regardless of self._reuse. reuse = self._variable_scope.reuse or self._reuse + if not reuse: + reuse = tf.AUTO_REUSE with tf.variable_scope(self._variable_scope, reuse=reuse): # tf.variable_scope always creates a NEW name scope for ops, but diff --git a/neuralmonkey/runners/base_runner.py b/neuralmonkey/runners/base_runner.py index 31bd42964..4026cf985 100644 --- a/neuralmonkey/runners/base_runner.py +++ b/neuralmonkey/runners/base_runner.py @@ -1,4 +1,4 @@ -from abc import abstractmethod +from abc import abstractmethod, abstractproperty from typing import (Any, Dict, Tuple, List, NamedTuple, Union, Set, TypeVar, Generic, Optional) import numpy as np @@ -13,6 +13,8 @@ FeedDict = Dict[tf.Tensor, Union[int, float, np.ndarray]] NextExecute = Tuple[Union[Dict, List], List[FeedDict]] MP = TypeVar("MP", bound=GenericModelPart) +Executor = TypeVar("Executor", bound="GraphExecutor") +Runner = TypeVar("Runner", bound="BaseRunner") # pylint: enable=invalid-name @@ -37,29 +39,63 @@ class ExecutionResult(NamedTuple( """ -class Executable: +class GraphExecutor(GenericModelPart): - @property - def result(self) -> Optional[ExecutionResult]: - return getattr(self, "_result") + class Executable(Generic[Executor]): - @abstractmethod - def next_to_execute(self) -> NextExecute: - """Get the tensors and additional feed dicts for execution.""" - raise NotImplementedError() + def __init__(self, + executor: Executor, + compute_losses: bool, + summaries: bool, + num_sessions: int) -> None: + self._executor = executor + self.compute_losses = compute_losses + self.summaries = summaries + self.num_sessions = num_sessions - @abstractmethod - def collect_results(self, results: List[Dict]) -> None: - raise NotImplementedError() + self._result = None # type: Optional[ExecutionResult] + def set_result(self, outputs: List[Any], losses: List[float], + scalar_summaries: tf.Summary, + histogram_summaries: tf.Summary, + image_summaries: tf.Summary) -> None: + self._result = ExecutionResult( + outputs, losses, scalar_summaries, histogram_summaries, + image_summaries) -class GraphExecutor(GenericModelPart): + @property + def result(self) -> Optional[ExecutionResult]: + return self._result + + @property + def executor(self) -> Executor: + return self._executor + + def next_to_execute(self) -> NextExecute: + """Get the tensors and additional feed dicts for execution.""" + return self.executor.fetches, [{}] + + @abstractmethod + def collect_results(self, results: List[Dict]) -> None: + return None def __init__(self, dependencies: Set[GenericModelPart]) -> None: self._dependencies = dependencies self._feedables, self._parameterizeds = self.get_dependencies() + def get_executable(self, + compute_losses: bool, + summaries: bool, + num_sessions: int) -> "GraphExecutor.Executable": + # Since the executable is always subclassed, we can instantiate it + return self.Executable( # type: ignore + self, compute_losses, summaries, num_sessions) + + @abstractproperty + def fetches(self) -> Dict[str, tf.Tensor]: + raise NotImplementedError() + @property def dependencies(self) -> List[str]: return ["_dependencies"] @@ -72,21 +108,29 @@ def feedables(self) -> Set[Feedable]: def parameterizeds(self) -> Set[Parameterized]: return self._parameterizeds - @abstractmethod - def get_executable(self, - compute_losses: bool, - summaries: bool, - num_sessions: int) -> Executable: - raise NotImplementedError() - class BaseRunner(GraphExecutor, Generic[MP]): + + # pylint: disable=too-few-public-methods + # Pylint issue here: https://github.com/PyCQA/pylint/issues/2607 + class Executable(GraphExecutor.Executable[Runner]): + + def next_to_execute(self) -> NextExecute: + fetches = self.executor.fetches + + if not self.compute_losses: + for loss in self.executor.loss_names: + fetches[loss] = tf.zeros([]) + + return fetches, [{}] + # pylint: enable=too-few-public-methods + def __init__(self, output_series: str, decoder: MP) -> None: GraphExecutor.__init__(self, {decoder}) self.output_series = output_series - self._decoder = decoder + self.decoder = decoder if not hasattr(decoder, "data_id"): notice("Top-level decoder {} does not have the 'data_id' attribute" @@ -94,7 +138,7 @@ def __init__(self, @property def decoder_data_id(self) -> Optional[str]: - return getattr(self._decoder, "data_id", None) + return getattr(self.decoder, "data_id", None) @property def loss_names(self) -> List[str]: diff --git a/neuralmonkey/runners/beamsearch_runner.py b/neuralmonkey/runners/beamsearch_runner.py index 8caf27b8c..92a384afc 100644 --- a/neuralmonkey/runners/beamsearch_runner.py +++ b/neuralmonkey/runners/beamsearch_runner.py @@ -1,125 +1,124 @@ -from typing import Callable, List, Dict, Optional +from typing import Callable, List, Dict import scipy import numpy as np +import tensorflow as tf from typeguard import check_argument_types +from neuralmonkey.decorators import tensor from neuralmonkey.decoders.beam_search_decoder import BeamSearchDecoder -from neuralmonkey.runners.base_runner import ( - BaseRunner, Executable, ExecutionResult, NextExecute) +from neuralmonkey.runners.base_runner import BaseRunner, NextExecute # pylint: disable=unused-import from neuralmonkey.runners.base_runner import FeedDict # pylint: enable=unused-import from neuralmonkey.vocabulary import END_TOKEN_INDEX -class BeamSearchExecutable(Executable): - def __init__(self, - rank: int, - num_sessions: int, - decoder: BeamSearchDecoder, - postprocess: Optional[Callable]) -> None: - self._rank = rank - self._num_sessions = num_sessions - self._decoder = decoder - self._postprocess = postprocess - - self._next_feed = [{} for _ in range(self._num_sessions)] \ - # type: List[FeedDict] - - # During ensembling, we set the decoder max_steps to zero because the - # loop is run manually in the runner. - if self._num_sessions > 1: - for fd in self._next_feed: - fd.update({self._decoder.max_steps: 0}) - - self._result = None # type: Optional[ExecutionResult] - - def next_to_execute(self) -> NextExecute: - return {"bs_outputs": self._decoder.outputs}, self._next_feed - - def collect_results(self, results: List[Dict]) -> None: - # Recompute logits - # Only necessary when ensembling models - prev_logprobs = [res["bs_outputs"].last_search_state.prev_logprobs - for res in results] - - # Arithmetic mean - ens_logprobs = (scipy.misc.logsumexp(prev_logprobs, 0) - - np.log(self._num_sessions)) - - if self._is_finished(results): - self.prepare_results( - results[0]["bs_outputs"].last_search_step_output) - return - - # Prepare the next feed_dict (required for ensembles) - self._next_feed = [] - for result in results: - bs_out = result["bs_outputs"] - - search_state = bs_out.last_search_state._replace( - prev_logprobs=ens_logprobs) - - dec_ls = bs_out.last_dec_loop_state - feedables = dec_ls.feedables._replace(step=1) - dec_ls = dec_ls._replace(feedables=feedables) - - fd = { - self._decoder.max_steps: 1, - self._decoder.search_state: search_state, - self._decoder.search_results: bs_out.last_search_step_output, - self._decoder.decoder_state: dec_ls} - - self._next_feed.append(fd) - - return - - def prepare_results(self, output): - bs_scores = [s[self._rank - 1] for s in output.scores] - - tok_ids = np.transpose(output.token_ids, [1, 2, 0]) - decoded_tokens = [toks[self._rank - 1][1:] for toks in tok_ids] - - for i, sent in enumerate(decoded_tokens): - decoded = [] - for tok_id in sent: - if tok_id == END_TOKEN_INDEX: - break - decoded.append(self._decoder.vocabulary.index_to_word[tok_id]) - decoded_tokens[i] = decoded - - if self._postprocess is not None: - decoded_tokens = self._postprocess(decoded_tokens) - - # TODO: provide better summaries in case (issue #599) - # we want to use the runner during training. - self._result = ExecutionResult( - outputs=decoded_tokens, - losses=[np.mean(bs_scores) * len(bs_scores)], - scalar_summaries=None, - histogram_summaries=None, - image_summaries=None) - - def _is_finished(self, results): - finished = [ - all(res["bs_outputs"].last_dec_loop_state.feedables.finished) - for res in results] - if all(finished): - return True - bs_out = results[0]["bs_outputs"] - step = len(bs_out.last_search_step_output.token_ids) - 1 - if step >= self._decoder.max_steps_int: - return True - return False - - class BeamSearchRunner(BaseRunner[BeamSearchDecoder]): """A runner which takes the output from a beam search decoder. The runner and the beam search decoder support ensembling. """ + class Executable(BaseRunner.Executable["BeamSearchRunner"]): + def __init__(self, + executor: "BeamSearchRunner", + compute_losses: bool, + summaries: bool, + num_sessions: int) -> None: + super().__init__(executor, compute_losses, summaries, num_sessions) + + self.rank = executor.rank + self.decoder = executor.decoder + self.postprocess = executor.postprocess + + self._next_feed = [{} for _ in range(self.num_sessions)] \ + # type: List[FeedDict] + + # During ensembling, we set the decoder max_steps to zero because + # the loop is run manually in the runner. + if self.num_sessions > 1: + for fd in self._next_feed: + fd.update({self.decoder.max_steps: 0}) + + def next_to_execute(self) -> NextExecute: + return {"bs_outputs": self.decoder.outputs}, self._next_feed + + def collect_results(self, results: List[Dict]) -> None: + # Recompute logits + # Only necessary when ensembling models + prev_logprobs = [res["bs_outputs"].last_search_state.prev_logprobs + for res in results] + + # Arithmetic mean + ens_logprobs = (scipy.misc.logsumexp(prev_logprobs, 0) + - np.log(self.num_sessions)) + + if self._is_finished(results): + self.prepare_results( + results[0]["bs_outputs"].last_search_step_output) + return + + # Prepare the next feed_dict (required for ensembles) + self._next_feed = [] + for result in results: + bout = result["bs_outputs"] + + search_state = bout.last_search_state._replace( + prev_logprobs=ens_logprobs) + + dec_ls = bout.last_dec_loop_state + feedables = dec_ls.feedables._replace(step=1) + dec_ls = dec_ls._replace(feedables=feedables) + + fd = { + self.decoder.max_steps: 1, + self.decoder.search_state: search_state, + self.decoder.search_results: bout.last_search_step_output, + self.decoder.decoder_state: dec_ls} + + self._next_feed.append(fd) + + return + + def prepare_results(self, output): + bs_scores = [s[self.rank - 1] for s in output.scores] + + tok_ids = np.transpose(output.token_ids, [1, 2, 0]) + decoded_tokens = [toks[self.rank - 1][1:] for toks in tok_ids] + + for i, sent in enumerate(decoded_tokens): + decoded = [] + for tok_id in sent: + if tok_id == END_TOKEN_INDEX: + break + decoded.append( + self.decoder.vocabulary.index_to_word[tok_id]) + decoded_tokens[i] = decoded + + if self.postprocess is not None: + decoded_tokens = self.postprocess(decoded_tokens) + + # TODO: provide better summaries in case (issue #599) + # we want to use the runner during training. + self.set_result(outputs=decoded_tokens, + losses=[np.mean(bs_scores) * len(bs_scores)], + scalar_summaries=None, + histogram_summaries=None, + image_summaries=None) + + def _is_finished(self, results): + finished = [ + all(res["bs_outputs"].last_dec_loop_state.feedables.finished) + for res in results] + if all(finished): + return True + bs_out = results[0]["bs_outputs"] + step = len(bs_out.last_search_step_output.token_ids) - 1 + if step >= self.decoder.max_steps_int: + return True + return False + def __init__(self, output_series: str, decoder: BeamSearchDecoder, @@ -142,17 +141,12 @@ def __init__(self, ("Rank of output hypothesis must be between 1 and the beam " "size ({}), was {}.").format(decoder.beam_size, rank)) - self._rank = rank - self._postprocess = postprocess - - # pylint: disable=unused-argument - def get_executable(self, - compute_losses: bool = False, - summaries: bool = True, - num_sessions: int = 1) -> BeamSearchExecutable: - return BeamSearchExecutable( - self._rank, num_sessions, self._decoder, self._postprocess) - # pylint: enable=unused-argument + self.rank = rank + self.postprocess = postprocess + + @tensor + def fetches(self) -> Dict[str, tf.Tensor]: + return {"bs_outputs": self.decoder.outputs} @property def loss_names(self) -> List[str]: diff --git a/neuralmonkey/runners/ctc_debug_runner.py b/neuralmonkey/runners/ctc_debug_runner.py index 84148c3d6..4c1e63e93 100644 --- a/neuralmonkey/runners/ctc_debug_runner.py +++ b/neuralmonkey/runners/ctc_debug_runner.py @@ -1,55 +1,45 @@ -from typing import Dict, List, Optional +from typing import Dict, List import numpy as np +import tensorflow as tf from typeguard import check_argument_types -from neuralmonkey.runners.base_runner import ( - BaseRunner, Executable, FeedDict, ExecutionResult, NextExecute) -from neuralmonkey.vocabulary import Vocabulary +from neuralmonkey.runners.base_runner import BaseRunner from neuralmonkey.decoders.ctc_decoder import CTCDecoder +from neuralmonkey.decorators import tensor -class CTCDebugExecutable(Executable): - - def __init__(self, - fetches: FeedDict, - vocabulary: Vocabulary) -> None: - self._fetches = fetches - self._vocabulary = vocabulary - - self._result = None # type: Optional[ExecutionResult] - - def next_to_execute(self) -> NextExecute: - return self._fetches, [] +class CTCDebugRunner(BaseRunner[CTCDecoder]): + """A runner that print out raw CTC output including the blank symbols.""" - def collect_results(self, results: List[Dict]) -> None: - if len(results) != 1: - raise RuntimeError("CTCDebug runner does not support ensembling.") + # pylint: disable=too-few-public-methods + # Pylint issue here: https://github.com/PyCQA/pylint/issues/2607 + class Executable(BaseRunner.Executable["CTCDebugRunner"]): - logits = results[0]["logits"] - argmaxes = np.argmax(logits, axis=2).T + def collect_results(self, results: List[Dict]) -> None: - decoded_batch = [] - for indices in argmaxes: - decoded_instance = [] - for index in indices: - if index == len(self._vocabulary): - symbol = "" - else: - symbol = self._vocabulary.index_to_word[index] - decoded_instance.append(symbol) - decoded_batch.append(decoded_instance) + vocabulary = self.executor.decoder.vocabulary + if len(results) != 1: + raise RuntimeError("CTCDebugRunners do not support ensembles.") - self._result = ExecutionResult( - outputs=decoded_batch, - losses=[], - scalar_summaries=None, - histogram_summaries=None, - image_summaries=None) + logits = results[0]["logits"] + argmaxes = np.argmax(logits, axis=2).T + decoded_batch = [] + for indices in argmaxes: + decoded_instance = [] + for index in indices: + if index == len(vocabulary): + symbol = "" + else: + symbol = vocabulary.index_to_word[index] + decoded_instance.append(symbol) + decoded_batch.append(decoded_instance) -class CTCDebugRunner(BaseRunner[CTCDecoder]): - """A runner that print out raw CTC output including the blank symbols.""" + self.set_result(outputs=decoded_batch, losses=[], + scalar_summaries=None, histogram_summaries=None, + image_summaries=None) + # pylint: enable=too-few-public-methods def __init__(self, output_series: str, @@ -57,15 +47,9 @@ def __init__(self, check_argument_types() BaseRunner[CTCDecoder].__init__(self, output_series, decoder) - # pylint: disable=unused-argument - def get_executable(self, - compute_losses: bool, - summaries: bool, - num_sessions: int) -> CTCDebugExecutable: - fetches = {"logits": self._decoder.logits} - - return CTCDebugExecutable(fetches, self._decoder.vocabulary) - # pylint: enable=unused-argument + @tensor + def fetches(self) -> Dict[str, tf.Tensor]: + return {"logits": self.decoder.logits} @property def loss_names(self) -> List[str]: diff --git a/neuralmonkey/runners/label_runner.py b/neuralmonkey/runners/label_runner.py index e2296c166..90cdaac78 100644 --- a/neuralmonkey/runners/label_runner.py +++ b/neuralmonkey/runners/label_runner.py @@ -1,10 +1,12 @@ -from typing import List, Dict, Optional, Callable +from typing import List, Dict, Callable + import numpy as np +import tensorflow as tf from typeguard import check_argument_types -from neuralmonkey.vocabulary import Vocabulary, END_TOKEN_INDEX -from neuralmonkey.runners.base_runner import ( - BaseRunner, Executable, FeedDict, ExecutionResult, NextExecute) +from neuralmonkey.decorators import tensor +from neuralmonkey.vocabulary import END_TOKEN_INDEX +from neuralmonkey.runners.base_runner import BaseRunner from neuralmonkey.decoders.sequence_labeler import SequenceLabeler # pylint: disable=invalid-name @@ -12,54 +14,41 @@ # pylint: enable=invalid-name -class LabelRunExecutable(Executable): - - def __init__(self, - fetches: FeedDict, - vocabulary: Vocabulary, - postprocess: Optional[Postprocessor]) -> None: - self._fetches = fetches - self._vocabulary = vocabulary - self._postprocess = postprocess - - self._result = None # type: Optional[ExecutionResult] - - def next_to_execute(self) -> NextExecute: - return self._fetches, [] - - def collect_results(self, results: List[Dict]) -> None: - loss = results[0].get("loss", 0.) - summed_logprobs = results[0]["label_logprobs"] - input_mask = results[0]["input_mask"] +class LabelRunner(BaseRunner[SequenceLabeler]): - for sess_result in results[1:]: - loss += sess_result.get("loss", 0.) - summed_logprobs = np.logaddexp(summed_logprobs, - sess_result["label_logprobs"]) - assert input_mask == sess_result["input_mask"] + # pylint: disable=too-few-public-methods + # Pylint issue here: https://github.com/PyCQA/pylint/issues/2607 + class Executable(BaseRunner.Executable["LabelRunner"]): - argmaxes = np.argmax(summed_logprobs, axis=2) + def collect_results(self, results: List[Dict]) -> None: + loss = results[0].get("loss", 0.) + summed_logprobs = results[0]["label_logprobs"] + input_mask = results[0]["input_mask"] - # CAUTION! FABULOUS HACK BELIEVE ME - argmaxes -= END_TOKEN_INDEX - argmaxes *= input_mask.astype(int) - argmaxes += END_TOKEN_INDEX + for sess_result in results[1:]: + loss += sess_result.get("loss", 0.) + summed_logprobs = np.logaddexp(summed_logprobs, + sess_result["label_logprobs"]) + assert input_mask == sess_result["input_mask"] - # must transpose argmaxes because vectors_to_sentences is time-major - decoded_labels = self._vocabulary.vectors_to_sentences(argmaxes.T) + argmaxes = np.argmax(summed_logprobs, axis=2) - if self._postprocess is not None: - decoded_labels = self._postprocess(decoded_labels) + # CAUTION! FABULOUS HACK BELIEVE ME + argmaxes -= END_TOKEN_INDEX + argmaxes *= input_mask.astype(int) + argmaxes += END_TOKEN_INDEX - self._result = ExecutionResult( - outputs=decoded_labels, - losses=[loss], - scalar_summaries=None, - histogram_summaries=None, - image_summaries=None) + # transpose argmaxes because vectors_to_sentences is time-major + vocabulary = self.executor.decoder.vocabulary + decoded_labels = vocabulary.vectors_to_sentences(argmaxes.T) + if self.executor.postprocess is not None: + decoded_labels = self.executor.postprocess(decoded_labels) -class LabelRunner(BaseRunner[SequenceLabeler]): + self.set_result(outputs=decoded_labels, losses=[loss], + scalar_summaries=None, histogram_summaries=None, + image_summaries=None) + # pylint: enable=too-few-public-methods def __init__(self, output_series: str, @@ -67,26 +56,14 @@ def __init__(self, postprocess: Postprocessor = None) -> None: check_argument_types() BaseRunner[SequenceLabeler].__init__(self, output_series, decoder) - - self._postprocess = postprocess - - # pylint: disable=unused-argument - # Don't know why it works in Attention.attention and not here. - # Parameters are unused beacause they are inherited. - def get_executable(self, - compute_losses: bool, - summaries: bool, - num_sessions: int) -> LabelRunExecutable: - fetches = { - "label_logprobs": self._decoder.logprobs, - "input_mask": self._decoder.encoder.input_sequence.temporal_mask} - - if compute_losses: - fetches["loss"] = self._decoder.cost - - return LabelRunExecutable( - fetches, self._decoder.vocabulary, self._postprocess) - # pylint: enable: unused-argument + self.postprocess = postprocess + + @tensor + def fetches(self) -> Dict[str, tf.Tensor]: + return { + "label_logprobs": self.decoder.logprobs, + "input_mask": self.decoder.encoder.input_sequence.temporal_mask, + "loss": self.decoder.cost} @property def loss_names(self) -> List[str]: diff --git a/neuralmonkey/runners/logits_runner.py b/neuralmonkey/runners/logits_runner.py index 8609f4462..9d19a988e 100644 --- a/neuralmonkey/runners/logits_runner.py +++ b/neuralmonkey/runners/logits_runner.py @@ -7,73 +7,56 @@ import tensorflow as tf from neuralmonkey.decoders.classifier import Classifier -from neuralmonkey.runners.base_runner import ( - BaseRunner, Executable, FeedDict, ExecutionResult, NextExecute) -from neuralmonkey.vocabulary import Vocabulary +from neuralmonkey.decorators import tensor +from neuralmonkey.runners.base_runner import BaseRunner -class LogitsExecutable(Executable): - - def __init__(self, - fetches: FeedDict, - vocabulary: Vocabulary, - normalize: bool, - pick_index: Optional[int]) -> None: - self._fetches = fetches - self._vocabulary = vocabulary - self._normalize = normalize - self._pick_index = pick_index - - self._result = None # type: Optional[ExecutionResult] - - def next_to_execute(self) -> NextExecute: - return self._fetches, [] - - def collect_results(self, results: List[Dict]) -> None: - if len(results) != 1: - raise ValueError("LogitsRunner needs exactly 1 execution result, " - "got {}".format(len(results))) +# pylint: disable=too-few-public-methods +class LogitsRunner(BaseRunner[Classifier]): + """A runner which takes the output from decoder.decoded_logits. - train_loss = results[0]["train_loss"] - runtime_loss = results[0]["runtime_loss"] + The logits / normalized probabilities are outputted as tab-separates string + values. If the decoder produces a list of logits (as the recurrent + decoder), the tab separated arrays are separated with commas. + Alternatively, we may be interested in a single distribution dimension. + """ - # logits_list in shape (time, batch, vocab) - logits_list = results[0]["logits"] + class Executable(BaseRunner.Executable["LogitsRunner"]): - # outputs are lists of strings (batch, time) - outputs = [[] for _ in logits_list[0]] # type: List[List[str]] + def collect_results(self, results: List[Dict]) -> None: + if len(results) != 1: + raise ValueError("LogitsRunner needs exactly 1 execution " + "result, got {}".format(len(results))) - for time_step in logits_list: - for logits, output_list in zip(time_step, outputs): + train_loss = results[0]["train_loss"] + runtime_loss = results[0]["runtime_loss"] - if self._normalize: - logits = np.exp(logits) / np.sum(np.exp(logits), axis=0) - if self._pick_index: - instance_logits = str(logits[self._pick_index]) - else: - instance_logits = ",".join(str(l) for l in logits) + # logits_list in shape (time, batch, vocab) + logits_list = results[0]["logits"] - output_list.append(instance_logits) + # outputs are lists of strings (batch, time) + outputs = [[] for _ in logits_list[0]] # type: List[List[str]] - str_outputs = [["\t".join(l)] for l in outputs] + for time_step in logits_list: + for logits, output_list in zip(time_step, outputs): - self._result = ExecutionResult( - outputs=str_outputs, - losses=[train_loss, runtime_loss], - scalar_summaries=None, - histogram_summaries=None, - image_summaries=None) + if self.executor.normalize: + logits = np.exp(logits) / np.sum(np.exp(logits), + axis=0) + if self.executor.pick_index: + instance_logits = str(logits[self.executor.pick_index]) + else: + instance_logits = ",".join(str(l) for l in logits) + output_list.append(instance_logits) -# pylint: disable=too-few-public-methods -class LogitsRunner(BaseRunner[Classifier]): - """A runner which takes the output from decoder.decoded_logits. + str_outputs = [["\t".join(l)] for l in outputs] - The logits / normalized probabilities are outputted as tab-separates string - values. If the decoder produces a list of logits (as the recurrent - decoder), the tab separated arrays are separated with commas. - Alternatively, we may be interested in a single distribution dimension. - """ + self.set_result(outputs=str_outputs, + losses=[train_loss, runtime_loss], + scalar_summaries=None, + histogram_summaries=None, + image_summaries=None) def __init__(self, output_series: str, @@ -100,36 +83,25 @@ def __init__(self, raise ValueError("Either a pick index or a vocabulary value can " "be specified, not both at the same time.") - self._pick_index = None # type: Optional[int] + self.pick_index = None # type: Optional[int] - self._normalize = normalize + self.normalize = normalize if pick_value is not None: - if pick_value in self._decoder.vocabulary: - vocab_map = self._decoder.vocabulary.word_to_index - self._pick_index = vocab_map[pick_value] + if pick_value in self.decoder.vocabulary: + vocab_map = self.decoder.vocabulary.word_to_index + self.pick_index = vocab_map[pick_value] else: raise ValueError( "Value '{}' is not in vocabulary of decoder '{}'".format( pick_value, decoder.name)) else: - self._pick_index = pick_index - - # pylint: disable=unused-argument - def get_executable(self, - compute_losses: bool, - summaries: bool, - num_sessions: int) -> LogitsExecutable: - fetches = {"logits": self._decoder.decoded_logits, - "train_loss": tf.zeros([]), - "runtime_loss": tf.zeros([])} - - if compute_losses: - fetches["train_loss"] = self._decoder.train_loss - fetches["runtime_loss"] = self._decoder.runtime_loss - - return LogitsExecutable(fetches, self._decoder.vocabulary, - self._normalize, self._pick_index) - # pylint: enable: unused-argument + self.pick_index = pick_index + + @tensor + def fetches(self) -> Dict[str, tf.Tensor]: + return {"logits": self.decoder.decoded_logits, + "train_loss": self.decoder.train_loss, + "runtime_loss": self.decoder.runtime_loss} @property def loss_names(self) -> List[str]: diff --git a/neuralmonkey/runners/perplexity_runner.py b/neuralmonkey/runners/perplexity_runner.py index 08c277689..eadbb9f90 100644 --- a/neuralmonkey/runners/perplexity_runner.py +++ b/neuralmonkey/runners/perplexity_runner.py @@ -1,38 +1,31 @@ from typing import Dict, List -# pylint: disable=unused-import -from typing import Optional -# pylint: enable=unused-import from typeguard import check_argument_types import tensorflow as tf import numpy as np from neuralmonkey.decoders.autoregressive import AutoregressiveDecoder -from neuralmonkey.runners.base_runner import ( - BaseRunner, Executable, ExecutionResult, NextExecute) +from neuralmonkey.decorators import tensor +from neuralmonkey.runners.base_runner import BaseRunner -class PerplexityExecutable(Executable): - def __init__(self, xent_op: tf.Tensor) -> None: - self._xent_op = xent_op - - self._result = None # type: Optional[ExecutionResult] - - def next_to_execute(self) -> NextExecute: - return {"xents": self._xent_op}, [] - - def collect_results(self, results: List[Dict]) -> None: - perplexities = np.mean([2 ** res["xents"] for res in results], axis=0) - xent = float(np.mean([res["xents"] for res in results])) - self._result = ExecutionResult( - outputs=perplexities.tolist(), - losses=[xent], - scalar_summaries=None, - histogram_summaries=None, - image_summaries=None) +class PerplexityRunner(BaseRunner[AutoregressiveDecoder]): + # pylint: disable=too-few-public-methods + # Pylint issue here: https://github.com/PyCQA/pylint/issues/2607 + class Executable(BaseRunner.Executable["PerplexityRunner"]): + + def collect_results(self, results: List[Dict]) -> None: + perplexities = np.mean( + [2 ** res["xents"] for res in results], axis=0) + xent = float(np.mean([res["xents"] for res in results])) + self.set_result(outputs=perplexities.tolist(), + losses=[xent], + scalar_summaries=None, + histogram_summaries=None, + image_summaries=None) + # pylint: enable=too-few-public-methods -class PerplexityRunner(BaseRunner[AutoregressiveDecoder]): def __init__(self, output_series: str, decoder: AutoregressiveDecoder) -> None: @@ -40,16 +33,11 @@ def __init__(self, BaseRunner[AutoregressiveDecoder].__init__( self, output_series, decoder) - self._decoder_xent = self._decoder.train_xents - - # pylint: disable=unused-argument - def get_executable(self, - compute_losses: bool, - summaries: bool, - num_sessions: int) -> PerplexityExecutable: - return PerplexityExecutable(self._decoder_xent) - # pylint: enable=unused-argument + @tensor + def fetches(self) -> Dict[str, tf.Tensor]: + return {"xents": self.decoder.train_xents} @property def loss_names(self) -> List[str]: + # TODO(tf-data) Shouldn't be "xents" here? return ["xent"] diff --git a/neuralmonkey/runners/plain_runner.py b/neuralmonkey/runners/plain_runner.py index 2e2c4a857..215a46994 100644 --- a/neuralmonkey/runners/plain_runner.py +++ b/neuralmonkey/runners/plain_runner.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Union, Callable, Optional +from typing import Dict, List, Union, Callable import tensorflow as tf from typeguard import check_argument_types @@ -7,9 +7,8 @@ from neuralmonkey.decoders.ctc_decoder import CTCDecoder from neuralmonkey.decoders.classifier import Classifier from neuralmonkey.decoders.sequence_labeler import SequenceLabeler -from neuralmonkey.runners.base_runner import ( - BaseRunner, Executable, FeedDict, ExecutionResult, NextExecute) -from neuralmonkey.vocabulary import Vocabulary +from neuralmonkey.decorators import tensor +from neuralmonkey.runners.base_runner import BaseRunner # pylint: disable=invalid-name SupportedDecoder = Union[ @@ -18,47 +17,35 @@ # pylint: enable=invalid-name -class PlainExecutable(Executable): - - def __init__(self, - fetches: FeedDict, - num_sessions: int, - vocabulary: Vocabulary, - postprocess: Optional[Postprocessor]) -> None: - self._fetches = fetches - self._num_sessions = num_sessions - self._vocabulary = vocabulary - self._postprocess = postprocess - - self._result = None # type: Optional[ExecutionResult] +class PlainRunner(BaseRunner[SupportedDecoder]): + """A runner which takes the output from decoder.decoded.""" - def next_to_execute(self) -> NextExecute: - return self._fetches, [] + # pylint: disable=too-few-public-methods + # Pylint issue here: https://github.com/PyCQA/pylint/issues/2607 + class Executable(BaseRunner.Executable["PlainRunner"]): - def collect_results(self, results: List[Dict]) -> None: - if len(results) != 1: - raise ValueError("PlainRunner needs exactly 1 execution result, " - "got {}".format(len(results))) + def collect_results(self, results: List[Dict]) -> None: + if len(results) != 1: + raise ValueError("PlainRunner needs exactly 1 execution " + "result, got {}".format(len(results))) - train_loss = results[0]["train_loss"] - runtime_loss = results[0]["runtime_loss"] - decoded = results[0]["decoded"] + vocabulary = self.executor.decoder.vocabulary - decoded_tokens = self._vocabulary.vectors_to_sentences(decoded) + train_loss = results[0]["train_loss"] + runtime_loss = results[0]["runtime_loss"] + decoded = results[0]["decoded"] - if self._postprocess is not None: - decoded_tokens = self._postprocess(decoded_tokens) + decoded_tokens = vocabulary.vectors_to_sentences(decoded) - self._result = ExecutionResult( - outputs=decoded_tokens, - losses=[train_loss, runtime_loss], - scalar_summaries=None, - histogram_summaries=None, - image_summaries=None) + if self.executor.postprocess is not None: + decoded_tokens = self.executor.postprocess(decoded_tokens) - -class PlainRunner(BaseRunner[SupportedDecoder]): - """A runner which takes the output from decoder.decoded.""" + self.set_result(outputs=decoded_tokens, + losses=[train_loss, runtime_loss], + scalar_summaries=None, + histogram_summaries=None, + image_summaries=None) + # pylint: enable=too-few-public-methods def __init__(self, output_series: str, @@ -66,25 +53,13 @@ def __init__(self, postprocess: Postprocessor = None) -> None: check_argument_types() BaseRunner[SupportedDecoder].__init__(self, output_series, decoder) + self.postprocess = postprocess - self._postprocess = postprocess - - # pylint: disable=unused-argument - def get_executable(self, - compute_losses: bool, - summaries: bool, - num_sessions: int): - fetches = {"decoded": self._decoder.decoded, - "train_loss": tf.zeros([]), - "runtime_loss": tf.zeros([])} - - if compute_losses: - fetches["train_loss"] = self._decoder.train_loss - fetches["runtime_loss"] = self._decoder.runtime_loss - - return PlainExecutable( - fetches, num_sessions, self._decoder.vocabulary, self._postprocess) - # pylint: enable=unused-argument + @tensor + def fetches(self) -> Dict[str, tf.Tensor]: + return {"decoded": self.decoder.decoded, + "train_loss": self.decoder.train_loss, + "runtime_loss": self.decoder.runtime_loss} @property def loss_names(self) -> List[str]: diff --git a/neuralmonkey/runners/regression_runner.py b/neuralmonkey/runners/regression_runner.py index 351bb14b8..c285a18f0 100644 --- a/neuralmonkey/runners/regression_runner.py +++ b/neuralmonkey/runners/regression_runner.py @@ -1,56 +1,46 @@ -from typing import Dict, List, Callable, Optional +from typing import Dict, List, Callable import numpy as np import tensorflow as tf from typeguard import check_argument_types from neuralmonkey.decoders.sequence_regressor import SequenceRegressor -from neuralmonkey.runners.base_runner import ( - BaseRunner, Executable, ExecutionResult, NextExecute) +from neuralmonkey.decorators import tensor +from neuralmonkey.runners.base_runner import BaseRunner # pylint: disable=invalid-name Postprocessor = Callable[[List[float]], List[float]] # pylint: enable=invalid-name -class RegressionRunExecutable(Executable): - - def __init__(self, - fetches: Dict[str, tf.Tensor], - postprocess: Optional[Postprocessor]) -> None: - self._fetches = fetches - self._postprocess = postprocess - - self._result = None # type: Optional[ExecutionResult] - - def next_to_execute(self) -> NextExecute: - return self._fetches, [] - - def collect_results(self, results: List[Dict]) -> None: - predictions_sum = np.zeros_like(results[0]["prediction"]) - mse_loss = 0. +class RegressionRunner(BaseRunner[SequenceRegressor]): + """A runnner that takes the predictions of a sequence regressor.""" - for sess_result in results: - if "mse" in sess_result: - mse_loss += sess_result["mse"] + # pylint: disable=too-few-public-methods + # Pylint issue here: https://github.com/PyCQA/pylint/issues/2607 + class Executable(BaseRunner.Executable["RegressionRunner"]): - predictions_sum += sess_result["prediction"] + def collect_results(self, results: List[Dict]) -> None: + predictions_sum = np.zeros_like(results[0]["prediction"]) + mse_loss = 0. - predictions = predictions_sum / len(results) + for sess_result in results: + if "mse" in sess_result: + mse_loss += sess_result["mse"] - if self._postprocess is not None: - predictions = self._postprocess(predictions) + predictions_sum += sess_result["prediction"] - self._result = ExecutionResult( - outputs=predictions.tolist(), - losses=[mse_loss], - scalar_summaries=None, - histogram_summaries=None, - image_summaries=None) + predictions = predictions_sum / len(results) + if self.executor.postprocess is not None: + predictions = self.executor.postprocess(predictions) -class RegressionRunner(BaseRunner[SequenceRegressor]): - """A runnner that takes the predictions of a sequence regressor.""" + self.set_result(outputs=predictions.tolist(), + losses=[mse_loss], + scalar_summaries=None, + histogram_summaries=None, + image_summaries=None) + # pylint: enable=too-few-public-methods def __init__(self, output_series: str, @@ -58,20 +48,12 @@ def __init__(self, postprocess: Postprocessor = None) -> None: check_argument_types() BaseRunner[SequenceRegressor].__init__(self, output_series, decoder) + self.postprocess = postprocess - self._postprocess = postprocess - - # pylint: disable=unused-argument - def get_executable(self, - compute_losses: bool, - summaries: bool, - num_sessions: int) -> Executable: - fetches = {"prediction": self._decoder.predictions} - if compute_losses: - fetches["mse"] = self._decoder.cost - - return RegressionRunExecutable(fetches, self._postprocess) - # pylint: enable=unused-argument + @tensor + def fetches(self) -> Dict[str, tf.Tensor]: + return {"prediction": self.decoder.predictions, + "mse": self.decoder.cost} @property def loss_names(self) -> List[str]: diff --git a/neuralmonkey/runners/runner.py b/neuralmonkey/runners/runner.py index 517dde5e7..33956aad3 100644 --- a/neuralmonkey/runners/runner.py +++ b/neuralmonkey/runners/runner.py @@ -1,14 +1,13 @@ -from typing import Dict, List, Optional, Callable, Union +from typing import Dict, List, Callable, Union import numpy as np import tensorflow as tf from typeguard import check_argument_types -from neuralmonkey.runners.base_runner import ( - BaseRunner, Executable, FeedDict, ExecutionResult, NextExecute) -from neuralmonkey.vocabulary import Vocabulary +from neuralmonkey.runners.base_runner import BaseRunner, NextExecute from neuralmonkey.decoders.autoregressive import AutoregressiveDecoder from neuralmonkey.decoders.classifier import Classifier +from neuralmonkey.decorators import tensor # pylint: disable=invalid-name SupportedDecoder = Union[AutoregressiveDecoder, Classifier] @@ -16,52 +15,52 @@ # pylint: enable=invalid-name -class GreedyRunExecutable(Executable): +class GreedyRunner(BaseRunner[SupportedDecoder]): - def __init__(self, - fetches: FeedDict, - vocabulary: Vocabulary, - postprocess: Optional[Postprocessor]) -> None: - self._fetches = fetches - self._vocabulary = vocabulary - self._postprocess = postprocess + class Executable(BaseRunner.Executable["GreedyRunner"]): - self._result = None # type: Optional[ExecutionResult] + def next_to_execute(self) -> NextExecute: + """Get the tensors and additional feed dicts for execution.""" + fetches = self.executor.fetches - def next_to_execute(self) -> NextExecute: - return self._fetches, [] + if not self.summaries: + fetches["image_summaries"] = None - def collect_results(self, results: List[Dict]) -> None: - train_loss = 0. - runtime_loss = 0. - summed_logprobs = [-np.inf for _ in range( - results[0]["decoded_logprobs"].shape[0])] + if not self.compute_losses: + fetches["train_xent"] = tf.zeros([]) + fetches["runtime_xent"] = tf.zeros([]) - for sess_result in results: - train_loss += sess_result["train_xent"] - runtime_loss += sess_result["runtime_xent"] + return fetches, [{}] - for i, logprob in enumerate(sess_result["decoded_logprobs"]): - summed_logprobs[i] = np.logaddexp(summed_logprobs[i], logprob) + def collect_results(self, results: List[Dict]) -> None: + train_loss = 0. + runtime_loss = 0. + summed_logprobs = [-np.inf for _ in range( + results[0]["decoded_logprobs"].shape[0])] - argmaxes = [np.argmax(l, axis=1) for l in summed_logprobs] + for sess_result in results: + train_loss += sess_result["train_xent"] + runtime_loss += sess_result["runtime_xent"] - decoded_tokens = self._vocabulary.vectors_to_sentences(argmaxes) + for i, logprob in enumerate(sess_result["decoded_logprobs"]): + summed_logprobs[i] = np.logaddexp( + summed_logprobs[i], logprob) - if self._postprocess is not None: - decoded_tokens = self._postprocess(decoded_tokens) + argmaxes = [np.argmax(l, axis=1) for l in summed_logprobs] - image_summaries = results[0].get("image_summaries") + decoded_tokens = self.executor.vocabulary.vectors_to_sentences( + argmaxes) - self._result = ExecutionResult( - outputs=decoded_tokens, - losses=[train_loss, runtime_loss], - scalar_summaries=None, - histogram_summaries=None, - image_summaries=image_summaries) + if self.executor.postprocess is not None: + decoded_tokens = self.executor.postprocess(decoded_tokens) + image_summaries = results[0].get("image_summaries") -class GreedyRunner(BaseRunner[SupportedDecoder]): + self.set_result(outputs=decoded_tokens, + losses=[train_loss, runtime_loss], + scalar_summaries=None, + histogram_summaries=None, + image_summaries=image_summaries) def __init__(self, output_series: str, @@ -71,32 +70,21 @@ def __init__(self, BaseRunner[AutoregressiveDecoder].__init__( self, output_series, decoder) - self._postprocess = postprocess + self.postprocess = postprocess + self.vocabulary = self.decoder.vocabulary + + @tensor + def fetches(self) -> Dict[str, tf.Tensor]: + + fetches = {"decoded_logprobs": self.decoder.runtime_logprobs, + "train_xent": self.decoder.train_loss, + "runtime_xent": self.decoder.runtime_loss} - self.image_summaries = None att_plot_summaries = tf.get_collection("summary_att_plots") if att_plot_summaries: - self.image_summaries = tf.summary.merge(att_plot_summaries) - - # pylint: disable=unused-argument - def get_executable(self, - compute_losses: bool, - summaries: bool, - num_sessions: int) -> GreedyRunExecutable: - fetches = {"decoded_logprobs": self._decoder.runtime_logprobs, - "train_xent": tf.zeros([]), - "runtime_xent": tf.zeros([])} - - if compute_losses: - fetches["train_xent"] = self._decoder.train_loss - fetches["runtime_xent"] = self._decoder.runtime_loss - - if summaries and self.image_summaries is not None: - fetches["image_summaries"] = self.image_summaries - - return GreedyRunExecutable( - fetches, self._decoder.vocabulary, self._postprocess) - # pylint: enable=unused-argument + fetches["image_summaries"] = tf.summary.merge(att_plot_summaries) + + return fetches @property def loss_names(self) -> List[str]: diff --git a/neuralmonkey/runners/tensor_runner.py b/neuralmonkey/runners/tensor_runner.py index 0b24b538e..7fd115d26 100644 --- a/neuralmonkey/runners/tensor_runner.py +++ b/neuralmonkey/runners/tensor_runner.py @@ -1,86 +1,16 @@ -from typing import Dict, List, Optional - -# pylint: disable=unused-import -# Type annotation used in comment -import tensorflow as tf -# pylint: enable=unused-import +from typing import Dict, List import numpy as np +import tensorflow as tf from typeguard import check_argument_types +from neuralmonkey.decorators import tensor from neuralmonkey.logging import warn from neuralmonkey.model.model_part import GenericModelPart -from neuralmonkey.runners.base_runner import ( - BaseRunner, Executable, ExecutionResult, NextExecute, FeedDict) +from neuralmonkey.runners.base_runner import BaseRunner from neuralmonkey.experiment import Experiment -class TensorExecutable(Executable): - - def __init__(self, - fetches: FeedDict, - batch_dims: Dict[str, int], - select_session: Optional[int], - single_tensor: bool) -> None: - self._fetches = fetches - self._batch_dims = batch_dims - self._select_session = select_session - self._single_tensor = single_tensor - - self._result = None # type: Optional[ExecutionResult] - - def next_to_execute(self) -> NextExecute: - return self._fetches, [] - - def collect_results(self, results: List[Dict]) -> None: - if len(results) > 1 and self._select_session is None: - sessions = [] - for res_dict in results: - sessions.append(self._fetch_values_from_session(res_dict)) - - # one call returns a list of dicts. we need to add another list - # dimension in between, so it'll become a 2D list of dicts - # with dimensions (batch, session, tensor_name) - # the ``sessions`` structure is of 'shape' - # (session, batch, tensor_name) so it should be sufficient to - # transpose it: - batched = list(zip(*sessions)) - else: - batched = self._fetch_values_from_session(results[0]) - - self._result = ExecutionResult( - outputs=batched, - losses=[], - scalar_summaries=None, - histogram_summaries=None, - image_summaries=None) - - def _fetch_values_from_session(self, sess_results: Dict) -> List: - - transposed = {} - for name, val in sess_results.items(): - batch_dim = self._batch_dims[name] - - perm = [batch_dim] - for dim in range(len(val.shape)): - if dim != batch_dim: - perm.append(dim) - - transposed_val = np.transpose(val, perm) - transposed[name] = transposed_val - - # now we have dict of tensors in batch. we need - # to have a batch of dicts with the batch dim removed - batched = [dict(zip(transposed, col)) - for col in zip(*transposed.values())] - - if self._single_tensor: - # extract the only item from each dict - batched = [next(iter(d.values())) for d in batched] - - return batched - - class TensorRunner(BaseRunner[GenericModelPart]): """Runner class for printing tensors from a model. @@ -89,6 +19,55 @@ class TensorRunner(BaseRunner[GenericModelPart]): will contain the tensors in a dictionary of numpy arrays. """ + # pylint: disable=too-few-public-methods + # Pylint issue here: https://github.com/PyCQA/pylint/issues/2607 + class Executable(BaseRunner.Executable["TensorRunner"]): + + def collect_results(self, results: List[Dict]) -> None: + if len(results) > 1 and self.executor.select_session is None: + sessions = [] + for res_dict in results: + sessions.append(self._fetch_values_from_session(res_dict)) + + # one call returns a list of dicts. we need to add another + # list dimension in between, so it'll become a 2D list of + # dicts with dimensions (batch, session, tensor_name) the + # ``sessions`` structure is of 'shape' (session, batch, + # tensor_name) so it should be sufficient to transpose it: + batched = list(zip(*sessions)) + else: + batched = self._fetch_values_from_session(results[0]) + + self.set_result(outputs=batched, losses=[], + scalar_summaries=None, histogram_summaries=None, + image_summaries=None) + + def _fetch_values_from_session(self, sess_results: Dict) -> List: + + transposed = {} + for name, val in sess_results.items(): + batch_dim = self.executor.batch_ids[name] + + perm = [batch_dim] + for dim in range(len(val.shape)): + if dim != batch_dim: + perm.append(dim) + + transposed_val = np.transpose(val, perm) + transposed[name] = transposed_val + + # now we have dict of tensors in batch. we need + # to have a batch of dicts with the batch dim removed + batched = [dict(zip(transposed, col)) + for col in zip(*transposed.values())] + + if self.executor.single_tensor: + # extract the only item from each dict + batched = [next(iter(d.values())) for d in batched] + + return batched + # pylint: enable=too-few-public-methods + # pylint: disable=too-many-arguments def __init__(self, output_series: str, @@ -149,44 +128,38 @@ def __init__(self, self._modelparts = modelparts self._tensors = tensors self._batch_dims_name = batch_dims_by_name - self._batch_dims = batch_dims - self._select_session = select_session - self._single_tensor = single_tensor - - self._fetches = {} # type: Dict[str, tf.Tensor] - self._batch_ids = {} # type: Dict[str, int] + self.batch_dims = batch_dims + self.select_session = select_session + self.single_tensor = single_tensor + self.batch_ids = {} # type: Dict[str, int] # pylint: enable=too-many-arguments - # pylint: disable=unused-argument - def get_executable(self, - compute_losses: bool, - summaries: bool, - num_sessions: int) -> TensorExecutable: + @tensor + def fetches(self) -> Dict[str, tf.Tensor]: + fetches = {} # type: Dict[str, tf.Tensor] for name, bid in zip(self._names, self._batch_dims_name): try: - self._fetches[name] = ( + fetches[name] = ( Experiment.get_current().graph.get_tensor_by_name(name)) - self._batch_ids[name] = bid + self.batch_ids[name] = bid except KeyError: warn(("The tensor of name '{}' is not present in the " "graph.").format(name)) for mpart, tname, bid in zip(self._modelparts, self._tensors, - self._batch_dims): + self.batch_dims): if not hasattr(mpart, tname): raise ValueError("Model part {} does not have a tensor called " "{}.".format(mpart, tname)) - tensor = getattr(mpart, tname) + tensorval = getattr(mpart, tname) - self._fetches[tensor.name] = tensor - self._batch_ids[tensor.name] = bid + fetches[tensorval.name] = tensorval + self.batch_ids[tensorval.name] = bid - return TensorExecutable(self._fetches, self._batch_ids, - self._select_session, self._single_tensor) - # pylint: enable=unused-argument + return fetches @property def loss_names(self) -> List[str]: diff --git a/neuralmonkey/runners/word_alignment_runner.py b/neuralmonkey/runners/word_alignment_runner.py index 53791f92b..96e93239a 100644 --- a/neuralmonkey/runners/word_alignment_runner.py +++ b/neuralmonkey/runners/word_alignment_runner.py @@ -1,36 +1,25 @@ from typing import Dict, List -# pylint: disable=unused-import -from typing import Optional -# pylint: enable=unused-import import tensorflow as tf from typeguard import check_argument_types from neuralmonkey.attention.base_attention import BaseAttention from neuralmonkey.decoders.decoder import Decoder -from neuralmonkey.runners.base_runner import ( - BaseRunner, Executable, FeedDict, ExecutionResult, NextExecute) +from neuralmonkey.decorators import tensor +from neuralmonkey.runners.base_runner import BaseRunner -class WordAlignmentRunnerExecutable(Executable): - - def __init__(self, fetches: FeedDict) -> None: - self._fetches = fetches - self._result = None # type: Optional[ExecutionResult] - - def next_to_execute(self) -> NextExecute: - return self._fetches, [] - - def collect_results(self, results: List[Dict]) -> None: - self._result = ExecutionResult( - outputs=results[0]["alignment"], - losses=[], - scalar_summaries=None, - histogram_summaries=None, - image_summaries=None) +class WordAlignmentRunner(BaseRunner[BaseAttention]): + # pylint: disable=too-few-public-methods + # Pylint issue here: https://github.com/PyCQA/pylint/issues/2607 + class Executable(BaseRunner.Executable["WordAlignmentRunner"]): -class WordAlignmentRunner(BaseRunner[BaseAttention]): + def collect_results(self, results: List[Dict]) -> None: + self.set_result(outputs=results[0]["alignment"], losses=[], + scalar_summaries=None, histogram_summaries=None, + image_summaries=None) + # pylint: enable=too-few-public-methods def __init__(self, output_series: str, @@ -41,21 +30,16 @@ def __init__(self, self._key = "{}_run".format(decoder.name) - # pylint: disable=unused-argument - def get_executable(self, - compute_losses: bool = False, - summaries: bool = True, - num_sessions: int = 1) -> WordAlignmentRunnerExecutable: - if self._key not in self._decoder.histories: + @tensor + def fetches(self) -> Dict[str, tf.Tensor]: + if self._key not in self.decoder.histories: raise KeyError("Attention has no recorded histories under " "key '{}'".format(self._key)) - att_histories = self._decoder.histories[self._key] + att_histories = self.decoder.histories[self._key] alignment = tf.transpose(att_histories, perm=[1, 2, 0]) - fetches = {"alignment": alignment} - return WordAlignmentRunnerExecutable(fetches) - # pylint: enable=unused-argument + return {"alignment": alignment} @property def loss_names(self) -> List[str]: diff --git a/neuralmonkey/tf_manager.py b/neuralmonkey/tf_manager.py index dbf1a7ce8..dc428a525 100644 --- a/neuralmonkey/tf_manager.py +++ b/neuralmonkey/tf_manager.py @@ -22,7 +22,7 @@ from neuralmonkey.dataset import Dataset from neuralmonkey.model.feedable import Feedable from neuralmonkey.runners.base_runner import ( - FeedDict, ExecutionResult, Executable, GraphExecutor) + FeedDict, ExecutionResult, GraphExecutor) class TensorFlowManager: @@ -170,7 +170,7 @@ def validation_hook(self, score: float, epoch: int, batch: int) -> None: # pylint: disable=too-many-locals def _run_executables(self, feed_dict: FeedDict, - executables: List[Executable]) -> None: + executables: List[GraphExecutor.Executable]) -> None: all_fetches = {} # We might want to feed different values to each session diff --git a/neuralmonkey/trainers/delayed_update_trainer.py b/neuralmonkey/trainers/delayed_update_trainer.py index 3b0939966..5cd970708 100644 --- a/neuralmonkey/trainers/delayed_update_trainer.py +++ b/neuralmonkey/trainers/delayed_update_trainer.py @@ -1,20 +1,80 @@ from typing import Dict, List, Tuple -# pylint: disable=unused-import -from typing import Optional -# pylint: enable=unused-import import tensorflow as tf from typeguard import check_argument_types from neuralmonkey.decorators import tensor -from neuralmonkey.runners.base_runner import ( - Executable, ExecutionResult, NextExecute) -from neuralmonkey.trainers.generic_trainer import ( - GenericTrainer, Objective, Gradients) +from neuralmonkey.runners.base_runner import GraphExecutor, NextExecute +from neuralmonkey.trainers.generic_trainer import (GenericTrainer, Objective, + Gradients) class DelayedUpdateTrainer(GenericTrainer): + class Executable(GraphExecutor.Executable["DelayedUpdateTrainer"]): + + def __init__(self, executor: "DelayedUpdateTrainer", + compute_losses: bool, summaries: bool, + num_sessions: int) -> None: + assert compute_losses + if num_sessions != 1: + raise ValueError( + "Trainer only supports execution in a single session") + + super().__init__(executor, compute_losses, summaries, num_sessions) + + self.state = 0 + self.res_hist_sums = None + self.res_scal_sums = None + self.res_losses = None + + def next_to_execute(self) -> NextExecute: + + if self.state == 0: # ACCUMULATING + fetches = {"accumulators": self.executor.accumulate_ops, + "counter": self.executor.cumulator_counter, + "losses": self.executor.objective_values} + + elif self.state == 1: # UPDATING + fetches = { + "train_op": self.executor.train_op, + "_update_ops": tf.get_collection(tf.GraphKeys.UPDATE_OPS)} + + if self.summaries: + fetches.update(self.executor.summaries) + + else: # RESETTING + fetches = {"resets": self.executor.reset_ops} + + return fetches, [{}] + + def collect_results(self, results: List[Dict]) -> None: + assert len(results) == 1 + result = results[0] + + if self.state == 0: # ACCUMULATING + self.res_losses = result["losses"] + + # Are we updating? + counter = result["counter"] + + if counter == self.executor.batches_per_update: + self.state = 1 + return + elif self.state == 1: + if self.summaries: + self.res_scal_sums = result["scalar_summaries"] + self.res_hist_sums = result["histogram_summaries"] + + self.state = 2 + return + + assert self.res_losses is not None + self.set_result([], losses=self.res_losses, + scalar_summaries=self.res_scal_sums, + histogram_summaries=self.res_hist_sums, + image_summaries=None) + # pylint: disable=too-many-arguments def __init__(self, batches_per_update: int, @@ -166,75 +226,3 @@ def summaries(self) -> Dict[str, tf.Tensor]: tf.get_collection("summary_train")), "histogram_summaries": tf.summary.merge( tf.get_collection("summary_gradients"))} - - def get_executable(self, - compute_losses: bool = True, - summaries: bool = True, - num_sessions: int = 1) -> Executable: - assert compute_losses - if num_sessions != 1: - raise ValueError( - "Trainer only supports execution in a single session") - - return DelayedTrainExecutable(self, summaries) - - -class DelayedTrainExecutable(Executable): - - def __init__(self, trainer: DelayedUpdateTrainer, summaries: bool) -> None: - self.trainer = trainer - self.summaries = summaries - self._result = None # type: Optional[ExecutionResult] - - self.state = 0 - self.res_hist_sums = None - self.res_scal_sums = None - self.res_losses = None - - def next_to_execute(self) -> NextExecute: - - if self.state == 0: # ACCUMULATING - fetches = {"accumulators": self.trainer.accumulate_ops, - "counter": self.trainer.cumulator_counter, - "losses": self.trainer.objective_values} - - elif self.state == 1: # UPDATING - fetches = { - "train_op": self.trainer.train_op, - "_update_ops": tf.get_collection(tf.GraphKeys.UPDATE_OPS)} - - if self.summaries: - fetches.update(self.trainer.summaries) - - else: # RESETTING - fetches = {"resets": self.trainer.reset_ops} - - return fetches, [{}] - - def collect_results(self, results: List[Dict]) -> None: - assert len(results) == 1 - result = results[0] - - if self.state == 0: # ACCUMULATING - self.res_losses = result["losses"] - - # Are we updating? - counter = result["counter"] - - if counter == self.trainer.batches_per_update: - self.state = 1 - return - elif self.state == 1: - if self.summaries: - self.res_scal_sums = result["scalar_summaries"] - self.res_hist_sums = result["histogram_summaries"] - - self.state = 2 - return - - assert self.res_losses is not None - self._result = ExecutionResult( - [], losses=self.res_losses, - scalar_summaries=self.res_scal_sums, - histogram_summaries=self.res_hist_sums, - image_summaries=None) diff --git a/neuralmonkey/trainers/generic_trainer.py b/neuralmonkey/trainers/generic_trainer.py index ce8044c99..10ba4fb9d 100644 --- a/neuralmonkey/trainers/generic_trainer.py +++ b/neuralmonkey/trainers/generic_trainer.py @@ -7,8 +7,7 @@ from neuralmonkey.decorators import tensor from neuralmonkey.logging import log from neuralmonkey.model.model_part import GenericModelPart -from neuralmonkey.runners.base_runner import ( - GraphExecutor, Executable, ExecutionResult, NextExecute) +from neuralmonkey.runners.base_runner import GraphExecutor, NextExecute # pylint: disable=invalid-name Gradients = List[Tuple[tf.Tensor, tf.Variable]] @@ -42,6 +41,39 @@ class Objective(NamedTuple( # pylint: disable=too-few-public-methods,too-many-locals,too-many-arguments class GenericTrainer(GraphExecutor): + class Executable(GraphExecutor.Executable["GenericTrainer"]): + + def __init__(self, executor: "GenericTrainer", compute_losses: bool, + summaries: bool, num_sessions: int) -> None: + assert compute_losses + if num_sessions != 1: + raise ValueError( + "Trainer only supports execution in a single session") + + super().__init__(executor, compute_losses, summaries, num_sessions) + + def next_to_execute(self) -> NextExecute: + fetches = self.executor.fetches + + if self.summaries: + fetches.update(self.executor.summaries) + + return fetches, [{}] + + def collect_results(self, results: List[Dict]) -> None: + assert len(results) == 1 + result = results[0] + + scalar_summaries = ( + result["scalar_summaries"] if self.summaries else None) + histogram_summaries = ( + result["histogram_summaries"] if self.summaries else None) + + self.set_result([], losses=result["losses"], + scalar_summaries=scalar_summaries, + histogram_summaries=histogram_summaries, + image_summaries=None) + @staticmethod def default_optimizer(): return tf.train.AdamOptimizer(learning_rate=1e-4) @@ -226,47 +258,8 @@ def summaries(self) -> Dict[str, tf.Tensor]: "histogram_summaries": tf.summary.merge( tf.get_collection("summary_gradients"))} - def get_executable(self, - compute_losses: bool = True, - summaries: bool = True, - num_sessions: int = 1) -> Executable: - assert compute_losses - if num_sessions != 1: - raise ValueError( - "Trainer only supports execution in a single session") - - return TrainExecutable(self, summaries) - - -class TrainExecutable(Executable): - - def __init__(self, trainer: GenericTrainer, summaries: bool) -> None: - self.trainer = trainer - self.summaries = summaries - self._result = None # type: Optional[ExecutionResult] - - def next_to_execute(self) -> NextExecute: - fetches = {"train_op": self.trainer.train_op} - - if self.summaries: - fetches.update(self.trainer.summaries) - - fetches["losses"] = self.trainer.objective_values - fetches["_update_ops"] = tf.get_collection(tf.GraphKeys.UPDATE_OPS) - - return fetches, [{}] - - def collect_results(self, results: List[Dict]) -> None: - assert len(results) == 1 - result = results[0] - - scalar_summaries = ( - result["scalar_summaries"] if self.summaries else None) - histogram_summaries = ( - result["histogram_summaries"] if self.summaries else None) - - self._result = ExecutionResult( - [], losses=result["losses"], - scalar_summaries=scalar_summaries, - histogram_summaries=histogram_summaries, - image_summaries=None) + @property + def fetches(self) -> Dict[str, tf.Tensor]: + return {"train_op": self.train_op, + "losses": self.objective_values, + "_update_ops": tf.get_collection(tf.GraphKeys.UPDATE_OPS)} diff --git a/neuralmonkey/trainers/multitask_trainer.py b/neuralmonkey/trainers/multitask_trainer.py index e1e4209a5..96fb04d7b 100644 --- a/neuralmonkey/trainers/multitask_trainer.py +++ b/neuralmonkey/trainers/multitask_trainer.py @@ -1,8 +1,10 @@ -from typing import List +from typing import List, Dict +import tensorflow as tf from typeguard import check_argument_types -from neuralmonkey.runners.base_runner import Executable, GraphExecutor +from neuralmonkey.decorators import tensor +from neuralmonkey.runners.base_runner import GraphExecutor from neuralmonkey.trainers.generic_trainer import GenericTrainer @@ -27,10 +29,17 @@ def __init__(self, def get_executable( self, compute_losses: bool = True, summaries: bool = True, - num_sessions: int = 1) -> Executable: + num_sessions: int = 1) -> GraphExecutor.Executable: focused_trainer = self.trainers[self.trainer_idx] self.trainer_idx = (self.trainer_idx + 1) % len(self.trainers) return focused_trainer.get_executable( compute_losses, summaries, num_sessions) + + @tensor + def fetches(self) -> Dict[str, tf.Tensor]: + fetches = {} + for trainer in self.trainers: + fetches.update(trainer.fetches) + return fetches diff --git a/tests/pydocstyle_run.sh b/tests/pydocstyle_run.sh index 9dedf49fa..88c74b99f 100755 --- a/tests/pydocstyle_run.sh +++ b/tests/pydocstyle_run.sh @@ -9,6 +9,6 @@ IGNORED="D202,D203,D213,D406,D407,D408,D409,D413" # These are currently turned off on master branch # because of the missing docstrings. However, they should # be switched on in the future. -IGNORED="D100,D101,D102,D103,D104,D107,$IGNORED" +IGNORED="D100,D101,D102,D103,D104,D106,D107,$IGNORED" pydocstyle --ignore=$IGNORED neuralmonkey From 0e329a89cc48725cf8328283e84276181a6eff05 Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Mon, 26 Nov 2018 16:17:13 +0100 Subject: [PATCH 09/16] Lazifying autoregressive decoder family and input seqeunce --- neuralmonkey/decoders/autoregressive.py | 55 +++++----- neuralmonkey/decoders/beam_search_decoder.py | 6 +- neuralmonkey/decoders/decoder.py | 100 +++++++++++-------- neuralmonkey/decoders/transformer.py | 83 ++++++++------- neuralmonkey/model/sequence.py | 20 ++-- neuralmonkey/tests/test_decoder.py | 3 +- 6 files changed, 155 insertions(+), 112 deletions(-) diff --git a/neuralmonkey/decoders/autoregressive.py b/neuralmonkey/decoders/autoregressive.py index 86fe2a4b9..527245aa4 100644 --- a/neuralmonkey/decoders/autoregressive.py +++ b/neuralmonkey/decoders/autoregressive.py @@ -15,7 +15,7 @@ from neuralmonkey.model.feedable import FeedDict from neuralmonkey.model.parameterized import InitializerSpecs from neuralmonkey.model.model_part import ModelPart -from neuralmonkey.logging import log, warn +from neuralmonkey.logging import warn from neuralmonkey.model.sequence import EmbeddedSequence from neuralmonkey.nn.utils import dropout from neuralmonkey.tf_utils import get_variable, get_state_shape_invariants @@ -134,55 +134,62 @@ def __init__(self, ModelPart.__init__(self, name, reuse, save_checkpoint, load_checkpoint, initializers) - log("Initializing decoder, name: '{}'".format(name)) - self.vocabulary = vocabulary self.data_id = data_id self.max_output_len = max_output_len self.dropout_keep_prob = dropout_keep_prob - self.embedding_size = embedding_size + self._embedding_size = embedding_size self.embeddings_source = embeddings_source self.label_smoothing = label_smoothing self.tie_embeddings = tie_embeddings self.supress_unk = supress_unk - self.encoder_states = [] # type: List[tf.Tensor] - self.encoder_masks = [] # type: List[tf.Tensor] + self.encoder_states = lambda: [] # type: Callable[[], List[tf.Tensor]] + self.encoder_masks = lambda: [] # type: Callable[[], List[tf.Tensor]] # Check the values of the parameters (max_output_len, ...) if self.max_output_len <= 0: raise ValueError( "Maximum sequence length must be a positive integer.") - if self.embedding_size is not None and self.embedding_size <= 0: + if self._embedding_size is not None and self._embedding_size <= 0: raise ValueError("Embedding size must be a positive integer.") if self.dropout_keep_prob < 0.0 or self.dropout_keep_prob > 1.0: raise ValueError("Dropout keep probability must be a real number " "in the interval [0,1].") + # pylint: enable=too-many-arguments,too-many-locals - if self.embedding_size is None and self.embeddings_source is None: - raise ValueError( - "You must specify either embedding size or the embedded " - "sequence from which to reuse the embeddings (e.g. set either " - "'embedding_size' or 'embeddings_source' parameter)") + @property + def embedding_size(self) -> int: + if self.embeddings_source is None: + if self._embedding_size is None: + raise ValueError( + "You must specify either embedding size or the embedded " + "sequence from which to reuse the embeddings (e.g. set " + "'embedding_size' or 'embeddings_source' parameter)") + return self._embedding_size if self.embeddings_source is not None: - if self.embedding_size is not None: - warn("Overriding the embedding_size parameter with the" - " size of the reused embeddings from the encoder.") + if self._embedding_size is not None: + warn("Overriding the embedding_size parameter with the " + "size of the reused embeddings from the encoder.") - self.embedding_size = ( - self.embeddings_source.embedding_matrix.get_shape()[1].value) + return self.embeddings_source.embedding_matrix.get_shape()[1].value - with self.use_scope(): - self.go_symbols = tf.placeholder(tf.int32, [None], "go_symbols") + # pylint: disable=no-self-use + @tensor + def go_symbols(self) -> tf.Tensor: + return tf.placeholder(tf.int32, [None], "go_symbols") - self.train_inputs = tf.placeholder( - tf.int32, [None, None], "train_inputs") - self.train_mask = tf.placeholder( - tf.float32, [None, None], "train_mask") - # pylint: enable=too-many-arguments,too-many-locals + @tensor + def train_inputs(self) -> tf.Tensor: + return tf.placeholder(tf.int32, [None, None], "train_inputs") + + @tensor + def train_mask(self) -> tf.Tensor: + return tf.placeholder(tf.float32, [None, None], "train_mask") + # pylint: enable=no-self-use @tensor def decoding_w(self) -> tf.Variable: diff --git a/neuralmonkey/decoders/beam_search_decoder.py b/neuralmonkey/decoders/beam_search_decoder.py index f065c720c..8596da3d9 100644 --- a/neuralmonkey/decoders/beam_search_decoder.py +++ b/neuralmonkey/decoders/beam_search_decoder.py @@ -163,13 +163,15 @@ def __init__(self, # the beam. We need to access all the inner states of the network in # the graph, replace them with beam-size-times copied originals, create # the beam search graph, and then replace the inner states back. + self._building = False + enc_states = self.parent_decoder.encoder_states enc_masks = self.parent_decoder.encoder_masks setattr(self.parent_decoder, "encoder_states", - [self.expand_to_beam(states) for states in enc_states]) + lambda: [self.expand_to_beam(sts) for sts in enc_states()]) setattr(self.parent_decoder, "encoder_masks", - [self.expand_to_beam(mask) for mask in enc_masks]) + lambda: [self.expand_to_beam(mask) for mask in enc_masks()]) # Create the beam search symbolic graph. with self.use_scope(): diff --git a/neuralmonkey/decoders/decoder.py b/neuralmonkey/decoders/decoder.py index 6957cccd4..0677d29bc 100644 --- a/neuralmonkey/decoders/decoder.py +++ b/neuralmonkey/decoders/decoder.py @@ -154,55 +154,25 @@ def __init__(self, initializers=initializers) self.encoders = encoders - self.output_projection_spec = output_projection + self._output_projection_spec = output_projection self._conditional_gru = conditional_gru self._attention_on_input = attention_on_input self._rnn_cell_str = rnn_cell + self._rnn_size = rnn_size + self._encoder_projection = encoder_projection self.attentions = [] # type: List[BaseAttention] if attentions is not None: self.attentions = attentions - if rnn_size is not None: - self.rnn_size = rnn_size - - if encoder_projection is not None: - self.encoder_projection = encoder_projection - elif not self.encoders: - log("No direct encoder input. Using empty initial state") - self.encoder_projection = empty_initial_state - elif rnn_size is None: - log("No rnn_size or encoder_projection: Using concatenation of" - " encoded states") - self.encoder_projection = concat_encoder_projection - self.rnn_size = sum(e.output.get_shape()[1].value - for e in encoders) - else: - log("Using linear projection of encoders as the initial state") - self.encoder_projection = linear_encoder_projection( - self.dropout_keep_prob) - - assert self.rnn_size is not None + if not rnn_size and not encoder_projection and not encoders: + raise ValueError( + "No RNN size, no encoders and no encoder_projection specified") if self._rnn_cell_str not in RNN_CELL_TYPES: raise ValueError("RNN cell must be a either 'GRU', 'LSTM', or " "'NematusGRU'. Not {}".format(self._rnn_cell_str)) - if self.output_projection_spec is None: - log("No output projection specified - using tanh projection") - self.output_projection = nonlinear_output( - self.rnn_size, tf.tanh)[0] - self.output_projection_size = self.rnn_size - elif isinstance(self.output_projection_spec, tuple): - self.output_projection_spec = cast( - Tuple[OutputProjection, int], self.output_projection_spec) - (self.output_projection, - self.output_projection_size) = self.output_projection_spec - else: - self.output_projection = cast( - OutputProjection, self.output_projection_spec) - self.output_projection_size = self.rnn_size - if self._attention_on_input: self.input_projection = self.input_plus_attention else: @@ -216,6 +186,56 @@ def __init__(self, tf.random_normal_initializer(stddev=0.001)) # pylint: enable=too-many-arguments,too-many-branches,too-many-statements + @property + def encoder_projection(self) -> EncoderProjection: + if self._encoder_projection is not None: + return self._encoder_projection + + if not self.encoders: + log("No direct encoder input. Using empty initial state") + return empty_initial_state + + if self._rnn_size is None: + log("No rnn_size or encoder_projection: Using concatenation of " + "encoded states") + return concat_encoder_projection + + log("Using linear projection of encoders as the initial state") + return linear_encoder_projection(self.dropout_keep_prob) + + @property + def rnn_size(self) -> int: + if self._rnn_size is not None: + return self._rnn_size + + if self._encoder_projection is None: + assert self.encoders + return sum(e.output.get_shape()[1].value for e in self.encoders) + + raise ValueError("Cannot infer RNN size.") + + @tensor + def output_projection_spec(self) -> Tuple[OutputProjection, int]: + if self._output_projection_spec is None: + log("No output projection specified - using tanh projection") + return (nonlinear_output(self.rnn_size, tf.tanh)[0], self.rnn_size) + + if isinstance(self._output_projection_spec, tuple): + return self._output_projection_spec + + return cast(OutputProjection, + self._output_projection_spec), self.rnn_size + + # pylint: disable=unsubscriptable-object + @property + def output_projection(self) -> OutputProjection: + return self.output_projection_spec[0] + + @property + def output_dimension(self) -> int: + return self.output_projection_spec[1] + # pylint: enable=unsubscriptable-object + @tensor def initial_state(self) -> tf.Tensor: """Compute initial decoder state. @@ -224,12 +244,14 @@ def initial_state(self) -> tf.Tensor: the initial state of the decoder. """ with tf.variable_scope("initial_state"): + # pylint: disable=not-callable initial_state = dropout( self.encoder_projection(self.train_mode, self.rnn_size, self.encoders), self.dropout_keep_prob, self.train_mode) + # pylint: enable=not-callable init_state_shape = initial_state.get_shape() @@ -242,10 +264,6 @@ def initial_state(self) -> tf.Tensor: return initial_state - @property - def output_dimension(self) -> int: - return self.output_projection_size - def _get_rnn_cell(self) -> tf.contrib.rnn.RNNCell: return RNN_CELL_TYPES[self._rnn_cell_str](self.rnn_size) @@ -338,9 +356,11 @@ def body(*args) -> LoopState: self.embedding_matrix, loop_state.feedables.input_symbol) + # pylint: disable=not-callable output = self.output_projection( cell_output, embedded_input, list(contexts), self.train_mode) + # pylint: enable=not-callable logits = self.get_logits(output) / temperature diff --git a/neuralmonkey/decoders/transformer.py b/neuralmonkey/decoders/transformer.py index 5c12b8d2e..b6cefb976 100644 --- a/neuralmonkey/decoders/transformer.py +++ b/neuralmonkey/decoders/transformer.py @@ -161,33 +161,10 @@ def __init__(self, self.attention_combination_strategy = attention_combination_strategy self.n_heads_hier = n_heads_hier - self.encoder_states = [get_attention_states(e) for e in self.encoders] - self.encoder_masks = [get_attention_mask(e) for e in self.encoders] - - if self.encoder_states: - self.dimension = ( - self.encoder_states[0].get_shape()[2].value) # type: int - - for i, enc_states in enumerate(self.encoder_states): - enc_dim = enc_states.get_shape()[2].value - if enc_dim != self.dimension: - raise ValueError( - "Dimension of the {}-th encoder ({}) differs from the " - "dimension of the first one ({})." - .format(i, enc_dim, self.dimension)) - - elif not self.embedding_size: - raise ValueError("'embedding_size' must be specified when " - "no encoders are provided") - else: - self.dimension = self.embedding_size - - if not self.dimension: - raise ValueError("Decoder could not infer model dimension") - - if self.embedding_size != self.dimension: - raise ValueError("Model dimension and input embedding size" - "do not match") + self.encoder_states = lambda: [get_attention_states(e) + for e in self.encoders] + self.encoder_masks = lambda: [get_attention_mask(e) + for e in self.encoders] if self.attention_combination_strategy not in STRATEGIES: raise ValueError( @@ -216,6 +193,33 @@ def __init__(self, mode="fan_avg", distribution="uniform")) # pylint: enable=too-many-arguments,too-many-locals,too-many-branches + @property + def dimension(self) -> int: + enc_states = self.encoder_states() + + if enc_states: + first_dim = enc_states[0].get_shape()[2].value # type: int + + for i, states in enumerate(enc_states): + enc_dim = states.get_shape()[2].value + if enc_dim != first_dim: + raise ValueError( + "Dimension of the {}-th encoder ({}) differs from the " + "dimension of the first one ({})." + .format(i, enc_dim, first_dim)) + + if self.embedding_size is not None: + if self.embedding_size != first_dim: + raise ValueError("Model dimension and input embedding " + "size do not match") + return first_dim + + if self.embedding_size is None: + raise ValueError("'embedding_size' must be specified when " + "no encoders are provided") + + return self.embedding_size + @property def output_dimension(self) -> int: return self.dimension @@ -242,9 +246,11 @@ def embedded_train_inputs(self) -> tf.Tensor: # (just as a target) # shape (batch, 1 + (time - 1)) + # pylint: disable=unsubscriptable-object input_tokens = tf.concat( [tf.expand_dims(self.go_symbols, 1), tf.transpose(self.train_inputs[:-1])], 1) + # pylint: enable=unsubscriptable-object input_embeddings = self.embed_inputs(input_tokens) @@ -281,8 +287,10 @@ def self_attention_sublayer( def encoder_attention_sublayer(self, queries: tf.Tensor) -> tf.Tensor: """Create the encoder-decoder attention sublayer.""" - assert self.encoder_states is not None - assert self.encoder_masks is not None + enc_states = self.encoder_states() + enc_masks = self.encoder_masks() + assert enc_states is not None + assert enc_masks is not None # Attention dropout callbacks are created in a loop so we need to # use a factory function to prevent late binding. @@ -297,27 +305,26 @@ def callback(x: tf.Tensor) -> tf.Tensor: for prob in self.attention_dropout_keep_prob] if self.attention_combination_strategy == "serial": - return serial(queries, self.encoder_states, self.encoder_masks, - self.n_heads_enc, attn_dropout_cbs, dropout_cb) + return serial(queries, enc_states, enc_masks, self.n_heads_enc, + attn_dropout_cbs, dropout_cb) if self.attention_combination_strategy == "parallel": - return parallel(queries, self.encoder_states, self.encoder_masks, - self.n_heads_enc, attn_dropout_cbs, dropout_cb) + return parallel(queries, enc_states, enc_masks, self.n_heads_enc, + attn_dropout_cbs, dropout_cb) if self.attention_combination_strategy == "flat": assert len(set(self.n_heads_enc)) == 1 assert len(set(self.attention_dropout_keep_prob)) == 1 - return flat(queries, self.encoder_states, self.encoder_masks, - self.n_heads_enc[0], attn_dropout_cbs[0], dropout_cb) + return flat(queries, enc_states, enc_masks, self.n_heads_enc[0], + attn_dropout_cbs[0], dropout_cb) if self.attention_combination_strategy == "hierarchical": assert self.n_heads_hier is not None return hierarchical( - queries, self.encoder_states, self.encoder_masks, - self.n_heads_enc, self.n_heads_hier, attn_dropout_cbs, - dropout_cb) + queries, enc_states, enc_masks, self.n_heads_enc, + self.n_heads_hier, attn_dropout_cbs, dropout_cb) raise NotImplementedError( "Unknown attention combination strategy: {}" diff --git a/neuralmonkey/model/sequence.py b/neuralmonkey/model/sequence.py index 42a4303e3..9ea93dbb1 100644 --- a/neuralmonkey/model/sequence.py +++ b/neuralmonkey/model/sequence.py @@ -127,16 +127,22 @@ def __init__(self, raise ValueError( "When reusing embeedings, embeddings sizes must be equal.") - with self.use_scope(): - self.mask = tf.placeholder(tf.float32, [None, None], "mask") - self.input_factors = [ - tf.placeholder(tf.int32, [None, None], "factor_{}".format(did)) - for did in self.data_ids] - self._variable_scope.set_initializer( tf.random_normal_initializer(stddev=0.001)) # pylint: enable=too-many-arguments + # pylint: disable=no-self-use + @tensor + def mask(self) -> tf.Tensor: + return tf.placeholder(tf.float32, [None, None], "mask") + + @tensor + def input_factors(self) -> List[tf.Tensor]: + return [ + tf.placeholder(tf.int32, [None, None], "factor_{}".format(did)) + for did in self.data_ids] + # pylint: enable=no-self-use + # TODO this should be placed into the abstract embedding class def tb_embedding_visualization(self, logdir: str, prj: projector): @@ -301,12 +307,12 @@ def __init__(self, initializers=initializers) # pylint: enable=too-many-arguments + # pylint: disable=unsubscriptable-object @property def inputs(self) -> tf.Tensor: """Return a 2D placeholder for the sequence inputs.""" return self.input_factors[0] - # pylint: disable=unsubscriptable-object @property def embedding_matrix(self) -> tf.Tensor: """Return the embedding matrix for the sequence.""" diff --git a/neuralmonkey/tests/test_decoder.py b/neuralmonkey/tests/test_decoder.py index 4b1e7c968..7c89865b1 100644 --- a/neuralmonkey/tests/test_decoder.py +++ b/neuralmonkey/tests/test_decoder.py @@ -48,7 +48,8 @@ def test_embedding_size(self): dparams["embedding_size"] = None with self.assertRaises(ValueError): - Decoder(**dparams) + dec = Decoder(**dparams) + print(dec.embedding_size) dparams["embedding_size"] = -10 with self.assertRaises(ValueError): From 00d924c5f78c4823ed595b61926e1200b47d47b2 Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Mon, 26 Nov 2018 23:26:37 +0100 Subject: [PATCH 10/16] Moving configuration disambiguation to a separate module freed some space in learning utils and experiment's build_model function --- neuralmonkey/config/configuration.py | 2 + neuralmonkey/config/disambiguate.py | 135 +++++++++++++++++++++++++ neuralmonkey/experiment.py | 36 +++---- neuralmonkey/learning_utils.py | 144 ++++----------------------- 4 files changed, 174 insertions(+), 143 deletions(-) create mode 100644 neuralmonkey/config/disambiguate.py diff --git a/neuralmonkey/config/configuration.py b/neuralmonkey/config/configuration.py index bf25871e6..7ef2fb04a 100644 --- a/neuralmonkey/config/configuration.py +++ b/neuralmonkey/config/configuration.py @@ -6,6 +6,7 @@ from neuralmonkey.logging import log from neuralmonkey.config.builder import build_config from neuralmonkey.config.parsing import parse_file, write_file +from neuralmonkey.config.disambiguate import disambiguate_configuration class Configuration: @@ -97,6 +98,7 @@ def build_model(self, warn_unused=False) -> None: exit(1) log("Model built.") self.model = self.make_namespace(model) + disambiguate_configuration(self.model) def _check_loaded_conf(self) -> None: """Check whether there are unexpected or missing fields.""" diff --git a/neuralmonkey/config/disambiguate.py b/neuralmonkey/config/disambiguate.py new file mode 100644 index 000000000..db9f3d93b --- /dev/null +++ b/neuralmonkey/config/disambiguate.py @@ -0,0 +1,135 @@ +"""Module for disambiguating and enhancing configuration.""" + +from argparse import Namespace +from datetime import timedelta +import re +import time +from typing import List, Union, Callable + +import numpy as np + +from neuralmonkey.dataset import BatchingScheme +from neuralmonkey.logging import warn +from neuralmonkey.tf_manager import get_default_tf_manager +from neuralmonkey.trainers.delayed_update_trainer import DelayedUpdateTrainer + + +# pylint: disable=too-many-branches +def disambiguate_configuration(cfg: Namespace) -> None: + if cfg.tf_manager is None: + cfg.tf_manager = get_default_tf_manager() + + if not isinstance(cfg.val_dataset, List): + cfg.val_datasets = [cfg.val_dataset] + else: + cfg.val_datasets = cfg.val_dataset + + if (cfg.batch_size is None) == (cfg.batching_scheme is None): + raise ValueError("You must specify either batch_size or " + "batching_scheme (not both).") + + if cfg.batch_size is not None: + assert cfg.batching_scheme is None + cfg.batching_scheme = BatchingScheme(batch_size=cfg.batch_size) + else: + assert cfg.batching_scheme is not None + cfg.batch_size = cfg.batching_scheme.batch_size + + if cfg.runners_batch_size is None: + cfg.runners_batch_size = cfg.batching_scheme.batch_size + + cfg.runners_batching_scheme = BatchingScheme( + batch_size=cfg.runners_batch_size, + token_level_batching=cfg.batching_scheme.token_level_batching, + use_leftover_buckets=True) + + if not isinstance(cfg.trainer, List): + cfg.trainers = [cfg.trainer] + else: + cfg.trainers = cfg.trainer + + cfg.evaluation = [(e[0], e[0], e[1]) if len(e) == 2 else e + for e in cfg.evaluation] + + if cfg.evaluation: + cfg.main_metric = "{}/{}".format(cfg.evaluation[-1][0], + cfg.evaluation[-1][-1].name) + else: + cfg.main_metric = "{}/{}".format(cfg.runners[-1].decoder_data_id, + cfg.runners[-1].loss_names[0]) + + if not cfg.tf_manager.minimize_metric: + raise ValueError("minimize_metric must be set to True in " + "TensorFlowManager when using loss as " + "the main metric") + + # deal with delayed trainer and logging periods + # the correct way if there are more trainers is perhaps to do a + # lowest common denominator of their batches_per_update. + # But we can also warn because it is a very weird setup. + + delayed_trainers = [t for t in cfg.trainers + if isinstance(t, DelayedUpdateTrainer)] + + denominator = 1 + if len(cfg.trainers) > 1 and delayed_trainers: + warn("Weird setup: using more trainers and one of them is delayed " + "update trainer. No-one can vouch for your safety, user!") + warn("Using the lowest common denominator of all delayed trainers'" + " batches_per_update parameters for logging period") + warn("Note that if you are using a multi-task trainer, it is on " + "your own risk") + + denominator = np.lcm.reduce([t.batches_per_update + for t in delayed_trainers]) + elif delayed_trainers: + assert len(cfg.trainers) == 1 + denominator = cfg.trainers[0].batches_per_update + + cfg.log_timer = _resolve_period(cfg.logging_period, denominator) + cfg.val_timer = _resolve_period(cfg.validation_period, denominator) + + +def _resolve_period(period: Union[str, int], + denominator: int) -> Callable[[int, float], bool]: + + def get_batch_logger(period: int) -> Callable[[int, float], bool]: + def is_time(step: int, _: float) -> bool: + return step != 0 and step % period == 0 + return is_time + + def get_time_logger(period: float) -> Callable[[int, float], bool]: + def is_time(step: int, last_time: float) -> bool: + if step % denominator != 0: + return False + return last_time + period < time.process_time() + return is_time + + if isinstance(period, int): + if period % denominator != 0: + raise ValueError( + "When using delayed update trainer, the logging/validation " + "periods must be divisible by batches_per_update.") + + return get_batch_logger(period) + + regex = re.compile( + r"((?P\d+?)d)?((?P\d+?)h)?((?P\d+?)m)?" + r"((?P\d+?)s)?") + parts = regex.match(period) + + if not parts: + raise ValueError( + "Validation or logging period have incorrect format. " + "It should be in format: 3h; 5m; 14s") + + time_params = {} + for (name, param) in parts.groupdict().items(): + if param: + time_params[name] = int(param) + + delta_seconds = timedelta(**time_params).total_seconds() + if delta_seconds <= 0: + raise ValueError("Validation or logging period must be bigger than 0") + + return get_time_logger(delta_seconds) diff --git a/neuralmonkey/experiment.py b/neuralmonkey/experiment.py index dc4ec9386..5128f5751 100644 --- a/neuralmonkey/experiment.py +++ b/neuralmonkey/experiment.py @@ -1,6 +1,7 @@ """Provides a high-level API for training and using a model.""" +# pylint: disable=too-many-lines -from argparse import Namespace # pylint: disable=unused-import +from argparse import Namespace import os import random import shutil @@ -15,15 +16,14 @@ from neuralmonkey.checking import (check_dataset_and_coders, CheckingException) +from neuralmonkey.dataset import BatchingScheme, Dataset from neuralmonkey.logging import Logging, log, debug, warn from neuralmonkey.config.configuration import Configuration from neuralmonkey.learning_utils import (training_loop, evaluation, run_on_dataset, print_final_evaluation) -from neuralmonkey.dataset import Dataset, BatchingScheme from neuralmonkey.model.sequence import EmbeddedFactorSequence from neuralmonkey.runners.base_runner import ExecutionResult -from neuralmonkey.tf_manager import get_default_tf_manager _TRAIN_ARGS = [ @@ -106,6 +106,7 @@ def model(self) -> Namespace: return self._model def _bless_graph_executors(self) -> None: + log("Building TF Graph") if hasattr(self.model, "trainer"): if isinstance(self.model.trainer, List): trainers = self.model.trainer @@ -113,10 +114,11 @@ def _bless_graph_executors(self) -> None: trainers = [self.model.trainer] for trainer in trainers: - log("Trainer fetches: {}".format(trainer.fetches)) + debug("Trainer fetches: {}".format(trainer.fetches)) for runner in self.model.runners: - log("Runner fetches: {}".format(runner.fetches)) + debug("Runner fetches: {}".format(runner.fetches)) + log("TF Graph built") def build_model(self) -> None: if self._model_built: @@ -134,16 +136,11 @@ def build_model(self) -> None: self.config.build_model(warn_unused=self.train_mode) self._model = self.config.model self._model_built = True + self._bless_graph_executors() type(self)._current_experiment = None - if self.model.runners_batch_size is None: - self.model.runners_batch_size = self.model.batch_size - - if self.model.tf_manager is None: - self.model.tf_manager = get_default_tf_manager() - if self.train_mode: check_dataset_and_coders(self.model.train_dataset, self.model.runners) @@ -184,23 +181,23 @@ def train(self) -> None: training_loop( tf_manager=self.model.tf_manager, epochs=self.model.epochs, - trainer=self.model.trainer, - batch_size=self.model.batch_size, + trainers=self.model.trainers, batching_scheme=self.model.batching_scheme, + runners_batching_scheme=self.model.runners_batching_scheme, log_directory=self.model.output, evaluators=self.model.evaluation, + main_metric=self.model.main_metric, runners=self.model.runners, train_dataset=self.model.train_dataset, - val_dataset=self.model.val_dataset, + val_datasets=self.model.val_datasets, test_datasets=self.model.test_datasets, - logging_period=self.model.logging_period, - validation_period=self.model.validation_period, + log_timer=self.model.log_timer, + val_timer=self.model.val_timer, val_preview_input_series=self.model.val_preview_input_series, val_preview_output_series=self.model.val_preview_output_series, val_preview_num_examples=self.model.val_preview_num_examples, postprocess=self.model.postprocess, train_start_offset=self.model.train_start_offset, - runners_batch_size=self.model.runners_batch_size, initial_variables=self.model.initial_variables, final_variables=self.get_path("variables.data.final")) @@ -257,10 +254,13 @@ def run_model(self, if not self._vars_loaded: self.load_variables() + toklevel = self.model.runners_batching_scheme.token_level_batching + assert self.model.runners_batching_scheme.batch_bucket_span is None + batching_scheme = BatchingScheme( batch_size=batch_size or self.model.runners_batch_size, batch_bucket_span=None, - token_level_batching=False, + token_level_batching=toklevel, bucketing_ignore_series=[]) with self.graph.as_default(): diff --git a/neuralmonkey/learning_utils.py b/neuralmonkey/learning_utils.py index d27aca04c..c7dcdfce2 100644 --- a/neuralmonkey/learning_utils.py +++ b/neuralmonkey/learning_utils.py @@ -7,8 +7,6 @@ # pylint: enable=unused-import import time -import re -from datetime import timedelta import numpy as np import tensorflow as tf from termcolor import colored @@ -33,29 +31,29 @@ # pylint: enable=invalid-name -# pylint: disable=too-many-arguments, too-many-locals, too-many-branches -# pylint: disable=too-many-statements, too-many-nested-blocks +# pylint: disable=too-many-arguments,too-many-locals,too-many-nested-blocks +# pylint: disable=too-many-branches,too-many-statements def training_loop(tf_manager: TensorFlowManager, epochs: int, - trainer: Union[Trainer, List[Trainer]], + trainers: List[Trainer], + batching_scheme: BatchingScheme, + runners_batching_scheme: BatchingScheme, log_directory: str, evaluators: EvalConfiguration, + main_metric: str, runners: List[BaseRunner], - final_variables: str, train_dataset: Dataset, - val_dataset: Union[Dataset, List[Dataset]], - test_datasets: List[Dataset] = None, - logging_period: Union[str, int] = 20, - validation_period: Union[str, int] = 500, - val_preview_input_series: List[str] = None, - val_preview_output_series: List[str] = None, - val_preview_num_examples: int = 15, - train_start_offset: int = 0, - batch_size: int = None, - batching_scheme: BatchingScheme = None, - runners_batch_size: int = None, - initial_variables: Union[str, List[str]] = None, - postprocess: Postprocess = None) -> None: + val_datasets: List[Dataset], + test_datasets: Optional[List[Dataset]], + log_timer: Callable[[int, float], bool], + val_timer: Callable[[int, float], bool], + val_preview_input_series: Optional[List[str]], + val_preview_output_series: Optional[List[str]], + val_preview_num_examples: int, + postprocess: Optional[Postprocess], + train_start_offset: int, + initial_variables: Optional[Union[str, List[str]]], + final_variables: str) -> None: """Execute the training loop for given graph and data. Args: @@ -104,91 +102,15 @@ def training_loop(tf_manager: TensorFlowManager, """ check_argument_types() - if (batch_size is None) == (batching_scheme is None): - raise ValueError("You must specify either batch_size or " - "batching_scheme (not both).") - - if batch_size is not None: - assert batching_scheme is None - batching_scheme = BatchingScheme(batch_size=batch_size) - - assert batching_scheme is not None - - if runners_batch_size is None: - runners_batch_size = batching_scheme.batch_size - - runners_batching_scheme = BatchingScheme( - batch_size=runners_batch_size, - token_level_batching=batching_scheme.token_level_batching, - use_leftover_buckets=True) - - if isinstance(val_dataset, List): - val_datasets = val_dataset - else: - val_datasets = [val_dataset] - - log_period_batch, log_period_time = _resolve_period(logging_period) - val_period_batch, val_period_time = _resolve_period(validation_period) - _check_series_collisions(runners, postprocess) - if isinstance(trainer, List): - trainers = trainer - else: - trainers = [trainer] - _log_model_variables( var_list=list(set().union(*[t.var_list for t in trainers]))) - evaluators = [(e[0], e[0], e[1]) if len(e) == 2 else e - for e in evaluators] - - if evaluators: - main_metric = "{}/{}".format(evaluators[-1][0], - evaluators[-1][-1].name) - else: - main_metric = "{}/{}".format(runners[-1].decoder_data_id, - runners[-1].loss_names[0]) - - if not tf_manager.minimize_metric: - raise ValueError("minimize_metric must be set to True in " - "TensorFlowManager when using loss as " - "the main metric") - - if log_period_batch is not None and isinstance( - trainer, DelayedUpdateTrainer): - if log_period_batch % trainer.batches_per_update != 0: - raise ValueError("When using delayed update trainer, the logging " - "period must be divisible by batches_per_update.") - - if val_period_batch is not None and isinstance( - trainer, DelayedUpdateTrainer): - if val_period_batch % trainer.batches_per_update != 0: - raise ValueError("When using delayed update trainer, validation " - "period must be divisible by batches_per_update.") - step = 0 seen_instances = 0 last_seen_instances = 0 - def _is_logging_time(period_batch: Optional[int], - period_time: Optional[float], - last_time: float) -> bool: - if step == 0: - return False - - if period_batch is not None: - return step % period_batch == 0 - - assert period_time is not None - - # deal with delayed trainer - if isinstance(trainer, DelayedUpdateTrainer): - if step % trainer.batches_per_update != 0: - return False - - return last_time + period_time < time.process_time() - if initial_variables is None: # Assume we don't look at coder checkpoints when global # initial variables are supplied @@ -232,8 +154,7 @@ def _is_logging_time(period_batch: Optional[int], step += 1 seen_instances += len(batch) - if _is_logging_time(log_period_batch, log_period_time, - last_log_time): + if log_timer(step, last_log_time): trainer_result = tf_manager.execute( batch, feedables, trainers, train=True, summaries=True) train_results, train_outputs = run_on_dataset( @@ -256,8 +177,7 @@ def _is_logging_time(period_batch: Optional[int], tf_manager.execute(batch, feedables, trainers, train=True, summaries=False) - if _is_logging_time(val_period_batch, val_period_time, - last_val_time): + if val_timer(step, last_val_time): log_print("") val_duration_start = time.process_time() val_examples = 0 @@ -371,32 +291,6 @@ def _is_logging_time(period_batch: Optional[int], raise interrupt # pylint: disable=raising-bad-type -def _resolve_period( - period: Union[str, int]) -> Tuple[Optional[int], Optional[float]]: - if isinstance(period, int): - return period, None - - regex = re.compile( - r"((?P\d+?)d)?((?P\d+?)h)?((?P\d+?)m)?" - r"((?P\d+?)s)?") - parts = regex.match(period) - - if not parts: - raise ValueError( - "Validation or logging period have incorrect format. " - "It should be in format: 3h; 5m; 14s") - - time_params = {} - for (name, param) in parts.groupdict().items(): - if param: - time_params[name] = int(param) - - delta_seconds = timedelta(**time_params).total_seconds() - if delta_seconds <= 0: - raise ValueError("Validation or logging period must be bigger than 0") - return None, delta_seconds - - def _check_series_collisions(runners: List[BaseRunner], postprocess: Postprocess) -> None: """Check if output series names do not collide.""" From 0a66e0f6a9e0eaeb60d8ea10bde811ae311a11d4 Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Tue, 27 Nov 2018 00:13:53 +0100 Subject: [PATCH 11/16] Removing blessing of training Op from GenericTrainer --- neuralmonkey/config/configuration.py | 2 - neuralmonkey/config/disambiguate.py | 30 ++++++++------ neuralmonkey/experiment.py | 4 ++ neuralmonkey/tf_manager.py | 51 +++++++++++++----------- neuralmonkey/trainers/generic_trainer.py | 4 -- 5 files changed, 50 insertions(+), 41 deletions(-) diff --git a/neuralmonkey/config/configuration.py b/neuralmonkey/config/configuration.py index 7ef2fb04a..bf25871e6 100644 --- a/neuralmonkey/config/configuration.py +++ b/neuralmonkey/config/configuration.py @@ -6,7 +6,6 @@ from neuralmonkey.logging import log from neuralmonkey.config.builder import build_config from neuralmonkey.config.parsing import parse_file, write_file -from neuralmonkey.config.disambiguate import disambiguate_configuration class Configuration: @@ -98,7 +97,6 @@ def build_model(self, warn_unused=False) -> None: exit(1) log("Model built.") self.model = self.make_namespace(model) - disambiguate_configuration(self.model) def _check_loaded_conf(self) -> None: """Check whether there are unexpected or missing fields.""" diff --git a/neuralmonkey/config/disambiguate.py b/neuralmonkey/config/disambiguate.py index db9f3d93b..093f85faf 100644 --- a/neuralmonkey/config/disambiguate.py +++ b/neuralmonkey/config/disambiguate.py @@ -14,16 +14,14 @@ from neuralmonkey.trainers.delayed_update_trainer import DelayedUpdateTrainer -# pylint: disable=too-many-branches -def disambiguate_configuration(cfg: Namespace) -> None: +def disambiguate_configuration(cfg: Namespace, train_mode: bool) -> None: + + if train_mode: + _disambiguate_train_cfg(cfg) + if cfg.tf_manager is None: cfg.tf_manager = get_default_tf_manager() - if not isinstance(cfg.val_dataset, List): - cfg.val_datasets = [cfg.val_dataset] - else: - cfg.val_datasets = cfg.val_dataset - if (cfg.batch_size is None) == (cfg.batching_scheme is None): raise ValueError("You must specify either batch_size or " "batching_scheme (not both).") @@ -43,11 +41,6 @@ def disambiguate_configuration(cfg: Namespace) -> None: token_level_batching=cfg.batching_scheme.token_level_batching, use_leftover_buckets=True) - if not isinstance(cfg.trainer, List): - cfg.trainers = [cfg.trainer] - else: - cfg.trainers = cfg.trainer - cfg.evaluation = [(e[0], e[0], e[1]) if len(e) == 2 else e for e in cfg.evaluation] @@ -63,6 +56,19 @@ def disambiguate_configuration(cfg: Namespace) -> None: "TensorFlowManager when using loss as " "the main metric") + +def _disambiguate_train_cfg(cfg: Namespace) -> None: + + if not isinstance(cfg.val_dataset, List): + cfg.val_datasets = [cfg.val_dataset] + else: + cfg.val_datasets = cfg.val_dataset + + if not isinstance(cfg.trainer, List): + cfg.trainers = [cfg.trainer] + else: + cfg.trainers = cfg.trainer + # deal with delayed trainer and logging periods # the correct way if there are more trainers is perhaps to do a # lowest common denominator of their batches_per_update. diff --git a/neuralmonkey/experiment.py b/neuralmonkey/experiment.py index 5128f5751..1fe600a89 100644 --- a/neuralmonkey/experiment.py +++ b/neuralmonkey/experiment.py @@ -19,6 +19,7 @@ from neuralmonkey.dataset import BatchingScheme, Dataset from neuralmonkey.logging import Logging, log, debug, warn from neuralmonkey.config.configuration import Configuration +from neuralmonkey.config.disambiguate import disambiguate_configuration from neuralmonkey.learning_utils import (training_loop, evaluation, run_on_dataset, print_final_evaluation) @@ -134,10 +135,13 @@ def build_model(self) -> None: type(self)._current_experiment = self # type: ignore self.config.build_model(warn_unused=self.train_mode) + disambiguate_configuration(self.config.model, self.train_mode) + self._model = self.config.model self._model_built = True self._bless_graph_executors() + self.model.tf_manager.initialize_sessions() type(self)._current_experiment = None diff --git a/neuralmonkey/tf_manager.py b/neuralmonkey/tf_manager.py index dc428a525..e150d23da 100644 --- a/neuralmonkey/tf_manager.py +++ b/neuralmonkey/tf_manager.py @@ -25,6 +25,7 @@ FeedDict, ExecutionResult, GraphExecutor) +# pylint: disable=too-many-instance-attributes class TensorFlowManager: """Inteface between computational graph, data and TF sessions. @@ -38,7 +39,6 @@ def __init__(self, num_threads: int, save_n_best: int = 1, minimize_metric: bool = False, - variable_files: Optional[List[str]] = None, gpu_allow_growth: bool = True, per_process_gpu_memory_fraction: float = 1.0, enable_tf_debug: bool = False) -> None: @@ -54,19 +54,18 @@ def __init__(self, save_n_best: How many best models to keep minimize_metric: Whether the best model is the one with the lowest or the highest score - variable_files: List of variable files. gpu_allow_growth: TF to allocate incrementally, not all at once. per_process_gpu_memory_fraction: Limit TF memory use. """ check_argument_types() - session_cfg = tf.ConfigProto() - session_cfg.inter_op_parallelism_threads = num_threads - session_cfg.intra_op_parallelism_threads = num_threads - session_cfg.allow_soft_placement = True # needed for multiple GPUs + self.session_cfg = tf.ConfigProto() + self.session_cfg.inter_op_parallelism_threads = num_threads + self.session_cfg.intra_op_parallelism_threads = num_threads + self.session_cfg.allow_soft_placement = True # needed for more GPUs # pylint: disable=no-member - session_cfg.gpu_options.allow_growth = gpu_allow_growth - session_cfg.gpu_options.per_process_gpu_memory_fraction = \ + self.session_cfg.gpu_options.allow_growth = gpu_allow_growth + self.session_cfg.gpu_options.per_process_gpu_memory_fraction = \ per_process_gpu_memory_fraction # pylint: enable=no-member @@ -74,27 +73,16 @@ def __init__(self, raise Exception("save_n_best parameter must be greater than zero") self.saver_max_to_keep = save_n_best self.minimize_metric = minimize_metric + self.num_sessions = num_sessions - self.sessions = [tf.Session(config=session_cfg) - for _ in range(num_sessions)] + self.sessions = [tf.Session(config=self.session_cfg) + for _ in range(self.num_sessions)] if enable_tf_debug: self.sessions = [tf_debug.LocalCLIDebugWrapperSession(sess) for sess in self.sessions] - init_op = tf.global_variables_initializer() - for sess in self.sessions: - sess.run(init_op) - self.saver = tf.train.Saver(max_to_keep=None, - var_list=[g for g in tf.global_variables() - if "reward_" not in g.name]) - - if variable_files: - if len(variable_files) != num_sessions: - raise Exception(("The number of provided variable files ({}) " - "is different than a number sessions ({})") - .format(len(variable_files), num_sessions)) - self.restore(variable_files) + self.saver = None self.best_score_index = 0 self.best_score_epoch = 0 @@ -238,6 +226,9 @@ def execute(self, return [getattr(ex, "result") for ex in executables] def save(self, variable_files: Union[str, List[str]]) -> None: + if self.saver is None: + raise RuntimeError("Saver uninitialized") + if isinstance(variable_files, str) and len(self.sessions) == 1: self.saver.save(self.sessions[0], variable_files) return @@ -255,6 +246,9 @@ def save(self, variable_files: Union[str, List[str]]) -> None: self.saver.save(sess, file_name) def restore(self, variable_files: Union[str, List[str]]) -> None: + if self.saver is None: + raise RuntimeError("Saver uninitialized") + if isinstance(variable_files, str): variable_files = [variable_files] if len(variable_files) != len(self.sessions): @@ -271,6 +265,17 @@ def restore_best_vars(self) -> None: # TODO warn when link does not exist self.restore(self.variables_files[self.best_score_index]) + def initialize_sessions(self) -> None: + log("Initializing variables") + init_op = tf.global_variables_initializer() + for sess in self.sessions: + sess.run(init_op) + + log("Initializing tf.train.Saver") + self.saver = tf.train.Saver(max_to_keep=None, + var_list=[g for g in tf.global_variables() + if "reward_" not in g.name]) + def initialize_model_parts(self, runners: Sequence[GraphExecutor], save: bool = False) -> None: """Initialize model parts variables from their checkpoints.""" diff --git a/neuralmonkey/trainers/generic_trainer.py b/neuralmonkey/trainers/generic_trainer.py index 10ba4fb9d..f2b08eba2 100644 --- a/neuralmonkey/trainers/generic_trainer.py +++ b/neuralmonkey/trainers/generic_trainer.py @@ -5,7 +5,6 @@ from typeguard import check_argument_types from neuralmonkey.decorators import tensor -from neuralmonkey.logging import log from neuralmonkey.model.model_part import GenericModelPart from neuralmonkey.runners.base_runner import GraphExecutor, NextExecute @@ -101,9 +100,6 @@ def __init__(self, self.optimizer = ( optimizer if optimizer is not None else self.default_optimizer()) - log("Building model") - log("Train op: {}".format(str(self.train_op))) - # pylint: disable=no-self-use @tensor def regularization_losses(self) -> Tuple[tf.Tensor, tf.Tensor]: From 0bf5c29fe8838b53d54b99d9d9f7237fe9325782 Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Tue, 11 Dec 2018 17:46:07 +0100 Subject: [PATCH 12/16] Rename and add docstrings to the config disambiguation/normalization --- .../config/{disambiguate.py => normalize.py} | 48 +++++++++++++++++-- neuralmonkey/experiment.py | 4 +- 2 files changed, 45 insertions(+), 7 deletions(-) rename neuralmonkey/config/{disambiguate.py => normalize.py} (72%) diff --git a/neuralmonkey/config/disambiguate.py b/neuralmonkey/config/normalize.py similarity index 72% rename from neuralmonkey/config/disambiguate.py rename to neuralmonkey/config/normalize.py index 093f85faf..a8cbc8bf7 100644 --- a/neuralmonkey/config/disambiguate.py +++ b/neuralmonkey/config/normalize.py @@ -1,4 +1,10 @@ -"""Module for disambiguating and enhancing configuration.""" +"""Module for configuration normalization. + +The `[main]` configuration section contains arguments that can be filled with +different types of values, e.g. `trainer` can be either a single trainer +object or a list of them. This module provides functions for unifying the +configuration interface. +""" from argparse import Namespace from datetime import timedelta @@ -14,10 +20,16 @@ from neuralmonkey.trainers.delayed_update_trainer import DelayedUpdateTrainer -def disambiguate_configuration(cfg: Namespace, train_mode: bool) -> None: +def normalize_configuration(cfg: Namespace, train_mode: bool) -> None: + """Given a configuration namespace, normalize the values it contains. + Arguments: + cfg: The namespace object returned by `Configuration.make_namespace` + train_mode: Boolean flag controlling normalization of parameters only + used during training. + """ if train_mode: - _disambiguate_train_cfg(cfg) + _normalize_train_cfg(cfg) if cfg.tf_manager is None: cfg.tf_manager = get_default_tf_manager() @@ -57,8 +69,14 @@ def disambiguate_configuration(cfg: Namespace, train_mode: bool) -> None: "the main metric") -def _disambiguate_train_cfg(cfg: Namespace) -> None: +def _normalize_train_cfg(cfg: Namespace) -> None: + """Given a configuration namespace, normalize the values it contains. + + This function is only executed when training mode has been invoked. + Arguments: + cfg: The namespace object returned by `Configuration.make_namespace` + """ if not isinstance(cfg.val_dataset, List): cfg.val_datasets = [cfg.val_dataset] else: @@ -98,7 +116,27 @@ def _disambiguate_train_cfg(cfg: Namespace) -> None: def _resolve_period(period: Union[str, int], denominator: int) -> Callable[[int, float], bool]: - + """Convert logging period into a function for logging time checks. + + Logging and validation periods can both be provided either as a number of + batches after which to log/validate, or as a time interval between the + logs/validation runs. + + This function unifies both representations into a function that decides + whether to log/validate based on a given training step and time since the + last log/validation. + + Arguments: + period: Either a string representing time, or a number representing + number of batches. + denominator: Only allow logging when the given step (number of batches + since the start of the training) is divisible by this value. + This is used e.g. when `DelayedUpdateTrainer` is used. + + Returns: + A function of the current training step and time since the last logging + period that returns a boolean value. + """ def get_batch_logger(period: int) -> Callable[[int, float], bool]: def is_time(step: int, _: float) -> bool: return step != 0 and step % period == 0 diff --git a/neuralmonkey/experiment.py b/neuralmonkey/experiment.py index 1fe600a89..461184165 100644 --- a/neuralmonkey/experiment.py +++ b/neuralmonkey/experiment.py @@ -19,7 +19,7 @@ from neuralmonkey.dataset import BatchingScheme, Dataset from neuralmonkey.logging import Logging, log, debug, warn from neuralmonkey.config.configuration import Configuration -from neuralmonkey.config.disambiguate import disambiguate_configuration +from neuralmonkey.config.normalize import normalize_configuration from neuralmonkey.learning_utils import (training_loop, evaluation, run_on_dataset, print_final_evaluation) @@ -135,7 +135,7 @@ def build_model(self) -> None: type(self)._current_experiment = self # type: ignore self.config.build_model(warn_unused=self.train_mode) - disambiguate_configuration(self.config.model, self.train_mode) + normalize_configuration(self.config.model, self.train_mode) self._model = self.config.model self._model_built = True From 5bb6b1f28e03dd6bad8dfd6c97be0fe5f6bf1a18 Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Tue, 11 Dec 2018 18:00:39 +0100 Subject: [PATCH 13/16] Clarifying blessing TODO in ff attention module --- neuralmonkey/attention/feed_forward.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/neuralmonkey/attention/feed_forward.py b/neuralmonkey/attention/feed_forward.py index edef3c6f4..5118e5fb4 100644 --- a/neuralmonkey/attention/feed_forward.py +++ b/neuralmonkey/attention/feed_forward.py @@ -13,7 +13,7 @@ BaseAttention, AttentionLoopState, empty_attention_loop_state, get_attention_states, get_attention_mask, Attendable) from neuralmonkey.decorators import tensor -from neuralmonkey.logging import log +from neuralmonkey.logging import debug from neuralmonkey.model.model_part import ModelPart from neuralmonkey.model.parameterized import InitializerSpecs from neuralmonkey.nn.utils import dropout @@ -166,10 +166,18 @@ def attention(self, return context, next_loop_state def initial_loop_state(self) -> AttentionLoopState: - # TODO blessing - log("Pre-computing attention tensors") - log("Hidden features: {}".format(self.hidden_features)) - log("Hidden mask: {}".format(self.attention_mask)) + + # Here we need to make sure that the hidden_features and attention_mask + # are pre-computed. If this is used in combination with a decoder which + # has train and runtime while loops, these tensors need to be created + # outside of any of those loops in order to be available to both. + + # Note that we are not breaking lazy loading here because this method + # is called from a lazy tensor. + + debug("Pre-computing attention tensors", "bless") + debug("Hidden features: {}".format(self.hidden_features), "bless") + debug("Hidden mask: {}".format(self.attention_mask), "bless") return empty_attention_loop_state( self.batch_size, From 1496de4e23b870e9d7e10572782e5f97c135232d Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Tue, 11 Dec 2018 18:08:35 +0100 Subject: [PATCH 14/16] Simplifying next_to_execute Do not to return a list with an empty dict, but just an empty list. --- neuralmonkey/runners/base_runner.py | 4 ++-- neuralmonkey/runners/runner.py | 2 +- neuralmonkey/trainers/delayed_update_trainer.py | 2 +- neuralmonkey/trainers/generic_trainer.py | 2 +- neuralmonkey/trainers/test_multitask_trainer.py | 6 +++--- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/neuralmonkey/runners/base_runner.py b/neuralmonkey/runners/base_runner.py index 4026cf985..bc60ea1a2 100644 --- a/neuralmonkey/runners/base_runner.py +++ b/neuralmonkey/runners/base_runner.py @@ -73,7 +73,7 @@ def executor(self) -> Executor: def next_to_execute(self) -> NextExecute: """Get the tensors and additional feed dicts for execution.""" - return self.executor.fetches, [{}] + return self.executor.fetches, [] @abstractmethod def collect_results(self, results: List[Dict]) -> None: @@ -122,7 +122,7 @@ def next_to_execute(self) -> NextExecute: for loss in self.executor.loss_names: fetches[loss] = tf.zeros([]) - return fetches, [{}] + return fetches, [] # pylint: enable=too-few-public-methods def __init__(self, diff --git a/neuralmonkey/runners/runner.py b/neuralmonkey/runners/runner.py index 33956aad3..3106ad9b5 100644 --- a/neuralmonkey/runners/runner.py +++ b/neuralmonkey/runners/runner.py @@ -30,7 +30,7 @@ def next_to_execute(self) -> NextExecute: fetches["train_xent"] = tf.zeros([]) fetches["runtime_xent"] = tf.zeros([]) - return fetches, [{}] + return fetches, [] def collect_results(self, results: List[Dict]) -> None: train_loss = 0. diff --git a/neuralmonkey/trainers/delayed_update_trainer.py b/neuralmonkey/trainers/delayed_update_trainer.py index 5cd970708..6af121323 100644 --- a/neuralmonkey/trainers/delayed_update_trainer.py +++ b/neuralmonkey/trainers/delayed_update_trainer.py @@ -46,7 +46,7 @@ def next_to_execute(self) -> NextExecute: else: # RESETTING fetches = {"resets": self.executor.reset_ops} - return fetches, [{}] + return fetches, [] def collect_results(self, results: List[Dict]) -> None: assert len(results) == 1 diff --git a/neuralmonkey/trainers/generic_trainer.py b/neuralmonkey/trainers/generic_trainer.py index f2b08eba2..aaf164b0b 100644 --- a/neuralmonkey/trainers/generic_trainer.py +++ b/neuralmonkey/trainers/generic_trainer.py @@ -57,7 +57,7 @@ def next_to_execute(self) -> NextExecute: if self.summaries: fetches.update(self.executor.summaries) - return fetches, [{}] + return fetches, [] def collect_results(self, results: List[Dict]) -> None: assert len(results) == 1 diff --git a/neuralmonkey/trainers/test_multitask_trainer.py b/neuralmonkey/trainers/test_multitask_trainer.py index 92af59e5e..3d3c4dc5f 100644 --- a/neuralmonkey/trainers/test_multitask_trainer.py +++ b/neuralmonkey/trainers/test_multitask_trainer.py @@ -64,7 +64,7 @@ def test_mt_trainer(self): # mparts = trainer.feedables fetches, feeds = executable.next_to_execute() # self.assertSetEqual(mparts, {self.mpart}) - self.assertFalse(feeds[0]) + self.assertFalse(feeds) self.assertTrue(trainer.trainer_idx == 1) self.assertTrue(fetches["losses"][0] == self.mpart.loss) @@ -72,7 +72,7 @@ def test_mt_trainer(self): executable = trainer.get_executable() fetches, feeds = executable.next_to_execute() # self.assertSetEqual(mparts, {self.mpart_2}) - self.assertFalse(feeds[0]) + self.assertFalse(feeds) self.assertTrue(trainer.trainer_idx == 2) self.assertTrue(fetches["losses"][0] == self.mpart_2.loss) @@ -80,7 +80,7 @@ def test_mt_trainer(self): executable = trainer.get_executable() fetches, feeds = executable.next_to_execute() # self.assertSetEqual(mparts, {self.mpart}) - self.assertFalse(feeds[0]) + self.assertFalse(feeds) self.assertTrue(trainer.trainer_idx == 0) self.assertTrue(fetches["losses"][0] == self.mpart.loss) From 4e06b0d7a903aafa438572f79ca3fb4bf88d0c7b Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Tue, 11 Dec 2018 21:20:07 +0100 Subject: [PATCH 15/16] Adding docstrings to the Experiment class --- neuralmonkey/experiment.py | 69 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 67 insertions(+), 2 deletions(-) diff --git a/neuralmonkey/experiment.py b/neuralmonkey/experiment.py index 461184165..e3ae13103 100644 --- a/neuralmonkey/experiment.py +++ b/neuralmonkey/experiment.py @@ -101,12 +101,42 @@ def __init__(self, @property def model(self) -> Namespace: + """Get configuration namespace of the experiment. + + The `Experiment` stores the configuration recipe in `self.config`. + When the configuration is built (meaning the classes referenced from + the config file are instantiated), it is saved in the `model` property + of the experiment. + + Returns: + The built namespace config object. + + Raises: + `RuntimeError` when the configuration model has not been built. + """ if self._model is None: raise RuntimeError("Experiment argument model not initialized") return self._model def _bless_graph_executors(self) -> None: + """Pre-compute the tensors referenced by the graph executors. + + Due to the lazy nature of the computational graph related components, + nothing is actually added to the graph until it is "blessed" ( + referenced, and therefore, executed). + + "Blessing" is usually implemented in the form of a log or a debug call + with the blessed tensor as parameter. Referencing a `Tensor` causes the + whole computational graph that is needed to evaluate the tensor to be + built. + + This function "blesses" all tensors that could be potentially used + using the `fetches` property of the provided runner objects. + + If the experiment runs in the training mode, this function also + blesses the tensors fetched by the trainer(s). + """ log("Building TF Graph") if hasattr(self.model, "trainer"): if isinstance(self.model.trainer, List): @@ -115,13 +145,31 @@ def _bless_graph_executors(self) -> None: trainers = [self.model.trainer] for trainer in trainers: - debug("Trainer fetches: {}".format(trainer.fetches)) + debug("Trainer fetches: {}".format(trainer.fetches), "bless") for runner in self.model.runners: - debug("Runner fetches: {}".format(runner.fetches)) + debug("Runner fetches: {}".format(runner.fetches), "bless") log("TF Graph built") def build_model(self) -> None: + """Build the configuration and the computational graph. + + This function is invoked by all of the main entrypoints of the + `Experiment` class (`train`, `evaluate`, `run`). It manages the + building of the TensorFlow graph. + + The bulding procedure is executed as follows: + 1. Random seeds are set. + 2. Configuration is built (instantiated) and normalized. + 3. TODO(tf-data) tf.data.Dataset instance is created and registered + in the model parts. (This is not implemented yet!) + 4. Graph executors are "blessed". This causes the rest of the TF Graph + to be built. + 5. Sessions are initialized using the TF Manager object. + + Raises: + `RuntimeError` when the model is already built. + """ if self._model_built: raise RuntimeError("build_model() called twice") @@ -163,6 +211,15 @@ def build_model(self) -> None: self._check_unused_initializers() def train(self) -> None: + """Train model specified by this experiment. + + This function is one of the main functions (entrypoints) called on + the experiment. It builds the model (if needed) and runs the training + procedure. + + Raises: + `RuntimeError` when the experiment is not intended for training. + """ if not self.train_mode: raise RuntimeError("train() was called, but the experiment was " "created with train_mode=False") @@ -208,6 +265,14 @@ def train(self) -> None: self._vars_loaded = True def load_variables(self, variable_files: List[str] = None) -> None: + """Load variables from files. + + Arguments: + variable_files: A list of checkpoint file prefixes. A TF checkpoint + is usually three files with a common prefix. This list should + have the same number of files as there are sessions in the + `tf_manager` object. + """ if not self._model_built: self.build_model() From 4f0f44fe482c35d8f0a673cd5e95bac880793461 Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Tue, 11 Dec 2018 21:38:07 +0100 Subject: [PATCH 16/16] Added docstring to the GraphExecutor and BaseRunner classes. plus, removing warning about missing data_ids, since the runners are way past using only "decoders" with "output series" --- neuralmonkey/runners/base_runner.py | 47 +++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/neuralmonkey/runners/base_runner.py b/neuralmonkey/runners/base_runner.py index bc60ea1a2..8075563c2 100644 --- a/neuralmonkey/runners/base_runner.py +++ b/neuralmonkey/runners/base_runner.py @@ -4,7 +4,6 @@ import numpy as np import tensorflow as tf -from neuralmonkey.logging import notice from neuralmonkey.model.model_part import GenericModelPart from neuralmonkey.model.feedable import Feedable from neuralmonkey.model.parameterized import Parameterized @@ -25,7 +24,7 @@ class ExecutionResult(NamedTuple( ("scalar_summaries", tf.Summary), ("histogram_summaries", tf.Summary), ("image_summaries", tf.Summary)])): - """A data structure that represents a result of a graph execution. + """A data structure that represents the result of a graph execution. The goal of each runner is to populate this structure and set it as its ``self._result``. @@ -40,8 +39,41 @@ class ExecutionResult(NamedTuple( class GraphExecutor(GenericModelPart): + """The abstract parent class of all graph executors. + + In Neural Monkey, a graph executor is an object that retrieves tensors + from the computational graph. The two major groups of graph executors are + trainers and runners. + + Each graph executor is an instance of `GenericModelPart` class, which means + it has parameterized and feedable dependencies which reference the model + part objects needed to be created in order to compute the tensors of + interest (called "fetches"). + + Every graph executor has a method called `get_executable`, which returns + an `GraphExecutor.Executable` instance, which specifies what tensors to + execute and collects results from the session execution. + """ class Executable(Generic[Executor]): + """Abstract base class for executables. + + Executables are objects associated with the graph executors. Each + executable has two main functions: `next_to_execute` and + `collect_results`. These functions are called in a loop, until + the executable's result has been set. + + To make use of Mypy's type checking, the executables are generic and + are parameterized by the type of their graph executor. Since Python + does not know the concept of nested classes, each executable receives + the instance of the graph executor through its constructor. + + When subclassing `GraphExecutor`, it is also necessary to subclass + the `Executable` class and name it `Executable`, so it overrides the + definition of this class. Following this guideline, the default + implementation of the `get_executable` function on the graph executor + will work without the need of overriding it. + """ def __init__(self, executor: Executor, @@ -110,6 +142,12 @@ def parameterizeds(self) -> Set[Parameterized]: class BaseRunner(GraphExecutor, Generic[MP]): + """Base class for runners. + + Runners are graph executors that retrieve tensors from the model without + changing the model parameters. Each runner has a top-level model part it + relates to. + """ # pylint: disable=too-few-public-methods # Pylint issue here: https://github.com/PyCQA/pylint/issues/2607 @@ -130,12 +168,9 @@ def __init__(self, decoder: MP) -> None: GraphExecutor.__init__(self, {decoder}) self.output_series = output_series + # TODO(tf-data) rename decoder to something more general self.decoder = decoder - if not hasattr(decoder, "data_id"): - notice("Top-level decoder {} does not have the 'data_id' attribute" - .format(decoder)) - @property def decoder_data_id(self) -> Optional[str]: return getattr(self.decoder, "data_id", None)