Skip to content

Commit

Permalink
🔨 reduce the number of args in plot_movie
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Oct 2, 2023
1 parent 473e0dd commit 48e4316
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 31 deletions.
33 changes: 14 additions & 19 deletions arpes/plotting/movie.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,29 +56,26 @@ def plot_movie(
fig, ax = plt.subplots(figsize=(7, 7))
assert isinstance(ax, Axes)
assert isinstance(fig, Figure)

assert isinstance(arpes.config.SETTINGS, dict)
cmap = arpes.config.SETTINGS.get("interactive", {}).get("palette", "viridis")
vmax = data.max().item()
vmin = data.min().item()
kwargs.setdefault(
"cmap",
arpes.config.SETTINGS.get("interactive", {}).get(
"palette",
"viridis",
),
)
kwargs.setdefault("vmax", data.max().item())
kwargs.setdefault("vmim", data.min().item())

if data.S.is_subtracted:
cmap = "RdBu"
vmax = np.max([np.abs(vmin), np.abs(vmax)])
vmin = -vmax

if "vmax" in kwargs:
vmax = kwargs.pop("vmax")
if "vmin" in kwargs:
vmin = kwargs.pop("vmin")
kwargs["cmap"] = "RdBu"
kwargs["vmax"] = np.max([np.abs(kwargs["vmin"]), np.abs(kwargs["vmax"])])
kwargs["vmin"] = -kwargs["vmax"]

plot: QuadMesh = (
data.mean(time_dim)
.transpose()
.plot(
vmax=vmax,
vmin=vmin,
cmap=cmap,
**kwargs,
)
)
Expand All @@ -95,21 +92,19 @@ def animate(i: int) -> tuple[QuadMesh]:
plot.set_array(data_for_plot.values.G.ravel())
return (plot,)

computed_interval = interval

anim = animation.FuncAnimation(
fig,
animate,
init_func=init,
repeat=500,
frames=len(animation_coords),
interval=computed_interval,
interval=interval,
blit=True,
)

animation_writer = animation.writers["ffmpeg"]
writer = animation_writer(
fps=1000 / computed_interval,
fps=1000 / interval,
metadata={"artist": "Me"},
bitrate=1800,
)
Expand Down
33 changes: 21 additions & 12 deletions arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def select_around(

if isinstance(point, tuple | list):
warnings.warn("Dangerous iterable point argument to `select_around`", stacklevel=2)
point = dict(zip(point, self._obj.dims))
point = dict(zip(point, self._obj.dims, strict=True))
if isinstance(point, xr.Dataset):
point = {k: point[k].item() for k in point.data_vars}

Expand Down Expand Up @@ -614,8 +614,6 @@ def short_history(self, key: str = "by") -> list:
def _calculate_symmetry_points(
self,
symmetry_points: dict,
projection_distance: float = 0.05,
include_projected: float = True,
epsilon: float = 0.01,
) -> tuple:
# For each symmetry point, we need to determine if it is projected or not
Expand Down Expand Up @@ -671,7 +669,12 @@ def _calculate_symmetry_points(

return points, projected_points

def symmetry_points(self, *, raw: bool = False, **kwargs: float) -> dict | tuple:
def symmetry_points(
self,
*,
raw: bool = False,
**kwargs: float,
) -> dict | tuple:
"""[TODO:summary].
Args:
Expand Down Expand Up @@ -2008,7 +2011,10 @@ def show_d2(self, **kwargs: Incomplete) -> None:
curve_tool = CurvatureTool(**kwargs)
return curve_tool.make_tool(self._obj)

def show_band_tool(self, **kwargs: Incomplete):
def show_band_tool(
self,
**kwargs: float | str | bool | dict[str, bool],
) -> dict[str, None | xr.DataArray | xr.Dataset | dict[str, Any]]:
"""Opens the Bokeh based band placement tool."""
from arpes.plotting.band_tool import BandTool

Expand Down Expand Up @@ -2053,22 +2059,23 @@ def _referenced_scans_for_spatial_plot(
*,
use_id: bool = True,
pattern: str = "{}.png",
**kwargs: str,
out: str | bool = "",
) -> Path | tuple[Figure, NDArray[Axes]]:
"""[TODO:summary].
Args:
use_id ([TODO:type]): [TODO:description]
pattern ([TODO:type]): [TODO:description]
kwargs: pass to
out (str|bool): if str, Path for output figure. if True,
the file name is automatically set. If False/"", no output is given.
"""
out = kwargs.get("out")
label = self._obj.attrs["id"] if use_id else self.label
if out is not None and isinstance(out, bool):
if isinstance(out, bool) and out is True:
out = pattern.format(f"{label}_reference_scan_fs")
kwargs["out"] = out
elif isinstance(out, bool) and out is False:
out = ""

return plotting.spatial.reference_scan_spatial(self._obj, **kwargs)
return plotting.spatial.reference_scan_spatial(self._obj, out=out)

def _referenced_scans_for_map_plot(
self,
Expand All @@ -2086,6 +2093,8 @@ def _referenced_scans_for_map_plot(
return reference_scan_fermi_surface(self._obj, **kwargs)

class HvRefScanParam(LabeledFermiSurfaceParam):
"""Parameter for hf_ref_scan."""

e_cut: float
bkg_subtraction: float

Expand Down Expand Up @@ -2156,7 +2165,7 @@ def reference_plot(self, **kwargs: IncompleteMPL) -> Axes:
if self.spectrum_type == "cut":
return self._simple_spectrum_reference_plot(**kwargs) # [PColorMeshKwargs]
if self.spectrum_type in {"ucut", "spem"}:
return self._referenced_scans_for_spatial_plot(**kwargs)
return self._referenced_scans_for_spatial_plot(**kwargs) # not kwargs for plot
raise NotImplementedError

@property
Expand Down

0 comments on commit 48e4316

Please sign in to comment.