Skip to content

Commit

Permalink
Merge pull request #1 from ThisisPromise/Back-translation
Browse files Browse the repository at this point in the history
Back-translation
  • Loading branch information
ThisisPromise authored Jan 22, 2024
2 parents 84c4340 + 9d2433f commit 358022c
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import argparse
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer

def load_model_and_tokenizer(model_name):
model = M2M100ForConditionalGeneration.from_pretrained(model_name)
tokenizer = M2M100Tokenizer.from_pretrained(model_name)
return model, tokenizer

def perform_translation(model, tokenizer, batch_texts):
formatted_batch_texts = [f"{text}" for text in batch_texts]
model_inputs = tokenizer(formatted_batch_texts, return_tensors="pt", padding=True, truncation=True)
translated = model.generate(**model_inputs)
translated_texts = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
return translated_texts

def perform_back_translation(original_texts, original_model, original_tokenizer, back_translation_model, back_translation_tokenizer):
temp_translated_batch = perform_translation(original_model, original_tokenizer, original_texts)
back_translated_batch = perform_translation(back_translation_model, back_translation_tokenizer, temp_translated_batch)
return list(set(original_texts) | set(back_translated_batch))

def main():
parser = argparse.ArgumentParser(description="Perform translation and back-translation with M2M100 models.")
parser.add_argument("--model_name", type=str, required=True, help="Model name for tokenizer and model loading.")
parser.add_argument("--original_texts", nargs="+", required=True, help="Original texts for translation.")
parser.add_argument("--back_translation_model_name", type=str, required=True, help="Model name for back translation.")
parser.add_argument("--tokenizer_name", type=str, required=True, help="Tokenizer name for loading.")

args = parser.parse_args()

# Load models and tokenizer
original_model, original_tokenizer = load_model_and_tokenizer(args.model_name)
back_translation_model, back_translation_tokenizer = load_model_and_tokenizer(args.back_translation_model_name)

# Perform translation
translated_texts = perform_translation(original_model, original_tokenizer, args.original_texts)
print("Translated Texts:", translated_texts)

# Perform back-translation
back_translated_texts = perform_back_translation(args.original_texts, original_model, original_tokenizer, back_translation_model, back_translation_tokenizer)
print("Back-Translated Texts:", back_translated_texts)

if __name__ == "__main__":
main()

0 comments on commit 358022c

Please sign in to comment.