Skip to content

Commit

Permalink
MLPEncoder on condition
Browse files Browse the repository at this point in the history
  • Loading branch information
Mossforest committed Aug 2, 2024
1 parent c096db8 commit 2954305
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 57 deletions.
13 changes: 7 additions & 6 deletions grl/cxy_scripts/swiss_roll_nongen_vary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
from grl.utils import set_seed
from grl.utils.log import log

exp_name = "swiss-roll-nongen-varying-world-model-noise"
exp_name = "swiss-roll-nongen-varying-world-model-mlpencoder"

x_size = 2
condition_size=3
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
t_embedding_dim = 32
condition_dim=256
data_num=100000
config = EasyDict(
dict(
Expand All @@ -54,11 +55,11 @@
),
),
condition_encoder = dict(
type="GaussianFourierProjectionEncoder",
type="MLPEncoder",
args=dict(
embed_dim=t_embedding_dim, # after flatten, 32*3=96
x_shape=(condition_size,),
scale=30.0,
hidden_sizes=[condition_size] + [condition_dim] * 2,
output_size=condition_dim,
activation='relu',
),
),
backbone=dict(
Expand All @@ -67,7 +68,7 @@
hidden_sizes=[512, 256, 128],
output_dim=x_size,
t_dim=t_embedding_dim,
condition_dim=t_embedding_dim*condition_size,
condition_dim=condition_dim,
condition_hidden_dim=64,
t_condition_hidden_dim=128,
),
Expand Down
13 changes: 7 additions & 6 deletions grl/cxy_scripts/swiss_roll_nongen_vary_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
from grl.utils import set_seed
from grl.utils.log import log

exp_name = "swiss-roll-nongen-varying-world-model-noise"
exp_name = "swiss-roll-nongen-varying-world-model-mlpencoder"

x_size = 2
condition_size=3
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
t_embedding_dim = 32
condition_dim=256
data_num=100000
config = EasyDict(
dict(
Expand All @@ -54,11 +55,11 @@
),
),
condition_encoder = dict(
type="GaussianFourierProjectionEncoder",
type="MLPEncoder",
args=dict(
embed_dim=t_embedding_dim, # after flatten, 32*3=96
x_shape=(condition_size,),
scale=30.0,
hidden_sizes=[condition_size] + [condition_dim] * 2,
output_size=condition_dim,
activation='relu',
),
),
backbone=dict(
Expand All @@ -67,7 +68,7 @@
hidden_sizes=[512, 256, 128],
output_dim=x_size,
t_dim=t_embedding_dim,
condition_dim=t_embedding_dim*condition_size,
condition_dim=condition_dim,
condition_hidden_dim=64,
t_condition_hidden_dim=128,
),
Expand Down
63 changes: 31 additions & 32 deletions grl/cxy_scripts/swiss_roll_world_model_2_vary.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,13 @@
from grl.utils import set_seed
from grl.utils.log import log

# exp_name = "swiss-roll-dynamic-icfm-varying-world-model"
# exp_name = "swiss-roll-dynamic-icfm-varying-world-model-noise"
exp_name = "swiss-roll-dynamic-icfm-varying-world-model-test"
exp_name = "swiss-roll-dynamic-icfm-varying-world-model-mlpencoder"

x_size = 2
condition_size=3
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
t_embedding_dim = 32
condition_dim=256
t_encoder = dict(
type="GaussianFourierProjectionTimeEncoder",
args=dict(
Expand All @@ -43,11 +42,11 @@
),
)
condition_encoder = dict(
type="GaussianFourierProjectionEncoder",
type="MLPEncoder",
args=dict(
embed_dim=t_embedding_dim, # after flatten, 96
x_shape=(condition_size,),
scale=30.0,
hidden_sizes=[condition_size] + [condition_dim] * 2,
output_size=condition_dim,
activation='relu',
),
)
data_num=100000
Expand Down Expand Up @@ -88,7 +87,7 @@
hidden_sizes=[512, 256, 128],
output_dim=x_size,
t_dim=t_embedding_dim,
condition_dim=t_embedding_dim*condition_size,
condition_dim=condition_dim,
condition_hidden_dim=64,
t_condition_hidden_dim=128,
),
Expand Down Expand Up @@ -294,31 +293,31 @@ def exit_handler(signal, frame):
if iteration <= last_iteration:
continue

