Skip to content

Commit

Permalink
Merge branch 'main' into auto_eval
Browse files Browse the repository at this point in the history
  • Loading branch information
lvliang-intel authored Sep 27, 2024
2 parents f134058 + f2bff45 commit ccf74a7
Showing 1 changed file with 54 additions and 3 deletions.
57 changes: 54 additions & 3 deletions evals/evaluation/rag_eval/examples/eval_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
import json
import os

from tqdm import tqdm

from evals.evaluation.rag_eval import Evaluator
from evals.evaluation.rag_eval.template import CRUDTemplate
from evals.metrics.ragas import RagasMetric


class CRUD_Evaluator(Evaluator):
Expand Down Expand Up @@ -78,6 +81,45 @@ def get_template(self):
def post_process(self, result):
return result.split("<response>")[-1].split("</response>")[0].strip()

def get_ragas_metrics(self, results, arguments):
from langchain_huggingface import HuggingFaceEndpointEmbeddings

embeddings = HuggingFaceEndpointEmbeddings(model=arguments.tei_embedding_endpoint)

metric = RagasMetric(
threshold=0.5,
model=arguments.llm_endpoint,
embeddings=embeddings,
metrics=["faithfulness", "answer_relevancy"],
)

all_answer_relevancy = 0
all_faithfulness = 0
ragas_inputs = {
"question": [],
"answer": [],
"ground_truth": [],
"contexts": [],
}

valid_results = self.remove_invalid(results["results"])

for data in tqdm(valid_results):
data = data["original_data"]

query = self.get_query(data)
generated_text = data["generated_text"]
ground_truth = data["ground_truth_text"]
retrieved_documents = data["retrieved_documents"]

ragas_inputs["question"].append(query)
ragas_inputs["answer"].append(generated_text)
ragas_inputs["ground_truth"].append(ground_truth)
ragas_inputs["contexts"].append(retrieved_documents[:3])

ragas_metrics = metric.measure(ragas_inputs)
return ragas_metrics


def args_parser():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -116,6 +158,13 @@ def args_parser():
parser.add_argument(
"--retrieval_endpoint", type=str, default="http://localhost:7000/v1/retrieval", help="Service URL address."
)
parser.add_argument(
"--tei_embedding_endpoint",
type=str,
default="http://localhost:8090",
help="Service URL address of tei embedding.",
)
parser.add_argument("--ragas_metrics", action="store_true", help="Whether to compute ragas metrics.")
parser.add_argument("--llm_endpoint", type=str, default=None, help="Service URL address.")
parser.add_argument(
"--show_progress_bar", action="store", default=True, type=bool, help="Whether to show a progress bar"
Expand Down Expand Up @@ -145,14 +194,16 @@ def main():
"summarization, question_answering, continuation and hallucinated_modified."
)
output_save_path = os.path.join(args.output_dir, f"{task}.json")
evaluator = CRUD_Evaluator(
dataset=dataset, output_path=output_save_path, task=task, llm_endpoint=args.llm_endpoint
)
evaluator = CRUD_Evaluator(dataset=dataset, output_path=output_save_path, task=task)
if args.ingest_docs:
CRUD_Evaluator.ingest_docs(args.docs_path, args.database_endpoint, args.chunk_size, args.chunk_overlap)
results = evaluator.evaluate(
args, show_progress_bar=args.show_progress_bar, contain_original_data=args.contain_original_data
)
print(results["overall"])
if args.ragas_metrics:
ragas_metrics = evaluator.get_ragas_metrics(results, args)
print(ragas_metrics)
print(f"Evaluation results of task {task} saved to {output_save_path}.")


Expand Down

0 comments on commit ccf74a7

Please sign in to comment.