Skip to content

Commit

Permalink
Merge pull request #538 from cophus/occupancy
Browse files Browse the repository at this point in the history
Adding occupancy to py4DSTEM diffraction
  • Loading branch information
bsavitzky authored Jan 23, 2024
2 parents d136d65 + c70914e commit 1a494e1
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 25 deletions.
113 changes: 102 additions & 11 deletions py4DSTEM/process/diffraction/crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fractions import Fraction
from typing import Union, Optional
import sys
import warnings

from emdfile import PointList
from py4DSTEM.process.utils import single_atom_scatter, electron_wavelength_angstrom
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
positions,
numbers,
cell,
occupancy=None,
):
"""
Args:
Expand All @@ -76,7 +78,7 @@ def __init__(
3 numbers: the three lattice parameters for an orthorhombic cell
6 numbers: the a,b,c lattice parameters and ɑ,β,ɣ angles for any cell
3x3 array: row vectors containing the (u,v,w) lattice vectors.
occupancy (np.array): Partial occupancy values for each atomic site. Must match the length of positions
"""
# Initialize Crystal
self.positions = np.asarray(positions) #: fractional atomic coordinates
Expand Down Expand Up @@ -131,6 +133,17 @@ def __init__(
else:
raise Exception("Cell cannot contain " + np.size(cell) + " entries")

# occupancy
if occupancy is not None:
self.occupancy = np.array(occupancy)
# check the occupancy shape makes sense
if self.occupancy.shape[0] != self.positions.shape[0]:
raise Warning(
f"Number of occupancies ({self.occupancy.shape[0]}) and atomic positions ({self.positions.shape[0]}) do not match"
)
else:
self.occupancy = np.ones(self.positions.shape[0], dtype=np.float32)

# pymatgen flag
if "pymatgen" in sys.modules:
self.pymatgen_available = True
Expand Down Expand Up @@ -257,7 +270,70 @@ def get_strained_crystal(
else:
return crystal_strained

def from_CIF(CIF, conventional_standard_structure=True):
@staticmethod
def from_ase(
atoms,
):
"""
Create a py4DSTEM Crystal object from an ASE atoms object
Args:
atoms (ase.Atoms): an ASE atoms object
"""
# get the occupancies from the atoms object
occupancies = (
atoms.arrays["occupancies"]
if "occupancies" in atoms.arrays.keys()
else None
)

if "occupancy" in atoms.info.keys():
warnings.warn(
"This Atoms object contains occupancy information but it will be ignored."
)

xtal = Crystal(
positions=atoms.get_scaled_positions(), # fractional coords
numbers=atoms.numbers,
cell=atoms.cell.array,
occupancy=occupancies,
)
return xtal

@staticmethod
def from_prismatic(filepath):
"""
Create a py4DSTEM Crystal object from an prismatic style xyz co-ordinate file
Args:
filepath (str|Pathlib.Path): path to the prismatic format xyz file
"""

from ase import io

# read the atoms using ase
atoms = io.read(filepath, format="prismatic")

# get the occupancies from the atoms object
occupancies = (
atoms.arrays["occupancies"]
if "occupancies" in atoms.arrays.keys()
else None
)
xtal = Crystal(
positions=atoms.get_scaled_positions(), # fractional coords
numbers=atoms.numbers,
cell=atoms.cell.array,
occupancy=occupancies,
)
return xtal

@staticmethod
def from_CIF(
CIF, primitive: bool = True, conventional_standard_structure: bool = True
):
"""
Create a Crystal object from a CIF file, using pymatgen to import the CIF
Expand All @@ -273,12 +349,13 @@ def from_CIF(CIF, conventional_standard_structure=True):

parser = CifParser(CIF)

structure = parser.get_structures(False)[0]
structure = parser.get_structures(primitive=primitive)[0]

return Crystal.from_pymatgen_structure(
structure, conventional_standard_structure=conventional_standard_structure
)

@staticmethod
def from_pymatgen_structure(
structure=None,
formula=None,
Expand Down Expand Up @@ -375,8 +452,6 @@ def from_pymatgen_structure(
else selected["structure"]
)

positions = structure.frac_coords #: fractional atomic coordinates

cell = np.array(
[
structure.lattice.a,
Expand All @@ -388,10 +463,22 @@ def from_pymatgen_structure(
]
)

numbers = np.array([s.species.elements[0].Z for s in structure])
site_data = np.array(
[
(*site.frac_coords, elem.number, comp)
for site in structure
for elem, comp in site.species.items()
]
)
positions = site_data[:, :3]
numbers = site_data[:, 3]
occupancies = site_data[:, 4]

return Crystal(positions, numbers, cell)
return Crystal(
positions=positions, numbers=numbers, cell=cell, occupancy=occupancies
)

@staticmethod
def from_unitcell_parameters(
latt_params,
elements,
Expand Down Expand Up @@ -575,10 +662,14 @@ def calculate_structure_factors(
# Calculate structure factors
self.struct_factors = np.zeros(np.size(self.g_vec_leng, 0), dtype="complex64")
for a0 in range(self.positions.shape[0]):
self.struct_factors += f_all[:, a0] * np.exp(
(2j * np.pi)
* np.sum(
self.hkl * np.expand_dims(self.positions[a0, :], axis=1), axis=0
self.struct_factors += (
f_all[:, a0]
* self.occupancy[a0]
* np.exp(
(2j * np.pi)
* np.sum(
self.hkl * np.expand_dims(self.positions[a0, :], axis=1), axis=0
)
)
)

Expand Down
7 changes: 4 additions & 3 deletions py4DSTEM/process/diffraction/crystal_bloch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def calculate_dynamical_structure_factors(
tol_structure_factor: float = 0.0,
recompute_kinematic_structure_factors=True,
g_vec_precision=None,
verbose=True,
):
"""
Calculate and store the relativistic corrected structure factors used for Bloch computations
Expand Down Expand Up @@ -92,7 +91,7 @@ def calculate_dynamical_structure_factors(

# Calculate the reciprocal lattice points to include based on k_max

k_max = np.asarray(k_max)
k_max: np.ndarray = np.asarray(k_max)

if recompute_kinematic_structure_factors:
if hasattr(self, "struct_factors"):
Expand Down Expand Up @@ -215,7 +214,9 @@ def get_f_e(q, Z, thermal_sigma, method):

# Calculate structure factors
struct_factors = np.sum(
f_e * np.exp(2.0j * np.pi * np.squeeze(self.positions[:, None, :] @ hkl)),
f_e
* self.occupancy[:, None]
* np.exp(2.0j * np.pi * np.squeeze(self.positions[:, None, :] @ hkl)),
axis=0,
)

Expand Down
124 changes: 113 additions & 11 deletions py4DSTEM/process/diffraction/crystal_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from matplotlib.axes import Axes
import matplotlib.tri as mtri
from mpl_toolkits.mplot3d import Axes3D, art3d
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

from scipy.signal import medfilt
from scipy.ndimage import gaussian_filter
from scipy.ndimage import distance_transform_edt
Expand Down Expand Up @@ -91,18 +93,26 @@ def plot_structure(

# Fractional atomic coordinates
pos = self.positions
occ = self.occupancy

# x tile
sub = pos[:, 0] < tol_distance
pos = np.vstack([pos, pos[sub, :] + np.array([1, 0, 0])])
ID = np.hstack([ID, ID[sub]])
if occ is not None:
occ = np.hstack([occ, occ[sub]])
# y tile
sub = pos[:, 1] < tol_distance
pos = np.vstack([pos, pos[sub, :] + np.array([0, 1, 0])])
ID = np.hstack([ID, ID[sub]])
if occ is not None:
occ = np.hstack([occ, occ[sub]])
# z tile
sub = pos[:, 2] < tol_distance
pos = np.vstack([pos, pos[sub, :] + np.array([0, 0, 1])])
ID = np.hstack([ID, ID[sub]])
if occ is not None:
occ = np.hstack([occ, occ[sub]])

# Cartesian atomic positions
xyz = pos @ self.lat_real
Expand Down Expand Up @@ -141,17 +151,109 @@ def plot_structure(

# atoms
ID_all = np.unique(ID)
for ID_plot in ID_all:
sub = ID == ID_plot
ax.scatter(
xs=xyz[sub, 1], # + d[0],
ys=xyz[sub, 0], # + d[1],
zs=xyz[sub, 2], # + d[2],
s=size_marker,
linewidth=2,
facecolors=atomic_colors(ID_plot),
edgecolor=[0, 0, 0],
)
if occ is None:
for ID_plot in ID_all:
sub = ID == ID_plot
ax.scatter(
xs=xyz[sub, 1], # + d[0],
ys=xyz[sub, 0], # + d[1],
zs=xyz[sub, 2], # + d[2],
s=size_marker,
linewidth=2,
facecolors=atomic_colors(ID_plot),
edgecolor=[0, 0, 0],
)
else:
# init
tol = 1e-4
num_seg = 180
radius = 0.7
zp = np.zeros(num_seg + 1)

mark = np.ones(xyz.shape[0], dtype="bool")
for a0 in range(xyz.shape[0]):
if mark[a0]:
xyz_plot = xyz[a0, :]
inds = np.argwhere(np.sum((xyz - xyz_plot) ** 2, axis=1) < tol)
occ_plot = occ[inds]
mark[inds] = False
ID_plot = ID[inds]

if np.sum(occ_plot) < 1.0:
occ_plot = np.append(occ_plot, 1 - np.sum(occ_plot))
ID_plot = np.append(ID_plot, -1)
else:
occ_plot = occ_plot[0]
ID_plot = ID_plot[0]

# Plot site as series of filled arcs
theta0 = 0
for a1 in range(occ_plot.shape[0]):
theta1 = theta0 + occ_plot[a1] * 2.0 * np.pi
theta = np.linspace(theta0, theta1, num_seg + 1)
xp = np.cos(theta) * radius
yp = np.sin(theta) * radius

# Rotate towards camera
xyz_rot = np.vstack((xp.ravel(), yp.ravel(), zp.ravel()))
if occ_plot[a1] < 1.0:
xyz_rot = np.append(
xyz_rot, np.array((0, 0, 0))[:, None], axis=1
)
xyz_rot = orientation_matrix @ xyz_rot

# add to plot
verts = [
list(
zip(
xyz_rot[1, :] + xyz_plot[1],
xyz_rot[0, :] + xyz_plot[0],
xyz_rot[2, :] + xyz_plot[2],
)
)
]
# ax.add_collection3d(
# Poly3DCollection(
# verts
# )
# )
collection = Poly3DCollection(
verts,
linewidths=2.0,
alpha=1.0,
edgecolors="k",
)
face_color = [
0.5,
0.5,
1,
] # alternative: matplotlib.colors.rgb2hex([0.5, 0.5, 1])
if ID_plot[a1] == -1:
collection.set_facecolor((1.0, 1.0, 1.0))
else:
collection.set_facecolor(atomic_colors(ID_plot[a1]))
ax.add_collection3d(collection)

# update start point
if a1 < occ_plot.size:
theta0 = theta1

# for ID_plot in ID_all:
# sub = ID == ID_plot
# ax.scatter(
# xs=xyz[sub, 1], # + d[0],
# ys=xyz[sub, 0], # + d[1],
# zs=xyz[sub, 2], # + d[2],
# s=size_marker,
# linewidth=2,
# facecolors='none',
# edgecolor=[0, 0, 0],
# )
# poly = PolyCollection(
# verts,
# facecolors=['r', 'g', 'b', 'y'],
# alpha = 0.6)
# ax.add_collection3d(poly, zs=zs, zdir='y')

# plot limit
if plot_limit is None:
Expand Down

0 comments on commit 1a494e1

Please sign in to comment.