diff --git a/evals/evaluation/rag_eval/examples/eval_crud.py b/evals/evaluation/rag_eval/examples/eval_crud.py index 4a4ac8e6..1cb3b247 100644 --- a/evals/evaluation/rag_eval/examples/eval_crud.py +++ b/evals/evaluation/rag_eval/examples/eval_crud.py @@ -8,8 +8,11 @@ import json import os +from tqdm import tqdm + from evals.evaluation.rag_eval import Evaluator from evals.evaluation.rag_eval.template import CRUDTemplate +from evals.metrics.ragas import RagasMetric class CRUD_Evaluator(Evaluator): @@ -78,6 +81,45 @@ def get_template(self): def post_process(self, result): return result.split("")[-1].split("")[0].strip() + def get_ragas_metrics(self, results, arguments): + from langchain_huggingface import HuggingFaceEndpointEmbeddings + + embeddings = HuggingFaceEndpointEmbeddings(model=arguments.tei_embedding_endpoint) + + metric = RagasMetric( + threshold=0.5, + model=arguments.llm_endpoint, + embeddings=embeddings, + metrics=["faithfulness", "answer_relevancy"], + ) + + all_answer_relevancy = 0 + all_faithfulness = 0 + ragas_inputs = { + "question": [], + "answer": [], + "ground_truth": [], + "contexts": [], + } + + valid_results = self.remove_invalid(results["results"]) + + for data in tqdm(valid_results): + data = data["original_data"] + + query = self.get_query(data) + generated_text = data["generated_text"] + ground_truth = data["ground_truth_text"] + retrieved_documents = data["retrieved_documents"] + + ragas_inputs["question"].append(query) + ragas_inputs["answer"].append(generated_text) + ragas_inputs["ground_truth"].append(ground_truth) + ragas_inputs["contexts"].append(retrieved_documents[:3]) + + ragas_metrics = metric.measure(ragas_inputs) + return ragas_metrics + def args_parser(): parser = argparse.ArgumentParser() @@ -116,6 +158,13 @@ def args_parser(): parser.add_argument( "--retrieval_endpoint", type=str, default="http://localhost:7000/v1/retrieval", help="Service URL address." ) + parser.add_argument( + "--tei_embedding_endpoint", + type=str, + default="http://localhost:8090", + help="Service URL address of tei embedding.", + ) + parser.add_argument("--ragas_metrics", action="store_true", help="Whether to compute ragas metrics.") parser.add_argument("--llm_endpoint", type=str, default=None, help="Service URL address.") parser.add_argument( "--show_progress_bar", action="store", default=True, type=bool, help="Whether to show a progress bar" @@ -145,14 +194,16 @@ def main(): "summarization, question_answering, continuation and hallucinated_modified." ) output_save_path = os.path.join(args.output_dir, f"{task}.json") - evaluator = CRUD_Evaluator( - dataset=dataset, output_path=output_save_path, task=task, llm_endpoint=args.llm_endpoint - ) + evaluator = CRUD_Evaluator(dataset=dataset, output_path=output_save_path, task=task) if args.ingest_docs: CRUD_Evaluator.ingest_docs(args.docs_path, args.database_endpoint, args.chunk_size, args.chunk_overlap) results = evaluator.evaluate( args, show_progress_bar=args.show_progress_bar, contain_original_data=args.contain_original_data ) + print(results["overall"]) + if args.ragas_metrics: + ragas_metrics = evaluator.get_ragas_metrics(results, args) + print(ragas_metrics) print(f"Evaluation results of task {task} saved to {output_save_path}.")