Skip to content

Commit

Permalink
working; implement On predicate with VLM classifier pipeline! add cal…
Browse files Browse the repository at this point in the history
…l VLM and more
  • Loading branch information
lf-zhao committed May 1, 2024
1 parent 1abf488 commit 1c82c44
Showing 1 changed file with 110 additions and 29 deletions.
139 changes: 110 additions & 29 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Callable, ClassVar, Collection, Dict, Iterator, List, \
Optional, Sequence, Set, Tuple, Any

import PIL.Image
import matplotlib
import numpy as np
import pbrspot
Expand Down Expand Up @@ -48,6 +49,7 @@
from predicators.structs import Action, EnvironmentTask, GoalDescription, \
GroundAtom, LiftedAtom, Object, Observation, Predicate, \
SpotActionExtraInfo, State, STRIPSOperator, Type, Variable
from predicators.pretrained_model_interface import OpenAIVLM

###############################################################################
# Base Class #
Expand Down Expand Up @@ -1035,6 +1037,39 @@ def _generate_goal_description(self) -> GoalDescription:
"""For now, we assume that there's only one goal per environment."""


###############################################################################
# VLM Predicate Evaluation Related #
###############################################################################

# Initialize VLM
vlm = OpenAIVLM(model_name="gpt-4-turbo", detail="auto")

# Engineer the prompt for VLM
vlm_predicate_eval_prompt_prefix = """
Your goal is to answer questions related to object relationships in the
given image(s).
We will use following predicate-style descriptions to ask questions:
Inside(object1, container)
Blocking(object1, object2)
On(object, surface)
Examples:
Does this predicate hold in the following image?
Inside(apple, bowl)
Answer (in a single word): Yes/No
Actual question:
Does this predicate hold in the following image?
{question}
Answer (in a single word):
"""

# Provide some visual examples when needed
vlm_predicate_eval_prompt_example = ""

# TODO: Next, try include visual hints via segmentation ("Set of Masks")


###############################################################################
# Shared Types, Predicates, Operators #
###############################################################################
Expand Down Expand Up @@ -1117,14 +1152,46 @@ def _on_classifier(state: State, objects: Sequence[Object]) -> bool:
obj_on, obj_surface = objects
currently_visible = all([o in state.visible_objects for o in objects])

print(currently_visible, state)

# If object not all visible and choose to use VLM,
# then use predicate values of previous time step
if CFG.spot_vlm_eval_predicate and not currently_visible:
# TODO: add all previous atoms to the state
raise NotImplementedError

# Call VLM to evaluate predicate value
elif CFG.spot_vlm_eval_predicate and currently_visible:
# TODO call VLM to evaluate predicate value
raise NotImplementedError
predicate_str = f"On({obj_on}, {obj_surface})"
full_prompt = vlm_predicate_eval_prompt_prefix.format(
question=predicate_str
)

images_dict: Dict[str, RGBDImageWithContext] = state.camera_images
images = [PIL.Image.fromarray(v.rotated_rgb) for _, v in images_dict.items()]

# Logging: prompt
logging.info(f"VLM predicate evaluation for: {predicate_str}")
logging.info(f"Prompt: {full_prompt}")

vlm_responses = vlm.sample_completions(
prompt=full_prompt,
imgs=images,
temperature=0.2,
seed=int(time.time()),
num_completions=1,
)

# Logging
logging.info(f"VLM response 0: {vlm_responses[0]}")

vlm_response = vlm_responses[0].strip().lower()
if vlm_response == "yes":
return True
elif vlm_response == "no":
return False
else:
logging.error(f"VLM response not understood: {vlm_response}. Treat as False.")
return False

else:
# Check that the bottom of the object is close to the top of the surface.
expect = state.get(obj_surface, "z") + state.get(obj_surface, "height") / 2
Expand Down Expand Up @@ -1154,35 +1221,49 @@ def _inside_classifier(state: State, objects: Sequence[Object]) -> bool:

print(currently_visible, state)

if CFG.spot_vlm_eval_predicate and not currently_visible:
# TODO: add all previous atoms to the state
raise NotImplementedError
elif CFG.spot_vlm_eval_predicate and currently_visible:
# TODO call VLM to evaluate predicate value
raise NotImplementedError

else:

if not _object_in_xy_classifier(
state, obj_in, obj_container, buffer=_INSIDE_SURFACE_BUFFER):
return False
# if CFG.spot_vlm_eval_predicate and not currently_visible:
# # TODO: add all previous atoms to the state
# # TODO: then we just use the atom value from the last state
# raise NotImplementedError
# elif CFG.spot_vlm_eval_predicate and currently_visible:
# # TODO call VLM to evaluate predicate value
# full_prompt = vlm_predicate_eval_prompt_prefix.format(
# question=f"Inside({obj_in}, {obj_container})"
# )
# images = state.camera_images
#
# vlm_responses = vlm.sample_completions(
# prompt=full_prompt,
# imgs=images,
# temperature=0.2,
# seed=int(time.time()),
# num_completions=1,
# )
# vlm_response = vlm_responses[0].strip().lower()
# raise NotImplementedError
#
# else:

if not _object_in_xy_classifier(
state, obj_in, obj_container, buffer=_INSIDE_SURFACE_BUFFER):
return False

obj_z = state.get(obj_in, "z")
obj_half_height = state.get(obj_in, "height") / 2
obj_bottom = obj_z - obj_half_height
obj_top = obj_z + obj_half_height
obj_z = state.get(obj_in, "z")
obj_half_height = state.get(obj_in, "height") / 2
obj_bottom = obj_z - obj_half_height
obj_top = obj_z + obj_half_height

container_z = state.get(obj_container, "z")
container_half_height = state.get(obj_container, "height") / 2
container_bottom = container_z - container_half_height
container_top = container_z + container_half_height
container_z = state.get(obj_container, "z")
container_half_height = state.get(obj_container, "height") / 2
container_bottom = container_z - container_half_height
container_top = container_z + container_half_height

# Check that the bottom is "above" the bottom of the container.
if obj_bottom < container_bottom - _INSIDE_Z_THRESHOLD:
return False
# Check that the bottom is "above" the bottom of the container.
if obj_bottom < container_bottom - _INSIDE_Z_THRESHOLD:
return False

# Check that the top is "below" the top of the container.
return obj_top < container_top + _INSIDE_Z_THRESHOLD
# Check that the top is "below" the top of the container.
return obj_top < container_top + _INSIDE_Z_THRESHOLD


def _not_inside_any_container_classifier(state: State,
Expand Down

0 comments on commit 1c82c44

Please sign in to comment.