From 69a08611bb2904b090dc9fb20bfeba58f781ea97 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Fri, 26 Apr 2024 13:53:02 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A5=20=20Preparation=20a=20deep=20refa?= =?UTF-8?q?ctoring=20of=20band=5Fanalyais.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit remove weird codes --- src/arpes/analysis/band_analysis.py | 34 +++++++++-------------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/src/arpes/analysis/band_analysis.py b/src/arpes/analysis/band_analysis.py index 2118363c..b50d9d71 100644 --- a/src/arpes/analysis/band_analysis.py +++ b/src/arpes/analysis/band_analysis.py @@ -321,8 +321,8 @@ def fit_patterned_bands( # noqa: PLR0913 1. Fit directions, these are coordinates along the 1D (or maybe later 2D) marginals 2. Broadcast directions, these are directions used to interpolate against the patterned - directions 3. Free directions, these are broadcasted but they are not used to extract initial values of the + directions fit parameters For instance, if you laid out band patterns in a E, k_p, delay spectrum at delta_t=0, then if @@ -478,21 +478,21 @@ def _instantiate_band(partial_band: dict[str, Any]) -> lf.Model: def fit_bands( arr: xr.DataArray, - band_description: list[BandDescription] | BandDescription, + band_descriptions: list[BandDescription], direction: Literal["edc", "mdc", "EDC", "MDC"] = "mdc", - step: Literal["initial", None] = None, ) -> tuple[xr.DataArray | None, None, lf.ModelResult | None]: """Fits bands and determines dispersion in some region of a spectrum. Args: arr(xr.DataArray): ARPES data for fit. - band_description: A description of the bands to fit in the region + band_descriptions: List of the description of the bands to fit in the region direction: fit direction (along the enegy or momentum), default is "mdc" (Momentum Distribution Curve). - step: if "Initial" is set, .... Returns: Fitted bands. + + ToDo: Deep refactoring. The current version may not work. """ assert direction in {"edc", "mdc", "EDC", "MDC"} @@ -509,20 +509,13 @@ def fit_bands( # Let the first band be given by fitting the raw data to this band # Find subsequent peaks by fitting models to the residuals - raw_bands = [band.get("band") if isinstance(band, dict) else band for band in band_description] + raw_bands = [band_description.get("band") for band_description in band_descriptions] initial_fits = None all_fit_parameters = {} - if step == "initial": - residual.plot() - - for band in band_description: - if isinstance(band, dict): - band_inst = band.get("band") - params = band.get("params", {}) - else: - band_inst = band - params = None + for band_description in band_descriptions: + band_inst = band_description.get("band") + params = band_description.get("params", {}) fit_model = band_inst.fit_cls(prefix=band_inst.label) initial_fit = fit_model.guess_fit(residual, params=params) if initial_fits is None: @@ -538,13 +531,6 @@ def fit_bands( # alright for now pass - if step == "initial": - residual.plot() - (residual - residual + initial_fit.best_fit).plot() - - if step == "initial": - return None, None, residual - template = arr.sum(broadcast_direction) band_results = xr.DataArray( np.ndarray(shape=template.values.shape, dtype=object), @@ -563,7 +549,7 @@ def fit_bands( for c, v in all_fit_parameters.items(): delta = np.array(c) - frozen_coordinate current_distance = delta.dot(delta) - if current_distance < dist and direction == "mdc": # TODO: remove me + if current_distance < dist and direction in {"mdc", "MDC"}: # TODO: remove me closest_model_params = v # TODO: mix in any params to the model params