Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
pitneitemeier committed Feb 28, 2024
1 parent a115df5 commit 34f702d
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 13 deletions.
9 changes: 9 additions & 0 deletions src/examples/evaluation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,15 @@
"table.get_dataframe()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"run.finish()"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,14 @@ def list_datasets(self) -> Iterable[str]:
class WandbDatasetRepository(DatasetRepository):
def __init__(self) -> None:
self.team_name = "aleph-alpha-intelligence-layer-trial"
self._run = None
self._run: Run | None = None

def create_dataset(
self,
examples: Iterable[Example[Input, ExpectedOutput]],
) -> str:
if self._run is None:
raise ValueError("Run not started")
dataset_id = str(uuid4())
artifact = wandb.Artifact(name=dataset_id, type="dataset")
table = Table(columns=["example"]) # type: ignore
Expand Down Expand Up @@ -90,6 +92,8 @@ def examples_by_id(

@lru_cache(maxsize=1)
def _get_dataset(self, id: str) -> Table:
if self._run is None:
raise ValueError("Run not started")
artifact = self._run.use_artifact(
f"{self.team_name}/{self._run.project_name()}/{id}:latest"
)
Expand All @@ -103,9 +107,12 @@ def example(
expected_output_type: type[ExpectedOutput],
) -> Optional[Example[Input, ExpectedOutput]]:
examples = self.examples_by_id(dataset_id, input_type, expected_output_type)
if examples is None:
return None
for example in examples:
if example.id == example_id:
return example
return None

def delete_dataset(self, dataset_id: str) -> None:
raise NotImplementedError
Expand Down
29 changes: 19 additions & 10 deletions src/intelligence_layer/evaluation/data_storage/run_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from functools import lru_cache
from os import getenv
from pathlib import Path
from typing import Iterable, Optional, Sequence, cast
from typing import Any, Iterable, Optional, Sequence, cast
from uuid import uuid4

import wandb
from dotenv import load_dotenv
from wandb import Table
from wandb.sdk.wandb_run import Run

from intelligence_layer.core.task import Output
Expand Down Expand Up @@ -273,10 +274,10 @@ def store_run_overview(self, overview: RunOverview) -> None:

class WandbRunRepository(RunRepository):
def __init__(self) -> None:
self._example_outputs = dict()
self._run_overviews = dict()
self._run = None
self.team_name = "aleph-alpha-intelligence-layer-trial"
self._example_outputs: dict[str, wandb.Table] = dict()
self._run_overviews: dict[str, wandb.Table] = dict()
self._run: Run | None = None
self.team_name: str = "aleph-alpha-intelligence-layer-trial"

def run_ids(self) -> Sequence[str]:
"""Returns the ids of all stored runs.
Expand All @@ -287,7 +288,7 @@ def run_ids(self) -> Sequence[str]:
Returns:
The ids of all stored runs.
"""
pass
raise NotImplementedError

def example_outputs(
self, run_id: str, output_type: type[Output]
Expand All @@ -307,13 +308,17 @@ def example_outputs(

@lru_cache(maxsize=2)
def _get_table(self, artifact_id: str, name: str) -> wandb.Table:
if self._run is None:
raise ValueError(
"The run has not been started, are you using a WandbRunner?"
)
artifact = self._run.use_artifact(
f"{self.team_name}/{self._run.project_name()}/{artifact_id}:latest"
)
return artifact.get(name) # type: ignore

def store_example_output(self, example_output: ExampleOutput[Output]) -> None:
self._example_outputs[example_output.run_id].add_data(
self._example_outputs[example_output.run_id].add_data( # type: ignore
example_output.model_dump_json(),
)

Expand Down Expand Up @@ -354,16 +359,20 @@ def store_run_overview(self, overview: RunOverview) -> None:
Args:
overview: The overview to be persisted.
"""
self._run_overviews[overview.id].add_data(
self._run_overviews[overview.id].add_data( # type: ignore
overview.model_dump_json(),
)

def start_run(self, run: Run, run_id: str) -> None:
self._run = run
self._example_outputs[run_id] = wandb.Table(columns=["example_output"])
self._run_overviews[run_id] = wandb.Table(columns=["run_overview"])
self._example_outputs[run_id] = Table(columns=["example_output"]) # type: ignore
self._run_overviews[run_id] = Table(columns=["run_overview"]) # type: ignore

def finish_run(self, run_id: str) -> None:
if self._run is None:
raise ValueError(
"The run has not been started, are you using a WandbRunner?"
)
artifact = wandb.Artifact(name=run_id, type="Run")
artifact.add(self._example_outputs[run_id], name="example_outputs")
artifact.add(self._run_overviews[run_id], name="run_overview")
Expand Down
8 changes: 6 additions & 2 deletions src/intelligence_layer/evaluation/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import wandb
from pydantic import JsonValue
from tqdm import tqdm
from wandb.sdk.wandb_run import Run

from intelligence_layer.core.task import Input, Output, Task
from intelligence_layer.core.tracer import CompositeTracer, Tracer, utc_now
Expand Down Expand Up @@ -89,6 +90,7 @@ def run_dataset(
An overview of the run. Outputs will not be returned but instead stored in the
:class:`RunRepository` provided in the __init__.
"""
run_id = str(uuid4()) if run_id is None else run_id

def run(
example: Example[Input, ExpectedOutput]
Expand All @@ -112,7 +114,6 @@ def run(
if num_examples:
examples = islice(examples, num_examples)

run_id = str(uuid4()) if run_id is None else run_id
start = utc_now()
with ThreadPoolExecutor(max_workers=10) as executor:
ids_and_outputs = tqdm(executor.map(run, examples), desc="Evaluating")
Expand Down Expand Up @@ -142,7 +143,7 @@ def run(
return run_overview


class WandbRunner(Runner):
class WandbRunner(Runner[Input, Output]):
def __init__(
self,
task: Task[Input, Output],
Expand All @@ -152,6 +153,7 @@ def __init__(
wandb_project_name: str,
) -> None:
super().__init__(task, dataset_repository, run_repository, description)
self._run_repository: WandbRunRepository = run_repository
self._wandb_project_name = wandb_project_name

def run_dataset(
Expand All @@ -163,7 +165,9 @@ def run_dataset(
) -> RunOverview:
run = wandb.init(project=self._wandb_project_name, job_type="Runner")
run_id = str(uuid4())
assert isinstance(run, Run)
self._run_repository.start_run(run, run_id)
run_overview = super().run_dataset(dataset_id, tracer, num_examples, run_id)
self._run_repository.finish_run(run_id)
run.finish()
return run_overview

0 comments on commit 34f702d

Please sign in to comment.