Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
lf-zhao committed May 11, 2024
1 parent 5ebd449 commit dc6ae7b
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 25 deletions.
28 changes: 13 additions & 15 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from bosdyn.client.sdk import Robot
from bosdyn.client.util import authenticate, setup_logging
from gym.spaces import Box
from predicators.utils import log_rich_table
from scipy.spatial import Delaunay
from rich.table import Table
from rich import print
from rich.table import Table
from scipy.spatial import Delaunay

from predicators import utils
from predicators.envs import BaseEnv
Expand Down Expand Up @@ -55,6 +54,7 @@
GroundAtom, LiftedAtom, Object, Observation, Predicate, \
SpotActionExtraInfo, State, STRIPSOperator, Type, Variable, \
VLMGroundAtom, VLMPredicate
from predicators.utils import log_rich_table

###############################################################################
# Base Class #
Expand Down Expand Up @@ -757,11 +757,11 @@ def _build_realworld_observation(
objects = list(all_objects_in_view.keys())
vlm_atoms = get_vlm_atom_combinations(objects, vlm_predicates)
vlm_atom_new: Dict[VLMGroundAtom,
bool or None] = vlm_predicate_batch_classify(
vlm_atoms,
rgbds,
predicates=vlm_predicates,
get_dict=True)
bool or None] = vlm_predicate_batch_classify(
vlm_atoms,
rgbds,
predicates=vlm_predicates,
get_dict=True)

# Update VLM atom value if the new ground atom value is not None
# Otherwise, use the value in current obs
Expand All @@ -785,19 +785,17 @@ def _build_realworld_observation(
table_compare.add_column("Atom", style="cyan")
table_compare.add_column("Value (Last)", style="blue")
table_compare.add_column("Value (New)", style="magenta")
vlm_atom_union = set(vlm_atom_new.keys()) | set(curr_obs.vlm_atom_dict.keys())
vlm_atom_union = set(vlm_atom_new.keys()) | set(
curr_obs.vlm_atom_dict.keys())
for atom in vlm_atom_union:
table_compare.add_row(
str(atom),
str(curr_obs.vlm_atom_dict.get(atom, None)),
str(vlm_atom_new.get(atom, None))
)
str(atom), str(curr_obs.vlm_atom_dict.get(atom, None)),
str(vlm_atom_new.get(atom, None)))
logging.info(log_rich_table(table_compare))

logging.info(
f"True VLM atoms (after updated with current obs): "
f"{dict(filter(lambda it: it[1], vlm_atom_return.items()))}"
)
f"{dict(filter(lambda it: it[1], vlm_atom_return.items()))}")

else:
vlm_predicates = set()
Expand Down
4 changes: 3 additions & 1 deletion predicators/spot_utils/perception/object_perception.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ def vlm_predicate_batch_classify(
if len(queries) == 0:
return {}

queries_print = [atom.get_query_str(include_prompt=False) for atom in atoms]
queries_print = [
atom.get_query_str(include_prompt=False) for atom in atoms
]
logging.info(f"VLM predicate evaluation queries: {queries_print}")

# Call VLM to evaluate the queries
Expand Down
6 changes: 2 additions & 4 deletions predicators/spot_utils/skills/spot_find_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,8 @@ def _find_objects_with_choreographed_moves(

# Logging
print(f"Calculated VLM atoms (in all views): {dict(all_vlm_atom_dict)}")
print(
f"True VLM atoms (in all views; with values as True): "
f"{dict(filter(lambda it: it[1], all_vlm_atom_dict.items()))}"
)
print(f"True VLM atoms (in all views; with values as True): "
f"{dict(filter(lambda it: it[1], all_vlm_atom_dict.items()))}")

table = Table(title="Evaluated VLM atoms (in all views)")
table.add_column("Atom", style="cyan")
Expand Down
4 changes: 3 additions & 1 deletion predicators/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,9 @@ class VLMGroundAtom(GroundAtom):
# NOTE: This subclasses GroundAtom to support VLM predicates and classifiers
predicate: VLMPredicate

def get_query_str(self, without_type: bool = False, include_prompt: bool = True) -> str:
def get_query_str(self,
without_type: bool = False,
include_prompt: bool = True) -> str:
"""Get a query string for this ground atom.
Instead of directly evaluating the ground atom, we will use the
Expand Down
9 changes: 5 additions & 4 deletions predicators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@
import matplotlib.pyplot as plt
import numpy as np
import pathos.multiprocessing as mp
import rich.table
from bosdyn.client import math_helpers
from gym.spaces import Box
from matplotlib import patches
from numpy.typing import NDArray
from pyperplan.heuristics.heuristic_base import \
Heuristic as _PyperplanBaseHeuristic
from pyperplan.planner import HEURISTICS as _PYPERPLAN_HEURISTICS
from scipy.stats import beta as BetaRV
import rich.table
from rich.console import Console
from rich.text import Text
from scipy.stats import beta as BetaRV

from predicators.args import create_arg_parser
from predicators.pybullet_helpers.joint import JointPositions
Expand Down Expand Up @@ -3869,8 +3869,9 @@ def run_ground_nsrt_with_assertions(ground_nsrt: _GroundNSRT,


def log_rich_table(rich_table: rich.table.Table) -> "Texssas":
"""Generate an ascii formatted presentation of a Rich table
Eliminates any column styling
"""Generate an ascii formatted presentation of a Rich table.
Eliminates any column styling.
"""
console = Console(width=150)
with console.capture() as capture:
Expand Down

0 comments on commit dc6ae7b

Please sign in to comment.