Skip to content

Commit

Permalink
Apply formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Flova committed Oct 26, 2024
1 parent b4cc464 commit b4a146a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 22 deletions.
21 changes: 10 additions & 11 deletions ddlitlab2024/ml/run_diffusion_transformer_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
import torch.nn.functional as F # noqa
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from ema_pytorch import EMA
from torch import nn

Expand Down Expand Up @@ -89,18 +88,18 @@ def forward(self, x):

# Extract the joint command data all joints, and drop the time column
joints = [
#"LHipYaw",
#"RHipYaw",
#"LHipRoll",
#"RHipRoll",
#"LHipPitch",
#"RHipPitch",
# "LHipYaw",
# "RHipYaw",
# "LHipRoll",
# "RHipRoll",
# "LHipPitch",
# "RHipPitch",
"LKnee",
"RKnee",
#"LAnklePitch",
#"RAnklePitch",
#"LAnkleRoll",
#"RAnkleRoll",
# "LAnklePitch",
# "RAnklePitch",
# "LAnkleRoll",
# "RAnkleRoll",
]
trajectory_dim = len(joints)

Expand Down
22 changes: 11 additions & 11 deletions ddlitlab2024/ml/train_diffusion_transformer_robot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F # noqa
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from ema_pytorch import EMA
from torch import nn
from tqdm import tqdm
import pandas as pd

# Check if CUDA is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -96,18 +96,18 @@ def forward(self, x):

# Extract the joint command data all joints, and drop the time column
joints = [
#"LHipYaw",
#"RHipYaw",
#"LHipRoll",
#"RHipRoll",
#"LHipPitch",
#"RHipPitch",
# "LHipYaw",
# "RHipYaw",
# "LHipRoll",
# "RHipRoll",
# "LHipPitch",
# "RHipPitch",
"LKnee",
"RKnee",
#"LAnklePitch",
#"RAnklePitch",
#"LAnkleRoll",
#"RAnkleRoll",
# "LAnklePitch",
# "RAnklePitch",
# "LAnkleRoll",
# "RAnkleRoll",
]
data = data[joints]
trajectory_dim = len(joints)
Expand Down

0 comments on commit b4a146a

Please sign in to comment.