-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implemented methods to save and restore PyBullet states. (#33)
* Implemented methods to save and restore PyBullet states. * Fixed typos. * Added docs for save_state() and remove_state(). * Make save and restore state docs visible in index. * Added unit tests for save and restore states. * Added unit test for remove state. * Fixed save and restore test logic. * isort and black * Test for desired goal consistency during state saving and restoring. * Save and restore task goal. * Run linting. * `p` to `self.physics_client` * fix docstring style * Update version Co-authored-by: Quentin Gallouédec <[email protected]>
- Loading branch information
1 parent
2f634e2
commit bdc9ae1
Showing
7 changed files
with
109 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
.. _save_restore_states: | ||
|
||
Save and Restore States | ||
============== | ||
|
||
It is possible to save a state of the entire simulation environment. This is useful if your application requires lookahead search. Below is an example of a greedy random search. | ||
|
||
.. code-block:: python | ||
import gym | ||
import panda_gym | ||
env = gym.make("PandaReachDense-v2", render=True) | ||
obs = env.reset() | ||
while True: | ||
state_id = env.save_state() | ||
best_action = None | ||
rew = best_rew = env.task.compute_reward( | ||
obs["achieved_goal"], obs["desired_goal"], None) | ||
while rew <= best_rew: | ||
env.restore_state(state_id) | ||
a = env.action_space.sample() | ||
_, rew, _, _ = env.step(a) | ||
env.restore_state(state_id) | ||
obs, _, _, _ = env.step(a) | ||
env.remove_state(state_id) | ||
env.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
2.0.3 | ||
2.0.4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import gym | ||
import numpy as np | ||
import pybullet | ||
import pytest | ||
|
||
import panda_gym | ||
|
||
|
||
def test_save_and_restore_state(): | ||
env = gym.make("PandaReach-v2") | ||
env.reset() | ||
|
||
state_id = env.save_state() | ||
|
||
# Perform the action | ||
action = env.action_space.sample() | ||
next_obs1, reward, done, info = env.step(action) | ||
|
||
# Restore and perform the same action | ||
env.reset() | ||
env.restore_state(state_id) | ||
next_obs2, reward, done, info = env.step(action) | ||
|
||
# The observations in both cases should be equals | ||
assert np.all(next_obs1["achieved_goal"] == next_obs2["achieved_goal"]) | ||
assert np.all(next_obs1["observation"] == next_obs2["observation"]) | ||
assert np.all(next_obs1["desired_goal"] == next_obs2["desired_goal"]) | ||
|
||
|
||
def test_remove_state(): | ||
env = gym.make("PandaReach-v2") | ||
env.reset() | ||
state_id = env.save_state() | ||
env.remove_state(state_id) | ||
with pytest.raises(pybullet.error): | ||
env.restore_state(state_id) |