diff --git a/neuralmonkey/learning_utils.py b/neuralmonkey/learning_utils.py index 50e0e0711..102eafb42 100644 --- a/neuralmonkey/learning_utils.py +++ b/neuralmonkey/learning_utils.py @@ -13,7 +13,7 @@ from termcolor import colored from neuralmonkey.logging import log, log_print, warn -from neuralmonkey.dataset import Dataset +from neuralmonkey.dataset import Dataset, BatchingScheme from neuralmonkey.tf_manager import TensorFlowManager from neuralmonkey.runners.base_runner import ( BaseRunner, ExecutionResult, GraphExecutor, OutputSeries) @@ -85,6 +85,9 @@ def training_loop(cfg: Namespace) -> None: trainer_result = cfg.tf_manager.execute( batch, feedables, cfg.trainers, train=True, summaries=True) + # workaround: we need to use validation batching scheme + # during evaluation + batch.batching = BatchingScheme(batch_size=cfg.batch_size) train_results, train_outputs, f_batch = run_on_dataset( cfg.tf_manager, cfg.runners, cfg.dataset_runner, batch, cfg.postprocess, write_out=False)