Skip to content

Commit

Permalink
Fixing temporary bug with inputs and bucketing
Browse files Browse the repository at this point in the history
when series are bucketed (i.e. batches() do not return data in the same order
as get_series()), inputs were returned in the wrong order.
  • Loading branch information
jindrahelcl committed Jan 7, 2019
1 parent 1f58c92 commit e3f0f68
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion neuralmonkey/learning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@ def run_on_dataset(tf_manager: TensorFlowManager,
feedables = set.union(*[runner.feedables for runner in runners])
feedables |= dataset_runner.feedables

fetched_input = {s: [] for s in dataset.series} # type: Dict[str, List]

processed_examples = 0
for batch in dataset.batches():
if 0 < log_progress < time.process_time() - last_log_time:
Expand All @@ -335,6 +337,9 @@ def run_on_dataset(tf_manager: TensorFlowManager,
for script_list, ex_result in zip(batch_results, execution_results):
script_list.append(ex_result)

for s_id in batch.series:
fetched_input[s_id].extend(batch.get_series(s_id))

# Transpose runner interim results.
all_results = [join_execution_results(res) for res in batch_results[:-1]]

Expand All @@ -343,7 +348,6 @@ def run_on_dataset(tf_manager: TensorFlowManager,
# fetched_input = {
# k: [dic[k] for dic in input_transposed] for k in input_transposed[0]}

fetched_input = {s: list(dataset.get_series(s)) for s in dataset.series}
fetched_input_lengths = {s: len(fetched_input[s]) for s in dataset.series}

if len(set(fetched_input_lengths.values())) != 1:
Expand Down

0 comments on commit e3f0f68

Please sign in to comment.