Skip to content

Commit

Permalink
Merge branch 'devel' into enerhess
Browse files Browse the repository at this point in the history
  • Loading branch information
1azyking authored Nov 15, 2024
2 parents 23107fa + 0ad4289 commit 44ee0b8
Show file tree
Hide file tree
Showing 500 changed files with 3,357 additions and 2,361 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ jobs:
env:
NUM_WORKERS: 0
DP_TEST_TF2_ONLY: 1
DP_DTYPE_PROMOTION_STRICT: 1
if: matrix.group == 1
- run: mv .test_durations .test_durations_${{ matrix.group }}
- name: Upload partial durations
Expand Down
2 changes: 1 addition & 1 deletion backend/read_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_argument_from_env() -> tuple[str, list, list, dict, str, str]:
)


def set_scikit_build_env():
def set_scikit_build_env() -> None:
"""Set scikit-build environment variables before executing scikit-build."""
cmake_minimum_required_version, cmake_args, _, _, _, _ = get_argument_from_env()
os.environ["SKBUILD_CMAKE_MINIMUM_VERSION"] = cmake_minimum_required_version
Expand Down
2 changes: 1 addition & 1 deletion data/json/json2yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import yaml


def _main():
def _main() -> None:
parser = argparse.ArgumentParser(
description="convert json config file to yaml",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
Expand Down
4 changes: 2 additions & 2 deletions data/raw/copy_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np


def copy(in_dir, out_dir, ncopies=[1, 1, 1]):
def copy(in_dir, out_dir, ncopies=[1, 1, 1]) -> None:
has_energy = os.path.isfile(in_dir + "/energy.raw")
has_force = os.path.isfile(in_dir + "/force.raw")
has_virial = os.path.isfile(in_dir + "/virial.raw")
Expand Down Expand Up @@ -71,7 +71,7 @@ def copy(in_dir, out_dir, ncopies=[1, 1, 1]):
np.savetxt(out_dir + "/ncopies.raw", ncopies, fmt="%d")


def _main():
def _main() -> None:
parser = argparse.ArgumentParser(description="parse copy raw args")
parser.add_argument("INPUT", default=".", help="input dir of raw files")
parser.add_argument("OUTPUT", default=".", help="output dir of copied raw files")
Expand Down
2 changes: 1 addition & 1 deletion data/raw/shuffle_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def detect_raw(path):
return raws


def _main():
def _main() -> None:
args = _parse_args()
raws = args.raws
inpath = args.INPUT
Expand Down
2 changes: 1 addition & 1 deletion deepmd/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def calculate(
atoms: Optional["Atoms"] = None,
properties: list[str] = ["energy", "forces", "virial"],
system_changes: list[str] = all_changes,
):
) -> None:
"""Run calculation with deepmd model.
Parameters
Expand Down
14 changes: 6 additions & 8 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,11 @@

if TYPE_CHECKING:
_DICT_VAL = TypeVar("_DICT_VAL")
__all__.extend(
[
"_DICT_VAL",
"_PRECISION",
"_ACTIVATION",
]
)
__all__ += [
"_DICT_VAL",
"_PRECISION",
"_ACTIVATION",
]


def select_idx_map(atom_types: np.ndarray, select_types: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -237,7 +235,7 @@ def get_np_precision(precision: "_PRECISION") -> np.dtype:
raise RuntimeError(f"{precision} is not a valid precision")


def symlink_prefix_files(old_prefix: str, new_prefix: str):
def symlink_prefix_files(old_prefix: str, new_prefix: str) -> None:
"""Create symlinks from old checkpoint prefix to new one.
On Windows this function will copy files instead of creating symlinks.
Expand Down
25 changes: 13 additions & 12 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ def __init__(
pair_exclude_types: list[tuple[int, int]] = [],
rcond: Optional[float] = None,
preset_out_bias: Optional[dict[str, np.ndarray]] = None,
):
) -> None:
super().__init__()
self.type_map = type_map
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)
self.rcond = rcond
self.preset_out_bias = preset_out_bias

