diff --git a/macq/extract/learned_fluent.py b/macq/extract/learned_fluent.py index 59e2c1c6..37305a2c 100644 --- a/macq/extract/learned_fluent.py +++ b/macq/extract/learned_fluent.py @@ -50,7 +50,7 @@ def __hash__(self): return hash(self.details()) def __str__(self): - return self.details() + return self.details() + f" {self.param_act_inds}" def __repr__(self): return self.details() diff --git a/macq/extract/locm.py b/macq/extract/locm.py index 9042fa5f..82d6fe0d 100644 --- a/macq/extract/locm.py +++ b/macq/extract/locm.py @@ -25,7 +25,7 @@ class AP: sort: int def __repr__(self) -> str: - return f"{self.action.name}.{self.pos}" + return f"{self.action.name}.{self.pos} [sort{self.sort}]" def __hash__(self): return hash(self.action.name + str(self.pos)) @@ -106,6 +106,7 @@ class Hypothesis: another argument of the same sort G′ in position k′ and l′ respectively, we hypothesise that there may be a relation between sorts G and G′." """ + S: int B: AP k: int @@ -131,10 +132,26 @@ def __hash__(self) -> int: ) def __repr__(self) -> str: - out = "<\n" - for k, v in asdict(self).items(): - out += f" {k}={v}\n" - return out.strip() + "\n>" + out = f"\n [->B->S->C->] = ({str(self.B)}) -> {str(self.S)} -> ({str(self.C)})" + out += "\n\tG\tG'\tS\tk\tk'\tl\tl'" + out += ( + "\n\t" + + str(self.G) + + "\t" + + str(self.G_) + + "\t" + + str(self.S) + + "\t" + + str(self.k) + + "\t" + + str(self.k_) + + "\t" + + str(self.l) + + "\t" + + str(self.l_) + + "\n" + ) + return out @staticmethod def from_dict( @@ -216,7 +233,12 @@ def __new__( sorts = LOCM._get_sorts(obs_trace, debug=debug["get_sorts"]) if debug["sorts"]: - print(f"Sorts:\n{sorts}", end="\n\n") + sortid2objs = {v: [] for v in set(sorts.values())} + for k, v in sorts.items(): + sortid2objs[v].append(k) + print("\nSorts:\n") + pprint(sortid2objs) + print("\n") TS, ap_state_pointers, OS = LOCM._step1(obs_trace, sorts, debug["step1"]) HS = LOCM._step3(TS, ap_state_pointers, OS, sorts, debug["step3"]) @@ -229,13 +251,9 @@ def __new__( bindings, statics if statics is not None else {}, debug["step7"], + viz, ) - if viz: - state_machines = LOCM.get_state_machines(ap_state_pointers, OS, bindings) - for sm in state_machines: - sm.render(view=view) - return Model(fluents, actions) @staticmethod @@ -422,9 +440,12 @@ def _step1( ap_state_pointers = defaultdict(dict) # iterate over each object and its action sequence for obj, seq in obj_traces.items(): - state_n = 1 # count current (new) state id sort = sorts[obj.name] if obj != zero_obj else 0 TS[sort][obj] = seq # add the sequence to the transition set + # max of the states already in OS[sort], plus 1 + state_n = ( + max([max(s) for s in OS[sort]] + [0]) + 1 + ) # count current (new) state id prev_states: StatePointers = None # type: ignore # iterate over each transition A.P in the sequence for ap in seq: @@ -453,6 +474,7 @@ def _step1( OS[sort][prev_end_state] ) OS[sort].pop(prev_end_state) + assert len(set.union(*OS[sort])) == sum([len(s) for s in OS[sort]]) prev_states = ap_states @@ -548,7 +570,7 @@ def _step3( # Remove any unsupported hypotheses (but yet undisputed) for hind, hs in HS.copy().items(): - for h in hs: + for h in hs.copy(): if not h.supported: hs.remove(h) if len(hs) == 0: @@ -558,11 +580,8 @@ def _step3( return Hypothesis.from_dict(HS) @staticmethod - def _step4( - HS: Dict[int, Dict[int, Set[Hypothesis]]], debug: bool = False - ) -> Bindings: + def _step4(HS: Hypotheses, debug: bool = False) -> Bindings: """Step 4: Creation and merging of state parameters""" - # bindings = {sort: {state: [(hypothesis, state param)]}} bindings: Bindings = defaultdict(dict) for sort, hs_sort in HS.items(): @@ -592,7 +611,7 @@ def _step4( # check if hypothesis parameters (v1 & v2) need to be unified if ( (h1.B == h2.B and h1.k == h2.k and h1.k_ == h2.k_) - or + and # See https://github.com/AI-Planning/macq/discussions/200 (h1.C == h2.C and h1.l == h2.l and h1.l_ == h2.l_) # fmt: skip ): v1 = state_bindings[h1] @@ -609,6 +628,10 @@ def _step4( state_params.pop(P2) state_param_pointers[v2] = P1 + # fix state_param_pointers after v2 + for ind in range(v2 + 1, len(state_param_pointers)): + state_param_pointers[ind] -= 1 + # add state bindings for the sort to the output bindings # replacing hypothesis params with actual state params bindings[sort][state] = [ @@ -620,7 +643,7 @@ def _step4( @staticmethod def _step5( - HS: Dict[int, Dict[int, Set[Hypothesis]]], + HS: Hypotheses, bindings: Bindings, debug: bool = False, ) -> Bindings: @@ -628,12 +651,12 @@ def _step5( # check each bindings[G][S] -> (h, P) for sort, hs_sort in HS.items(): - for state in hs_sort: + for state_id in hs_sort: # track all the h.Bs that occur in bindings[G][S] all_hB = set() # track the set of h.B that set parameter P sets_P = defaultdict(set) - for h, P in bindings[sort][state]: + for h, P in bindings[sort][state_id]: sets_P[P].add(h.B) all_hB.add(h.B) @@ -642,11 +665,27 @@ def _step5( for P, setby in sets_P.items(): if not setby == all_hB: # P is a flawed parameter # remove all bindings referencing P - for h, P_ in bindings[sort][state].copy(): + for h, P_ in bindings[sort][state_id].copy(): if P_ == P: - bindings[sort][state].remove(Binding(h, P_)) - if len(bindings[sort][state]) == 0: - del bindings[sort][state] + bindings[sort][state_id].remove(Binding(h, P_)) + if len(bindings[sort][state_id]) == 0: + del bindings[sort][state_id] + + # do the same for checking h.C reading parameter P + # See https://github.com/AI-Planning/macq/discussions/200 + all_hC = set() + reads_P = defaultdict(set) + if state_id in bindings[sort]: + for h, P in bindings[sort][state_id]: + reads_P[P].add(h.C) + all_hC.add(h.C) + for P, readby in reads_P.items(): + if not readby == all_hC: + for h, P_ in bindings[sort][state_id].copy(): + if P_ == P: + bindings[sort][state_id].remove(Binding(h, P_)) + if len(bindings[sort][state_id]) == 0: + del bindings[sort][state_id] for k, v in bindings.copy().items(): if not v: @@ -655,41 +694,45 @@ def _step5( return bindings @staticmethod - def get_state_machines( - ap_state_pointers: APStatePointers, - OS: OSType, - bindings: Optional[Bindings] = None, - ): - from graphviz import Digraph + def _debug_state_machines(OS, ap_state_pointers, state_params): + import os + + import networkx as nx - state_machines = [] - for (sort, trans), states in zip(ap_state_pointers.items(), OS.values()): - graph = Digraph(f"LOCM-step1-sort{sort}") - for state in range(len(states)): - label = f"state{state}" + for sort in OS: + G = nx.DiGraph() + for n in range(len(OS[sort])): + lbl = f"state{n}" if ( - bindings is not None - and sort in bindings - and state in bindings[sort] + state_params is not None + and sort in state_params + and n in state_params[sort] ): - label += f"\n[" - params = [] - for binding in bindings[sort][state]: - params.append(f"{binding.hypothesis.G_}") - label += f",".join(params) - label += f"]" - graph.node(str(state), label=label, shape="oval") - for ap, apstate in trans.items(): + lbl += str( + [ + state_params[sort][n][v] + for v in sorted(state_params[sort][n].keys()) + ] + ) + G.add_node(n, label=lbl, shape="oval") + for ap, apstate in ap_state_pointers[sort].items(): start_idx, end_idx = LOCM._pointer_to_set( - states, apstate.start, apstate.end - ) - graph.edge( - str(start_idx), str(end_idx), label=f"{ap.action.name}.{ap.pos}" + OS[sort], apstate.start, apstate.end ) - - state_machines.append(graph) - - return state_machines + # check if edge is already in graph + if G.has_edge(start_idx, end_idx): + # append to the edge label + G.edges[start_idx, end_idx][ + "label" + ] += f"\n{ap.action.name}.{ap.pos}" + else: + G.add_edge(start_idx, end_idx, label=f"{ap.action.name}.{ap.pos}") + # write to dot file + nx.drawing.nx_pydot.write_dot(G, f"LOCM-step7-sort{sort}.dot") + os.system( + f"dot -Tpng LOCM-step7-sort{sort}.dot -o LOCM-step7-sort{sort}.png" + ) + os.system(f"rm LOCM-step7-sort{sort}.dot") @staticmethod def _step7( @@ -699,6 +742,7 @@ def _step7( bindings: Bindings, statics: Statics, debug: bool = False, + viz: bool = False, ) -> Tuple[Set[LearnedLiftedFluent], Set[LearnedLiftedAction]]: """Step 7: Formation of PDDL action schema Implicitly includes Step 6 (statics) by including statics as an argument @@ -710,114 +754,122 @@ def _step7( del OS[0] del ap_state_pointers[0] - if debug: - print("ap state pointers") - pprint(ap_state_pointers) - print() - - print("OS:") - pprint(OS) - print() - - print("bindings:") - pprint(bindings) - print() - - bound_param_sorts = { - sort: { - state: [ - binding.hypothesis.G_ - for binding in bindings.get(sort, {}).get(state, []) - ] - for state in range(len(states)) - } - for sort, states in OS.items() - } - - actions = {} - fluents = defaultdict(dict) - + # all_aps = {action_name: [AP]} all_aps: Dict[str, List[AP]] = defaultdict(list) for aps in ap_state_pointers.values(): for ap in aps: all_aps[ap.action.name].append(ap) - for action, aps in all_aps.items(): - actions[action] = LearnedLiftedAction( - action, [f"sort{ap.sort}" for ap in aps] - ) - - @dataclass - class TemplateFluent: - name: str - param_sorts: List[str] + state_params = defaultdict(dict) + state_params_to_hyps = defaultdict(dict) + for sort in bindings: + state_params[sort] = defaultdict(dict) + state_params_to_hyps[sort] = defaultdict(dict) + for state in bindings[sort]: + keys = {b.param for b in bindings[sort][state]} + typ = None + for key in keys: + hyps = [ + b.hypothesis for b in bindings[sort][state] if b.param == key + ] + # assert that all are the same G_ + assert len(set([h.G_ for h in hyps])) == 1 + state_params[sort][state][key] = hyps[0].G_ + state_params_to_hyps[sort][state][key] = hyps - def __hash__(self) -> int: - return hash(self.name + "".join(self.param_sorts)) + if viz: + LOCM._debug_state_machines(OS, ap_state_pointers, state_params) - for sort, state_bindings in bound_param_sorts.items(): - for state, bound_sorts in state_bindings.items(): - fluents[sort][state] = TemplateFluent( - f"sort{sort}_state{state}", - [f"sort{sort}"] + [f"sort{s}" for s in bound_sorts], - ) + fluents = defaultdict(dict) + actions = {} + for sort in ap_state_pointers: + sort_str = f"sort{sort}" + for ap in ap_state_pointers[sort]: + if ap.action.name not in actions: + actions[ap.action.name] = LearnedLiftedAction( + ap.action.name, + [None for _ in range(len(all_aps[ap.action.name]))], # type: ignore + ) + a = actions[ap.action.name] + a.param_sorts[ap.pos - 1] = sort_str - for (sort, aps), states in zip(ap_state_pointers.items(), OS.values()): - for ap, pointers in aps.items(): + start_pointer, end_pointer = ap_state_pointers[sort][ap] start_state, end_state = LOCM._pointer_to_set( - states, pointers.start, pointers.end + OS[sort], start_pointer, end_pointer ) - # preconditions += fluent for origin state - start_fluent_temp = fluents[sort][start_state] - - bound_param_inds = [] - - # for each bindings on the start state (if there are any) - # then add each binding.hypothesis.l_ - if sort in bindings and start_state in bindings[sort]: - bound_param_inds = [ - b.hypothesis.l_ - 1 for b in bindings[sort][start_state] - ] - - start_fluent = LearnedLiftedFluent( - start_fluent_temp.name, - start_fluent_temp.param_sorts, - [ap.pos - 1] + bound_param_inds, - ) - fluents[sort][start_state] = start_fluent - actions[ap.action.name].update_precond(start_fluent) - - if start_state != end_state: - # del += fluent for origin state - actions[ap.action.name].update_delete(start_fluent) - - # add += fluent for destination state - end_fluent_temp = fluents[sort][end_state] - bound_param_inds = [] - if sort in bindings and end_state in bindings[sort]: - bound_param_inds = [ - b.hypothesis.l_ - 1 for b in bindings[sort][end_state] - ] - end_fluent = LearnedLiftedFluent( - end_fluent_temp.name, - end_fluent_temp.param_sorts, - [ap.pos - 1] + bound_param_inds, + start_fluent_name = f"sort{sort}_state{start_state}" + if start_fluent_name not in fluents[ap.action.name]: + start_fluent = LearnedLiftedFluent( + start_fluent_name, + param_sorts=[sort_str], + param_act_inds=[ap.pos - 1], ) - fluents[sort][end_state] = end_fluent - actions[ap.action.name].update_add(end_fluent) + fluents[ap.action.name][start_fluent_name] = start_fluent - fluents = set(fluent for sort in fluents.values() for fluent in sort.values()) - actions = set(actions.values()) + start_fluent = fluents[ap.action.name][start_fluent_name] + + if ( + sort in state_params_to_hyps + and start_state in state_params_to_hyps[sort] + ): + for param in state_params_to_hyps[sort][start_state]: + psort = None + pind = None + for hyp in state_params_to_hyps[sort][start_state][param]: + if hyp.C == ap: + assert psort is None or psort == hyp.G_ + assert pind is None or pind == hyp.l_ + psort = hyp.G_ + pind = hyp.l_ + assert psort is not None + assert pind is not None + start_fluent.param_sorts.append(f"sort{psort}") + start_fluent.param_act_inds.append(pind - 1) + + a.update_precond(start_fluent) + + if end_state != start_state: + end_fluent_name = f"sort{sort}_state{end_state}" + if end_fluent_name not in fluents[ap.action.name]: + end_fluent = LearnedLiftedFluent( + end_fluent_name, + param_sorts=[sort_str], + param_act_inds=[ap.pos - 1], + ) + fluents[ap.action.name][end_fluent_name] = end_fluent + + end_fluent = fluents[ap.action.name][end_fluent_name] + + if ( + sort in state_params_to_hyps + and end_state in state_params_to_hyps[sort] + ): + for param in state_params_to_hyps[sort][end_state]: + psort = None + pind = None + for hyp in state_params_to_hyps[sort][end_state][param]: + if hyp.B == ap: + assert psort is None or psort == hyp.G_ + assert pind is None or pind == hyp.k_ + psort = hyp.G_ + pind = hyp.k_ + assert psort is not None + assert pind is not None + end_fluent.param_sorts.append(f"sort{psort}") + end_fluent.param_act_inds.append(pind - 1) + + a.update_delete(start_fluent) + a.update_add(end_fluent) # Step 6: Extraction of static preconditions - for action in actions: + for action in actions.values(): if action.name in statics: for static in statics[action.name]: action.update_precond(static) - if debug: - pprint(fluents) - pprint(actions) - - return fluents, actions + return set( + fluent + for action_fluents in fluents.values() + for fluent in action_fluents.values() + ), set(actions.values()) diff --git a/setup.py b/setup.py index 549107f5..023d33ec 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,8 @@ "numpy", "clingo", "graphviz", + "networkx", + "pydot", ] DEV_DEPENDENCIES = [ diff --git a/tests/extract/test_locm.py b/tests/extract/test_locm.py index c92b0683..7bdab418 100644 --- a/tests/extract/test_locm.py +++ b/tests/extract/test_locm.py @@ -365,9 +365,7 @@ def test_locm_step4(HS=None, is_test=True): for G, bG in bindings.items(): for S, bGS in bG.items(): print(f"\nG={G}, S={S}") - for h, v in bGS: - print(f"{h} -> {v}\n") - assert v == 0 + assert 4 == len({v for _, v in bGS}) else: return bindings