From 7f1ede126fec54304e2031a87138920059184b68 Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Mon, 26 Feb 2024 13:23:58 -0800 Subject: [PATCH] work? --- mani_skill2/utils/wrappers/record.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/mani_skill2/utils/wrappers/record.py b/mani_skill2/utils/wrappers/record.py index c1add17b7..79c2dcadf 100644 --- a/mani_skill2/utils/wrappers/record.py +++ b/mani_skill2/utils/wrappers/record.py @@ -1,5 +1,6 @@ import copy import time +from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Union @@ -78,6 +79,18 @@ def clean_trajectories(h5_file: h5py.File, json_dict: dict, prune_empty_action=T json_dict["episodes"] = new_json_episodes +@dataclass +class Step: + state: np.ndarray + observation: np.ndarray + action: np.ndarray + reward: np.ndarray + terminated: np.ndarray + truncated: np.ndarray + env_ptrs: np.ndarray + """track the step of each parallel env?""" + + def pack_step_data(state, obs, action, rew, terminated, truncated, info): data = dict( s=to_numpy(state) if state is not None else None, @@ -230,9 +243,9 @@ def reset( options.pop("save_trajectory", False) if self.save_on_reset and self._episode_id >= 0 and not skip_trajectory: - # To make things easier, we only flush data when there is no partial reset. + self.flush_trajectory(ignore_empty_transition=True) + # To make things easier, we only flush videos when there is no partial reset. if "env_idx" not in options: - self.flush_trajectory(ignore_empty_transition=True) self.flush_video() # Clear cache @@ -284,7 +297,6 @@ def step(self, action): image = put_info_on_image(image, scalar_info, extras=extra_texts) self._render_images.append(image) - print(self._video_steps, self.max_steps_per_video) if ( self.max_steps_per_video is not None and self._video_steps >= self.max_steps_per_video