From 6e2bb5b5921633d4ca2f5ceb1b17f6ee9f02a315 Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Tue, 29 Aug 2017 21:52:37 +0200 Subject: [PATCH] Replace astypes by more flexible asarray Also remove duplicate methods in System class. --- yaff/system.py | 113 ++++++--------------------------------- yaff/test/test_system.py | 13 +++++ 2 files changed, 28 insertions(+), 98 deletions(-) diff --git a/yaff/system.py b/yaff/system.py index e7a4cbcf..6496d82a 100644 --- a/yaff/system.py +++ b/yaff/system.py @@ -486,7 +486,7 @@ def from_hdf5(cls, f): # String arrays have to be converted back to unicode... for key in 'scopes', 'ffatypes': if key in sgrp: - kwargs[key] = sgrp[key][:].astype('U') + kwargs[key] = np.asarray(sgrp[key][:], 'U22') if log.do_high: log('Read system parameters from %s.' % f.filename) return cls(**kwargs) @@ -527,6 +527,10 @@ def to_file(self, fn): 'bonds': self.bonds, 'rvecs': self.cell.rvecs, 'charges': self.charges, + 'radii': self.radii, + 'valence_charges': self.valence_charges, + 'dipoles': self.dipoles, + 'radii2': self.radii2, 'masses': self.masses, }) elif fn.endswith('.h5'): @@ -557,12 +561,10 @@ def to_hdf5(self, f): sgrp.create_dataset('numbers', data=self.numbers) sgrp.create_dataset('pos', data=self.pos) if self.scopes is not None: - # Strings have to be stored as ascii - sgrp.create_dataset('scopes', data=self.scopes.astype('S22')) + sgrp.create_dataset('scopes', data=np.asarray(self.scopes, 'S22')) sgrp.create_dataset('scope_ids', data=self.scope_ids) if self.ffatypes is not None: - # Strings have to be stored as ascii - sgrp.create_dataset('ffatypes', data=self.ffatypes.astype('S22')) + sgrp.create_dataset('ffatypes', data=np.asarray(self.ffatypes, 'S22')) sgrp.create_dataset('ffatype_ids', data=self.ffatype_ids) if self.bonds is not None: sgrp.create_dataset('bonds', data=self.bonds) @@ -570,10 +572,17 @@ def to_hdf5(self, f): sgrp.create_dataset('rvecs', data=self.cell.rvecs) if self.charges is not None: sgrp.create_dataset('charges', data=self.charges) + if self.radii is not None: + sgrp.create_dataset('radii', data=self.radii) + if self.valence_charges is not None: + sgrp.create_dataset('valence_charges', data=self.charges) + if self.dipoles is not None: + sgrp.create_dataset('dipoles', data=self.dipoles) + if self.radii2 is not None: + sgrp.create_dataset('radii2', data=self.radii2) if self.masses is not None: sgrp.create_dataset('masses', data=self.masses) - def get_scope(self, index): """Return the of the scope (string) of atom with given index""" return self.scopes[self.scope_ids[index]] @@ -1163,95 +1172,3 @@ def error_sq_fn(x, y): log('Generating renumberings.') for match in iter_matches(dm0, dm1, allowed, 1e-3, error_sq_fn, overlapping): yield match - - def to_file(self, fn): - """Write the system to a file - - **Arguments:** - - fn - The file to write to. - - Supported formats are: - - chk - Internal human-readable checkpoint format. This format includes - all the information of a system object. All data are stored in - atomic units. - - h5 - Internal binary checkpoint format. This format includes - all the information of a system object. All data are stored in - atomic units. - - xyz - A simple file with atomic positions and elements. Coordinates - are written in Angstroms. - """ - if fn.endswith('.chk'): - from molmod.io import dump_chk - dump_chk(fn, { - 'numbers': self.numbers, - 'pos': self.pos, - 'ffatypes': self.ffatypes, - 'ffatype_ids': self.ffatype_ids, - 'scopes': self.scopes, - 'scope_ids': self.scope_ids, - 'bonds': self.bonds, - 'rvecs': self.cell.rvecs, - 'charges': self.charges, - 'radii': self.radii, - 'valence_charges': self.valence_charges, - 'dipoles': self.dipoles, - 'radii2': self.radii2, - 'masses': self.masses, - }) - elif fn.endswith('.h5'): - with h5.File(fn, 'w') as f: - self.to_hdf5(f) - elif fn.endswith('.xyz'): - from molmod.io import XYZWriter - from molmod.periodic import periodic - xyz_writer = XYZWriter(fn, [periodic[n].symbol for n in self.numbers]) - xyz_writer.dump(str(self), self.pos) - else: - raise NotImplementedError('The extension of %s does not correspond to any known format.' % fn) - if log.do_high: - with log.section('SYS'): - log('Wrote system to %s.' % fn) - - def to_hdf5(self, f): - """Write the system to a HDF5 file. - - **Arguments:** - - f - A Writable h5.File object. - """ - if 'system' in f: - raise ValueError('The HDF5 file already contains a system description.') - sgrp = f.create_group('system') - sgrp.create_dataset('numbers', data=self.numbers) - sgrp.create_dataset('pos', data=self.pos) - if self.scopes is not None: - sgrp.create_dataset('scopes', data=self.scopes.astype('S22')) - sgrp.create_dataset('scope_ids', data=self.scope_ids) - if self.ffatypes is not None: - sgrp.create_dataset('ffatypes', data=self.ffatypes.astype('S22')) - sgrp.create_dataset('ffatype_ids', data=self.ffatype_ids) - if self.bonds is not None: - sgrp.create_dataset('bonds', data=self.bonds) - if self.cell.nvec > 0: - sgrp.create_dataset('rvecs', data=self.cell.rvecs) - if self.charges is not None: - sgrp.create_dataset('charges', data=self.charges) - if self.radii is not None: - sgrp.create_dataset('radii', data=self.radii) - if self.valence_charges is not None: - sgrp.create_dataset('valence_charges', data=self.charges) - if self.dipoles is not None: - sgrp.create_dataset('dipoles', data=self.dipoles) - if self.radii2 is not None: - sgrp.create_dataset('radii2', data=self.radii2) - if self.masses is not None: - sgrp.create_dataset('masses', data=self.masses) diff --git a/yaff/test/test_system.py b/yaff/test/test_system.py index 7c76fd13..893057e2 100644 --- a/yaff/test/test_system.py +++ b/yaff/test/test_system.py @@ -82,6 +82,19 @@ def test_hdf5(): compare_water32(system0, system1) +def test_hdf5_assign_ffatypes(): + system0 = get_system_water32() + with tmpdir(__name__, 'test_hdf5_assign_ffatypes') as dirname: + system0.ffatypes = ['O', 'H'] + system0.ffatype_ids = np.array([0, 1, 1]*32) + fn = '%s/tmp.h5' % dirname + system0.to_file(fn) + with h5.File(fn) as f: + assert 'system' in f + system1 = System.from_file(fn) + compare_water32(system0, system1) + + def test_ffatypes(): system = get_system_water32() assert (system.ffatypes == ['O', 'H']).all()