Skip to content

Commit

Permalink
overwrite vlm predicate classifier; reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
lf-zhao committed May 9, 2024
1 parent 63a1429 commit ba2575f
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions predicators/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def allclose(self, other: State) -> bool:
"""Return whether this state is close enough to another one, i.e., its
objects are the same, and the features are close."""
if self.simulator_state is not None or \
other.simulator_state is not None:
other.simulator_state is not None:
raise NotImplementedError("Cannot use allclose when "
"simulator_state is not None.")

Expand All @@ -206,7 +206,7 @@ def pretty_str(self) -> str:
if obj.type not in type_to_table:
type_to_table[obj.type] = []
type_to_table[obj.type].append([obj.name] + \
list(map(str, self[obj])))
list(map(str, self[obj])))
table_strs = []
for t in sorted(type_to_table):
headers = ["type: " + t.name] + list(t.feature_names)
Expand Down Expand Up @@ -325,6 +325,8 @@ class VLMPredicate(Predicate):
at once.
"""

_classifier: Optional[Callable[[State, Sequence[Object]], bool]] = None

def holds(self, state: State, objects: Sequence[Object]) -> bool:
"""Public method for getting predicate value.
Expand Down Expand Up @@ -732,8 +734,8 @@ def pddl_str(self) -> str:
for i, t in enumerate(pred.types))
pred_eff_variables_str = " ".join(f"?x{i}"
for i in range(pred.arity))
effects_str += f"(forall ({pred_types_str})" +\
f" (not ({pred.name} {pred_eff_variables_str})))"
effects_str += f"(forall ({pred_types_str})" + \
f" (not ({pred.name} {pred_eff_variables_str})))"
effects_str += "\n "
return f"""(:action {self.name}
:parameters ({params_str})
Expand Down Expand Up @@ -1586,7 +1588,7 @@ def __post_init__(self) -> None:
# The preconditions and goal preconditions should only use variables in
# the rule parameters.
for atom in self.pos_state_preconditions | \
self.neg_state_preconditions | self.goal_preconditions:
self.neg_state_preconditions | self.goal_preconditions:
assert all(v in self.parameters for v in atom.variables)

@lru_cache(maxsize=None)
Expand Down

0 comments on commit ba2575f

Please sign in to comment.