Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lgan-com -> swei-dev -> main: Add CoM experiments' code. #7

Merged
merged 8 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions cfg/solo-c2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
group_label: C2

# QJ: Joint Space symmetries____________________________________
# _____RL____|___FL____|____RR____|____FR___|
# q = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
# Configure qj (joint-space) group actions
# _______gs (Sagittal Symmetry)__________|_______gt (Transversal symmetry)______
permutation_Q_js: [[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5], [3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8]]
# Reflections are determined by joint frame predefined orientation.
reflection_Q_js: [[-1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, 1], [1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1]]

# B: Body Space symmetries____________________________________
# RL, FL, RR, FR
# q_lin (acc) = [x0, y0, z0, x1, y1, z1, x2, y2, z2, x3, y3, z3]
# q_ang (vel) = [x0, y0, z0, x1, y1, z1, x2, y2, z2, x3, y3, z3]
# _______gs (Sagittal Symmetry)__________|_______gt (Transversal symmetry)______
permutation_Q_bs: [[3, 4, 5, 0, 1, 2], [0, 1, 2, 3, 4, 5]]
# Reflections are determined by body frame predefined orientation.
reflection_Q_bs_lin: [[1, -1, 1, 1, -1, 1], [-1, 1, 1, -1, 1, 1]]
reflection_Q_bs_ang: [[-1, 1, -1, -1, 1, -1], [1, -1, -1, 1, -1, -1]]

# F: Foot Space symmetries____________________________________
# RL, FL, RR, FR
# q = [x0, y0, z0, x1, y1, z1, x2, y2, z2, x3, y3, z3]
# _______gs (Sagittal Symmetry)__________|_______gt (Transversal symmetry)______
permutation_Q_fs: [[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5], [3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8]]
# Reflections are determined by body frame predefined orientation.
# gs: [1, -1, 1], gt: [-1, 1, 1]
reflection_Q_fs: [[1, -1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1], [-1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, 1]]

# L: Label Space symmetries____________________________________
# RL,FL,RR,FR
# q = [0, 1, 2, 3]
# _____gs______|______gt_____
permutation_Q_ls: [[2, 3, 0, 1], [1, 0, 3, 2]]
# Reflections are determined by body frame predefined orientation.
reflection_Q_ls: [[1, 1, 1, 1], [1, 1, 1, 1]]
38 changes: 38 additions & 0 deletions cfg/solo-k4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
group_label: K4

# QJ: Joint Space symmetries____________________________________
# _____RL____|___FL____|____RR____|____FR___|
# q = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
# Configure qj (joint-space) group actions
# _______gs (Sagittal Symmetry)__________|_______gt (Transversal symmetry)______
permutation_Q_js: [[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5], [3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8]]
# Reflections are determined by joint frame predefined orientation.
reflection_Q_js: [[-1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, 1], [1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1]]

# B: Body Space symmetries____________________________________
# RL, FL, RR, FR
# q_lin (acc) = [x0, y0, z0, x1, y1, z1, x2, y2, z2, x3, y3, z3]
# q_ang (vel) = [x0, y0, z0, x1, y1, z1, x2, y2, z2, x3, y3, z3]
# _______gs (Sagittal Symmetry)__________|_______gt (Transversal symmetry)______
permutation_Q_bs: [[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5], [3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8]]
# Reflections are determined by body frame predefined orientation.
reflection_Q_bs_lin: [[1, -1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1], [-1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, 1]]
reflection_Q_bs_ang: [[-1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1], [1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1]]

# # F: Foot Space symmetries____________________________________
# # RL, FL, RR, FR
# # q = [x0, y0, z0, x1, y1, z1, x2, y2, z2, x3, y3, z3]
# # _______gs (Sagittal Symmetry)__________|_______gt (Transversal symmetry)______
# permutation_Q_fs: [[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5], [3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8]]
# # Reflections are determined by body frame predefined orientation.
# # gs: [1, -1, 1], gt: [-1, 1, 1]
# reflection_Q_fs: [[1, -1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1], [-1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, 1]]