if iteration > 0 and iteration % config.parameter.eval_freq == 0:
# if True:
flow_model.eval()
t_span = torch.linspace(0.0, 1.0, 1000)
customized_eval_dataset = DynamicSwissRollDataset(config.dataset, train=True)
x0_eval, x1_eval, action_eval, background_eval = customized_eval_dataset.data['state'], customized_eval_dataset.data['next_state'], customized_eval_dataset.data['action'], customized_eval_dataset.data['background']
x0_eval = torch.tensor(x0_eval).to(config.device)
x1_eval = torch.tensor(x1_eval).to(config.device)
condition_eval = torch.cat((torch.tensor(action_eval).unsqueeze(1), torch.tensor(background_eval)), dim=1).float().to(config.device)

# ramdom choose 500 samples from x0_eval, x1_eval, action_eval
x0_eval = x0_eval[:500]
x1_eval = x1_eval[:500]
condition_eval = condition_eval[:500]
# if iteration > 0 and iteration % config.parameter.eval_freq == 0:
# # if True:
# flow_model.eval()
# t_span = torch.linspace(0.0, 1.0, 1000)
# customized_eval_dataset = DynamicSwissRollDataset(config.dataset, train=True)
# x0_eval, x1_eval, action_eval, background_eval = customized_eval_dataset.data['state'], customized_eval_dataset.data['next_state'], customized_eval_dataset.data['action'], customized_eval_dataset.data['background']
# x0_eval = torch.tensor(x0_eval).to(config.device)
# x1_eval = torch.tensor(x1_eval).to(config.device)
# condition_eval = torch.cat((torch.tensor(action_eval).unsqueeze(1), torch.tensor(background_eval)), dim=1).float().to(config.device)

# # ramdom choose 500 samples from x0_eval, x1_eval, action_eval
# x0_eval = x0_eval[:500]
# x1_eval = x1_eval[:500]
# condition_eval = condition_eval[:500]

# action_eval = -torch.ones_like(action_eval).to(config.device)*0.05
x_t = (
flow_model.sample_forward_process(t_span=t_span, x_0=x0_eval, condition=condition_eval)
.cpu()
.detach()
)
x_t = [
x.squeeze(0) for x in torch.split(x_t, split_size_or_sections=1, dim=0)
]
render_video(x_t, config.parameter.video_save_path, iteration, fps=100, dpi=100)
# # action_eval = -torch.ones_like(action_eval).to(config.device)*0.05
# x_t = (
# flow_model.sample_forward_process(t_span=t_span, x_0=x0_eval, condition=condition_eval)
# .cpu()
# .detach()
# )
# x_t = [
# x.squeeze(0) for x in torch.split(x_t, split_size_or_sections=1, dim=0)
# ]
# render_video(x_t, config.parameter.video_save_path, iteration, fps=100, dpi=100)

batch_data = next(data_generator)

Expand Down
23 changes: 12 additions & 11 deletions grl/cxy_scripts/swiss_roll_world_model_2_vary_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
from grl.utils import set_seed
from grl.utils.log import log

exp_name = "swiss-roll-dynamic-icfm-varying-world-model"
exp_name = "swiss-roll-dynamic-icfm-varying-world-model-mlpencoder"

