diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index 7835a9ca2b..20fb77a8a1 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -94,12 +94,12 @@ class _PartialPerceptionState(State): in the classifier definitions for the dummy predicates """ - # DEBUG Add an additional field to store Spot images - # This would be directly copied from the images in raw Observation - # NOTE: This is only used when using VLM for predicate evaluation - # NOTE: Performance aspect should be considered later - obs_images: Optional[Dict[str, RGBDImageWithContext]] = None - # TODO: it's still unclear how we select and store useful images! + # # DEBUG Add an additional field to store Spot images + # # This would be directly copied from the images in raw Observation + # # NOTE: This is only used when using VLM for predicate evaluation + # # NOTE: Performance aspect should be considered later + # cam_images: Optional[Dict[str, RGBDImageWithContext]] = None + # # TODO: it's still unclear how we select and store useful images! @property def _simulator_state_predicates(self) -> Set[Predicate]: @@ -128,7 +128,8 @@ def copy(self) -> State: "atoms": self._simulator_state_atoms.copy() } return _PartialPerceptionState(state_copy, - simulator_state=sim_state_copy) + simulator_state=sim_state_copy, + camera_images=self.camera_images) def _create_dummy_predicate_classifier( @@ -1114,10 +1115,16 @@ def _object_in_xy_classifier(state: State, def _on_classifier(state: State, objects: Sequence[Object]) -> bool: obj_on, obj_surface = objects + currently_visible = all([o in state.visible_objects for o in objects]) - if CFG.spot_vlm_eval_predicate: - print("TODO!!") - print(state.camera_images) + print(currently_visible, state) + + if CFG.spot_vlm_eval_predicate and not currently_visible: + # TODO: add all previous atoms to the state + raise NotImplementedError + elif CFG.spot_vlm_eval_predicate and currently_visible: + # TODO call VLM to evaluate predicate value + raise NotImplementedError else: # Check that the bottom of the object is close to the top of the surface. expect = state.get(obj_surface, "z") + state.get(obj_surface, "height") / 2 @@ -1143,27 +1150,39 @@ def _top_above_classifier(state: State, objects: Sequence[Object]) -> bool: def _inside_classifier(state: State, objects: Sequence[Object]) -> bool: obj_in, obj_container = objects + currently_visible = all([o in state.visible_objects for o in objects]) - if not _object_in_xy_classifier( - state, obj_in, obj_container, buffer=_INSIDE_SURFACE_BUFFER): - return False + print(currently_visible, state) - obj_z = state.get(obj_in, "z") - obj_half_height = state.get(obj_in, "height") / 2 - obj_bottom = obj_z - obj_half_height - obj_top = obj_z + obj_half_height + if CFG.spot_vlm_eval_predicate and not currently_visible: + # TODO: add all previous atoms to the state + raise NotImplementedError + elif CFG.spot_vlm_eval_predicate and currently_visible: + # TODO call VLM to evaluate predicate value + raise NotImplementedError - container_z = state.get(obj_container, "z") - container_half_height = state.get(obj_container, "height") / 2 - container_bottom = container_z - container_half_height - container_top = container_z + container_half_height + else: - # Check that the bottom is "above" the bottom of the container. - if obj_bottom < container_bottom - _INSIDE_Z_THRESHOLD: - return False + if not _object_in_xy_classifier( + state, obj_in, obj_container, buffer=_INSIDE_SURFACE_BUFFER): + return False - # Check that the top is "below" the top of the container. - return obj_top < container_top + _INSIDE_Z_THRESHOLD + obj_z = state.get(obj_in, "z") + obj_half_height = state.get(obj_in, "height") / 2 + obj_bottom = obj_z - obj_half_height + obj_top = obj_z + obj_half_height + + container_z = state.get(obj_container, "z") + container_half_height = state.get(obj_container, "height") / 2 + container_bottom = container_z - container_half_height + container_top = container_z + container_half_height + + # Check that the bottom is "above" the bottom of the container. + if obj_bottom < container_bottom - _INSIDE_Z_THRESHOLD: + return False + + # Check that the top is "below" the top of the container. + return obj_top < container_top + _INSIDE_Z_THRESHOLD def _not_inside_any_container_classifier(state: State, @@ -1462,6 +1481,13 @@ def _get_sweeping_surface_for_container(container: Object, _IsSemanticallyGreaterThan } _NONPERCEPT_PREDICATES: Set[Predicate] = set() +# NOTE: We maintain a list of predicates that we check via +# NOTE: In the future, we may include an attribute to denote whether a predicate +# is VLM perceptible or not. +# NOTE: candidates: on, inside, door opened, blocking, not blocked, ... +_VLM_EVAL_PREDICATES: { + _On, _Inside, +} ## Operators (needed in the environment for non-percept atom hack)