Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix languages covered by M4Tv2 #28019

Merged
merged 3 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4596,7 +4596,12 @@ def generate(
if tgt_lang is not None:
# also accept __xxx__
tgt_lang = tgt_lang.replace("__", "")
for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]:
keys_to_check = (
["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]
if generate_speech
else ["text_decoder_lang_to_code_id"]
)
ylacombe marked this conversation as resolved.
Show resolved Hide resolved
for key in keys_to_check:
lang_code_to_id = getattr(self.generation_config, key, None)
if lang_code_to_id is None:
raise ValueError(
Expand Down
34 changes: 29 additions & 5 deletions tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,13 @@ def setUp(self):
self.tmpdirname = tempfile.mkdtemp()

def update_generation(self, model):
lang_code_to_id = {
text_lang_code_to_id = {
"fra": 4,
"eng": 4,
"rus": 4,
}

speech_lang_code_to_id = {
"fra": 4,
"eng": 4,
}
Expand All @@ -773,9 +779,9 @@ def update_generation(self, model):

generation_config = copy.deepcopy(model.generation_config)

generation_config.__setattr__("text_decoder_lang_to_code_id", lang_code_to_id)
generation_config.__setattr__("t2u_lang_code_to_id", lang_code_to_id)
generation_config.__setattr__("vocoder_lang_code_to_id", lang_code_to_id)
generation_config.__setattr__("text_decoder_lang_to_code_id", text_lang_code_to_id)
generation_config.__setattr__("t2u_lang_code_to_id", speech_lang_code_to_id)
generation_config.__setattr__("vocoder_lang_code_to_id", speech_lang_code_to_id)
generation_config.__setattr__("id_to_text", id_to_text)
generation_config.__setattr__("char_to_id", char_to_id)
generation_config.__setattr__("eos_token_id", 0)
Expand All @@ -784,7 +790,7 @@ def update_generation(self, model):

model.generation_config = generation_config

def prepare_text_input(self):
def prepare_text_input(self, is_rus=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just make this tgt_lang and then you can easily test with different target languages both in an outside of the supported languages with generate_speech as True and False?

config, inputs, decoder_input_ids, input_mask, lm_labels = self.text_model_tester.prepare_config_and_inputs()

input_dict = {
Expand All @@ -795,6 +801,9 @@ def prepare_text_input(self):
"do_sample": True,
}

if is_rus:
input_dict["tgt_lang"] = "rus"

return config, input_dict

def prepare_speech_input(self):
Expand Down Expand Up @@ -837,6 +846,21 @@ def factory_generation_speech_test(self, model, inputs):
output = model.generate(**inputs)
return output

def test_generation_languages(self):
config, input_text_rus = self.prepare_text_input(is_rus=True)

model = SeamlessM4Tv2Model(config=config)
self.update_generation(model)
model.to(torch_device)
model.eval()

# make sure that generating speech, with a language that is only supported for text translation, raises error
with self.assertRaises(ValueError):
model.generate(**input_text_rus)

# make sure that generating text only works
model.generate(**input_text_rus, generate_speech=False)
Comment on lines +854 to +859
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also make sure it works in both cases for a language supported for all tasks


def test_speech_generation(self):
config, input_speech, input_text = self.prepare_speech_and_text_input()

Expand Down
Loading