Skip to content

Commit

Permalink
update tasks for testing picking
Browse files Browse the repository at this point in the history
  • Loading branch information
lf-zhao committed Nov 12, 2024
1 parent 9839333 commit 9b081c9
Showing 1 changed file with 75 additions and 11 deletions.
86 changes: 75 additions & 11 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3357,7 +3357,8 @@ def _get_dry_task(self, train_or_test: str,
"cup": "orange cup/orange cylinder/orange-ish mug",
"orange_cup": "orange cup/orange cylinder/orange-ish mug",
"blue_cup": "blue cup/blue mug/uncovered blue cup",
"green_cup": "green cup/greenish cup/green cylinder",
# "green_cup": "green cup/greenish cup/green cylinder",
"green_cup": "green cup",

# Containers
"green_bowl": "green bowl/greenish bowl",
Expand Down Expand Up @@ -3404,12 +3405,13 @@ def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]:

# List object identifier, object name (to find prompt), and type
objects_to_detect = [
("red_block", _movable_object_type),
# ("block", "green_block", _movable_object_type),
("block", "green_cup", _movable_object_type),
]

# Add detection IDs for each object
for obj_name, obj_type in objects_to_detect:
obj = Object(obj_name, obj_type)
# Add detection object prompt and save object identifier
for obj_identifier, obj_name, obj_type in objects_to_detect:
obj = Object(obj_identifier, obj_type)
detection_id = _get_detection_id(obj_name)
detection_id_to_obj[detection_id] = obj

Expand All @@ -3421,7 +3423,7 @@ def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]:
return detection_id_to_obj

def _generate_goal_description(self) -> GoalDescription:
return "pick up the red block"
return "pick up the block"

def _get_dry_task(self, train_or_test: str,
task_idx: int) -> EnvironmentTask:
Expand Down Expand Up @@ -3492,12 +3494,17 @@ def __init__(self, use_gui: bool = True) -> None:

op_to_name = {o.name: o for o in _create_operators()}
op_names_to_keep = {
# "MoveToReachObject",
# "PickObjectFromTop",
# "PlaceObjectOnTop",
# "DropObjectInside",
# "MoveToHandViewObject",
# "ObserveFromTop",
"MoveToReachObject",
"PickObjectFromTop",
"MoveToHandViewObject",
"PickObjectFromTop",
"PlaceObjectOnTop",
"DropObjectInside",
"MoveToHandViewObject",
"ObserveFromTop",
}
self._strips_operators = {op_to_name[o] for o in op_names_to_keep}

Expand All @@ -3512,12 +3519,12 @@ def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]:
# List object identifier, object name (to find prompt), and type
objects_to_detect = [
("cardboard_box", "cardboard_box", _container_type),
("block", "red_block", _movable_object_type),
# ("block", "red_block", _movable_object_type),
# DEBUG objects:
# ("block", "orange_cup", _container_type),
# ("block", "spam_box", _container_type),
# ("block", "yellow_apple", _movable_object_type),
# ("block", "green_cup", _container_type),
("block", "green_cup", _container_type),
# ("block", "green_block", _movable_object_type),
# ("block", "green_apple", _movable_object_type),
]
Expand Down Expand Up @@ -3653,6 +3660,63 @@ def _generate_goal_description(self) -> GoalDescription:
def _get_dry_task(self, train_or_test: str,
task_idx: int) -> EnvironmentTask:
raise NotImplementedError("Dry task generation not implemented.")


class LISSpotTableMultiCupInBoxEnv(SpotRearrangementEnv):
"""A partially observable environment where a cup (with certain property) on a table needs to be moved into a cardboard box."""

def __init__(self, use_gui: bool = True) -> None:
super().__init__(use_gui)

op_to_name = {o.name: o for o in _create_operators()}
op_names_to_keep = {
"MoveToReachObject",
"PickObjectFromTop",
"PlaceObjectOnTop",
"DropObjectInside",
"MoveToHandViewObject",
"ObserveFromTop",
}
self._strips_operators = {op_to_name[o] for o in op_names_to_keep}

@classmethod
def get_name(cls) -> str:
return "lis_spot_table_cup_in_box_env"

@property
def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]:
detection_id_to_obj: Dict[ObjectDetectionID, Object] = {}

# List object identifier, object name (to find prompt), and type
objects_to_detect = [
("cardboard_box", "cardboard_box", _container_type),
("cup", "orange_cup", _movable_object_type),
]

# Add detection object prompt and save object identifier
for obj_identifier, obj_name, obj_type in objects_to_detect:
obj = Object(obj_identifier, obj_type)
detection_id = _get_detection_id(obj_name)
detection_id_to_obj[detection_id] = obj

# AprilTag object
wooden_table = Object("wooden_table", _immovable_object_type)
wooden_table_detection = AprilTagObjectDetectionID(32)
detection_id_to_obj[wooden_table_detection] = wooden_table

# Add known immovable objects
for obj, pose in get_known_immovable_objects().items():
detection_id = KnownStaticObjectDetectionID(obj.name, pose)
detection_id_to_obj[detection_id] = obj

return detection_id_to_obj

def _generate_goal_description(self) -> GoalDescription:
return "put the cup into the cardboard box on floor"

def _get_dry_task(self, train_or_test: str,
task_idx: int) -> EnvironmentTask:
raise NotImplementedError("Dry task generation not implemented.")


class LISSpotEmptyCupBoxEnv(SpotRearrangementEnv):
Expand Down

0 comments on commit 9b081c9

Please sign in to comment.