Skip to content

Commit

Permalink
refactor: split evaluation and aggregation
Browse files Browse the repository at this point in the history
Task: IL-259
  • Loading branch information
Valentina Galata committed Feb 21, 2024
1 parent fb2f422 commit 47fec3d
Show file tree
Hide file tree
Showing 19 changed files with 775 additions and 379 deletions.
8 changes: 4 additions & 4 deletions src/examples/classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@
"outputs": [],
"source": [
"from os import getenv\n",
"from intelligence_layer.connectors import LimitedConcurrencyClient\n",
"\n",
"from intelligence_layer.use_cases import ClassifyInput, PromptBasedClassify\n",
"from intelligence_layer.core import Chunk, InMemoryTracer\n",
"\n",
"from dotenv import load_dotenv\n",
"\n",
"from intelligence_layer.connectors import LimitedConcurrencyClient\n",
"from intelligence_layer.core import Chunk, InMemoryTracer\n",
"from intelligence_layer.use_cases import ClassifyInput, PromptBasedClassify\n",
"\n",
"load_dotenv()\n",
"\n",
"text_to_classify = Chunk(\"In the distant future, a space exploration party embarked on a thrilling journey to the uncharted regions of the galaxy. \\n\\\n",
Expand Down
4 changes: 2 additions & 2 deletions src/examples/document_index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@
"source": [
"from os import getenv\n",
"\n",
"from intelligence_layer.connectors import DocumentIndexClient\n",
"\n",
"from dotenv import load_dotenv\n",
"\n",
"from intelligence_layer.connectors import DocumentIndexClient\n",
"\n",
"load_dotenv()\n",
"\n",
"\n",
Expand Down
37 changes: 23 additions & 14 deletions src/examples/evaluation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,14 @@
" InMemoryRunRepository,\n",
" InMemoryDatasetRepository,\n",
" InMemoryAggregationRepository,\n",
" Runner,\n",
" Runner, Aggregator,\n",
")\n",
"from intelligence_layer.use_cases import (\n",
" PromptBasedClassify,\n",
" SingleLabelClassifyEvaluationLogic,\n",
" SingleLabelClassifyAggregationLogic,\n",
")\n",
"\n",
"\n",
"load_dotenv()\n",
"\n",
"client = LimitedConcurrencyClient.from_token(os.getenv(\"AA_TOKEN\"))\n",
Expand All @@ -75,9 +74,13 @@
" dataset_repository,\n",
" run_repository,\n",
" evaluation_repository,\n",
" aggregation_repository,\n",
" \"singel-label-classify\",\n",
" \"single-label-classify\",\n",
" evaluation_logic,\n",
")\n",
"aggregator = Aggregator(\n",
" evaluation_repository,\n",
" aggregation_repository,\n",
" \"single-label-classify\",\n",
" aggregation_logic,\n",
")\n",
"runner = Runner(task, dataset_repository, run_repository, \"prompt-based-classify\")"
Expand Down Expand Up @@ -114,10 +117,10 @@
"])\n",
"\n",
"run_overview = runner.run_dataset(single_example_dataset, NoOpTracer())\n",
"aggregation_overview = evaluator.eval_and_aggregate_runs(run_overview.id)\n",
"evaluation_overview = evaluator.evaluate_runs(run_overview.id)\n",
"aggregation_overview = aggregator.aggregate_evaluation(evaluation_overview.id)\n",
"\n",
"print(\"Statistics: \", aggregation_overview.statistics)\n",
"\n"
"print(\"Statistics: \", aggregation_overview.statistics)"
]
},
{
Expand All @@ -140,7 +143,7 @@
"dataset = load_dataset(\"cardiffnlp/tweet_topic_multi\")\n",
"test_set_name = \"validation_random\"\n",
"all_data = list(dataset[test_set_name])\n",
"data = all_data[:25] # this has 573 datapoints, let's take a look at 25 for now\n"
"data = all_data[:25] # this has 573 datapoints, let's take a look at 25 for now"
]
},
{
Expand All @@ -157,7 +160,7 @@
"metadata": {},
"outputs": [],
"source": [
"data[1]\n"
"data[1]"
]
},
{
Expand Down Expand Up @@ -215,7 +218,8 @@
"outputs": [],
"source": [
"run_overview = runner.run_dataset(dataset_id)\n",
"aggregation_overview = evaluator.eval_and_aggregate_runs(run_overview.id)\n",
"evaluation_overview = evaluator.evaluate_runs(run_overview.id)\n",
"aggregation_overview = aggregator.aggregate_evaluation(evaluation_overview.id)\n",
"aggregation_overview.raise_on_evaluation_failure()"
]
},
Expand Down Expand Up @@ -288,9 +292,13 @@
" dataset_repository,\n",
" run_repository,\n",
" evaluation_repository,\n",
" aggregation_repository,\n",
" \"multi-label-classify\",\n",
" eval_logic,\n",
")\n",
"embedding_based_classify_aggregator = Aggregator(\n",
" evaluation_repository,\n",
" aggregation_repository,\n",
" \"multi-label-classify\",\n",
" aggregation_logic,\n",
")\n",
"embedding_based_classify_runner = Runner(\n",
Expand All @@ -308,8 +316,9 @@
"outputs": [],
"source": [
"embedding_based_classify_run_result = embedding_based_classify_runner.run_dataset(dataset_id)\n",
"embedding_based_classify_evaluation_result = embedding_based_classify_evaluator.eval_and_aggregate_runs(embedding_based_classify_run_result.id)\n",
"embedding_based_classify_evaluation_result.raise_on_evaluation_failure()"
"embedding_based_classify_evaluation_result = embedding_based_classify_evaluator.evaluate_runs(embedding_based_classify_run_result.id)\n",
"embedding_based_classify_aggregation_result = embedding_based_classify_aggregator.aggregate_evaluation(embedding_based_classify_evaluation_result.id)\n",
"embedding_based_classify_aggregation_result.raise_on_evaluation_failure()"
]
},
{
Expand All @@ -318,7 +327,7 @@
"metadata": {},
"outputs": [],
"source": [
"embedding_based_classify_evaluation_result.statistics.macro_avg"
"embedding_based_classify_aggregation_result.statistics.macro_avg"
]
},
{
Expand Down
54 changes: 32 additions & 22 deletions src/examples/human_evaluation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,26 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from typing import Iterable, cast\n",
"\n",
"from datasets import load_dataset\n",
"from dotenv import load_dotenv\n",
"from intelligence_layer.core import (\n",
" InstructInput, \n",
" Instruct, \n",
" PromptOutput\n",
")\n",
"from pydantic import BaseModel\n",
"\n",
"from intelligence_layer.connectors import (\n",
" LimitedConcurrencyClient, \n",
" Question, \n",
" ArgillaEvaluation, \n",
" DefaultArgillaClient, \n",
" Field, \n",
" LimitedConcurrencyClient,\n",
" Question,\n",
" ArgillaEvaluation,\n",
" DefaultArgillaClient,\n",
" Field,\n",
" RecordData\n",
")\n",
"from intelligence_layer.core import (\n",
" InstructInput,\n",
" Instruct,\n",
" PromptOutput\n",
")\n",
"from intelligence_layer.evaluation import (\n",
" ArgillaEvaluator,\n",
" AggregationLogic,\n",
Expand All @@ -70,10 +76,7 @@
" Runner,\n",
" SuccessfulExampleOutput\n",
")\n",
"from typing import Iterable, cast, Sequence\n",
"from datasets import load_dataset\n",
"import os\n",
"from pydantic import BaseModel\n",
"from intelligence_layer.evaluation.argilla import ArgillaAggregator\n",
"\n",
"load_dotenv()\n",
"\n",
Expand Down Expand Up @@ -318,14 +321,14 @@
" def _to_record(\n",
" self,\n",
" example: Example[InstructInput, None],\n",
" example_outputs: SuccessfulExampleOutput[PromptOutput],\n",
" *example_outputs: SuccessfulExampleOutput[PromptOutput],\n",
" ) -> RecordDataSequence:\n",
" return RecordDataSequence(\n",
" records=[\n",
" RecordData(\n",
" content={\n",
" \"input\": example.input.instruction,\n",
" \"output\": example_outputs.output.completion,\n",
" \"output\": example_outputs[0].output.completion,\n",
" },\n",
" example_id=example.id,\n",
" )\n",
Expand All @@ -340,16 +343,16 @@
"eval_logic = InstructArgillaEvaluationLogic()\n",
"aggregation_logic = InstructArgillaAggregationLogic()\n",
"\n",
"argilla_evaluation_repository = ArgillaEvaluationRepository(\n",
" evaluation_repository, argilla_client, workspace_id, fields, questions\n",
")\n",
"\n",
"evaluator = ArgillaEvaluator(\n",
" dataset_repository,\n",
" run_repository,\n",
" ArgillaEvaluationRepository(\n",
" evaluation_repository, argilla_client, workspace_id, fields, questions\n",
" ),\n",
" aggregation_repository,\n",
" argilla_evaluation_repository,\n",
" \"instruct\",\n",
" eval_logic,\n",
" aggregation_logic,\n",
")"
]
},
Expand Down Expand Up @@ -388,8 +391,15 @@
"metadata": {},
"outputs": [],
"source": [
"aggregator = ArgillaAggregator(\n",
" argilla_evaluation_repository,\n",
" aggregation_repository,\n",
" \"instruct\",\n",
" aggregation_logic,\n",
")\n",
"\n",
"if eval_overview:\n",
" output = evaluator.aggregate_evaluation(eval_overview.id)\n",
" output = aggregator.aggregate_evaluation(eval_overview.id)\n",
" print(output.statistics)"
]
}
Expand Down
6 changes: 4 additions & 2 deletions src/examples/performance_tips.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@
"metadata": {},
"outputs": [],
"source": [
"from intelligence_layer.core.task import Task\n",
"from intelligence_layer.core.tracer import TaskSpan, NoOpTracer\n",
"import time\n",
"from typing import Any\n",
"\n",
"from intelligence_layer.core.task import Task\n",
"from intelligence_layer.core.tracer import TaskSpan, NoOpTracer\n",
"\n",
"\n",
"class DummyTask(Task):\n",
" def do_run(self, input: Any, task_span: TaskSpan) -> Any:\n",
" time.sleep(2)\n",
Expand Down
2 changes: 1 addition & 1 deletion src/examples/qa.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
"outputs": [],
"source": [
"from os import getenv\n",
"from dotenv import load_dotenv\n",
"\n",
"from dotenv import load_dotenv\n",
"\n",
"load_dotenv()\n",
"from intelligence_layer.connectors import LimitedConcurrencyClient\n",
Expand Down
Loading

0 comments on commit 47fec3d

Please sign in to comment.