Skip to content

Commit

Permalink
update objects
Browse files Browse the repository at this point in the history
  • Loading branch information
lf-zhao committed Nov 5, 2024
1 parent cd3bf5c commit bf48b4f
Showing 1 changed file with 40 additions and 46 deletions.
86 changes: 40 additions & 46 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3335,12 +3335,19 @@ def _get_dry_task(self, train_or_test: str,
_OBJECT_PROMPTS = {
# Movable objects
"red_block": "red block/orange block/yellow block",
"green_block": "green block/greenish block",
"yellow_apple": "yellow apple/yellowish apple",
"green_apple": "green apple/greenish apple",
"spam_box": "spam box/spam container/spam-ish box",

# Movable objects: cups, potentially with lids and non-empty
"cup": "orange cup/orange cylinder/orange-ish mug",
"blue_cup": "blue cup/blue mug/uncovered blue cup",
"green_cup": "green cup/greenish cup/green cylinder",

# Containers
"green_bowl": "green bowl/greenish bowl",
"cardboard_box": "cardboard box/paper box",
"cup": "orange cup/orange cylinder/orange-ish mug",
"blue_cup": "blue cup/blue mug/uncovered blue cup",

# Fixed objects with AprilTags
"wooden_table": 32, # AprilTag ID
Expand Down Expand Up @@ -3379,18 +3386,24 @@ def get_name(cls) -> str:

@property
def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]:

detection_id_to_obj: Dict[ObjectDetectionID, Object] = {}

red_block = Object("red_block", _movable_object_type)
red_block_detection = LanguageObjectDetectionID(
"red block/orange block/yellow block")
detection_id_to_obj[red_block_detection] = red_block


# List objects used in this env
objects_to_detect = [
("red_block", _movable_object_type),
]

# Add detection IDs for each object
for obj_name, obj_type in objects_to_detect:
obj = Object(obj_name, obj_type)
detection_id = _get_detection_id(obj_name)
detection_id_to_obj[detection_id] = obj

# Add known immovable objects
for obj, pose in get_known_immovable_objects().items():
detection_id = KnownStaticObjectDetectionID(obj.name, pose)
detection_id_to_obj[detection_id] = obj

return detection_id_to_obj

def _generate_goal_description(self) -> GoalDescription:
Expand All @@ -3415,7 +3428,7 @@ def __init__(self, use_gui: bool = True) -> None:
op_names_to_keep = {
"MoveToReachObject",
"MoveToHandViewObject",
"PickObjectFromTop",
"PickObjectFromTop",
"PlaceObjectOnTop",
"DropObjectInside",
}
Expand All @@ -3427,44 +3440,26 @@ def get_name(cls) -> str:

@property
def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]:

detection_id_to_obj: Dict[ObjectDetectionID, Object] = {}

red_block = Object("red_block", _movable_object_type)
red_block_detection = LanguageObjectDetectionID(
"red block/orange block/yellow block")
detection_id_to_obj[red_block_detection] = red_block

green_bowl = Object("green_bowl", _container_type)
# green_bowl_detection = LanguageObjectDetectionID(
# "green bowl/greenish bowl")
# TODO test
green_bowl_detection = LanguageObjectDetectionID(
"cardboard box/paper box")
detection_id_to_obj[green_bowl_detection] = green_bowl

# # TODO temp test new object sets
# green_bowl = Object("green_bowl", _container_type)
# green_bowl_detection = LanguageObjectDetectionID(
# "green bowl/greenish bowl")
# detection_id_to_obj[green_bowl_detection] = green_bowl

# Case 1: Mug facing up with no lid
# To try more cases below
# orange_mug = Object("orange_mug", _movable_object_type)
# orange_mug_detection = LanguageObjectDetectionID("orange mug/orange cup/uncovered orange mug")
# detection_id_to_obj[orange_mug_detection] = orange_mug

# TODO just use different prompt - for debugging
# red_block = Object("red_block", _movable_object_type)
# red_block_detection = LanguageObjectDetectionID(
# "orange mug/orange cup/uncovered orange mug")
# detection_id_to_obj[red_block_detection] = red_block

# List objects used in this env
objects_to_detect = [
("red_block", _movable_object_type),
("green_bowl", _container_type),
# ("orange_mug", _movable_object_type), # Case 1: Mug facing up with no lid
]

# Add detection IDs for each object
for obj_name, obj_type in objects_to_detect:
obj = Object(obj_name, obj_type)
detection_id = _get_detection_id(obj_name)
detection_id_to_obj[detection_id] = obj

# Add known immovable objects
for obj, pose in get_known_immovable_objects().items():
detection_id = KnownStaticObjectDetectionID(obj.name, pose)
detection_id_to_obj[detection_id] = obj

return detection_id_to_obj

def _generate_goal_description(self) -> GoalDescription:
Expand All @@ -3487,8 +3482,7 @@ def __init__(self, use_gui: bool = True) -> None:
"PickObjectFromTop",
"PlaceObjectOnTop",
"DropObjectInside",
# NOTE: add new; replacing "MoveToHandViewObject"
"MoveToHandObserveObjectFromTop",
"MoveToHandObserveObjectFromTop", # NOTE: replacing "MoveToHandViewObject"
"ObserveFromTop",
}
self._strips_operators = {op_to_name[o] for o in op_names_to_keep}
Expand Down

0 comments on commit bf48b4f

Please sign in to comment.