Skip to content

Commit

Permalink
🎨 add function: make_psf
Browse files Browse the repository at this point in the history
   * build PSF (Point spread function) from the DataArray
  • Loading branch information
arafune committed Feb 20, 2024
1 parent 74e855c commit ef24869
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Major Changes from 3.0.1

- Most important change:

- Use correct method to convert from the angle to moementum. (Original approach was found to be incorrect)
- Use correct method to convert from the angle to momentum. (The original way was incorrect)

- New feature
- Provide SPD_main.py & prodigy_itx.py
Expand Down
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
:target: https://github.com/arafune/arpes/actions/workflows/test.yml
.. |code style| image:: https://img.shields.io/badge/code%20style-black-000000.svg
:target: https://github.com/psf/black


.. |code fromat| image:: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json
:target: https://github.com/astral-sh/ruff

PyARPES
=======
Expand Down
68 changes: 64 additions & 4 deletions src/arpes/analysis/deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from __future__ import annotations

from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING

import numpy as np
import scipy
import scipy.ndimage
import xarray as xr
from scipy.stats import multivariate_normal
from skimage.restoration import richardson_lucy

import arpes.xarray_extensions # noqa: F401
Expand All @@ -16,6 +18,8 @@
from arpes.utilities import normalize_to_spectrum

if TYPE_CHECKING:
from collections.abc import Hashable

from numpy.typing import NDArray


Expand All @@ -25,7 +29,17 @@
"make_psf1d",
)

TWO_DIMWENSION = 2
LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[1]
logger = getLogger(__name__)
fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s"
formatter = Formatter(fmt)
handler = StreamHandler()
handler.setLevel(LOGLEVEL)
logger.setLevel(LOGLEVEL)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False


@update_provenance("Approximate Iterative Deconvolution")
Expand Down Expand Up @@ -120,15 +134,61 @@ def make_psf1d(data: xr.DataArray, dim: str, sigma: float) -> xr.DataArray:


@update_provenance("Make Point Spread Function")
def make_psf(data: xr.DataArray, sigmas: dict[str, float]) -> xr.DataArray:
def make_psf(
data: xr.DataArray,
sigmas: dict[Hashable, float],
*,
fwhm: bool = True,
) -> xr.DataArray:
"""Produces an n-dimensional gaussian point spread function for use in deconvolve_rl.
Not yet operational.
Args:
data (DataType): input data
sigmas (dict[str, float]): sigma values for each dimension.
fwhm (bool): if True, sigma is FWHM, not the standard deviation.
Returns:
The PSF to use.
"""
strides = data.G.stride(generic_dim_names=False)
logger.debug(f"strides: {strides}")
assert set(strides) == set(sigmas)
pixels: dict[Hashable, int] = dict(
zip(
data.dims,
tuple([i - 1 if i % 2 == 0 else i for i in data.shape]),
strict=True,
),
)

if fwhm:
sigmas = {k: v / (2 * np.sqrt(2 * np.log(2))) for k, v, in sigmas.items()}
cov: NDArray[np.float_] = np.zeros((len(sigmas), len(sigmas)))
for i, dim in enumerate(data.dims):
cov[i][i] = sigmas[dim] ** 2 # sigma is deviation, but multivariate_normal uses covariant
logger.debug(f"cov: {cov}")

psf_coords: dict[Hashable, NDArray[np.float_]] = {}
for k in data.dims:
psf_coords[str(k)] = np.linspace(
-(pixels[str(k)] - 1) / 2 * strides[str(k)],
(pixels[str(k)] - 1) / 2 * strides[str(k)],
pixels[str(k)],
)
if LOGLEVEL == DEBUG:
for k, v in psf_coords.items():
logger.debug(
f" psf_coords[{k}]: ±{np.max(v):.3f}",
)
coords = np.meshgrid(*[psf_coords[dim] for dim in data.dims], indexing="ij")

coords_for_pdf_pos = np.stack(coords, axis=-1) # point distribution function (pdf)
logger.debug(f"shape of coords_for_pdf_pos: {coords_for_pdf_pos.shape}")
return xr.DataArray(
multivariate_normal(mean=np.zeros(len(sigmas)), cov=cov).pdf(
coords_for_pdf_pos,
),
dims=data.dims,
coords=psf_coords,
name="PSF",
)
4 changes: 2 additions & 2 deletions src/arpes/endstations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class EndstationBase:

