Skip to content

Commit

Permalink
make style + simplify and enrich test
Browse files Browse the repository at this point in the history
  • Loading branch information
ylacombe committed Dec 14, 2023
1 parent 5005a2e commit 3d2ea36
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4596,9 +4596,9 @@ def generate(
if tgt_lang is not None:
# also accept __xxx__
tgt_lang = tgt_lang.replace("__", "")
if generate_speech:
if generate_speech:
keys_to_check = ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]
else:
else:
keys_to_check = ["text_decoder_lang_to_code_id"]
for key in keys_to_check:
lang_code_to_id = getattr(self.generation_config, key, None)
Expand Down
14 changes: 8 additions & 6 deletions tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,20 +790,17 @@ def update_generation(self, model):

model.generation_config = generation_config

def prepare_text_input(self, is_rus=False):
def prepare_text_input(self, tgt_lang):
config, inputs, decoder_input_ids, input_mask, lm_labels = self.text_model_tester.prepare_config_and_inputs()

input_dict = {
"input_ids": inputs,
"attention_mask": input_mask,
"tgt_lang": "eng",
"tgt_lang": tgt_lang,
"num_beams": 2,
"do_sample": True,
}

if is_rus:
input_dict["tgt_lang"] = "rus"

return config, input_dict

def prepare_speech_input(self):
Expand Down Expand Up @@ -847,7 +844,7 @@ def factory_generation_speech_test(self, model, inputs):
return output

def test_generation_languages(self):
config, input_text_rus = self.prepare_text_input(is_rus=True)
config, input_text_rus = self.prepare_text_input(tgt_lang="rus")

model = SeamlessM4Tv2Model(config=config)
self.update_generation(model)
Expand All @@ -861,6 +858,11 @@ def test_generation_languages(self):
# make sure that generating text only works
model.generate(**input_text_rus, generate_speech=False)

# make sure it works for languages supported by both output modalities
config, input_text_eng = self.prepare_text_input(tgt_lang="eng")
model.generate(**input_text_eng)
model.generate(**input_text_eng, generate_speech=False)

def test_speech_generation(self):
config, input_speech, input_text = self.prepare_speech_and_text_input()

Expand Down

0 comments on commit 3d2ea36

Please sign in to comment.