Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/load all values #310

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 2 additions & 23 deletions demo/scenarios/report.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,35 +76,14 @@
"source": [
"from mlte.spec.spec import Spec\n",
"from mlte.validation.spec_validator import SpecValidator\n",
"\n",
"from mlte.measurement.cpu import CPUStatistics\n",
"from mlte.measurement.memory import MemoryStatistics\n",
"from mlte.value.types.image import Image\n",
"from mlte.value.types.integer import Integer\n",
"\n",
"from values.multiple_accuracy import MultipleAccuracy\n",
"from values.multiple_ranksums import MultipleRanksums\n",
"from values.ranksums import RankSums\n",
"from mlte.value.artifact import Value\n",
"\n",
"# Load the specification\n",
"spec = Spec.load()\n",
"\n",
"# Add all values to the validator.\n",
"spec_validator = SpecValidator(spec)\n",
"spec_validator.add_value(MultipleAccuracy.load(\"accuracy across gardens.value\"))\n",
"spec_validator.add_value(RankSums.load(\"ranksums blur2x8.value\"))\n",
"spec_validator.add_value(RankSums.load(\"ranksums blur5x8.value\"))\n",
"spec_validator.add_value(RankSums.load(\"ranksums blur0x8.value\"))\n",
"spec_validator.add_value(\n",
" MultipleRanksums.load(\"multiple ranksums for clade2.value\")\n",
")\n",
"spec_validator.add_value(\n",
" MultipleRanksums.load(\"multiple ranksums between clade2 and 3.value\")\n",
")\n",
"spec_validator.add_value(Integer.load(\"model size.value\"))\n",
"spec_validator.add_value(CPUStatistics.load(\"predicting cpu.value\"))\n",
"spec_validator.add_value(MemoryStatistics.load(\"predicting memory.value\"))\n",
"spec_validator.add_value(Image.load(\"image attributions.value\"))"
"spec_validator.add_values(Value.load_all())"
]
},
{
Expand Down
16 changes: 2 additions & 14 deletions demo/simple/report.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,14 @@
"source": [
"from mlte.spec.spec import Spec\n",
"from mlte.validation.spec_validator import SpecValidator\n",
"from mlte.value.types.integer import Integer\n",
"from mlte.value.types.real import Real\n",
"from mlte.value.types.image import Image\n",
"from mlte.measurement.cpu import CPUStatistics\n",
"from mlte.measurement.memory import MemoryStatistics\n",
"from confusion_matrix import ConfusionMatrix\n",
"from mlte.value.artifact import Value\n",
"\n",
"# Load the specification\n",
"spec = Spec.load()\n",
"\n",
"# TODO: This could also be done in bulk, with a dictionary-binding like indicating properties, validators and ids.\n",
"\n",
"# Add all values to the validator.\n",
"spec_validator = SpecValidator(spec)\n",
"spec_validator.add_value(Integer.load(\"model size.value\"))\n",
"spec_validator.add_value(CPUStatistics.load(\"training cpu.value\"))\n",
"spec_validator.add_value(MemoryStatistics.load(\"training memory.value\"))\n",
"spec_validator.add_value(Real.load(\"accuracy.value\"))\n",
"spec_validator.add_value(ConfusionMatrix.load(\"confusion matrix.value\"))\n",
"spec_validator.add_value(Image.load(\"class distribution.value\"))"
"spec_validator.add_values(Value.load_all())"
]
},
{
Expand Down
26 changes: 26 additions & 0 deletions mlte/_private/reflection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import importlib
from typing import Any, Type


def load_class(class_path: str) -> Type[Any]:
"""
Returns a class type of the given class name/path.
:param class_path: A path to a class to use, including absolute package/module path and class name.
"""
# Split into package/module and class name.
parts = class_path.rsplit(".", 1)
module_name = parts[0]
class_name = parts[1]

try:
loaded_module = importlib.import_module(module_name)
except Exception:
raise RuntimeError(f"Module {module_name} not found")
try:
class_type: Type[Any] = getattr(loaded_module, class_name)
except Exception:
raise RuntimeError(
f"Class {class_name} in module {module_name} not found"
)

return class_type
27 changes: 27 additions & 0 deletions mlte/artifact/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Optional

import mlte._private.meta as meta
import mlte.store.query as query
from mlte.artifact.model import ArtifactHeaderModel, ArtifactModel
from mlte.artifact.type import ArtifactType
from mlte.context.context import Context
Expand Down Expand Up @@ -166,6 +167,32 @@ def load_with(
artifact.post_load_hook(context, store)
return artifact

@staticmethod
def load_all_models(artifact_type: ArtifactType) -> list[ArtifactModel]:
"""Loads all artifact models of the given type from the session."""
return Artifact.load_all_models_with(
artifact_type, context=session().context, store=session().store
)

@staticmethod
def load_all_models_with(
artifact_type: ArtifactType, context: Context, store: Store
) -> list[ArtifactModel]:
"""Loads all artifact models of the given type for the given context and store."""
with ManagedSession(store.session()) as handle:
query_instance = query.Query(
filter=query.ArtifactTypeFilter(
type=query.FilterType.TYPE, artifact_type=artifact_type
)
)
artifact_models = handle.search_artifacts(
context.namespace,
context.model,
context.version,
query_instance,
)
return artifact_models

@staticmethod
def get_default_id() -> str:
"""To be overriden by derived classes."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@
"metadata": {
"$ref": "#/$defs/EvidenceMetadata"
},
"value_class": {
"title": "Value Class",
"type": "string"
},
"value": {
"discriminator": {
"mapping": {
Expand Down Expand Up @@ -160,6 +164,7 @@
"required": [
"artifact_type",
"metadata",
"value_class",
"value"
],
"title": "ValueModel",
Expand Down
5 changes: 5 additions & 0 deletions mlte/schema/artifact/value/v0.0.1/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@
"metadata": {
"$ref": "#/$defs/EvidenceMetadata"
},
"value_class": {
"title": "Value Class",
"type": "string"
},
"value": {
"discriminator": {
"mapping": {
Expand Down Expand Up @@ -160,6 +164,7 @@
"required": [
"artifact_type",
"metadata",
"value_class",
"value"
],
"title": "ValueModel",
Expand Down
39 changes: 39 additions & 0 deletions mlte/value/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@
from __future__ import annotations

import abc
import typing

from mlte._private.reflection import load_class
from mlte.artifact.artifact import Artifact
from mlte.artifact.model import ArtifactModel
from mlte.artifact.type import ArtifactType
from mlte.context.context import Context
from mlte.evidence.metadata import EvidenceMetadata
from mlte.store.base import Store
from mlte.value.model import ValueModel


class Value(Artifact, metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -57,3 +62,37 @@ def from_model(cls, _: ArtifactModel) -> Value: # type: ignore[override]
is delegated to subclasses that implement concrete types
"""
raise NotImplementedError("Value.from_model()")

@staticmethod
def load_all() -> list[Value]:
"""Loads all artifact models of the given type for the current session."""
value_models = Value.load_all_models(ArtifactType.VALUE)
return Value._load_from_models(value_models)

@staticmethod
def load_all_with(context: Context, store: Store) -> list[Value]:
"""Loads all artifact models of the given type for the given context and store."""
value_models = Value.load_all_models_with(
ArtifactType.VALUE, context, store
)
return Value._load_from_models(value_models)

@staticmethod
def _load_from_models(value_models: list[ArtifactModel]) -> list[Value]:
"""Converts a list of value models (as Artifact Models) into values."""
values = []
for artifact_model in value_models:
value_model: ValueModel = typing.cast(
ValueModel, artifact_model.body
)
value_type: Value = typing.cast(
Value, load_class(value_model.value_class)
)
value = value_type.from_model(artifact_model)
values.append(value)
return values

@classmethod
def get_class_path(cls) -> str:
"""Returns the full path to this class, including module."""
return f"{cls.__module__}.{cls.__name__}"
12 changes: 3 additions & 9 deletions mlte/value/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,6 @@ def __subclasshook__(cls, subclass):
"""Define the interface for all Value subclasses."""
return meta.has_callables(subclass, "serialize", "deserialize")

def __init__(self, instance: ValueBase, metadata: EvidenceMetadata) -> None:
"""
Initialize a MLTE value.
:param instance: The subclass instance
:param metadata: Evidence metadata associated with the value
"""
super().__init__(instance, metadata)

@abc.abstractmethod
def serialize(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -72,8 +64,10 @@ def to_model(self) -> ArtifactModel:
body=ValueModel(
artifact_type=ArtifactType.VALUE,
metadata=self.metadata,
value_class=self.get_class_path(),
value=OpaqueValueModel(
value_type=ValueType.OPAQUE, data=self.serialize()
value_type=ValueType.OPAQUE,
data=self.serialize(),
),
),
)
Expand Down
3 changes: 3 additions & 0 deletions mlte/value/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class ValueModel(BaseModel):
metadata: EvidenceMetadata
"""Evidence metadata associated with the value."""

value_class: str
"""Full path to class that implements this value."""

value: Union[
"IntegerValueModel",
"RealValueModel",
Expand Down
1 change: 1 addition & 0 deletions mlte/value/types/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def to_model(self) -> ArtifactModel:
body=ValueModel(
artifact_type=ArtifactType.VALUE,
metadata=self.metadata,
value_class=self.get_class_path(),
value=ImageValueModel(
value_type=ValueType.IMAGE,
data=base64.encodebytes(self.image).decode("utf-8"),
Expand Down
4 changes: 3 additions & 1 deletion mlte/value/types/integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ def to_model(self) -> ArtifactModel:
body=ValueModel(
artifact_type=ArtifactType.VALUE,
metadata=self.metadata,
value_class=self.get_class_path(),
value=IntegerValueModel(
value_type=ValueType.INTEGER, integer=self.value
value_type=ValueType.INTEGER,
integer=self.value,
),
),
)
Expand Down
4 changes: 3 additions & 1 deletion mlte/value/types/opaque.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ def to_model(self) -> ArtifactModel:
body=ValueModel(
artifact_type=ArtifactType.VALUE,
metadata=self.metadata,
value_class=self.get_class_path(),
value=OpaqueValueModel(
value_type=ValueType.OPAQUE, data=self.data
value_type=ValueType.OPAQUE,
data=self.data,
),
),
)
Expand Down
4 changes: 3 additions & 1 deletion mlte/value/types/real.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ def to_model(self) -> ArtifactModel:
body=ValueModel(
artifact_type=ArtifactType.VALUE,
metadata=self.metadata,
value_class=self.get_class_path(),
value=RealValueModel(
value_type=ValueType.REAL, real=self.value
value_type=ValueType.REAL,
real=self.value,
),
),
)
Expand Down
62 changes: 62 additions & 0 deletions test/artifact/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,24 @@

Unit tests for MLTE artifact protocol implementation.
"""
from typing import Tuple

import pytest

from mlte.artifact.artifact import Artifact
from mlte.artifact.type import ArtifactType
from mlte.context.context import Context
from mlte.evidence.metadata import EvidenceMetadata, Identifier
from mlte.negotiation.artifact import NegotiationCard
from mlte.report.artifact import Report
from mlte.session.state import set_context, set_store
from mlte.spec.spec import Spec
from mlte.store.base import Store
from mlte.validation.validated_spec import ValidatedSpec
from mlte.value.types.integer import Integer
from mlte.value.types.real import Real

from ..fixture.store import store_with_context # noqa
from ..fixture.store import FX_MODEL_ID, FX_NAMESPACE_ID, FX_VERSION_ID


Expand All @@ -24,3 +38,51 @@ def test_save_load_session() -> None:

a.save(parents=True)
_ = NegotiationCard.load("my-card")


def fill_test_store(ctx: Context, store: Store):
"""Fills a sample store."""
n1 = NegotiationCard("test-card")
n2 = NegotiationCard("test-card2")
s1 = Spec("test-spec1")
s2 = Spec("test-spec2")
vs1 = ValidatedSpec("test-validated1", s1)
vs2 = ValidatedSpec("test-validated2", s2)
m1 = EvidenceMetadata(
measurement_type="typename", identifier=Identifier(name="id1")
)
v1 = Integer(m1, 10)
m2 = EvidenceMetadata(
measurement_type="typename", identifier=Identifier(name="id2")
)
v2 = Real(m2, 3.14)
r1 = Report("r1")
r2 = Report("r2")

n1.save_with(ctx, store, parents=True)
n2.save_with(ctx, store)
s1.save_with(ctx, store)
s2.save_with(ctx, store)
vs1.save_with(ctx, store)
vs2.save_with(ctx, store)
v1.save_with(ctx, store)
v2.save_with(ctx, store)
r1.save_with(ctx, store)
r2.save_with(ctx, store)


@pytest.mark.parametrize("artifact_type", ArtifactType)
def test_load_all_models(
artifact_type: ArtifactType,
store_with_context: Tuple[Store, Context], # noqa
):
"""
Loading all models of a given type.
"""
store, ctx = store_with_context
fill_test_store(ctx, store)

models = Artifact.load_all_models_with(artifact_type, ctx, store)

assert len(models) == 2
assert models[0].header.type == artifact_type
Loading
Loading