diff --git a/neuralmonkey/learning_utils.py b/neuralmonkey/learning_utils.py index 4a168da0e..50e0e0711 100644 --- a/neuralmonkey/learning_utils.py +++ b/neuralmonkey/learning_utils.py @@ -317,6 +317,8 @@ def run_on_dataset(tf_manager: TensorFlowManager, feedables = set.union(*[runner.feedables for runner in runners]) feedables |= dataset_runner.feedables + fetched_input = {s: [] for s in dataset.series} # type: Dict[str, List] + processed_examples = 0 for batch in dataset.batches(): if 0 < log_progress < time.process_time() - last_log_time: @@ -335,6 +337,9 @@ def run_on_dataset(tf_manager: TensorFlowManager, for script_list, ex_result in zip(batch_results, execution_results): script_list.append(ex_result) + for s_id in batch.series: + fetched_input[s_id].extend(batch.get_series(s_id)) + # Transpose runner interim results. all_results = [join_execution_results(res) for res in batch_results[:-1]] @@ -343,7 +348,6 @@ def run_on_dataset(tf_manager: TensorFlowManager, # fetched_input = { # k: [dic[k] for dic in input_transposed] for k in input_transposed[0]} - fetched_input = {s: list(dataset.get_series(s)) for s in dataset.series} fetched_input_lengths = {s: len(fetched_input[s]) for s in dataset.series} if len(set(fetched_input_lengths.values())) != 1: