Skip to content

Commit

Permalink
typing
Browse files Browse the repository at this point in the history
  • Loading branch information
thorrester committed Apr 30, 2024
1 parent 1d963ef commit 92fa064
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 27 deletions.
7 changes: 4 additions & 3 deletions opsml/model/interfaces/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ class TorchModel(ModelInterface):
"""

model: Optional[torch.nn.Module] = None
sample_data: Optional[Union[torch.Tensor, Dict[str, torch.Tensor], List[torch.Tensor], Tuple[torch.Tensor]]] = (
None
)
sample_data: Optional[Union[torch.Tensor, Dict[str, torch.Tensor], List[torch.Tensor], Tuple[torch.Tensor]]] = None
onnx_args: Optional[TorchOnnxArgs] = None
save_args: TorchSaveArgs = TorchSaveArgs()
preprocessor: Optional[Any] = None
Expand Down Expand Up @@ -168,6 +166,9 @@ def load_model(self, path: Path, **kwargs: Any) -> None:
"""
model_arch = kwargs.get(CommonKwargs.MODEL_ARCH.value)

# remove model_arch from kwargs. Will raise an error if passed to torch.load
kwargs.pop(CommonKwargs.MODEL_ARCH.value, None)

if model_arch is not None:
model_arch.load_state_dict(torch.load(path, **kwargs))
model_arch.eval()
Expand Down
14 changes: 7 additions & 7 deletions opsml/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

import json
from pathlib import Path
from typing import Any
from typing import Any, Optional

from opsml.model import HuggingFaceModel, ModelInterface
from opsml.types import ModelMetadata, OnnxModel, SaveName, Suffix
from opsml.types import ModelMetadata, OnnxModel, SaveName, Suffix, HuggingFaceOnnxArgs


class ModelLoader:
Expand Down Expand Up @@ -151,12 +151,12 @@ def _load_huggingface_onnx_model(self, load_quantized: bool) -> None:
self.interface.onnx_model = OnnxModel(onnx_version=self.metadata.onnx_version)
self.interface.load_onnx_model(load_path)

def load_onnx_model(self, **kwargs: Any) -> None:
def load_onnx_model(self, load_quantized: bool = False, onnx_args: Optional[HuggingFaceOnnxArgs] = None) -> None:
"""Load onnx model from disk
Kwargs:
Args:
------Note: These kwargs only apply to HuggingFace models------
------Note: These args only apply to HuggingFace models------
kwargs:
load_quantized:
Expand All @@ -167,8 +167,8 @@ def load_onnx_model(self, **kwargs: Any) -> None:
"""
if isinstance(self.interface, HuggingFaceModel):
self.interface.onnx_args = kwargs.get("onnx_args", None)
self._load_huggingface_onnx_model(**kwargs)
self.interface.onnx_args = onnx_args
self._load_huggingface_onnx_model(load_quantized)
return

load_path = (self.path / SaveName.ONNX_MODEL.value).with_suffix(Suffix.ONNX.value)
Expand Down
2 changes: 1 addition & 1 deletion opsml/projects/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Union, cast