ALIASES: ClassVar[list[str]] = []
PRINCIPAL_NAME = ""
ATTR_TRANSFORMS: ClassVar[dict[str, Callable]] = {}
ATTR_TRANSFORMS: ClassVar[dict[str, Callable[..., dict[str, float | list[str] | str]]]] = {}
MERGE_ATTRS: ClassVar[SPECTROMETER] = {}

_SEARCH_DIRECTORIES: tuple[str, ...] = (
Expand Down Expand Up @@ -176,7 +176,7 @@ def is_file_accepted(

return False
try:
_ = cls.find_first_file(str(file))
_ = cls.find_first_file(int(file))
except ValueError:
return False
return True
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/endstations/plugin/ALG_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ALGMainChamber(HemisphericalEndstation, FITSEndstation):
"ALG-Main Chamber",
]

ATTR_TRANSFORMS: ClassVar[dict[str, Callable[..., dict[str, list[str] | str]]]] = {
ATTR_TRANSFORMS: ClassVar[dict[str, Callable[..., dict[str, float | list[str] | str]]]] = {
"START_T": lambda _: {"time": " ".join(_.split(" ")[1:]).lower(), "date": _.split(" ")[0]},
}

Expand Down
2 changes: 1 addition & 1 deletion src/arpes/endstations/plugin/BL10_SARPES.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class BL10012SARPESEndstation(SynchrotronEndstation, HemisphericalEndstation, SE
# Look at merlin.py for details
}

ATTR_TRANSFORMS: ClassVar[dict[str, Callable]] = {
ATTR_TRANSFORMS: ClassVar[dict[str, Callable[..., dict[str, float | list[str] | str]]]] = {
# TODO: Kayla or another user should add these
# Look at merlin.py for details
}
Expand Down
4 changes: 2 additions & 2 deletions src/arpes/endstations/plugin/MAESTRO.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class MAESTROMicroARPESEndstation(MAESTROARPESEndstationBase):
"Z": "z",
}

ATTR_TRANSFORMS: ClassVar[dict[str, Callable]] = {
ATTR_TRANSFORMS: ClassVar[dict[str, Callable[..., dict[str, int | list[str] | str]]]] = {
"START_T": lambda _: {
"time": " ".join(_.split(" ")[1:]).lower(),
"date": _.split(" ")[0],
Expand Down Expand Up @@ -255,7 +255,7 @@ class MAESTRONanoARPESEndstation(MAESTROARPESEndstationBase):
"Slit Defl.": "psi",
}

ATTR_TRANSFORMS: ClassVar[dict[str, Callable]] = {
ATTR_TRANSFORMS: ClassVar[dict[str, Callable[..., dict[str, float | list[str] | str]]]] = {
"START_T": lambda _: {
"time": " ".join(_.split(" ")[1:]).lower(),
"date": _.split(" ")[0],
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/endstations/plugin/example_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def load_single_frame(
if len(coord.values.shape) and cname not in data.dims:
replacement_coords[cname] = coord.mean().item()

data = data.assign_coords(**replacement_coords)
data = data.assign_coords(replacement_coords)

# Wrap into a dataset
dataset = xr.Dataset({"spectrum": data})
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/endstations/plugin/igor_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class IgorEndstation(SingleFileEndstation):

MERGE_ATTRS: ClassVar[SPECTROMETER] = {}

ATTR_TRANSFORMS: ClassVar[dict[str, Callable]] = {}
ATTR_TRANSFORMS: ClassVar[dict[str, Callable[..., dict[str, float | list[str] | str]]]] = {}

def load_single_frame(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/endstations/plugin/merlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class BL403ARPESEndstation(SynchrotronEndstation, HemisphericalEndstation, SESEn
"undulator_type": "elliptically_polarized_undulator",
}

ATTR_TRANSFORMS: ClassVar[dict[str, Callable]] = {
ATTR_TRANSFORMS: ClassVar[dict[str, Callable[..., dict[str, float | list[str] | str]]]] = {
"acquisition_mode": lambda _: _.lower(),
"lens_mode": lambda _: {
"lens_mode": None,
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/utilities/bz.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ def reduced_bz_E_mask(

selector = {}
selector[data.dims[skip_col]] = selector_val
sdata = data.sel(**selector, method="nearest")
sdata = data.sel(selector, method="nearest")

path = matplotlib.path.Path(poly_points)
grid = np.array(
Expand Down

0 comments on commit ef24869

Please sign in to comment.