diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index 2fac61485..78af052e8 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -3357,7 +3357,8 @@ def _get_dry_task(self, train_or_test: str, "cup": "orange cup/orange cylinder/orange-ish mug", "orange_cup": "orange cup/orange cylinder/orange-ish mug", "blue_cup": "blue cup/blue mug/uncovered blue cup", - "green_cup": "green cup/greenish cup/green cylinder", + # "green_cup": "green cup/greenish cup/green cylinder", + "green_cup": "green cup", # Containers "green_bowl": "green bowl/greenish bowl", @@ -3404,12 +3405,13 @@ def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]: # List object identifier, object name (to find prompt), and type objects_to_detect = [ - ("red_block", _movable_object_type), + # ("block", "green_block", _movable_object_type), + ("block", "green_cup", _movable_object_type), ] - # Add detection IDs for each object - for obj_name, obj_type in objects_to_detect: - obj = Object(obj_name, obj_type) + # Add detection object prompt and save object identifier + for obj_identifier, obj_name, obj_type in objects_to_detect: + obj = Object(obj_identifier, obj_type) detection_id = _get_detection_id(obj_name) detection_id_to_obj[detection_id] = obj @@ -3421,7 +3423,7 @@ def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]: return detection_id_to_obj def _generate_goal_description(self) -> GoalDescription: - return "pick up the red block" + return "pick up the block" def _get_dry_task(self, train_or_test: str, task_idx: int) -> EnvironmentTask: @@ -3492,12 +3494,17 @@ def __init__(self, use_gui: bool = True) -> None: op_to_name = {o.name: o for o in _create_operators()} op_names_to_keep = { + # "MoveToReachObject", + # "PickObjectFromTop", + # "PlaceObjectOnTop", + # "DropObjectInside", + # "MoveToHandViewObject", + # "ObserveFromTop", "MoveToReachObject", - "PickObjectFromTop", + "MoveToHandViewObject", + "PickObjectFromTop", "PlaceObjectOnTop", "DropObjectInside", - "MoveToHandViewObject", - "ObserveFromTop", } self._strips_operators = {op_to_name[o] for o in op_names_to_keep} @@ -3512,12 +3519,12 @@ def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]: # List object identifier, object name (to find prompt), and type objects_to_detect = [ ("cardboard_box", "cardboard_box", _container_type), - ("block", "red_block", _movable_object_type), + # ("block", "red_block", _movable_object_type), # DEBUG objects: # ("block", "orange_cup", _container_type), # ("block", "spam_box", _container_type), # ("block", "yellow_apple", _movable_object_type), - # ("block", "green_cup", _container_type), + ("block", "green_cup", _container_type), # ("block", "green_block", _movable_object_type), # ("block", "green_apple", _movable_object_type), ] @@ -3653,6 +3660,63 @@ def _generate_goal_description(self) -> GoalDescription: def _get_dry_task(self, train_or_test: str, task_idx: int) -> EnvironmentTask: raise NotImplementedError("Dry task generation not implemented.") + + +class LISSpotTableMultiCupInBoxEnv(SpotRearrangementEnv): + """A partially observable environment where a cup (with certain property) on a table needs to be moved into a cardboard box.""" + + def __init__(self, use_gui: bool = True) -> None: + super().__init__(use_gui) + + op_to_name = {o.name: o for o in _create_operators()} + op_names_to_keep = { + "MoveToReachObject", + "PickObjectFromTop", + "PlaceObjectOnTop", + "DropObjectInside", + "MoveToHandViewObject", + "ObserveFromTop", + } + self._strips_operators = {op_to_name[o] for o in op_names_to_keep} + + @classmethod + def get_name(cls) -> str: + return "lis_spot_table_cup_in_box_env" + + @property + def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]: + detection_id_to_obj: Dict[ObjectDetectionID, Object] = {} + + # List object identifier, object name (to find prompt), and type + objects_to_detect = [ + ("cardboard_box", "cardboard_box", _container_type), + ("cup", "orange_cup", _movable_object_type), + ] + + # Add detection object prompt and save object identifier + for obj_identifier, obj_name, obj_type in objects_to_detect: + obj = Object(obj_identifier, obj_type) + detection_id = _get_detection_id(obj_name) + detection_id_to_obj[detection_id] = obj + + # AprilTag object + wooden_table = Object("wooden_table", _immovable_object_type) + wooden_table_detection = AprilTagObjectDetectionID(32) + detection_id_to_obj[wooden_table_detection] = wooden_table + + # Add known immovable objects + for obj, pose in get_known_immovable_objects().items(): + detection_id = KnownStaticObjectDetectionID(obj.name, pose) + detection_id_to_obj[detection_id] = obj + + return detection_id_to_obj + + def _generate_goal_description(self) -> GoalDescription: + return "put the cup into the cardboard box on floor" + + def _get_dry_task(self, train_or_test: str, + task_idx: int) -> EnvironmentTask: + raise NotImplementedError("Dry task generation not implemented.") class LISSpotEmptyCupBoxEnv(SpotRearrangementEnv):