diff --git a/README.md b/README.md index 7b4b6d9..3d04461 100644 --- a/README.md +++ b/README.md @@ -272,7 +272,7 @@ After the models are trained, run the following commands to retrieve premises fo python retrieval/main.py predict --config retrieval/confs/cli_lean4_random.yaml --ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --trainer.logger.name predict_retriever_random --trainer.logger.save_dir logs/predict_retriever_random python retrieval/main.py predict --config retrieval/confs/cli_lean4_novel_premises.yaml --ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --trainer.logger.name predict_retriever_novel_premises --trainer.logger.save_dir logs/predict_retriever_novel_premises ``` -, where `PATH_TO_RETRIEVER_CHECKPOINT` is the model checkpoint produced in the previous step. Retrieved premises are saved to `./logs/predict_retriever_*/predictions.pickle`. +Retrieved premises are saved to `./logs/predict_retriever_*/predictions.pickle`. Note that `PATH_TO_RETRIEVER_CHECKPOINT` is the DeepSpeed model checkpoint produced in the previous step. If you want to use a Hugging Face checkpoint instead, a workaround would be to run the training for 1 step with zero learning rate. ### Evaluating the Retrieved Premises diff --git a/generation/model.py b/generation/model.py index bc3b01e..381bbaa 100644 --- a/generation/model.py +++ b/generation/model.py @@ -157,7 +157,6 @@ def _log_io_texts( def on_fit_start(self) -> None: if self.logger is not None: self.logger.log_hyperparams(self.hparams) - self.logger.watch(self.generator) assert self.trainer is not None logger.info(f"Logging to {self.trainer.log_dir}") diff --git a/prover/proof_search.py b/prover/proof_search.py index 1516c7e..b1e6579 100644 --- a/prover/proof_search.py +++ b/prover/proof_search.py @@ -344,8 +344,9 @@ def initialize(self) -> None: engine_args = AsyncEngineArgs( model=self.model_path, tensor_parallel_size=self.num_gpus, - max_num_batched_tokens=2048, - enable_chunked_prefill=True, + max_num_batched_tokens=8192, + # max_num_batched_tokens=2048, + # enable_chunked_prefill=True, ) self.engine = AsyncLLMEngine.from_engine_args(engine_args) diff --git a/retrieval/confs/cli_dummy.yaml b/retrieval/confs/cli_dummy.yaml new file mode 100644 index 0000000..d5f0577 --- /dev/null +++ b/retrieval/confs/cli_dummy.yaml @@ -0,0 +1,30 @@ +seed_everything: 3407 # https://arxiv.org/abs/2109.08203 +trainer: + accelerator: gpu + devices: 1 + precision: bf16-mixed + strategy: + class_path: pytorch_lightning.strategies.DeepSpeedStrategy + init_args: + stage: 2 + offload_optimizer: false + cpu_checkpointing: false + gradient_clip_val: 1.0 + max_steps: 1 + logger: null + +model: + model_name: google/byt5-small + lr: 0 + warmup_steps: 2000 + num_retrieved: 100 + +data: + data_path: data/leandojo_benchmark_4/random/ + corpus_path: data/leandojo_benchmark_4/corpus.jsonl + num_negatives: 3 + num_in_file_negatives: 1 + batch_size: 8 + eval_batch_size: 64 + max_seq_len: 1024 + num_workers: 4 diff --git a/retrieval/model.py b/retrieval/model.py index 7583563..04d917f 100644 --- a/retrieval/model.py +++ b/retrieval/model.py @@ -9,7 +9,7 @@ from loguru import logger import pytorch_lightning as pl import torch.nn.functional as F -from typing import List, Dict, Any, Tuple, Union, Optional +from typing import List, Dict, Any, Tuple, Union from transformers import AutoModelForTextEncoding, AutoTokenizer from common import ( @@ -146,7 +146,6 @@ def forward( def on_fit_start(self) -> None: if self.logger is not None: self.logger.log_hyperparams(self.hparams) - self.logger.watch(self.encoder) logger.info(f"Logging to {self.trainer.log_dir}") self.corpus = self.trainer.datamodule.corpus