From edac7004da97c68ea9e06dfb168e21d1c8fd32fe Mon Sep 17 00:00:00 2001 From: Hoang Cuong Date: Thu, 10 Mar 2022 01:10:48 -0800 Subject: [PATCH] Support target prefix with JSON format (#1025) This feature allows Sockeye to add target prefix and target prefix factors during inference with JSON format. During inference target prefix can be specified with JSON format as follows: { "text": "The boy ate the waff@@ le .", "target_prefix": "2XX"} If a model was trained with target factors, we can add target prefix factors during inference with JSON format as follows: { "text": "The boy ate the waff@@ le .", "target_prefix_factors": ["O"]} Meanwhile, we can also add both target prefix and target prefix factors at the same time with JSON format, e.g.,: { "text": "The boy ate the waff@@ le .", "target_prefix": "2XX", "target_prefix_factors": ["O"]} Note that if an input is very long, Sockeye chunks the text and translates each chunk separately. By default, target prefix and target prefix factors are added to all chunks in that case. Alternatively, we can set use_target_prefix_all_chunks to false to add them only to the first chunk, e.g.,: { "text": "The boy ate the waff@@ le .", "target_prefix": "2XX", "target_prefix_factors": ["O"], "use_target_prefix_all_chunks": false} * support target prefix * revise illustration * fix space * add type ignore * add target prefix, revision 2 * add target prefix factors * slightly revise docs * revise based on Michael's suggestions * revise based on Felix's comments * small revise type for pylint * revise based on Tobias suggestion * revised based on Felix's comments * one_hot_encoding_from_prefix function for a full tensor * revise warning of empty prefix * use clamp instead of masked_fill_ * cleaner adding of target_prefix_factors to decode_step) * put pt.index_select vocab_slice_ids before beam loop * revised with prefix masking * revise a tiny comment * pre_expand outside, suggested by Tobi * pre_expand outside, a small fix * competely pre_expand outside, gen_prefix_masking generated only once * fix factors and add prefix to vocab_slice_ids by pytorch * small mypy fix * small fix mypy * minor revision * avoiding unnecesary copy tensor * avoid duplicate padding * extra revision suggested by Tobias * Fix translatorInput with pylint * Fix translatorInput with pylint Co-authored-by: Hoang --- CHANGELOG.md | 5 ++ docs/inference.md | 35 ++++++++ sockeye/__init__.py | 2 +- sockeye/beam_search.py | 131 +++++++++++++++++++++++++----- sockeye/constants.py | 4 + sockeye/inference.py | 134 ++++++++++++++++++++++++++++--- sockeye/model.py | 1 + sockeye/test_utils.py | 96 ++++++++++++++++++++-- sockeye/utils.py | 146 ++++++++++++++++++++++++++++++++++ test/common.py | 99 +++++++++++++++++++++-- test/unit/test_beam_search.py | 10 ++- test/unit/test_inference.py | 2 +- test/unit/test_utils.py | 119 +++++++++++++++++++++++++++ 13 files changed, 737 insertions(+), 47 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c24853c11..b7223ee66 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,11 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [3.1.4] + +### Added +- Added support for the use of adding target prefix and target prefix factors to the input in JSON format during inference. + ## [3.1.3] ### Added diff --git a/docs/inference.md b/docs/inference.md index 92a1c33f9..225602c4e 100644 --- a/docs/inference.md +++ b/docs/inference.md @@ -96,6 +96,41 @@ Similar to source factors, source prefix factors can be also specified with JSON { "text": "The boy ate the waff@@ le .", "source_prefix": "2XX", "source_prefix_factors": ["O"]} ``` +Finally, Sockeye also supports the use of adding target prefix and target prefix factors to the translation during inference. In the same spirit to the example above, let us assume a multilingual translation model trained with a target prefix 2XX (this time the prefix is added to the target sentence instead of the source sentence). During inference this target prefix can be specified with JSON format as follows: + +```json +{ "text": "The boy ate the waff@@ le .", "target_prefix": "2XX"} +``` + +This forces the decoder to generate `2XX` as its first target token (i.e. the one right after the `` token). + +If your model was trained with target factors, every target translation token aligns with one or more corresponding target factor tokens (depending of the number of target factors of the model). During inference, you can add target prefix factors to the translation with JSON format, e.g.: + +```json +{ "text": "The boy ate the waff@@ le .", "target_prefix_factors": ["O"]} +``` + +Here, the decoder is forced to generate a translation and its corresponding target factors so that the first target token aligns with factor `O` as its target factor. + +Note that you can also add both target prefix and target prefix factors with different length, e.g.,: + +```json +{ "text": "The boy ate the waff@@ le .", "target_prefix": "2XX", "target_prefix_factors": ["O O E"]} +``` +With this example, `2XX` is the force-decoded first target token of the translation. This token also aligns with factor `O` its corresponding target factor. Moreover, the next two target tokens after `2XX` align with `O E` as their corresponding target factors. + +Note that if an input is very long, Sockeye chunks the text and translates each chunk separately. By default, target prefix and target prefix factors are added to all chunks in that case. Alternatively, you can set `use_target_prefix_all_chunks` to `false` to add them only to the first chunk, e.g.,: + +```json +{ "text": "The boy ate the waff@@ le .", "target_prefix": "2XX", "target_prefix_factors": ["O"], "use_target_prefix_all_chunks": false} +``` + +Note also that the translation output includes the target prefix as its first string by default. Alternatively, you can remove the target prefix from the translation output by setting `keep_target_prefix` to `false`, e.g.,: + +```json +{ "text": "The boy ate the waff@@ le .", "target_prefix": "2XX", "keep_target_prefix": false} +``` + ## N-best translations Sockeye can return the n best hypotheses per input (*nbest lists*). diff --git a/sockeye/__init__.py b/sockeye/__init__.py index 854cb1cd3..24afdda6c 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '3.1.3' +__version__ = '3.1.4' diff --git a/sockeye/beam_search.py b/sockeye/beam_search.py index 9a3e7c1af..0a84fbf5e 100644 --- a/sockeye/beam_search.py +++ b/sockeye/beam_search.py @@ -46,7 +46,17 @@ def encode_and_initialize(self, def decode_step(self, step_input: pt.Tensor, states: List, - vocab_slice_ids: Optional[pt.Tensor] = None): + vocab_slice_ids: Optional[pt.Tensor] = None, + target_prefix_factor_mask: Optional[pt.Tensor] = None, + factor_vocab_size: Optional[int] = None): + raise NotImplementedError() + + @property + def model_output_vocab_size(self): + raise NotImplementedError() + + @property + def model_output_factor_vocab_size(self): raise NotImplementedError() @@ -70,19 +80,23 @@ def encode_and_initialize(self, inputs: pt.Tensor, valid_length: Optional[pt.Ten def decode_step(self, step_input: pt.Tensor, states: List, - vocab_slice_ids: Optional[pt.Tensor] = None): + vocab_slice_ids: Optional[pt.Tensor] = None, + target_prefix_factor_mask: Optional[pt.Tensor] = None, + factor_vocab_size: Optional[int] = None): logits, states, target_factor_outputs = self._model.decode_step(step_input, states, vocab_slice_ids) if not self._skip_softmax: logits = pt.log_softmax(logits, dim=-1) - scores = -logits # shape: (batch, output_vocab_size/len(vocab_slice_ids)) + scores = -logits # shape: (batch*beam, output_vocab_size/len(vocab_slice_ids)) target_factors = None # type: Optional[pt.Tensor] if target_factor_outputs: predictions = [] # type: List[pt.Tensor] - for tf_logits in target_factor_outputs: + for i, tf_logits in enumerate(target_factor_outputs, 1): if not self._skip_softmax: tf_logits = pt.log_softmax(tf_logits, dim=-1) tf_scores = -tf_logits + if target_prefix_factor_mask is not None: + tf_scores += target_prefix_factor_mask[:,:,i-1].reshape(-1, factor_vocab_size) # target factors are greedily chosen, and score and index are collected via torch.min. # Shape per factor: (batch*beam, 1, 2), where last dimension holds values and indices. tf_prediction = pt.cat(tf_scores.min(dim=-1, keepdim=True), dim=1).unsqueeze(1) @@ -92,6 +106,14 @@ def decode_step(self, return scores, states, target_factors + @property + def model_output_vocab_size(self): + return self._model.output_layer_vocab_size + + @property + def model_output_factor_vocab_size(self): + return self._model.factor_vocab_size + class _EnsembleInference(_Inference): @@ -125,7 +147,9 @@ def encode_and_initialize(self, inputs: pt.Tensor, valid_length: Optional[pt.Ten def decode_step(self, step_input: pt.Tensor, states: List, - vocab_slice_ids: Optional[pt.Tensor] = None): + vocab_slice_ids: Optional[pt.Tensor] = None, + target_prefix_factor_mask: Optional[pt.Tensor] = None, + factor_vocab_size: Optional[int] = None): outputs = [] # type: List[pt.Tensor] new_states = [] # type: List[pt.Tensor] factor_outputs = [] # type: List[List[pt.Tensor]] @@ -138,6 +162,9 @@ def decode_step(self, outputs.append(probs) if target_factor_outputs: target_factor_probs = [tfo.softmax(dim=-1) for tfo in target_factor_outputs] + if target_prefix_factor_mask is not None: + for i in range(len(target_factor_probs)): + target_factor_probs[i] += target_prefix_factor_mask[:,:,i].reshape(-1, factor_vocab_size) factor_outputs.append(target_factor_probs) new_states += model_states scores = self._interpolation(outputs) @@ -160,6 +187,13 @@ def log_linear_interpolation(predictions): log_probs = utils.average_tensors([p.log() for p in predictions]) return -(log_probs.log_softmax()) + @property + def model_output_vocab_size(self): + return self._models[0].output_layer_vocab_size + + @property + def model_output_factor_vocab_size(self): + return self._models[0].factor_vocab_size @dataclass class SearchResult: @@ -495,14 +529,18 @@ def forward(self, best_hyp_indices, *states): def _get_vocab_slice_ids(restrict_lexicon: Optional[lexicon.TopKLexicon], source_words: pt.Tensor, eos_id: int, - beam_size: int) -> Tuple[pt.Tensor, int]: + beam_size: int, + target_prefix: Optional[pt.Tensor] = None) -> Tuple[pt.Tensor, int]: device = source_words.device - vocab_slice_ids = restrict_lexicon.get_trg_ids(source_words.cpu().int().numpy()) + vocab_slice_ids_np = restrict_lexicon.get_trg_ids(source_words.cpu().int().numpy()) # type: ignore + vocab_slice_ids = pt.tensor(vocab_slice_ids_np, device=device, dtype=pt.int64) + if target_prefix is not None: + # Ensuring that target prefix ids are part of vocab_slice_ids + vocab_slice_ids = pt.concat([vocab_slice_ids, target_prefix.flatten().type(pt.int64)], -1).unique() # Pad to a multiple of 8. - vocab_slice_ids = pt.nn.functional.pad(pt.tensor(vocab_slice_ids, device=source_words.device, dtype=pt.int64), # type: ignore - pad=(0, 7 - ((vocab_slice_ids.size - 1) % 8)), + vocab_slice_ids = pt.nn.functional.pad(vocab_slice_ids, \ + pad=(0, 7 - ((vocab_slice_ids.size(-1) - 1) % 8)), \ mode='constant', value=eos_id) - vocab_slice_ids_shape = vocab_slice_ids.size()[0] # type: ignore if vocab_slice_ids_shape < beam_size + 1: # This fixes an edge case for toy models, where the number of vocab ids from the lexicon is @@ -538,6 +576,8 @@ def __init__(self, self.bos_id = bos_id self.eos_id = eos_id self.device = device + self.output_vocab_size = inference.model_output_vocab_size + self.output_factor_vocab_size = inference.model_output_factor_vocab_size self._inference = inference self.num_source_factors = num_source_factors self.num_target_factors = num_target_factors @@ -549,8 +589,10 @@ def __init__(self, def forward(self, source: pt.Tensor, source_length: pt.Tensor, - restrict_lexicon: Optional[lexicon.TopKLexicon], - max_output_lengths: pt.Tensor) -> SearchResult: + restrict_lexicon: Optional[lexicon.TopKLexicon] = None, + max_output_lengths: pt.Tensor = None, + target_prefix: Optional[pt.Tensor] = None, + target_prefix_factors: Optional[pt.Tensor] = None) -> SearchResult: """ Translates a single sentence (batch_size=1) using greedy search. @@ -559,6 +601,8 @@ def forward(self, :param restrict_lexicon: Lexicon to use for vocabulary restriction. :param max_output_lengths: ndarray of maximum output lengths per input in source. Shape: (batch_size=1,). Dtype: int32. + :param target_prefix: Target prefix ids. Shape: (batch_size=1, max target prefix length). + :param target_prefix_factors: Target prefix factor ids. Shape: (batch_size=1, max target prefix factors length, num_target_factors). :return SearchResult. """ batch_size = source.size()[0] @@ -578,17 +622,36 @@ def forward(self, # target vocab for this sentence. if restrict_lexicon: source_words = source[:, :, 0] - vocab_slice_ids, _ = _get_vocab_slice_ids(restrict_lexicon, source_words, self.eos_id, beam_size=1) + vocab_slice_ids, _ = _get_vocab_slice_ids(restrict_lexicon, source_words, self.eos_id, beam_size=1, target_prefix=target_prefix) # (0) encode source sentence, returns a list model_states, _ = self._inference.encode_and_initialize(source, source_length) # TODO: check for disabled predicted output length + # Prefix masks, where scores are infinity for all other vocabulary items except target_prefix ids + prefix_masks, prefix_masks_length = (utils.gen_prefix_masking(target_prefix, self.output_vocab_size, self.dtype), target_prefix.size(1)) \ + if target_prefix is not None else (None, None) # type: ignore + if prefix_masks is not None and vocab_slice_ids is not None: + prefix_masks = pt.index_select(prefix_masks, -1, vocab_slice_ids) + # Prefix factor masks, where scores are also infinity for all other factor items except target_prefix_factor ids + target_prefix_factor_masks, target_prefix_factor_length = (utils.gen_prefix_masking(target_prefix_factors, self.output_factor_vocab_size, self.dtype), target_prefix_factors.size(1)) \ + if self.num_target_factors > 1 and target_prefix_factors is not None \ + else (None, None) # type: ignore + t = 1 for t in range(1, max_iterations + 1): + target_prefix_factor_mask = target_prefix_factor_masks[:, t-1] \ + if target_prefix_factor_masks is not None and t <= target_prefix_factor_length \ + else None scores, model_states, target_factors = self._inference.decode_step(best_word_index, model_states, - vocab_slice_ids=vocab_slice_ids) + vocab_slice_ids, + target_prefix_factor_mask, + self.output_factor_vocab_size) + if target_prefix is not None and t <= prefix_masks_length: + # Make sure search selects the current prefix token + scores += prefix_masks[:, t-1] + # shape: (batch*beam=1, 1) best_word_index = self.work_block(scores, vocab_slice_ids, target_factors) outputs.append(best_word_index) @@ -667,6 +730,7 @@ def __init__(self, self.bos_id = bos_id self.eos_id = eos_id self.output_vocab_size = output_vocab_size + self.output_factor_vocab_size = inference.model_output_factor_vocab_size self.device = device self._inference = inference self.beam_search_stop = beam_search_stop @@ -698,7 +762,9 @@ def forward(self, source: pt.Tensor, source_length: pt.Tensor, restrict_lexicon: Optional[lexicon.TopKLexicon], - max_output_lengths: pt.Tensor) -> SearchResult: + max_output_lengths: pt.Tensor, + target_prefix: Optional[pt.Tensor] = None, + target_prefix_factors: Optional[pt.Tensor] = None) -> SearchResult: """ Translates multiple sentences using beam search. @@ -707,6 +773,8 @@ def forward(self, :param restrict_lexicon: Lexicon to use for vocabulary restriction. :param max_output_lengths: Tensor of maximum output lengths per input in source. Shape: (batch_size,). Dtype: int32. + :param target_prefix: Target prefix ids. Shape: (batch_size, max prefix length). + :param target_prefix_factors: Target prefix factors ids. Shape: (batch_size, max prefix factors length, num_target_factors). :return SearchResult. """ batch_size = source.size()[0] @@ -734,6 +802,8 @@ def forward(self, batch_indices = pt.arange(0, batch_size * self.beam_size, self.beam_size, dtype=pt.int64, device=self.device) first_step_mask = pt.full((batch_size * self.beam_size, 1), fill_value=onp.inf, device=self.device, dtype=self.dtype) first_step_mask[batch_indices] = 0.0 + if target_prefix is not None: + first_step_mask = utils.adjust_first_step_masking(target_prefix, first_step_mask) # Best word and hypotheses indices across beam search steps from topk operation. best_hyp_indices_list = [] # type: List[pt.Tensor] @@ -761,7 +831,7 @@ def forward(self, if restrict_lexicon: source_words = source[:, :, 0] vocab_slice_ids, output_vocab_size = _get_vocab_slice_ids(restrict_lexicon, source_words, self.eos_id, - beam_size=1) + beam_size=1, target_prefix=target_prefix) pad_dist = pt.full((1, output_vocab_size), fill_value=onp.inf, device=self.device, dtype=self.dtype) pad_dist[0, 0] = 0 # [0, inf, inf, ...] @@ -779,15 +849,31 @@ def forward(self, # repeat estimated_reference_lengths to shape (batch_size * beam_size) estimated_reference_lengths = estimated_reference_lengths.repeat_interleave(self.beam_size, dim=0) + # Prefix token masks, where scores are infinity for all other vocabulary items except target_prefix ids + prefix_masks, prefix_masks_length = (utils.gen_prefix_masking(target_prefix, self.output_vocab_size, self.dtype), target_prefix.size(1)) \ + if target_prefix is not None else (None, None) # type: ignore + prefix_masks = prefix_masks.unsqueeze(2).expand(-1, -1, self.beam_size, -1) if target_prefix is not None else None # type: ignore + if prefix_masks is not None and vocab_slice_ids is not None: + prefix_masks = pt.index_select(prefix_masks, -1, vocab_slice_ids) + # Prefix factor masks, where scores are also infinity for all other factor items except target_prefix_factor ids + target_prefix_factor_masks, target_prefix_factor_length = (utils.gen_prefix_masking(target_prefix_factors, self.output_factor_vocab_size, self.dtype), target_prefix_factors.size(1)) \ + if self.num_target_factors > 1 and target_prefix_factors is not None \ + else (None, None) # type: ignore + target_prefix_factor_masks = target_prefix_factor_masks.unsqueeze(2).expand(-1, -1, self.beam_size, -1, -1) if target_prefix_factor_masks is not None else None # type: ignore t = 1 for t in range(1, max_iterations + 1): # max_iterations + 1 required to get correct results # (1) obtain next predictions and advance models' state # target_dists: (batch_size * beam_size, target_vocab_size) # target_factors: (batch_size * beam_size, num_secondary_factors, 2), # where last dimension holds indices and scores + target_prefix_factor_mask = target_prefix_factor_masks[:, t-1] \ + if target_prefix_factor_masks is not None and t <= target_prefix_factor_length \ + else None target_dists, model_states, target_factors = self._inference.decode_step(best_word_indices, model_states, - vocab_slice_ids) + vocab_slice_ids, + target_prefix_factor_mask, + self.output_factor_vocab_size) # (2) Produces the accumulated cost of target words in each row. # There is special treatment for finished rows. @@ -795,6 +881,10 @@ def forward(self, scores, lengths = self._update_scores(target_dists, finished, scores_accumulated, lengths, max_output_lengths, pad_dist, eos_dist) + if target_prefix is not None and t <= prefix_masks_length: + # Make sure search selects the current prefix token + scores += prefix_masks[:, t-1].reshape(-1, output_vocab_size) + # (3) Get beam_size winning hypotheses for each sentence block separately. Only look as # far as the active beam size for each sentence. if self._sample is not None: @@ -802,8 +892,11 @@ def forward(self, else: # On the first timestep, all hypotheses have identical histories, so force topk() to choose extensions # of the first row only by setting all other rows to inf - if t == 1: - scores += first_step_mask + if target_prefix is None: + scores = scores + first_step_mask if t == 1 else scores + else: + # While decoding target prefixes, we also mask all other hypotheses than the first + scores = scores + first_step_mask[:, t-1:t] if t <= first_step_mask.size(-1) else scores if self._traced_top is None: logger.debug("Tracing _top") diff --git a/sockeye/constants.py b/sockeye/constants.py index 3c5da9a60..6412f8b5e 100644 --- a/sockeye/constants.py +++ b/sockeye/constants.py @@ -124,6 +124,10 @@ JSON_FACTORS_KEY = "factors" JSON_SOURCE_PREFIX_KEY = "source_prefix" JSON_SOURCE_PREFIX_FACTORS_KEY = "source_prefix_factors" +JSON_TARGET_PREFIX_KEY = "target_prefix" +JSON_TARGET_PREFIX_FACTORS_KEY = "target_prefix_factors" +JSON_USE_TARGET_PREFIX_ALL_CHUNKS_KEY = "use_target_prefix_all_chunks" +JSON_KEEP_TARGET_PREFIX_KEY = "keep_target_prefix" JSON_RESTRICT_LEXICON_KEY = "restrict_lexicon" JSON_CONSTRAINTS_KEY = "constraints" JSON_AVOID_KEY = "avoid" diff --git a/sockeye/inference.py b/sockeye/inference.py index 0596d2890..471a65426 100644 --- a/sockeye/inference.py +++ b/sockeye/inference.py @@ -132,17 +132,20 @@ class TranslatorInput: factors: Optional[List[Tokens]] = None source_prefix_tokens: Optional[Tokens] = None source_prefix_factors: Optional[List[Tokens]] = None + target_prefix_tokens: Optional[Tokens] = None + target_prefix_factors: Optional[List[Tokens]] = None + use_target_prefix_all_chunks: Optional[bool] = True + keep_target_prefix_key: Optional[bool] = True restrict_lexicon: Optional[lexicon.TopKLexicon] = None constraints: Optional[List[Tokens]] = None avoid_list: Optional[List[Tokens]] = None pass_through_dict: Optional[Dict] = None def __str__(self): - return 'TranslatorInput(%s, %s, factors=%s, source_prefix_tokens=%s, source_prefix_factors=%s, constraints=%s, avoid=%s)' \ - % (self.sentence_id, self.tokens, self.factors, self.source_prefix_tokens, self.source_prefix_factors, self.constraints, self.avoid_list) + return f'TranslatorInput({self.sentence_id}, {self.tokens}, factors={self.factors}, source_prefix_tokens={self.source_prefix_tokens}, source_prefix_factors={self.source_prefix_factors}, target_prefix_tokens={self.target_prefix_tokens}, target_prefix_factors={self.target_prefix_factors}, use_target_prefix_all_chunks={self.use_target_prefix_all_chunks}, keep_target_prefix_key={self.keep_target_prefix_key}, constraints={self.constraints}, avoid={self.avoid_list})' def __len__(self): - return len(self.tokens) + self.num_source_prefix_tokens() + return len(self.tokens) + self.num_source_prefix_tokens @property def num_factors(self) -> int: @@ -157,12 +160,39 @@ def get_source_prefix_tokens(self) -> Tokens: """ return self.source_prefix_tokens if self.source_prefix_tokens is not None else [] + @property def num_source_prefix_tokens(self) -> int: """ Returns the number of source prefix tokens of this instance. """ return len(self.get_source_prefix_tokens()) + def get_target_prefix_tokens(self) -> Tokens: + """ + Returns the target prefix tokens of this instance. + """ + return self.target_prefix_tokens if self.target_prefix_tokens is not None else [] + + @property + def num_target_prefix_tokens(self) -> int: + """ + Returns the number of target prefix tokens of this instance. + """ + return len(self.get_target_prefix_tokens()) + + def get_target_prefix_factors(self) -> List[Tokens]: + """ + Returns the target prefix factors of this instance. + """ + return self.target_prefix_factors if self.target_prefix_factors is not None else [[]] + + @property + def num_target_prefix_factors(self) -> int: + """ + Returns the number of target prefix factors of this instance. + """ + return len(self.get_target_prefix_factors()[0]) + def chunks(self, chunk_size: int) -> Generator['TranslatorInput', None, None]: """ Takes a TranslatorInput (itself) and yields TranslatorInputs for chunks of size chunk_size. @@ -178,11 +208,15 @@ def chunks(self, chunk_size: int) -> Generator['TranslatorInput', None, None]: 'with the first chunk, which is probably wrong.', self.sentence_id, len(self.tokens), chunk_size) - for chunk_id, i in enumerate(range(0, len(self) - self.num_source_prefix_tokens(), chunk_size)): + for chunk_id, i in enumerate(range(0, len(self) - self.num_source_prefix_tokens, chunk_size)): factors = [factor[i:i + chunk_size] for factor in self.factors] if self.factors is not None else None # Constrained decoding is not supported for chunked TranslatorInputs. As a fall-back, constraints are # assigned to the first chunk constraints = self.constraints if chunk_id == 0 else None + # Target_prefix_tokens are assigned to all chunks if self.use_target_prefix_all_chunks is True, + # otherwise target_prefix_tokens are assigned only to the first chunk + target_prefix_tokens = self.target_prefix_tokens if chunk_id == 0 or self.use_target_prefix_all_chunks else None + target_prefix_factors = self.target_prefix_factors if chunk_id == 0 or self.use_target_prefix_all_chunks else None pass_through_dict = copy.deepcopy(self.pass_through_dict) \ if (chunk_id == 0 and self.pass_through_dict is not None) else None yield TranslatorInput(sentence_id=self.sentence_id, @@ -190,6 +224,10 @@ def chunks(self, chunk_size: int) -> Generator['TranslatorInput', None, None]: factors=factors, source_prefix_tokens=self.source_prefix_tokens, source_prefix_factors=self.source_prefix_factors, + target_prefix_tokens=target_prefix_tokens, + target_prefix_factors=self.target_prefix_factors, + use_target_prefix_all_chunks=self.use_target_prefix_all_chunks, + keep_target_prefix_key=self.keep_target_prefix_key, restrict_lexicon=self.restrict_lexicon, constraints=constraints, avoid_list=self.avoid_list, @@ -205,6 +243,10 @@ def with_eos(self) -> 'TranslatorInput': self.factors] if self.factors is not None else None, source_prefix_tokens=self.source_prefix_tokens, source_prefix_factors=self.source_prefix_factors, + target_prefix_tokens=self.target_prefix_tokens, + target_prefix_factors=self.target_prefix_factors, + use_target_prefix_all_chunks=self.use_target_prefix_all_chunks, + keep_target_prefix_key=self.keep_target_prefix_key, restrict_lexicon=self.restrict_lexicon, constraints=self.constraints, avoid_list=self.avoid_list, @@ -274,7 +316,9 @@ def make_input_from_dict(sentence_id: SentenceId, tokens = list(utils.get_tokens(tokens)) factors = input_dict.get(C.JSON_FACTORS_KEY) source_prefix_tokens = input_dict.get(C.JSON_SOURCE_PREFIX_KEY) - source_prefix_tokens = list(utils.get_tokens(source_prefix_tokens)) if source_prefix_tokens else None + source_prefix_tokens = list(utils.get_tokens(source_prefix_tokens)) if source_prefix_tokens is not None else None + if source_prefix_tokens is not None and not source_prefix_tokens: + logger.warning(f"Empty string is specified as a source prefix for input '{input_dict[C.JSON_SOURCE_PREFIX_KEY]}'.") source_prefix_factors = input_dict.get(C.JSON_SOURCE_PREFIX_FACTORS_KEY) if source_prefix_factors is not None and not source_prefix_tokens: logger.error("Source prefix factors cannot be specified when source prefix is not specified") @@ -295,6 +339,9 @@ def make_input_from_dict(sentence_id: SentenceId, if isinstance(source_prefix_factors, list): source_prefix_factors = [list(utils.get_tokens(source_prefix_factor)) for source_prefix_factor in source_prefix_factors] + for source_prefix_factor in source_prefix_factors: + if not source_prefix_factor: + logger.warning(f"Empty list is specified as source prefix factors for input '{input_dict[C.JSON_TEXT_KEY]}'.") lengths = [len(source_prefix_factor) for source_prefix_factor in source_prefix_factors] if not all(len(source_prefix_tokens) == length for length in lengths): logger.error("Source prefix has %d tokens but there are %s prefix factors", len(source_prefix_tokens), str(lengths)) @@ -303,6 +350,20 @@ def make_input_from_dict(sentence_id: SentenceId, logger.error("There is mismatch in source factors %d and prefix factors %d", len(factors), len(source_prefix_factors)) return _bad_input(sentence_id, reason=str(input_dict)) + target_prefix_tokens = input_dict.get(C.JSON_TARGET_PREFIX_KEY) + target_prefix_tokens = list(utils.get_tokens(target_prefix_tokens)) if target_prefix_tokens is not None else None + if target_prefix_tokens is not None and not target_prefix_tokens: + logger.warning(f"Empty string is specified as a target prefix for input '{input_dict[C.JSON_TEXT_KEY]}'.") + + target_prefix_factors = input_dict.get(C.JSON_TARGET_PREFIX_FACTORS_KEY) + if isinstance(target_prefix_factors, list): + target_prefix_factors = [list(utils.get_tokens(target_prefix_factor)) for target_prefix_factor in target_prefix_factors] + for target_prefix_factor in target_prefix_factors: + if not target_prefix_factor: + logger.warning(f"Empty list is specified as target prefix factors for input '{input_dict[C.JSON_TEXT_KEY]}'.") + + use_target_prefix_all_chunks = input_dict.get(C.JSON_USE_TARGET_PREFIX_ALL_CHUNKS_KEY, True) + keep_target_prefix_key = input_dict.get(C.JSON_KEEP_TARGET_PREFIX_KEY, True) # Lexicon for vocabulary selection/restriction: # This is only populated when using multiple lexicons, in which case the # restrict_lexicon key must exist and the value (name) must map to one @@ -344,6 +405,10 @@ def make_input_from_dict(sentence_id: SentenceId, return TranslatorInput(sentence_id=sentence_id, tokens=tokens, factors=factors, source_prefix_tokens=source_prefix_tokens, source_prefix_factors=source_prefix_factors, + target_prefix_tokens=target_prefix_tokens, + target_prefix_factors=target_prefix_factors, + use_target_prefix_all_chunks=use_target_prefix_all_chunks, + keep_target_prefix_key=keep_target_prefix_key, restrict_lexicon=restrict_lexicon, constraints=constraints, avoid_list=avoid_list, pass_through_dict=input_dict) @@ -588,9 +653,18 @@ def _expand_nbest_translation(translation: Translation) -> List[Translation]: for target_ids, score in zip(translation.nbest_translations.target_ids_list, translation.nbest_translations.scores): nbest_list.append(Translation(target_ids, score, estimated_reference_length=translation.estimated_reference_length)) - return nbest_list +def _remove_target_prefix_tokens(target_ids: TokenIds, num_target_prefix_tokens: int) -> TokenIds: + """ + Remove target prefix tokens from target token Ids + + :param target_ids: target token Ids of translation of an input + :param num_target_prefix_tokens: number of target prefix tokens included in the translation + :return: new target_ids + """ + starting_idx = min(len(target_ids), num_target_prefix_tokens) + return target_ids[starting_idx:] def _concat_translations(translations: List[Translation], stop_ids: Set[int], @@ -802,11 +876,11 @@ def translate(self, trans_inputs: List[TranslatorInput], fill_up_batches: bool = translated_chunks.append(IndexedTranslation(input_idx=trans_input_idx, chunk_idx=0, translation=empty_translation(add_nbest=(self.nbest_size > 1)))) else: - max_input_length_for_chunking = self.max_input_length - trans_input.num_source_prefix_tokens() # take length of source prefix, if used, into account while chunking + max_input_length_for_chunking = self.max_input_length - trans_input.num_source_prefix_tokens # take length of source prefix, if used, into account while chunking if max_input_length_for_chunking <= 0: logger.warning( "Input %s has a source prefix with length (%d) that already equals or exceeds max input length (%d). Return an empty translation instead.", \ - trans_input.sentence_id, trans_input.num_source_prefix_tokens(), self.max_input_length) + trans_input.sentence_id, trans_input.num_source_prefix_tokens, self.max_input_length) translated_chunks.append(IndexedTranslation(input_idx=trans_input_idx, chunk_idx=0, translation=empty_translation(add_nbest=(self.nbest_size > 1)))) elif len(trans_input.tokens) > max_input_length_for_chunking: @@ -869,11 +943,20 @@ def translate(self, trans_inputs: List[TranslatorInput], fill_up_batches: bool = chunks_by_input_idx = itertools.groupby(translated_chunks, key=lambda translation: translation.input_idx) for trans_input, (input_idx, translations_for_input_idx) in zip(trans_inputs, chunks_by_input_idx): translations_for_input_idx = list(translations_for_input_idx) # type: ignore + num_target_prefix_tokens = trans_input.num_target_prefix_tokens if len(translations_for_input_idx) == 1: # type: ignore translation = translations_for_input_idx[0].translation # type: ignore + if num_target_prefix_tokens > 0 and not trans_input.keep_target_prefix_key: + translation.target_ids = \ + _remove_target_prefix_tokens(translation.target_ids, num_target_prefix_tokens) else: translations_to_concat = [translated_chunk.translation for translated_chunk in translations_for_input_idx] + if num_target_prefix_tokens > 0 and not trans_input.keep_target_prefix_key: + for i in range(len(translations_to_concat)): + if i == 0 or trans_input.use_target_prefix_all_chunks: + translations_to_concat[i].target_ids = \ + _remove_target_prefix_tokens(translations_to_concat[i].target_ids, num_target_prefix_tokens) translation = self._concat_translations(translations_to_concat) results.append(self._make_result(trans_input, translation)) @@ -888,7 +971,9 @@ def translate(self, trans_inputs: List[TranslatorInput], fill_up_batches: bool = def _get_inference_input(self, trans_inputs: List[TranslatorInput]) -> Tuple[pt.Tensor, int, + Optional[pt.Tensor], Optional[lexicon.TopKLexicon], + pt.Tensor, pt.Tensor]: """ Assembles the numerical data for the batch. This comprises a tensor for the source sentences, @@ -896,16 +981,20 @@ def _get_inference_input(self, :param trans_inputs: List of TranslatorInputs. :return tensor of source ids (shape=(batch_size, bucket_key, num_factors)), - tensor of valid source lengths, lexicon for vocabulary restriction, and a tensor of maximum output + tensor of valid source lengths, target prefix, lexicon for vocabulary restriction, and a tensor of maximum output lengths. """ batch_size = len(trans_inputs) lengths = [len(inp) for inp in trans_inputs] + max_target_prefix_length = max(inp.num_target_prefix_tokens for inp in trans_inputs) + max_target_prefix_factors_length = max(inp.num_target_prefix_factors for inp in trans_inputs) max_length = max(len(inp) for inp in trans_inputs) # assembling source ids on cpu array (faster) and copy to Translator.device (potentially GPU) in one go below. source = onp.zeros((batch_size, max_length, self.num_source_factors), dtype='int32') + target_prefix = onp.zeros((batch_size, max_target_prefix_length), dtype='int32') if max_target_prefix_length > 0 else None + target_prefix_factors = onp.zeros((batch_size, max_target_prefix_factors_length, self.num_target_factors - 1), dtype='int32') if self.num_target_factors > 1 and max_target_prefix_factors_length > 0 else None restrict_lexicon = None # type: Optional[lexicon.TopKLexicon] max_output_lengths = [] # type: List[int] @@ -914,6 +1003,11 @@ def _get_inference_input(self, max_output_lengths.append(self._get_max_output_length(num_tokens)) source[j, :num_tokens, 0] = tokens2ids(itertools.chain(trans_input.get_source_prefix_tokens(), \ trans_input.tokens), self.source_vocabs[0]) + if target_prefix is not None and trans_input.num_target_prefix_tokens > 0: + target_prefix[j, :trans_input.num_target_prefix_tokens] = tokens2ids(trans_input.get_target_prefix_tokens(), self.vocab_targets[0]) + if target_prefix_factors is not None and self.num_target_factors > 1 and trans_input.num_target_prefix_factors > 0: + for i in range(1, self.num_target_factors): + target_prefix_factors[j, :trans_input.num_target_prefix_factors, i - 1] = tokens2ids(trans_input.get_target_prefix_factors()[i - 1], self.vocab_targets[i]) factors = trans_input.factors if trans_input.factors is not None else [] num_factors = 1 + len(factors) if num_factors != self.num_source_factors: @@ -951,7 +1045,15 @@ def _get_inference_input(self, source = pt.tensor(source, device=self.device, dtype=pt.int32) # type: ignore source_length = pt.tensor(lengths, device=self.device, dtype=pt.int32) # shape: (batch_size,) max_output_lengths = pt.tensor(max_output_lengths, device=self.device, dtype=pt.int32) # type: ignore - return source, source_length, restrict_lexicon, max_output_lengths # type: ignore + target_prefix = pt.tensor(target_prefix, device=self.device, dtype=pt.int32) if target_prefix is not None else None # type: ignore + target_prefix_factors_tensor = pt.tensor(target_prefix_factors, device=self.device, dtype=pt.int32) if target_prefix_factors is not None else None # type: ignore + + # During inference, if C.TARGET_FACTOR_SHIFT is True, predicted target_factors are left-shifted (see _unshift_target_factors function()) so that \ + # they re-align with the words. With that, target_prefix_factors need to be also right-shifted here if C.TARGET_FACTOR_SHIFT is True so that when \ + # they are shifted back later they would align with words. + target_prefix_factors_tensor = utils.shift_prefix_factors(target_prefix_factors_tensor) if target_prefix_factors_tensor is not None and C.TARGET_FACTOR_SHIFT else target_prefix_factors_tensor # type: ignore + + return source, source_length, restrict_lexicon, max_output_lengths, target_prefix, target_prefix_factors_tensor # type: ignore def _get_translation_tokens_and_factors(self, target_ids: List[List[int]]) -> Tuple[List[str], str, @@ -1033,20 +1135,28 @@ def _translate_np(self, source: pt.Tensor, source_length: pt.Tensor, restrict_lexicon: Optional[lexicon.TopKLexicon], - max_output_lengths: pt.Tensor) -> List[Translation]: + max_output_lengths: pt.Tensor, + target_prefix: Optional[pt.Tensor] = None, + target_prefix_factors: Optional[pt.Tensor] = None) -> List[Translation]: """ Translates source of source_length and returns list of Translations. :param source: Source ids. Shape: (batch_size, bucket_key, num_factors). :param source_length: Valid source lengths. :param restrict_lexicon: Lexicon to use for vocabulary restriction. + :param max_output_lengths: Tensor of maximum output lengths per input in source. + Shape: (batch_size,). Dtype: int32. + :param target_prefix: Target prefix ids. + :param target_prefix_factors: Target prefix factors ids. :return: List of translations. """ return self._get_best_translations(self._search(source, source_length, restrict_lexicon, - max_output_lengths)) + max_output_lengths, + target_prefix, + target_prefix_factors)) def _get_best_translations(self, result: SearchResult) -> List[Translation]: """ diff --git a/sockeye/model.py b/sockeye/model.py index 1f84ea8ad..8b81902bc 100644 --- a/sockeye/model.py +++ b/sockeye/model.py @@ -127,6 +127,7 @@ def __init__(self, out_features=factor_config.vocab_size, bias=True) self.factor_output_layers.append(output_layer) + self.factor_vocab_size = factor_config.vocab_size if self.target_factor_configs else None self.length_ratio = None # type: Optional[layers.LengthRatio] if self.config.config_length_task is not None: diff --git a/sockeye/test_utils.py b/sockeye/test_utils.py index 598b3b939..2474cbd01 100644 --- a/sockeye/test_utils.py +++ b/sockeye/test_utils.py @@ -18,7 +18,7 @@ import sys from contextlib import contextmanager from tempfile import TemporaryDirectory -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from unittest.mock import patch import sockeye.constants as C @@ -58,6 +58,50 @@ def generate_digits_file(source_path: str, print(C.TOKEN_SEPARATOR.join(digits), file=target_out) +def generate_json_input_file_with_tgt_prefix(src_path:str, tgt_path: str, json_file_with_tgt_prefix_path: str, \ + src_factors_path: Optional[List[str]] = None, tgt_factors_path: List[str] = None, seed=13): + random_gen = random.Random(seed) + with open(src_path, "r") as src_reader, open(tgt_path, "r") as tgt_reader: + with open(json_file_with_tgt_prefix_path, "w") as out: + list_src_factors = None + list_tgt_factors = None + + if src_factors_path is not None: + list_src_factors = [open(src_factors, "r") for src_factors in src_factors_path] + list_src_factors = [[sf.strip() for sf in src_factors] for src_factors in list_src_factors] + + if tgt_factors_path is not None: + list_tgt_factors = [open(tgt_factors, "r") for tgt_factors in tgt_factors_path] + list_tgt_factors = [[tf.strip().split() for tf in tgt_factors] for tgt_factors in list_tgt_factors] + + for i, stdigits in enumerate(zip(src_reader, tgt_reader)): + src_digits, tgt_digits = stdigits[0].strip(), stdigits[1].strip() + tgt_prefix = tgt_digits.split() + if len(tgt_digits) > 0: + random_pos = random_gen.choice([pos for pos in range(len(tgt_prefix))]) + tgt_prefix = tgt_prefix[:random_pos] + if tgt_factors_path is not None and len(list_tgt_factors[0][i]) > 0: + # Another random_pos, which is different to the one used for target prefix + # With this, target prefix and target factors may have different lengths for testing + random_pos = random_gen.choice([pos for pos in range(len(list_tgt_factors[0][i]))]) + for k in range(len(list_tgt_factors)): + list_tgt_factors[k][i] = list_tgt_factors[k][i][:random_pos] + tgt_prefix = C.TOKEN_SEPARATOR.join(tgt_prefix) + if src_factors_path is None and tgt_factors_path is None: + jsone_line = {"text": src_digits, "target_prefix": tgt_prefix} + elif src_factors_path is not None and tgt_factors_path is None: + jsone_line = {"text": src_digits, "factors": [src_factors[i] for src_factors in list_src_factors], \ + "target_prefix": tgt_prefix} + elif tgt_factors_path is not None and src_factors_path is None: + jsone_line = {"text": src_digits, "target_prefix_factors": [C.TOKEN_SEPARATOR.join(tgt_factors[i]) for tgt_factors in list_tgt_factors], \ + "target_prefix": tgt_prefix} + else: + jsone_line = {"text": src_digits, "factors": [src_factors[i] for src_factors in list_src_factors], \ + "target_prefix_factors": [C.TOKEN_SEPARATOR.join(tgt_factors[i]) for tgt_factors in list_tgt_factors], \ + "target_prefix": tgt_prefix} + print(json.dumps(jsone_line), file=out) + + def generate_low_high_factors(input_path: str, output_path: str): """ Writes low/high factor file given a file of digit sequences. @@ -115,6 +159,7 @@ def tmp_digits_dataset(prefix: str, dev_target_path = os.path.join(work_dir, "dev.tgt") test_source_path = os.path.join(work_dir, "test.src") test_target_path = os.path.join(work_dir, "test.tgt") + test_source_with_target_prefix_path = os.path.join(work_dir, "test_source_with_target_prefix.json") generate_digits_file(train_source_path, train_target_path, train_line_count, train_max_length, line_count_empty=train_line_count_empty, sort_target=sort_target, seed=seed_train) generate_digits_file(dev_source_path, dev_target_path, dev_line_count, dev_max_length, sort_target=sort_target, @@ -127,7 +172,8 @@ def tmp_digits_dataset(prefix: str, 'dev_source': dev_source_path, 'dev_target': dev_target_path, 'test_source': test_source_path, - 'test_target': test_target_path} + 'test_target': test_target_path, + 'test_source_with_target_prefix': test_source_with_target_prefix_path} if with_n_source_factors > 0: data['train_source_factors'] = [] @@ -157,8 +203,12 @@ def tmp_digits_dataset(prefix: str, generate_odd_even_factors(test_target_path, test_factor_path) data['train_target_factors'].append(train_factor_path) data['dev_target_factors'].append(dev_factor_path) - data['test_target_factors'].append(dev_factor_path) + data['test_target_factors'].append(test_factor_path) + source_factors_path = None if 'test_source_factors' not in data else data['test_source_factors'] + target_factors_path = None if 'test_target_factors' not in data else data['test_target_factors'] + generate_json_input_file_with_tgt_prefix(test_source_path, test_target_path, test_source_with_target_prefix_path, \ + source_factors_path, target_factors_path) yield data @@ -183,6 +233,8 @@ def tmp_digits_dataset(prefix: str, TRANSLATE_WITH_FACTORS_COMMON = " --input-factors {input_factors}" +TRANSLATE_WITH_JSON_FORMAT = " --json-input" + TRANSLATE_PARAMS_RESTRICT = "--restrict-lexicon {lexicon} --restrict-lexicon-topk {topk}" SCORE_PARAMS_COMMON = "--use-cpu --model {model} --source {source} --target {target} --output {output} " @@ -290,6 +342,23 @@ def run_train_translate(train_params: str, # Translate corpus with the 1st params and scoring output handler to obtain scores data['test_output'] = os.path.join(work_dir, "test.out") + data['test_with_target_prefix_output'] = os.path.join(work_dir, "test_with_target_prefix.out") + + # First set of params (with target prefix in JSON format) + params = "{} {} {}".format(sockeye.translate.__file__, + TRANSLATE_PARAMS_COMMON.format(model=data['model'], + input=data['test_source_with_target_prefix'], + output=data['test_with_target_prefix_output']), + translate_params) + params += TRANSLATE_WITH_JSON_FORMAT + logger.info("Translating with params %s", params) + with patch.object(sys, "argv", params.split()): + sockeye.translate.main() + + # Collect test translate outputs and scores + data['test_with_target_prefix_outputs'] = collect_translate_output_and_scores(data['test_with_target_prefix_output']) + + # Second set of params (without target prefix) params = "{} {} {}".format(sockeye.translate.__file__, TRANSLATE_PARAMS_COMMON.format(model=data['model'], input=data['test_source'], @@ -313,7 +382,7 @@ def run_train_translate(train_params: str, # Collect test translate outputs and scores data['test_outputs'] = collect_translate_output_and_scores(data['test_output']) - assert len(data['test_inputs']) == len(data['test_targets']) == len(data['test_outputs']) + assert len(data['test_inputs']) == len(data['test_targets']) == len(data['test_outputs']) == len(data['test_with_target_prefix_outputs']) return data @@ -324,7 +393,24 @@ def run_translate_restrict(data: Dict[str, Any], translate_params: str) -> Dict[ """ translate_mod = sockeye.translate out_path = os.path.join(data['work_dir'], "out-restrict.txt") + out_with_target_prefix_path = os.path.join(data['work_dir'], "out-with-target-prefix-restrict.txt") # Translate corpus with restrict-lexicon + + # First set of params (with target prefix in JSON format) + params = "{} {} {} {}".format(translate_mod.__file__, + TRANSLATE_PARAMS_COMMON.format(model=data['model'], + input=data['test_source_with_target_prefix'], + output=out_with_target_prefix_path), + translate_params, + TRANSLATE_PARAMS_RESTRICT.format(lexicon=data['lexicon'], topk=1)) + params += TRANSLATE_WITH_JSON_FORMAT + with patch.object(sys, "argv", params.split()): + translate_mod.main() + + # Collect test translate outputs and scores + data['test_with_target_prefix_outputs_restricted'] = collect_translate_output_and_scores(out_with_target_prefix_path) + + # Second set of params (without using target prefix) params = "{} {} {} {}".format(translate_mod.__file__, TRANSLATE_PARAMS_COMMON.format(model=data['model'], input=data['test_source'], @@ -338,7 +424,7 @@ def run_translate_restrict(data: Dict[str, Any], translate_params: str) -> Dict[ # Collect test translate outputs and scores data['test_outputs_restricted'] = collect_translate_output_and_scores(out_path) - assert len(data['test_outputs_restricted']) == len(data['test_outputs']) + assert len(data['test_with_target_prefix_outputs_restricted']) == len(data['test_outputs_restricted']) == len(data['test_outputs']) return data diff --git a/sockeye/utils.py b/sockeye/utils.py index 332df4d2c..953ce5fa7 100644 --- a/sockeye/utils.py +++ b/sockeye/utils.py @@ -264,6 +264,152 @@ def average_tensors(tensors: List[pt.Tensor]) -> pt.Tensor: return sum(tensors) / len(tensors) # type: ignore +def gen_prefix_masking(prefix: pt.Tensor, vocab_size: int, dtype: pt.dtype) -> pt.Tensor: + """ + Generate prefix masks from prefix ids, which are inf everywhere except zero for prefix ids. + + :param prefix: Target prefix token or factors in ids. Shape (batch size, max length of prefix). + :param vocab_size: vocabulary size + :param dtype: dtype of the retuning output + :return prefix_masks (batch size, max length of prefix, vocab_size), with type as dtype + + """ + prefix_masks_sizes = [s for s in prefix.size()] + prefix_masks_sizes.append(vocab_size) + + # prefix_masks are inf everywhere except zero for indices of prefix ids. + prefix_masks = pt.full(prefix_masks_sizes, fill_value=np.inf, device=prefix.device, dtype=dtype) + prefix_masks.scatter_(-1, prefix.to(pt.int64).unsqueeze(-1), 0.) + # Note: The use of prefix_masks.scatter_() function is equivalent (but much faster) to + # prefix_masks[prefix_one_hot != 0] = 0., where + # prefix_one_hot = pt.nn.functional.one_hot(prefix.to(pt.int64), num_classes=vocab_size).to(prefix.device) + + # In the same batch during inference, it is possible that some translations have target prefix + # while others do not have. It is also possible that translation may have a target prefix with + # different length to others. Thus prefix ids may include a full zero vector if a translation + # in the batch does not have prefix, or include a vector padding with zeros on the right if some + # translations are with shorter prefix. An example of prefix ids reflecting length differences \ + # is as follows: + # + # [1, 2, 3] + # [1, 2, 0] + # [0, 0, 0] + # + # Here, the first sentence has a prefix of length 3, the second one has a prefix of length 1 \ + # and the last one does not have prefix. + # + # At any timestep, some target prefix ids could be 0 (i.e. 0 in the target_prefix means 'no constraint'). \ + # If a prefix id is 0 for a translation at a timestep, all hots in the vocab are assigned to 0 (instead \ + # of only one hot is assigned to 0 and other hots are inf). This makes sure there is no constraint on \ + # selecting any specific target token for the translation in that case. + + prefix_masks.masked_fill_(prefix.unsqueeze(-1) == 0, 0) + return prefix_masks + + +def shift_prefix_factors(prefix_factors: pt.Tensor) -> pt.Tensor: + """ + Shift prefix factors one step to the right + + :param prefix_factors: tensor ids. Shape (batch size, length, num of factors). + :return new prefix_factors_shift (batch size, length + 1, num of factors) + """ + prefix_factors_sizes = prefix_factors.size() + prefix_factors_shift = pt.zeros(prefix_factors_sizes[0], prefix_factors_sizes[1] + 1, prefix_factors_sizes[2], dtype=prefix_factors.dtype, device=prefix_factors.device) + prefix_factors_shift[:, 1:] = prefix_factors + return prefix_factors_shift + + +def adjust_first_step_masking(target_prefix: pt.Tensor, first_step_mask: pt.Tensor) -> pt.Tensor: + """ + Adjust first_step_masking based on the target prefix + (Target prefix for each input in the same batch may have a different length. \ + Thus first_step_mask needs to be adjusted accordingly.) + + :param target_prefix: Shape (batch size, max target prefix length). + :param first_step_mask: Shape (batch_size * beam_size, 1) + :return (adjusted) first_steps_masking (batch_size * beam_size, max target prefix length + 1). + + An illustrative example of how first_step_masking is adjusted + + Inputs: + + target_prefix (batch_size = 2, max target prefix length = 2) + + tensor([1 2] + [1 0]) + Note: Two target prefix tokens in the first sentence, \ + one target prefix token in the second sentence. + + first_step_mask (batch_size = 2 * beam_size = 5, 1) + + tensor([[0], + [inf], + [inf], + [inf], + [inf], + [0], + [inf], + [inf], + [inf], + [inf]) + + Output: + Adjusted first_step_mask (batch_size * beam_size, max target prefix length + 1): + + tensor([[0 0 0], + [inf inf inf], + [inf inf inf], + [inf inf inf], + [inf inf, inf], + [0 0 0], + [inf inf 0], + [inf inf 0], + [inf inf 0], + [inf inf 0]]) + + The concrete steps of what this function does are as follows: + + Step 1: Create a zero masking matrix with shape (batch size, max target prefix length + 1) + Fill 1 into this masking matrix based on the target prefix + + target prefix initialize masking masking roll one step to the right + from target prefix is not 0 and assign 1 at index 0 + [1 2] -> [1 2 0] -> [1 1 0] -> [1 1 1] + [1 0] [1 0 0] [1 0 0] [1 1 0] + + Step 2: Adjust first_step_mask based on masking + + masking expand masking with expand first_step_mask with max target + beam size prefix length, fill 0 where masking is 0 + [1 1 1] -> [1 1 1] -> [0 0 0] + [1 1 0] [1 1 1] [inf inf inf] + [1 1 1] [inf inf inf] + [1 1 1] [inf inf inf] + [1 1 1] [inf inf inf] + [1 1 0] [0 0 0] + [1 1 0] [inf inf 0] + [1 1 0] [inf inf 0] + [1 1 0] [inf inf 0] + [1 1 0] [inf inf 0] + """ + batch_beam, _ = first_step_mask.size() + batch, max_prefix_len = target_prefix.size() + beam_size = batch_beam // batch + # Step 1 + masking = pt.zeros((batch, max_prefix_len + 1), device=target_prefix.device) + masking[:, :max_prefix_len] = target_prefix + masking = pt.clamp(masking, 0., 1.) # force all non zero ids to 1 + masking = pt.roll(masking, 1, -1) + masking[:, 0] = 1. + + # Step 2 + masking = masking.unsqueeze(1).expand(-1, beam_size, -1).reshape(batch_beam, -1) + first_step_mask = first_step_mask.expand(-1, masking.size(-1)).clone() + first_step_mask.masked_fill_(masking == 0., 0.) + return first_step_mask + + def parse_metrics_line(line_number: int, line: str) -> Dict[str, Any]: """ Parse a line of metrics into a mappings of key and values. diff --git a/test/common.py b/test/common.py index ed8e9af4b..fada0d226 100644 --- a/test/common.py +++ b/test/common.py @@ -25,7 +25,7 @@ from sockeye.test_utils import run_train_translate, run_translate_restrict, \ TRANSLATE_PARAMS_COMMON, TRANSLATE_WITH_FACTORS_COMMON, \ collect_translate_output_and_scores, SCORE_PARAMS_COMMON, \ - SCORE_WITH_SOURCE_FACTORS_COMMON, SCORE_WITH_TARGET_FACTORS_COMMON + SCORE_WITH_SOURCE_FACTORS_COMMON, SCORE_WITH_TARGET_FACTORS_COMMON, TRANSLATE_WITH_JSON_FORMAT logger = logging.getLogger(__name__) @@ -66,7 +66,7 @@ def check_train_translate(train_params: str, # - translate splits up too-long sentences and translates them in sequence, invalidating the score, so skip that # - scoring requires valid translation output to compare against if '--max-input-length' not in translate_params and _translate_output_is_valid(data['test_outputs']) \ - and 'greedy' not in translate_params: + and _translate_output_is_valid(data['test_with_target_prefix_outputs']) and 'greedy' not in translate_params: test_scoring(data, translate_params, compare_output) # Test correct prediction of target factors if enabled @@ -82,6 +82,22 @@ def test_translate_equivalence(data: Dict[str, Any], translate_params_equiv: str the previously generated outputs, referenced in the data dictionary. """ out_path = os.path.join(data['work_dir'], "test.out.equiv") + out_with_target_prefix_path = os.path.join(data['work_dir'], "test_with_target_prefix.out.equiv") + + # First set of params (with target prefix in JSON format) + params = "{} {} {}".format(sockeye.translate.__file__, + TRANSLATE_PARAMS_COMMON.format(model=data['model'], + input=data['test_source_with_target_prefix'], + output=out_with_target_prefix_path), + translate_params_equiv) + params += TRANSLATE_WITH_JSON_FORMAT + with patch.object(sys, "argv", params.split()): + sockeye.translate.main() + + # Collect translate outputs and scores + translate_outputs_with_target_prefix_equiv = collect_translate_output_and_scores(out_with_target_prefix_path) + + # Second set of params (without using target prefix) params = "{} {} {}".format(sockeye.translate.__file__, TRANSLATE_PARAMS_COMMON.format(model=data['model'], input=data['test_source'], @@ -95,14 +111,37 @@ def test_translate_equivalence(data: Dict[str, Any], translate_params_equiv: str translate_outputs_equiv = collect_translate_output_and_scores(out_path) assert 'test_outputs' in data - assert len(data['test_outputs']) == len(translate_outputs_equiv) + assert 'test_with_target_prefix_outputs' in data + assert len(data['test_outputs']) == len(data['test_with_target_prefix_outputs']) == len(translate_outputs_with_target_prefix_equiv) == len(translate_outputs_equiv) if compare_output: - for json_output, json_output_equiv in zip(data['test_outputs'], translate_outputs_equiv): + for json_output, json_output_with_target_prefix, json_output_equiv, json_output_with_target_prefix_equiv in zip(data['test_outputs'], data['test_with_target_prefix_outputs'], translate_outputs_equiv, translate_outputs_with_target_prefix_equiv): assert json_output['translation'] == json_output_equiv['translation'], \ f"'{json_output['translation']}' vs. '{json_output_equiv['translation']}'" + assert json_output_with_target_prefix['translation'] == json_output_with_target_prefix_equiv['translation'], \ + f"'{json_output_with_target_prefix['translation']}' vs. '{json_output_with_target_prefix_equiv['translation']}'" assert abs(json_output['score'] - json_output_equiv['score']) < 0.01 or \ np.isnan(json_output['score'] - json_output_equiv['score']), \ f"'{json_output['score']}' vs. '{ json_output_equiv['score']}'" + assert abs(json_output_with_target_prefix['score'] - json_output_with_target_prefix_equiv['score']) < 0.01 or \ + np.isnan(json_output_with_target_prefix['score'] - json_output_with_target_prefix_equiv['score']), \ + f"'{json_output_with_target_prefix['score']}' vs. '{ json_output_with_target_prefix_equiv['score']}'" + + # Check translation output always includes target prefix tokens + prefix = json_output_with_target_prefix['target_prefix'].split() + translation = json_output_with_target_prefix['translation'].split() + ending = min(len(prefix), len(translation)) + assert prefix[:ending] == translation[:ending], \ + f"'{prefix[:ending]}' vs. '{translation[:ending]}'" + + # Check translation output factors always include target prefix factors + if 'target_prefix_factors' in json_output_with_target_prefix: + prefix = json_output_with_target_prefix['target_prefix_factors'] + if len(prefix) > 0: + for j in range(1, len(prefix) + 1): + factors_from_translation = json_output_with_target_prefix[f'factor{j}'] + ending = min(len(prefix[j - 1]), len(factors_from_translation)) + assert prefix[j - 1][:ending] == factors_from_translation[:ending], \ + f"'{prefix[j - 1][:ending]}' vs. '{factors_from_translation[:ending]}' from . '{json_output_with_target_prefix}'" def test_scoring(data: Dict[str, Any], translate_params: str, test_similar_scores: bool): @@ -121,6 +160,7 @@ def test_scoring(data: Dict[str, Any], translate_params: str, test_similar_score if param in relevant_params: score_params = '{} {}'.format(param, params[i + 1]) out_path = os.path.join(data['work_dir'], "score.out") + out_with_target_prefix_path = os.path.join(data['work_dir'], "score_with_target_prefix.out") # write translate outputs as target file for scoring and collect tokens # also optionally collect factor outputs @@ -132,9 +172,42 @@ def test_scoring(data: Dict[str, Any], translate_params: str, test_similar_score for json_output in data['test_outputs']: print(json_output['translation'], file=target_out) for i, factor_out in enumerate(target_factor_outs, 1): - factor = json_output['factor%d' % i] + factor = json_output[f'factor{i}'] print(factor, file=factor_out) + target_with_target_prefix_path = os.path.join(data['work_dir'], "score_with_target_prefix.target") + target_with_target_prefix_factor_paths = [os.path.join(data['work_dir'], f"score_with_target_prefix.target.factor{i}") for i, _ in + enumerate(data.get('test_target_factors', []), 1)] + with open(target_with_target_prefix_path, 'w') as target_out, ExitStack() as exit_stack: + target_factor_outs = [exit_stack.enter_context(open(p, 'w')) for p in target_with_target_prefix_factor_paths] + for json_output in data['test_with_target_prefix_outputs']: + print(json_output['translation'], file=target_out) + for i, factor_out in enumerate(target_factor_outs, 1): + factor = json_output[f'factor{i}'] + print(factor, file=factor_out) + + + # First set of params (with target prefix in JSON format) + params = "{} {} {}".format(sockeye.score.__file__, + SCORE_PARAMS_COMMON.format(model=data['model'], + source=data['test_source'], + target=target_with_target_prefix_path, + output=out_with_target_prefix_path), + score_params) + if 'test_source_factors' in data: + params += SCORE_WITH_SOURCE_FACTORS_COMMON.format(source_factors=" ".join(data['test_source_factors'])) + if target_with_target_prefix_factor_paths: + params += SCORE_WITH_TARGET_FACTORS_COMMON.format(target_factors=" ".join(target_with_target_prefix_factor_paths)) + + logger.info("Scoring with params %s", params) + with patch.object(sys, "argv", params.split()): + sockeye.score.main() + + # Collect scores from output file + with open(out_with_target_prefix_path) as score_out: + data_scoring_with_target_prefix = [[float(x) for x in line.strip().split('\t')] for line in score_out] + + # Second set of params (without target prefix) params = "{} {} {}".format(sockeye.score.__file__, SCORE_PARAMS_COMMON.format(model=data['model'], source=data['test_source'], @@ -155,9 +228,8 @@ def test_scoring(data: Dict[str, Any], translate_params: str, test_similar_score data_scoring = [[float(x) for x in line.strip().split('\t')] for line in score_out] if test_similar_scores: - for inp, translate_json, score_scores in zip(data['test_inputs'], - data['test_outputs'], - data_scoring): + for inp, translate_json, translate_with_target_prefix_json, score_scores, score_with_target_prefix_scores in zip\ + (data['test_inputs'], data['test_outputs'], data['test_with_target_prefix_outputs'], data_scoring, data_scoring_with_target_prefix): score_score, *factor_scores = score_scores translate_tokens = translate_json['translation'].split() translate_score = translate_json['score'] @@ -169,6 +241,17 @@ def test_scoring(data: Dict[str, Any], translate_params: str, test_similar_score "input: %s || tokens: %s || translate score: %.6f || score score: %.6f" % (inp, translate_tokens, translate_score, score_score) + score_score, *factor_scores = score_with_target_prefix_scores + translate_tokens = translate_with_target_prefix_json['translation'].split() + translate_score = translate_with_target_prefix_json['score'] + logger.info("tokens: %s || translate score: %.4f || score score: %.4f", + translate_tokens, translate_score, score_score) + assert (translate_score == -np.inf and score_score == -np.inf) or np.isclose(translate_score, + score_score, + atol=1e-06),\ + "input: %s || tokens: %s || translate score: %.6f || score score: %.6f" % (inp, translate_tokens, + translate_score, + score_score) def _translate_output_is_valid(translate_outputs: List[str]) -> bool: diff --git a/test/unit/test_beam_search.py b/test/unit/test_beam_search.py index 761576a48..14c39f85c 100644 --- a/test/unit/test_beam_search.py +++ b/test/unit/test_beam_search.py @@ -258,7 +258,7 @@ def encode_and_initialize(self, def decode_step(self, step_input: pt.Tensor, states: List, - vocab_slice_ids: Optional[pt.Tensor] = None): + vocab_slice_ids: Optional[pt.Tensor] = None, *args): batch_beam_size, num_target_factors = step_input.size() print('step_input', step_input) @@ -286,6 +286,14 @@ def decode_step(self, self.states = states = [internal_lengths, pt.tensor([num_decode_step_calls], dtype=pt.int)] return scores, states, None + @property + def model_output_vocab_size(self): + return self.output_vocab_size + + @property + def model_output_factor_vocab_size(self): + return None + # TODO make this a useful test # TODO: add vocabulary selection test diff --git a/test/unit/test_inference.py b/test/unit/test_inference.py index 727ee90e4..3d633dcf1 100644 --- a/test/unit/test_inference.py +++ b/test/unit/test_inference.py @@ -168,7 +168,7 @@ def test_translator_input_with_source_prefix(sentence_id, sentence, factors, chu assert chunk_input.sentence_id == sentence_id assert chunk_input.tokens == trans_input.tokens[chunk_id * chunk_size: (chunk_id + 1) * chunk_size] assert chunk_input.source_prefix_tokens == trans_input.source_prefix_tokens - assert chunk_input.num_source_prefix_tokens() == trans_input.num_source_prefix_tokens() + assert chunk_input.num_source_prefix_tokens == trans_input.num_source_prefix_tokens if source_prefix_factors is not None: assert len(chunk_input.source_prefix_factors) == len(source_prefix_factors) for chunk_input_source_prefix_factor, source_prefix_factor in zip(chunk_input.source_prefix_factors, trans_input.source_prefix_factors): diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py index 433a17c79..dc2e0988b 100644 --- a/test/unit/test_utils.py +++ b/test/unit/test_utils.py @@ -250,3 +250,122 @@ def test_write_read_metric_file(): assert len(read_metrics) == len(expected_metrics) assert expected_metrics == read_metrics + + +def test_adjust_first_step_masking(): + first_step_mask = pt.tensor([[0.], + [np.inf], + [np.inf], + [np.inf], + [0.], + [np.inf], + [np.inf], + [np.inf]]) + target_prefix = pt.tensor([[1, 2], [1, 0]]) + adjust_first_step_mask = pt.tensor([[0., 0., 0.], + [np.inf, np.inf, np.inf], + [np.inf, np.inf, np.inf], + [np.inf, np.inf, np.inf], + [0., 0., 0.], + [np.inf, np.inf, 0.], + [np.inf, np.inf, 0.], + [np.inf, np.inf, 0.]]) + assert pt.equal(adjust_first_step_mask, utils.adjust_first_step_masking(target_prefix, first_step_mask)) == True + target_prefix = pt.tensor([[1, 0], [2, 3]]) + adjust_first_step_mask = pt.tensor([[0., 0., 0.], + [np.inf, np.inf, 0.], + [np.inf, np.inf, 0.], + [np.inf, np.inf, 0.], + [0., 0., 0.], + [np.inf, np.inf, np.inf], + [np.inf, np.inf, np.inf], + [np.inf, np.inf, np.inf]]) + assert pt.equal(adjust_first_step_mask, utils.adjust_first_step_masking(target_prefix, first_step_mask)) == True + target_prefix = pt.tensor([[1, 0, 0], [2, 3, 4]]) + adjust_first_step_mask = pt.tensor([[0., 0., 0., 0.], + [np.inf, np.inf, 0., 0.], + [np.inf, np.inf, 0., 0.], + [np.inf, np.inf, 0., 0.], + [0., 0., 0., 0.], + [np.inf, np.inf, np.inf, np.inf], + [np.inf, np.inf, np.inf, np.inf], + [np.inf, np.inf, np.inf, np.inf]]) + assert pt.equal(adjust_first_step_mask, utils.adjust_first_step_masking(target_prefix, first_step_mask)) == True + target_prefix = pt.tensor([[1, 0, 0, 0], [2, 3, 4, 5]]) + adjust_first_step_mask = pt.tensor([[0., 0., 0., 0., 0.], + [np.inf, np.inf, 0., 0., 0.], + [np.inf, np.inf, 0., 0., 0.], + [np.inf, np.inf, 0., 0., 0.], + [0., 0., 0., 0., 0.], + [np.inf, np.inf, np.inf, np.inf, np.inf], + [np.inf, np.inf, np.inf, np.inf, np.inf], + [np.inf, np.inf, np.inf, np.inf, np.inf]]) + assert pt.equal(adjust_first_step_mask, utils.adjust_first_step_masking(target_prefix, first_step_mask)) == True + first_step_mask = pt.tensor([[0.], + [np.inf], + [np.inf], + [np.inf], + [0.], + [np.inf], + [np.inf], + [np.inf], + [0.], + [np.inf], + [np.inf], + [np.inf]]) + target_prefix = pt.tensor([[1, 0, 0, 0], [1, 0, 0, 0], [2, 3, 4, 5]]) + adjust_first_step_mask = pt.tensor([[0., 0., 0., 0., 0.], + [np.inf, np.inf, 0., 0., 0.], + [np.inf, np.inf, 0., 0., 0.], + [np.inf, np.inf, 0., 0., 0.], + [0., 0., 0., 0., 0.], + [np.inf, np.inf, 0., 0., 0.], + [np.inf, np.inf, 0., 0., 0.], + [np.inf, np.inf, 0., 0., 0.], + [0., 0., 0., 0., 0.], + [np.inf, np.inf, np.inf, np.inf, np.inf], + [np.inf, np.inf, np.inf, np.inf, np.inf], + [np.inf, np.inf, np.inf, np.inf, np.inf]]) + assert pt.equal(adjust_first_step_mask, utils.adjust_first_step_masking(target_prefix, first_step_mask)) == True + target_prefix = pt.tensor([[1, 0, 0, 0], [1, 3, 0, 0], [2, 3, 4, 5]]) + adjust_first_step_mask = pt.tensor([[0., 0., 0., 0., 0.], + [np.inf, np.inf, 0., 0., 0.], + [np.inf, np.inf, 0., 0., 0.], + [np.inf, np.inf, 0., 0., 0.], + [0., 0., 0., 0., 0.], + [np.inf, np.inf, np.inf, 0., 0.], + [np.inf, np.inf, np.inf, 0., 0.], + [np.inf, np.inf, np.inf, 0., 0.], + [0., 0., 0., 0., 0.], + [np.inf, np.inf, np.inf, np.inf, np.inf], + [np.inf, np.inf, np.inf, np.inf, np.inf], + [np.inf, np.inf, np.inf, np.inf, np.inf]]) + assert pt.equal(adjust_first_step_mask, utils.adjust_first_step_masking(target_prefix, first_step_mask)) == True + target_prefix = pt.tensor([[0, 0, 0, 0], [1, 3, 0, 0], [2, 3, 4, 5]]) + adjust_first_step_mask = pt.tensor([[0., 0., 0., 0., 0.], + [np.inf, 0., 0., 0., 0.], + [np.inf, 0., 0., 0., 0.], + [np.inf, 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [np.inf, np.inf, np.inf, 0., 0.], + [np.inf, np.inf, np.inf, 0., 0.], + [np.inf, np.inf, np.inf, 0., 0.], + [0., 0., 0., 0., 0.], + [np.inf, np.inf, np.inf, np.inf, np.inf], + [np.inf, np.inf, np.inf, np.inf, np.inf], + [np.inf, np.inf, np.inf, np.inf, np.inf]]) + assert pt.equal(adjust_first_step_mask, utils.adjust_first_step_masking(target_prefix, first_step_mask)) == True + target_prefix = pt.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [2, 3, 4, 5]]) + adjust_first_step_mask = pt.tensor([[0., 0., 0., 0., 0.], + [np.inf, 0., 0., 0., 0.], + [np.inf, 0., 0., 0., 0.], + [np.inf, 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [np.inf, 0., 0., 0., 0.], + [np.inf, 0., 0., 0., 0.], + [np.inf, 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [np.inf, np.inf, np.inf, np.inf, np.inf], + [np.inf, np.inf, np.inf, np.inf, np.inf], + [np.inf, np.inf, np.inf, np.inf, np.inf]]) + assert pt.equal(adjust_first_step_mask, utils.adjust_first_step_masking(target_prefix, first_step_mask)) == True