diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index 98c431bb22..19c9b7fa50 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from pathlib import Path from typing import Callable, ClassVar, Collection, Dict, Iterator, List, \ - Optional, Sequence, Set, Tuple, Any + Optional, Sequence, Set, Tuple import PIL.Image import matplotlib @@ -22,6 +22,7 @@ from predicators import utils from predicators.envs import BaseEnv +from predicators.pretrained_model_interface import OpenAIVLM from predicators.settings import CFG from predicators.spot_utils.perception.object_detection import \ AprilTagObjectDetectionID, KnownStaticObjectDetectionID, \ @@ -49,7 +50,6 @@ from predicators.structs import Action, EnvironmentTask, GoalDescription, \ GroundAtom, LiftedAtom, Object, Observation, Predicate, \ SpotActionExtraInfo, State, STRIPSOperator, Type, Variable -from predicators.pretrained_model_interface import OpenAIVLM ############################################################################### # Base Class # @@ -96,12 +96,7 @@ class _PartialPerceptionState(State): in the classifier definitions for the dummy predicates """ - # # DEBUG Add an additional field to store Spot images - # # This would be directly copied from the images in raw Observation - # # NOTE: This is only used when using VLM for predicate evaluation - # # NOTE: Performance aspect should be considered later - # cam_images: Optional[Dict[str, RGBDImageWithContext]] = None - # # TODO: it's still unclear how we select and store useful images! + # obs_images: Optional[Dict[str, RGBDImageWithContext]] = None @property def _simulator_state_predicates(self) -> Set[Predicate]: @@ -1071,11 +1066,11 @@ def _generate_goal_description(self) -> GoalDescription: def vlm_predicate_classify(question: str, state: State) -> bool: """Use VLM to evaluate (classify) a predicate in a given state.""" - full_prompt = vlm_predicate_eval_prompt_prefix.format( - question=question - ) + full_prompt = vlm_predicate_eval_prompt_prefix.format(question=question) images_dict: Dict[str, RGBDImageWithContext] = state.camera_images - images = [PIL.Image.fromarray(v.rotated_rgb) for _, v in images_dict.items()] + images = [ + PIL.Image.fromarray(v.rotated_rgb) for _, v in images_dict.items() + ] logging.info(f"VLM predicate evaluation for: {question}") logging.info(f"Prompt: {full_prompt}") @@ -1095,7 +1090,8 @@ def vlm_predicate_classify(question: str, state: State) -> bool: elif vlm_response == "no": return False else: - logging.error(f"VLM response not understood: {vlm_response}. Treat as False.") + logging.error( + f"VLM response not understood: {vlm_response}. Treat as False.") return False @@ -1197,7 +1193,8 @@ def _on_classifier(state: State, objects: Sequence[Object]) -> bool: else: # Check that the bottom of the object is close to the top of the surface. - expect = state.get(obj_surface, "z") + state.get(obj_surface, "height") / 2 + expect = state.get(obj_surface, + "z") + state.get(obj_surface, "height") / 2 actual = state.get(obj_on, "z") - state.get(obj_on, "height") / 2 classification_val = abs(actual - expect) < _ONTOP_Z_THRESHOLD @@ -1575,7 +1572,8 @@ def _get_sweeping_surface_for_container(container: Object, # is VLM perceptible or not. # NOTE: candidates: on, inside, door opened, blocking, not blocked, ... _VLM_EVAL_PREDICATES: { - _On, _Inside, + _On, + _Inside, }