-
Notifications
You must be signed in to change notification settings - Fork 31
/
base.py
99 lines (79 loc) · 3.62 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from hashlib import md5
# import logging
import pymatgen.core.structure
from pymatgen.core.structure import Structure
from typing import Callable, Union
from networkx import MultiDiGraph
from kgcnn.graph.base import GraphDict
# A separate module logger is not need for the base class.
# logging.basicConfig() # Module logger
# module_logger = logging.getLogger(__name__)
# module_logger.setLevel(logging.INFO)
class CrystalPreprocessor(Callable[[Structure], MultiDiGraph]):
"""Base class for crystal preprocessors.
Concrete CrystalPreprocessors must be implemented as subclasses.
"""
node_attributes = []
edge_attributes = []
graph_attributes = []
def __init__(self, output_graph_as_dict: bool = False,
lattice: str = "graph_lattice", species: str = "node_number",
coords: str = "node_coordinates", charge: str = "charge"):
self.output_graph_as_dict = output_graph_as_dict
self._input_config = {
"lattice": lattice, "species": species, "charge": charge, "coords": coords}
def call(self, structure: Structure) -> MultiDiGraph:
r"""Should be implemented in a subclass.
Args:
structure (Structure): Crystal for which the graph representation should be calculated.
Raises:
NotImplementedError:Should be implemented in a subclass.
Returns:
MultiDiGraph: Graph representation of the crystal.
"""
raise NotImplementedError("Must be implemented in sub-classes.")
def __call__(self, structure: Union[Structure, GraphDict]) -> Union[MultiDiGraph, GraphDict]:
r"""Function to process crystal structures. Executes :obj:`call` .
Args:
structure (Structure): Crystal for which the graph representation should be calculated.
Raises:
NotImplementedError:Should be implemented in a subclass.
Returns:
MultiDiGraph: Graph representation of the crystal.
"""
if isinstance(structure, GraphDict):
structure = pymatgen.core.structure.Structure(
lattice=structure.get(self._input_config["lattice"]),
species=structure.get(self._input_config["species"]),
coords=structure.get(self._input_config["coords"]),
charge=structure.get(self._input_config["charge"]),
coords_are_cartesian=True
)
nxg = self.call(structure)
if self.output_graph_as_dict:
g = GraphDict()
g.from_networkx(
nxg, node_attributes=self.node_attributes, edge_attributes=self.edge_attributes,
graph_attributes=self.graph_attributes, reverse_edge_indices=True)
return g
return nxg
def get_config(self) -> dict:
"""Returns a dictionary uniquely identifying the CrystalPreprocessor and its configuration.
Returns:
dict: A dictionary uniquely identifying the CrystalPreprocessor and its configuration.
"""
config = vars(self)
config = {k: v for k, v in config.items() if not k.startswith("_")}
config['preprocessor'] = self.__class__.__name__
config.update(self._input_config)
return config
def hash(self) -> str:
"""Generates a unique hash for the CrystalPreprocessor and its configuration.
Returns:
str: A unique hash for the CrystalPreprocessor and its configuration.
"""
return md5(str(self.get_config()).encode()).hexdigest()
def __hash__(self):
return int(self.hash(), 16)
def __eq__(self, other):
return hash(self) == hash(other)