Skip to content

Commit

Permalink
found a way to use VLM to evaluate; check if visible in current scene…
Browse files Browse the repository at this point in the history
…, only update these predicates
  • Loading branch information
lf-zhao committed Apr 30, 2024
1 parent aec70de commit 94e6a4c
Showing 1 changed file with 52 additions and 26 deletions.
78 changes: 52 additions & 26 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ class _PartialPerceptionState(State):
in the classifier definitions for the dummy predicates
"""

# DEBUG Add an additional field to store Spot images
# This would be directly copied from the images in raw Observation
# NOTE: This is only used when using VLM for predicate evaluation
# NOTE: Performance aspect should be considered later
obs_images: Optional[Dict[str, RGBDImageWithContext]] = None
# TODO: it's still unclear how we select and store useful images!
# # DEBUG Add an additional field to store Spot images
# # This would be directly copied from the images in raw Observation
# # NOTE: This is only used when using VLM for predicate evaluation
# # NOTE: Performance aspect should be considered later
# cam_images: Optional[Dict[str, RGBDImageWithContext]] = None
# # TODO: it's still unclear how we select and store useful images!

@property
def _simulator_state_predicates(self) -> Set[Predicate]:
Expand Down Expand Up @@ -128,7 +128,8 @@ def copy(self) -> State:
"atoms": self._simulator_state_atoms.copy()
}
return _PartialPerceptionState(state_copy,
simulator_state=sim_state_copy)
simulator_state=sim_state_copy,
camera_images=self.camera_images)


def _create_dummy_predicate_classifier(
Expand Down Expand Up @@ -1114,10 +1115,16 @@ def _object_in_xy_classifier(state: State,

def _on_classifier(state: State, objects: Sequence[Object]) -> bool:
obj_on, obj_surface = objects
currently_visible = all([o in state.visible_objects for o in objects])

if CFG.spot_vlm_eval_predicate:
print("TODO!!")
print(state.camera_images)
print(currently_visible, state)

if CFG.spot_vlm_eval_predicate and not currently_visible:
# TODO: add all previous atoms to the state
raise NotImplementedError
elif CFG.spot_vlm_eval_predicate and currently_visible:
# TODO call VLM to evaluate predicate value
raise NotImplementedError
else:
# Check that the bottom of the object is close to the top of the surface.
expect = state.get(obj_surface, "z") + state.get(obj_surface, "height") / 2
Expand All @@ -1143,27 +1150,39 @@ def _top_above_classifier(state: State, objects: Sequence[Object]) -> bool:

def _inside_classifier(state: State, objects: Sequence[Object]) -> bool:
obj_in, obj_container = objects
currently_visible = all([o in state.visible_objects for o in objects])

if not _object_in_xy_classifier(
state, obj_in, obj_container, buffer=_INSIDE_SURFACE_BUFFER):
return False
print(currently_visible, state)

obj_z = state.get(obj_in, "z")
obj_half_height = state.get(obj_in, "height") / 2
obj_bottom = obj_z - obj_half_height
obj_top = obj_z + obj_half_height
if CFG.spot_vlm_eval_predicate and not currently_visible:
# TODO: add all previous atoms to the state
raise NotImplementedError
elif CFG.spot_vlm_eval_predicate and currently_visible:
# TODO call VLM to evaluate predicate value
raise NotImplementedError

container_z = state.get(obj_container, "z")
container_half_height = state.get(obj_container, "height") / 2
container_bottom = container_z - container_half_height
container_top = container_z + container_half_height
else:

# Check that the bottom is "above" the bottom of the container.
if obj_bottom < container_bottom - _INSIDE_Z_THRESHOLD:
return False
if not _object_in_xy_classifier(
state, obj_in, obj_container, buffer=_INSIDE_SURFACE_BUFFER):
return False

# Check that the top is "below" the top of the container.
return obj_top < container_top + _INSIDE_Z_THRESHOLD
obj_z = state.get(obj_in, "z")
obj_half_height = state.get(obj_in, "height") / 2
obj_bottom = obj_z - obj_half_height
obj_top = obj_z + obj_half_height

container_z = state.get(obj_container, "z")
container_half_height = state.get(obj_container, "height") / 2
container_bottom = container_z - container_half_height
container_top = container_z + container_half_height

# Check that the bottom is "above" the bottom of the container.
if obj_bottom < container_bottom - _INSIDE_Z_THRESHOLD:
return False

# Check that the top is "below" the top of the container.
return obj_top < container_top + _INSIDE_Z_THRESHOLD


def _not_inside_any_container_classifier(state: State,
Expand Down Expand Up @@ -1462,6 +1481,13 @@ def _get_sweeping_surface_for_container(container: Object,
_IsSemanticallyGreaterThan
}
_NONPERCEPT_PREDICATES: Set[Predicate] = set()
# NOTE: We maintain a list of predicates that we check via
# NOTE: In the future, we may include an attribute to denote whether a predicate
# is VLM perceptible or not.
# NOTE: candidates: on, inside, door opened, blocking, not blocked, ...
_VLM_EVAL_PREDICATES: {
_On, _Inside,
}


## Operators (needed in the environment for non-percept atom hack)
Expand Down

0 comments on commit 94e6a4c

Please sign in to comment.