Skip to content

Commit

Permalink
Merge pull request #7 from lunarlab-gatech/swei-dev
Browse files Browse the repository at this point in the history
lgan-com -> swei-dev -> main: Add CoM experiments' code.
  • Loading branch information
SizheWei authored Dec 3, 2024
2 parents 14458a1 + b6ccace commit f9ac502
Show file tree
Hide file tree
Showing 11 changed files with 850 additions and 465 deletions.
37 changes: 37 additions & 0 deletions cfg/solo-c2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
group_label: C2

# QJ: Joint Space symmetries____________________________________
# _____RL____|___FL____|____RR____|____FR___|
# q = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
# Configure qj (joint-space) group actions
# _______gs (Sagittal Symmetry)__________|_______gt (Transversal symmetry)______
permutation_Q_js: [[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5], [3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8]]
# Reflections are determined by joint frame predefined orientation.
reflection_Q_js: [[-1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, 1], [1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1]]

# B: Body Space symmetries____________________________________
# RL, FL, RR, FR
# q_lin (acc) = [x0, y0, z0, x1, y1, z1, x2, y2, z2, x3, y3, z3]
# q_ang (vel) = [x0, y0, z0, x1, y1, z1, x2, y2, z2, x3, y3, z3]
# _______gs (Sagittal Symmetry)__________|_______gt (Transversal symmetry)______
permutation_Q_bs: [[3, 4, 5, 0, 1, 2], [0, 1, 2, 3, 4, 5]]
# Reflections are determined by body frame predefined orientation.
reflection_Q_bs_lin: [[1, -1, 1, 1, -1, 1], [-1, 1, 1, -1, 1, 1]]
reflection_Q_bs_ang: [[-1, 1, -1, -1, 1, -1], [1, -1, -1, 1, -1, -1]]

# F: Foot Space symmetries____________________________________
# RL, FL, RR, FR
# q = [x0, y0, z0, x1, y1, z1, x2, y2, z2, x3, y3, z3]
# _______gs (Sagittal Symmetry)__________|_______gt (Transversal symmetry)______
permutation_Q_fs: [[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5], [3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8]]
# Reflections are determined by body frame predefined orientation.
# gs: [1, -1, 1], gt: [-1, 1, 1]
reflection_Q_fs: [[1, -1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1], [-1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, 1]]

# L: Label Space symmetries____________________________________
# RL,FL,RR,FR
# q = [0, 1, 2, 3]
# _____gs______|______gt_____
permutation_Q_ls: [[2, 3, 0, 1], [1, 0, 3, 2]]
# Reflections are determined by body frame predefined orientation.
reflection_Q_ls: [[1, 1, 1, 1], [1, 1, 1, 1]]
38 changes: 38 additions & 0 deletions cfg/solo-k4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
group_label: K4

# QJ: Joint Space symmetries____________________________________
# _____RL____|___FL____|____RR____|____FR___|
# q = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
# Configure qj (joint-space) group actions
# _______gs (Sagittal Symmetry)__________|_______gt (Transversal symmetry)______
permutation_Q_js: [[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5], [3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8]]
# Reflections are determined by joint frame predefined orientation.
reflection_Q_js: [[-1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, 1], [1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1]]

# B: Body Space symmetries____________________________________
# RL, FL, RR, FR
# q_lin (acc) = [x0, y0, z0, x1, y1, z1, x2, y2, z2, x3, y3, z3]
# q_ang (vel) = [x0, y0, z0, x1, y1, z1, x2, y2, z2, x3, y3, z3]
# _______gs (Sagittal Symmetry)__________|_______gt (Transversal symmetry)______
permutation_Q_bs: [[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5], [3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8]]
# Reflections are determined by body frame predefined orientation.
reflection_Q_bs_lin: [[1, -1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1], [-1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, 1]]
reflection_Q_bs_ang: [[-1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1], [1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1]]