# L: Label Space symmetries____________________________________
# RL,FL,RR,FR
# q = [0, 1, 2, 3]
# _____gs______|______gt_____
permutation_Q_ls: [[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5], [3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8]]
# Reflections are determined by body frame predefined orientation.
reflection_Q_ls_lin: [[1, -1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1], [-1, 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, 1]]
reflection_Q_ls_ang: [[-1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1], [1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1]]
35 changes: 18 additions & 17 deletions research/evaluator_regression-com.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def main(MorphSym_version: str,
elif MorphSym_version == 'S4':
from ms_hgnn.datasets_py.soloDataset import Solo12Dataset
model_type = 'heterogeneous_gnn_s4_com'
elif MorphSym_version == 'C2':
from mi_hgnn.datasets_py.soloDataset import Solo12Dataset
model_type = 'heterogeneous_gnn_c2_com'
else:
raise ValueError("Other MorphSym versions are not supported for this script yet!")

Expand Down Expand Up @@ -56,33 +59,31 @@ def main(MorphSym_version: str,
print(f"Model Type: {model_type}")
print(f"Path to Checkpoint: {path_to_checkpoint}")
print("================================================")

test_dataset = prepare_test_dataset(Solo12Dataset, path_to_urdf, model_type, history_length, normalize=True, swap_legs=swap_legs, symmetry_operator=symmetry_operator, symmetry_mode=symmetry_mode, group_operator_path=group_operator_path)


root = Path(Path('.').parent, 'datasets', 'Solo-12').absolute()
test_dataset = prepare_test_dataset(Solo12Dataset, root, path_to_urdf, model_type, history_length, normalize=True,
symmetry_operator=symmetry_operator, symmetry_mode=symmetry_mode, group_operator_path=group_operator_path)

# Evaluate with model
pred, labels, loss, cos_sim_lin, cos_sim_ang = evaluate_model(path_to_checkpoint, test_dataset, symmetry_mode=symmetry_mode, group_operator_path=group_operator_path, data_path = root)
pred, labels, mse_loss, cos_sim_lin, cos_sim_ang = evaluate_model(path_to_checkpoint, test_dataset, symmetry_mode=symmetry_mode, group_operator_path=group_operator_path, data_path = root)
# Save to DataFrame
df = pandas.concat([df, pandas.DataFrame([[symmetry_operator, loss.item(), cos_sim_lin.item(), cos_sim_ang.item()]], columns=columns)], ignore_index=True)
df = pandas.concat([df, pandas.DataFrame([[symmetry_operator, mse_loss.item(), cos_sim_lin.item(), cos_sim_ang.item()]], columns=columns)], ignore_index=True)

# Print the results
print_results(loss, cos_sim_lin, cos_sim_ang)
print_results(mse_loss, cos_sim_lin, cos_sim_ang)

# Save to csv
df.to_csv(path_to_save_csv, index=False)
print("===> Evaluation Finished! Results saved to: ", path_to_save_csv)

def prepare_test_dataset(Solo12Dataset, path_to_urdf, model_type, history_length, normalize=True, swap_legs=None, symmetry_operator=None, symmetry_mode=None, group_operator_path=None):
def prepare_test_dataset(Solo12Dataset, root, path_to_urdf, model_type, history_length, normalize=True, symmetry_operator=None, symmetry_mode=None, group_operator_path=None):
"""Prepare the test dataset"""
# Define train and val sets
test_dataset = Solo12Dataset(Path(Path('.').parent, 'datasets', 'Solo-12').absolute(), path_to_urdf,
'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize)
solo12data_test = Solo12Dataset(root, path_to_urdf,
'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize, stage='test',
symmetry_operator=symmetry_operator, symmetry_mode=symmetry_mode, group_operator_path=group_operator_path)

# Take first 85% for training, and last 15% for validation
# Also remove the last entries, as dynamics models can't use last entry due to derivative calculation
data_len_minus_1 = test_dataset.__len__() - 1
split_index_test = int(np.round(data_len_minus_1 * 0.85)) # When value has .5, round to nearest-even
test_dataset = torch.utils.data.Subset(test_dataset, np.arange(0, split_index_test))
test_dataset = torch.utils.data.Subset(solo12data_test, np.arange(0, solo12data_test.__len__()))

return test_dataset

Expand All @@ -94,9 +95,9 @@ def print_results(loss, cos_sim_lin, cos_sim_ang):

if __name__ == "__main__":
# K4
MorphSym_version = 'S4'
path_to_checkpoint = "models/rich-sky-40/epoch=22-val_MSE_loss=0.36576-val_avg_cos_sim=0.70491.ckpt"
group_operator_path = 'cfg/mini_cheetah-k4.yaml'
MorphSym_version = 'C2'
path_to_checkpoint = "models/com_exp/autumn-cloud-2/"
group_operator_path = "cfg/a1-c2.yaml"
symmetry_operator_list = [None] # Can be 'gs' or 'gt' or 'gr' or None
symmetry_mode = 'MorphSym' # Can be 'Euclidean' or 'MorphSym' or None

Expand Down
9 changes: 8 additions & 1 deletion research/train_regression-com_msgn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ def main(seed,
lr=0.0024,
epochs=30,
logger_project_name='com_debug',
model_type='heterogeneous_gnn_s4_com',
model_type='heterogeneous_gnn_k4_com',
symmetry_operator=None,
symmetry_mode='MorphSym',
group_operator_path='cfg/solo-k4.yaml',
wandb_api_key = "eed5fa86674230b63649180cc343f14e1f1ace78"):
# ================================= CHANGE THESE ===================================
# wandb_api_key = "eed5fa86674230b63649180cc343f14e1f1ace78"
Expand Down Expand Up @@ -61,6 +64,8 @@ def main(seed,
disable_test=True,
data_path = root,
subfoler_name=logger_project_name,
symmetry_mode=symmetry_mode,
group_operator_path=group_operator_path,
wandb_api_key=wandb_api_key)

if __name__ == '__main__':
Expand All @@ -77,6 +82,7 @@ def main(seed,
parser.add_argument('--logger_project_name', type=str, default='com_debug', help='Logger project name')
# Model parameters
parser.add_argument('--model_type', type=str, default='heterogeneous_gnn_s4_com', help='Model type, options: heterogeneous_gnn_s4_com, heterogeneous_gnn_k4_com')
parser.add_argument('--group_operator_path', type=str, default='cfg/solo-k4.yaml', help='cfg/solo-k4.yaml or cfg/solo-c2.yaml')
parser.add_argument('--wandb_api_key', type=str, default='eed5fa86674230b63649180cc343f14e1f1ace78', help="Check your key at https://wandb.ai/authorize",)
args = parser.parse_args()

Expand All @@ -90,4 +96,5 @@ def main(seed,
epochs=args.epochs,
logger_project_name=args.logger_project_name,
model_type=args.model_type,
group_operator_path=args.group_operator_path,
wandb_api_key=args.wandb_api_key)
12 changes: 5 additions & 7 deletions src/ms_hgnn/datasets_py/flexibleDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def __init__(self,
# Check for valid data format
self.data_format = data_format
if self.data_format != 'dynamics' and self.data_format != 'mlp' and self.data_format != 'heterogeneous_gnn' and self.data_format != 'heterogeneous_gnn_k4' and self.data_format != 'heterogeneous_gnn_c2' \
and self.data_format != 'heterogeneous_gnn_k4_com' and self.data_format != 'heterogeneous_gnn_s4_com' and self.data_format != 'mlp_com':
and self.data_format != 'heterogeneous_gnn_k4_com' and self.data_format != 'heterogeneous_gnn_c2_com' and self.data_format != 'heterogeneous_gnn_s4_com' and self.data_format != 'mlp_com':
raise ValueError(
"Parameter 'data_format' must be 'dynamics', 'mlp', 'heterogeneous_gnn', or 'heterogeneous_gnn_k4' or 'heterogeneous_gnn_c2' or 'heterogeneous_gnn_k4_com' or 'heterogeneous_gnn_s4_com' or 'mlp_com'."
"Parameter 'data_format' must be 'dynamics', 'mlp', 'heterogeneous_gnn', or 'heterogeneous_gnn_k4' or 'heterogeneous_gnn_c2' or 'heterogeneous_gnn_k4_com' or 'heterogeneous_gnn_c2_com' or 'heterogeneous_gnn_s4_com' or 'mlp_com'."
)

# Setup the directories for raw and processed data, download it,
Expand Down Expand Up @@ -163,7 +163,7 @@ def __init__(self,
raise ValueError("Dataset must provide at least one input.")

# Premake the tensors for edge attributes and connections for HGNN
if self.data_format == 'heterogeneous_gnn' or self.data_format == 'heterogeneous_gnn_k4' or self.data_format == 'heterogeneous_gnn_c2' or self.data_format == 'heterogeneous_gnn_k4_com' or self.data_format == 'heterogeneous_gnn_s4_com':
if self.data_format == 'heterogeneous_gnn' or self.data_format == 'heterogeneous_gnn_k4' or self.data_format == 'heterogeneous_gnn_c2' or self.data_format == 'heterogeneous_gnn_k4_com' or self.data_format == 'heterogeneous_gnn_c2_com' or self.data_format == 'heterogeneous_gnn_s4_com':
bj, jb, jj, fj, jf = self.robotGraph.get_edge_index_matrices()
self.bj = torch.tensor(bj, dtype=torch.long)
self.jb = torch.tensor(jb, dtype=torch.long)
Expand All @@ -180,7 +180,7 @@ def __init__(self,

# Precompute feature matrix sizes for HGNN
# Calculate the size of the feature matrices
if self.data_format == 'heterogeneous_gnn' or self.data_format == 'heterogeneous_gnn_k4' or self.data_format == 'heterogeneous_gnn_c2' or self.data_format == 'heterogeneous_gnn_k4_com' or self.data_format == 'heterogeneous_gnn_s4_com':
if self.data_format == 'heterogeneous_gnn' or self.data_format == 'heterogeneous_gnn_k4' or self.data_format == 'heterogeneous_gnn_c2' or self.data_format == 'heterogeneous_gnn_k4_com' or self.data_format == 'heterogeneous_gnn_c2_com' or self.data_format == 'heterogeneous_gnn_s4_com':
self.hgnn_number_nodes = self.robotGraph.get_num_of_each_node_type()
self.base_width = len(self.variables_to_use_base) * 3 * self.history_length
self.joint_width = len(self.variables_to_use_joint) * self.history_length
Expand Down Expand Up @@ -455,9 +455,7 @@ def get(self, idx):
return self.get_helper_heterogeneous_gnn(idx)
elif self.data_format == 'heterogeneous_gnn_c2':
return self.get_helper_heterogeneous_gnn_c2(idx)
elif self.data_format == 'heterogeneous_gnn_k4_com':
return self.get_helper_heterogeneous_gnn(idx)
elif self.data_format == 'heterogeneous_gnn_s4_com':
elif self.data_format == 'heterogeneous_gnn_k4_com' or self.data_format == 'heterogeneous_gnn_s4_com' or self.data_format == 'heterogeneous_gnn_c2_com':
return self.get_helper_heterogeneous_gnn(idx)
else:
raise ValueError("Invalid data format.")
Expand Down
Loading
Loading