def init_out_stat(self):
def init_out_stat(self) -> None:
"""Initialize the output bias."""
ntypes = self.get_ntypes()
self.bias_keys: list[str] = list(self.fitting_output_def().keys())
Expand All @@ -68,7 +68,7 @@ def init_out_stat(self):
self.out_bias = out_bias_data
self.out_std = out_std_data

def __setitem__(self, key, value):
def __setitem__(self, key, value) -> None:
if key in ["out_bias"]:
self.out_bias = value
elif key in ["out_std"]:
Expand All @@ -91,7 +91,7 @@ def get_type_map(self) -> list[str]:
def reinit_atom_exclude(
self,
exclude_types: list[int] = [],
):
) -> None:
self.atom_exclude_types = exclude_types
if exclude_types == []:
self.atom_excl = None
Expand All @@ -101,7 +101,7 @@ def reinit_atom_exclude(
def reinit_pair_exclude(
self,
exclude_types: list[tuple[int, int]] = [],
):
) -> None:
self.pair_exclude_types = exclude_types
if exclude_types == []:
self.pair_excl = None
Expand Down Expand Up @@ -201,18 +201,19 @@ def forward_common_atomic(
ret_dict = self.apply_out_stat(ret_dict, atype)

# nf x nloc
atom_mask = ext_atom_mask[:, :nloc].astype(xp.int32)
atom_mask = ext_atom_mask[:, :nloc]
if self.atom_excl is not None:
atom_mask *= self.atom_excl.build_type_exclude_mask(atype)
atom_mask = xp.logical_and(
atom_mask, self.atom_excl.build_type_exclude_mask(atype)
)

for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
out_shape2 = math.prod(out_shape[2:])
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask
tmp_arr = ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
tmp_arr = xp.where(atom_mask[:, :, None], tmp_arr, xp.zeros_like(tmp_arr))
ret_dict[kk] = xp.reshape(tmp_arr, out_shape)
ret_dict["mask"] = xp.astype(atom_mask, xp.int32)

return ret_dict

Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
fitting,
type_map: list[str],
**kwargs,
):
) -> None:
super().__init__(type_map, **kwargs)
self.type_map = type_map
self.descriptor = descriptor
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
models: list[BaseAtomicModel],
type_map: list[str],
**kwargs,
):
) -> None:
super().__init__(type_map, **kwargs)
super().init_out_stat()

Expand Down Expand Up @@ -391,7 +391,7 @@ def __init__(
type_map: list[str],
smin_alpha: Optional[float] = 0.1,
**kwargs,
):
) -> None:
models = [dp_model, zbl_model]
kwargs["models"] = models
kwargs["type_map"] = type_map
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
rcond: Optional[float] = None,
atom_ener: Optional[list[float]] = None,
**kwargs,
):
) -> None:
super().__init__(type_map, **kwargs)
super().init_out_stat()
self.tab_file = tab_file
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/atomic_model/property_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@


class DPPropertyAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
def __init__(self, descriptor, fitting, type_map, **kwargs) -> None:
assert isinstance(fitting, PropertyFittingNet)
super().__init__(descriptor, fitting, type_map, **kwargs)
108 changes: 106 additions & 2 deletions deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
ABC,
abstractmethod,
)
from functools import (
wraps,
)
from typing import (
Any,
Callable,
Optional,
overload,
)

import array_api_compat
Expand All @@ -29,7 +34,7 @@
"double": np.float64,
"int32": np.int32,
"int64": np.int64,
"bool": bool,
"bool": np.bool_,
"default": GLOBAL_NP_FLOAT_PRECISION,
# NumPy doesn't have bfloat16 (and doesn't plan to add)
# ml_dtypes is a solution, but it seems not supporting np.save/np.load
Expand All @@ -45,7 +50,7 @@
np.int32: "int32",
np.int64: "int64",
ml_dtypes.bfloat16: "bfloat16",
bool: "bool",
np.bool_: "bool",
}
assert set(RESERVED_PRECISON_DICT.keys()) == set(PRECISION_DICT.values())
DEFAULT_PRECISION = "float64"
Expand Down Expand Up @@ -116,6 +121,105 @@ def to_numpy_array(x: Any) -> Optional[np.ndarray]:
return np.from_dlpack(x)


