Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
yangky11 committed Jul 16, 2024
1 parent e2ec3fd commit e643d94
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 13 deletions.
25 changes: 17 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -310,22 +310,31 @@ python generation/main.py fit --config generation/confs/cli_lean4_novel_premises

After the tactic generator is trained, we combine it with best-first search to prove theorems by interacting with Lean.

For models without retrieval, run:
The evaluation script takes Hugging Face model checkpoints (either local or remote) as input. For remote models, you can simply use their names, e.g., [kaiyuy/leandojo-lean4-tacgen-byt5-small](https://huggingface.co/kaiyuy/leandojo-lean4-tacgen-byt5-small). For locally trained models, you first need to convert them from PyTorch Ligthning checkpoints to Hugging Face checkpoints:
```bash
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --ckpt_path $PATH_TO_MODEL_CHECKPOINT --split test --num-workers 5 --num-gpus 1
python prover/evaluate.py --data-path data/leandojo_benchmark_4/novel_premises/ --ckpt_path $PATH_TO_MODEL_CHECKPOINT --split test --num-workers 5 --num-gpus 1
python scripts/convert_checkpoint.py generator --src $PATH_TO_GENERATOR_CHECKPOINT --dst ./leandojo-lean4-tacgen-byt5-small
python scripts/convert_checkpoint.py retriever --src $PATH_TO_RETRIEVER_CHECKPOINT --dst ./leandojo-lean4-retriever-byt5-small
```
, where `PATH_TO_GENERATOR_CHECKPOINT` and `PATH_TO_RETRIEVER_CHECKPOINT` are PyTorch Ligthning checkpoints produced by the training script.

For models with retrieval, first use the retriever to index the corpus (pre-computing the embeddings of all premises):

To evaluate the model without retrieval, run (using the `random` data split as example):
```bash
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --gen_ckpt_path ./leandojo-lean4-tacgen-byt5-small --split test --num-workers 5 --num-gpus 1
```
You may tweak `--num-workers` and `--num-gpus` to fit your hardware.


For the model with retrieval, first use the retriever to index the corpus (pre-computing the embeddings of all premises):
```bash
python retrieval/index.py --ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --corpus-path data/leandojo_benchmark_4/corpus.jsonl --output-path $PATH_TO_INDEXED_CORPUS
# Do it separately for two data splits.
python retrieval/index.py --ckpt_path ./leandojo-lean4-retriever-byt5-small --corpus-path data/leandojo_benchmark_4/corpus.jsonl --output-path $PATH_TO_INDEXED_CORPUS
```
It saves the indexed corpurs as a pickle file to `PATH_TO_INDEXED_CORPUS`.

Then, run:
```bash
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --ckpt_path $PATH_TO_REPROVER_CHECKPOINT --indexed-corpus-path $PATH_TO_INDEXED_CORPUS --split test --num-cpus 8 --with-gpus
# Do it separately for two data splits.
python scripts/convert_checkpoint.py generator --src $PATH_TO_REPROVER_CHECKPOINT --dst ./leandojo-lean4-retriever-tacgen-byt5-small
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --gen_ckpt_path ./leandojo-lean4-retriever-tacgen-byt5-small --ret_ckpt_path ./leandojo-lean4-retriever-byt5-small --indexed-corpus-path $PATH_TO_INDEXED_CORPUS --split test --num-workers 5 --num-gpus 1
```


Expand Down
2 changes: 1 addition & 1 deletion prover/proof_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
debug: bool,
) -> None:
self.tac_gen = tac_gen
self.tac_gen.initialize()
self.timeout = timeout
self.num_sampled_tactics = num_sampled_tactics
self.debug = debug
Expand Down Expand Up @@ -309,7 +310,6 @@ def __init__(
num_sampled_tactics: int,
debug: bool,
) -> None:
tac_gen.initialize()
self.prover = BestFirstSearchProver(
tac_gen,
timeout,
Expand Down
2 changes: 1 addition & 1 deletion retrieval/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def main() -> None:
model.reindex_corpus(batch_size=args.batch_size)

pickle.dump(
IndexedCorpus(model.corpus, model.corpus_embeddings.cpu()),
IndexedCorpus(model.corpus, model.corpus_embeddings.to(torch.float32).cpu()),
open(args.output_path, "wb"),
)
logger.info(f"Indexed corpus saved to {args.output_path}")
Expand Down
25 changes: 22 additions & 3 deletions scripts/stats.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,36 @@
import re
import sys
import numpy as np
from glob import glob
from loguru import logger
import matplotlib.pyplot as plt

total_time = []
TOTAL_TIME_REGEX = re.compile(r"total_time=(?P<time>.+?),")

for filename in glob(sys.argv[1]):
print(filename)
logger.info(filename)
num_total = num_correct = 0
for line in open(filename):
if "SearchResult" in line:
num_total += 1
if "Proved" in line:
num_correct += 1
total_time.append(float(TOTAL_TIME_REGEX.search(line)["time"]))

if num_total == 0:
print("N/A")
logger.info("Pass@1: N/A")
else:
print(f"{num_correct} / {num_total} = {num_correct / num_total}")
logger.info(f"Pass@1: {num_correct} / {num_total} = {num_correct / num_total}")

logger.info(f"Average time: {np.mean(total_time)}")

total_time.sort()
x = []
y = []
for i, t in enumerate(total_time):
x.append(t)
y.append(i + 1)
plt.scatter(x, y)
plt.savefig("stats.pdf")
logger.info("Figure saved to stats.pdf")

0 comments on commit e643d94

Please sign in to comment.