From dc6ae7b3f997b590b036a625f4d3a6d8064d23f2 Mon Sep 17 00:00:00 2001 From: Linfeng Date: Sat, 11 May 2024 13:41:21 -0400 Subject: [PATCH] formatting --- predicators/envs/spot_env.py | 28 +++++++++---------- .../perception/object_perception.py | 4 ++- .../spot_utils/skills/spot_find_objects.py | 6 ++-- predicators/structs.py | 4 ++- predicators/utils.py | 9 +++--- 5 files changed, 26 insertions(+), 25 deletions(-) diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index 4942b5fb2f..6f2d3f7fbc 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -18,10 +18,9 @@ from bosdyn.client.sdk import Robot from bosdyn.client.util import authenticate, setup_logging from gym.spaces import Box -from predicators.utils import log_rich_table -from scipy.spatial import Delaunay -from rich.table import Table from rich import print +from rich.table import Table +from scipy.spatial import Delaunay from predicators import utils from predicators.envs import BaseEnv @@ -55,6 +54,7 @@ GroundAtom, LiftedAtom, Object, Observation, Predicate, \ SpotActionExtraInfo, State, STRIPSOperator, Type, Variable, \ VLMGroundAtom, VLMPredicate +from predicators.utils import log_rich_table ############################################################################### # Base Class # @@ -757,11 +757,11 @@ def _build_realworld_observation( objects = list(all_objects_in_view.keys()) vlm_atoms = get_vlm_atom_combinations(objects, vlm_predicates) vlm_atom_new: Dict[VLMGroundAtom, - bool or None] = vlm_predicate_batch_classify( - vlm_atoms, - rgbds, - predicates=vlm_predicates, - get_dict=True) + bool or None] = vlm_predicate_batch_classify( + vlm_atoms, + rgbds, + predicates=vlm_predicates, + get_dict=True) # Update VLM atom value if the new ground atom value is not None # Otherwise, use the value in current obs @@ -785,19 +785,17 @@ def _build_realworld_observation( table_compare.add_column("Atom", style="cyan") table_compare.add_column("Value (Last)", style="blue") table_compare.add_column("Value (New)", style="magenta") - vlm_atom_union = set(vlm_atom_new.keys()) | set(curr_obs.vlm_atom_dict.keys()) + vlm_atom_union = set(vlm_atom_new.keys()) | set( + curr_obs.vlm_atom_dict.keys()) for atom in vlm_atom_union: table_compare.add_row( - str(atom), - str(curr_obs.vlm_atom_dict.get(atom, None)), - str(vlm_atom_new.get(atom, None)) - ) + str(atom), str(curr_obs.vlm_atom_dict.get(atom, None)), + str(vlm_atom_new.get(atom, None))) logging.info(log_rich_table(table_compare)) logging.info( f"True VLM atoms (after updated with current obs): " - f"{dict(filter(lambda it: it[1], vlm_atom_return.items()))}" - ) + f"{dict(filter(lambda it: it[1], vlm_atom_return.items()))}") else: vlm_predicates = set() diff --git a/predicators/spot_utils/perception/object_perception.py b/predicators/spot_utils/perception/object_perception.py index 0ee8a5fb8b..66bd7549f7 100644 --- a/predicators/spot_utils/perception/object_perception.py +++ b/predicators/spot_utils/perception/object_perception.py @@ -190,7 +190,9 @@ def vlm_predicate_batch_classify( if len(queries) == 0: return {} - queries_print = [atom.get_query_str(include_prompt=False) for atom in atoms] + queries_print = [ + atom.get_query_str(include_prompt=False) for atom in atoms + ] logging.info(f"VLM predicate evaluation queries: {queries_print}") # Call VLM to evaluate the queries diff --git a/predicators/spot_utils/skills/spot_find_objects.py b/predicators/spot_utils/skills/spot_find_objects.py index ac0801228e..f6c1286767 100644 --- a/predicators/spot_utils/skills/spot_find_objects.py +++ b/predicators/spot_utils/skills/spot_find_objects.py @@ -120,10 +120,8 @@ def _find_objects_with_choreographed_moves( # Logging print(f"Calculated VLM atoms (in all views): {dict(all_vlm_atom_dict)}") - print( - f"True VLM atoms (in all views; with values as True): " - f"{dict(filter(lambda it: it[1], all_vlm_atom_dict.items()))}" - ) + print(f"True VLM atoms (in all views; with values as True): " + f"{dict(filter(lambda it: it[1], all_vlm_atom_dict.items()))}") table = Table(title="Evaluated VLM atoms (in all views)") table.add_column("Atom", style="cyan") diff --git a/predicators/structs.py b/predicators/structs.py index 688ab03f1a..ffc432bdf7 100644 --- a/predicators/structs.py +++ b/predicators/structs.py @@ -472,7 +472,9 @@ class VLMGroundAtom(GroundAtom): # NOTE: This subclasses GroundAtom to support VLM predicates and classifiers predicate: VLMPredicate - def get_query_str(self, without_type: bool = False, include_prompt: bool = True) -> str: + def get_query_str(self, + without_type: bool = False, + include_prompt: bool = True) -> str: """Get a query string for this ground atom. Instead of directly evaluating the ground atom, we will use the diff --git a/predicators/utils.py b/predicators/utils.py index 9964c9e5b7..5056de1f91 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -42,6 +42,7 @@ import matplotlib.pyplot as plt import numpy as np import pathos.multiprocessing as mp +import rich.table from bosdyn.client import math_helpers from gym.spaces import Box from matplotlib import patches @@ -49,10 +50,9 @@ from pyperplan.heuristics.heuristic_base import \ Heuristic as _PyperplanBaseHeuristic from pyperplan.planner import HEURISTICS as _PYPERPLAN_HEURISTICS -from scipy.stats import beta as BetaRV -import rich.table from rich.console import Console from rich.text import Text +from scipy.stats import beta as BetaRV from predicators.args import create_arg_parser from predicators.pybullet_helpers.joint import JointPositions @@ -3869,8 +3869,9 @@ def run_ground_nsrt_with_assertions(ground_nsrt: _GroundNSRT, def log_rich_table(rich_table: rich.table.Table) -> "Texssas": - """Generate an ascii formatted presentation of a Rich table - Eliminates any column styling + """Generate an ascii formatted presentation of a Rich table. + + Eliminates any column styling. """ console = Console(width=150) with console.capture() as capture: