Skip to content

Commit

Permalink
Update optbuild.py
Browse files Browse the repository at this point in the history
remove AZ specific inference function
  • Loading branch information
lewismervin1 authored Jul 3, 2024
1 parent 7ede264 commit 5edf6ec
Showing 1 changed file with 0 additions and 50 deletions.
50 changes: 0 additions & 50 deletions optunaz/optbuild.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,44 +22,6 @@
logger = logging.getLogger(__name__)


def predict_pls(model_path, inference_path):
if inference_path == "None":
logger.info(f"Inference path is not set so AL predictions not performed")
return
else:
logger.info(f"Inference path is {inference_path}")
predict_args = [
"prog",
"--model-file",
str(model_path),
"--input-smiles-csv-file",
str(inference_path),
"--input-smiles-csv-column",
"Structure",
"--output-prediction-csv-file",
str(os.path.dirname(model_path)) + "/al.csv",
"--predict-uncertainty",
"--uncertainty_quantile",
"0.99",
]
try:
with patch.object(sys, "argv", predict_args):
logging.info("Performing active learning predictions")
predict.main()
except FileNotFoundError:
logger.info(
f"PLS file not found at {model_path}, AL predictions not performed"
)
except predict.UncertaintyError:
logging.info(
"PLS prediction not performed: algorithm does not support uncertainty prediction"
)
except predict.AuxCovariateMissing:
logging.info(
"PLS prediction not performed: algorithm requires corvariate auxiliary data for inference"
)


def main():
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -97,12 +59,6 @@ def main():
help="Turn off descriptor generation caching ",
action="store_true",
)
parser.add_argument(
"--inference_uncert",
help="Path for uncertainty inference and thresholding.",
type=pathlib.Path,
default="/projects/db-mirror/MLDatasets/PLS/pls.csv",
)
args = parser.parse_args()

AnyConfig = Union[OptimizationConfig, BuildConfig]
Expand All @@ -111,7 +67,6 @@ def main():

if isinstance(config, OptimizationConfig):
study_name = str(pathlib.Path(args.config).absolute())
pred_pls = False
if not args.no_cache:
config.set_cache()
cache = config._cache
Expand All @@ -123,7 +78,6 @@ def main():
if args.best_model_outpath or args.merged_model_outpath:
buildconfig = buildconfig_best(study)
elif isinstance(config, BuildConfig):
pred_pls = True
buildconfig = config
cache = None
cache_dir = None
Expand All @@ -140,15 +94,11 @@ def main():
args.best_model_outpath,
cache=cache,
)
if not args.merged_model_outpath and pred_pls:
predict_pls(args.best_model_outpath, args.inference_uncert)
if args.merged_model_outpath:
build_merged(
buildconfig,
args.merged_model_outpath,
cache=cache,
)
if pred_pls:
predict_pls(args.merged_model_outpath, args.inference_uncert)
if cache_dir is not None:
cache_dir.cleanup()

0 comments on commit 5edf6ec

Please sign in to comment.