Skip to content

Commit

Permalink
refactor: Refactor test for test_incremental_evaluator_should_filter_…
Browse files Browse the repository at this point in the history
…previous_run_ids

Task: IL-315
  • Loading branch information
SebastianNiehusAA committed May 2, 2024
1 parent 151a331 commit 013ccaf
Showing 1 changed file with 31 additions and 46 deletions.
77 changes: 31 additions & 46 deletions tests/evaluation/test_diff_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,77 +56,62 @@ def do_run(self, input: str, tracer: Tracer) -> str:
def test_incremental_evaluator_should_filter_previous_run_ids() -> None:
# Given
examples = [Example(input="a", expected_output="0", id="id_0")]

dataset_repository = InMemoryDatasetRepository()
run_repository = InMemoryRunRepository()
evaluation_repository = InMemoryEvaluationRepository()
dataset = dataset_repository.create_dataset(
examples=examples, dataset_name="test_examples"
)

run_repository = InMemoryRunRepository()
first_runner = Runner(
task=DummyTask("Task1"),
dataset_repository=dataset_repository,
run_repository=run_repository,
description="test_runner_1",
)
first_run = first_runner.run_dataset(dataset.id)

evaluation_repository = InMemoryEvaluationRepository()
evaluator = IncrementalEvaluator(
dataset_repository=dataset_repository,
run_repository=run_repository,
evaluation_repository=evaluation_repository,
description="test_incremental_evaluator",
incremental_evaluation_logic=DummyIncrementalLogic(),
)
first_evaluation_overview = evaluator.evaluate_additional_runs(first_run.id)

second_runner = Runner(
task=DummyTask("Task2"),
dataset_repository=dataset_repository,
run_repository=run_repository,
description="test_runner_2",
)
second_run = second_runner.run_dataset(dataset.id)
def create_run(name: str) -> str:
runner = Runner(
task=DummyTask(name),
dataset_repository=dataset_repository,
run_repository=run_repository,
description=f"Runner of {name}",
)
return runner.run_dataset(dataset.id).id

first_run_id = create_run("first")

first_evaluation_overview = evaluator.evaluate_additional_runs(first_run_id)

second_run_id = create_run("second")

second_evaluation_overview = evaluator.evaluate_additional_runs(
first_run.id,
second_run.id,
first_run_id,
second_run_id,
previous_evaluation_ids=[first_evaluation_overview.id],
)

second_result = next(
iter(evaluator.evaluation_lineages(second_evaluation_overview.id))
).evaluation.result
assert isinstance(second_result, DummyEvaluation)
assert second_result.new_run_ids == [second_run.id]
assert second_result.old_run_ids == [[first_run.id]]
assert second_result.new_run_ids == [second_run_id]
assert second_result.old_run_ids == [[first_run_id]]

independent_runner = Runner(
task=DummyTask("TaskIndependent"),
dataset_repository=dataset_repository,
run_repository=run_repository,
description="test_runner_independent",
)
independent_run = independent_runner.run_dataset(dataset.id)
independent_evaluation_overview = evaluator.evaluate_additional_runs(
independent_run.id
)
independent_run_id = create_run("independent")

third_runner = Runner(
task=DummyTask("Task3"),
dataset_repository=dataset_repository,
run_repository=run_repository,
description="test_runner_3",
independent_evaluation_overview = evaluator.evaluate_additional_runs(
independent_run_id
)

third_run = third_runner.run_dataset(dataset.id)
third_run_id = create_run("third")

third_evaluation_overview = evaluator.evaluate_additional_runs(
first_run.id,
second_run.id,
independent_run.id,
third_run.id,
first_run_id,
second_run_id,
independent_run_id,
third_run_id,
previous_evaluation_ids=[
second_evaluation_overview.id,
independent_evaluation_overview.id,
Expand All @@ -137,6 +122,6 @@ def test_incremental_evaluator_should_filter_previous_run_ids() -> None:
iter(evaluator.evaluation_lineages(third_evaluation_overview.id))
).evaluation.result
assert isinstance(third_result, DummyEvaluation)
assert third_result.new_run_ids == [third_run.id]
assert sorted(third_result.old_run_ids[0]) == sorted([first_run.id, second_run.id])
assert sorted(third_result.old_run_ids[1]) == sorted([independent_run.id])
assert third_result.new_run_ids == [third_run_id]
assert sorted(third_result.old_run_ids[0]) == sorted([first_run_id, second_run_id])
assert sorted(third_result.old_run_ids[1]) == sorted([independent_run_id])

0 comments on commit 013ccaf

Please sign in to comment.