Skip to content

Commit

Permalink
fix: Support Pydantic 2.0 (#68)
Browse files Browse the repository at this point in the history
* fix: Support Pydantic 2.0

* fix: Linting

* fix: Remove typo
  • Loading branch information
jstlaurent authored Jul 10, 2023
1 parent b68689a commit 8cf5c9a
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 58 deletions.
2 changes: 1 addition & 1 deletion env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies:
- h5py
- pyarrow
- matplotlib
- pydantic
- pydantic >=2.0.0

# Chemistry
- datamol >=0.8.0
Expand Down
28 changes: 15 additions & 13 deletions molfeat/store/modelcard.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
from typing import Optional
from typing import List
from typing import Union

from datetime import datetime
from pydantic.typing import Literal
from pydantic import BaseModel
from pydantic import Field
from typing import List, Literal, Optional, Union

import datamol as dm
from pydantic import BaseModel, ConfigDict, Field


def get_model_init(card):
Expand Down Expand Up @@ -53,6 +49,12 @@ def get_model_init(card):


class ModelInfo(BaseModel):
model_config = ConfigDict(
protected_namespaces=(
"protected_",
) # Prevents warning from usage of model_ prefix in fields
)

name: str
inputs: str = "smiles"
type: Literal["pretrained", "hand-crafted", "hashed", "count"]
Expand All @@ -62,12 +64,12 @@ class ModelInfo(BaseModel):
description: str
representation: Literal["graph", "line-notation", "vector", "tensor", "other"]
require_3D: Optional[bool] = False
tags: Optional[List[str]]
tags: Optional[List[str]] = []
authors: Optional[List[str]]
reference: Optional[str]
reference: Optional[str] = None
created_at: datetime = Field(default_factory=datetime.now)
sha256sum: Optional[str]
model_usage: Optional[str]
sha256sum: Optional[str] = None
model_usage: Optional[str] = None

def path(self, root_path: str):
"""Generate the folder path where to save this model
Expand All @@ -86,9 +88,9 @@ def match(self, new_card: Union["ModelInfo", dict], match_only: Optional[List[st
match_only: list of minimum attribute that should match between the two model information
"""

self_content = self.dict().copy()
self_content = self.model_dump().copy()
if not isinstance(new_card, dict):
new_card = new_card.dict()
new_card = new_card.model_dump()
new_content = new_card.copy()
# we always remove the datetime field
self_content.pop("created_at", None)
Expand Down
20 changes: 8 additions & 12 deletions molfeat/store/modelstore.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
from typing import Optional
from typing import Any
from typing import Union
from typing import Callable

import yaml
import joblib
import pathlib
import os
import fsspec
import pathlib
import tempfile
import platformdirs
import filelock
from typing import Any, Callable, Optional, Union

import datamol as dm
import filelock
import fsspec
import joblib
import platformdirs
import yaml
from dotenv import load_dotenv
from loguru import logger

from molfeat.store.modelcard import ModelInfo
from molfeat.utils import commons


load_dotenv()


Expand Down
9 changes: 5 additions & 4 deletions molfeat/trans/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from typing import Mapping
from typing import Union
from typing import List
Expand Down Expand Up @@ -575,11 +576,11 @@ def to_state_json(self) -> str:
def to_state_yaml(self) -> str:
return yaml.dump(self.to_state_dict(), Dumper=yaml.SafeDumper)

def to_state_json_file(self, filepath: str):
def to_state_json_file(self, filepath: Union[str, Path]):
with fsspec.open(filepath, "w") as f:
f.write(self.to_state_json()) # type: ignore

def to_state_yaml_file(self, filepath: str):
def to_state_yaml_file(self, filepath: Union[str, Path]):
with fsspec.open(filepath, "w") as f:
f.write(self.to_state_yaml()) # type: ignore

Expand Down Expand Up @@ -674,7 +675,7 @@ def from_state_yaml(

@staticmethod
def from_state_json_file(
filepath: str,
filepath: Union[str, Path],
override_args: Optional[dict] = None,
) -> "MoleculeTransformer":
with fsspec.open(filepath, "r") as f:
Expand All @@ -683,7 +684,7 @@ def from_state_json_file(

@staticmethod
def from_state_yaml_file(
filepath: str,
filepath: Union[str, Path],
override_args: Optional[dict] = None,
) -> "MoleculeTransformer":
with fsspec.open(filepath, "r") as f:
Expand Down
54 changes: 26 additions & 28 deletions tests/test_state.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,33 @@
import pytest

import numpy as np
import datamol as dm
import numpy as np
import pytest

from molfeat._version import __version__ as MOLFEAT_VERSION
from molfeat.trans.fp import FPVecTransformer
from molfeat.trans.fp import FPVecFilteredTransformer
from molfeat.trans.base import MoleculeTransformer
from molfeat.trans.base import PrecomputedMolTransformer
from molfeat.trans.graph import AdjGraphTransformer
from molfeat.trans.graph import DGLGraphTransformer
from molfeat.trans.graph import TopoDistGraphTransformer
from molfeat.trans.graph import PYGGraphTransformer
from molfeat.trans.pretrained import PretrainedDGLTransformer
from molfeat.trans.pretrained import GraphormerTransformer
from molfeat.trans.pretrained import PretrainedHFTransformer

from molfeat.calc.atom import AtomCalculator
from molfeat.calc import (
CATS,
FPCalculator,
Pharmacophore2D,
RDKitDescriptors2D,
ScaffoldKeyCalculator,
)
from molfeat.calc._atom_bond_features import atom_chiral_tag_one_hot, atom_one_hot
from molfeat.calc.atom import AtomCalculator, AtomMaterialCalculator
from molfeat.calc.bond import BondCalculator
from molfeat.calc.atom import AtomMaterialCalculator
from molfeat.calc import FPCalculator
from molfeat.calc import ScaffoldKeyCalculator
from molfeat.calc import RDKitDescriptors2D
from molfeat.calc import CATS
from molfeat.calc import Pharmacophore2D
from molfeat.calc._atom_bond_features import atom_chiral_tag_one_hot
from molfeat.calc._atom_bond_features import atom_one_hot
from molfeat.trans.graph import MolTreeDecompositionTransformer

from molfeat.utils.cache import MolToKey
from molfeat.utils.cache import FileCache
from molfeat.trans.base import MoleculeTransformer, PrecomputedMolTransformer
from molfeat.trans.fp import FPVecFilteredTransformer, FPVecTransformer
from molfeat.trans.graph import (
AdjGraphTransformer,
DGLGraphTransformer,
MolTreeDecompositionTransformer,
PYGGraphTransformer,
TopoDistGraphTransformer,
)
from molfeat.trans.pretrained import (
GraphormerTransformer,
PretrainedDGLTransformer,
PretrainedHFTransformer,
)
from molfeat.utils.cache import FileCache, MolToKey
from molfeat.utils.state import compare_state


Expand Down

0 comments on commit 8cf5c9a

Please sign in to comment.