Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make MCF writer feature complete #560

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 123 additions & 96 deletions gmso/formats/mcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ def write_mcf(top, filename):
in_ring, frag_list, frag_conn = _id_rings_fragments(top)

# TODO: What oh what to do about subtops?
# For now refuse topologies with subtops as MCF writer is for
# single molecules
if top.n_subtops > 0:
# Refuse topologies with subtops as MCF writer is for single molecules
if top.n_subtops > 1:
raise GMSOError(
"MCF writer does not support subtopologies. "
"Please provide a single molecule as an gmso.Topology "
Expand All @@ -66,18 +65,16 @@ def write_mcf(top, filename):
"!Molecular connectivity file\n"
"!***************************************"
"****************************************\n"
"!File {} written by gmso {} at {}\n\n".format(
filename, __version__, str(datetime.datetime.now())
)
f"!File {filename} written by gmso {__version__} "
f"at {str(datetime.datetime.now())}\n\n"
)

mcf.write(header)
_write_atom_information(mcf, top, in_ring)
_write_bond_information(mcf, top)
_write_angle_information(mcf, top)
_write_dihedral_information(mcf, top)
# TODO: Add improper information
# _write_improper_information(mcf, top)
_write_improper_information(mcf, top)
_write_fragment_information(mcf, top, frag_list, frag_conn)
_write_intrascaling_information(mcf, top)

