Skip to content

Commit

Permalink
Merge branch 'daredevil' of https://github.com/arafune/arpes into dar…
Browse files Browse the repository at this point in the history
…edevil
  • Loading branch information
arafune committed Apr 23, 2024
2 parents df124ba + ec5178d commit 81946e1
Showing 1 changed file with 24 additions and 17 deletions.
41 changes: 24 additions & 17 deletions src/arpes/plotting/stack_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def flat_stack_plot( # noqa: PLR0913
fermi_level: float | None = None,
figsize: tuple[float, float] = (7, 5),
title: str = "",
max_stacks: int = 200,
out: str | Path = "",
loc: LEGENDLOCATION = "upper left",
**kwargs: Unpack[MPLPlotKwargsBasic],
Expand All @@ -229,6 +230,7 @@ def flat_stack_plot( # noqa: PLR0913
(Not drawn)
figsize (tuple[float, float]): figure size
title(str): Title string, by default ""
max_stacks(int): maximum number of the staking spectra
out(str | Path): Path to the figure, by default ""
loc: Legend location
**kwargs: pass to subplot if figsize is set, and ticks is set, and the others to be passed
Expand All @@ -243,11 +245,16 @@ def flat_stack_plot( # noqa: PLR0913
NotImplementedError
_description_
"""
data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
data = _rebinning(
data,
stack_axis=stack_axis,
max_stacks=max_stacks,
method="mean",
)[0]

if len(data_array.dims) != TWO_DIMENSION:
if len(data.dims) != TWO_DIMENSION:
msg = "In order to produce a stack plot, data must be image-like."
msg += f"Passed data included dimensions: {data_array.dims}"
msg += f"Passed data included dimensions: {data.dims}"
raise ValueError(
msg,
)
Expand All @@ -258,12 +265,12 @@ def flat_stack_plot( # noqa: PLR0913
ax_inset = inset_axes(ax, width="40%", height="5%", loc=loc)
assert isinstance(ax, Axes)
if not stack_axis:
stack_axis = str(data_array.dims[0])
stack_axis = str(data.dims[0])

horizontal_dim = next(str(d) for d in data_array.dims if d != stack_axis)
horizontal = data_array.coords[horizontal_dim]
horizontal_dim = next(str(d) for d in data.dims if d != stack_axis)
horizontal = data.coords[horizontal_dim]

if "eV" in data_array.dims and stack_axis != "eV" and fermi_level is not None:
if "eV" in data.dims and stack_axis != "eV" and fermi_level is not None:
ax.axvline(
fermi_level,
color="red",
Expand All @@ -274,15 +281,15 @@ def flat_stack_plot( # noqa: PLR0913

color = kwargs.pop("color", "viridis")

for i, (_, marginal) in enumerate(data_array.G.iterate_axis(stack_axis)):
for i, (_, marginal) in enumerate(data.G.iterate_axis(stack_axis)):
if mode == "line":
ax.plot(
horizontal,
marginal.values,
color=_color_for_plot(
color,
i,
len(data_array.coords[stack_axis]),
len(data.coords[stack_axis]),
),
**kwargs,
)
Expand All @@ -294,23 +301,22 @@ def flat_stack_plot( # noqa: PLR0913
color=_color_for_plot(
color,
i,
len(data_array.coords[stack_axis]),
len(data.coords[stack_axis]),
),
**kwargs,
)
assert isinstance(color, Colormap), "The 'color' arg is not Colormap name."
matplotlib.colorbar.Colorbar(
ax_inset,
orientation="horizontal",
label=label_for_dim(data_array, stack_axis),
label=label_for_dim(data, stack_axis),
norm=matplotlib.colors.Normalize(
vmin=data_array.coords[stack_axis].min().values,
vmax=data_array.coords[stack_axis].max().values,
vmin=data.coords[stack_axis].min().values,
vmax=data.coords[stack_axis].max().values,
),
ticks=matplotlib.ticker.MaxNLocator(2),
cmap=color,
)
ax.set_xlabel(label_for_dim(data_array, horizontal_dim))
ax.set_xlabel(label_for_dim(data, horizontal_dim))
ax.set_ylabel("Spectrum Intensity (arb).")
ax.set_title(title, fontsize=14)
ax.set_xlim(left=horizontal.min().item(), right=horizontal.max().item())
Expand Down Expand Up @@ -548,6 +554,7 @@ def _rebinning(
data: xr.DataArray,
stack_axis: str,
max_stacks: int,
method: Literal["sum", "mean"] = "sum",
) -> tuple[xr.DataArray, str, str]:
"""Preparation for stack plot.
Expand All @@ -557,8 +564,7 @@ def _rebinning(
"""
data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
assert isinstance(data_arr, xr.DataArray)
data_arr_must_be_two_dimensional = 2
assert len(data_arr.dims) == data_arr_must_be_two_dimensional
assert len(data_arr.dims) == TWO_DIMENSION
if not stack_axis:
stack_axis = str(data_arr.dims[0])

Expand All @@ -572,6 +578,7 @@ def _rebinning(
rebin(
data_arr,
bin_width={stack_axis: int(np.ceil(len(stack_coord.values) / max_stacks))},
method=method,
),
stack_axis,
horizontal_axis,
Expand Down

0 comments on commit 81946e1

Please sign in to comment.