diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index f1a26b3e5b6924..bceb1b49460889 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -4596,7 +4596,11 @@ def generate( if tgt_lang is not None: # also accept __xxx__ tgt_lang = tgt_lang.replace("__", "") - for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: + if generate_speech: + keys_to_check = ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"] + else: + keys_to_check = ["text_decoder_lang_to_code_id"] + for key in keys_to_check: lang_code_to_id = getattr(self.generation_config, key, None) if lang_code_to_id is None: raise ValueError( diff --git a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py index 8627220c71aa51..795f3d80422b2e 100644 --- a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py +++ b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py @@ -758,7 +758,13 @@ def setUp(self): self.tmpdirname = tempfile.mkdtemp() def update_generation(self, model): - lang_code_to_id = { + text_lang_code_to_id = { + "fra": 4, + "eng": 4, + "rus": 4, + } + + speech_lang_code_to_id = { "fra": 4, "eng": 4, } @@ -773,9 +779,9 @@ def update_generation(self, model): generation_config = copy.deepcopy(model.generation_config) - generation_config.__setattr__("text_decoder_lang_to_code_id", lang_code_to_id) - generation_config.__setattr__("t2u_lang_code_to_id", lang_code_to_id) - generation_config.__setattr__("vocoder_lang_code_to_id", lang_code_to_id) + generation_config.__setattr__("text_decoder_lang_to_code_id", text_lang_code_to_id) + generation_config.__setattr__("t2u_lang_code_to_id", speech_lang_code_to_id) + generation_config.__setattr__("vocoder_lang_code_to_id", speech_lang_code_to_id) generation_config.__setattr__("id_to_text", id_to_text) generation_config.__setattr__("char_to_id", char_to_id) generation_config.__setattr__("eos_token_id", 0) @@ -784,13 +790,13 @@ def update_generation(self, model): model.generation_config = generation_config - def prepare_text_input(self): + def prepare_text_input(self, tgt_lang): config, inputs, decoder_input_ids, input_mask, lm_labels = self.text_model_tester.prepare_config_and_inputs() input_dict = { "input_ids": inputs, "attention_mask": input_mask, - "tgt_lang": "eng", + "tgt_lang": tgt_lang, "num_beams": 2, "do_sample": True, } @@ -837,6 +843,26 @@ def factory_generation_speech_test(self, model, inputs): output = model.generate(**inputs) return output + def test_generation_languages(self): + config, input_text_rus = self.prepare_text_input(tgt_lang="rus") + + model = SeamlessM4Tv2Model(config=config) + self.update_generation(model) + model.to(torch_device) + model.eval() + + # make sure that generating speech, with a language that is only supported for text translation, raises error + with self.assertRaises(ValueError): + model.generate(**input_text_rus) + + # make sure that generating text only works + model.generate(**input_text_rus, generate_speech=False) + + # make sure it works for languages supported by both output modalities + config, input_text_eng = self.prepare_text_input(tgt_lang="eng") + model.generate(**input_text_eng) + model.generate(**input_text_eng, generate_speech=False) + def test_speech_generation(self): config, input_speech, input_text = self.prepare_speech_and_text_input()