Skip to content

A multiple-choice-question (MCQ) answering LLM model based on Retrieval Augmented Generation (RAG)

Notifications You must be signed in to change notification settings

garvitagarwal290/LLM_Science_Exam_Kaggle

Repository files navigation

Description

This is my implementation of an LLM model for answering hard science multiple-choice questions (MCQs) based on Retrieval Augmented Generation (RAG). The inspiration for this project came from Kaggle's LLM Science Exam competition. Some help was taken from this video of the YouTube channel DataScienceCastnet and a couple of other videos.

This solution is comprised of 3 steps:

  1. Take a dataset of Wikipedia articles on various STEM topics and convert each article into an embedding of fixed-size.
  2. Take a dataset of many science MCQs and their relevant Wikipedia text and train a BERT-like encoder to choose the correct option out of 5 choices with the Wikipedia text as the context.
  3. Now for a new MCQ question, find the Wikipedia article (from step 1) most likely to contain the relevant context for the question, and let the model trained in step 2 answer the question with this added context.

The details of each step are explained below:

1) Generating embeddings of STEM Wikipedia articles

For generating embeddings, the dataset wiki-20220301-en-sci which contains about 130000 Wikipedia articles was used. After a cleaning process that removes the references, external links etc, each article was converted into a 768-dimensional embedding by the sentence transformer multi-qa-mpnet-base-dot-v1. The embeddings of all the articles are stored here.

2) Fine-tuning an encoder for MCQ answering task

As the base encoder, the deberta-v3-base model was used. Parameter efficient fine-tuning (PEFT) was done using LoRA (low-rank adaption) with rank = 8 and alpha = 8.

The training dataset was taken from 15k_gpt3.5-turbo.csv that contains 15k examples of STEM MCQs with contexts and their correct answers.

The fine-tuned model can be found here.

3) Bring everything together and answer questions

This RAG model was evaluated on the 200 STEM MCQ questions provided in the LLM Science Exam Competition. For a given question, the question statement was turned into a sentence embedding again using the multi-qa-mpnet-base-dot-v1 model. Next, the 10 most similar Wikipedia articles were identified using the nearest neighbor search functionality of the Facebook AI Similarity Search (Faiss) library.

Re-ranking of these 10 articles was done using the bge-reranker-base model. After re-ranking the most similar article was used as the context to the fine-tuned deberta-v3-base model and the scores assigned by it to the 5 options were recorded.

Taking the option with the highest score as the model's `answer', the above model answered only 63 questions correctly out of 200. This accuracy is not significantly more than a model that chooses an option randomly out of the 5 options. The sub-par performance might be due to the following reasons:

  1. Limited GPU resources: All the above steps were implemented on Kaggle and so were subjected to the GPU memory and time limitations. This limited the size of the answering/reader model.
  2. Limited context length: The maximum context length of our reader model was 512 tokens. This is likely not enough to accommodate many long Wikipedia articles which are hence truncated. The truncation might result in the loss of essential context required to answer the question. Hence an encoder model with a larger context length should perform better, however, at the same time, the limited GPU memory might cause a roadblock.
  3. Limited Wikipedia dataset: Though our Wikipedia dataset had 130k articles, the test questions are supposed to be difficult questions generated by GPT3.5. It was seen that for many test questions, the Wikipedia article identified by our retrieval model did not actually have the information required to answer the question. Hence a bigger retrieval dataset should improve performance.

To conclude, more work is required to build a more accurate RAG-based MCQ answering model.

About

A multiple-choice-question (MCQ) answering LLM model based on Retrieval Augmented Generation (RAG)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published