Skip to content

Commit

Permalink
Implemented methods to save and restore PyBullet states. (#33)
Browse files Browse the repository at this point in the history
* Implemented methods to save and restore PyBullet states.

* Fixed typos.

* Added docs for save_state() and remove_state().

* Make save and restore state docs visible in index.

* Added unit tests for save and restore states.

* Added unit test for remove state.

* Fixed save and restore test logic.

* isort and black

* Test for desired goal consistency during state saving and restoring.

* Save and restore task goal.

* Run linting.

* `p` to `self.physics_client`

* fix docstring style

* Update version

Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
louixp and qgallouedec authored Jul 5, 2022
1 parent 2f634e2 commit bdc9ae1
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
author = 'Quentin Gallouédec'

# The full version, including alpha/beta/rc tags
release = 'v2.0.3'
release = 'v2.0.4'


# -- General configuration ---------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Welcome to panda-gym's documentation!

usage/environments
usage/manual_control
usage/save_restore_state
usage/train_with_sb3

.. toctree::
Expand Down
31 changes: 31 additions & 0 deletions docs/usage/save_restore_state.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
.. _save_restore_states:

Save and Restore States
==============

It is possible to save a state of the entire simulation environment. This is useful if your application requires lookahead search. Below is an example of a greedy random search.

.. code-block:: python
import gym
import panda_gym
env = gym.make("PandaReachDense-v2", render=True)
obs = env.reset()
while True:
state_id = env.save_state()
best_action = None
rew = best_rew = env.task.compute_reward(
obs["achieved_goal"], obs["desired_goal"], None)
while rew <= best_rew:
env.restore_state(state_id)
a = env.action_space.sample()
_, rew, _, _ = env.step(a)
env.restore_state(state_id)
obs, _, _, _ = env.step(a)
env.remove_state(state_id)
env.close()
14 changes: 14 additions & 0 deletions panda_gym/envs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def __init__(self, robot: PyBulletRobot, task: Task) -> None:
)
self.action_space = self.robot.action_space
self.compute_reward = self.task.compute_reward
self._saved_goal = dict()

def _get_obs(self) -> Dict[str, np.ndarray]:
robot_obs = self.robot.get_obs() # robot state
Expand All @@ -246,6 +247,19 @@ def reset(self, seed: Optional[int] = None) -> Dict[str, np.ndarray]:
self.task.reset()
return self._get_obs()

def save_state(self) -> int:
state_id = self.sim.save_state()
self._saved_goal[state_id] = self.task.goal
return state_id

def restore_state(self, state_id: int) -> None:
self.sim.restore_state(state_id)
self.task.goal = self._saved_goal[state_id]

def remove_state(self, state_id: int) -> None:
self._saved_goal.pop(state_id)
self.sim.remove_state(state_id)

def step(self, action: np.ndarray) -> Tuple[Dict[str, np.ndarray], float, bool, Dict[str, Any]]:
self.robot.set_action(action)
self.sim.step()
Expand Down
25 changes: 25 additions & 0 deletions panda_gym/pybullet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,31 @@ def close(self) -> None:
"""Close the simulation."""
self.physics_client.disconnect()

def save_state(self) -> int:
"""Save the current simulation state.
Returns:
int: A state id assigned by PyBullet, which is the first non-negative
integer available for indexing.
"""
return self.physics_client.saveState()

def restore_state(self, state_id: int) -> None:
"""Restore a simulation state.
Args:
state_id: The simulation state id returned by save_state().
"""
self.physics_client.restoreState(state_id)

def remove_state(self, state_id: int) -> None:
"""Remove a simulation state. This will make this state_id available again for returning in save_state().
Args:
state_id: The simulation state id returned by save_state().
"""
self.physics_client.removeState(state_id)

def render(
self,
mode: str = "human",
Expand Down
2 changes: 1 addition & 1 deletion panda_gym/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.3
2.0.4
36 changes: 36 additions & 0 deletions test/save_and_restore_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import gym
import numpy as np
import pybullet
import pytest

import panda_gym


def test_save_and_restore_state():
env = gym.make("PandaReach-v2")
env.reset()

state_id = env.save_state()

# Perform the action
action = env.action_space.sample()
next_obs1, reward, done, info = env.step(action)

# Restore and perform the same action
env.reset()
env.restore_state(state_id)
next_obs2, reward, done, info = env.step(action)

# The observations in both cases should be equals
assert np.all(next_obs1["achieved_goal"] == next_obs2["achieved_goal"])
assert np.all(next_obs1["observation"] == next_obs2["observation"])
assert np.all(next_obs1["desired_goal"] == next_obs2["desired_goal"])


def test_remove_state():
env = gym.make("PandaReach-v2")
env.reset()
state_id = env.save_state()
env.remove_state(state_id)
with pytest.raises(pybullet.error):
env.restore_state(state_id)

0 comments on commit bdc9ae1

Please sign in to comment.