Skip to content

Commit

Permalink
Support target prefix with JSON format (#1025)
Browse files Browse the repository at this point in the history
This feature allows Sockeye to add target prefix and target prefix factors during inference with JSON format. During inference target prefix can be specified with JSON format as follows:

{ "text": "The boy ate the waff@@ le .", "target_prefix": "2XX"}
If a model was trained with target factors, we can add target prefix factors during inference with JSON format as follows:

{ "text": "The boy ate the waff@@ le .", "target_prefix_factors": ["O"]}
Meanwhile, we can also add both target prefix and target prefix factors at the same time with JSON format, e.g.,:

{ "text": "The boy ate the waff@@ le .", "target_prefix": "2XX", "target_prefix_factors": ["O"]}
Note that if an input is very long, Sockeye chunks the text and translates each chunk separately. By default, target prefix and target prefix factors are added to all chunks in that case. Alternatively, we can set use_target_prefix_all_chunks to false to add them only to the first chunk, e.g.,:

{ "text": "The boy ate the waff@@ le .", "target_prefix": "2XX", "target_prefix_factors": ["O"], "use_target_prefix_all_chunks": false}

* support target prefix

* revise illustration

* fix space

* add type ignore

* add target prefix, revision 2

* add target prefix factors

* slightly revise docs

* revise based on Michael's suggestions

* revise based on Felix's comments

* small revise type for pylint

* revise based on Tobias suggestion

* revised based on Felix's comments

* one_hot_encoding_from_prefix function for a full tensor

* revise warning of empty prefix

* use clamp instead of masked_fill_

* cleaner adding of target_prefix_factors to decode_step)

* put pt.index_select vocab_slice_ids before beam loop

* revised with prefix masking

* revise a tiny comment

* pre_expand outside, suggested by Tobi

* pre_expand outside, a small fix

* competely pre_expand outside, gen_prefix_masking generated only once

* fix factors and add prefix to vocab_slice_ids by pytorch

* small mypy fix

* small fix mypy

* minor revision

* avoiding unnecesary copy tensor

* avoid duplicate padding

* extra revision suggested by Tobias

* Fix translatorInput with pylint

* Fix translatorInput with pylint

Co-authored-by: Hoang <[email protected]>
  • Loading branch information
hoangcuong2011 and Hoang authored Mar 10, 2022
1 parent ea08143 commit edac700
Show file tree
Hide file tree
Showing 13 changed files with 737 additions and 47 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [3.1.4]

### Added
- Added support for the use of adding target prefix and target prefix factors to the input in JSON format during inference.

## [3.1.3]

