From 92fa064d1bcd0a032c7b10aded3767a308a48552 Mon Sep 17 00:00:00 2001 From: Thorrester Date: Tue, 30 Apr 2024 13:45:00 -0400 Subject: [PATCH] typing --- opsml/model/interfaces/pytorch.py | 7 ++++--- opsml/model/loader.py | 14 +++++++------- opsml/projects/project.py | 2 +- opsml/storage/card_loader.py | 8 ++++---- tests/conftest.py | 6 +++--- .../test_model_interface_saver_loader_api.py | 2 +- tests/test_registry/test_card.py | 16 ++++++++-------- 7 files changed, 28 insertions(+), 27 deletions(-) diff --git a/opsml/model/interfaces/pytorch.py b/opsml/model/interfaces/pytorch.py index f7ca256f4..fcdc23bfd 100644 --- a/opsml/model/interfaces/pytorch.py +++ b/opsml/model/interfaces/pytorch.py @@ -55,9 +55,7 @@ class TorchModel(ModelInterface): """ model: Optional[torch.nn.Module] = None - sample_data: Optional[Union[torch.Tensor, Dict[str, torch.Tensor], List[torch.Tensor], Tuple[torch.Tensor]]] = ( - None - ) + sample_data: Optional[Union[torch.Tensor, Dict[str, torch.Tensor], List[torch.Tensor], Tuple[torch.Tensor]]] = None onnx_args: Optional[TorchOnnxArgs] = None save_args: TorchSaveArgs = TorchSaveArgs() preprocessor: Optional[Any] = None @@ -168,6 +166,9 @@ def load_model(self, path: Path, **kwargs: Any) -> None: """ model_arch = kwargs.get(CommonKwargs.MODEL_ARCH.value) + # remove model_arch from kwargs. Will raise an error if passed to torch.load + kwargs.pop(CommonKwargs.MODEL_ARCH.value, None) + if model_arch is not None: model_arch.load_state_dict(torch.load(path, **kwargs)) model_arch.eval() diff --git a/opsml/model/loader.py b/opsml/model/loader.py index 6ee111cbc..eb9fceeba 100644 --- a/opsml/model/loader.py +++ b/opsml/model/loader.py @@ -6,10 +6,10 @@ import json from pathlib import Path -from typing import Any +from typing import Any, Optional from opsml.model import HuggingFaceModel, ModelInterface -from opsml.types import ModelMetadata, OnnxModel, SaveName, Suffix +from opsml.types import ModelMetadata, OnnxModel, SaveName, Suffix, HuggingFaceOnnxArgs class ModelLoader: @@ -151,12 +151,12 @@ def _load_huggingface_onnx_model(self, load_quantized: bool) -> None: self.interface.onnx_model = OnnxModel(onnx_version=self.metadata.onnx_version) self.interface.load_onnx_model(load_path) - def load_onnx_model(self, **kwargs: Any) -> None: + def load_onnx_model(self, load_quantized: bool = False, onnx_args: Optional[HuggingFaceOnnxArgs] = None) -> None: """Load onnx model from disk - Kwargs: + Args: - ------Note: These kwargs only apply to HuggingFace models------ + ------Note: These args only apply to HuggingFace models------ kwargs: load_quantized: @@ -167,8 +167,8 @@ def load_onnx_model(self, **kwargs: Any) -> None: """ if isinstance(self.interface, HuggingFaceModel): - self.interface.onnx_args = kwargs.get("onnx_args", None) - self._load_huggingface_onnx_model(**kwargs) + self.interface.onnx_args = onnx_args + self._load_huggingface_onnx_model(load_quantized) return load_path = (self.path / SaveName.ONNX_MODEL.value).with_suffix(Suffix.ONNX.value) diff --git a/opsml/projects/project.py b/opsml/projects/project.py index 2f143500b..39db9a65e 100644 --- a/opsml/projects/project.py +++ b/opsml/projects/project.py @@ -7,7 +7,7 @@ from contextlib import contextmanager from typing import Any, Dict, Iterator, List, Optional, Union, cast -from opsml.cards.base import Card +from opsml.cards import Card from opsml.cards.project import ProjectCard from opsml.cards.run import RunCard from opsml.helpers.logging import ArtifactLogger diff --git a/opsml/storage/card_loader.py b/opsml/storage/card_loader.py index a30c171fa..a2bdc1f52 100644 --- a/opsml/storage/card_loader.py +++ b/opsml/storage/card_loader.py @@ -501,7 +501,7 @@ def load_preprocessor(self, lpath: Optional[Path] = None, rpath: Optional[Path] self.card.interface.load_preprocessor(lpath) return - def _load_model(self, lpath: Path, rpath: Path) -> None: + def _load_model(self, lpath: Path, rpath: Path, **kwargs: Any) -> None: """Load model to interface Args: @@ -512,7 +512,7 @@ def _load_model(self, lpath: Path, rpath: Path) -> None: """ lpath = self.download(lpath, rpath, SaveName.TRAINED_MODEL.value, self.model_suffix) - self.card.interface.load_model(lpath) + self.card.interface.load_model(lpath, **kwargs) if isinstance(self.card.interface, HuggingFaceModel): if self.card.interface.is_pipeline: @@ -603,7 +603,7 @@ def load_onnx_model(self, load_preprocessor: bool = False, load_quantized: bool return None - def load_model(self, load_preprocessor: bool) -> None: + def load_model(self, load_preprocessor: bool, **kwargs: Any) -> None: """Load model, preprocessor and sample data""" if self.card.interface.model is not None: @@ -617,7 +617,7 @@ def load_model(self, load_preprocessor: bool) -> None: if load_preprocessor: self.load_preprocessor(lpath, rpath) self._load_sample_data(lpath, rpath) - self._load_model(lpath, rpath) + self._load_model(lpath, rpath, **kwargs) return None diff --git a/tests/conftest.py b/tests/conftest.py index 2b16dd641..ccb0e46e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -160,7 +160,7 @@ def cleanup() -> None: @pytest.fixture -def gcp_cred_path(): +def gcp_cred_path() -> str: return os.path.join(os.path.dirname(__file__), "assets/fake_gcp_creds.json") @@ -171,7 +171,7 @@ def save_path() -> str: @pytest.fixture -def mock_gcp_vars(gcp_cred_path): +def mock_gcp_vars(gcp_cred_path) -> Any: creds, _ = load_credentials_from_file(gcp_cred_path) mock_vars = { "gcp_project": "test", @@ -185,7 +185,7 @@ def mock_gcp_vars(gcp_cred_path): @pytest.fixture -def mock_gcp_creds(mock_gcp_vars): +def mock_gcp_creds(mock_gcp_vars) -> Any: creds = GcpCreds( creds=mock_gcp_vars["gcp_creds"], project=mock_gcp_vars["gcp_project"], diff --git a/tests/test_interface/test_model_interface_saver_loader_api.py b/tests/test_interface/test_model_interface_saver_loader_api.py index 3611d9f7e..c11ef9bf6 100644 --- a/tests/test_interface/test_model_interface_saver_loader_api.py +++ b/tests/test_interface/test_model_interface_saver_loader_api.py @@ -102,7 +102,7 @@ def test_save_huggingface_modelcard_api_client( modelcard.download_model(path=path, load_preprocessor=False, load_onnx=True) assert (path / SaveName.ONNX_MODEL.value).exists() - modelcard.download_model(path=path, load_preprocessor=False, load_onnx=True, quantize=True) + modelcard.download_model(path=path, load_preprocessor=False, load_onnx=True, load_quantized=True) assert (path / SaveName.QUANTIZED_MODEL.value).exists() diff --git a/tests/test_registry/test_card.py b/tests/test_registry/test_card.py index bdae14772..35f3016df 100644 --- a/tests/test_registry/test_card.py +++ b/tests/test_registry/test_card.py @@ -1,6 +1,6 @@ import pytest -from opsml.cards import Card +from opsml.cards import ArtifactCard from opsml.helpers.utils import validate_name_repository_pattern from opsml.types import CardInfo, Comment, RegistryType @@ -8,7 +8,7 @@ def test_artifact_card_with_args() -> None: - card = Card( + card = ArtifactCard( name=card_info.name, repository=card_info.repository, contact=card_info.contact, @@ -20,14 +20,14 @@ def test_artifact_card_with_args() -> None: def test_artifact_card_without_args() -> None: - card = Card(info=card_info) + card = ArtifactCard(info=card_info) assert card.name == card_info.name assert card.repository == card_info.repository assert card.contact == card_info.contact def test_artifact_card_with_both() -> None: - card = Card(name="override_name", info=card_info) + card = ArtifactCard(name="override_name", info=card_info) assert card.name == "override-name" # string cleaning assert card.repository == card_info.repository @@ -42,7 +42,7 @@ def test_artifact_card_name_repository_fail() -> None: ) with pytest.raises(ValueError): - Card( + ArtifactCard( name=card_info.name, repository=card_info.repository, contact=card_info.contact, @@ -76,17 +76,17 @@ def test_argument_fail() -> None: ) with pytest.raises(ValueError): - Card( + ArtifactCard( repository=card_info.repository, contact=card_info.contact, ) with pytest.raises(ValueError): - Card( + ArtifactCard( repository=card_info.repository, ) with pytest.raises(ValueError): - Card( + ArtifactCard( info=card_info, )