Skip to content

Commit

Permalink
Fixes to blocking cross-attention (#1087)
Browse files Browse the repository at this point in the history
  • Loading branch information
xingniu authored Mar 2, 2023
1 parent 01e2392 commit 4c30942
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 10 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [3.1.34]

### Fixed
- Do not mask prepended tokens by default (for self-attention).
- Do not require specifying `--end-of-prepending-tag` if it is already done when preparing the data.

## [3.1.33]

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '3.1.33'
__version__ = '3.1.34'
2 changes: 1 addition & 1 deletion sockeye/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def forward(self,


def prepare_source_length_mask(lengths: pt.Tensor, heads: int, max_length: int, expand: bool = True,
mask_prepended_tokens: bool = True) -> pt.Tensor:
mask_prepended_tokens: bool = False) -> pt.Tensor:
"""
Prepare source length masks where positions of invalid tokens are marked as True.
Expand Down
22 changes: 14 additions & 8 deletions sockeye/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,6 @@ def check_arg_compatibility(args: argparse.Namespace):
# Length 1: expand the list to the appropriate length
args.target_factors_share_embedding = args.target_factors_share_embedding * n_target_factors

# Check arguments used for blocking cross-attention between decoder and encoded prepended tokens
if args.transformer_block_prepended_cross_attention:
check_condition(args.end_of_prepending_tag is not None,
'In order to block cross-attention between decoder and encoded prepended tokens, '
'please specify the tag indicating the end of prepended text using --end-of-prepending-tag')

check_condition(not (args.amp and args.apex_amp), 'Use either --amp (safer) or --apex-amp (faster).')

if args.dtype != C.DTYPE_FP32:
Expand Down Expand Up @@ -305,8 +299,6 @@ def create_data_iters_and_vocabs(args: argparse.Namespace,
C.TRAINING_ARG_PREPARED_DATA)
if args.prepared_data is not None:
utils.check_condition(args.source is None and args.target is None, either_raw_or_prepared_error_msg)
if args.end_of_prepending_tag is not None:
logger.warning("The end-of-prepending tag specified in the prepared data will be used.")
if not resume_training:
utils.check_condition(args.source_vocab is None and args.target_vocab is None,
"You are using a prepared data folder, which is tied to a vocabulary. "
Expand All @@ -320,6 +312,15 @@ def create_data_iters_and_vocabs(args: argparse.Namespace,
batch_type=args.batch_type,
batch_sentences_multiple_of=args.batch_sentences_multiple_of)

# Check arguments used for blocking cross-attention between decoder and encoded prepended tokens
if args.transformer_block_prepended_cross_attention:
check_condition(data_config.eop_id != C.INVALID_ID,
'In order to block cross-attention between decoder and encoded prepended tokens, '
'please specify the tag indicating the end of prepended text when preparing the data using '
'--end-of-prepending-tag')
if args.end_of_prepending_tag is not None:
logger.warning("The end-of-prepending tag specified in the prepared data will be used.")

check_condition(all([combine in [C.FACTORS_COMBINE_SUM, C.FACTORS_COMBINE_AVERAGE]
for combine in args.source_factors_combine])
or len(source_vocabs) == len(args.source_factors_num_embed) + 1,
Expand Down Expand Up @@ -356,6 +357,11 @@ def create_data_iters_and_vocabs(args: argparse.Namespace,
else:
utils.check_condition(args.prepared_data is None and args.source is not None and args.target is not None,
either_raw_or_prepared_error_msg)
# Check arguments used for blocking cross-attention between decoder and encoded prepended tokens
if args.transformer_block_prepended_cross_attention:
check_condition(args.end_of_prepending_tag is not None,
'In order to block cross-attention between decoder and encoded prepended tokens, '
'please specify the tag indicating the end of prepended text using --end-of-prepending-tag')

if resume_training:
# Load the existing vocabs created when starting the training run.
Expand Down

0 comments on commit 4c30942

Please sign in to comment.