diff --git a/gmso/core/topology.py b/gmso/core/topology.py index 2472361cb..6058c4ca5 100644 --- a/gmso/core/topology.py +++ b/gmso/core/topology.py @@ -647,6 +647,61 @@ def get_scaling_factors(self, *, molecule_id=None): ] ) + def remove_site(self, site): + """Remove a site from the topology. + + Parameters + ---------- + site : gmso.core.Site + The site to be removed. + + Notes + ----- + When a site is removed, any connections that site belonged + to are also removed. + + See Also + -------- + gmso.core.topology.Topology.iter_connections_by_site + The method that shows all connections belonging to a specific site + """ + if site not in self._sites: + raise ValueError( + f"Site {site} is not currently part of this topology." + ) + site_connections = [ + conn for conn in self.iter_connections_by_site(site) + ] + for conn in site_connections: + self.remove_connection(conn) + self._sites.remove(site) + + def remove_connection(self, connection): + """Remove a connection from the topology. + + Parameters + ---------- + connection : gmso.abc.abstract_conneciton.Connection + The connection to be removed from the topology + + Notes + ----- + The sites that belong to this connection are + not removed from the topology. + """ + if connection not in self.connections: + raise ValueError( + f"Connection {connection} is not currently part of this topology." + ) + if isinstance(connection, gmso.core.bond.Bond): + self._bonds.remove(connection) + elif isinstance(connection, gmso.core.angle.Angle): + self._angles.remove(connection) + elif isinstance(connection, gmso.core.dihedral.Dihedral): + self._dihedrals.remove(connection) + elif isinstance(connection, gmso.core.improper.Improper): + self._impropers.remove(connection) + def set_scaling_factors(self, lj, electrostatics, *, molecule_id=None): """Set both lj and electrostatics scaling factors.""" self.set_lj_scale( @@ -831,7 +886,6 @@ def add_connection(self, connection, update_types=False): Improper: self._impropers, } connections_sets[type(connection)].add(connection) - if update_types: self.update_topology() @@ -1367,6 +1421,44 @@ def iter_sites_by_molecule(self, molecule_tag): else: return self.iter_sites("molecule", molecule_tag) + def iter_connections_by_site(self, site, connections=None): + """Iterate through this topology's connections which contain + this specific site. + + Parameters + ---------- + site : gmso.core.Site + Site to limit connections search to. + connections : set or list or tuple, optional, default=None + The connection types to include in the search. + If None, iterates through all of a site's connections. + Options include "bonds", "angles", "dihedrals", "impropers" + + Yields + ------ + gmso.abc.abstract_conneciton.Connection + Connection where site is in Connection.connection_members + + """ + if site not in self._sites: + raise ValueError( + f"Site {site} is not currently part of this topology." + ) + if connections is None: + connections = ["bonds", "angles", "dihedrals", "impropers"] + else: + connections = set([option.lower() for option in connections]) + for option in connections: + if option not in ["bonds", "angles", "dihedrals", "impropers"]: + raise ValueError( + "Valid connection types are limited to: " + '"bonds", "angles", "dihedrals", "impropers"' + ) + for conn_str in connections: + for conn in getattr(self, conn_str): + if site in conn.connection_members: + yield conn + def create_subtop(self, label_type, label): """Create a new Topology object from a molecule or graup of the current Topology. diff --git a/gmso/tests/test_topology.py b/gmso/tests/test_topology.py index 54fd29d7e..91e68f2e9 100644 --- a/gmso/tests/test_topology.py +++ b/gmso/tests/test_topology.py @@ -52,6 +52,19 @@ def test_add_site(self): top.add_site(site) assert top.n_sites == 1 + def test_remove_site(self, ethane): + ethane.identify_connections() + for site in ethane.sites[2:]: + ethane.remove_site(site) + assert ethane.n_sites == 2 + assert ethane.n_connections == 1 + + def test_remove_site_not_in_top(self, ethane): + top = Topology() + site = Atom(name="site") + with pytest.raises(ValueError): + top.remove_site(site) + def test_add_connection(self): top = Topology() atom1 = Atom(name="atom1") @@ -64,6 +77,26 @@ def test_add_connection(self): assert len(top.connections) == 1 + def test_remove_connection(self): + top = Topology() + atom1 = Atom(name="atom1") + atom2 = Atom(name="atom2") + connect = Bond(connection_members=[atom1, atom2]) + + top.add_connection(connect) + top.add_site(atom1) + top.add_site(atom2) + top.remove_connection(connect) + assert top.n_connections == 0 + + def test_remove_connection_not_in_top(self): + top = Topology() + atom1 = Atom(name="atom1") + atom2 = Atom(name="atom2") + connect = Bond(connection_members=[atom1, atom2]) + with pytest.raises(ValueError): + top.remove_connection(connect) + def test_add_box(self): top = Topology() box = Box(2 * u.nm * np.ones(3)) @@ -905,6 +938,49 @@ def test_iter_sites_by_molecule(self, labeled_top): for site in labeled_top.iter_sites_by_molecule(molecule_name): assert site.molecule.name == molecule_name + @pytest.mark.parametrize( + "connections", + ["bonds", "angles", "dihedrals", "impropers"], + ) + def test_iter_connections_by_site(self, ethane, connections): + type_dict = { + "bonds": Bond, + "angles": Angle, + "dihedrals": Dihedral, + "impropers": Improper, + } + ethane.identify_connections() + site = ethane.sites[0] + for conn in ethane.iter_connections_by_site( + site=site, connections=[connections] + ): + assert site in conn.connection_members + assert isinstance(conn, type_dict[connections]) + + def test_iter_connections_by_site_none(self, ethane): + ethane.identify_connections() + site = ethane.sites[0] + for conn in ethane.iter_connections_by_site( + site=site, connections=None + ): + assert site in conn.connection_members + + def test_iter_connections_by_site_bad_param(self, ethane): + ethane.identify_connections() + site = ethane.sites[0] + with pytest.raises(ValueError): + for conn in ethane.iter_connections_by_site( + site=site, connections=["bond"] + ): + pass + + def test_iter_connections_by_site_not_in_top(self): + top = Topology() + site = Atom(name="site") + with pytest.raises(ValueError): + for conn in top.iter_connections_by_site(site): + pass + def test_write_forcefield(self, typed_water_system): forcefield = typed_water_system.get_forcefield() assert "opls_111" in forcefield.atom_types