diff --git a/grl/cxy_scripts/swiss_roll_nongen_vary.py b/grl/cxy_scripts/swiss_roll_nongen_vary.py index 25f7a45c..31913001 100644 --- a/grl/cxy_scripts/swiss_roll_nongen_vary.py +++ b/grl/cxy_scripts/swiss_roll_nongen_vary.py @@ -25,12 +25,13 @@ from grl.utils import set_seed from grl.utils.log import log -exp_name = "swiss-roll-nongen-varying-world-model-noise" +exp_name = "swiss-roll-nongen-varying-world-model-mlpencoder" x_size = 2 condition_size=3 device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") t_embedding_dim = 32 +condition_dim=256 data_num=100000 config = EasyDict( dict( @@ -54,11 +55,11 @@ ), ), condition_encoder = dict( - type="GaussianFourierProjectionEncoder", + type="MLPEncoder", args=dict( - embed_dim=t_embedding_dim, # after flatten, 32*3=96 - x_shape=(condition_size,), - scale=30.0, + hidden_sizes=[condition_size] + [condition_dim] * 2, + output_size=condition_dim, + activation='relu', ), ), backbone=dict( @@ -67,7 +68,7 @@ hidden_sizes=[512, 256, 128], output_dim=x_size, t_dim=t_embedding_dim, - condition_dim=t_embedding_dim*condition_size, + condition_dim=condition_dim, condition_hidden_dim=64, t_condition_hidden_dim=128, ), diff --git a/grl/cxy_scripts/swiss_roll_nongen_vary_eval.py b/grl/cxy_scripts/swiss_roll_nongen_vary_eval.py index 22bfd263..ddafa5ee 100644 --- a/grl/cxy_scripts/swiss_roll_nongen_vary_eval.py +++ b/grl/cxy_scripts/swiss_roll_nongen_vary_eval.py @@ -25,12 +25,13 @@ from grl.utils import set_seed from grl.utils.log import log -exp_name = "swiss-roll-nongen-varying-world-model-noise" +exp_name = "swiss-roll-nongen-varying-world-model-mlpencoder" x_size = 2 condition_size=3 device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") t_embedding_dim = 32 +condition_dim=256 data_num=100000 config = EasyDict( dict( @@ -54,11 +55,11 @@ ), ), condition_encoder = dict( - type="GaussianFourierProjectionEncoder", + type="MLPEncoder", args=dict( - embed_dim=t_embedding_dim, # after flatten, 32*3=96 - x_shape=(condition_size,), - scale=30.0, + hidden_sizes=[condition_size] + [condition_dim] * 2, + output_size=condition_dim, + activation='relu', ), ), backbone=dict( @@ -67,7 +68,7 @@ hidden_sizes=[512, 256, 128], output_dim=x_size, t_dim=t_embedding_dim, - condition_dim=t_embedding_dim*condition_size, + condition_dim=condition_dim, condition_hidden_dim=64, t_condition_hidden_dim=128, ), diff --git a/grl/cxy_scripts/swiss_roll_world_model_2_vary.py b/grl/cxy_scripts/swiss_roll_world_model_2_vary.py index 33c6ebf8..22cf2ea0 100644 --- a/grl/cxy_scripts/swiss_roll_world_model_2_vary.py +++ b/grl/cxy_scripts/swiss_roll_world_model_2_vary.py @@ -27,14 +27,13 @@ from grl.utils import set_seed from grl.utils.log import log -# exp_name = "swiss-roll-dynamic-icfm-varying-world-model" -# exp_name = "swiss-roll-dynamic-icfm-varying-world-model-noise" -exp_name = "swiss-roll-dynamic-icfm-varying-world-model-test" +exp_name = "swiss-roll-dynamic-icfm-varying-world-model-mlpencoder" x_size = 2 condition_size=3 device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") t_embedding_dim = 32 +condition_dim=256 t_encoder = dict( type="GaussianFourierProjectionTimeEncoder", args=dict( @@ -43,11 +42,11 @@ ), ) condition_encoder = dict( - type="GaussianFourierProjectionEncoder", + type="MLPEncoder", args=dict( - embed_dim=t_embedding_dim, # after flatten, 96 - x_shape=(condition_size,), - scale=30.0, + hidden_sizes=[condition_size] + [condition_dim] * 2, + output_size=condition_dim, + activation='relu', ), ) data_num=100000 @@ -88,7 +87,7 @@ hidden_sizes=[512, 256, 128], output_dim=x_size, t_dim=t_embedding_dim, - condition_dim=t_embedding_dim*condition_size, + condition_dim=condition_dim, condition_hidden_dim=64, t_condition_hidden_dim=128, ), @@ -294,31 +293,31 @@ def exit_handler(signal, frame): if iteration <= last_iteration: continue - if iteration > 0 and iteration % config.parameter.eval_freq == 0: - # if True: - flow_model.eval() - t_span = torch.linspace(0.0, 1.0, 1000) - customized_eval_dataset = DynamicSwissRollDataset(config.dataset, train=True) - x0_eval, x1_eval, action_eval, background_eval = customized_eval_dataset.data['state'], customized_eval_dataset.data['next_state'], customized_eval_dataset.data['action'], customized_eval_dataset.data['background'] - x0_eval = torch.tensor(x0_eval).to(config.device) - x1_eval = torch.tensor(x1_eval).to(config.device) - condition_eval = torch.cat((torch.tensor(action_eval).unsqueeze(1), torch.tensor(background_eval)), dim=1).float().to(config.device) - - # ramdom choose 500 samples from x0_eval, x1_eval, action_eval - x0_eval = x0_eval[:500] - x1_eval = x1_eval[:500] - condition_eval = condition_eval[:500] + # if iteration > 0 and iteration % config.parameter.eval_freq == 0: + # # if True: + # flow_model.eval() + # t_span = torch.linspace(0.0, 1.0, 1000) + # customized_eval_dataset = DynamicSwissRollDataset(config.dataset, train=True) + # x0_eval, x1_eval, action_eval, background_eval = customized_eval_dataset.data['state'], customized_eval_dataset.data['next_state'], customized_eval_dataset.data['action'], customized_eval_dataset.data['background'] + # x0_eval = torch.tensor(x0_eval).to(config.device) + # x1_eval = torch.tensor(x1_eval).to(config.device) + # condition_eval = torch.cat((torch.tensor(action_eval).unsqueeze(1), torch.tensor(background_eval)), dim=1).float().to(config.device) + + # # ramdom choose 500 samples from x0_eval, x1_eval, action_eval + # x0_eval = x0_eval[:500] + # x1_eval = x1_eval[:500] + # condition_eval = condition_eval[:500] - # action_eval = -torch.ones_like(action_eval).to(config.device)*0.05 - x_t = ( - flow_model.sample_forward_process(t_span=t_span, x_0=x0_eval, condition=condition_eval) - .cpu() - .detach() - ) - x_t = [ - x.squeeze(0) for x in torch.split(x_t, split_size_or_sections=1, dim=0) - ] - render_video(x_t, config.parameter.video_save_path, iteration, fps=100, dpi=100) + # # action_eval = -torch.ones_like(action_eval).to(config.device)*0.05 + # x_t = ( + # flow_model.sample_forward_process(t_span=t_span, x_0=x0_eval, condition=condition_eval) + # .cpu() + # .detach() + # ) + # x_t = [ + # x.squeeze(0) for x in torch.split(x_t, split_size_or_sections=1, dim=0) + # ] + # render_video(x_t, config.parameter.video_save_path, iteration, fps=100, dpi=100) batch_data = next(data_generator) diff --git a/grl/cxy_scripts/swiss_roll_world_model_2_vary_eval.py b/grl/cxy_scripts/swiss_roll_world_model_2_vary_eval.py index 7f4289dd..ba933c93 100644 --- a/grl/cxy_scripts/swiss_roll_world_model_2_vary_eval.py +++ b/grl/cxy_scripts/swiss_roll_world_model_2_vary_eval.py @@ -24,12 +24,13 @@ from grl.utils import set_seed from grl.utils.log import log -exp_name = "swiss-roll-dynamic-icfm-varying-world-model" +exp_name = "swiss-roll-dynamic-icfm-varying-world-model-mlpencoder" x_size = 2 condition_size=3 device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") t_embedding_dim = 32 +condition_dim=256 t_encoder = dict( type="GaussianFourierProjectionTimeEncoder", args=dict( @@ -38,11 +39,11 @@ ), ) condition_encoder = dict( - type="GaussianFourierProjectionEncoder", + type="MLPEncoder", args=dict( - embed_dim=t_embedding_dim, # after flatten, 96 - x_shape=(condition_size,), - scale=30.0, + hidden_sizes=[condition_size] + [condition_dim] * 2, + output_size=condition_dim, + activation='relu', ), ) data_num=100000 @@ -54,10 +55,10 @@ n_samples=data_num, test_n_samples=data_num, pair_samples=10000, - delta_t_barrie=0.1, - noise=0.001, - # delta_t_barrie=0.2, - # noise=0.3, + # delta_t_barrie=0.1, + # noise=0.001, + delta_t_barrie=0.2, + noise=0.3, ), flow_model=dict( device=device, @@ -83,7 +84,7 @@ hidden_sizes=[512, 256, 128], output_dim=x_size, t_dim=t_embedding_dim, - condition_dim=t_embedding_dim*condition_size, + condition_dim=condition_dim, condition_hidden_dim=64, t_condition_hidden_dim=128, ), @@ -294,4 +295,4 @@ def exit_handler(signal, frame): x_t = [ x.squeeze(0) for x in torch.split(x_t, split_size_or_sections=1, dim=0) ] - render_eval_video(origin_line[0], x_t, config.parameter.video_save_path, f"eval_video_param_{param}", fps=100, dpi=100) + render_eval_video(origin_line, x_t, config.parameter.video_save_path, f"eval_video_param_{param}", fps=100, dpi=100) diff --git a/grl/datasets/swiss_roll_dataset.py b/grl/datasets/swiss_roll_dataset.py index 0c94866e..8ca3f76b 100644 --- a/grl/datasets/swiss_roll_dataset.py +++ b/grl/datasets/swiss_roll_dataset.py @@ -227,8 +227,8 @@ def make_swiss_roll(n_samples=100, noise=0.0, a = 1.5, b = 1): origin_line = origin_line * 10 - 5 # # pair the sampled (x, t) 2by2 - idx_1 = torch.randint(100, (100,)) - idx_2 = torch.randint(n_samples, (100,)) + idx_1 = torch.randint(100, (1000,)) + idx_2 = torch.randint(n_samples, (1000,)) unfil_x_1 = x[idx_1] unfil_t_1 = t[idx_1] unfil_x_2 = x[idx_2] diff --git a/grl/neural_network/encoders.py b/grl/neural_network/encoders.py index 7cc1ddf0..c21a3316 100644 --- a/grl/neural_network/encoders.py +++ b/grl/neural_network/encoders.py @@ -1,8 +1,10 @@ import math +from typing import Callable, List, Optional, Union import numpy as np import torch import torch.nn as nn +from grl.neural_network.activation import get_activation def get_encoder(type: str): @@ -229,9 +231,73 @@ def forward(self, x): return emb +class MLPEncoder(nn.Module): + # ./grl/neural_network/__init__.py#L365 + + def __init__( + self, + hidden_sizes: List[int], + output_size: int, + activation: Union[str, List[str]], + dropout: float = None, + layernorm: bool = False, + final_activation: str = None, + scale: float = None, + shrink: float = None, + ): + super().__init__() + + self.model = nn.Sequential() + + for i in range(len(hidden_sizes) - 1): + self.model.add_module( + "linear" + str(i), nn.Linear(hidden_sizes[i], hidden_sizes[i + 1]) + ) + + if isinstance(activation, list): + self.model.add_module( + "activation" + str(i), get_activation(activation[i]) + ) + else: + self.model.add_module("activation" + str(i), get_activation(activation)) + if dropout is not None and dropout > 0: + self.model.add_module("dropout", nn.Dropout(dropout)) + if layernorm: + self.model.add_module("layernorm", nn.LayerNorm(hidden_sizes[i + 1])) + + self.model.add_module( + "linear" + str(len(hidden_sizes) - 1), + nn.Linear(hidden_sizes[-1], output_size), + ) + + if final_activation is not None: + self.model.add_module("final_activation", get_activation(final_activation)) + + if scale is not None: + self.scale = nn.Parameter(torch.tensor(scale), requires_grad=False) + else: + self.scale = 1.0 + + # shrink the weight of linear layer 'linear'+str(len(hidden_sizes) to it's origin 0.01 + if shrink is not None: + if final_activation is not None: + self.model[-2].weight.data.normal_(0, shrink) + else: + self.model[-1].weight.data.normal_(0, shrink) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Return the output of the multi-layer perceptron. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor. + """ + return self.scale * self.model(x) + ENCODERS = { "GaussianFourierProjectionTimeEncoder".lower(): GaussianFourierProjectionTimeEncoder, "GaussianFourierProjectionEncoder".lower(): GaussianFourierProjectionEncoder, "ExponentialFourierProjectionTimeEncoder".lower(): ExponentialFourierProjectionTimeEncoder, "SinusoidalPosEmb".lower(): SinusoidalPosEmb, + "MLPEncoder".lower(): MLPEncoder, }