-
Notifications
You must be signed in to change notification settings - Fork 173
/
a3c_model.py
29 lines (21 loc) · 1.39 KB
/
a3c_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import tensorflow as tf
from keras import backend as K
from keras.layers import Convolution2D, Flatten, Dense, Input
from keras.models import Model
def build_policy_and_value_networks(num_actions, agent_history_length, resized_width, resized_height):
with tf.device("/cpu:0"):
state = tf.placeholder("float", [None, agent_history_length, resized_width, resized_height])
inputs = Input(shape=(agent_history_length, resized_width, resized_height,))
shared = Convolution2D(name="conv1", nb_filter=16, nb_row=8, nb_col=8, subsample=(4,4), activation='relu', border_mode='same')(inputs)
shared = Convolution2D(name="conv2", nb_filter=32, nb_row=4, nb_col=4, subsample=(2,2), activation='relu', border_mode='same')(shared)
shared = Flatten()(shared)
shared = Dense(name="h1", output_dim=256, activation='relu')(shared)
action_probs = Dense(name="p", output_dim=num_actions, activation='softmax')(shared)
state_value = Dense(name="v", output_dim=1, activation='linear')(shared)
policy_network = Model(input=inputs, output=action_probs)
value_network = Model(input=inputs, output=state_value)
p_params = policy_network.trainable_weights
v_params = value_network.trainable_weights
p_out = policy_network(state)
v_out = value_network(state)
return state, p_out, v_out, p_params, v_params