Skip to content

Commit

Permalink
🔥 Preparation a deep refactoring of band_analyais.py
Browse files Browse the repository at this point in the history
remove weird codes
  • Loading branch information
arafune committed Apr 26, 2024
1 parent ce27ae7 commit 69a0861
Showing 1 changed file with 10 additions and 24 deletions.
34 changes: 10 additions & 24 deletions src/arpes/analysis/band_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ def fit_patterned_bands( # noqa: PLR0913
1. Fit directions, these are coordinates along the 1D (or maybe later 2D) marginals
2. Broadcast directions, these are directions used to interpolate against the patterned
directions
3. Free directions, these are broadcasted but they are not used to extract initial values of the
directions
fit parameters
For instance, if you laid out band patterns in a E, k_p, delay spectrum at delta_t=0, then if
Expand Down Expand Up @@ -478,21 +478,21 @@ def _instantiate_band(partial_band: dict[str, Any]) -> lf.Model:

def fit_bands(
arr: xr.DataArray,
band_description: list[BandDescription] | BandDescription,
band_descriptions: list[BandDescription],
direction: Literal["edc", "mdc", "EDC", "MDC"] = "mdc",
step: Literal["initial", None] = None,
) -> tuple[xr.DataArray | None, None, lf.ModelResult | None]:
"""Fits bands and determines dispersion in some region of a spectrum.
Args:
arr(xr.DataArray): ARPES data for fit.
band_description: A description of the bands to fit in the region
band_descriptions: List of the description of the bands to fit in the region
direction: fit direction (along the enegy or momentum),
default is "mdc" (Momentum Distribution Curve).
step: if "Initial" is set, ....
Returns:
Fitted bands.
ToDo: Deep refactoring. The current version may not work.
"""
assert direction in {"edc", "mdc", "EDC", "MDC"}

Expand All @@ -509,20 +509,13 @@ def fit_bands(

# Let the first band be given by fitting the raw data to this band
# Find subsequent peaks by fitting models to the residuals
raw_bands = [band.get("band") if isinstance(band, dict) else band for band in band_description]
raw_bands = [band_description.get("band") for band_description in band_descriptions]
initial_fits = None
all_fit_parameters = {}

if step == "initial":
residual.plot()

for band in band_description:
if isinstance(band, dict):
band_inst = band.get("band")
params = band.get("params", {})
else:
band_inst = band
params = None
for band_description in band_descriptions:
band_inst = band_description.get("band")
params = band_description.get("params", {})
fit_model = band_inst.fit_cls(prefix=band_inst.label)
initial_fit = fit_model.guess_fit(residual, params=params)
if initial_fits is None:
Expand All @@ -538,13 +531,6 @@ def fit_bands(
# alright for now
pass

if step == "initial":
residual.plot()
(residual - residual + initial_fit.best_fit).plot()

if step == "initial":
return None, None, residual

template = arr.sum(broadcast_direction)
band_results = xr.DataArray(
np.ndarray(shape=template.values.shape, dtype=object),
Expand All @@ -563,7 +549,7 @@ def fit_bands(
for c, v in all_fit_parameters.items():
delta = np.array(c) - frozen_coordinate
current_distance = delta.dot(delta)
if current_distance < dist and direction == "mdc": # TODO: remove me
if current_distance < dist and direction in {"mdc", "MDC"}: # TODO: remove me
closest_model_params = v

# TODO: mix in any params to the model params
Expand Down

0 comments on commit 69a0861

Please sign in to comment.