Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
yangky11 committed Aug 25, 2024
1 parent 903fac4 commit b37ca6e
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 178 deletions.
178 changes: 78 additions & 100 deletions prover/proof_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
ProofGivenUp,
DojoInitError,
DojoCrashError,
DojoTacticTimeoutError,
)
from loguru import logger
from dataclasses import dataclass
Expand Down Expand Up @@ -75,7 +74,8 @@ def __init__(
self.num_expansions = 0
self.actor_time = 0.0
self.environment_time = 0.0
self.total_time = None
self.total_time = 0
self.time_start = None

def search(
self, repo: LeanGitRepo, thm: Theorem, pos: Pos
Expand Down Expand Up @@ -106,11 +106,14 @@ def search(
)
self.nodes = {init_state: self.root}

self.priority_queue = asyncio.PriorityQueue()
self.priority_queue.put_nowait((-self.root.priority, self.root))
self.pending_request_ids = set()

try:
asyncio.run(self._best_first_search())
except DojoCrashError as ex:
logger.warning(f"Dojo crashed with {ex} when proving {thm}")
pass

if self.root.status == Status.PROVED:
proof = [e.tactic for e in self.root.extract_proof()]
Expand All @@ -134,87 +137,62 @@ def search(
logger.warning(ex)
return None

async def _best_first_search(self) -> None:
time_start = time.time()
def get_remaining_time(self) -> float:
return self.timeout - (time.time() - self.time_start)

priority_queue = asyncio.PriorityQueue()
priority_queue.put_nowait((-self.root.priority, self.root))
async def _best_first_search(self) -> None:
self.time_start = time.time()
current_task = asyncio.current_task()

while True:
if priority_queue.empty():
logger.info("Ran out of nodes to search.")
break

try:
await self._step(priority_queue)
except DojoTacticTimeoutError:
assert time.time() - time_start >= self.timeout
while self.priority_queue.empty():
other_tasks = [t for t in asyncio.all_tasks() if t is not current_task]
remaining_time = self.get_remaining_time()
if len(other_tasks) == 0 or remaining_time <= 0:
break
await asyncio.wait(
other_tasks,
timeout=remaining_time,
return_when=asyncio.FIRST_COMPLETED,
)
if self.root.status == Status.PROVED:
break

self.total_time = time.time() - time_start
if self.total_time > self.timeout or (
self.max_expansions is not None
and self.num_expansions > self.max_expansions
if (
self.priority_queue.empty()
or self.get_remaining_time() <= 0
or self.root.status == Status.PROVED
):
if self.root.status == Status.PROVED:
logger.info("Found a proof!")
self.root.status = Status.OPEN
logger.info("Hit the resource limit (timeout or max_expansions).")
break

if self.root.status == Status.FAILED:
logger.info("Failed early!")
break
_, search_node = self.priority_queue.get_nowait()
logger.debug(f"Expanding node: {search_node}")
asyncio.create_task(self._expand(search_node), name="_expand")
await asyncio.sleep(0)
self.num_expansions += 1

if self.root.status == Status.PROVED:
logger.info("Found a proof!")
break
self.total_time = time.time() - self.time_start
if self.total_time > self.timeout or (
self.max_expansions is not None
and self.num_expansions > self.max_expansions
):
self.root.status = Status.OPEN
logger.info("Hit the resource limit (timeout or max_expansions).")

async def _step(self, priority_queue):
"""
Perform a single step of search.
logger.info(f"Cancel {len(self.pending_request_ids)} vLLM requests")
await asyncio.gather(
*[self.tac_gen.cancel(req_id) for req_id in self.pending_request_ids]
)

Selects the node with the highest priority, queries the model for suggested
tactics, and tries each tactic in the environment, creating and enqueuing
a new node for each valid result.
"""
# Search the node with highest priority.
try:
_, search_node = priority_queue.get_nowait()
except asyncio.QueueEmpty:
return
logger.debug(f"Expanding node: {search_node}")
# TODO: Hack vLLM scheduler to treat different workers equally but prioritize within each worker.

if isinstance(search_node.state, TacticState):
ts = search_node.state.pp
else:
ts = search_node.state.unsolved_tactic_state
suggestions = await self._generate_tactics(ts)

# Try all tactics in order of descending logprob, and collect the results. Any
# new nodes are added to `self.nodes`, and edges are added to the result node.
results = []
async def _expand(self, node: InternalNode) -> None:
assert isinstance(node.state, TacticState)
suggestions = await self._generate_tactics(node.state.pp)
for tactic, logprob in suggestions:
edge, finished = self._run_tactic(
search_node, tactic, logprob, priority_queue
)
results.append(edge)
if finished:
break

# Store the fixed out edges of this node, marking it as explored.
# This will trigger recursively recomputing tree statistics.
search_node.out_edges = results
self.num_expansions += 1
priority_queue.task_done()

# If we're running in debug mode, run a full test suite each step
if self.debug:
assert self.num_expansions == sum(
node.is_explored
for node in self.nodes.values()
if isinstance(node, InternalNode)
asyncio.create_task(
self._run_tactic(node, tactic, logprob), name="_run_tactic"
)
self.check_invariants()

@torch.no_grad()
async def _generate_tactics(self, ts: str) -> List[Tuple[str, float]]:
Expand All @@ -225,24 +203,32 @@ async def _generate_tactics(self, ts: str) -> List[Tuple[str, float]]:
if self.theorem.repo != self.repo:
path = self.theorem.repo.get_packages_dir() / self.theorem.repo.name / path

req_id = str(uuid.uuid4().hex)
self.pending_request_ids.add(req_id)
suggestions = await self.tac_gen.generate(
req_id,
state=ts,
file_path=path,
theorem_full_name=self.theorem.full_name,
theorem_pos=self.posision,
num_samples=self.num_sampled_tactics,
)
self.pending_request_ids.remove(req_id)

self.actor_time += time.time() - t0

logger.debug(f"Tactic suggestions: {suggestions}")
return suggestions

def _run_tactic(
self, node: InternalNode, tactic: str, logprob: float, priority_queue
) -> Tuple[Edge, bool]:
async def _run_tactic(
self, node: InternalNode, tactic: str, logprob: float
) -> Edge:
t0 = time.time()
response = self.dojo.run_tac(node.state, tactic)
response = self.dojo.run_tac(
node.state, tactic
) # TODO: What if this blocks higher priority requests?
# TODO: We can have more precise synchronization around this critical section?
logger.debug(response)

elapsed = time.time() - t0
self.environment_time += elapsed
Expand All @@ -268,7 +254,7 @@ def _run_tactic(
)

if result_node.status == Status.OPEN: # Don't search proved/failed nodes
priority_queue.put_nowait((-result_node.priority, result_node))
self.priority_queue.put_nowait((-result_node.priority, result_node))

# Record the new node and add it to the search queue.
self.nodes[response] = result_node
Expand All @@ -277,31 +263,19 @@ def _run_tactic(
# Will be added to the source node externally.
edge = Edge(tactic=tactic, src=node, dst=result_node)

node.add_out_edge(edge, self.num_sampled_tactics)
if isinstance(result_node, InternalNode):
result_node.in_edges.append(edge)

return edge, isinstance(response, ProofFinished)
if node.status == Status.PROVED:
logger.debug("Found a proof!")

#########
# DEBUG #
#########
current_task = asyncio.current_task()
for t in asyncio.all_tasks():
if t is not current_task and t.get_name() in ("_expand", "_run_tactic"):
t.cancel() # TODO: Use TaskGroup instead of cancelling tasks manually

def check_invariants(self):
"""Perform some sanity checks."""

for response, node in self.nodes.items():
if isinstance(response, ProofFinished):
assert isinstance(node, ProofFinishedNode)
assert self.root.status == Status.PROVED
elif type(response) in (
LeanError,
TimeoutError,
ProofGivenUp,
):
assert isinstance(node, ErrorNode)
else:
assert isinstance(node, InternalNode)
node.check_invariants()
return edge


@ray.remote
Expand Down Expand Up @@ -350,7 +324,9 @@ def initialize(self) -> None:
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)

async def generate(self, prompt: str, num_samples: int) -> RequestOutput:
async def generate(
self, req_id: str, prompt: str, num_samples: int
) -> RequestOutput:
sampling_params = SamplingParams(
n=num_samples,
temperature=0,
Expand All @@ -360,12 +336,14 @@ async def generate(self, prompt: str, num_samples: int) -> RequestOutput:
logprobs=0,
)

async for oup in self.engine.generate(
prompt, sampling_params, request_id=str(uuid.uuid4().hex)
):
async for oup in self.engine.generate(prompt, sampling_params, req_id):
final_output = oup

return final_output

async def cancel(self, req_id: str) -> None:
await self.engine.abort(req_id)


class DistributedProver:
"""A distributed prover that uses Ray to parallelize the proof search.
Expand Down
Loading

0 comments on commit b37ca6e

Please sign in to comment.