Skip to content

Commit

Permalink
Merge branch 'mlte-team:master' into feature/add-report-table-colors
Browse files Browse the repository at this point in the history
  • Loading branch information
sei-aderr authored Jan 23, 2024
2 parents 3aa8017 + 8e13603 commit 40a725d
Show file tree
Hide file tree
Showing 39 changed files with 500 additions and 425 deletions.
24 changes: 14 additions & 10 deletions mlte/artifact/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from typing import Optional

import mlte._private.meta as meta
import mlte.store.query as query
import mlte.store.artifact.query as query
from mlte.artifact.model import ArtifactHeaderModel, ArtifactModel
from mlte.artifact.type import ArtifactType
from mlte.context.context import Context
from mlte.session.state import session
from mlte.store.base import ManagedSession, Store
from mlte.store.artifact.store import ArtifactStore, ManagedArtifactSession


class Artifact(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -63,7 +63,7 @@ def from_model(cls, _: ArtifactModel) -> Artifact:
"Artifact.from_model() not implemented for abstract Artifact."
)

def pre_save_hook(self, context: Context, store: Store) -> None:
def pre_save_hook(self, context: Context, store: ArtifactStore) -> None:
"""
A method that artifact subclasses can override to enforce pre-save invariants.
:param context: The context in which to save the artifact
Expand All @@ -73,7 +73,7 @@ def pre_save_hook(self, context: Context, store: Store) -> None:
# Default implementation is a no-op
pass

def post_load_hook(self, context: Context, store: Store) -> None:
def post_load_hook(self, context: Context, store: ArtifactStore) -> None:
"""
A method that artifact subclasses may override to enforce post-load invariants.
:param context: The context in which to save the artifact
Expand Down Expand Up @@ -101,7 +101,7 @@ def save(self, *, force: bool = False, parents: bool = False) -> None:
def save_with(
self,
context: Context,
store: Store,
store: ArtifactStore,
*,
force: bool = False,
parents: bool = False,
Expand All @@ -118,7 +118,7 @@ def save_with(

artifact_model = self.to_model()
artifact_model.header.timestamp = int(time.time())
with ManagedSession(store.session()) as handle:
with ManagedArtifactSession(store.session()) as handle:
handle.write_artifact(
context.namespace,
context.model,
Expand All @@ -143,7 +143,11 @@ def load(cls, identifier: Optional[str] = None) -> Artifact:

@classmethod
def load_with(
cls, identifier: Optional[str] = None, *, context: Context, store: Store
cls,
identifier: Optional[str] = None,
*,
context: Context,
store: ArtifactStore,
) -> Artifact:
"""
Load an artifact with the given context and store configuration.
Expand All @@ -154,7 +158,7 @@ def load_with(
if identifier is None:
identifier = cls.get_default_id()

with ManagedSession(store.session()) as handle:
with ManagedArtifactSession(store.session()) as handle:
artifact = cls.from_model(
handle.read_artifact(
context.namespace,
Expand All @@ -176,10 +180,10 @@ def load_all_models(artifact_type: ArtifactType) -> list[ArtifactModel]:

@staticmethod
def load_all_models_with(
artifact_type: ArtifactType, context: Context, store: Store
artifact_type: ArtifactType, context: Context, store: ArtifactStore
) -> list[ArtifactModel]:
"""Loads all artifact models of the given type for the given context and store."""
with ManagedSession(store.session()) as handle:
with ManagedArtifactSession(store.session()) as handle:
query_instance = query.Query(
filter=query.ArtifactTypeFilter(
type=query.FilterType.TYPE, artifact_type=artifact_type
Expand Down
8 changes: 4 additions & 4 deletions mlte/report/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
ReportModel,
SummaryDescriptor,
)
from mlte.store.base import ManagedSession, Store
from mlte.store.artifact.store import ArtifactStore, ManagedArtifactSession

DEFAULT_REPORT_ID = "default.report"

Expand Down Expand Up @@ -90,7 +90,7 @@ def to_model(self) -> ArtifactModel:
),
)

def pre_save_hook(self, context: Context, store: Store) -> None:
def pre_save_hook(self, context: Context, store: ArtifactStore) -> None:
"""
Override Artifact.pre_save_hook().
:param context: The context in which to save the artifact
Expand All @@ -100,7 +100,7 @@ def pre_save_hook(self, context: Context, store: Store) -> None:
if self.validated_spec_id is None:
return

with ManagedSession(store.session()) as handle:
with ManagedArtifactSession(store.session()) as handle:
try:
artifact = handle.read_artifact(
context.namespace,
Expand All @@ -118,7 +118,7 @@ def pre_save_hook(self, context: Context, store: Store) -> None:
f"Validated specification with identifier {self.validated_spec_id} not found."
)

def post_load_hook(self, context: Context, store: Store) -> None:
def post_load_hook(self, context: Context, store: ArtifactStore) -> None:
"""
Override Artifact.post_load_hook().
:param context: The context in which to save the artifact
Expand Down
10 changes: 5 additions & 5 deletions mlte/session/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from typing import Optional

from mlte.context import Context
from mlte.store.base import Store
from mlte.store.factory import create_store
from mlte.store.artifact.factory import create_store
from mlte.store.artifact.store import ArtifactStore


class Session:
Expand All @@ -30,7 +30,7 @@ def __init__(self):
self._context: Optional[Context] = None
"""The MLTE context for the session."""

self._store: Optional[Store] = None
self._store: Optional[ArtifactStore] = None
"""The MLTE store instance for the session."""

@property
Expand All @@ -40,7 +40,7 @@ def context(self) -> Context:
return self._context

@property
def store(self) -> Store:
def store(self) -> ArtifactStore:
if self._store is None:
raise RuntimeError("Must initialize MLTE store for session.")
return self._store
Expand All @@ -49,7 +49,7 @@ def _set_context(self, context: Context) -> None:
"""Set the session context."""
self._context = context

def _set_store(self, store: Store) -> None:
def _set_store(self, store: ArtifactStore) -> None:
"""Set the session store."""
self._store = store

Expand Down
File renamed without changes.
13 changes: 7 additions & 6 deletions mlte/store/factory.py → mlte/store/artifact/factory.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
"""
mlte/store/factory.py
mlte/store/artifact/factory.py
Top-level functions for artifact store creation.
"""

from mlte.store.base import Store, StoreType, StoreURI
from mlte.store.underlying.fs import LocalFileSystemStore
from mlte.store.underlying.http import RemoteHttpStore
from mlte.store.underlying.memory import InMemoryStore
from mlte.store.artifact.store import ArtifactStore
from mlte.store.artifact.underlying.fs import LocalFileSystemStore
from mlte.store.artifact.underlying.http import RemoteHttpStore
from mlte.store.artifact.underlying.memory import InMemoryStore
from mlte.store.base import StoreType, StoreURI


def create_store(uri: str) -> Store:
def create_store(uri: str) -> ArtifactStore:
"""
Create a MLTE artifact store instance.
:param uri: The URI for the store instance
Expand Down
2 changes: 1 addition & 1 deletion mlte/store/query.py → mlte/store/artifact/query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
mlte/store/query.py
mlte/store/artifact/query.py
Query and filtering functionality for store operations.
"""
Expand Down
Loading

0 comments on commit 40a725d

Please sign in to comment.