-
Notifications
You must be signed in to change notification settings - Fork 377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Imitation Learning #973
base: i210_dev
Are you sure you want to change the base?
Imitation Learning #973
Conversation
…nto akash-dagger
…nto akash-dagger
make_policy_optimizer=choose_policy_optimizer, | ||
validate_config=validate_config, | ||
after_optimizer_step=update_kl, | ||
after_train_result=warn_about_bad_reward_scales) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing blank line at end of file
self.action_network = action_network # neural network which specifies action to take | ||
self.multiagent = multiagent # whether env is multiagent or singleagent | ||
self.veh_id = veh_id # vehicle id that controller is controlling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comments not needed, should be in doc-string
try: | ||
rl_ids = env.get_sorted_rl_ids() | ||
except: | ||
print("Error caught: no get_sorted_rl_ids function, using get_rl_ids instead") | ||
rl_ids = env.k.vehicle.get_rl_ids() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use has_attr to do this instead of a try except
""" | ||
# observation is a dictionary for multiagent envs, list for singleagent envs | ||
if self.multiagent: | ||
observation = env.get_state()[self.veh_id] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if you're on a no control edge and your ID isn't in the state?
from flow.controllers.imitation_learning.utils_tensorflow import * | ||
from flow.controllers.imitation_learning.keras_utils import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import * not recommended; loading files you don't need?
# load network if specified, or construct network | ||
if load_model: | ||
self.load_network(load_path) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: no space between if ad else
if len(observation.shape)<=1: | ||
observation = observation[None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
summary = tf.Summary(value=[tf.Summary.Value(tag="Variance norm", simple_value=variance_norm), ]) | ||
self.writer.add_summary(summary, global_step=self.action_steps) | ||
|
||
cov_matrix = np.diag(var[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why var[0]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Magic number means you should write what the expected dimension of this object is. This makes it so at least the magic number is understandable
# build layers for policy | ||
for i in range(num_layers): | ||
size = self.model.layers[i + 1].output.shape[1].value | ||
activation = tf.keras.activations.serialize(self.model.layers[i + 1].activation) | ||
curr_layer = tf.keras.layers.Dense(size, activation=activation, name="policy_hidden_layer_{}".format(i + 1))(curr_layer) | ||
output_layer_policy = tf.keras.layers.Dense(self.model.output.shape[1].value, activation=None, name="policy_output_layer")(curr_layer) | ||
|
||
# build layers for value function | ||
curr_layer = input | ||
for i in range(num_layers): | ||
size = self.fcnet_hiddens[i] | ||
curr_layer = tf.keras.layers.Dense(size, activation="tanh", name="vf_hidden_layer_{}".format(i+1))(curr_layer) | ||
output_layer_vf = tf.keras.layers.Dense(1, activation=None, name="vf_output_layer")(curr_layer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
watch out; you're implicitly assuming that vf_share_layers is never true. This should be warned about somewhere
""" | ||
|
||
env_name = config['env'] | ||
# agent_cls = get_agent_class(config['env_config']['run']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, remove
def compare_weights(ppo_model, imitation_path): | ||
imitation_model = tf.keras.models.load_model(imitation_path, custom_objects={'nll_loss': negative_log_likelihood_loss(0.5)}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great call to put this in
(outputs, state) | ||
Tuple, first element is policy output, second element state | ||
""" | ||
# print(self.base_model.get_weights()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove
def forward(self, input_dict, state, seq_lens): | ||
""" | ||
Overrides parent class's method. Used to pass a input through model and get policy/vf output. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it necessary to override this function?
""" Replay buffer class to store state, action, expert_action, reward, next_state, terminal tuples""" | ||
|
||
def __init__(self, max_size=100000): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doc_string
action_dim = (1,)[0] | ||
|
||
sess = create_tf_session() | ||
action_network = ImitatingNetwork(sess, action_dim, obs_dim, None, None, None, None, load_existing=True, load_path='/Users/akashvelu/Documents/models2/') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure what this file is for; seems to be the equivalent of a unit test? If so, maybe make it a unit test?
import os | ||
import time | ||
import numpy as np | ||
from flow.controllers.imitation_learning.trainer import Trainer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
header to make clear what this file is for
from flow.controllers.imitation_learning.run import * | ||
from examples.train import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this file should probably not be in the controllers folder
import time | ||
from collections import OrderedDict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think a lot of these files should not be in this folder.
import numpy as np | ||
import tensorflow as tf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like this file could be used to simplify some of your other files?
flow/visualize/visualizer_rllib.py
Outdated
# TODO(akashvelu): remove this | ||
# print("NEW CONFIGGG: ", config['env_config']['run']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove
flow/visualize/visualizer_rllib.py
Outdated
agent.import_model('/Users/akashvelu/Desktop/combined_test3/ppo_model.h5', 'av') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really good; main things are minor nits and that we need to figure out the right folder to place this / integrate the imitation into train.py
Pull request information
Description
? (general description)