Expand Down Expand Up @@ -106,7 +103,13 @@ def _id_rings_fragments(top):
# Identify atoms in rings
bond_graph = nx.Graph()
bond_graph.add_edges_from(
[[bond.atom1.idx, bond.atom2.idx] for bond in top.bonds]
[
[
top.get_index(bond.connection_members[0]),
top.get_index(bond.connection_members[1]),
]
for bond in top.bonds
]
)
if len(top.bonds) == 0:
warnings.warn(
Expand All @@ -118,7 +121,7 @@ def _id_rings_fragments(top):
return in_ring, frag_list, frag_conn

# Check if entire molecule is connected. Warn if not.
if nx.is_connected(bond_graph) == False:
if not nx.is_connected(bond_graph):
raise ValueError(
"Not all components of the molecule are connected. "
"MCF files are for a single molecule and thus "
Expand All @@ -142,21 +145,24 @@ def _id_rings_fragments(top):
i: list(bond_graph.neighbors(i))
for i in range(bond_graph.number_of_nodes())
}
# First ID fused rings
fused_rings = []
rings_to_remove = []
for i in range(len(all_rings)):
ring1 = all_rings[i]
for j in range(i + 1, len(all_rings)):
ring2 = all_rings[j]
shared_atoms = list(set(ring1) & set(ring2))
if len(shared_atoms) == 2:
fused_rings.append(list(set(ring1 + ring2)))
rings_to_remove.append(ring1)
rings_to_remove.append(ring2)
for ring in rings_to_remove:
all_rings.remove(ring)
all_rings = all_rings + fused_rings

# Handle fused/adjoining rings
rings_changed = True
while rings_changed:
rings_changed = False
for ring1 in all_rings:
if rings_changed:
break
for ring2 in all_rings:
if ring1 == ring2:
continue
if len(set(ring1) & set(ring2)) > 0:
all_rings.remove(ring1)
all_rings.remove(ring2)
all_rings.append(list(set(ring1 + ring2)))
rings_changed = True
break

# ID fragments which contain a ring
for ring in all_rings:
adjacent_atoms = []
Expand Down Expand Up @@ -208,53 +214,54 @@ def _write_atom_information(mcf, top, in_ring):
Topology object
in_ring : list
Boolean for each atom idx True if atom belongs to a ring

"""
names = [site.name for site in top.sites]
types = [site.atom_type.name for site in top.sites]
# Based upon Cassandra; updated following 1.2.2 release
max_element_length = 6
max_atomtype_length = 20

# Sort to make sure the atom order matches the topology indexing
sorted_sites = [site for site in top.sites]
sorted_sites.sort(key=lambda site: top.get_index(site))
names = [site.name for site in sorted_sites]
types = [site.atom_type.name for site in sorted_sites]

# Check constraints on atom type length and element name length
# TODO: Update these following Cassandra release
# to be more reasonable values
n_unique_names = len(set(names))
for name in names:
if len(name) > 2:
if len(name) > max_element_length:
warnings.warn(
"Warning, name {} will be shortened "
"to two characters. Please confirm your final "
"MCF.".format(name)
f"Name: {name} will be shortened to {max_element_length}"
"characters. Please confirm your final MCF."
)

# Confirm that shortening names to two characters does not
# cause two previously unique atom names to become identical.
names = [name[:2] for name in names]
names = [name[:max_element_length] for name in names]
if len(set(names)) < n_unique_names:
warnings.warn(
"Warning, the number of unique names has been "
"reduced due to shortening the name to two "
"characters."
"The number of unique names has been reduced due"
f"to shortening the name to {max_element_length} characters."
)

n_unique_types = len(set(types))
for itype in types:
if len(itype) > 6:
for type_ in types:
if len(type_) > max_atomtype_length:
warnings.warn(
"Warning, type name {} will be shortened to six "
"characters as {}. Please confirm your final "
"MCF.".format(itype, itype[-6:])
f"Type name: {type_} will be shortened to "
f"{max_atomtype_length} characters as "
f"{type[-max_atomtype_length:]}. Please confirm your final MCF."
)
types = [itype[-6:] for itype in types]
types = [itype[-max_atomtype_length:] for itype in types]
if len(set(types)) < n_unique_types:
warnings.warn(
"Warning, the number of unique atomtypes has been "
"reduced due to shortening the atomtype name to six "
"characters."
"The number of unique atomtypes has been reduced due to "
f"shortening the atomtype name to {max_atomtype_length} characters."
)

# Detect VDW style
vdw_styles = set()
for site in top.sites:
vdw_styles.add(_get_vdw_style(site.atom_type))
for site in sorted_sites:
vdw_styles.add(_get_vdw_style(site))
if len(vdw_styles) > 1:
raise GMSOError(
"More than one vdw_style detected. "
Expand All @@ -276,14 +283,14 @@ def _write_atom_information(mcf, top, in_ring):
)

mcf.write(header)
mcf.write("{:d}\n".format(len(top.sites)))
for (idx, site) in enumerate(top.sites):
mcf.write("{:d}\n".format(len(sorted_sites)))
for (idx, site) in enumerate(sorted_sites):
mcf.write(
"{:<4d} "
"{:<6s} "
"{:<2s} "
"{:7.3f} "
"{:7.3f} ".format(
"{:12.8f} ".format(
idx + 1,
types[idx],
names[idx],
Expand Down Expand Up @@ -353,8 +360,9 @@ def _write_bond_information(mcf, top):
"{:s} "
"{:10.5f}\n".format(
idx + 1,
bond.connection_members[0].idx + 1, # TODO: Confirm the +1 here
bond.connection_members[1].idx + 1,
top.get_index(bond.connection_members[0])
+ 1, # TODO: Confirm the +1 here
top.get_index(bond.connection_members[1]) + 1,
"fixed",
bond.connection_type.parameters["r_eq"]
.in_units(u.Angstrom)
Expand All @@ -373,8 +381,6 @@ def _write_angle_information(mcf, top):
top : Topology
Topology object
"""
# TODO: Add support for fixed angles
angle_style = "harmonic"
header = (
"\n!Angle Format\n"
"!index i j k type parameters\n"
Expand All @@ -386,27 +392,38 @@ def _write_angle_information(mcf, top):
mcf.write("{:d}\n".format(len(top.angles)))
for idx, angle in enumerate(top.angles):
mcf.write(
"{:<4d} "
"{:<4d} "
"{:<4d} "
"{:<4d} "
"{:s} "
"{:10.5f} "
"{:10.5f}\n".format(
idx + 1,
angle.connection_members[0].idx + 1,
angle.connection_members[1].idx
+ 1, # TODO: Confirm order for angles i-j-k
angle.connection_members[2].idx + 1,
angle_style,
(0.5 * angle.connection_type.parameters["k"] / u.kb)
.in_units("K/rad**2")
.value, # TODO: k vs. k/2. conversion
angle.connection_type.parameters["theta_eq"]
.in_units(u.degree)
.value,
)
f"{idx + 1:<4d} "
f"{top.get_index(angle.connection_members[0]) + 1:<4d} "
f"{top.get_index(angle.connection_members[1]) + 1:<4d} "
f"{top.get_index(angle.connection_members[2]) + 1:<4d} "
)
angle_style = _get_angle_style(angle)
if angle_style == "fixed":
mcf.write(
"{:s} "
"{:10.5f}\n".format(
angle_style,
angle.connection_type.parameters["theta_eq"]
.in_units(u.degree)
.value,
)
)
elif angle_style == "harmonic":
mcf.write(
"{:s} "
"{:10.5f} "
"{:10.5f}\n".format(
angle_style,
(0.5 * angle.connection_type.parameters["k"] / u.kb)
.in_units("K/rad**2")
.value, # TODO: k vs. k/2. conversion
angle.connection_type.parameters["theta_eq"]
.in_units(u.degree)
.value,
)
)
else:
raise GMSOError("Unsupported angle style for Cassandra MCF writer")


def _write_dihedral_information(mcf, top):
Expand All @@ -418,8 +435,6 @@ def _write_dihedral_information(mcf, top):
The file object of the Cassandra mcf being written
top : Topology
Topology object
dihedral_style : string
Dihedral style for Cassandra to use
"""
# Dihedral info
header = (
Expand All @@ -434,7 +449,6 @@ def _write_dihedral_information(mcf, top):

mcf.write(header)

# TODO: Are impropers buried in dihedrals?
mcf.write("{:d}\n".format(len(top.dihedrals)))
for (idx, dihedral) in enumerate(top.dihedrals):
mcf.write(
Expand All @@ -444,10 +458,10 @@ def _write_dihedral_information(mcf, top):
"{:<4d} "
"{:<4d} ".format(
idx + 1,
dihedral.connection_members[0].idx + 1,
dihedral.connection_members[1].idx + 1,
dihedral.connection_members[2].idx + 1,
dihedral.connection_members[3].idx + 1,
top.get_index(dihedral.connection_members[0]) + 1,
top.get_index(dihedral.connection_members[1]) + 1,
top.get_index(dihedral.connection_members[2]) + 1,
top.get_index(dihedral.connection_members[3]) + 1,
)
)
dihedral_style = _get_dihedral_style(dihedral)
Expand Down Expand Up @@ -541,19 +555,19 @@ def _write_improper_information(mcf, top):
mcf.write(header)
mcf.write("{:d}\n".format(len(top.impropers)))

improper_type = "harmonic"
improper_style = "harmonic"
for i, improper in enumerate(top.impropers):
mcf.write(
"{:<4d} {:<4d} {:<4d} {:<4d} {:<4d}"
" {:s} {:8.3f} {:8.3f}\n".format(
" {:s} {:10.5f} {:10.5f}\n".format(
i + 1,
improper.atom1.idx + 1,
improper.atom2.idx + 1,
improper.atom3.idx + 1,
improper.atom4.idx + 1,
improper_type,
improper.type.psi_k * KCAL_TO_KJ,
improper.type.psi_eq,
top.get_index(improper.connection_members[0]) + 1,
top.get_index(improper.connection_members[1]) + 1,
top.get_index(improper.connection_members[2]) + 1,
top.get_index(improper.connection_members[3]) + 1,
improper_style,
0.5 * improper.connection_type.parameters["k"],
improper.connection_type.parameters["phi_eq"],
)
)

Expand Down Expand Up @@ -655,25 +669,38 @@ def _check_compatibility(top):
accepted_potentials = [
potential_templates["LennardJonesPotential"],
potential_templates["MiePotential"],
potential_templates["FixedBondPotential"],
potential_templates["HarmonicBondPotential"],
potential_templates["HarmonicAnglePotential"],
potential_templates["FixedAnglePotential"],
potential_templates["PeriodicTorsionPotential"],
potential_templates["OPLSTorsionPotential"],
potential_templates["RyckaertBellemansTorsionPotential"],
]
check_compatibility(top, accepted_potentials)


def _get_vdw_style(atom_type):
def _get_vdw_style(site):
"""Return the vdw style."""
vdw_styles = {
"LJ": potential_templates["LennardJonesPotential"],
"Mie": potential_templates["MiePotential"],
}

return _get_potential_style(vdw_styles, atom_type)
return _get_potential_style(vdw_styles, site.atom_type)


def _get_angle_style(angle):
"""Return the angle style."""
angle_styles = {
"harmonic": potential_templates["HarmonicAnglePotential"],
"fixed": potential_templates["FixedAnglePotential"],
}

return _get_potential_style(angle_styles, angle.connection_type)


def _get_dihedral_style(dihedral_type):
def _get_dihedral_style(dihedral):
"""Return the dihedral style."""
dihedral_styles = {
"charmm": potential_templates["PeriodicTorsionPotential"],
Expand All @@ -682,7 +709,7 @@ def _get_dihedral_style(dihedral_type):
"ryckaert": potential_templates["RyckaertBellemansTorsionPotential"],
}

return _get_potential_style(dihedral_styles, dihedral_type)
return _get_potential_style(dihedral_styles, dihedral.connection_type)


def _get_potential_style(styles, potential):
Expand Down
5 changes: 5 additions & 0 deletions gmso/lib/jsons/FixedAnglePotential.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"name": "FixedAnglePotential",
"expression": "DiracDelta(theta-theta_eq)",
"independent_variables": "theta"
}
5 changes: 5 additions & 0 deletions gmso/lib/jsons/FixedBondPotential.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"name": "FixedBondPotential",
"expression": "DiracDelta(r-r_eq)",
"independent_variables": "r"
}
Loading