Skip to content

Commit

Permalink
Make RepresentationRunner return a single array instead of a dict
Browse files Browse the repository at this point in the history
  • Loading branch information
cifkao committed Aug 24, 2018
1 parent 5acfb1d commit bff868f
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions neuralmonkey/runners/tensor_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ def __init__(self,
all_coders: Set[ModelPart],
fetches: FeedDict,
batch_dims: Dict[str, int],
select_session: Optional[int]) -> None:
select_session: Optional[int],
single_tensor: bool) -> None:
self._all_coders = all_coders
self._fetches = fetches
self._batch_dims = batch_dims
self._select_session = select_session
self._single_tensor = single_tensor

self.result = None # type: Optional[ExecutionResult]

Expand Down Expand Up @@ -70,6 +72,10 @@ def _fetch_values_from_session(self, sess_results: Dict) -> List:
batched = [dict(zip(transposed, col))
for col in zip(*transposed.values())]

if self._single_tensor:
# extract the only item from each dict
batched = [next(iter(d.values())) for d in batched]

return batched


Expand All @@ -90,7 +96,8 @@ def __init__(self,
tensors_by_ref: List[tf.Tensor],
batch_dims_by_name: List[int],
batch_dims_by_ref: List[int],
select_session: int = None) -> None:
select_session: int = None,
single_tensor: bool = False) -> None:
"""Construct a new ``TensorRunner`` object.
Note that at this time, one must specify the toplevel objects so that
Expand Down Expand Up @@ -120,15 +127,24 @@ def __init__(self,
in case of ensembling. When not used, tensors from all sessions
are stored. In case of a single session, this option has no
effect.
single_tensor: If `True`, it is assumed that only one tensor is to
be fetched, and the execution result will consist of this
tensor only. If `False`, the result will be a dict mapping
tensor names to NumPy arrays.
"""
check_argument_types()
BaseRunner[ModelPart].__init__(self, output_series, toplevel_modelpart)

total_tensors = len(tensors_by_name) + len(tensors_by_ref)
if single_tensor and total_tensors > 1:
raise ValueError("single_tensor is True, but {} tensors were given".format(total_tensors))

self._names = tensors_by_name
self._tensors = tensors_by_ref
self._batch_dims_name = batch_dims_by_name
self._batch_dims_ref = batch_dims_by_ref
self._select_session = select_session
self._single_tensor = single_tensor

log("Blessing toplevel tensors for tensor runner:")
for tensor in toplevel_tensors:
Expand Down Expand Up @@ -159,7 +175,7 @@ def get_executable(self,

return TensorExecutable(
self.all_coders, self._fetches, self._batch_ids,
self._select_session)
self._select_session, self._single_tensor)
# pylint: enable=unused-argument

@property
Expand Down Expand Up @@ -206,4 +222,5 @@ def __init__(self,
tensors_by_ref=[tensor_to_get],
batch_dims_by_name=[],
batch_dims_by_ref=[0],
select_session=select_session)
select_session=select_session,
single_tensor=True)

0 comments on commit bff868f

Please sign in to comment.