x_size = 2
condition_size=3
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
t_embedding_dim = 32
condition_dim=256
t_encoder = dict(
type="GaussianFourierProjectionTimeEncoder",
args=dict(
Expand All @@ -38,11 +39,11 @@
),
)
condition_encoder = dict(
type="GaussianFourierProjectionEncoder",
type="MLPEncoder",
args=dict(
embed_dim=t_embedding_dim, # after flatten, 96
x_shape=(condition_size,),
scale=30.0,
hidden_sizes=[condition_size] + [condition_dim] * 2,
output_size=condition_dim,
activation='relu',
),
)
data_num=100000
Expand All @@ -54,10 +55,10 @@
n_samples=data_num,
test_n_samples=data_num,
pair_samples=10000,
delta_t_barrie=0.1,
noise=0.001,
# delta_t_barrie=0.2,
# noise=0.3,
# delta_t_barrie=0.1,
# noise=0.001,
delta_t_barrie=0.2,
noise=0.3,
),
flow_model=dict(
device=device,
Expand All @@ -83,7 +84,7 @@
hidden_sizes=[512, 256, 128],
output_dim=x_size,
t_dim=t_embedding_dim,
condition_dim=t_embedding_dim*condition_size,
condition_dim=condition_dim,
condition_hidden_dim=64,
t_condition_hidden_dim=128,
),
Expand Down Expand Up @@ -294,4 +295,4 @@ def exit_handler(signal, frame):
x_t = [
x.squeeze(0) for x in torch.split(x_t, split_size_or_sections=1, dim=0)
]
render_eval_video(origin_line[0], x_t, config.parameter.video_save_path, f"eval_video_param_{param}", fps=100, dpi=100)
render_eval_video(origin_line, x_t, config.parameter.video_save_path, f"eval_video_param_{param}", fps=100, dpi=100)
4 changes: 2 additions & 2 deletions grl/datasets/swiss_roll_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def make_swiss_roll(n_samples=100, noise=0.0, a = 1.5, b = 1):
origin_line = origin_line * 10 - 5

# # pair the sampled (x, t) 2by2
idx_1 = torch.randint(100, (100,))
idx_2 = torch.randint(n_samples, (100,))
idx_1 = torch.randint(100, (1000,))
idx_2 = torch.randint(n_samples, (1000,))
unfil_x_1 = x[idx_1]
unfil_t_1 = t[idx_1]
unfil_x_2 = x[idx_2]
Expand Down
66 changes: 66 additions & 0 deletions grl/neural_network/encoders.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import math
from typing import Callable, List, Optional, Union

import numpy as np
import torch
import torch.nn as nn
from grl.neural_network.activation import get_activation


def get_encoder(type: str):
Expand Down Expand Up @@ -229,9 +231,73 @@ def forward(self, x):
return emb


class MLPEncoder(nn.Module):
# ./grl/neural_network/__init__.py#L365

def __init__(
self,
hidden_sizes: List[int],
output_size: int,
activation: Union[str, List[str]],
dropout: float = None,
layernorm: bool = False,
final_activation: str = None,
scale: float = None,
shrink: float = None,
):
super().__init__()

self.model = nn.Sequential()

for i in range(len(hidden_sizes) - 1):
self.model.add_module(
"linear" + str(i), nn.Linear(hidden_sizes[i], hidden_sizes[i + 1])
)

if isinstance(activation, list):
self.model.add_module(
"activation" + str(i), get_activation(activation[i])
)
else:
self.model.add_module("activation" + str(i), get_activation(activation))
if dropout is not None and dropout > 0:
self.model.add_module("dropout", nn.Dropout(dropout))
if layernorm:
self.model.add_module("layernorm", nn.LayerNorm(hidden_sizes[i + 1]))

self.model.add_module(
"linear" + str(len(hidden_sizes) - 1),
nn.Linear(hidden_sizes[-1], output_size),
)

if final_activation is not None:
self.model.add_module("final_activation", get_activation(final_activation))

if scale is not None:
self.scale = nn.Parameter(torch.tensor(scale), requires_grad=False)
else:
self.scale = 1.0

# shrink the weight of linear layer 'linear'+str(len(hidden_sizes) to it's origin 0.01
if shrink is not None:
if final_activation is not None:
self.model[-2].weight.data.normal_(0, shrink)
else:
self.model[-1].weight.data.normal_(0, shrink)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
Return the output of the multi-layer perceptron.
Arguments:
- x (:obj:`torch.Tensor`): The input tensor.
"""
return self.scale * self.model(x)

ENCODERS = {
"GaussianFourierProjectionTimeEncoder".lower(): GaussianFourierProjectionTimeEncoder,
"GaussianFourierProjectionEncoder".lower(): GaussianFourierProjectionEncoder,
"ExponentialFourierProjectionTimeEncoder".lower(): ExponentialFourierProjectionTimeEncoder,
"SinusoidalPosEmb".lower(): SinusoidalPosEmb,
"MLPEncoder".lower(): MLPEncoder,
}

0 comments on commit 2954305

Please sign in to comment.