From 747b3daa598d310956681274fbe94783308a8e6c Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Mon, 22 Apr 2024 17:17:18 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20=20Fix=20a=20bug=20in=20stack=5F?= =?UTF-8?q?plot.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/arpes/plotting/stack_plot.py | 41 +++++++++++++++++++------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/src/arpes/plotting/stack_plot.py b/src/arpes/plotting/stack_plot.py index aec7a9f9..d525119f 100644 --- a/src/arpes/plotting/stack_plot.py +++ b/src/arpes/plotting/stack_plot.py @@ -214,6 +214,7 @@ def flat_stack_plot( # noqa: PLR0913 fermi_level: float | None = None, figsize: tuple[float, float] = (7, 5), title: str = "", + max_stacks: int = 200, out: str | Path = "", loc: LEGENDLOCATION = "upper left", **kwargs: Unpack[MPLPlotKwargsBasic], @@ -229,6 +230,7 @@ def flat_stack_plot( # noqa: PLR0913 (Not drawn) figsize (tuple[float, float]): figure size title(str): Title string, by default "" + max_stacks(int): maximum number of the staking spectra out(str | Path): Path to the figure, by default "" loc: Legend location **kwargs: pass to subplot if figsize is set, and ticks is set, and the others to be passed @@ -243,11 +245,16 @@ def flat_stack_plot( # noqa: PLR0913 NotImplementedError _description_ """ - data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) + data = _rebinning( + data, + stack_axis=stack_axis, + max_stacks=max_stacks, + method="mean", + )[0] - if len(data_array.dims) != TWO_DIMENSION: + if len(data.dims) != TWO_DIMENSION: msg = "In order to produce a stack plot, data must be image-like." - msg += f"Passed data included dimensions: {data_array.dims}" + msg += f"Passed data included dimensions: {data.dims}" raise ValueError( msg, ) @@ -258,12 +265,12 @@ def flat_stack_plot( # noqa: PLR0913 ax_inset = inset_axes(ax, width="40%", height="5%", loc=loc) assert isinstance(ax, Axes) if not stack_axis: - stack_axis = str(data_array.dims[0]) + stack_axis = str(data.dims[0]) - horizontal_dim = next(str(d) for d in data_array.dims if d != stack_axis) - horizontal = data_array.coords[horizontal_dim] + horizontal_dim = next(str(d) for d in data.dims if d != stack_axis) + horizontal = data.coords[horizontal_dim] - if "eV" in data_array.dims and stack_axis != "eV" and fermi_level is not None: + if "eV" in data.dims and stack_axis != "eV" and fermi_level is not None: ax.axvline( fermi_level, color="red", @@ -274,7 +281,7 @@ def flat_stack_plot( # noqa: PLR0913 color = kwargs.pop("color", "viridis") - for i, (_, marginal) in enumerate(data_array.G.iterate_axis(stack_axis)): + for i, (_, marginal) in enumerate(data.G.iterate_axis(stack_axis)): if mode == "line": ax.plot( horizontal, @@ -282,7 +289,7 @@ def flat_stack_plot( # noqa: PLR0913 color=_color_for_plot( color, i, - len(data_array.coords[stack_axis]), + len(data.coords[stack_axis]), ), **kwargs, ) @@ -294,23 +301,22 @@ def flat_stack_plot( # noqa: PLR0913 color=_color_for_plot( color, i, - len(data_array.coords[stack_axis]), + len(data.coords[stack_axis]), ), **kwargs, ) - assert isinstance(color, Colormap), "The 'color' arg is not Colormap name." matplotlib.colorbar.Colorbar( ax_inset, orientation="horizontal", - label=label_for_dim(data_array, stack_axis), + label=label_for_dim(data, stack_axis), norm=matplotlib.colors.Normalize( - vmin=data_array.coords[stack_axis].min().values, - vmax=data_array.coords[stack_axis].max().values, + vmin=data.coords[stack_axis].min().values, + vmax=data.coords[stack_axis].max().values, ), ticks=matplotlib.ticker.MaxNLocator(2), cmap=color, ) - ax.set_xlabel(label_for_dim(data_array, horizontal_dim)) + ax.set_xlabel(label_for_dim(data, horizontal_dim)) ax.set_ylabel("Spectrum Intensity (arb).") ax.set_title(title, fontsize=14) ax.set_xlim(left=horizontal.min().item(), right=horizontal.max().item()) @@ -548,6 +554,7 @@ def _rebinning( data: xr.DataArray, stack_axis: str, max_stacks: int, + method: Literal["sum", "mean"] = "sum", ) -> tuple[xr.DataArray, str, str]: """Preparation for stack plot. @@ -557,8 +564,7 @@ def _rebinning( """ data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) assert isinstance(data_arr, xr.DataArray) - data_arr_must_be_two_dimensional = 2 - assert len(data_arr.dims) == data_arr_must_be_two_dimensional + assert len(data_arr.dims) == TWO_DIMENSION if not stack_axis: stack_axis = str(data_arr.dims[0]) @@ -572,6 +578,7 @@ def _rebinning( rebin( data_arr, bin_width={stack_axis: int(np.ceil(len(stack_coord.values) / max_stacks))}, + method=method, ), stack_axis, horizontal_axis,