# # F: Foot Space symmetries____________________________________
# # RL, FL, RR, FR
# # q = [x0, y0, z0, x1, y1, z1, x2, y2, z2, x3, y3, z3]
# # _______gs (Sagittal Symmetry)__________|_______gt (Transversal symmetry)______
# permutation_Q_fs: [[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5], [3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8]]
# # Reflections are determined by body frame predefined orientation.
# # gs: [1, -1, 1], gt: [-1, 1, 1]
# reflection_Q_fs: [[1, -1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1], [-1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, 1]]

# L: Label Space symmetries____________________________________
# RL,FL,RR,FR
# q = [0, 1, 2, 3]
# _____gs______|______gt_____
permutation_Q_ls: [[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5], [3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8]]
# Reflections are determined by body frame predefined orientation.
reflection_Q_ls_lin: [[1, -1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1], [-1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, 1]]
reflection_Q_ls_ang: [[-1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1], [1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1]]
35 changes: 18 additions & 17 deletions research/evaluator_regression-com.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def main(MorphSym_version: str,
elif MorphSym_version == 'S4':
from ms_hgnn.datasets_py.soloDataset import Solo12Dataset
model_type = 'heterogeneous_gnn_s4_com'
elif MorphSym_version == 'C2':
from mi_hgnn.datasets_py.soloDataset import Solo12Dataset
model_type = 'heterogeneous_gnn_c2_com'
else:
raise ValueError("Other MorphSym versions are not supported for this script yet!")

Expand Down Expand Up @@ -56,33 +59,31 @@ def main(MorphSym_version: str,
print(f"Model Type: {model_type}")
print(f"Path to Checkpoint: {path_to_checkpoint}")
print("================================================")

test_dataset = prepare_test_dataset(Solo12Dataset, path_to_urdf, model_type, history_length, normalize=True, swap_legs=swap_legs, symmetry_operator=symmetry_operator, symmetry_mode=symmetry_mode, group_operator_path=group_operator_path)


root = Path(Path('.').parent, 'datasets', 'Solo-12').absolute()
test_dataset = prepare_test_dataset(Solo12Dataset, root, path_to_urdf, model_type, history_length, normalize=True,
symmetry_operator=symmetry_operator, symmetry_mode=symmetry_mode, group_operator_path=group_operator_path)

# Evaluate with model
pred, labels, loss, cos_sim_lin, cos_sim_ang = evaluate_model(path_to_checkpoint, test_dataset, symmetry_mode=symmetry_mode, group_operator_path=group_operator_path, data_path = root)
pred, labels, mse_loss, cos_sim_lin, cos_sim_ang = evaluate_model(path_to_checkpoint, test_dataset, symmetry_mode=symmetry_mode, group_operator_path=group_operator_path, data_path = root)
# Save to DataFrame
df = pandas.concat([df, pandas.DataFrame([[symmetry_operator, loss.item(), cos_sim_lin.item(), cos_sim_ang.item()]], columns=columns)], ignore_index=True)
df = pandas.concat([df, pandas.DataFrame([[symmetry_operator, mse_loss.item(), cos_sim_lin.item(), cos_sim_ang.item()]], columns=columns)], ignore_index=True)

# Print the results
print_results(loss, cos_sim_lin, cos_sim_ang)
print_results(mse_loss, cos_sim_lin, cos_sim_ang)

# Save to csv
df.to_csv(path_to_save_csv, index=False)
print("===> Evaluation Finished! Results saved to: ", path_to_save_csv)

def prepare_test_dataset(Solo12Dataset, path_to_urdf, model_type, history_length, normalize=True, swap_legs=None, symmetry_operator=None, symmetry_mode=None, group_operator_path=None):
def prepare_test_dataset(Solo12Dataset, root, path_to_urdf, model_type, history_length, normalize=True, symmetry_operator=None, symmetry_mode=None, group_operator_path=None):
"""Prepare the test dataset"""
# Define train and val sets
test_dataset = Solo12Dataset(Path(Path('.').parent, 'datasets', 'Solo-12').absolute(), path_to_urdf,
'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize)
solo12data_test = Solo12Dataset(root, path_to_urdf,
'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize, stage='test',
symmetry_operator=symmetry_operator, symmetry_mode=symmetry_mode, group_operator_path=group_operator_path)

# Take first 85% for training, and last 15% for validation
# Also remove the last entries, as dynamics models can't use last entry due to derivative calculation
data_len_minus_1 = test_dataset.__len__() - 1
split_index_test = int(np.round(data_len_minus_1 * 0.85)) # When value has .5, round to nearest-even
test_dataset = torch.utils.data.Subset(test_dataset, np.arange(0, split_index_test))
test_dataset = torch.utils.data.Subset(solo12data_test, np.arange(0, solo12data_test.__len__()))

return test_dataset

Expand All @@ -94,9 +95,9 @@ def print_results(loss, cos_sim_lin, cos_sim_ang):

if __name__ == "__main__":
# K4
MorphSym_version = 'S4'
path_to_checkpoint = "models/rich-sky-40/epoch=22-val_MSE_loss=0.36576-val_avg_cos_sim=0.70491.ckpt"
group_operator_path = 'cfg/mini_cheetah-k4.yaml'
MorphSym_version = 'C2'
path_to_checkpoint = "models/com_exp/autumn-cloud-2/"
group_operator_path = "cfg/a1-c2.yaml"
symmetry_operator_list = [None] # Can be 'gs' or 'gt' or 'gr' or None
symmetry_mode = 'MorphSym' # Can be 'Euclidean' or 'MorphSym' or None

Expand Down
9 changes: 8 additions & 1 deletion research/train_regression-com_msgn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ def main(seed,
lr=0.0024,
epochs=30,
logger_project_name='com_debug',
model_type='heterogeneous_gnn_s4_com',
model_type='heterogeneous_gnn_k4_com',
symmetry_operator=None,
symmetry_mode='MorphSym',
group_operator_path='cfg/solo-k4.yaml',
wandb_api_key = "eed5fa86674230b63649180cc343f14e1f1ace78"):
# ================================= CHANGE THESE ===================================
# wandb_api_key = "eed5fa86674230b63649180cc343f14e1f1ace78"
Expand Down Expand Up @@ -61,6 +64,8 @@ def main(seed,
disable_test=True,
data_path = root,
subfoler_name=logger_project_name,
symmetry_mode=symmetry_mode,
group_operator_path=group_operator_path,
wandb_api_key=wandb_api_key)

if __name__ == '__main__':
Expand All @@ -77,6 +82,7 @@ def main(seed,
parser.add_argument('--logger_project_name', type=str, default='com_debug', help='Logger project name')
# Model parameters
parser.add_argument('--model_type', type=str, default='heterogeneous_gnn_s4_com', help='Model type, options: heterogeneous_gnn_s4_com, heterogeneous_gnn_k4_com')
parser.add_argument('--group_operator_path', type=str, default='cfg/solo-k4.yaml', help='cfg/solo-k4.yaml or cfg/solo-c2.yaml')
parser.add_argument('--wandb_api_key', type=str, default='eed5fa86674230b63649180cc343f14e1f1ace78', help="Check your key at https://wandb.ai/authorize",)
args = parser.parse_args()

Expand All @@ -90,4 +96,5 @@ def main(seed,
epochs=args.epochs,
logger_project_name=args.logger_project_name,
model_type=args.model_type,
group_operator_path=args.group_operator_path,
wandb_api_key=args.wandb_api_key)
12 changes: 5 additions & 7 deletions src/ms_hgnn/datasets_py/flexibleDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def __init__(self,
# Check for valid data format
self.data_format = data_format
if self.data_format != 'dynamics' and self.data_format != 'mlp' and self.data_format != 'heterogeneous_gnn' and self.data_format != 'heterogeneous_gnn_k4' and self.data_format != 'heterogeneous_gnn_c2' \
and self.data_format != 'heterogeneous_gnn_k4_com' and self.data_format != 'heterogeneous_gnn_s4_com' and self.data_format != 'mlp_com':
and self.data_format != 'heterogeneous_gnn_k4_com' and self.data_format != 'heterogeneous_gnn_c2_com' and self.data_format != 'heterogeneous_gnn_s4_com' and self.data_format != 'mlp_com':
raise ValueError(
"Parameter 'data_format' must be 'dynamics', 'mlp', 'heterogeneous_gnn', or 'heterogeneous_gnn_k4' or 'heterogeneous_gnn_c2' or 'heterogeneous_gnn_k4_com' or 'heterogeneous_gnn_s4_com' or 'mlp_com'."
"Parameter 'data_format' must be 'dynamics', 'mlp', 'heterogeneous_gnn', or 'heterogeneous_gnn_k4' or 'heterogeneous_gnn_c2' or 'heterogeneous_gnn_k4_com' or 'heterogeneous_gnn_c2_com' or 'heterogeneous_gnn_s4_com' or 'mlp_com'."
)

# Setup the directories for raw and processed data, download it,
Expand Down Expand Up @@ -163,7 +163,7 @@ def __init__(self,
raise ValueError("Dataset must provide at least one input.")

# Premake the tensors for edge attributes and connections for HGNN
if self.data_format == 'heterogeneous_gnn' or self.data_format == 'heterogeneous_gnn_k4' or self.data_format == 'heterogeneous_gnn_c2' or self.data_format == 'heterogeneous_gnn_k4_com' or self.data_format == 'heterogeneous_gnn_s4_com':
if self.data_format == 'heterogeneous_gnn' or self.data_format == 'heterogeneous_gnn_k4' or self.data_format == 'heterogeneous_gnn_c2' or self.data_format == 'heterogeneous_gnn_k4_com' or self.data_format == 'heterogeneous_gnn_c2_com' or self.data_format == 'heterogeneous_gnn_s4_com':
bj, jb, jj, fj, jf = self.robotGraph.get_edge_index_matrices()
self.bj = torch.tensor(bj, dtype=torch.long)
self.jb = torch.tensor(jb, dtype=torch.long)
Expand All @@ -180,7 +180,7 @@ def __init__(self,

# Precompute feature matrix sizes for HGNN
# Calculate the size of the feature matrices
if self.data_format == 'heterogeneous_gnn' or self.data_format == 'heterogeneous_gnn_k4' or self.data_format == 'heterogeneous_gnn_c2' or self.data_format == 'heterogeneous_gnn_k4_com' or self.data_format == 'heterogeneous_gnn_s4_com':
if self.data_format == 'heterogeneous_gnn' or self.data_format == 'heterogeneous_gnn_k4' or self.data_format == 'heterogeneous_gnn_c2' or self.data_format == 'heterogeneous_gnn_k4_com' or self.data_format == 'heterogeneous_gnn_c2_com' or self.data_format == 'heterogeneous_gnn_s4_com':
self.hgnn_number_nodes = self.robotGraph.get_num_of_each_node_type()
self.base_width = len(self.variables_to_use_base) * 3 * self.history_length
self.joint_width = len(self.variables_to_use_joint) * self.history_length
Expand Down Expand Up @@ -455,9 +455,7 @@ def get(self, idx):
return self.get_helper_heterogeneous_gnn(idx)
elif self.data_format == 'heterogeneous_gnn_c2':
return self.get_helper_heterogeneous_gnn_c2(idx)
elif self.data_format == 'heterogeneous_gnn_k4_com':
return self.get_helper_heterogeneous_gnn(idx)
elif self.data_format == 'heterogeneous_gnn_s4_com':
elif self.data_format == 'heterogeneous_gnn_k4_com' or self.data_format == 'heterogeneous_gnn_s4_com' or self.data_format == 'heterogeneous_gnn_c2_com':
return self.get_helper_heterogeneous_gnn(idx)
else:
raise ValueError("Invalid data format.")
Expand Down
Loading

0 comments on commit f9ac502

Please sign in to comment.