Skip to content

Commit

Permalink
stm orb same axes ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
eimrek committed Aug 11, 2019
1 parent 8a45b64 commit 83441aa
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 33 deletions.
69 changes: 38 additions & 31 deletions atomistic_tools/cp2k_stm_sts.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,19 +251,26 @@ def calculate_stm_maps(self, fwhms, isovalues, heights, energies):
ch_sts = np.concatenate((ch_sts_n, ch_sts_p), axis=4)
ch_stm = np.concatenate((ch_stm_n, ch_stm_p), axis=4)

# Move energy axis to position 2
cc_sts = np.moveaxis(cc_sts, 4, 2)
cc_stm = np.moveaxis(cc_stm, 4, 2)
ch_sts = np.moveaxis(ch_sts, 4, 2)
ch_stm = np.moveaxis(ch_stm, 4, 2)

self.stm_maps_data['cc_sts'] = cc_sts
self.stm_maps_data['cc_stm'] = cc_stm
self.stm_maps_data['ch_sts'] = ch_sts
self.stm_maps_data['ch_stm'] = ch_stm


def apply_zero_threshold(self, data_array, z_thresh):
def apply_zero_threshold(self, data_array, zero_thresh):
# apply it to every energy slice independently
for i_series in range(data_array.shape[0]):
for i_e in range(data_array.shape[3]):
sli = data_array[i_series, :, :, i_e]
slice_absmax = np.max(np.abs(sli))
sli[np.abs(sli) < slice_absmax*z_thresh] = 0.0
for i_0 in range(data_array.shape[0]): # spin or fwhm
for i_series in range(data_array.shape[1]):
for i_e in range(data_array.shape[2]):
sli = data_array[i_0, i_series, i_e, :, :]
slice_absmax = np.max(np.abs(sli))
sli[np.abs(sli) < slice_absmax*zero_thresh] = 0.0

