Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
yangky11 committed Jul 15, 2024
1 parent 4263cfb commit e2ec3fd
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 27 deletions.
46 changes: 22 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,8 @@ pip install git+https://github.com/pytorch/torchtune
4. Make sure `wget` and `tar` are available. Then, run `python scripts/download_data.py` to download [LeanDojo Benchmark 4](https://zenodo.org/doi/10.5281/zenodo.8040109). They will be saved to `./data`.
5. Satisfy the requirements of [LeanDojo](https://github.com/lean-dojo/LeanDojo#requirements).
6. Use [LeanDojo](https://github.com/lean-dojo/LeanDojo) to trace all repos in the datasets: `python scripts/trace_repos.py`. This step may take some time. Please refer to [LeanDojo's documentation](https://leandojo.readthedocs.io/en/latest/) if you encounter any issues.
7. Log in Weights & Biases and set its log directory.
```bash
wandb login
export WANDB_DIR=$(pwd)/tmp
mkdir -p tmp/wandb
mkdir logs
```
7. Run `wandb login` to log in Weights & Biases.



## Premise Selection
Expand All @@ -264,28 +259,29 @@ The config files for our experiments are in [./retrieval/confs](./retrieval/conf

Run `python retrieval/main.py fit --help` to see how to use the training script. For example:
```bash
python retrieval/main.py fit --config retrieval/confs/cli_lean4_random.yaml # Train the retriever on the `random` split of LeanDojo Benchmark 4.
python retrieval/main.py fit --config retrieval/confs/cli_lean4_novel_premises.yaml # Train the retriever on the `novel_premises` split of LeanDojo Benchmark 4.
mkdir logs # Create the directory for training logs.
python retrieval/main.py fit --config retrieval/confs/cli_lean4_random.yaml --trainer.logger.name train_retriever_random --trainer.logger.save_dir logs/train_retriever_random # Train the retriever on the `random` split of LeanDojo Benchmark 4.
python retrieval/main.py fit --config retrieval/confs/cli_lean4_novel_premises.yaml --trainer.logger.name train_retriever_novel_premises --trainer.logger.save_dir logs/train_retriever_novel_premises # Train the retriever on the `novel_premises` split of LeanDojo Benchmark 4.
```
The training script saves hyperparameters, model checkpoints, and other information to `./lightning_logs/EXP_ID/`, where `EXP_ID` is an arbitrary experiment ID that will be printed by the training script.
Hyperparameters and model checkpoints are saved in `./logs/train_retriever_*`, and you can monitor the training process on Weights & Biases.


### Retrieving Premises for All Proof States

After the models are trained, run the following commands to retrieve premises for all proof states in the dataset.
```bash
python retrieval/main.py predict --config retrieval/confs/cli_lean4_random.yaml --ckpt_path PATH_TO_RETRIEVER_CHECKPOINT
python retrieval/main.py predict --config retrieval/confs/cli_lean4_novel_premises.yaml --ckpt_path PATH_TO_RETRIEVER_CHECKPOINT --trainer.logger.name XXX --trainer.logger.save_dir logs/XXX
python retrieval/main.py predict --config retrieval/confs/cli_lean4_random.yaml --ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --trainer.logger.name predict_retriever_random --trainer.logger.save_dir logs/predict_retriever_random
python retrieval/main.py predict --config retrieval/confs/cli_lean4_novel_premises.yaml --ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --trainer.logger.name predict_retriever_novel_premises --trainer.logger.save_dir logs/predict_retriever_novel_premises
```
Retrieved premises are saved to `./lightning_logs/EXP_ID'/predictions.pickle`.
, where `PATH_TO_RETRIEVER_CHECKPOINT` is the model checkpoint produced in the previous step. Retrieved premises are saved to `./logs/predict_retriever_*/predictions.pickle`.


### Evaluating the Retrieved Premises

After predictions are saved, evaluate them using metrics such as R@1, R@10, and MRR.
```bash
python retrieval/evaluate.py --data-path data/leandojo_benchmark_4/random --preds-file PATH_TO_PREDICTIONS_PICKLE
python retrieval/evaluate.py --data-path data/leandojo_benchmark_4/novel_premises --preds-file PATH_TO_PREDICTIONS_PICKLE
python retrieval/evaluate.py --data-path data/leandojo_benchmark_4/random --preds-file logs/predict_retriever_random/predictions.pickle
python retrieval/evaluate.py --data-path data/leandojo_benchmark_4/novel_premises --preds-file logs/predict_retriever_novel_premises/predictions.pickle
```


Expand All @@ -294,39 +290,41 @@ python retrieval/evaluate.py --data-path data/leandojo_benchmark_4/novel_premise

### Training the Tactic Generator

Similar to premise selection, you can run `python generator/main.py --help` and `python generator/main.py fit --help` to check the command line options.
Similar to premise selection, you can run `python generation/main.py --help` and `python generation/main.py fit --help` to check the command line options.

To train tactic generators without retrieval:
```bash
python generator/main.py fit --config generator/confs/cli_lean4_random.yaml # LeanDojo Benchmark 4, `random` split
python generator/main.py fit --config generator/confs/cli_lean4_novel_premises.yaml # LeanDojo Benchmark 4, `novel_premises` split
python generation/main.py fit --config generation/confs/cli_lean4_random.yaml --trainer.logger.name train_generator_random --trainer.logger.save_dir logs/train_generator_random # LeanDojo Benchmark 4, `random` split
python generation/main.py fit --config generation/confs/cli_lean4_novel_premises.yaml --trainer.logger.name train_generator_novel_premises --trainer.logger.save_dir logs/train_generator_novel_premises # LeanDojo Benchmark 4, `novel_premises` split
```
Hyperparameters and model checkpoints are saved in `./logs/train_generator_*`, and you can monitor the training process on Weights & Biases.

To train models augmented by retrieval, we need to provide a retriever checkpoint and its predictions on all proof states in the dataset:
```bash
python generator/main.py fit --config generator/confs/cli_lean4_random.yaml --model.ret_ckpt_path PATH_TO_RETRIEVER_CHECKPOINT --data.preds_path PATH_TO_PREDICTIONS_PICKLE
python generator/main.py fit --config generator/confs/cli_lean4_novel_premises.yaml --model.ret_ckpt_path PATH_TO_RETRIEVER_CHECKPOINT --data.preds_path PATH_TO_PREDICTIONS_PICKLE
python generation/main.py fit --config generation/confs/cli_lean4_random.yaml --model.ret_ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --data.preds_path logs/predict_retriever_random/predictions.pickle --trainer.logger.name train_reprover_random --trainer.logger.save_dir logs/train_reprover_random
python generation/main.py fit --config generation/confs/cli_lean4_novel_premises.yaml --model.ret_ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --data.preds_path logs/predict_retriever_novel_premises/predictions.pickle --trainer.logger.name train_reprover_novel_premises --trainer.logger.save_dir logs/train_reprover_novel_premises
```


### Theorem Proving Evaluation on LeanDojo Benchmark

After the tactic generator is trained, we combine it with best-first search to prove theorems by interacting with Lean.

For models without retrieval, run:
```bash
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --ckpt_path PATH_TO_MODEL_CHECKPOINT --split test --num-workers 5 --num-gpus 1
python prover/evaluate.py --data-path data/leandojo_benchmark_4/novel_premises/ --ckpt_path PATH_TO_MODEL_CHECKPOINT --split test --num-workers 5 --num-gpus 1
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --ckpt_path $PATH_TO_MODEL_CHECKPOINT --split test --num-workers 5 --num-gpus 1
python prover/evaluate.py --data-path data/leandojo_benchmark_4/novel_premises/ --ckpt_path $PATH_TO_MODEL_CHECKPOINT --split test --num-workers 5 --num-gpus 1
```

For models with retrieval, first use the retriever to index the corpus (pre-computing the embeddings of all premises):
```bash
python retrieval/index.py --ckpt_path PATH_TO_RETRIEVER_CHECKPOINT --corpus-path data/leandojo_benchmark_4/corpus.jsonl --output-path PATH_TO_INDEXED_CORPUS
python retrieval/index.py --ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --corpus-path data/leandojo_benchmark_4/corpus.jsonl --output-path $PATH_TO_INDEXED_CORPUS
# Do it separately for two data splits.
```

Then, run:
```bash
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --ckpt_path PATH_TO_REPROVER_CHECKPOINT --indexed-corpus-path PATH_TO_INDEXED_CORPUS --split test --num-cpus 8 --with-gpus
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --ckpt_path $PATH_TO_REPROVER_CHECKPOINT --indexed-corpus-path $PATH_TO_INDEXED_CORPUS --split test --num-cpus 8 --with-gpus
# Do it separately for two data splits.
```

Expand Down
5 changes: 4 additions & 1 deletion generation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,4 +257,7 @@ def on_validation_epoch_end(self) -> None:

for path in [gen_ckpt_path, ret_ckpt_path, indexed_corpus_path]:
if os.path.exists(path):
shutil.rmtree(path)
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.remove(path)
4 changes: 2 additions & 2 deletions scripts/download_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@


LEANDOJO_BENCHMARK_4_URL = (
"https://zenodo.org/records/10929138/files/leandojo_benchmark_4.tar.gz?download=1"
"https://zenodo.org/records/12740403/files/leandojo_benchmark_4.tar.gz?download=1"
)
DOWNLOADS = {
LEANDOJO_BENCHMARK_4_URL: "84a75ce552b31731165d55542b1aaca9",
LEANDOJO_BENCHMARK_4_URL: "25e1ee60cd8925b9d2e8673ddcc34b4c",
}


Expand Down

0 comments on commit e2ec3fd

Please sign in to comment.