Skip to content

Commit

Permalink
Save params (#107)
Browse files Browse the repository at this point in the history
* new feature: save params

* add unittest for save()/retore()

* add an example demonstrating the usage

* rename the variable

* yapf

* fix comment
  • Loading branch information
TomorrowIsAnOtherDay authored Aug 1, 2019
1 parent 2f11d0c commit 7dafee7
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 23 deletions.
6 changes: 3 additions & 3 deletions examples/QuickStart/cartpole_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ def __init__(self, algorithm, obs_dim, act_dim):

def build_program(self):
self.pred_program = fluid.Program()
self.train_program = fluid.Program()
self.learn_program = fluid.Program()

with fluid.program_guard(self.pred_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
self.act_prob = self.alg.predict(obs)

with fluid.program_guard(self.train_program):
with fluid.program_guard(self.learn_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
act = layers.data(name='act', shape=[1], dtype='int64')
Expand Down Expand Up @@ -68,5 +68,5 @@ def learn(self, obs, act, reward):
'reward': reward.astype('float32')
}
cost = self.fluid_executor.run(
self.train_program, feed=feed, fetch_list=[self.cost])[0]
self.learn_program, feed=feed, fetch_list=[self.cost])[0]
return cost
76 changes: 69 additions & 7 deletions parl/core/fluid/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Agent(AgentBase):
| `alias`: ``parl.Agent``
| `alias`: ``parl.core.fluid.agent.Agent``
| Agent is one of the three basic classes of PARL.
| Agent is one of the three basic classes of PARL.
| It is responsible for interacting with the environment and collecting data for training the policy.
| To implement a customized ``Agent``, users can:
Expand All @@ -57,10 +57,12 @@ def __init__(self, algorithm, act_dim):
- ``sample``: return a noisy action to perform exploration according to the policy.
- ``predict``: return an action given current observation.
- ``learn``: update the parameters of self.alg using the `learn_program` defined in `build_program()`.
- ``save``: save parameters of the ``agent`` to a given path.
- ``restore``: restore previous saved parameters from a given path.
Todo:
- allow users to get parameters of a specified model by specifying the model's name in ``get_weights()``.
"""

def __init__(self, algorithm, gpu_id=None):
Expand Down Expand Up @@ -90,13 +92,13 @@ def __init__(self, algorithm, gpu_id=None):
self.fluid_executor.run(fluid.default_startup_program())

def build_program(self):
"""Build various programs here with the
"""Build various programs here with the
learn, predict, sample functions of the algorithm.
Note:
| Users **must** implement this function in an ``Agent``.
| This function will be called automatically in the initialization function.
To build a program, you must do the following:
a. Create a fluid program with ``fluid.program_guard()``;
b. Define data layers for feeding the data;
Expand All @@ -112,7 +114,7 @@ def build_program(self):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
self.act_prob = self.alg.predict(obs)
"""
raise NotImplementedError
Expand Down Expand Up @@ -152,8 +154,68 @@ def predict(self, *args, **kwargs):

def sample(self, *args, **kwargs):
"""Return an action with noise when given the observation of the environment.
In general, this function is used in train process as noise is added to the action to preform exploration.
"""
raise NotImplementedError

def save(self, save_path, program=None):
"""Save parameters.
Args:
save_path(str): where to save the parameters.
program(fluid.Program): program that describes the neural network structure. If None, will use self.learn_program.
Raises:
ValueError: if program is None and self.learn_program does not exist.
Example:
.. code-block:: python
agent = AtariAgent()
agent.save('./model.ckpt')
"""
if program is None:
program = self.learn_program
dirname = '/'.join(save_path.split('/')[:-1])
filename = save_path.split('/')[-1]
fluid.io.save_params(
executor=self.fluid_executor,
dirname=dirname,
main_program=program,
filename=filename)

def restore(self, save_path, program=None):
"""Restore previously saved parameters.
This method requires a program that describes the network structure.
The save_path argument is typically a value previously passed to ``save_params()``.
Args:
save_path(str): path where parameters were previously saved.
program(fluid.Program): program that describes the neural network structure. If None, will use self.learn_program.
Raises:
ValueError: if program is None and self.learn_program does not exist.
Example:
.. code-block:: python
agent = AtariAgent()
agent.save('./model.ckpt')
agent.restore('./model.ckpt')
"""

if program is None:
program = self.learn_program
dirname = '/'.join(save_path.split('/')[:-1])
filename = save_path.split('/')[-1]
fluid.io.load_params(
executor=self.fluid_executor,
dirname=dirname,
main_program=program,
filename=filename)
70 changes: 57 additions & 13 deletions parl/core/fluid/tests/agent_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,42 +16,59 @@
import unittest
from paddle import fluid
from parl import layers
from parl.core.fluid.agent import Agent
from parl.core.fluid.algorithm import Algorithm
from parl.core.fluid.model import Model
from parl.utils.machine_info import get_gpu_count
import parl
import os


class TestModel(Model):
class TestModel(parl.Model):
def __init__(self):
self.fc1 = layers.fc(size=256)
self.fc2 = layers.fc(size=128)
self.fc2 = layers.fc(size=1)

def policy(self, obs):
out = self.fc1(obs)
out = self.fc2(out)
return out


class TestAlgorithm(Algorithm):
class TestAlgorithm(parl.Algorithm):
def __init__(self, model):
self.model = model

def predict(self, obs):
return self.model.policy(obs)

def learn(self, obs, label):
pred_output = self.model.policy(obs)
cost = layers.square_error_cost(obs, label)
cost = fluid.layers.reduce_mean(cost)
return cost

class TestAgent(Agent):

class TestAgent(parl.Agent):
def __init__(self, algorithm, gpu_id=None):
super(TestAgent, self).__init__(algorithm, gpu_id)

def build_program(self):
self.predict_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.predict_program):
obs = layers.data(name='obs', shape=[10], dtype='float32')
output = self.algorithm.predict(obs)
self.predict_output = [output]

with fluid.program_guard(self.learn_program):
obs = layers.data(name='obs', shape=[10], dtype='float32')
label = layers.data(name='label', shape=[1], dtype='float32')
cost = self.algorithm.learn(obs, label)

def learn(self, obs, label):
output_np = self.fluid_executor.run(
self.learn_program, feed={
'obs': obs,
'label': label
})

def predict(self, obs):
output_np = self.fluid_executor.run(
self.predict_program,
Expand All @@ -66,11 +83,38 @@ def setUp(self):
self.algorithm = TestAlgorithm(self.model)

def test_agent(self):
if get_gpu_count() > 0:
agent = TestAgent(self.algorithm)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
self.assertIsNotNone(output_np)
agent = TestAgent(self.algorithm)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
self.assertIsNotNone(output_np)

def test_save(self):
agent = TestAgent(self.algorithm)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
save_path1 = './model.ckpt'
save_path2 = './my_model/model-2.ckpt'
agent.save(save_path1)
agent.save(save_path2)
self.assertTrue(os.path.exists(save_path1))
self.assertTrue(os.path.exists(save_path2))

def test_restore(self):
agent = TestAgent(self.algorithm)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
save_path1 = './model.ckpt'
previous_output = agent.predict(obs)
agent.save(save_path1)
agent.restore(save_path1)
current_output = agent.predict(obs)
np.testing.assert_equal(current_output, previous_output)

# a new agent instance
another_agent = TestAgent(self.algorithm)
another_agent.restore(save_path1)
current_output = another_agent.predict(obs)
np.testing.assert_equal(current_output, previous_output)


if __name__ == '__main__':
Expand Down

0 comments on commit 7dafee7

Please sign in to comment.