from opsml.cards.base import Card
from opsml.cards import Card
from opsml.cards.project import ProjectCard
from opsml.cards.run import RunCard
from opsml.helpers.logging import ArtifactLogger
Expand Down
8 changes: 4 additions & 4 deletions opsml/storage/card_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def load_preprocessor(self, lpath: Optional[Path] = None, rpath: Optional[Path]
self.card.interface.load_preprocessor(lpath)
return

def _load_model(self, lpath: Path, rpath: Path) -> None:
def _load_model(self, lpath: Path, rpath: Path, **kwargs: Any) -> None:
"""Load model to interface
Args:
Expand All @@ -512,7 +512,7 @@ def _load_model(self, lpath: Path, rpath: Path) -> None:
"""

lpath = self.download(lpath, rpath, SaveName.TRAINED_MODEL.value, self.model_suffix)
self.card.interface.load_model(lpath)
self.card.interface.load_model(lpath, **kwargs)

if isinstance(self.card.interface, HuggingFaceModel):
if self.card.interface.is_pipeline:
Expand Down Expand Up @@ -603,7 +603,7 @@ def load_onnx_model(self, load_preprocessor: bool = False, load_quantized: bool

return None

def load_model(self, load_preprocessor: bool) -> None:
def load_model(self, load_preprocessor: bool, **kwargs: Any) -> None:
"""Load model, preprocessor and sample data"""

if self.card.interface.model is not None:
Expand All @@ -617,7 +617,7 @@ def load_model(self, load_preprocessor: bool) -> None:
if load_preprocessor:
self.load_preprocessor(lpath, rpath)
self._load_sample_data(lpath, rpath)
self._load_model(lpath, rpath)
self._load_model(lpath, rpath, **kwargs)

return None

Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def cleanup() -> None:


@pytest.fixture
def gcp_cred_path():
def gcp_cred_path() -> str:
return os.path.join(os.path.dirname(__file__), "assets/fake_gcp_creds.json")


Expand All @@ -171,7 +171,7 @@ def save_path() -> str:


@pytest.fixture
def mock_gcp_vars(gcp_cred_path):
def mock_gcp_vars(gcp_cred_path) -> Any:
creds, _ = load_credentials_from_file(gcp_cred_path)
mock_vars = {
"gcp_project": "test",
Expand All @@ -185,7 +185,7 @@ def mock_gcp_vars(gcp_cred_path):


@pytest.fixture
def mock_gcp_creds(mock_gcp_vars):
def mock_gcp_creds(mock_gcp_vars) -> Any:
creds = GcpCreds(
creds=mock_gcp_vars["gcp_creds"],
project=mock_gcp_vars["gcp_project"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_save_huggingface_modelcard_api_client(
modelcard.download_model(path=path, load_preprocessor=False, load_onnx=True)
assert (path / SaveName.ONNX_MODEL.value).exists()

modelcard.download_model(path=path, load_preprocessor=False, load_onnx=True, quantize=True)
modelcard.download_model(path=path, load_preprocessor=False, load_onnx=True, load_quantized=True)
assert (path / SaveName.QUANTIZED_MODEL.value).exists()


Expand Down
16 changes: 8 additions & 8 deletions tests/test_registry/test_card.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import pytest

from opsml.cards import Card
from opsml.cards import ArtifactCard
from opsml.helpers.utils import validate_name_repository_pattern
from opsml.types import CardInfo, Comment, RegistryType

card_info = CardInfo(name="test", repository="opsml", contact="[email protected]")


def test_artifact_card_with_args() -> None:
card = Card(
card = ArtifactCard(
name=card_info.name,
repository=card_info.repository,
contact=card_info.contact,
Expand All @@ -20,14 +20,14 @@ def test_artifact_card_with_args() -> None:


def test_artifact_card_without_args() -> None:
card = Card(info=card_info)
card = ArtifactCard(info=card_info)
assert card.name == card_info.name
assert card.repository == card_info.repository
assert card.contact == card_info.contact


def test_artifact_card_with_both() -> None:
card = Card(name="override_name", info=card_info)
card = ArtifactCard(name="override_name", info=card_info)

assert card.name == "override-name" # string cleaning
assert card.repository == card_info.repository
Expand All @@ -42,7 +42,7 @@ def test_artifact_card_name_repository_fail() -> None:
)

with pytest.raises(ValueError):
Card(
ArtifactCard(
name=card_info.name,
repository=card_info.repository,
contact=card_info.contact,
Expand Down Expand Up @@ -76,17 +76,17 @@ def test_argument_fail() -> None:
)

with pytest.raises(ValueError):
Card(
ArtifactCard(
repository=card_info.repository,
contact=card_info.contact,
)

with pytest.raises(ValueError):
Card(
ArtifactCard(
repository=card_info.repository,
)

with pytest.raises(ValueError):
Card(
ArtifactCard(
info=card_info,
)

0 comments on commit 92fa064

Please sign in to comment.