From 299c1bca4c27dc16544820569a57ab27565ecc73 Mon Sep 17 00:00:00 2001 From: Dusan Varis Date: Wed, 9 Jan 2019 16:29:51 +0100 Subject: [PATCH] workaround for train_set batching during inference time --- neuralmonkey/learning_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/neuralmonkey/learning_utils.py b/neuralmonkey/learning_utils.py index 50e0e0711..102eafb42 100644 --- a/neuralmonkey/learning_utils.py +++ b/neuralmonkey/learning_utils.py @@ -13,7 +13,7 @@ from termcolor import colored from neuralmonkey.logging import log, log_print, warn -from neuralmonkey.dataset import Dataset +from neuralmonkey.dataset import Dataset, BatchingScheme from neuralmonkey.tf_manager import TensorFlowManager from neuralmonkey.runners.base_runner import ( BaseRunner, ExecutionResult, GraphExecutor, OutputSeries) @@ -85,6 +85,9 @@ def training_loop(cfg: Namespace) -> None: trainer_result = cfg.tf_manager.execute( batch, feedables, cfg.trainers, train=True, summaries=True) + # workaround: we need to use validation batching scheme + # during evaluation + batch.batching = BatchingScheme(batch_size=cfg.batch_size) train_results, train_outputs, f_batch = run_on_dataset( cfg.tf_manager, cfg.runners, cfg.dataset_runner, batch, cfg.postprocess, write_out=False)