diff --git a/macq/extract/slaf.py b/macq/extract/slaf.py index 40248460..2b730f2b 100644 --- a/macq/extract/slaf.py +++ b/macq/extract/slaf.py @@ -43,7 +43,7 @@ class SLAF: top = true bottom = false - def __new__(cls, o_list: ObservedTraceList, debug: bool = False): + def __new__(cls, o_list: ObservedTraceList, debug: bool = False, sample: bool = False): """Creates a new Model object. Args: @@ -52,6 +52,9 @@ def __new__(cls, o_list: ObservedTraceList, debug: bool = False): debug_mode (bool): An optional mode that helps the user track any fluents they desire by examining the evolution of their fluent-factored formulas through the steps. + sample (bool): + An optional mode that allows the user to sample the possible models instead of returning + one that includes only guaranteed entailed fluents. Raises: IncompatibleObservationToken: Raised if the observations are not identity observation. @@ -63,7 +66,7 @@ def __new__(cls, o_list: ObservedTraceList, debug: bool = False): raise Exception("The SLAF extraction technique only takes one trace.") SLAF.debug_mode = debug - entailed = SLAF.__as_strips_slaf(o_list) + entailed = SLAF.__as_strips_slaf(o_list, sample) # return the Model return SLAF.__sort_results(o_list, entailed) @@ -197,7 +200,9 @@ def __sort_results(observations: ObservedTraceList, entailed: Set): precond = info_split[0] action = info_split[1] # update the precondition of this action with the appropriate fluent - learned_actions[action].update_precond({precond}) + # only if it's a positive precondition + if "~" not in precond: + learned_actions[action].update_precond({precond}) # if this proposition holds information about an effect elif effect in e: # split to separate effect and action, get rid of extra brackets @@ -216,7 +221,7 @@ def __sort_results(observations: ObservedTraceList, entailed: Set): return Model(model_fluents, set(learned_actions.values())) @staticmethod - def __as_strips_slaf(o_list: ObservedTraceList): + def __as_strips_slaf(o_list: ObservedTraceList, sample: bool): """Implements the AS-STRIPS-SLAF algorithm from section 5.3 of the SLAF paper. Iterates through the action/observation pairs of each observation/trace, returning a fluent-factored transition belief formula that filters according to that action/observation. @@ -228,6 +233,9 @@ def __as_strips_slaf(o_list: ObservedTraceList): The list of observations/traces to apply the filtering algorithm to. NOTE: with the current implementation, SLAF only works with a single trace. + sample (bool): + If true, an arbitrary solution will be produced rather than just the entailed literals. + Returns: The set of fluents that are entailed. """ @@ -467,14 +475,22 @@ def __as_strips_slaf(o_list: ObservedTraceList): cnf_formula = And(map(SLAF.__or_refactor, full_formula.children)) entailed = set() - children = set(cnf_formula.children) - # iterate through all fluents, gathering those that are entailed - for f in all_var: - children.add(Or([~f])) - check_theory = And(children) - # if False, then f is entailed + if sample: with config(sat_backend="kissat"): - if not check_theory.solve(): - entailed.add(f) - children.discard(Or([~f])) + sol = cnf_formula.solve() + if sol: + for f in all_var: + if sol[str(f)]: + entailed.add(f) + else: + children = set(cnf_formula.children) + # iterate through all fluents, gathering those that are entailed + for f in all_var: + children.add(Or([~f])) + check_theory = And(children) + # if False, then f is entailed + with config(sat_backend="kissat"): + if not check_theory.solve(): + entailed.add(f) + children.discard(Or([~f])) return entailed