Skip to content

Commit

Permalink
auto-calculate ADDED_MOS based on basissets
Browse files Browse the repository at this point in the history
  • Loading branch information
dev-zero committed Apr 8, 2021
1 parent 05d8199 commit 2362f08
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 1 deletion.
17 changes: 16 additions & 1 deletion aiida_cp2k/calculations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
validate_pseudos_namespace,
write_basissets,
write_pseudos,
estimate_added_mos,
)
from ..utils import Cp2kInput

Expand Down Expand Up @@ -139,7 +140,7 @@ def prepare_for_submission(self, folder):
:return: `aiida.common.datastructures.CalcInfo` instance
"""

# pylint: disable=too-many-statements,too-many-branches
# pylint: disable=too-many-statements,too-many-branches,too-many-locals

# Create cp2k input file.
inp = Cp2kInput(self.inputs.parameters.get_dict())
Expand Down Expand Up @@ -167,6 +168,20 @@ def prepare_for_submission(self, folder):
self.inputs.structure if 'structure' in self.inputs else None)
write_basissets(inp, self.inputs.basissets, folder)

# if we have both basissets and structure we can start helping the user :)
if 'basissets' in self.inputs and 'structure' in self.inputs:
try:
scf_section = inp.get_section_dict('FORCE_EVAL/DFT/SCF')
except (KeyError, TypeError):
pass # no SCF, no smearing, or multiple FORCE_EVAL, nothing to do (yet)
else:
if 'SMEAR' in scf_section and 'ADDED_MOS' not in scf_section:
# now is our time to shine!
added_mos = estimate_added_mos(self.inputs.basissets, self.inputs.structure)
inp.add_keyword('FORCE_EVAL/DFT/SCF/ADDED_MOS', added_mos)
self.logger.info(f'The FORCE_EVAL/DFT/SCF/ADDED_MOS was added'
f' with an automatically estimated value of {added_mos}')

if 'pseudos' in self.inputs:
validate_pseudos(inp, self.inputs.pseudos, self.inputs.structure if 'structure' in self.inputs else None)
write_pseudos(inp, self.inputs.pseudos, folder)
Expand Down
25 changes: 25 additions & 0 deletions aiida_cp2k/utils/datatype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,31 @@ def validate_basissets(inp, basissets, structure):
kind_sec["ELEMENT"] = bset.element


def estimate_added_mos(basissets, structure, fraction=0.3):
"""Calculate an estimate for ADDED_MOS based on used basis sets"""

symbols = [structure.get_kind(s.kind_name).get_symbols_string() for s in structure.sites]
n_mos = 0

# We are currently overcounting in the following cases:
# * if we get a mix of ORB basissets for the same chemical symbol but different sites
# * if we get multiple basissets for one element (merged within CP2K)

for label, bset in _unpack(basissets):
try:
_, bstype = label.split("_", maxsplit=1)
except ValueError:
bstype = "ORB"

if bstype != "ORB": # ignore non-ORB basissets
continue

n_mos += symbols.count(bset.element) * bset.n_orbital_functions

# at least one additional MO per site, otherwise a fraction of the total number of orbital functions
return max(len(symbols), int(fraction * n_mos))


def write_basissets(inp, basissets, folder):
"""Writes the unified BASIS_SETS file with the used basissets"""
_write_gdt(inp, basissets, folder, "BASIS_SET_FILE_NAME", "BASIS_SETS")
Expand Down
63 changes: 63 additions & 0 deletions aiida_cp2k/utils/input_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,69 @@ def add_keyword(self, kwpath, value, override=True, conflicting_keys=None):

Cp2kInput._add_keyword(kwpath, value, self._params, ovrd=override, cfct=conflicting_keys)

@staticmethod
def _stringify_path(kwpath):
"""Stringify a kwpath argument"""
if isinstance(kwpath, str):
return kwpath

assert isinstance(kwpath, Sequence), "path is neither Sequence nor String"
return "/".join(kwpath)

def get_section_dict(self, kwpath=""):
"""Get a copy of a section from the current input structure
Args:
kwpath: Can be a single keyword, a path with `/` as divider for sections & key,
or a sequence with sections and key.
"""

section = self._get_section_or_kw(kwpath)

if not isinstance(section, Mapping):
raise TypeError(f"Section '{self._stringify_path(kwpath)}' requested, but keyword found")

return deepcopy(section)

def get_keyword_value(self, kwpath):
"""Get the value of a keyword from the current input structure
Args:
kwpath: Can be a single keyword, a path with `/` as divider for sections & key,
or a sequence with sections and key.
"""

keyword = self._get_section_or_kw(kwpath)

if isinstance(keyword, Mapping):
raise TypeError(f"Keyword '{self._stringify_path(kwpath)}' requested, but section found")

return keyword

def _get_section_or_kw(self, kwpath):
"""Retrieve either a section or a keyword given a path"""

if isinstance(kwpath, str):
kwpath = kwpath.split("/") # convert to list of sections if string

# get a copy of the path in a mutable sequence
# accept any case, but internally we use uppercase
# strip empty strings to accept leading "/", "//", etc.
path = [k.upper() for k in kwpath if k]

# start with a reference to the root of the parameters
current = self._params

try:
while path:
current = current[path.pop(0)]
except KeyError:
raise KeyError(f"Section '{self._stringify_path(kwpath)}' not found in parameters")

return current

def render(self):
output = [self.DISCLAIMER]
self._render_section(output, deepcopy(self._params))
Expand Down
83 changes: 83 additions & 0 deletions test/test_gaussian_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,86 @@ def test_without_kinds(cp2k_code, cp2k_basissets, cp2k_pseudos, clear_database):

_, calc_node = run_get_node(CalculationFactory("cp2k"), **inputs)
assert calc_node.exit_status == 0


def test_added_mos(cp2k_code, cp2k_basissets, cp2k_pseudos, clear_database): # pylint: disable=unused-argument
"""Testing CP2K with the Basis Set stored in gaussian.basisset and a smearing section but no predefined ADDED_MOS"""

structure = StructureData(cell=[[4.00759, 0.0, 0.0], [-2.003795, 3.47067475, 0.0],
[3.06349683e-16, 5.30613216e-16, 5.00307]],
pbc=True)
structure.append_atom(position=(-0.00002004, 2.31379473, 0.87543719), symbols="H")
structure.append_atom(position=(2.00381504, 1.15688001, 4.12763281), symbols="H")
structure.append_atom(position=(2.00381504, 1.15688001, 3.37697219), symbols="H")
structure.append_atom(position=(-0.00002004, 2.31379473, 1.62609781), symbols="H")

# parameters
parameters = Dict(
dict={
'GLOBAL': {
'RUN_TYPE': 'ENERGY',
},
'FORCE_EVAL': {
'METHOD': 'Quickstep',
'DFT': {
"XC": {
"XC_FUNCTIONAL": {
"_": "PBE",
},
},
"MGRID": {
"CUTOFF": 100.0,
"REL_CUTOFF": 10.0,
},
"QS": {
"METHOD": "GPW",
"EXTRAPOLATION": "USE_GUESS",
},
"SCF": {
"EPS_SCF": 1e-05,
"MAX_SCF": 3,
"MIXING": {
"METHOD": "BROYDEN_MIXING",
"ALPHA": 0.4,
},
"SMEAR": {
"METHOD": "FERMI_DIRAC",
"ELECTRONIC_TEMPERATURE": 300.0,
},
},
"KPOINTS": {
"SCHEME": "MONKHORST-PACK 2 2 1",
"FULL_GRID": False,
"SYMMETRY": False,
"PARALLEL_GROUP_SIZE": -1,
},
},
},
})

options = {
"resources": {
"num_machines": 1,
"num_mpiprocs_per_machine": 1
},
"max_wallclock_seconds": 1 * 3 * 60,
}

inputs = {
"structure": structure,
"parameters": parameters,
"code": cp2k_code,
"metadata": {
"options": options,
},
"basissets": {label: b for label, b in cp2k_basissets.items() if label == "H"},
"pseudos": {label: p for label, p in cp2k_pseudos.items() if label == "H"},
}

_, calc_node = run_get_node(CalculationFactory("cp2k"), **inputs)

assert calc_node.exit_status == 0

# check that the ADDED_MOS keyword was added within the calculation
with calc_node.open("aiida.inp") as fhandle:
assert any("ADDED_MOS" in line for line in fhandle), "ADDED_MOS not found in the generated CP2K input file"
28 changes: 28 additions & 0 deletions test/test_input_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,31 @@ def test_invalid_preprocessor():
inp = Cp2kInput({"@SET": "bar"})
with pytest.raises(ValueError):
inp.render()


def test_get_keyword_value():
"""Test get_keyword_value()"""
inp = Cp2kInput({"FOO": "bar", "A": {"KW1": "val1"}})
assert inp.get_keyword_value("FOO") == "bar"
assert inp.get_keyword_value("/FOO") == "bar"
assert inp.get_keyword_value("A/KW1") == "val1"
assert inp.get_keyword_value("/A/KW1") == "val1"
assert inp.get_keyword_value(["A", "KW1"]) == "val1"
with pytest.raises(TypeError):
inp.get_keyword_value("A")


def test_get_section_dict():
"""Test get_section_dict()"""
orig_dict = {"FOO": "bar", "A": {"KW1": "val1"}}
inp = Cp2kInput(orig_dict)
assert inp.get_section_dict("/") == orig_dict
assert inp.get_section_dict("////") == orig_dict
assert inp.get_section_dict("") == orig_dict
assert inp.get_section_dict() == orig_dict
assert inp.get_section_dict("/") is not orig_dict # make sure we get a distinct object
assert inp.get_section_dict("A") == orig_dict["A"]
assert inp.get_section_dict("/A") == orig_dict["A"]
assert inp.get_section_dict(["A"]) == orig_dict["A"]
with pytest.raises(TypeError):
inp.get_section_dict("FOO")

0 comments on commit 2362f08

Please sign in to comment.