diff --git a/ddlitlab2024/ml/inference/plot.py b/ddlitlab2024/ml/inference/plot.py index f7a4978..2f9dc89 100644 --- a/ddlitlab2024/ml/inference/plot.py +++ b/ddlitlab2024/ml/inference/plot.py @@ -1,5 +1,6 @@ from dataclasses import asdict +import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F # noqa @@ -12,8 +13,6 @@ from ddlitlab2024.ml.model import End2EndDiffusionTransformer from ddlitlab2024.ml.model.encoder.image import ImageEncoderType, SequenceEncoderType from ddlitlab2024.ml.model.encoder.imu import IMUEncoder -import matplotlib.pyplot as plt - # Check if CUDA is available and set the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -99,7 +98,7 @@ dataloader = iter(dataloader) - for i in range(num_samples): + for _ in range(num_samples): batch = next(dataloader) # Move the data to the device batch = {k: v.to(device) for k, v in asdict(batch).items()} @@ -126,16 +125,29 @@ trajectory = normalizer.denormalize(trajectory) noisy_trajectory = normalizer.denormalize(noisy_trajectory) - # Plot the trajectory context, the noisy trajectory, the denoised trajectory and the target trajectory for each joint + # Plot the trajectory context, the noisy trajectory, the denoised trajectory + # and the target trajectory for each joint plt.figure(figsize=(10, 10)) for j in range(num_joints): plt.subplot(5, 4, j + 1) joint_command_context = batch["joint_command_history"][0, :, j].cpu().numpy() plt.plot(np.arange(len(joint_command_context)), joint_command_context, label="Context") - plt.plot(np.arange(len(joint_command_context), len(joint_command_context) + trajectory_prediction_length), noisy_trajectory[0, :, j].cpu().numpy(), label="Noisy Trajectory") - plt.plot(np.arange(len(joint_command_context), len(joint_command_context) + trajectory_prediction_length), joint_targets[0, :, j].cpu().numpy(), label="Target Trajectory") - plt.plot(np.arange(len(joint_command_context), len(joint_command_context) + trajectory_prediction_length), trajectory[0, :, j].cpu().numpy(), label="Denoised Trajectory") + plt.plot( + np.arange(len(joint_command_context), len(joint_command_context) + trajectory_prediction_length), + noisy_trajectory[0, :, j].cpu().numpy(), + label="Noisy Trajectory", + ) + plt.plot( + np.arange(len(joint_command_context), len(joint_command_context) + trajectory_prediction_length), + joint_targets[0, :, j].cpu().numpy(), + label="Target Trajectory", + ) + plt.plot( + np.arange(len(joint_command_context), len(joint_command_context) + trajectory_prediction_length), + trajectory[0, :, j].cpu().numpy(), + label="Denoised Trajectory", + ) plt.title(f"Joint {dataset.joint_names[j]}") plt.legend() plt.show()