From b4a146a984821332ad90d15432e59a5b9151dd7f Mon Sep 17 00:00:00 2001 From: Florian Vahl Date: Sat, 26 Oct 2024 22:42:53 +0200 Subject: [PATCH] Apply formatting --- .../ml/run_diffusion_transformer_robot.py | 21 +++++++++--------- .../ml/train_diffusion_transformer_robot.py | 22 +++++++++---------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/ddlitlab2024/ml/run_diffusion_transformer_robot.py b/ddlitlab2024/ml/run_diffusion_transformer_robot.py index cf73d59..a757bfe 100644 --- a/ddlitlab2024/ml/run_diffusion_transformer_robot.py +++ b/ddlitlab2024/ml/run_diffusion_transformer_robot.py @@ -3,7 +3,6 @@ import torch import torch.nn.functional as F # noqa from diffusers.schedulers.scheduling_ddim import DDIMScheduler -from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from ema_pytorch import EMA from torch import nn @@ -89,18 +88,18 @@ def forward(self, x): # Extract the joint command data all joints, and drop the time column joints = [ - #"LHipYaw", - #"RHipYaw", - #"LHipRoll", - #"RHipRoll", - #"LHipPitch", - #"RHipPitch", + # "LHipYaw", + # "RHipYaw", + # "LHipRoll", + # "RHipRoll", + # "LHipPitch", + # "RHipPitch", "LKnee", "RKnee", - #"LAnklePitch", - #"RAnklePitch", - #"LAnkleRoll", - #"RAnkleRoll", + # "LAnklePitch", + # "RAnklePitch", + # "LAnkleRoll", + # "RAnkleRoll", ] trajectory_dim = len(joints) diff --git a/ddlitlab2024/ml/train_diffusion_transformer_robot.py b/ddlitlab2024/ml/train_diffusion_transformer_robot.py index f4eb23d..0421c05 100644 --- a/ddlitlab2024/ml/train_diffusion_transformer_robot.py +++ b/ddlitlab2024/ml/train_diffusion_transformer_robot.py @@ -1,12 +1,12 @@ import matplotlib.pyplot as plt import numpy as np +import pandas as pd import torch import torch.nn.functional as F # noqa from diffusers.schedulers.scheduling_ddim import DDIMScheduler from ema_pytorch import EMA from torch import nn from tqdm import tqdm -import pandas as pd # Check if CUDA is available and set the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -96,18 +96,18 @@ def forward(self, x): # Extract the joint command data all joints, and drop the time column joints = [ - #"LHipYaw", - #"RHipYaw", - #"LHipRoll", - #"RHipRoll", - #"LHipPitch", - #"RHipPitch", + # "LHipYaw", + # "RHipYaw", + # "LHipRoll", + # "RHipRoll", + # "LHipPitch", + # "RHipPitch", "LKnee", "RKnee", - #"LAnklePitch", - #"RAnklePitch", - #"LAnkleRoll", - #"RAnkleRoll", + # "LAnklePitch", + # "RAnklePitch", + # "LAnkleRoll", + # "RAnkleRoll", ] data = data[joints] trajectory_dim = len(joints)