Skip to content

Commit

Permalink
Apply formatting again
Browse files Browse the repository at this point in the history
  • Loading branch information
Flova committed Dec 19, 2024
1 parent c96aa86 commit 933444f
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions ddlitlab2024/ml/inference/plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import asdict

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F # noqa
Expand All @@ -12,8 +13,6 @@
from ddlitlab2024.ml.model import End2EndDiffusionTransformer
from ddlitlab2024.ml.model.encoder.image import ImageEncoderType, SequenceEncoderType
from ddlitlab2024.ml.model.encoder.imu import IMUEncoder
import matplotlib.pyplot as plt


# Check if CUDA is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -99,7 +98,7 @@

dataloader = iter(dataloader)

for i in range(num_samples):
for _ in range(num_samples):
batch = next(dataloader)
# Move the data to the device
batch = {k: v.to(device) for k, v in asdict(batch).items()}
Expand All @@ -126,16 +125,29 @@
trajectory = normalizer.denormalize(trajectory)
noisy_trajectory = normalizer.denormalize(noisy_trajectory)

# Plot the trajectory context, the noisy trajectory, the denoised trajectory and the target trajectory for each joint
# Plot the trajectory context, the noisy trajectory, the denoised trajectory
# and the target trajectory for each joint
plt.figure(figsize=(10, 10))
for j in range(num_joints):
plt.subplot(5, 4, j + 1)

joint_command_context = batch["joint_command_history"][0, :, j].cpu().numpy()
plt.plot(np.arange(len(joint_command_context)), joint_command_context, label="Context")
plt.plot(np.arange(len(joint_command_context), len(joint_command_context) + trajectory_prediction_length), noisy_trajectory[0, :, j].cpu().numpy(), label="Noisy Trajectory")
plt.plot(np.arange(len(joint_command_context), len(joint_command_context) + trajectory_prediction_length), joint_targets[0, :, j].cpu().numpy(), label="Target Trajectory")
plt.plot(np.arange(len(joint_command_context), len(joint_command_context) + trajectory_prediction_length), trajectory[0, :, j].cpu().numpy(), label="Denoised Trajectory")
plt.plot(
np.arange(len(joint_command_context), len(joint_command_context) + trajectory_prediction_length),
noisy_trajectory[0, :, j].cpu().numpy(),
label="Noisy Trajectory",
)
plt.plot(
np.arange(len(joint_command_context), len(joint_command_context) + trajectory_prediction_length),
joint_targets[0, :, j].cpu().numpy(),
label="Target Trajectory",
)
plt.plot(
np.arange(len(joint_command_context), len(joint_command_context) + trajectory_prediction_length),
trajectory[0, :, j].cpu().numpy(),
label="Denoised Trajectory",
)
plt.title(f"Joint {dataset.joint_names[j]}")
plt.legend()
plt.show()
Expand Down

0 comments on commit 933444f

Please sign in to comment.