Skip to content

Commit

Permalink
simplify env
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed May 16, 2024
1 parent 6f78b82 commit 499bd06
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 148 deletions.
39 changes: 1 addition & 38 deletions acegen/models/gru.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from copy import deepcopy
from typing import Optional

import torch
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs import ExplorationType
from torchrl.modules import ActorValueOperator, GRUModule, MLP, ProbabilisticActor

Expand Down Expand Up @@ -154,14 +152,7 @@ def create_gru_actor(
recurrent_state,
python_based,
)
spec = CompositeSpec(
**{
recurrent_state: UnboundedContinuousTensorSpec(
shape=torch.Size([gru.gru.num_layers, gru.gru.hidden_size]),
dtype=torch.float32,
)
}
)

actor_inference_model = TensorDictSequential(embedding, gru, head)
actor_training_model = TensorDictSequential(
embedding,
Expand All @@ -187,9 +178,6 @@ def create_gru_actor(
default_interaction_type=ExplorationType.RANDOM,
)

actor_training_model.rnn_spec = spec
actor_inference_model.rnn_spec = deepcopy(spec)

return actor_training_model, actor_inference_model


Expand Down Expand Up @@ -243,21 +231,10 @@ def create_gru_critic(
python_based,
)

spec = CompositeSpec(
**{
recurrent_state: UnboundedContinuousTensorSpec(
shape=torch.Size([gru.gru.num_layers, gru.gru.hidden_size]),
dtype=torch.float32,
)
}
)

critic_inference_model = TensorDictSequential(embedding, gru, head)
critic_training_model = TensorDictSequential(
embedding, gru.set_recurrent_mode(True), head
)
critic_training_model.rnn_spec = spec
critic_inference_model.rnn_spec = deepcopy(spec)
return critic_training_model, critic_inference_model


Expand Down Expand Up @@ -318,15 +295,6 @@ def create_gru_actor_critic(
python_based,
)

spec = CompositeSpec(
**{
recurrent_state: UnboundedContinuousTensorSpec(
shape=torch.Size([gru.gru.num_layers, gru.gru.hidden_size]),
dtype=torch.float32,
)
}
)

actor_head = ProbabilisticActor(
module=actor_head,
in_keys=["logits"],
Expand Down Expand Up @@ -365,9 +333,4 @@ def create_gru_actor_critic(
actor_training = actor_critic_training.get_policy_operator()
critic_training = actor_critic_training.get_value_operator()

actor_training.rnn_spec = spec
actor_inference.rnn_spec = deepcopy(spec)
critic_training.rnn_spec = deepcopy(spec)
critic_inference.rnn_spec = deepcopy(spec)

return actor_training, actor_inference, critic_training, critic_inference
51 changes: 1 addition & 50 deletions acegen/models/lstm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from copy import deepcopy
from typing import Optional

import torch
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs import ExplorationType
from torchrl.modules import ActorValueOperator, LSTMModule, MLP, ProbabilisticActor

Expand Down Expand Up @@ -158,18 +156,7 @@ def create_lstm_actor(
recurrent_state,
python_based,
)
spec = CompositeSpec(
**{
f"{recurrent_state}_h": UnboundedContinuousTensorSpec(
shape=torch.Size([lstm.lstm.num_layers, lstm.lstm.hidden_size]),
dtype=torch.float32,
),
f"{recurrent_state}_c": UnboundedContinuousTensorSpec(
shape=torch.Size([lstm.lstm.num_layers, lstm.lstm.hidden_size]),
dtype=torch.float32,
),
}
)

actor_inference_model = TensorDictSequential(embedding, lstm, head)
actor_training_model = TensorDictSequential(
embedding,
Expand All @@ -195,9 +182,6 @@ def create_lstm_actor(
default_interaction_type=ExplorationType.RANDOM,
)

actor_training_model.rnn_spec = spec
actor_inference_model.rnn_spec = deepcopy(spec)

return actor_training_model, actor_inference_model


Expand Down Expand Up @@ -251,25 +235,10 @@ def create_lstm_critic(
python_based,
)

spec = CompositeSpec(
**{
f"{recurrent_state}_h": UnboundedContinuousTensorSpec(
shape=torch.Size([lstm.lstm.num_layers, lstm.lstm.hidden_size]),
dtype=torch.float32,
),
f"{recurrent_state}_c": UnboundedContinuousTensorSpec(
shape=torch.Size([lstm.lstm.num_layers, lstm.lstm.hidden_size]),
dtype=torch.float32,
),
}
)

critic_inference_model = TensorDictSequential(embedding, lstm, head)
critic_training_model = TensorDictSequential(
embedding, lstm.set_recurrent_mode(True), head
)
critic_training_model.rnn_spec = spec
critic_inference_model.rnn_spec = deepcopy(spec)
return critic_training_model, critic_inference_model


Expand Down Expand Up @@ -330,19 +299,6 @@ def create_lstm_actor_critic(
python_based,
)

spec = CompositeSpec(
**{
f"{recurrent_state}_h": UnboundedContinuousTensorSpec(
shape=torch.Size([lstm.lstm.num_layers, lstm.lstm.hidden_size]),
dtype=torch.float32,
),
f"{recurrent_state}_c": UnboundedContinuousTensorSpec(
shape=torch.Size([lstm.lstm.num_layers, lstm.lstm.hidden_size]),
dtype=torch.float32,
),
}
)

actor_head = ProbabilisticActor(
module=actor_head,
in_keys=["logits"],
Expand Down Expand Up @@ -381,9 +337,4 @@ def create_lstm_actor_critic(
actor_training = actor_critic_training.get_policy_operator()
critic_training = actor_critic_training.get_value_operator()

actor_training.rnn_spec = spec
actor_inference.rnn_spec = deepcopy(spec)
critic_training.rnn_spec = deepcopy(spec)
critic_inference.rnn_spec = deepcopy(spec)

return actor_training, actor_inference, critic_training, critic_inference
20 changes: 4 additions & 16 deletions scripts/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

from torchrl.envs import InitTracker, TensorDictPrimer, TransformedEnv
from torchrl.envs import InitTracker, TransformedEnv
from torchrl.modules.utils import get_primers_from_module
from torchrl.objectives import A2CLoss
from torchrl.objectives.value.advantages import GAE
from torchrl.record.loggers import get_logger
Expand Down Expand Up @@ -175,19 +176,6 @@ def run_a2c(cfg, task):
# Create RL environment
####################################################################################################################

rhs_primers = []
# if rnn's, create a transform to populate initial tensordict with recurrent states equal to 0.0
if cfg.shared_nets and hasattr(actor_training, "rnn_spec"):
primers = actor_training.rnn_spec.expand(cfg.num_envs)
rhs_primers = [TensorDictPrimer(primers)]
elif hasattr(actor_training, "rnn_spec"):
actor_primers = actor_training.rnn_spec.expand(cfg.num_envs)
critic_primers = critic_training.rnn_spec.expand(cfg.num_envs)
rhs_primers = [
TensorDictPrimer(actor_primers),
TensorDictPrimer(critic_primers),
]

# Define environment kwargs
env_kwargs = {
"start_token": vocabulary.start_token_index,
Expand All @@ -203,8 +191,8 @@ def create_env_fn():
env = SMILESEnv(**env_kwargs)
env = TransformedEnv(env)
env.append_transform(InitTracker())
for rhs_primer in rhs_primers:
env.append_transform(rhs_primer)
env.append_transform(get_primers_from_module(actor_training))
env.append_transform(get_primers_from_module(critic_training))
return env

env = create_env_fn()
Expand Down
12 changes: 3 additions & 9 deletions scripts/ahc/ahc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
TensorDictMaxValueWriter,
TensorDictReplayBuffer,
)
from torchrl.envs import InitTracker, TensorDictPrimer, TransformedEnv
from torchrl.envs import InitTracker, TransformedEnv
from torchrl.modules.utils import get_primers_from_module
from torchrl.record.loggers import get_logger

try:
Expand Down Expand Up @@ -155,12 +156,6 @@ def run_ahc(cfg, task):
# Create RL environment
####################################################################################################################

# For RNNs, create a transform to populate initial tensordict with recurrent states equal to 0.0
rhs_primers = []
if hasattr(actor_training, "rnn_spec"):
primers = actor_training.rnn_spec.expand(cfg.num_envs)
rhs_primers.append(TensorDictPrimer(primers))

env_kwargs = {
"start_token": vocabulary.start_token_index,
"end_token": vocabulary.end_token_index,
Expand All @@ -174,8 +169,7 @@ def create_env_fn():
env = SMILESEnv(**env_kwargs)
env = TransformedEnv(env)
env.append_transform(InitTracker())
for rhs_primer in rhs_primers:
env.append_transform(rhs_primer)
env.append_transform(get_primers_from_module(actor_training))
return env

env = create_env_fn()
Expand Down
21 changes: 4 additions & 17 deletions scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
TensorDictMaxValueWriter,
TensorDictReplayBuffer,
)
from torchrl.envs import ExplorationType, InitTracker, TensorDictPrimer, TransformedEnv
from torchrl.envs import InitTracker, TransformedEnv
from torchrl.modules.utils import get_primers_from_module
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE
from torchrl.record.loggers import get_logger


try:
import molscore
from molscore import MolScoreBenchmark
Expand Down Expand Up @@ -181,19 +181,6 @@ def run_ppo(cfg, task):
# Create RL environment
####################################################################################################################

rhs_primers = []
# if rnn's, create a transform to populate initial tensordict with recurrent states equal to 0.0
if cfg.shared_nets and hasattr(actor_training, "rnn_spec"):
primers = actor_training.rnn_spec.expand(cfg.num_envs)
rhs_primers = [TensorDictPrimer(primers)]
elif hasattr(actor_training, "rnn_spec"):
actor_primers = actor_training.rnn_spec.expand(cfg.num_envs)
critic_primers = critic_training.rnn_spec.expand(cfg.num_envs)
rhs_primers = [
TensorDictPrimer(actor_primers),
TensorDictPrimer(critic_primers),
]

# Define environment kwargs
env_kwargs = {
"start_token": vocabulary.start_token_index,
Expand All @@ -209,8 +196,8 @@ def create_env_fn():
env = SMILESEnv(**env_kwargs)
env = TransformedEnv(env)
env.append_transform(InitTracker())
for rhs_primer in rhs_primers:
env.append_transform(rhs_primer)
env.append_transform(get_primers_from_module(actor_training))
env.append_transform(get_primers_from_module(critic_training))
return env

env = create_env_fn()
Expand Down
12 changes: 3 additions & 9 deletions scripts/reinforce/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
TensorDictMaxValueWriter,
TensorDictReplayBuffer,
)
from torchrl.envs import InitTracker, TensorDictPrimer, TransformedEnv
from torchrl.envs import InitTracker, TransformedEnv
from torchrl.modules.utils import get_primers_from_module
from torchrl.record.loggers import get_logger

try:
Expand Down Expand Up @@ -153,12 +154,6 @@ def run_reinforce(cfg, task):
# Create RL environment
####################################################################################################################

# For RNNs, create a transform to populate initial tensordict with recurrent states equal to 0.0
rhs_primers = []
if hasattr(actor_training, "rnn_spec"):
primers = actor_training.rnn_spec.expand(cfg.num_envs)
rhs_primers.append(TensorDictPrimer(primers))

env_kwargs = {
"start_token": vocabulary.start_token_index,
"end_token": vocabulary.end_token_index,
Expand All @@ -172,8 +167,7 @@ def create_env_fn():
env = SMILESEnv(**env_kwargs)
env = TransformedEnv(env)
env.append_transform(InitTracker())
for rhs_primer in rhs_primers:
env.append_transform(rhs_primer)
env.append_transform(get_primers_from_module(actor_training))
return env

env = create_env_fn()
Expand Down
12 changes: 3 additions & 9 deletions scripts/reinvent/reinvent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
TensorDictMaxValueWriter,
TensorDictReplayBuffer,
)
from torchrl.envs import InitTracker, TensorDictPrimer, TransformedEnv
from torchrl.envs import InitTracker, TransformedEnv
from torchrl.modules.utils import get_primers_from_module
from torchrl.record.loggers import get_logger

try:
Expand Down Expand Up @@ -155,12 +156,6 @@ def run_reinvent(cfg, task):
# Create RL environment
####################################################################################################################

# For RNNs, create a transform to populate initial tensordict with recurrent states equal to 0.0
rhs_primers = []
if hasattr(actor_training, "rnn_spec"):
primers = actor_training.rnn_spec.expand(cfg.num_envs)
rhs_primers.append(TensorDictPrimer(primers))

env_kwargs = {
"start_token": vocabulary.start_token_index,
"end_token": vocabulary.end_token_index,
Expand All @@ -174,8 +169,7 @@ def create_env_fn():
env = SMILESEnv(**env_kwargs)
env = TransformedEnv(env)
env.append_transform(InitTracker())
for rhs_primer in rhs_primers:
env.append_transform(rhs_primer)
env.append_transform(get_primers_from_module(actor_training))
return env

env = create_env_fn()
Expand Down

0 comments on commit 499bd06

Please sign in to comment.