def cast_precision(func: Callable[..., Any]) -> Callable[..., Any]:
"""A decorator that casts and casts back the input
and output tensor of a method.
The decorator should be used on an instance method.
The decorator will do the following thing:
(1) It casts input arrays from the global precision
to precision defined by property `precision`.
(2) It casts output arrays from `precision` to
the global precision.
(3) It checks inputs and outputs and only casts when
input or output is an array and its dtype matches
the global precision and `precision`, respectively.
If it does not match (e.g. it is an integer), the decorator
will do nothing on it.
The decorator supports the array API.
Returns
-------
Callable
a decorator that casts and casts back the input and
output array of a method
Examples
--------
>>> class A:
... def __init__(self):
... self.precision = "float32"
...
... @cast_precision
... def f(x: Array, y: Array) -> Array:
... return x**2 + y
"""

@wraps(func)
def wrapper(self, *args, **kwargs):
# only convert tensors
returned_tensor = func(
self,
*[safe_cast_array(vv, "global", self.precision) for vv in args],
**{
kk: safe_cast_array(vv, "global", self.precision)
for kk, vv in kwargs.items()
},
)
if isinstance(returned_tensor, tuple):
return tuple(
safe_cast_array(vv, self.precision, "global") for vv in returned_tensor
)
elif isinstance(returned_tensor, dict):
return {
kk: safe_cast_array(vv, self.precision, "global")
for kk, vv in returned_tensor.items()
}
else:
return safe_cast_array(returned_tensor, self.precision, "global")

return wrapper


@overload
def safe_cast_array(
input: np.ndarray, from_precision: str, to_precision: str
) -> np.ndarray: ...
@overload
def safe_cast_array(input: None, from_precision: str, to_precision: str) -> None: ...
def safe_cast_array(
input: Optional[np.ndarray], from_precision: str, to_precision: str
) -> Optional[np.ndarray]:
"""Convert an array from a precision to another precision.
If input is not an array or without the specific precision, the method will not
cast it.
Array API is supported.
Parameters
----------
input : np.ndarray or None
Input array
from_precision : str
Array data type that is casted from
to_precision : str
Array data type that casts to
Returns
-------
np.ndarray or None
casted array
"""
if array_api_compat.is_array_api_obj(input):
xp = array_api_compat.array_namespace(input)
if input.dtype == get_xp_precision(xp, from_precision):
return xp.astype(input, get_xp_precision(xp, to_precision))
return input


__all__ = [
"GLOBAL_NP_FLOAT_PRECISION",
"GLOBAL_ENER_FLOAT_PRECISION",
Expand Down
7 changes: 4 additions & 3 deletions deepmd/dpmodel/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
)
from typing import (
Callable,
NoReturn,
Optional,
Union,
)
Expand Down Expand Up @@ -83,7 +84,7 @@ def compute_input_stats(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
path: Optional[DPPath] = None,
):
) -> NoReturn:
"""
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.
Expand All @@ -106,7 +107,7 @@ def get_stats(self) -> dict[str, StatItem]:
"""Get the statistics of the descriptor."""
raise NotImplementedError

def share_params(self, base_class, shared_level, resume=False):
def share_params(self, base_class, shared_level, resume=False) -> NoReturn:
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
Expand Down Expand Up @@ -135,7 +136,7 @@ def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""


def extend_descrpt_stat(des, type_map, des_with_stat=None):
def extend_descrpt_stat(des, type_map, des_with_stat=None) -> None:
r"""
Extend the statistics of a descriptor block with types from newly provided `type_map`.
Expand Down
Loading

0 comments on commit 44ee0b8

Please sign in to comment.