Skip to content

Commit

Permalink
Add functionality to remove sites and connections from a topology (#761)
Browse files Browse the repository at this point in the history
* add connections property to site class

* store connection to each site when adding new connection

* add remove_connections method and clean up remove_site

* create lists from set first before removing things

* add method to topology class to get connections by site; remove connections attr from Site

* remove site's connections in remove_site(); add iter_connections method

* add unit tests

* add doc strings, handle strings in iter_connections

* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci

* add missing doc strings

* raise error for bad connection types, add unit test

* raise errors when removing site/conn not in top

* fix unit test name

* add check site in top to iter_connections_by_site

* build up list of connections before removing from ordered set

* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Co Quach <[email protected]>
  • Loading branch information
3 people authored Sep 13, 2023
1 parent a3fa9f4 commit 0c4c34c
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 1 deletion.
94 changes: 93 additions & 1 deletion gmso/core/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,61 @@ def get_scaling_factors(self, *, molecule_id=None):
]
)

def remove_site(self, site):
"""Remove a site from the topology.
Parameters
----------
site : gmso.core.Site
The site to be removed.
Notes
-----
When a site is removed, any connections that site belonged
to are also removed.
See Also
--------
gmso.core.topology.Topology.iter_connections_by_site
The method that shows all connections belonging to a specific site
"""
if site not in self._sites:
raise ValueError(
f"Site {site} is not currently part of this topology."
)
site_connections = [
conn for conn in self.iter_connections_by_site(site)
]
for conn in site_connections:
self.remove_connection(conn)
self._sites.remove(site)

def remove_connection(self, connection):
"""Remove a connection from the topology.
Parameters
----------
connection : gmso.abc.abstract_conneciton.Connection
The connection to be removed from the topology
Notes
-----
The sites that belong to this connection are
not removed from the topology.
"""
if connection not in self.connections:
raise ValueError(
f"Connection {connection} is not currently part of this topology."
)
if isinstance(connection, gmso.core.bond.Bond):
self._bonds.remove(connection)
elif isinstance(connection, gmso.core.angle.Angle):
self._angles.remove(connection)
elif isinstance(connection, gmso.core.dihedral.Dihedral):
self._dihedrals.remove(connection)
elif isinstance(connection, gmso.core.improper.Improper):
self._impropers.remove(connection)

def set_scaling_factors(self, lj, electrostatics, *, molecule_id=None):
"""Set both lj and electrostatics scaling factors."""
self.set_lj_scale(
Expand Down Expand Up @@ -831,7 +886,6 @@ def add_connection(self, connection, update_types=False):
Improper: self._impropers,
}
connections_sets[type(connection)].add(connection)

if update_types:
self.update_topology()

Expand Down Expand Up @@ -1367,6 +1421,44 @@ def iter_sites_by_molecule(self, molecule_tag):
else:
return self.iter_sites("molecule", molecule_tag)

def iter_connections_by_site(self, site, connections=None):
"""Iterate through this topology's connections which contain
this specific site.
Parameters
----------
site : gmso.core.Site
Site to limit connections search to.
connections : set or list or tuple, optional, default=None
The connection types to include in the search.
If None, iterates through all of a site's connections.
Options include "bonds", "angles", "dihedrals", "impropers"
Yields
------
gmso.abc.abstract_conneciton.Connection
Connection where site is in Connection.connection_members
"""
if site not in self._sites:
raise ValueError(
f"Site {site} is not currently part of this topology."
)
if connections is None:
connections = ["bonds", "angles", "dihedrals", "impropers"]
else:
connections = set([option.lower() for option in connections])
for option in connections:
if option not in ["bonds", "angles", "dihedrals", "impropers"]:
raise ValueError(
"Valid connection types are limited to: "
'"bonds", "angles", "dihedrals", "impropers"'
)
for conn_str in connections:
for conn in getattr(self, conn_str):
if site in conn.connection_members:
yield conn

def create_subtop(self, label_type, label):
"""Create a new Topology object from a molecule or graup of the current Topology.
Expand Down
76 changes: 76 additions & 0 deletions gmso/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,19 @@ def test_add_site(self):
top.add_site(site)
assert top.n_sites == 1

def test_remove_site(self, ethane):
ethane.identify_connections()
for site in ethane.sites[2:]:
ethane.remove_site(site)
assert ethane.n_sites == 2
assert ethane.n_connections == 1

def test_remove_site_not_in_top(self, ethane):
top = Topology()
site = Atom(name="site")
with pytest.raises(ValueError):
top.remove_site(site)

def test_add_connection(self):
top = Topology()
atom1 = Atom(name="atom1")
Expand All @@ -64,6 +77,26 @@ def test_add_connection(self):

assert len(top.connections) == 1

def test_remove_connection(self):
top = Topology()
atom1 = Atom(name="atom1")
atom2 = Atom(name="atom2")
connect = Bond(connection_members=[atom1, atom2])

top.add_connection(connect)
top.add_site(atom1)
top.add_site(atom2)
top.remove_connection(connect)
assert top.n_connections == 0

def test_remove_connection_not_in_top(self):
top = Topology()
atom1 = Atom(name="atom1")
atom2 = Atom(name="atom2")
connect = Bond(connection_members=[atom1, atom2])
with pytest.raises(ValueError):
top.remove_connection(connect)

def test_add_box(self):
top = Topology()
box = Box(2 * u.nm * np.ones(3))
Expand Down Expand Up @@ -905,6 +938,49 @@ def test_iter_sites_by_molecule(self, labeled_top):
for site in labeled_top.iter_sites_by_molecule(molecule_name):
assert site.molecule.name == molecule_name

@pytest.mark.parametrize(
"connections",
["bonds", "angles", "dihedrals", "impropers"],
)
def test_iter_connections_by_site(self, ethane, connections):
type_dict = {
"bonds": Bond,
"angles": Angle,
"dihedrals": Dihedral,
"impropers": Improper,
}
ethane.identify_connections()
site = ethane.sites[0]
for conn in ethane.iter_connections_by_site(
site=site, connections=[connections]
):
assert site in conn.connection_members
assert isinstance(conn, type_dict[connections])

def test_iter_connections_by_site_none(self, ethane):
ethane.identify_connections()
site = ethane.sites[0]
for conn in ethane.iter_connections_by_site(
site=site, connections=None
):
assert site in conn.connection_members

def test_iter_connections_by_site_bad_param(self, ethane):
ethane.identify_connections()
site = ethane.sites[0]
with pytest.raises(ValueError):
for conn in ethane.iter_connections_by_site(
site=site, connections=["bond"]
):
pass

def test_iter_connections_by_site_not_in_top(self):
top = Topology()
site = Atom(name="site")
with pytest.raises(ValueError):
for conn in top.iter_connections_by_site(site):
pass

def test_write_forcefield(self, typed_water_system):
forcefield = typed_water_system.get_forcefield()
assert "opls_111" in forcefield.atom_types
Expand Down

0 comments on commit 0c4c34c

Please sign in to comment.