diff --git a/prover/proof_search.py b/prover/proof_search.py index b1e6579..47ca54a 100644 --- a/prover/proof_search.py +++ b/prover/proof_search.py @@ -19,7 +19,6 @@ ProofGivenUp, DojoInitError, DojoCrashError, - DojoTacticTimeoutError, ) from loguru import logger from dataclasses import dataclass @@ -75,7 +74,8 @@ def __init__( self.num_expansions = 0 self.actor_time = 0.0 self.environment_time = 0.0 - self.total_time = None + self.total_time = 0 + self.time_start = None def search( self, repo: LeanGitRepo, thm: Theorem, pos: Pos @@ -106,11 +106,14 @@ def search( ) self.nodes = {init_state: self.root} + self.priority_queue = asyncio.PriorityQueue() + self.priority_queue.put_nowait((-self.root.priority, self.root)) + self.pending_request_ids = set() + try: asyncio.run(self._best_first_search()) except DojoCrashError as ex: logger.warning(f"Dojo crashed with {ex} when proving {thm}") - pass if self.root.status == Status.PROVED: proof = [e.tactic for e in self.root.extract_proof()] @@ -134,87 +137,62 @@ def search( logger.warning(ex) return None - async def _best_first_search(self) -> None: - time_start = time.time() + def get_remaining_time(self) -> float: + return self.timeout - (time.time() - self.time_start) - priority_queue = asyncio.PriorityQueue() - priority_queue.put_nowait((-self.root.priority, self.root)) + async def _best_first_search(self) -> None: + self.time_start = time.time() + current_task = asyncio.current_task() while True: - if priority_queue.empty(): - logger.info("Ran out of nodes to search.") - break - - try: - await self._step(priority_queue) - except DojoTacticTimeoutError: - assert time.time() - time_start >= self.timeout + while self.priority_queue.empty(): + other_tasks = [t for t in asyncio.all_tasks() if t is not current_task] + remaining_time = self.get_remaining_time() + if len(other_tasks) == 0 or remaining_time <= 0: + break + await asyncio.wait( + other_tasks, + timeout=remaining_time, + return_when=asyncio.FIRST_COMPLETED, + ) + if self.root.status == Status.PROVED: + break - self.total_time = time.time() - time_start - if self.total_time > self.timeout or ( - self.max_expansions is not None - and self.num_expansions > self.max_expansions + if ( + self.priority_queue.empty() + or self.get_remaining_time() <= 0 + or self.root.status == Status.PROVED ): - if self.root.status == Status.PROVED: - logger.info("Found a proof!") - self.root.status = Status.OPEN - logger.info("Hit the resource limit (timeout or max_expansions).") break - if self.root.status == Status.FAILED: - logger.info("Failed early!") - break + _, search_node = self.priority_queue.get_nowait() + logger.debug(f"Expanding node: {search_node}") + asyncio.create_task(self._expand(search_node), name="_expand") + await asyncio.sleep(0) + self.num_expansions += 1 - if self.root.status == Status.PROVED: - logger.info("Found a proof!") - break + self.total_time = time.time() - self.time_start + if self.total_time > self.timeout or ( + self.max_expansions is not None + and self.num_expansions > self.max_expansions + ): + self.root.status = Status.OPEN + logger.info("Hit the resource limit (timeout or max_expansions).") - async def _step(self, priority_queue): - """ - Perform a single step of search. + logger.info(f"Cancel {len(self.pending_request_ids)} vLLM requests") + await asyncio.gather( + *[self.tac_gen.cancel(req_id) for req_id in self.pending_request_ids] + ) - Selects the node with the highest priority, queries the model for suggested - tactics, and tries each tactic in the environment, creating and enqueuing - a new node for each valid result. - """ - # Search the node with highest priority. - try: - _, search_node = priority_queue.get_nowait() - except asyncio.QueueEmpty: - return - logger.debug(f"Expanding node: {search_node}") + # TODO: Hack vLLM scheduler to treat different workers equally but prioritize within each worker. - if isinstance(search_node.state, TacticState): - ts = search_node.state.pp - else: - ts = search_node.state.unsolved_tactic_state - suggestions = await self._generate_tactics(ts) - - # Try all tactics in order of descending logprob, and collect the results. Any - # new nodes are added to `self.nodes`, and edges are added to the result node. - results = [] + async def _expand(self, node: InternalNode) -> None: + assert isinstance(node.state, TacticState) + suggestions = await self._generate_tactics(node.state.pp) for tactic, logprob in suggestions: - edge, finished = self._run_tactic( - search_node, tactic, logprob, priority_queue - ) - results.append(edge) - if finished: - break - - # Store the fixed out edges of this node, marking it as explored. - # This will trigger recursively recomputing tree statistics. - search_node.out_edges = results - self.num_expansions += 1 - priority_queue.task_done() - - # If we're running in debug mode, run a full test suite each step - if self.debug: - assert self.num_expansions == sum( - node.is_explored - for node in self.nodes.values() - if isinstance(node, InternalNode) + asyncio.create_task( + self._run_tactic(node, tactic, logprob), name="_run_tactic" ) - self.check_invariants() @torch.no_grad() async def _generate_tactics(self, ts: str) -> List[Tuple[str, float]]: @@ -225,24 +203,32 @@ async def _generate_tactics(self, ts: str) -> List[Tuple[str, float]]: if self.theorem.repo != self.repo: path = self.theorem.repo.get_packages_dir() / self.theorem.repo.name / path + req_id = str(uuid.uuid4().hex) + self.pending_request_ids.add(req_id) suggestions = await self.tac_gen.generate( + req_id, state=ts, file_path=path, theorem_full_name=self.theorem.full_name, theorem_pos=self.posision, num_samples=self.num_sampled_tactics, ) + self.pending_request_ids.remove(req_id) self.actor_time += time.time() - t0 logger.debug(f"Tactic suggestions: {suggestions}") return suggestions - def _run_tactic( - self, node: InternalNode, tactic: str, logprob: float, priority_queue - ) -> Tuple[Edge, bool]: + async def _run_tactic( + self, node: InternalNode, tactic: str, logprob: float + ) -> Edge: t0 = time.time() - response = self.dojo.run_tac(node.state, tactic) + response = self.dojo.run_tac( + node.state, tactic + ) # TODO: What if this blocks higher priority requests? + # TODO: We can have more precise synchronization around this critical section? + logger.debug(response) elapsed = time.time() - t0 self.environment_time += elapsed @@ -268,7 +254,7 @@ def _run_tactic( ) if result_node.status == Status.OPEN: # Don't search proved/failed nodes - priority_queue.put_nowait((-result_node.priority, result_node)) + self.priority_queue.put_nowait((-result_node.priority, result_node)) # Record the new node and add it to the search queue. self.nodes[response] = result_node @@ -277,31 +263,19 @@ def _run_tactic( # Will be added to the source node externally. edge = Edge(tactic=tactic, src=node, dst=result_node) + node.add_out_edge(edge, self.num_sampled_tactics) if isinstance(result_node, InternalNode): result_node.in_edges.append(edge) - return edge, isinstance(response, ProofFinished) + if node.status == Status.PROVED: + logger.debug("Found a proof!") - ######### - # DEBUG # - ######### + current_task = asyncio.current_task() + for t in asyncio.all_tasks(): + if t is not current_task and t.get_name() in ("_expand", "_run_tactic"): + t.cancel() # TODO: Use TaskGroup instead of cancelling tasks manually - def check_invariants(self): - """Perform some sanity checks.""" - - for response, node in self.nodes.items(): - if isinstance(response, ProofFinished): - assert isinstance(node, ProofFinishedNode) - assert self.root.status == Status.PROVED - elif type(response) in ( - LeanError, - TimeoutError, - ProofGivenUp, - ): - assert isinstance(node, ErrorNode) - else: - assert isinstance(node, InternalNode) - node.check_invariants() + return edge @ray.remote @@ -350,7 +324,9 @@ def initialize(self) -> None: ) self.engine = AsyncLLMEngine.from_engine_args(engine_args) - async def generate(self, prompt: str, num_samples: int) -> RequestOutput: + async def generate( + self, req_id: str, prompt: str, num_samples: int + ) -> RequestOutput: sampling_params = SamplingParams( n=num_samples, temperature=0, @@ -360,12 +336,14 @@ async def generate(self, prompt: str, num_samples: int) -> RequestOutput: logprobs=0, ) - async for oup in self.engine.generate( - prompt, sampling_params, request_id=str(uuid.uuid4().hex) - ): + async for oup in self.engine.generate(prompt, sampling_params, req_id): final_output = oup + return final_output + async def cancel(self, req_id: str) -> None: + await self.engine.abort(req_id) + class DistributedProver: """A distributed prover that uses Ray to parallelize the proof search. diff --git a/prover/search_tree.py b/prover/search_tree.py index 6959240..b8c7332 100644 --- a/prover/search_tree.py +++ b/prover/search_tree.py @@ -81,17 +81,11 @@ class InternalNode(Node): default_factory=list, init=False, compare=False, repr=False ) - # All edges out of this node that we've considered, or None for unexplored nodes. - # When a node is explored, this list is populated, and must not change after that. - _out_edges: Optional[List["Edge"]] = field( - default=None, init=False, compare=False, repr=False + out_edges: List["Edge"] = field( + default_factory=list, init=False, compare=False, repr=False ) - # A node is proved if any child is proved, and failed if every child is failed - # (or there are no children). A node that is proved or failed cannot change status - # because nothing is ever added to out_edges. _status is recomputed on an as-needed - # basis by children, since proving or failing a child may prove or fail this node. - _status: Status = field(default=Status.OPEN, init=False, compare=False, repr=True) + status: Status = field(default=Status.OPEN, init=False, compare=False, repr=True) is_terminal = False # type: ignore[override] @@ -101,18 +95,10 @@ class InternalNode(Node): default=math.inf, init=False, compare=False, repr=False ) - @property - def out_edges(self): - return self._out_edges - - # This setter implements exploring this node - @out_edges.setter - def out_edges(self, out_edges: Iterable["Edge"]) -> Optional[List["Edge"]]: - if self.is_explored: - raise RuntimeError("Node is already explored.") - - self._out_edges = list(out_edges) - self._recompute_status() + def add_out_edge(self, e: "Edge", max_num_edges: int) -> None: + assert e.src is self and e not in self.out_edges + self.out_edges.append(e) + self._recompute_status(max_num_edges) self._recompute_distance_to_proof() # A node is considered explored if we've evaluated the actor in the node to generate @@ -121,36 +107,30 @@ def out_edges(self, out_edges: Iterable["Edge"]) -> Optional[List["Edge"]]: def is_explored(self) -> bool: return self.out_edges is not None - @property - def status(self) -> Status: - return self._status - - @status.setter - def status(self, s): - self._status = s - - def _recompute_status(self): + def _recompute_status(self, max_num_edges: int): """ Recursively update the status of the current node and its ancestors. """ assert self.is_explored and self.out_edges is not None # If this node is proved or failed, nothing can change that - if self._status != Status.OPEN: + if self.status != Status.OPEN: return # If any child is proved, this node is proved, and so are parents recursively if any(edge.dst.status == Status.PROVED for edge in self.out_edges): - self._status = Status.PROVED + self.status = Status.PROVED # If all children failed, this node is failed. This may fail some parents too. - if all(edge.dst.status == Status.FAILED for edge in self.out_edges): - self._status = Status.FAILED + if len(self.out_edges) == max_num_edges and all( + edge.dst.status == Status.FAILED for edge in self.out_edges + ): + self.status = Status.FAILED # If this node was proved or failed, parents may need to recompute. # This is guaranteed to terminate because only open nodes can change, and # there are a finite number of open nodes in the tree. - if self._status != Status.OPEN: + if self.status != Status.OPEN: for edge in self.in_edges: edge.src._recompute_status() @@ -204,48 +184,6 @@ def extract_proof(self) -> Optional[List["Edge"]]: assert child_proof return [proving_edge, *child_proof] - ######### - # Debug # - ######### - - def check_invariants(self): - """ - Perform some sanity checks. - """ - if not self.is_explored: - assert self.status == Status.OPEN - return # Nothing more can be said about unexplored nodes - - for edge in self.in_edges: - assert edge.dst is self - - if self.out_edges == []: - assert self.status == Status.FAILED - else: - for edge in self.out_edges: # type: ignore - assert edge.src is self - - if self.status == Status.PROVED: - assert self.out_edges - assert any(edge.dst.status == Status.PROVED for edge in self.out_edges) - assert all(edge.dst.status == Status.PROVED for edge in self.in_edges) - - proof_by_steps = self.extract_proof() - assert proof_by_steps is not None - assert self.distance_to_proof == len(proof_by_steps) - - elif self.status == Status.FAILED: - assert self.out_edges is not None - assert all(edge.dst.status == Status.FAILED for edge in self.out_edges) - assert self.distance_to_proof == math.inf - assert self.extract_proof() == None - elif self.status == Status.OPEN: - assert self.out_edges - assert not any(edge.dst.status == Status.PROVED for edge in self.out_edges) - assert not all(edge.dst.status == Status.FAILED for edge in self.out_edges) - assert self.distance_to_proof == math.inf - assert self.extract_proof() == None - @dataclass class Edge: diff --git a/prover/tactic_generator.py b/prover/tactic_generator.py index 5e01960..1098ee4 100644 --- a/prover/tactic_generator.py +++ b/prover/tactic_generator.py @@ -308,6 +308,7 @@ def initialize(self) -> None: async def generate( self, + req_id: str, state: str, file_path: str, theorem_full_name: str, @@ -316,8 +317,11 @@ async def generate( ) -> List[Tuple[str, float]]: # prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n[GOAL]\n{state}\n[PROOFSTEP]\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" prompt = self.template % state - response = await self.vllm_actor.generate.remote(prompt, num_samples) + response = await self.vllm_actor.generate.remote(req_id, prompt, num_samples) return [ (remove_marks(x.text).strip(), x.cumulative_logprob) for x in response.outputs ] + + async def cancel(self, req_id: str) -> None: + await self.vllm_actor.cancel.remote(req_id)