Skip to content

Commit

Permalink
Merge pull request #11 from NASA-IMPACT/feature/pipeline-abstraction
Browse files Browse the repository at this point in the history
[alpha] Addition of simple pipeline abstraction
  • Loading branch information
NISH1001 authored Mar 27, 2023
2 parents ceeaaeb + 9ccf783 commit 9365268
Show file tree
Hide file tree
Showing 12 changed files with 261 additions and 126 deletions.
2 changes: 2 additions & 0 deletions evalem/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# flake8: noqa
from .defaults import SimpleEvaluationPipeline
33 changes: 33 additions & 0 deletions evalem/pipelines/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/usr/bin/env python3

from abc import abstractmethod
from typing import Any

from .._base import AbstractBase


class Pipeline(AbstractBase):
"""
Represents a type for Pipeline component.
All the downstream pipeline object should implement the `.run(...)` method.
See `pipelines.defaults.SimpleEvaluationPipeline` for an implementation.
"""

@abstractmethod
def run(self, *args, **kwags) -> Any:
"""
Entry-point method to run the evaluation.
"""
raise NotImplementedError()

def __call__(self, *args, **kwargs) -> Any:
return self.run(*args, **kwargs)


def main():
pass


if __name__ == "__main__":
main()
81 changes: 81 additions & 0 deletions evalem/pipelines/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python3

from typing import Iterable, List, Mapping, Type, Union

from ..evaluators._base import Evaluator
from ..models._base import ModelWrapper
from ..structures import EvaluationReferenceInstance, MetricOutput
from ._base import Pipeline


class SimpleEvaluationPipeline(Pipeline):
"""
This is a very basic evaluation pipeline that uses single model
and a list of evaluators to run the evaluation.
Args:
```model```: ```Type[ModelWrapper]```
Wrapped model to do the inference.
```evaluators```: ```Union[Evaluator, Iterable[Evalautor]]```
Either a single evaluator or an iterable of evaluators
Note: If single evaluator is provided, it'll be wrapped into
an iterable ultimately.
Usage:
.. code-block: python
from evalem.pipelines import SimpleEvaluationPipeline
from evalem.models import TextClassificationHFPipelineWrapper
from evalem.evaluators import TextClassificationEvaluator
model = TextClassificationHFPipelineWrapper()
evaluator = TextClassificationEvaluator()
pipe = SimpleEvaluationPipeline(model=model, evaluators=evaluator)
results = pipe(inputs, references)
# or
results = pipe.run(inputs, references)
"""

def __init__(
self,
model: Type[ModelWrapper],
evaluators: Union[Evaluator, Iterable[Evaluator]],
) -> None:
self.model = model

# if only single evaluator, wrap into an iterable
self.evaluators = (
[evaluators] if not isinstance(evaluators, Iterable) else evaluators
)

def run(
self,
inputs: Mapping,
references: EvaluationReferenceInstance,
**kwargs,
) -> List[MetricOutput]:
"""
```inputs```: ```Mapping```
Input data to run over the model to get predictions.
```references```: ```EvaluationReferenceInstance```
References/ground-truths to be used for evaluation.
See `evalem.metrics` for more information.
"""
predictions = self.model(inputs, **kwargs)
return list(
map(
lambda e: e(predictions=predictions, references=references),
self.evaluators,
),
)


def main():
pass


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[pytest]
markers =
metrics: For testing only the metrics.
models: for testing only the models.
models: For testing only the models.
pipelines: For testing only the pipelines.
Empty file added tests/__init__.py
Empty file.
58 changes: 58 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python3

import json
from pathlib import Path

import pytest

from evalem.evaluators import QAEvaluator, TextClassificationEvaluator
from evalem.models import (
QuestionAnsweringHFPipelineWrapper,
TextClassificationHFPipelineWrapper,
)


@pytest.fixture(autouse=True, scope="session")
def squad_v2_data():
path = Path(__file__).parent.joinpath("data/squad_v2.json")
data = {}
with open(path, "r") as f:
data = json.load(f)
return data


@pytest.fixture(autouse=True, scope="session")
def imdb_data():
path = Path(__file__).parent.joinpath("data/imdb.json")
data = {}
with open(path, "r") as f:
data = json.load(f)
return data


@pytest.fixture(autouse=True, scope="session")
def model_qa_default():
yield QuestionAnsweringHFPipelineWrapper()


@pytest.fixture(autouse=True, scope="session")
def model_classification_default():
yield TextClassificationHFPipelineWrapper(hf_params=dict(truncation=True))


@pytest.fixture(autouse=True, scope="session")
def evaluator_qa_default():
yield QAEvaluator()


@pytest.fixture(autouse=True, scope="session")
def evaluator_classification_default():
yield TextClassificationEvaluator()


def main():
pass


if __name__ == "__main__":
main()
32 changes: 0 additions & 32 deletions tests/models/fixtures.py

This file was deleted.

46 changes: 0 additions & 46 deletions tests/models/test_classification.py

This file was deleted.

39 changes: 39 additions & 0 deletions tests/models/test_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python3

from typing import Iterable

import pytest

from evalem.structures import PredictionDTO


@pytest.mark.parametrize(
"data, model",
[
("squad_v2_data", "model_qa_default"),
("imdb_data", "model_classification_default"),
],
)
@pytest.mark.models
class TestDefaultModels:
def test_predictions_format(self, data, model, request):
data = request.getfixturevalue(data)
model = request.getfixturevalue(model)
predictions = model(data.get("inputs", []))
assert isinstance(predictions, Iterable)
assert isinstance(predictions[0], PredictionDTO)

def test_predictions_len(self, data, model, request):
data = request.getfixturevalue(data)
model = request.getfixturevalue(model)
predictions = model(data.get("inputs", []))
print(predictions)
assert len(predictions) == len(data.get("references", []))


def main():
pass


if __name__ == "__main__":
main()
47 changes: 0 additions & 47 deletions tests/models/test_qa.py

This file was deleted.

Empty file added tests/pipelines/__init__.py
Empty file.
46 changes: 46 additions & 0 deletions tests/pipelines/test_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/usr/bin/env python3

from pprint import pprint
from typing import Iterable

import pytest

from evalem.pipelines import SimpleEvaluationPipeline

# from ..models.test_defaults import TestDefaultModels


@pytest.mark.dependency(depends=["TestDefaultModels"])
@pytest.mark.parametrize(
"data, model, evaluators",
[
("squad_v2_data", "model_qa_default", "evaluator_qa_default"),
(
"imdb_data",
"model_classification_default",
"evaluator_classification_default",
),
],
)
@pytest.mark.pipelines
class TestSimplePipeline:
def test_evaluation(self, data, model, evaluators, request):
data = request.getfixturevalue(data)
model = request.getfixturevalue(model)
evaluators = request.getfixturevalue(evaluators)
pipeline = SimpleEvaluationPipeline(model=model, evaluators=evaluators)

inputs, references = data.get("inputs", []), data.get("references", [])

results = pipeline(inputs, references)
pprint(results)

assert isinstance(results, Iterable)


def main():
pass


if __name__ == "__main__":
main()

0 comments on commit 9365268

Please sign in to comment.