This is the PyTorch implementation of the following paper accepted by ICASSP 2023 in Rhodes Island, Greek:
Title: Choice Fusion as Knowledge for Zero-Shot Dialogue State Tracking
Authors: Ruolin Su, Jingfeng Yang, Ting-Wei Wu, Biing-Hwang Juang.
conda creat -n py38 python=3.8
conda activate py38
pip install -r requirements.txt
Download RACE dataset and put it under qa_data folder. Download other QA datasets:
./download_data.sh
Combine and pre-process the QA datasets:
python create_qa_data.py
python create_data_mwoz.py
- Train our model with appreciated-choice selection:
./run_qa_pretrain_t5.sh pretrain
- Train our model with choice fusion mechanism including appreciated-choice selection and context-knowledge fusion:
./run_qa_pretrain_t5.sh pretrain_fusion
--percentage
The percentage of combined QA data for training.
--max_seq_length
The max length of the input tokens.
--num_train_epochs
The num of training epochs.
--overwrite_cache
Whether or not use the cached training dataset.
The number of CUDA_VISIBLE_DEVICES
, --per_device_train_batch_size
and --gradient_accumulation_steps
Multiply to get the total batch size.
--neg_num --neg_context_ratio
Negative sampling rate to encourage generating none values proactively. link
run_qa_pretrain_t5.sh: (1) pretrain. --percentage --evaluation_strategy --eval_steps --save_strategy --save_steps (2) predict.
./run_qa_pretrain_t5.sh predict
--history_turn
Previous turns used as the dialogue context for test.
--per_device_eval_batch_size
The batch size for test.
--test_type
dst for evaluating on the test set of MultiWOZ, or qa for evaluating on the QA dev set.
--overwrite_cache
Whether or not use the cached DST test dataset.
To generate DST slot-values with the trained context-knowledge fusion model, run ./run_qa_pretrain_t5.sh predict_fusion
.