diff --git a/neuralmonkey/runners/tensor_runner.py b/neuralmonkey/runners/tensor_runner.py index 21a4470dd..7ef7ec2ce 100644 --- a/neuralmonkey/runners/tensor_runner.py +++ b/neuralmonkey/runners/tensor_runner.py @@ -17,11 +17,13 @@ def __init__(self, all_coders: Set[ModelPart], fetches: FeedDict, batch_dims: Dict[str, int], - select_session: Optional[int]) -> None: + select_session: Optional[int], + single_tensor: bool) -> None: self._all_coders = all_coders self._fetches = fetches self._batch_dims = batch_dims self._select_session = select_session + self._single_tensor = single_tensor self.result = None # type: Optional[ExecutionResult] @@ -70,6 +72,10 @@ def _fetch_values_from_session(self, sess_results: Dict) -> List: batched = [dict(zip(transposed, col)) for col in zip(*transposed.values())] + if self._single_tensor: + # extract the only item from each dict + batched = [next(iter(d.values())) for d in batched] + return batched @@ -90,7 +96,8 @@ def __init__(self, tensors_by_ref: List[tf.Tensor], batch_dims_by_name: List[int], batch_dims_by_ref: List[int], - select_session: int = None) -> None: + select_session: int = None, + single_tensor: bool = False) -> None: """Construct a new ``TensorRunner`` object. Note that at this time, one must specify the toplevel objects so that @@ -120,15 +127,24 @@ def __init__(self, in case of ensembling. When not used, tensors from all sessions are stored. In case of a single session, this option has no effect. + single_tensor: If `True`, it is assumed that only one tensor is to + be fetched, and the execution result will consist of this + tensor only. If `False`, the result will be a dict mapping + tensor names to NumPy arrays. """ check_argument_types() BaseRunner[ModelPart].__init__(self, output_series, toplevel_modelpart) + total_tensors = len(tensors_by_name) + len(tensors_by_ref) + if single_tensor and total_tensors > 1: + raise ValueError("single_tensor is True, but {} tensors were given".format(total_tensors)) + self._names = tensors_by_name self._tensors = tensors_by_ref self._batch_dims_name = batch_dims_by_name self._batch_dims_ref = batch_dims_by_ref self._select_session = select_session + self._single_tensor = single_tensor log("Blessing toplevel tensors for tensor runner:") for tensor in toplevel_tensors: @@ -159,7 +175,7 @@ def get_executable(self, return TensorExecutable( self.all_coders, self._fetches, self._batch_ids, - self._select_session) + self._select_session, self._single_tensor) # pylint: enable=unused-argument @property @@ -206,4 +222,5 @@ def __init__(self, tensors_by_ref=[tensor_to_get], batch_dims_by_name=[], batch_dims_by_ref=[0], - select_session=select_session) + select_session=select_session, + single_tensor=True)