Skip to content

Commit

Permalink
fixed llm
Browse files Browse the repository at this point in the history
  • Loading branch information
wmcclinton committed Aug 18, 2023
1 parent 791eb66 commit 838e4c2
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
{
"objects": {
"hammer": "tool",
"floor": "floor",
"low_wall_rack": "flat_surface",
"extra_room_table": "flat_surface",
"bucket": "bag",
"tool_room_table": "flat_surface",
"brush": "tool",
"high_wall_rack": "flat_surface",
"drill": "tool",
"toolbag": "bag",
"measuring_tape": "tool",
"platform": "platform",
"spot": "robot",
"soda_can": "tool",
"umbrella": "tool",
"work_room_table": "flat_surface"
},
"init": {
"hammer": {
"x": 9.76516329018639,
"y": -6.822070521356679,
"z": 0.5556237073192688,
"lost": 0.0,
"in_view": 1.0
},
"floor": {
"x": 0.0,
"y": 0.0,
"z": -1.0
},
"low_wall_rack": {
"x": 9.849155913120516,
"y": -6.983181743565726,
"z": 0.24044494963648955
},
"extra_room_table": {
"x": 8.649433068213108,
"y": -6.229963443684359,
"z": -0.04515858591900626
},
"bucket": {
"x": 7.13679051790645,
"y": -8.26595972649328,
"z": -0.2007359283710973
},
"tool_room_table": {
"x": 6.429854722858522,
"y": -6.345642880614663,
"z": 0.11198430745302461
},
"brush": {
"x": 6.434172550103909,
"y": -6.106715248379663,
"z": 0.189010690666345,
"lost": 0.0,
"in_view": 1.0
},
"high_wall_rack": {
"x": 9.946591808280566,
"y": -7.326505639775886,
"z": 0.9000969451682522
},
"drill": {
"x": 9.75439324480222,
"y": -7.160612462503434,
"z": 1.116727639997601,
"lost": 0.0,
"in_view": 1.0
},
"toolbag": {
"x": 7.867815591155479,
"y": -5.907131434232858,
"z": -0.22990444589247733
},
"measuring_tape": {
"x": 8.746071606743353,
"y": -6.369464294717586,
"z": -0.38775082324964677,
"lost": 0.0,
"in_view": 1.0
},
"platform": {
"x": 8.700455481253051,
"y": -7.90716092394469,
"z": -0.07673487813571947,
"lost": 0.0,
"in_view": 1.0
},
"spot": {
"gripper_open_percentage": 1.0372281074523926,
"curr_held_item_id": 0,
"x": 8.291213884735656,
"y": -6.993276021353255,
"z": 0.14816109444794257,
"yaw": 0.029029069289615964
},
"soda_can": {
"x": 6.429854722858522,
"y": -6.565642880614663,
"z": -0.38775082324964677,
"lost": 0.0,
"in_view": 1.0
},
"work_room_table": {
"x": 6.429854722858522,
"y": -6.345642880614663,
"z": 0.11198430745302461
},
"umbrella": {
"x": 9.75439324480222,
"y": -7.160612462503434,
"z": 0.5556237073192688,
"lost": 0.0,
"in_view": 1.0
}
}
}
1 change: 1 addition & 0 deletions predicators/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def _parse_language_goal_from_json(
# to handle various errors and perhaps query the LLM for multiple
# responses until we find one that can be parsed.
goal_spec = json.loads(response)
print(response)
return self._parse_goal_from_json(goal_spec, id_to_obj)

def _parse_goal_from_input_to_json(
Expand Down
95 changes: 91 additions & 4 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, ClassVar, Dict, List, Optional, Sequence, Set, \
Tuple
Tuple, Collection

import matplotlib
import numpy as np
Expand All @@ -28,7 +28,7 @@
get_spot_interface, obj_name_to_apriltag_id
from predicators.structs import Action, Array, EnvironmentTask, GroundAtom, \
Image, LiftedAtom, Object, Observation, Predicate, State, STRIPSOperator, \
Type, Variable
Type, Variable, Task

###############################################################################
# Base Class #
Expand Down Expand Up @@ -527,8 +527,66 @@ def render_state_plt(
action: Optional[Action] = None,
caption: Optional[str] = None) -> matplotlib.figure.Figure:
raise NotImplementedError("This env does not use Matplotlib")

def _get_language_goal_prompt_prefix(self,
object_names: Collection[str]) -> str:
# pylint:disable=line-too-long
available_predicates = ", ".join(
[p.name for p in sorted(self.goal_predicates)])
available_objects = ", ".join(sorted(object_names))
# # We could extract the object names, but this is simpler.
# assert {"spot", "counter", "snack_table",
# "soda_can"}.issubset(object_names)
prompt = f"""# The available predicates are: {available_predicates}
# The available objects are: {available_objects}
# Use the available predicates and objects to convert natural language goals into PDDL JSON goals.
# (eg. {{"On": [["apple", "snack_table"]]}})
"""
return prompt

def _parse_init_preds_from_json(
self, spec: Dict[str, List[List[str]]],
id_to_obj: Dict[str, Object]) -> Set[GroundAtom]:
"""Helper for parsing init preds from JSON task specifications."""
pred_names = {p.name for p in self.predicates}
assert set(spec.keys()).issubset(pred_names)
pred_to_args = {p: spec.get(p.name, []) for p in self.predicates}
init_preds: Set[GroundAtom] = set()
for pred, args in pred_to_args.items():
for id_args in args:
obj_args = [id_to_obj[a] for a in id_args]
init_atom = GroundAtom(pred, obj_args)
init_preds.add(init_atom)
return init_preds

def _load_task_from_json(self, json_file: Path) -> Task:
"""Create a task from a JSON file.
By default, we assume JSON files are in the following format:
{
"objects": {
<object name>: <type name>
}
"init": {
<object name>: {
<feature name>: <value>
}
}
"goal": {
<predicate name> : [
[<object name>]
]
}
}
def _load_task_from_json(self, json_file: Path) -> EnvironmentTask:
Instead of "goal", "language_goal" can also be used.
Environments can override this method to handle different formats.
"""
with open(json_file, "r", encoding="utf-8") as f:
json_dict = json.load(f)
########
# Use the BaseEnv default code for loading from JSON, which will
# create a State as an observation. We'll then convert that State
# into a _SpotObservation instead.
Expand Down Expand Up @@ -565,6 +623,26 @@ def _load_task_from_json(self, json_file: Path) -> EnvironmentTask:
)
# The goal can remain the same.
goal = base_env_task.goal
task = EnvironmentTask(init_obs, goal)
########
object_name_to_object: Dict[str, Object] = {}
json_dict["init"] = init_obs
for obj in init:
object_name_to_object[obj.name] = obj
# TODO make flag
print(f"\n{object_name_to_object}\n")
json_dict['language_goal'] = input("\n[ChatGPT-Spot] What do you need from me?\n\n>> ")
print(json_dict)
########

# Parse goal.
if "goal" in json_dict:
goal = self._parse_goal_from_json(json_dict["goal"],
object_name_to_object)
else:
assert "language_goal" in json_dict
goal = self._parse_language_goal_from_json(
json_dict["language_goal"], object_name_to_object)
return EnvironmentTask(init_obs, goal)


Expand Down Expand Up @@ -1200,14 +1278,18 @@ def _make_object_name_to_obj_dict(self) -> Dict[str, Object]:
spot = Object("spot", self._robot_type)
tool_room_table = Object("tool_room_table", self._surface_type)
extra_room_table = Object("extra_room_table", self._surface_type)
work_room_table = Object("work_room_table", self._surface_type)
soda_can = Object("soda_can", self._tool_type)
umbrella = Object("umbrella", self._tool_type)
low_wall_rack = Object("low_wall_rack", self._surface_type)
high_wall_rack = Object("high_wall_rack", self._surface_type)
bucket = Object("bucket", self._bag_type)
toolbag = Object("toolbag", self._bag_type)
floor = Object("floor", self._floor_type)
objects.extend([
spot, tool_room_table, low_wall_rack, high_wall_rack, bucket,
extra_room_table, floor, toolbag
extra_room_table, floor, toolbag, work_room_table, soda_can,
umbrella
])
return {o.name: o for o in objects}

Expand All @@ -1216,6 +1298,11 @@ def _obj_name_to_obj(self, obj_name: str) -> Object:

@property
def goal_predicates(self) -> Set[Predicate]:
goal_preds = set()
for pred in self.predicates:
if "Reachable" not in pred.name:
goal_preds.add(pred)

return self.predicates

def _actively_construct_initial_object_views(
Expand Down
5 changes: 2 additions & 3 deletions predicators/spot_utils/perception_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,11 +578,10 @@ def _run_offline_analysis() -> None:
# Convenient script for identifying which classes might be best for a
# group of images that all have the same object. The images should still
# be manually inspected (in the debug dir).
class_candidates = ["brown bag for tools", "carrybag"]
class_candidates = ["soda can"]
# pylint:disable=line-too-long
files = [
"20230818-131813_detic_sam_right_fisheye_image_object_locs_inputs.png",
"20230818-131814_detic_sam_frontright_fisheye_image_object_locs_inputs.png"
"20230818-161904_detic_sam_hand_color_image_object_locs_inputs.png"
]
root_dir = Path(__file__).parent / "../.."
utils.reset_config({
Expand Down
10 changes: 8 additions & 2 deletions predicators/spot_utils/spot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def get_memorized_waypoint(obj_name: str) -> Optional[Tuple[str, Array]]:
"low_wall_rack": "alight-coyote-Nvl0i02Mk7Ds8ax0sj0Hsw==",
"high_wall_rack": "alight-coyote-Nvl0i02Mk7Ds8ax0sj0Hsw==",
"extra_room_table": "alight-coyote-Nvl0i02Mk7Ds8ax0sj0Hsw==",
"work_room_table": "ranked-oxen-G0kq38CpHN7H7R.0FCm7DA=="
}
offsets = {"extra_room_table": np.array([0.0, -0.3, np.pi / 2])}
if obj_name not in graph_nav_loc_to_id:
Expand All @@ -92,15 +93,20 @@ def get_memorized_waypoint(obj_name: str) -> Optional[Tuple[str, Array]]:
"platform": 411,
"high_wall_rack": 412,
"drill": 413,
"toolbag": 414
"toolbag": 415,
"work_room_table": 414,
"soda_can": 416,
"umbrella": 417,
}
obj_name_to_vision_prompt = {
"hammer": "red hammer tool",
"brush": "brush",
"measuring_tape": "small yellow measuring tape",
"bucket": "bucket",
"drill": "blue drill",
"toolbag": "carrybag"
"toolbag": "carrybag",
"soda_can": "soda can",
"umbrella": "umbrella"
}
vision_prompt_to_obj_name = {
value: key
Expand Down

0 comments on commit 838e4c2

Please sign in to comment.