Skip to content

Commit

Permalink
Fix languages covered by M4Tv2 (huggingface#28019)
Browse files Browse the repository at this point in the history
* correct language assessment  + add tests

* Update src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py

Co-authored-by: amyeroberts <[email protected]>

* make style + simplify and enrich test

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
2 people authored and iantbutler01 committed Dec 16, 2023
1 parent f38b23d commit ba1f2c1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4596,7 +4596,11 @@ def generate(
if tgt_lang is not None:
# also accept __xxx__
tgt_lang = tgt_lang.replace("__", "")
for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]:
if generate_speech:
keys_to_check = ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]
else:
keys_to_check = ["text_decoder_lang_to_code_id"]
for key in keys_to_check:
lang_code_to_id = getattr(self.generation_config, key, None)
if lang_code_to_id is None:
raise ValueError(
Expand Down
38 changes: 32 additions & 6 deletions tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,13 @@ def setUp(self):
self.tmpdirname = tempfile.mkdtemp()

def update_generation(self, model):
lang_code_to_id = {
text_lang_code_to_id = {
"fra": 4,
"eng": 4,
"rus": 4,
}

speech_lang_code_to_id = {
"fra": 4,
"eng": 4,
}
Expand All @@ -773,9 +779,9 @@ def update_generation(self, model):

generation_config = copy.deepcopy(model.generation_config)

generation_config.__setattr__("text_decoder_lang_to_code_id", lang_code_to_id)
generation_config.__setattr__("t2u_lang_code_to_id", lang_code_to_id)
generation_config.__setattr__("vocoder_lang_code_to_id", lang_code_to_id)
generation_config.__setattr__("text_decoder_lang_to_code_id", text_lang_code_to_id)
generation_config.__setattr__("t2u_lang_code_to_id", speech_lang_code_to_id)
generation_config.__setattr__("vocoder_lang_code_to_id", speech_lang_code_to_id)
generation_config.__setattr__("id_to_text", id_to_text)
generation_config.__setattr__("char_to_id", char_to_id)
generation_config.__setattr__("eos_token_id", 0)
Expand All @@ -784,13 +790,13 @@ def update_generation(self, model):

model.generation_config = generation_config

def prepare_text_input(self):
def prepare_text_input(self, tgt_lang):
config, inputs, decoder_input_ids, input_mask, lm_labels = self.text_model_tester.prepare_config_and_inputs()

input_dict = {
"input_ids": inputs,
"attention_mask": input_mask,
"tgt_lang": "eng",
"tgt_lang": tgt_lang,
"num_beams": 2,
"do_sample": True,
}
Expand Down Expand Up @@ -837,6 +843,26 @@ def factory_generation_speech_test(self, model, inputs):
output = model.generate(**inputs)
return output

def test_generation_languages(self):
config, input_text_rus = self.prepare_text_input(tgt_lang="rus")

model = SeamlessM4Tv2Model(config=config)
self.update_generation(model)
model.to(torch_device)
model.eval()

# make sure that generating speech, with a language that is only supported for text translation, raises error
with self.assertRaises(ValueError):
model.generate(**input_text_rus)

# make sure that generating text only works
model.generate(**input_text_rus, generate_speech=False)

# make sure it works for languages supported by both output modalities
config, input_text_eng = self.prepare_text_input(tgt_lang="eng")
model.generate(**input_text_eng)
model.generate(**input_text_eng, generate_speech=False)

def test_speech_generation(self):
config, input_speech, input_text = self.prepare_speech_and_text_input()

Expand Down

0 comments on commit ba1f2c1

Please sign in to comment.