Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanwilie committed Nov 14, 2023
1 parent f745468 commit 95fd884
Show file tree
Hide file tree
Showing 22 changed files with 2,132 additions and 1 deletion.
38 changes: 37 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,37 @@
# pick
# PICK: Polished & Informed Candidate Scoring for Knowledge-Grounded Dialogue Systems

This is the repo for the paper: [PICK: Polished & Informed Candidate Scoring for Knowledge-Grounded Dialogue Systems](https://arxiv.org/pdf/2309.10413.pdf). This framework addresses the key challenges in knowledge-grounded dialogue systems, such as hallucination and lack of coherence, through a generation re-scoring framework that empowers models to generate faithful and relevant responses without requiring additional labeled data or model tuning. Further details could be found [in the paper](https://arxiv.org/pdf/2309.10413.pdf).

## Steps:
1. Make sure all requirements are installed, or install it via: `pip install -r requirements.txt`
2. Prepare the dataset:
- Download the wizard_of_wikipedia dataset:
- `wget -P data_pool/wizard_of_wikipedia http://parl.ai/downloads/wizard_of_wikipedia/wizard_of_wikipedia.tgz`
- `tar -xvzf data_pool/wizard_of_wikipedia/wizard_of_wikipedia.tgz -C data_pool/wizard_of_wikipedia/`
- `rm -rf data_pool/wizard_of_wikipedia/wizard_of_wikipedia.tgz`
3. Prepare caffeinated_pandas to help in parallelization:
- Download caffeinated-pandas repo to this repo in your local using:
- `git clone https://github.com/scollay/caffeinated-pandas.git`
- `mv caffeinated-pandas caffeinated_pandas`
3. Finetune your model using `run_ft_*.sh`
4. Do inference with your model using `run_eval_*.sh`
5. Score your generations further with other metrices, i.e. [FED](https://github.com/Shikib/fed.git), by cloning it to your local.

## Citation

This work has been accepted by AACL-IJCNLP 2023 and you can find the details [in the paper](https://arxiv.org/pdf/2309.10413.pdf) (the link to AACL paper still currently not yet ready). Please cite our work if you find it useful.
```
@inproceedings{wilie2023pick,
author = {Wilie, Bryan and Xu, Yan and Chung, Willy and
Cahyawijaya, Samuel and Lovenia, Holy and Fung, Pascale},
title = {PICK: Polished \& Informed Candidate Scoring for Knowledge-Grounded Dialogue Systems},
booktitle = {Proceedings of the 13th International Joint Conference on Natural Language Processing
and the 3rd Conference of the Asia-Pacific Chapter of
the Association for Computational Linguistics},
month = {November},
year = {2023},
address = {Nusa Dua, Bali},
publisher = {Association for Computational Linguistics},
pages = {980--995}
}
```
210 changes: 210 additions & 0 deletions finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import os
import sys
import math
import torch
import logging
import datasets
import transformers
from tqdm import tqdm
from itertools import chain
from functools import cache
from transformers import (
default_data_collator,
set_seed,
GPT2Tokenizer,
GPT2LMHeadModel,
EarlyStoppingCallback,
HfArgumentParser,
Trainer
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process

from src.model.model import load_model_and_tokenizer
from src.dataset.load_dataset import load_dataset
from src.utils.train_args_helper import DataArguments, ModelArguments, TrainingArguments
from src.utils.trainer_helper import preprocess_logits_for_metrics, compute_metrics
from src.utils.general_helper import set_all_seeds

datasets.enable_caching()
logger = logging.getLogger(__name__)


def do_train(model_args, data_args, training_args):

# Init
set_seed(training_args.seed)
set_all_seeds(training_args.seed)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

##### ARGS #####
run_name = '{}_{}_{}_{}_{}dhuttr_nokn{}_npu{}_adddata{}_maxseqlen{}_bs{}_gradacc{}_lr{}_spktoken{}_{}pad_{}epoch_wd{}_ws{}'\
.format(model_args.model_name_or_path,
data_args.dataset_name,
data_args.experiment_mode,
data_args.eval_set,
data_args.dh_uttr_count,
data_args.no_knowledge_dataset,
data_args.no_passages_used_settings,
data_args.add_dataset,
model_args.max_seq_len,
training_args.per_device_train_batch_size,
training_args.gradient_accumulation_steps,
training_args.learning_rate,
data_args.use_speaker_token,
data_args.pad_token,
training_args.num_train_epochs,
training_args.weight_decay,
training_args.warmup_steps).replace('/', '-')


##### MODEL #####
model, tokenizer = load_model_and_tokenizer(data_args, model_args)


##### DATASET #####

# Load the preprocessed dataset splits
dataset_dict = {}
for split in ['train', data_args.eval_set]:
dataset_dict[split] = load_dataset(data_args, model_args, tokenizer, split)


##### TRAINING #####
print('Preparing Trainer...')
training_args.output_dir = training_args.output_dir + run_name

# Initialize Trainer
trainer = Trainer(
train_dataset=dataset_dict['train'],
eval_dataset=dataset_dict[data_args.eval_set],
model=model,
data_collator=default_data_collator,
args=training_args,
compute_metrics=compute_metrics if training_args.do_eval else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None,
callbacks=[EarlyStoppingCallback(early_stopping_patience=training_args.early_stopping_patience)]
)

###
# Training Phase
###

if training_args.do_train:
print('*** Training Phase ***')
checkpoint = None

# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)

if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)

### Saving
trainer.save_model() # Saves the tokenizer too for easy upload

metrics = train_result.metrics
metrics["train_samples"] = len(dataset_dict['train'])

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

###
# Evaluation Phase
###

if training_args.do_eval:
print("*** Evaluation Phase ***")

metrics = trainer.evaluate()
metrics["eval_samples"] = len(dataset_dict[data_args.eval_set])

try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity

trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)


#####
# Entry Point
#####
def main():

###
# Parsing & Initialization
###

# Parse argument
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()


# Set random seed
set_seed(training_args.seed)

# Detect last checkpoint
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)

###
# Prepare logger
###

# Init logging
os.makedirs("./log", exist_ok=True)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout), logging.FileHandler(
"./log/log__{}".format(model_args.model_name_or_path.replace("/", "_")), mode="w")],
)
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)

# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to warn of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank):
transformers.utils.logging.set_verbosity(transformers.logging.WARNING)
logger.info("Training/evaluation parameters %s", training_args)

do_train(model_args, data_args, training_args)


if __name__ == '__main__':
main()
Loading

0 comments on commit 95fd884

Please sign in to comment.