diff --git a/evals/evaluation/agent_eval/crag_eval/README.md b/evals/evaluation/agent_eval/crag_eval/README.md index b458a930..5c2c0e84 100644 --- a/evals/evaluation/agent_eval/crag_eval/README.md +++ b/evals/evaluation/agent_eval/crag_eval/README.md @@ -46,29 +46,29 @@ cd $WORKDIR/GenAIEval/evals/evaluation/agent_eval/crag_eval/preprocess_data bash run_data_preprocess.sh ``` **Note**: This is an example of data processing. You can develop and optimize your own data processing for this benchmark. -3. Sample queries for benchmark +3. (Optional) Sample queries for benchmark The CRAG dataset has more than 4000 queries, and running all of them can be very expensive and time-consuming. You can sample a subset for benchmark. Here we provide a script to sample up to 5 queries per question_type per dynamism in each domain. For example, we were able to get 92 queries from the music domain using the script. ``` bash run_sample_data.sh ``` ## Launch agent QnA system -Here we showcase a RAG agent in GenAIExample repo. Please refer to the README in the [AgentQnA example](https://github.com/opea-project/GenAIExamples/tree/main/AgentQnA/README.md) for more details. +Here we showcase an agent system in OPEA GenAIExamples repo. Please refer to the README in the [AgentQnA example](https://github.com/opea-project/GenAIExamples/tree/main/AgentQnA/README.md) for more details. > **Please note**: This is an example. You can build your own agent systems using OPEA components, then expose your own systems as an endpoint for this benchmark. -To launch the agent in our AgentQnA example, open another terminal and build images and launch agent system there. +To launch the agent in our AgentQnA example on Intel Gaudi accelerators, open another terminal and follow the instructions below. 1. Build images ``` export $WORKDIR= cd $WORKDIR git clone https://github.com/opea-project/GenAIExamples.git cd GenAIExamples/AgentQnA/tests/ -bash 1_build_images.sh +bash step1_build_images.sh ``` 2. Start retrieval tool ``` -bash 2_start_retrieval_tool.sh +bash step2_start_retrieval_tool.sh ``` 3. Ingest data into vector database and validate retrieval tool ``` @@ -86,19 +86,21 @@ python3 index_data.py --host_ip $host_ip --filedir ${WORKDIR}/datasets/crag_docs ``` # Go to the terminal where you launched the AgentQnA example cd $WORKDIR/GenAIExamples/AgentQnA/tests/ -bash 4_launch_and_validate_agent.sh +bash step4_launch_and_validate_agent_gaudi.sh ``` +Note: There are two agents in the agent system: a RAG agent (as the worker agent) and a ReAct agent (as the supervisor agent). We can evaluate both agents - just need to specify the agent endpoint url in the scripts - see instructions below. ## Run CRAG benchmark -Once you have your agent system up and running, the next step is to generate answers with agent. Change the variables in the script below and run the script. By default, it will run a sampled set of queries in music domain. +Once you have your agent system up and running, the next step is to generate answers with agent. Change the variables in the script below and run the script. By default, it will run the entire set of queries in the music domain (in total 373 queries). You can choose to run other domains or just run a sampled subset of music domain. ``` # Come back to the interactive crag-eval docker container cd $WORKDIR/GenAIEval/evals/evaluation/agent_eval/crag_eval/run_benchmark +# Remember to specify the agent endpoint url in the script. bash run_generate_answer.sh ``` ## Use LLM-as-judge to grade the answers -1. Launch llm endpoint with HF TGI: in another terminal, run the command below. By default, `meta-llama/Meta-Llama-3-70B-Instruct` is used as the LLM judge. +1. Launch llm endpoint with HF TGI: in another terminal, run the command below. By default, `meta-llama/Meta-Llama-3.1-70B-Instruct` is used as the LLM judge. ``` cd llm_judge bash launch_llm_judge_endpoint.sh @@ -123,3 +125,57 @@ python3 test_llm_endpoint.py cd $WORKDIR/GenAIEval/evals/evaluation/agent_eval/crag_eval/run_benchmark/ bash run_grading.sh ``` + +### Validation of LLM-as-judge +We validated RAGAS answer correctness as the metric to evaluate agents. We sampled 92 queries from the full music domain dataset (up to 5 questions per sub-category for all 32 sub-categories), and conducted human evaluations on the conventional RAG answers, the single RAG agent answers and the hierarchical ReAct agent answers of the 92 queries. + +We followed the criteria in the [CRAG paper](https://arxiv.org/pdf/2406.04744) to get human scores: +1. score 1 if the answer matches the golden answer or semantically similar. +2. score 0 if the answer misses information, or is "I don't know", “I’m sorry I can’t find ...”, a system error such as recursion limit is hit, or a request from the system to clarify the original question. +3. score -1 if the answer contains incorrect information. + +On the other hand, RAGAS `answer_correctness` score is on a scale of 0-1 and is a weighted average of 1) an F1 score and 2) similarity between answer and golden answer. The F1 score is based on the number of statements in the answer supported or not supported by the golden answer, and the number of statements in the golden answer appeared or did not appear in the answer. Please refer to [RAGAS source code](https://github.com/explodinggradients/ragas/blob/main/src/ragas/metrics/_answer_correctness.py) for the implementation of its `answer_correctness` score. We ran RAGAS on Intel Gaudi2 accelerators. We used `meta-llama/Meta-Llama-3.1-70B-Instruct` as the LLM judge. + +|Setup |Mean Human score|Mean RAGAS `answer_correctness` score| +|----------------|-----------|------------------------------| +|Conventional RAG|0.05 |0.37| +|Single RAG agent|0.18 |0.43| +|Hierarchical ReAct agent|0.22|0.54| + +We can see that the human scores and the RAGAS `answer_correctness` scores follow the same trend, although the two scoring methods used different grading criteria and methods. Since LLM-as-judge is more scalable for larger datasets, we decided to use RAGAS `answer_correctness` scores (produced by `meta-llama/Meta-Llama-3-70B-Instruct` as the LLM judge) for the evaluation of OPEA agents on the full CRAG music domain dataset. + +We have made available our scripts to calculate the mean RAGAS scores. Refer to the `run_compare_scores.sh` script in the `run_benchmark` folder. + + +## Benchmark results for OPEA RAG Agents +We have evaluated the agents (`rag_agent_llama` and `react_llama` strategies) in the OPEA AgentQnA example on CRAG music domain dataset (373 questions in total). We used `meta-llama/Meta-Llama-3.1-70B-Instruct` and we served the LLM with tgi-gaudi on 4 Intel Gaudi2 accelerator cards. Refer to the docker compose yaml files in the AgentQnA example for more details on the configurations. + +For the tests of conventional RAG, we used the script in the `run_benchmark` folder: `run_conv_rag.sh`. And we used the same LLM, serving configs and generation parameters as the RAG agent. + +The Conventional RAG and Single RAG agent use the same retriever. The Hierarchical ReAct agent uses the Single RAG agent as its retrieval tool and also has access to CRAG APIs provided by Meta as part of the CRAG benchmark. + + +|Setup |Mean RAGAS `answer_correctness` score| +|----------------|------------------------------| +|Conventional RAG|0.42| +|Single RAG agent|0.43| +|Hierarchical ReAct agent|0.53| + +From the results, we can see that the single RAG agent performs better than conventional RAG, while the hierarchical ReAct agent has the highest `answer_correctness` score. The reasons for such performance improvements: +1. RAG agent rewrites query and checks the quality of retrieved documents before feeding the docs to generation. It can get docs that are more relevant to generate answers. It can also decompose complex questions into modular tasks and get related docs for each task and then aggregate info to come up with answers. +2. Hierarchical ReAct agent was supplied with APIs to get information from knowledge graphs, and thus can supplement info to the knowledge in the retrieval vector database. So it can answer questions where conventional RAG or Single RAG agent cannot due to the lack of relevant info in vector database. + +Note: The performance result for the hierarchical ReAct agent is with tool selection, i.e., only give a subset of tools to agent based on query, which we found can boost agent performance when the number of tools is large. However, currently OPEA agents do not support tool selection yet. We are in the process of enabling tool selection. + +### Comparison with GPT-4o-mini +Open-source LLM serving libraries (tgi and vllm) have limited capabilities in producing tool-call objects. Although vllm improved its tool-calling capabilities recently, parallel tool calling is still not well supported. Therefore, we had to write our own prompts and output parsers for the `rag_agent_llama` and `react_llama` strategies for using open-source LLMs served with open-source serving frameworks for OPEA agent microservices. + +Below we show the comparisons of `meta-llama/Meta-Llama-3.1-70B-Instruct` versus OpenAI's `gpt-4o-mini-2024-07-18` on 20 sampled queries from the CRAG music domain dataset. We used human evaluation criteria outlined above. The numbers are the average scores graged by human. The parenthesis denotes the OPEA agent strategy used. + +|Setup|Llama3.1-70B-Instruct|gpt-4o-mini| +|-----|---------------------|-----------| +|Conventional RAG|0.15|0.05| +|Single RAG agent|0.45 (`rag_agent_llama`)|0.65 (`rag_agent`)| +|Hierarchical ReAct agent|0.55 (`react_llama`)|0.75 (`react_langgraph`)| + +From the comparisons on this small subset, we can see that OPEA agents using `meta-llama/Meta-Llama-3.1-70B-Instruct` with calibrated prompt templates and output parsers are only slightly behind `gpt-4o-mini-2024-07-18` with proprietary tool-calling capabilities. diff --git a/evals/evaluation/agent_eval/crag_eval/docker/Dockerfile b/evals/evaluation/agent_eval/crag_eval/docker/Dockerfile index 35f27199..0fc857fa 100644 --- a/evals/evaluation/agent_eval/crag_eval/docker/Dockerfile +++ b/evals/evaluation/agent_eval/crag_eval/docker/Dockerfile @@ -10,7 +10,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ git \ poppler-utils \ libmkl-dev \ - curl + curl \ + nano COPY requirements.txt /home/user/requirements.txt diff --git a/evals/evaluation/agent_eval/crag_eval/docker/build_image.sh b/evals/evaluation/agent_eval/crag_eval/docker/build_image.sh index a743900f..33207c9e 100644 --- a/evals/evaluation/agent_eval/crag_eval/docker/build_image.sh +++ b/evals/evaluation/agent_eval/crag_eval/docker/build_image.sh @@ -4,8 +4,9 @@ dockerfile=Dockerfile docker build \ + --no-cache \ -f ${dockerfile} . \ - -t crag-eval:latest \ + -t crag-eval:v1.1 \ --network=host \ --build-arg http_proxy=${http_proxy} \ --build-arg https_proxy=${https_proxy} \ diff --git a/evals/evaluation/agent_eval/crag_eval/docker/launch_eval_container.sh b/evals/evaluation/agent_eval/crag_eval/docker/launch_eval_container.sh index 8698f452..cf25502f 100644 --- a/evals/evaluation/agent_eval/crag_eval/docker/launch_eval_container.sh +++ b/evals/evaluation/agent_eval/crag_eval/docker/launch_eval_container.sh @@ -4,4 +4,4 @@ volume=$WORKDIR host_ip=$(hostname -I | awk '{print $1}') -docker run -it -v $volume:/home/user/ -e WORKDIR=/home/user -e HF_HOME=/home/user/hf_cache -e host_ip=$host_ip -e http_proxy=$http_proxy -e https_proxy=$https_proxy crag-eval:latest +docker run -it --name crag_eval -v $volume:/home/user/ -e WORKDIR=/home/user -e HF_HOME=/home/user/hf_cache -e host_ip=$host_ip -e http_proxy=$http_proxy -e https_proxy=$https_proxy crag-eval:v1.1 diff --git a/evals/evaluation/agent_eval/crag_eval/docker/requirements.txt b/evals/evaluation/agent_eval/crag_eval/docker/requirements.txt index b32606b7..a6d88f19 100644 --- a/evals/evaluation/agent_eval/crag_eval/docker/requirements.txt +++ b/evals/evaluation/agent_eval/crag_eval/docker/requirements.txt @@ -3,6 +3,8 @@ evaluate jieba langchain-community langchain-huggingface +langchain-openai +nltk pandas ragas sentence_transformers diff --git a/evals/evaluation/agent_eval/crag_eval/run_benchmark/compare_scores.py b/evals/evaluation/agent_eval/crag_eval/run_benchmark/compare_scores.py new file mode 100644 index 00000000..b91568f1 --- /dev/null +++ b/evals/evaluation/agent_eval/crag_eval/run_benchmark/compare_scores.py @@ -0,0 +1,90 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse + +import pandas as pd +from scipy.stats import pearsonr, spearmanr + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--filedir", type=str, help="file directory") + parser.add_argument("--conv_rag", type=str, help="file with RAGAS scores for conventional RAG") + parser.add_argument("--ragagent", type=str, help="file with RAGAS scores for RAG agent") + parser.add_argument("--reactagent", type=str, help="file with RAGAS scores for React agent") + parser.add_argument("--human_scores_file", type=str, help="file with human scores for 3 setups") + return parser.parse_args() + + +def merge_and_get_stats(filedir, conv_rag, ragagent, reactagent, prefix=""): + conv_rag_df = pd.read_csv(filedir + conv_rag) + ragagent_df = pd.read_csv(filedir + ragagent) + reactagent_df = pd.read_csv(filedir + reactagent) + + conv_rag_df = conv_rag_df.rename(columns={"answer_correctness": "conv_rag_score"}) + ragagent_df = ragagent_df.rename(columns={"answer_correctness": "ragagent_score"}) + reactagent_df = reactagent_df.rename(columns={"answer_correctness": "reactagent_score"}) + merged_df = pd.merge(conv_rag_df, ragagent_df, on="query") + merged_df = pd.merge(merged_df, reactagent_df, on="query") + print(merged_df.shape) + print(merged_df.describe()) + merged_df.to_csv(filedir + prefix + "merged_scores.csv", index=False) + + # drop rows with nan + merged_df_dropped = merged_df.dropna() + # merged_df = merged_df.reset_index(drop=True) + print(merged_df_dropped.shape) + + # compare scores + print(merged_df_dropped.describe()) + merged_df_dropped.to_csv(filedir + prefix + "merged_scores_nadropped.csv", index=False) + return merged_df, merged_df_dropped + + +if __name__ == "__main__": + args = get_args() + filedir = args.filedir + conv_rag = args.conv_rag + ragagent = args.ragagent + reactagent = args.reactagent + human_scores_file = args.human_scores_file + + # RAGAS scores + print("===============RAGAS scores==================") + merged_df, merged_df_dropped = merge_and_get_stats(filedir, conv_rag, ragagent, reactagent) + + # human scores + print("===============Human scores==================") + human_scores_df = pd.read_csv(filedir + human_scores_file) + print(human_scores_df.describe()) + + human_scores_df_dropped = human_scores_df.loc[human_scores_df["query"].isin(merged_df_dropped["query"])] + print(human_scores_df_dropped.describe()) + human_scores_df_dropped.to_csv(filedir + "human_scores_dropped.csv", index=False) + + # concat conv_rag, ragagent, reactagent scores in merged_df_dropped + ragas_scores = pd.concat( + [ + merged_df_dropped["conv_rag_score"], + merged_df_dropped["ragagent_score"], + merged_df_dropped["reactagent_score"], + ], + axis=0, + ) + human_scores = pd.concat( + [ + human_scores_df_dropped["conv_rag"], + human_scores_df_dropped["ragagent"], + human_scores_df_dropped["reactagent"], + ], + axis=0, + ) + + # calculate spearman correlation + print("===============Spearman correlation==================") + print(spearmanr(ragas_scores, human_scores)) + + # pearson correlation + print("===============Pearson correlation==================") + print(pearsonr(ragas_scores, human_scores)) diff --git a/evals/evaluation/agent_eval/crag_eval/run_benchmark/conventional_rag.py b/evals/evaluation/agent_eval/crag_eval/run_benchmark/conventional_rag.py new file mode 100644 index 00000000..f6ee133c --- /dev/null +++ b/evals/evaluation/agent_eval/crag_eval/run_benchmark/conventional_rag.py @@ -0,0 +1,158 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import json +import os + +import pandas as pd +import requests + + +def get_test_dataset(args): + filepath = os.path.join(args.filedir, args.filename) + if filepath.endswith(".jsonl"): + df = pd.read_json(filepath, lines=True, convert_dates=False) + elif filepath.endswith(".csv"): + df = pd.read_csv(filepath) + else: + raise ValueError("Invalid file format") + return df + + +def save_results(output_file, output_list): + with open(output_file, "w") as f: + for output in output_list: + f.write(json.dumps(output)) + f.write("\n") + + +def save_as_csv(output): + df = pd.read_json(output, lines=True, convert_dates=False) + df.to_csv(output.replace(".jsonl", ".csv"), index=False) + print(f"Saved to {output.replace('.jsonl', '.csv')}") + + +def search_knowledge_base(query: str) -> str: + """Search the knowledge base for a specific query.""" + url = os.environ.get("RETRIEVAL_TOOL_URL") + print(url) + proxies = {"http": ""} + payload = { + "text": query, + } + response = requests.post(url, json=payload, proxies=proxies) + print(response) + if "documents" in response.json(): + docs = response.json()["documents"] + context = "" + for i, doc in enumerate(docs): + if i == 0: + context = doc + else: + context += "\n" + doc + # print(context) + return context + elif "text" in response.json(): + return response.json()["text"] + elif "reranked_docs" in response.json(): + docs = response.json()["reranked_docs"] + context = "" + for i, doc in enumerate(docs): + if i == 0: + context = doc["text"] + else: + context += "\n" + doc["text"] + # print(context) + return context + else: + return "Error parsing response from the knowledge base." + + +PROMPT = """\ +### You are a helpful, respectful and honest assistant. +You are given a Question and the time when it was asked in the Pacific Time Zone (PT), referred to as "Query +Time". The query time is formatted as "mm/dd/yyyy, hh:mm:ss PT". +Please follow these guidelines when formulating your answer: +1. If the question contains a false premise or assumption, answer “invalid question”. +2. If you are uncertain or do not know the answer, respond with “I don’t know”. +3. Refer to the search results to form your answer. +5. Give concise, factual and relevant answers. + +### Search results: {context} \n +### Question: {question} \n +### Query Time: {time} \n +### Answer: +""" + + +def setup_chat_model(args): + from langchain_openai import ChatOpenAI + + params = { + "temperature": args.temperature, + "max_tokens": args.max_new_tokens, + "top_p": args.top_p, + "streaming": False, + } + openai_endpoint = f"{args.llm_endpoint_url}/v1" + llm = ChatOpenAI( + openai_api_key="EMPTY", + openai_api_base=openai_endpoint, + model_name=args.model, + **params, + ) + return llm + + +def generate_answer(llm, query, context, time): + prompt = PROMPT.format(context=context, question=query, time=time) + response = llm.invoke(prompt) + return response.content + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--filedir", type=str, default="./", help="test file directory") + parser.add_argument("--filename", type=str, default="query.csv", help="query_list_file") + parser.add_argument("--output", type=str, default="output.csv", help="query_list_file") + parser.add_argument("--llm_endpoint_url", type=str, default="http://localhost:8085", help="llm endpoint url") + parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-70B-Instruct", help="model name") + parser.add_argument("--temperature", type=float, default=0.01, help="temperature") + parser.add_argument("--max_new_tokens", type=int, default=8192, help="max_new_tokens") + parser.add_argument("--top_p", type=float, default=0.95, help="top_p") + args = parser.parse_args() + print(args) + + df = get_test_dataset(args) + print(df.shape) + + if not os.path.exists(os.path.dirname(args.output)): + os.makedirs(os.path.dirname(args.output)) + + llm = setup_chat_model(args) + + contexts = [] + output_list = [] + for _, row in df.iterrows(): + q = row["query"] + t = row["query_time"] + print("========== Query: ", q) + context = search_knowledge_base(q) + print("========== Context:\n", context) + answer = generate_answer(llm, q, context, t) + print("========== Answer:\n", answer) + contexts.append(context) + output_list.append( + { + "query": q, + "query_time": t, + "ref_answer": row["answer"], + "answer": answer, + "question_type": row["question_type"], + "static_or_dynamic": row["static_or_dynamic"], + } + ) + save_results(args.output, output_list) + + save_as_csv(args.output) diff --git a/evals/evaluation/agent_eval/crag_eval/run_benchmark/grade_answers.py b/evals/evaluation/agent_eval/crag_eval/run_benchmark/grade_answers.py index 8f95d497..5c826b5f 100644 --- a/evals/evaluation/agent_eval/crag_eval/run_benchmark/grade_answers.py +++ b/evals/evaluation/agent_eval/crag_eval/run_benchmark/grade_answers.py @@ -50,8 +50,8 @@ def grade_answers(args, test_case): scores = [] for case in test_case: metric.measure(case) - scores.append(metric.score["answer_correctness"]) - print(metric.score) + scores.append(metric.score["answer_correctness"][0]) + print(metric.score["answer_correctness"][0]) print("-" * 50) return scores @@ -79,13 +79,15 @@ def grade_answers(args, test_case): # print(test_case) scores = grade_answers(args, test_case) + # print(scores) # save the scores if args.batch_grade: print("Aggregated answer correctness score: ", scores) else: data["answer_correctness"] = scores - print("Average answer correctness score: ", data["answer_correctness"].mean()) - output_file = args.filename.split(".")[0] + "_graded.csv" + output_file = args.filename.replace(".csv", "_graded.csv") data.to_csv(os.path.join(args.filedir, output_file), index=False) print("Scores saved to ", os.path.join(args.filedir, output_file)) + + print("Average answer correctness score: ", data["answer_correctness"].mean()) diff --git a/evals/evaluation/agent_eval/crag_eval/run_benchmark/llm_judge/docker-compose-llm-judge-gaudi.yaml b/evals/evaluation/agent_eval/crag_eval/run_benchmark/llm_judge/docker-compose-llm-judge-gaudi.yaml index 572011ef..1ba0962a 100644 --- a/evals/evaluation/agent_eval/crag_eval/run_benchmark/llm_judge/docker-compose-llm-judge-gaudi.yaml +++ b/evals/evaluation/agent_eval/crag_eval/run_benchmark/llm_judge/docker-compose-llm-judge-gaudi.yaml @@ -3,7 +3,7 @@ services: tgi-service: - image: ghcr.io/huggingface/tgi-gaudi:latest + image: ghcr.io/huggingface/tgi-gaudi:2.0.5 container_name: tgi-server ports: - "8085:80" diff --git a/evals/evaluation/agent_eval/crag_eval/run_benchmark/llm_judge/launch_llm_judge_endpoint.sh b/evals/evaluation/agent_eval/crag_eval/run_benchmark/llm_judge/launch_llm_judge_endpoint.sh index 0cb08d8f..1a57cd56 100644 --- a/evals/evaluation/agent_eval/crag_eval/run_benchmark/llm_judge/launch_llm_judge_endpoint.sh +++ b/evals/evaluation/agent_eval/crag_eval/run_benchmark/llm_judge/launch_llm_judge_endpoint.sh @@ -1,7 +1,7 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -export LLM_MODEL_ID="meta-llama/Meta-Llama-3-70B-Instruct" +export LLM_MODEL_ID="meta-llama/Meta-Llama-3.1-70B-Instruct" export HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} export HF_CACHE_DIR=${HF_CACHE_DIR} docker compose -f docker-compose-llm-judge-gaudi.yaml up -d diff --git a/evals/evaluation/agent_eval/crag_eval/run_benchmark/run_compare_scores.sh b/evals/evaluation/agent_eval/crag_eval/run_benchmark/run_compare_scores.sh new file mode 100644 index 00000000..38eea97d --- /dev/null +++ b/evals/evaluation/agent_eval/crag_eval/run_benchmark/run_compare_scores.sh @@ -0,0 +1,15 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +filedir=$WORKDIR/datasets/crag_results/ +conv_rag="conv_rag_graded.csv" # replace with your file name +ragagent="ragagent_graded.csv" # replace with your file name +reactagent="react_graded.csv" # replace with your file name +human_scores_file="human_scores.csv" # replace with your file name + +python3 compare_scores.py \ +--filedir $filedir \ +--conv_rag $conv_rag \ +--ragagent $ragagent \ +--reactagent $reactagent \ +--human_scores_file $human_scores_file diff --git a/evals/evaluation/agent_eval/crag_eval/run_benchmark/run_conv_rag.sh b/evals/evaluation/agent_eval/crag_eval/run_benchmark/run_conv_rag.sh new file mode 100644 index 00000000..b5a7766a --- /dev/null +++ b/evals/evaluation/agent_eval/crag_eval/run_benchmark/run_conv_rag.sh @@ -0,0 +1,18 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +MODEL="meta-llama/Meta-Llama-3.1-70B-Instruct" +LLMENDPOINT=http://${host_ip}:8085 + +FILEDIR=$WORKDIR/datasets/crag_qas/ +FILENAME=crag_qa_music.jsonl +OUTPUT=$WORKDIR/datasets/crag_results/conv_rag_music.jsonl + +export RETRIEVAL_TOOL_URL="http://${host_ip}:8889/v1/retrievaltool" + +python3 conventional_rag.py \ +--model ${MODEL} \ +--llm_endpoint_url ${LLMENDPOINT} \ +--filedir ${FILEDIR} \ +--filename ${FILENAME} \ +--output ${OUTPUT} diff --git a/evals/evaluation/agent_eval/crag_eval/run_benchmark/run_generate_answer.sh b/evals/evaluation/agent_eval/crag_eval/run_benchmark/run_generate_answer.sh index ee863bba..20e578a6 100644 --- a/evals/evaluation/agent_eval/crag_eval/run_benchmark/run_generate_answer.sh +++ b/evals/evaluation/agent_eval/crag_eval/run_benchmark/run_generate_answer.sh @@ -7,8 +7,8 @@ endpoint=${port}/v1/chat/completions # change this to the endpoint of the agent URL="http://${host_ip}:${endpoint}" echo "AGENT ENDPOINT URL: ${URL}" -QUERYFILE=$WORKDIR/datasets/crag_qas/crag_qa_music_sampled.jsonl -OUTPUTFILE=$WORKDIR/datasets/crag_results/crag_music_sampled_results.jsonl +QUERYFILE=$WORKDIR/datasets/crag_qas/crag_qa_music.jsonl +OUTPUTFILE=$WORKDIR/datasets/crag_results/ragagent_crag_music_results.jsonl python3 generate_answers.py \ --endpoint_url ${URL} \ diff --git a/evals/evaluation/agent_eval/crag_eval/run_benchmark/run_grading.sh b/evals/evaluation/agent_eval/crag_eval/run_benchmark/run_grading.sh index 5431d39b..b9af1a18 100644 --- a/evals/evaluation/agent_eval/crag_eval/run_benchmark/run_grading.sh +++ b/evals/evaluation/agent_eval/crag_eval/run_benchmark/run_grading.sh @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 FILEDIR=$WORKDIR/datasets/crag_results/ -FILENAME=crag_music_sampled_results.csv +FILENAME=ragagent_crag_music_results.csv LLM_ENDPOINT=http://${host_ip}:8085 # change host_ip to the IP of LLM endpoint python3 grade_answers.py \ diff --git a/evals/metrics/ragas/ragas.py b/evals/metrics/ragas/ragas.py index c31cb632..8b98b60b 100644 --- a/evals/metrics/ragas/ragas.py +++ b/evals/metrics/ragas/ragas.py @@ -12,7 +12,42 @@ from langchain_huggingface import HuggingFaceEndpoint # import * is only allowed at module level according to python syntax -from ragas.metrics import * +try: + # from ragas.metrics import * + from ragas import evaluate + from ragas.metrics import ( + answer_correctness, + answer_relevancy, + answer_similarity, + context_precision, + context_recall, + faithfulness, + ) +except ModuleNotFoundError: + raise ModuleNotFoundError("Please install ragas to use this metric. `pip install ragas`.") + +try: + from datasets import Dataset +except ModuleNotFoundError: + raise ModuleNotFoundError("Please install dataset") + +VALIDATED_LIST = [ + "answer_correctness", + "answer_relevancy", + "answer_similarity", + "context_precision", + "context_recall", + "faithfulness", +] + +metrics_mapping = { + "answer_correctness": answer_correctness, + "answer_relevancy": answer_relevancy, + "answer_similarity": answer_similarity, + "context_precision": context_precision, + "context_recall": context_recall, + "faithfulness": faithfulness, +} def format_ragas_metric_name(name: str): @@ -34,64 +69,6 @@ def __init__( self.embeddings = embeddings self.metrics = metrics - # self.validated_list = [ - # "answer_correctness", - # "answer_relevancy", - # "answer_similarity", - # "context_precision", - # "context_recall", - # "faithfulness", - # "context_utilization", - # # "reference_free_rubrics_score", - # ] - - async def a_measure(self, test_case: Dict): - return self.measure(test_case) - - def measure(self, test_case: Dict): - # sends to server - try: - from ragas import evaluate - from ragas.metrics import ALL_METRICS - - self.metric_names = [metric.__class__.__name__ for metric in ALL_METRICS] - self.metric_names = [re.sub(r"(?