Skip to content

Commit

Permalink
Implement recommendations from PR #39
Browse files Browse the repository at this point in the history
  • Loading branch information
LiamPattinson committed Nov 27, 2024
1 parent 2e3ad8e commit 9a7c0da
Showing 1 changed file with 25 additions and 45 deletions.
70 changes: 25 additions & 45 deletions src/sdf_xarray/plotting.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Optional
import matplotlib.pyplot as plt
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import xarray as xr
from IPython.display import HTML
from matplotlib.animation import FuncAnimation

if TYPE_CHECKING:
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation


def get_frame_title(dataset: xr.Dataset, frame: int, display_sdf_name: bool) -> str:
Expand All @@ -14,8 +18,8 @@ def get_frame_title(dataset: xr.Dataset, frame: int, display_sdf_name: bool) ->


def calculate_window_velocity_and_edges(
dataset, target_attribute, time_since_start, x_axis_coord
):
dataset: xr.Dataset, target_attribute: str, time_since_start: str, x_axis_coord: str
) -> tuple[float, tuple[float, float]]:
"""Calculate the moving window's velocity and initial edges.
1. Finds a lineout of the target atribute in the x coordinate of the first frame
Expand All @@ -36,13 +40,12 @@ def calculate_window_velocity_and_edges(
def generate_animation(
dataset: xr.Dataset,
target_attribute: str,
folder_path: Optional[str] = None,
display: bool = False,
display_sdf_name: bool = False,
fps: int = 10,
move_window: bool = False,
ax: plt.Axes | None = None,
**kwargs,
) -> Optional[HTML]:
) -> FuncAnimation:
"""Generate an animation for the given target attribute
Arguments
Expand All @@ -61,13 +64,20 @@ def generate_animation(
Frames per second for the animation (default: 10)
move_window:
If the simulation has a moving window, the animation will move along with it (default: False)
ax:
Matplotlib axes on which to plot
kwargs:
Dictionary of variables from matplotlib
Examples
--------
>>> generateAnimation(dataset, "Derived_Number_Density_Electron")
"""
fig, ax = plt.subplots()
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

if ax is None:
_, ax = plt.subplots()

N_frames = dataset.sizes.get("time")

# Time since the first frame of the simulation
Expand Down Expand Up @@ -96,9 +106,9 @@ def generate_animation(
title = get_frame_title(dataset, 0, display_sdf_name)
ax.set_title(title)
cbar = plt.colorbar(plot, ax=ax)
cbar.set_label(
f'{dataset[target_attribute].attrs.get("long_name")} [${dataset[target_attribute].attrs.get("units")}$]'
)
long_name = dataset[target_attribute].attrs.get("long_name")
units = dataset[target_attribute].attrs.get("units")
cbar.set_label(f"{long_name} [${units}$]")

window_initial_edge = (0, 0)

Expand Down Expand Up @@ -129,40 +139,10 @@ def update(frame):
title = get_frame_title(dataset, frame, display_sdf_name)
ax.set_title(title)

ani = FuncAnimation(
fig,
return FuncAnimation(
ax.get_figure(),
update,
frames=range(N_frames),
interval=1000 / fps,
repeat=True,
)

# Save the animation
if folder_path:
try:
ani.save(
f"{folder_path}/{target_attribute.replace('/', '_')}.mp4",
writer="ffmpeg",
fps=fps,
)
print(
f"Animation saved as MP4 at {folder_path}/{target_attribute.replace('/', '_')}.mp4"
)
except Exception as e:
print(f"Failed to save as MP4 due to {e}. Falling back to GIF.")
# Save as HTML
ani.save(
f"{folder_path}/{target_attribute.replace('/', '_')}.gif",
writer="pillow",
fps=fps,
)
print(
f"Animation saved as GIF at {folder_path}/{target_attribute.replace('/', '_')}.mp4"
)

# Close the figure to avoid displaying the first frame as a separate plot
plt.close(fig)

if display:
return HTML(ani.to_jshtml())
return None

0 comments on commit 9a7c0da

Please sign in to comment.