Skip to content

Commit

Permalink
fix(examples/t5_summarize_cnn): move labels into reward_fn kwargs (#…
Browse files Browse the repository at this point in the history
…570)

* fix(examples/t5_summarize_cnn): move labels into `reward_fn` kwargs

* style: satisfy black
  • Loading branch information
maxreciprocate authored Oct 20, 2023
1 parent d03fea7 commit db466cb
Showing 1 changed file with 7 additions and 30 deletions.
37 changes: 7 additions & 30 deletions examples/summarize_daily_cnn/t5_summarize_daily_cnn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import List

from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer

import trlx
from trlx.data.configs import (
Expand Down Expand Up @@ -93,11 +91,10 @@

if __name__ == "__main__":

def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]):
original_summaries = [prompt_label[prompt.strip()] for prompt in prompts]
def reward_fn(samples: List[str], prompts: List[str], outputs: List[str], original_summaries: List[str], **kwargs):
scores = [
meteor.compute(predictions=[output.strip()], references=[original])["meteor"]
for (original, output) in zip(original_summaries, outputs)
meteor.compute(predictions=[output.strip()], references=[original_summary])["meteor"]
for (original_summary, output) in zip(original_summaries, outputs)
]
return scores

Expand All @@ -112,31 +109,11 @@ def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]):
val_prompts = ["Summarize: " + prompt for prompt in dataset["validation"]["article"][0:1000]]
val_summaries = dataset["validation"]["highlights"][0:1000]

# make dictionary of prompts and labels to use for reward function
tokenizer = AutoTokenizer.from_pretrained(config.model.model_path)
tokenizer.padding_side = "left"
tokenizer.truncation_side = "right"
tokenizer.sep_token = "<sep>"
prompt_label = {}
max_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"]

for i in tqdm(range(len(prompts))):
key = tokenizer.decode(
tokenizer(prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"],
skip_special_tokens=True,
) # get prompt like trlx's prompt
prompt_label[key.strip()] = summaries[i]

for i in tqdm(range(len(val_prompts))):
key = tokenizer.decode(
tokenizer(val_prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"],
skip_special_tokens=True,
) # get prompt like trlx's prompt
prompt_label[key.strip()] = val_summaries[i]

trlx.train(
reward_fn=reward_fn,
prompts=prompts,
eval_prompts=val_prompts,
prompts=[{"prompt": prompt, "original_summaries": summary} for prompt, summary in zip(prompts, summaries)],
eval_prompts=[
{"prompt": prompt, "original_summaries": summary} for prompt, summary in zip(val_prompts, val_summaries)
],
config=config,
)

0 comments on commit db466cb

Please sign in to comment.