### Added
Expand Down
35 changes: 35 additions & 0 deletions docs/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,41 @@ Similar to source factors, source prefix factors can be also specified with JSON
{ "text": "The boy ate the waff@@ le .", "source_prefix": "2XX", "source_prefix_factors": ["O"]}
```

Finally, Sockeye also supports the use of adding target prefix and target prefix factors to the translation during inference. In the same spirit to the example above, let us assume a multilingual translation model trained with a target prefix 2XX (this time the prefix is added to the target sentence instead of the source sentence). During inference this target prefix can be specified with JSON format as follows:

```json
{ "text": "The boy ate the waff@@ le .", "target_prefix": "2XX"}
```

This forces the decoder to generate `2XX` as its first target token (i.e. the one right after the `<bos>` token).

If your model was trained with target factors, every target translation token aligns with one or more corresponding target factor tokens (depending of the number of target factors of the model). During inference, you can add target prefix factors to the translation with JSON format, e.g.:

```json
{ "text": "The boy ate the waff@@ le .", "target_prefix_factors": ["O"]}
```

Here, the decoder is forced to generate a translation and its corresponding target factors so that the first target token aligns with factor `O` as its target factor.

Note that you can also add both target prefix and target prefix factors with different length, e.g.,:

```json
{ "text": "The boy ate the waff@@ le .", "target_prefix": "2XX", "target_prefix_factors": ["O O E"]}
```
With this example, `2XX` is the force-decoded first target token of the translation. This token also aligns with factor `O` its corresponding target factor. Moreover, the next two target tokens after `2XX` align with `O E` as their corresponding target factors.

Note that if an input is very long, Sockeye chunks the text and translates each chunk separately. By default, target prefix and target prefix factors are added to all chunks in that case. Alternatively, you can set `use_target_prefix_all_chunks` to `false` to add them only to the first chunk, e.g.,:

```json
{ "text": "The boy ate the waff@@ le .", "target_prefix": "2XX", "target_prefix_factors": ["O"], "use_target_prefix_all_chunks": false}
```

Note also that the translation output includes the target prefix as its first string by default. Alternatively, you can remove the target prefix from the translation output by setting `keep_target_prefix` to `false`, e.g.,:

```json
{ "text": "The boy ate the waff@@ le .", "target_prefix": "2XX", "keep_target_prefix": false}
```

## N-best translations

Sockeye can return the n best hypotheses per input (*nbest lists*).
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '3.1.3'
__version__ = '3.1.4'
131 changes: 112 additions & 19 deletions sockeye/beam_search.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions sockeye/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@
JSON_FACTORS_KEY = "factors"
JSON_SOURCE_PREFIX_KEY = "source_prefix"
JSON_SOURCE_PREFIX_FACTORS_KEY = "source_prefix_factors"
JSON_TARGET_PREFIX_KEY = "target_prefix"
JSON_TARGET_PREFIX_FACTORS_KEY = "target_prefix_factors"
JSON_USE_TARGET_PREFIX_ALL_CHUNKS_KEY = "use_target_prefix_all_chunks"
JSON_KEEP_TARGET_PREFIX_KEY = "keep_target_prefix"
JSON_RESTRICT_LEXICON_KEY = "restrict_lexicon"
JSON_CONSTRAINTS_KEY = "constraints"
JSON_AVOID_KEY = "avoid"
Expand Down
134 changes: 122 additions & 12 deletions sockeye/inference.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions sockeye/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(self,
out_features=factor_config.vocab_size,
bias=True)
self.factor_output_layers.append(output_layer)
self.factor_vocab_size = factor_config.vocab_size if self.target_factor_configs else None

self.length_ratio = None # type: Optional[layers.LengthRatio]
if self.config.config_length_task is not None:
Expand Down
96 changes: 91 additions & 5 deletions sockeye/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import sys
from contextlib import contextmanager
from tempfile import TemporaryDirectory
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from unittest.mock import patch

import sockeye.constants as C
Expand Down Expand Up @@ -58,6 +58,50 @@ def generate_digits_file(source_path: str,
print(C.TOKEN_SEPARATOR.join(digits), file=target_out)


def generate_json_input_file_with_tgt_prefix(src_path:str, tgt_path: str, json_file_with_tgt_prefix_path: str, \
src_factors_path: Optional[List[str]] = None, tgt_factors_path: List[str] = None, seed=13):
random_gen = random.Random(seed)
with open(src_path, "r") as src_reader, open(tgt_path, "r") as tgt_reader:
with open(json_file_with_tgt_prefix_path, "w") as out:
list_src_factors = None
list_tgt_factors = None

if src_factors_path is not None:
list_src_factors = [open(src_factors, "r") for src_factors in src_factors_path]
list_src_factors = [[sf.strip() for sf in src_factors] for src_factors in list_src_factors]

if tgt_factors_path is not None:
list_tgt_factors = [open(tgt_factors, "r") for tgt_factors in tgt_factors_path]
list_tgt_factors = [[tf.strip().split() for tf in tgt_factors] for tgt_factors in list_tgt_factors]

for i, stdigits in enumerate(zip(src_reader, tgt_reader)):
src_digits, tgt_digits = stdigits[0].strip(), stdigits[1].strip()
tgt_prefix = tgt_digits.split()
if len(tgt_digits) > 0:
random_pos = random_gen.choice([pos for pos in range(len(tgt_prefix))])
tgt_prefix = tgt_prefix[:random_pos]
if tgt_factors_path is not None and len(list_tgt_factors[0][i]) > 0:
# Another random_pos, which is different to the one used for target prefix
# With this, target prefix and target factors may have different lengths for testing
random_pos = random_gen.choice([pos for pos in range(len(list_tgt_factors[0][i]))])
for k in range(len(list_tgt_factors)):
list_tgt_factors[k][i] = list_tgt_factors[k][i][:random_pos]
tgt_prefix = C.TOKEN_SEPARATOR.join(tgt_prefix)
if src_factors_path is None and tgt_factors_path is None:
jsone_line = {"text": src_digits, "target_prefix": tgt_prefix}
elif src_factors_path is not None and tgt_factors_path is None:
jsone_line = {"text": src_digits, "factors": [src_factors[i] for src_factors in list_src_factors], \
"target_prefix": tgt_prefix}
elif tgt_factors_path is not None and src_factors_path is None:
jsone_line = {"text": src_digits, "target_prefix_factors": [C.TOKEN_SEPARATOR.join(tgt_factors[i]) for tgt_factors in list_tgt_factors], \
"target_prefix": tgt_prefix}
else:
jsone_line = {"text": src_digits, "factors": [src_factors[i] for src_factors in list_src_factors], \
"target_prefix_factors": [C.TOKEN_SEPARATOR.join(tgt_factors[i]) for tgt_factors in list_tgt_factors], \
"target_prefix": tgt_prefix}
print(json.dumps(jsone_line), file=out)


def generate_low_high_factors(input_path: str, output_path: str):
"""
Writes low/high factor file given a file of digit sequences.
Expand Down Expand Up @@ -115,6 +159,7 @@ def tmp_digits_dataset(prefix: str,
dev_target_path = os.path.join(work_dir, "dev.tgt")
test_source_path = os.path.join(work_dir, "test.src")
test_target_path = os.path.join(work_dir, "test.tgt")
test_source_with_target_prefix_path = os.path.join(work_dir, "test_source_with_target_prefix.json")
generate_digits_file(train_source_path, train_target_path, train_line_count, train_max_length,
line_count_empty=train_line_count_empty, sort_target=sort_target, seed=seed_train)
generate_digits_file(dev_source_path, dev_target_path, dev_line_count, dev_max_length, sort_target=sort_target,
Expand All @@ -127,7 +172,8 @@ def tmp_digits_dataset(prefix: str,
'dev_source': dev_source_path,
'dev_target': dev_target_path,
'test_source': test_source_path,
'test_target': test_target_path}
'test_target': test_target_path,
'test_source_with_target_prefix': test_source_with_target_prefix_path}

if with_n_source_factors > 0:
data['train_source_factors'] = []
Expand Down Expand Up @@ -157,8 +203,12 @@ def tmp_digits_dataset(prefix: str,
generate_odd_even_factors(test_target_path, test_factor_path)
data['train_target_factors'].append(train_factor_path)
data['dev_target_factors'].append(dev_factor_path)
data['test_target_factors'].append(dev_factor_path)
data['test_target_factors'].append(test_factor_path)

source_factors_path = None if 'test_source_factors' not in data else data['test_source_factors']
target_factors_path = None if 'test_target_factors' not in data else data['test_target_factors']
generate_json_input_file_with_tgt_prefix(test_source_path, test_target_path, test_source_with_target_prefix_path, \
source_factors_path, target_factors_path)
yield data


Expand All @@ -183,6 +233,8 @@ def tmp_digits_dataset(prefix: str,

TRANSLATE_WITH_FACTORS_COMMON = " --input-factors {input_factors}"

TRANSLATE_WITH_JSON_FORMAT = " --json-input"

TRANSLATE_PARAMS_RESTRICT = "--restrict-lexicon {lexicon} --restrict-lexicon-topk {topk}"

SCORE_PARAMS_COMMON = "--use-cpu --model {model} --source {source} --target {target} --output {output} "
Expand Down Expand Up @@ -290,6 +342,23 @@ def run_train_translate(train_params: str,

# Translate corpus with the 1st params and scoring output handler to obtain scores
data['test_output'] = os.path.join(work_dir, "test.out")
data['test_with_target_prefix_output'] = os.path.join(work_dir, "test_with_target_prefix.out")

# First set of params (with target prefix in JSON format)
params = "{} {} {}".format(sockeye.translate.__file__,
TRANSLATE_PARAMS_COMMON.format(model=data['model'],
input=data['test_source_with_target_prefix'],
output=data['test_with_target_prefix_output']),
translate_params)
params += TRANSLATE_WITH_JSON_FORMAT
logger.info("Translating with params %s", params)
with patch.object(sys, "argv", params.split()):
sockeye.translate.main()

# Collect test translate outputs and scores
data['test_with_target_prefix_outputs'] = collect_translate_output_and_scores(data['test_with_target_prefix_output'])

# Second set of params (without target prefix)
params = "{} {} {}".format(sockeye.translate.__file__,
TRANSLATE_PARAMS_COMMON.format(model=data['model'],
input=data['test_source'],
Expand All @@ -313,7 +382,7 @@ def run_train_translate(train_params: str,

# Collect test translate outputs and scores
data['test_outputs'] = collect_translate_output_and_scores(data['test_output'])
assert len(data['test_inputs']) == len(data['test_targets']) == len(data['test_outputs'])
assert len(data['test_inputs']) == len(data['test_targets']) == len(data['test_outputs']) == len(data['test_with_target_prefix_outputs'])
return data


Expand All @@ -324,7 +393,24 @@ def run_translate_restrict(data: Dict[str, Any], translate_params: str) -> Dict[
"""
translate_mod = sockeye.translate
out_path = os.path.join(data['work_dir'], "out-restrict.txt")
out_with_target_prefix_path = os.path.join(data['work_dir'], "out-with-target-prefix-restrict.txt")
# Translate corpus with restrict-lexicon

# First set of params (with target prefix in JSON format)
params = "{} {} {} {}".format(translate_mod.__file__,
TRANSLATE_PARAMS_COMMON.format(model=data['model'],
input=data['test_source_with_target_prefix'],
output=out_with_target_prefix_path),
translate_params,
TRANSLATE_PARAMS_RESTRICT.format(lexicon=data['lexicon'], topk=1))
params += TRANSLATE_WITH_JSON_FORMAT
with patch.object(sys, "argv", params.split()):
translate_mod.main()

# Collect test translate outputs and scores
data['test_with_target_prefix_outputs_restricted'] = collect_translate_output_and_scores(out_with_target_prefix_path)

# Second set of params (without using target prefix)
params = "{} {} {} {}".format(translate_mod.__file__,
TRANSLATE_PARAMS_COMMON.format(model=data['model'],
input=data['test_source'],
Expand All @@ -338,7 +424,7 @@ def run_translate_restrict(data: Dict[str, Any], translate_params: str) -> Dict[

# Collect test translate outputs and scores
data['test_outputs_restricted'] = collect_translate_output_and_scores(out_path)
assert len(data['test_outputs_restricted']) == len(data['test_outputs'])
assert len(data['test_with_target_prefix_outputs_restricted']) == len(data['test_outputs_restricted']) == len(data['test_outputs'])
return data


Expand Down
Loading

0 comments on commit edac700

Please sign in to comment.