def collect_local_grid(self, local_arr, global_shape, to_rank = 0):
"""
Expand Down Expand Up @@ -297,17 +304,17 @@ def collect_and_save_stm_maps(self, path = "./stm.npz"):
n_ch = len(self.stm_maps_data['heights'])
n_fwhms = len(self.stm_maps_data['fwhms'])

cc_sts = self.collect_local_grid(self.stm_maps_data['cc_sts'].swapaxes(0, 2), np.array([nx, n_cc, n_fwhms, ny, ne]))
cc_stm = self.collect_local_grid(self.stm_maps_data['cc_stm'].swapaxes(0, 2), np.array([nx, n_cc, n_fwhms, ny, ne]))
ch_sts = self.collect_local_grid(self.stm_maps_data['ch_sts'].swapaxes(0, 2), np.array([nx, n_ch, n_fwhms, ny, ne]))
ch_stm = self.collect_local_grid(self.stm_maps_data['ch_stm'].swapaxes(0, 2), np.array([nx, n_ch, n_fwhms, ny, ne]))
cc_sts = self.collect_local_grid(self.stm_maps_data['cc_sts'].swapaxes(0, 3), np.array([nx, n_cc, ne, n_fwhms, ny]))
cc_stm = self.collect_local_grid(self.stm_maps_data['cc_stm'].swapaxes(0, 3), np.array([nx, n_cc, ne, n_fwhms, ny]))
ch_sts = self.collect_local_grid(self.stm_maps_data['ch_sts'].swapaxes(0, 3), np.array([nx, n_ch, ne, n_fwhms, ny]))
ch_stm = self.collect_local_grid(self.stm_maps_data['ch_stm'].swapaxes(0, 3), np.array([nx, n_ch, ne, n_fwhms, ny]))

if self.mpi_rank == 0:
# back to correct orientation
cc_sts = cc_sts.swapaxes(2, 0)
cc_stm = cc_stm.swapaxes(2, 0)
ch_sts = ch_sts.swapaxes(2, 0)
ch_stm = ch_stm.swapaxes(2, 0)
cc_sts = cc_sts.swapaxes(3, 0)
cc_stm = cc_stm.swapaxes(3, 0)
ch_sts = ch_sts.swapaxes(3, 0)
ch_stm = ch_stm.swapaxes(3, 0)

save_data = {}
save_data['cc_stm'] = cc_stm.astype(np.float16) # all values either way ~ between -2 and 8
Expand All @@ -316,10 +323,10 @@ def collect_and_save_stm_maps(self, path = "./stm.npz"):
save_data['ch_sts'] = ch_sts.astype(np.float32)
### ----------------
### Reduce filesize further by zero threshold
z_thres = 1e-3
#self.apply_zero_threshold(save_data['cc_sts'], z_thres)
#self.apply_zero_threshold(save_data['ch_stm'], z_thres)
#self.apply_zero_threshold(save_data['ch_sts'], z_thres)
zero_thresh = 1e-3
self.apply_zero_threshold(save_data['cc_sts'], zero_thresh)
self.apply_zero_threshold(save_data['ch_stm'], zero_thresh)
self.apply_zero_threshold(save_data['ch_sts'], zero_thresh)
### ----------------

# additionally add info
Expand Down Expand Up @@ -473,10 +480,10 @@ def collect_and_save_orb_maps(self, path = "./orb.npz"):

### collect STM/STS maps at orbital energies

cc_sts = self.collect_local_grid(self.stm_maps_data['cc_sts'].swapaxes(0, 2), np.array([nx, n_cc, n_fwhms, ny, ne]))
cc_stm = self.collect_local_grid(self.stm_maps_data['cc_stm'].swapaxes(0, 2), np.array([nx, n_cc, n_fwhms, ny, ne]))
ch_sts = self.collect_local_grid(self.stm_maps_data['ch_sts'].swapaxes(0, 2), np.array([nx, n_ch, n_fwhms, ny, ne]))
ch_stm = self.collect_local_grid(self.stm_maps_data['ch_stm'].swapaxes(0, 2), np.array([nx, n_ch, n_fwhms, ny, ne]))
cc_sts = self.collect_local_grid(self.stm_maps_data['cc_sts'].swapaxes(0, 3), np.array([nx, n_cc, ne, n_fwhms, ny]))
cc_stm = self.collect_local_grid(self.stm_maps_data['cc_stm'].swapaxes(0, 3), np.array([nx, n_cc, ne, n_fwhms, ny]))
ch_sts = self.collect_local_grid(self.stm_maps_data['ch_sts'].swapaxes(0, 3), np.array([nx, n_ch, ne, n_fwhms, ny]))
ch_stm = self.collect_local_grid(self.stm_maps_data['ch_stm'].swapaxes(0, 3), np.array([nx, n_ch, ne, n_fwhms, ny]))

if self.mpi_rank == 0:

Expand All @@ -485,10 +492,10 @@ def collect_and_save_orb_maps(self, path = "./orb.npz"):
ch_orbs = ch_orbs.swapaxes(3, 0)
cc_orbs = cc_orbs.swapaxes(3, 0)

cc_sts = cc_sts.swapaxes(2, 0)
cc_stm = cc_stm.swapaxes(2, 0)
ch_sts = ch_sts.swapaxes(2, 0)
ch_stm = ch_stm.swapaxes(2, 0)
cc_sts = cc_sts.swapaxes(3, 0)
cc_stm = cc_stm.swapaxes(3, 0)
ch_sts = ch_sts.swapaxes(3, 0)
ch_stm = ch_stm.swapaxes(3, 0)

save_data = {}
save_data['cc_stm'] = cc_stm.astype(np.float16) # all values either way ~ between -2 and 8
Expand All @@ -501,11 +508,11 @@ def collect_and_save_orb_maps(self, path = "./orb.npz"):

### ----------------
### Reduce filesize further by zero threshold
z_thres = 1e-3
#self.apply_zero_threshold(save_data['cc_sts'], z_thres)
#self.apply_zero_threshold(save_data['ch_stm'], z_thres)
#self.apply_zero_threshold(save_data['ch_sts'], z_thres)
#self.apply_zero_threshold(save_data['ch_orbs'], z_thres)
zero_thresh = 1e-3
self.apply_zero_threshold(save_data['cc_sts'], zero_thresh)
self.apply_zero_threshold(save_data['ch_stm'], zero_thresh)
self.apply_zero_threshold(save_data['ch_sts'], zero_thresh)
self.apply_zero_threshold(save_data['ch_orbs'], zero_thresh)
### ----------------
# additionally add info
save_data['orbital_list'] = self.orb_maps_data['orbital_list']
Expand Down
4 changes: 2 additions & 2 deletions stm_sts_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def plot_stm_series(loaded_data, stm_dir, stm_itx_dir):
title = '%s\nE=%.2f eV'%(series_label, energy)
data = stm_series[series_label]
ax = plt.gca()
make_plot(fig, ax, data[:, :, i_e], extent, title=title, title_size=title_font_size, cmap=cmap, noadd=False)
make_plot(fig, ax, data[i_e, :, :], extent, title=title, title_size=title_font_size, cmap=cmap, noadd=False)

series_name = series_label.lower().replace(" ", '_').replace("=", '').replace(",", '')
plot_name = "/%s_%03de%.2f" % (series_name, i_e, energy)
Expand All @@ -134,7 +134,7 @@ def plot_stm_series(loaded_data, stm_dir, stm_itx_dir):
# ---------------------------------------------------
# export IGOR format
igorwave = igor.Wave2d(
data=data[:, :, i_e],
data=data[i_e, :, :],
xmin=extent[0],
xmax=extent[1],
xlabel='x [Angstroms]',
Expand Down

0 comments on commit 83441aa

Please sign in to comment.