From 71914bb222597d781b2144b8831421169b4247ee Mon Sep 17 00:00:00 2001 From: Felix Hieber Date: Tue, 5 Apr 2022 17:38:13 +0200 Subject: [PATCH] Clarify use of Translator.batch_size in code (#1033) --- CHANGELOG.md | 6 ++++++ sockeye/__init__.py | 2 +- sockeye/inference.py | 2 +- sockeye/translate.py | 5 ++--- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 295c88cb6..ebb578ae2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [3.1.9] + +### Changed + +- Clarified usage of `batch_size` in Translator code. + ## [3.1.8] ### Fixed diff --git a/sockeye/__init__.py b/sockeye/__init__.py index 192ff98d4..e0ef9df6c 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '3.1.8' +__version__ = '3.1.9' diff --git a/sockeye/inference.py b/sockeye/inference.py index 757096f02..ff6b2bf63 100644 --- a/sockeye/inference.py +++ b/sockeye/inference.py @@ -1200,7 +1200,7 @@ def _get_best_translations(self, result: SearchResult) -> List[Translation]: batch_size = best_hyp_indices.shape[0] // self.beam_size nbest_translations = [] # type: List[List[Translation]] reference_lengths = estimated_reference_lengths \ - if estimated_reference_lengths is not None else np.zeros((self.batch_size * self.beam_size, 1)) + if estimated_reference_lengths is not None else np.zeros((batch_size * self.beam_size, 1)) for n in range(0, self.nbest_size): # Initialize the best_ids to the first item in each batch, plus current nbest index diff --git a/sockeye/translate.py b/sockeye/translate.py index 511b5e4a0..839b33aa7 100644 --- a/sockeye/translate.py +++ b/sockeye/translate.py @@ -211,7 +211,6 @@ def read_and_translate(translator: inference.Translator, :param input_factors: Optional list of paths to files that contain source factors. :param input_is_json: Whether the input is in json format. """ - batch_size = translator.max_batch_size if chunk_size is None: if translator.max_batch_size == 1: # No batching, therefore there is not need to read segments in chunks. @@ -222,8 +221,8 @@ def read_and_translate(translator: inference.Translator, else: if chunk_size < translator.max_batch_size: logger.warning("You specified a chunk size (%d) smaller than the max batch size (%d). This will lead to " - "a reduction in translation speed. Consider choosing a larger chunk size." % (chunk_size, - batch_size)) + "a reduction in translation speed. Consider choosing a larger chunk size.", + chunk_size, translator.max_batch_size) logger.info("Translating...")