From df3d89fc7e0e4cef586bcb9b72b4dd0a35c571db Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Wed, 13 Dec 2023 20:57:17 +0000 Subject: [PATCH 1/3] correct language assessment + add tests --- .../modeling_seamless_m4t_v2.py | 7 +++- .../test_modeling_seamless_m4t_v2.py | 34 ++++++++++++++++--- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index f1a26b3e5b6924..d183753346a77c 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -4596,7 +4596,12 @@ def generate( if tgt_lang is not None: # also accept __xxx__ tgt_lang = tgt_lang.replace("__", "") - for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: + keys_to_check = ( + ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"] + if generate_speech + else ["text_decoder_lang_to_code_id"] + ) + for key in keys_to_check: lang_code_to_id = getattr(self.generation_config, key, None) if lang_code_to_id is None: raise ValueError( diff --git a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py index 8627220c71aa51..e522cd44da6689 100644 --- a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py +++ b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py @@ -758,7 +758,13 @@ def setUp(self): self.tmpdirname = tempfile.mkdtemp() def update_generation(self, model): - lang_code_to_id = { + text_lang_code_to_id = { + "fra": 4, + "eng": 4, + "rus": 4, + } + + speech_lang_code_to_id = { "fra": 4, "eng": 4, } @@ -773,9 +779,9 @@ def update_generation(self, model): generation_config = copy.deepcopy(model.generation_config) - generation_config.__setattr__("text_decoder_lang_to_code_id", lang_code_to_id) - generation_config.__setattr__("t2u_lang_code_to_id", lang_code_to_id) - generation_config.__setattr__("vocoder_lang_code_to_id", lang_code_to_id) + generation_config.__setattr__("text_decoder_lang_to_code_id", text_lang_code_to_id) + generation_config.__setattr__("t2u_lang_code_to_id", speech_lang_code_to_id) + generation_config.__setattr__("vocoder_lang_code_to_id", speech_lang_code_to_id) generation_config.__setattr__("id_to_text", id_to_text) generation_config.__setattr__("char_to_id", char_to_id) generation_config.__setattr__("eos_token_id", 0) @@ -784,7 +790,7 @@ def update_generation(self, model): model.generation_config = generation_config - def prepare_text_input(self): + def prepare_text_input(self, is_rus=False): config, inputs, decoder_input_ids, input_mask, lm_labels = self.text_model_tester.prepare_config_and_inputs() input_dict = { @@ -795,6 +801,9 @@ def prepare_text_input(self): "do_sample": True, } + if is_rus: + input_dict["tgt_lang"] = "rus" + return config, input_dict def prepare_speech_input(self): @@ -837,6 +846,21 @@ def factory_generation_speech_test(self, model, inputs): output = model.generate(**inputs) return output + def test_generation_languages(self): + config, input_text_rus = self.prepare_text_input(is_rus=True) + + model = SeamlessM4Tv2Model(config=config) + self.update_generation(model) + model.to(torch_device) + model.eval() + + # make sure that generating speech, with a language that is only supported for text translation, raises error + with self.assertRaises(ValueError): + model.generate(**input_text_rus) + + # make sure that generating text only works + model.generate(**input_text_rus, generate_speech=False) + def test_speech_generation(self): config, input_speech, input_text = self.prepare_speech_and_text_input() From 5005a2e0ce4b5654997980927f7334b719f9fb5a Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Thu, 14 Dec 2023 14:14:37 +0000 Subject: [PATCH 2/3] Update src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../models/seamless_m4t_v2/modeling_seamless_m4t_v2.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index d183753346a77c..e0d5e7392834f0 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -4596,11 +4596,10 @@ def generate( if tgt_lang is not None: # also accept __xxx__ tgt_lang = tgt_lang.replace("__", "") - keys_to_check = ( - ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"] - if generate_speech - else ["text_decoder_lang_to_code_id"] - ) + if generate_speech: + keys_to_check = ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"] + else: + keys_to_check = ["text_decoder_lang_to_code_id"] for key in keys_to_check: lang_code_to_id = getattr(self.generation_config, key, None) if lang_code_to_id is None: From 3d2ea3691f26204059c29aa605775b8d30604233 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Thu, 14 Dec 2023 14:20:48 +0000 Subject: [PATCH 3/3] make style + simplify and enrich test --- .../seamless_m4t_v2/modeling_seamless_m4t_v2.py | 4 ++-- .../test_modeling_seamless_m4t_v2.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index e0d5e7392834f0..bceb1b49460889 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -4596,9 +4596,9 @@ def generate( if tgt_lang is not None: # also accept __xxx__ tgt_lang = tgt_lang.replace("__", "") - if generate_speech: + if generate_speech: keys_to_check = ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"] - else: + else: keys_to_check = ["text_decoder_lang_to_code_id"] for key in keys_to_check: lang_code_to_id = getattr(self.generation_config, key, None) diff --git a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py index e522cd44da6689..795f3d80422b2e 100644 --- a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py +++ b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py @@ -790,20 +790,17 @@ def update_generation(self, model): model.generation_config = generation_config - def prepare_text_input(self, is_rus=False): + def prepare_text_input(self, tgt_lang): config, inputs, decoder_input_ids, input_mask, lm_labels = self.text_model_tester.prepare_config_and_inputs() input_dict = { "input_ids": inputs, "attention_mask": input_mask, - "tgt_lang": "eng", + "tgt_lang": tgt_lang, "num_beams": 2, "do_sample": True, } - if is_rus: - input_dict["tgt_lang"] = "rus" - return config, input_dict def prepare_speech_input(self): @@ -847,7 +844,7 @@ def factory_generation_speech_test(self, model, inputs): return output def test_generation_languages(self): - config, input_text_rus = self.prepare_text_input(is_rus=True) + config, input_text_rus = self.prepare_text_input(tgt_lang="rus") model = SeamlessM4Tv2Model(config=config) self.update_generation(model) @@ -861,6 +858,11 @@ def test_generation_languages(self): # make sure that generating text only works model.generate(**input_text_rus, generate_speech=False) + # make sure it works for languages supported by both output modalities + config, input_text_eng = self.prepare_text_input(tgt_lang="eng") + model.generate(**input_text_eng) + model.generate(**input_text_eng, generate_speech=False) + def test_speech_generation(self): config, input_speech, input_text = self.prepare_speech_and_text_input()