diff --git a/docs/tutorials/design-with-safe.ipynb b/docs/tutorials/design-with-safe.ipynb index 3de96b7..f5bb36c 100644 --- a/docs/tutorials/design-with-safe.ipynb +++ b/docs/tutorials/design-with-safe.ipynb @@ -49,7 +49,7 @@ } ], "source": [ - "designer = sf.SAFEDesign.load_default(verbose=True)\n" + "# designer = sf.SAFEDesign.load_default(verbose=True)\n" ] }, { @@ -58,13 +58,13 @@ "metadata": {}, "outputs": [], "source": [ - "candidate_mol = \"O=C(C#CCN1CCCCC1)Nc1ccc2ncnc(Nc3cccc(Br)c3)c2c1\"\n", + "# candidate_mol = \"O=C(C#CCN1CCCCC1)Nc1ccc2ncnc(Nc3cccc(Br)c3)c2c1\"\n", "\n", - "scaffold = \"[*]N-c1ccc2ncnc(-N[*])c2c1\" # this is for scaffold decoration\n", - "superstructure = \"c1ccc2ncncc2c1\"\n", - "side_chains = '[1*]C(=O)C#CCN1CCCCC1.[2*]c1cccc(Br)c1' # this is for scaffold morphing\n", - "motif = \"[*]-N1CCCCC1\" # this is for motif extension\n", - "linker_generation = [\"[*]-N1CCCCC1\", \"Brc1cccc(Nc2ncnc3ccc(-[*])cc23)c1\"] # this is for linker generation\n" + "# scaffold = \"[*]N-c1ccc2ncnc(-N[*])c2c1\" # this is for scaffold decoration\n", + "# superstructure = \"c1ccc2ncncc2c1\"\n", + "# side_chains = '[1*]C(=O)C#CCN1CCCCC1.[2*]c1cccc(Br)c1' # this is for scaffold morphing\n", + "# motif = \"[*]-N1CCCCC1\" # this is for motif extension\n", + "# linker_generation = [\"[*]-N1CCCCC1\", \"Brc1cccc(Nc2ncnc3ccc(-[*])cc23)c1\"] # this is for linker generation\n" ] }, { @@ -184,7 +184,7 @@ } ], "source": [ - "dm.to_image(dm.to_mol(candidate_mol))\n" + "# dm.to_image(dm.to_mol(candidate_mol))\n" ] }, { @@ -193,7 +193,7 @@ "metadata": {}, "outputs": [], "source": [ - "N_SAMPLES = 100\n" + "# N_SAMPLES = 100\n" ] }, { @@ -219,7 +219,7 @@ } ], "source": [ - "generated = designer.de_novo_generation(sanitize=True, n_samples_per_trial=N_SAMPLES)\n" + "# generated = designer.de_novo_generation(sanitize=True, n_samples_per_trial=N_SAMPLES)\n" ] }, { @@ -1633,7 +1633,7 @@ } ], "source": [ - "dm.to_image(generated[:20])\n" + "# dm.to_image(generated[:20])\n" ] }, { @@ -1710,7 +1710,7 @@ } ], "source": [ - "dm.to_image(scaffold)\n" + "# dm.to_image(scaffold)\n" ] }, { @@ -1727,7 +1727,7 @@ } ], "source": [ - "generated = designer.scaffold_decoration(scaffold=scaffold, n_samples_per_trial=N_SAMPLES, n_trials=2, sanitize=True, do_not_fragment_further=True)\n" + "# generated = designer.scaffold_decoration(scaffold=scaffold, n_samples_per_trial=N_SAMPLES, n_trials=2, sanitize=True, do_not_fragment_further=True)\n" ] }, { @@ -5996,7 +5996,7 @@ } ], "source": [ - "dm.viz.lasso_highlight_image([dm.to_mol(x) for x in generated[:20]], dm.from_smarts(scaffold))\n" + "# dm.viz.lasso_highlight_image([dm.to_mol(x) for x in generated[:20]], dm.from_smarts(scaffold))\n" ] }, { @@ -6061,7 +6061,7 @@ } ], "source": [ - "dm.to_image(superstructure)\n" + "# dm.to_image(superstructure)\n" ] }, { @@ -6078,8 +6078,8 @@ } ], "source": [ - "generated = designer.super_structure(core=superstructure, n_samples_per_trial=N_SAMPLES, n_trials=1, sanitize=True, do_not_fragment_further=False, attachment_point_depth=3)\n", - "#generated\n" + "# generated = designer.super_structure(core=superstructure, n_samples_per_trial=N_SAMPLES, n_trials=1, sanitize=True, do_not_fragment_further=False, attachment_point_depth=3)\n", + "# #generated\n" ] }, { @@ -7227,7 +7227,7 @@ } ], "source": [ - "dm.to_image(generated[:20])\n" + "# dm.to_image(generated[:20])\n" ] }, { @@ -7279,7 +7279,7 @@ } ], "source": [ - "dm.to_image(motif)\n" + "# dm.to_image(motif)\n" ] }, { @@ -7296,8 +7296,8 @@ } ], "source": [ - "# let's make some long sequence\n", - "generated = designer.motif_extension(motif=motif, n_samples_per_trial=N_SAMPLES, n_trials=1, sanitize=True, do_not_fragment_further=False, min_length=25, max_length=80)\n" + "# # let's make some long sequence\n", + "# generated = designer.motif_extension(motif=motif, n_samples_per_trial=N_SAMPLES, n_trials=1, sanitize=True, do_not_fragment_further=False, min_length=25, max_length=80)\n" ] }, { @@ -8453,7 +8453,7 @@ } ], "source": [ - "dm.to_image(generated[:20])\n" + "# dm.to_image(generated[:20])\n" ] }, { @@ -8543,7 +8543,7 @@ } ], "source": [ - "dm.to_image(side_chains)\n" + "# dm.to_image(side_chains)\n" ] }, { @@ -10007,8 +10007,8 @@ } ], "source": [ - "generated = designer.scaffold_morphing(side_chains=side_chains, n_samples_per_trial=N_SAMPLES, n_trials=1, sanitize=True, do_not_fragment_further=False, random_seed=100)\n", - "dm.to_image(generated[:20])\n" + "# generated = designer.scaffold_morphing(side_chains=side_chains, n_samples_per_trial=N_SAMPLES, n_trials=1, sanitize=True, do_not_fragment_further=False, random_seed=100)\n", + "# dm.to_image(generated[:20])\n" ] }, { @@ -10114,7 +10114,7 @@ } ], "source": [ - "dm.to_image(linker_generation)\n" + "# dm.to_image(linker_generation)\n" ] }, { @@ -12321,8 +12321,8 @@ } ], "source": [ - "generated = designer.linker_generation(*linker_generation, n_samples_per_trial=N_SAMPLES, n_trials=1, sanitize=True, do_not_fragment_further=False, random_seed=100)\n", - "dm.to_image(generated[:20])\n" + "# generated = designer.linker_generation(*linker_generation, n_samples_per_trial=N_SAMPLES, n_trials=1, sanitize=True, do_not_fragment_further=False, random_seed=100)\n", + "# dm.to_image(generated[:20])\n" ] } ], diff --git a/scripts/model_trainer.py b/scripts/model_trainer.py deleted file mode 100755 index 4ec8b6e..0000000 --- a/scripts/model_trainer.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python -import transformers -from safe.trainer.cli import ModelArguments -from safe.trainer.cli import DataArguments -from safe.trainer.cli import TrainingArguments -from safe.trainer.cli import train - -parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) -model_args, data_args, training_args = parser.parse_args_into_dataclasses() -train(model_args, data_args, training_args) diff --git a/scripts/mol_design.py b/scripts/mol_design.py deleted file mode 100755 index da163ff..0000000 --- a/scripts/mol_design.py +++ /dev/null @@ -1,392 +0,0 @@ -from typing import Optional -from typing_extensions import Annotated - -import typer - -import os -import torch -import safe as sf -import numpy as np -import pandas as pd -import datamol as dm -import itertools -import wandb -import fsspec - -from enum import Enum -from loguru import logger -from tqdm.auto import tqdm -from safe import SAFEDesign -from tdc import Oracle -from tdc import Evaluator -from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, create_reference_model - -os.environ["TOKENIZERS_PARALLELISM"] = "false" - -app = typer.Typer() - - -class DesignMode(str, Enum): - denovo = "denovo" - superstructure = "superstructure" - scaffold = "scaffold" - motif = "motif" - morphing = "morphing" - linker = "linker" - - -class DesignObjective(str, Enum): - clogp = "clogp" - qed = "qed" - sas = "sas" - tpsa = "tpsa" - mw = "mw" - cns = "cns" - - -class TargetedReward: - """Reward function for goal directed design - - reward = 1.0 / (1.0 + alpha * distance) - """ - - def __init__(self, objective="clogp", target=2, alpha=0.5): - self.obj = objective - self.target = target - self.alpha = alpha - self.cns_predictor = None - if self.obj == "cns": - import vdmpk.pka - - self.cns_predictor = vdmpk.pka.PkaPredictor.from_ada() - - @staticmethod - def _fail_silently(fn, *arg, **kwargs): - """Remains silent when the input function fails and return None instead""" - try: - return fn(*arg, **kwargs) - except Exception: - return None - - def __call__(self, smiles): - if not isinstance(smiles, (list, tuple, np.ndarray)): - smiles = [smiles] - n_input = len(smiles) - mols = [TargetedReward._fail_silently(dm.to_mol, x) for x in smiles] - default_scores = np.zeros(len(mols)) - valid = np.array([v is not None for v in mols]) - if np.sum(valid) > 0: - valid_mol = list(itertools.compress(mols, valid)) - out = self._metric(valid_mol) - default_scores[valid] = out - if n_input > 1: - return default_scores.astype(float) - return float(default_scores.flat[0]) - - def _compute_cns(self, mols): - out = [] - for mol in mols: - try: - results = self.cns_predictor.predict_pka( - mol=mol, - return_all_pka_values=False, - return_states=False, - return_mols=False, - return_clogd=False, - return_cns_score=True, - clogd_ph=7.4, - clogp=None, - ) - out.append(results["mpo_score"]) - except: - out.append(-1) - return np.asarray(out) - - def _metric(self, mols): - """Compute underlying metric objective for a set of smiles""" - if self.obj == "cns": - return self._compute_cns(mols) - if self.obj == "clogp": - out = [dm.descriptors.clogp(x) for x in mols] - elif self.obj == "mw": - out = [dm.descriptors.mw(x) for x in mols] - elif self.obj == "sas": - out = [dm.descriptors.sas(x) for x in mols] - elif self.obj == "tpsa": - out = [dm.descriptors.tpsa(x) for x in mols] - elif self.obj == "qed": - out = [dm.descriptors.qed(x) for x in mols] - else: - raise ValueError("Unknown objective") - out = np.asarray(out) - dist = np.abs(out - self.target) - return 1.0 / (1.0 + self.alpha * dist) - - -def train( - ppo_config, - generation_kwargs, - model, - tokenizer, - oracle, - prefix=None, - n_episodes=100, - batch_size=32, -): - safe_encoder = sf.SAFEConverter() - model_ref = create_reference_model(model) - config = PPOConfig(**ppo_config) - - diversity_evaluator = Evaluator(name="Diversity") - uniqueness_evaluator = Evaluator(name="Uniqueness") - - ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer) - - for _ in tqdm(range(n_episodes)): - fragment = "" - if isinstance(prefix, str): - fragment = safe_encoder.encoder( - prefix, - canonical=False, - randomize=True, - constraints=None, - allow_empty=True, - ) - fragment = fragment.rstrip(".") + "." - - if isinstance(fragment, str): - fragment = [fragment] - - batch_size = ppo_config.get("batch_size", 32) - - if len(fragment) < batch_size: - fragment = np.random.choice(fragment, size=batch_size) - - game_data = {} - game_data["query"] = fragment - batch = tokenizer( - [tokenizer.bos_token + x for x in fragment], - return_tensors="pt", - add_special_tokens=False, - ).to(model.pretrained_model.device) - query_tensor = batch["input_ids"] - response_tensor = ppo_trainer.generate( - list(query_tensor), return_prompt=True, **generation_kwargs - ) - decoded_safe_mols = tokenizer.batch_decode(response_tensor, skip_special_tokens=True) - - decoded_smiles = [ - sf.decode( - x, - as_mol=False, - fix=True, - remove_added_hs=True, - canonical=True, - ignore_errors=True, - remove_dummies=True, - ) - for x in decoded_safe_mols - ] - - game_data["response"] = decoded_safe_mols - rewards = np.zeros(len(decoded_smiles)) - try: - valid_position, valid_smiles = zip( - *[(i, x) for i, x in enumerate(decoded_smiles) if x is not None] - ) - batch_reward = oracle(list(valid_smiles)) - rewards[np.asarray(valid_position)] = batch_reward - except Exception as e: - logger.error(e) - rewards = torch.from_numpy(rewards).to(device=model.pretrained_model.device) - rewards = list(rewards) - stats = ppo_trainer.step(list(query_tensor), list(response_tensor), rewards) - stats["validity"] = len(valid_position) / batch_size - stats["uniqueness"] = uniqueness_evaluator(list(valid_smiles)) - stats["diversity"] = diversity_evaluator(list(valid_smiles)) - ppo_trainer.log_stats(stats, game_data, rewards) - return ppo_trainer, model - - -@app.command() -def sample( - checkpoint: Annotated[Optional[str], typer.Option()] = None, - n_samples: Annotated[int, typer.Option()] = 1000, - n_trials: Annotated[int, typer.Option()] = 1, - sanitize: Annotated[bool, typer.Option()] = True, - allow_further_decomposition: Annotated[bool, typer.Option()] = False, - mode: Annotated[DesignMode, typer.Option(case_sensitive=False)] = DesignMode.denovo, - inputs: Annotated[str, typer.Option()] = None, - seed: Annotated[int, typer.Option()] = None, - max_n: Annotated[int, typer.Option()] = -1, - outfile: Annotated[str, typer.Option(default=...)] = None, -): - """Sample molecule using SAFEDesign""" - - device = "cuda" if torch.cuda.is_available() else "cpu" - designer = SAFEDesign.load_default(verbose=False, model_dir=checkpoint, device=device) - generate_params = { - "n_samples_per_trial": n_samples, - "n_trials": n_trials, - "sanitize": sanitize, - "do_not_fragment_further": (not allow_further_decomposition), - "random_seed": seed, - } - - datas = [] - if mode.value == "denovo": - generated = designer.de_novo_generation( - n_samples_per_trial=n_samples, n_trials=n_trials, sanitize=sanitize - ) - data = {"smiles": generated} - data = pd.DataFrame(generated) - data["mode"] = mode.value - datas.append(data) - - else: - inputs_df = pd.read_csv(inputs) - input_list = inputs_df[mode.value].tolist() - if max_n is not None and max_n > 0: - input_list = input_list[:max_n] - for cur_input in tqdm(input_list): - try: - if mode.value == "scaffold": - generated = designer.scaffold_decoration(scaffold=cur_input, **generate_params) - elif mode.value == "superstructure": - generated = designer.super_structure( - core=cur_input, attachment_point_depth=3, **generate_params - ) - elif mode.value == "motif": - generated = designer.motif_extension( - motif=cur_input, min_length=len(inputs), **generate_params - ) - elif mode.value == "morphing": - generated = designer.scaffold_morphing(side_chains=cur_input, **generate_params) - elif mode.value == "linker": - generated = designer.linker_generation(*cur_input.split("."), **generate_params) - - data = {"smiles": generated} - data = pd.DataFrame(generated) - if cur_input is not None: - data["inputs"] = cur_input - data["mode"] = mode.value - datas.append(data) - except Exception as e: - logger.exception(e) - if len(datas) > 0: - datas = pd.concat(datas, ignore_index=True) - datas.to_csv(outfile, index=False) - - -@app.command() -def optim( - checkpoint: Annotated[Optional[str], typer.Option()] = None, - n_samples: Annotated[int, typer.Option()] = 500, - n_trials: Annotated[int, typer.Option()] = 2, - sanitize: Annotated[bool, typer.Option()] = True, - allow_further_decomposition: Annotated[bool, typer.Option()] = True, - seed: Annotated[int, typer.Option()] = 42, - inputs: Annotated[str, typer.Option()] = None, - task_id: Annotated[int, typer.Option()] = None, - objective: Annotated[DesignObjective, typer.Option(case_sensitive=False)] = DesignObjective.mw, - batch_size: Annotated[int, typer.Option()] = 100, - n_episodes: Annotated[int, typer.Option()] = 100, - target: Annotated[int, typer.Option()] = 350, - alpha: Annotated[float, typer.Option()] = 0.5, - learning_rate: Annotated[float, typer.Option()] = 5e-5, - max_new_tokens: Annotated[int, typer.Option()] = 150, - name: Annotated[str, typer.Option(default=...)] = None, - outdir: Annotated[str, typer.Option(default=...)] = None, -): - """Perform optimization under a given objective""" - device = "cuda" if torch.cuda.is_available() else "cpu" - designer = SAFEDesign.load_default(verbose=False, model_dir=checkpoint, device=device) - - safe_tokenizer = designer.tokenizer - tokenizer = safe_tokenizer.get_pretrained() - model = AutoModelForCausalLMWithValueHead(designer.model) - model.is_peft_model = False - TASKS = [ - ("mw", 350), - ("mw", 400), - ("mw", 450), - ("clogp", 2), - ("clogp", 4), - ("clogp", 6), - ("tpsa", 40), - ("tpsa", 80), - ("tpsa", 120), - ("qed", 0.3), - ("qed", 0.5), - ("qed", 0.7), - ("cns", None), - ] - if task_id is None: - reward_fn = TargetedReward(objective=objective.value, target=target, alpha=alpha) - else: - reward_fn = TargetedReward( - objective=TASKS[task_id][0], target=TASKS[task_id][1], alpha=alpha - ) - name = name + f"-{TASKS[task_id][0]}-{TASKS[task_id][1]}" - outdir = outdir.rstrip("/") + f"-{TASKS[task_id][0]}-{TASKS[task_id][1]}/" - if inputs == "None": - inputs = None - if name is None: - name = f"safe-{objective.value}" - - ppo_config = { - "batch_size": batch_size, - "log_with": "wandb", - "model_name": "GPT", - "tracker_project_name": name, - "learning_rate": learning_rate, - } - - generation_kwargs = { - "top_k": 0.0, - "top_p": 1.0, - "do_sample": True, - "pad_token_id": tokenizer.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, - "max_new_tokens": max_new_tokens, - } - - wandb.finish() - trainer, trained_model = train( - ppo_config, - generation_kwargs, - model, - tokenizer, - reward_fn, - prefix=inputs or None, - n_episodes=n_episodes, - batch_size=batch_size, - ) - - trained_model.eval() - designer.model = trained_model - - generate_params = {"n_samples_per_trial": n_samples, "n_trials": n_trials, "sanitize": sanitize} - if inputs: - generated = designer.scaffold_decoration( - scaffold=inputs, - do_not_fragment_further=(not allow_further_decomposition), - random_seed=seed, - **generate_params, - ) - else: - generated = designer.de_novo_generation(**generate_params) - - data = {"smiles": generated} - data = pd.DataFrame(generated) - if inputs is not None: - data["inputs"] = inputs - - with fsspec.open(os.path.join(outdir, f"data-{name}.csv"), "w", auto_mkdir=True) as IN: - data.to_csv(IN, index=False) - trained_model.save_pretrained(os.path.join(outdir, "model")) - - -if __name__ == "__main__": - app() diff --git a/scripts/tokenizer_trainer.py b/scripts/tokenizer_trainer.py deleted file mode 100755 index c265790..0000000 --- a/scripts/tokenizer_trainer.py +++ /dev/null @@ -1,107 +0,0 @@ -from typing import Optional - -from collections.abc import Mapping -from dataclasses import dataclass, field -from loguru import logger -from transformers import AutoTokenizer, HfArgumentParser -from safe.tokenizer import SAFETokenizer -from tqdm.auto import tqdm -from safe.trainer.data_utils import batch_iterator, get_dataset - - -def fast_batch_iterator(dataset, batch_size=1000, n_examples=None, column="inputs"): - cur_len = 0 - for batch in tqdm(dataset.iter(batch_size=batch_size)): - yield batch[column] - cur_len += len(batch[column]) - if n_examples is not None and cur_len >= n_examples: - break - - -@dataclass -class TokenizerTrainingArguments: - """ - Configuration for tokenizer training. - """ - - tokenizer_type: Optional[str] = field( - default="bpe", metadata={"help": "Type of the tokenizer to train."} - ) - base_tokenizer: Optional[str] = field( - default=None, - metadata={ - "help": "Optional base tokenizer to you. Otherwise, the tokenizer will be learnt from scratch using the safe tokenizer." - }, - ) - splitter: Optional[str] = field( - default=None, metadata={"help": "Presplitter to use to train SAFE tokenizer."} - ) - dataset: str = field( - default=None, metadata={"help": "Path to the dataset to load for training the tokenizer."} - ) - text_column: Optional[str] = field( - default="inputs", metadata={"help": "Column containing text data to process."} - ) - vocab_size: Optional[int] = field( - default=1000, metadata={"help": "Target vocab size of the final tokenizer."} - ) - batch_size: Optional[int] = field( - default=100, metadata={"help": "Batch size for training the tokenizer."} - ) - n_examples: Optional[int] = field( - default=None, metadata={"help": "Number of examples to train the tokenizer on."} - ) - tokenizer_name: Optional[str] = field( - default="safe", metadata={"help": "Name of new tokenizer."} - ) - outfile: Optional[str] = field( - default=None, metadata={"help": "Path to the local save of the trained tokenizer"} - ) - all_split: Optional[bool] = field( - default=False, - metadata={ - "help": "Whether to use all the splits or just the train split if only that is available." - }, - ) - push_to_hub: Optional[bool] = field( - default=False, metadata={"help": "Whether to push saved tokenizer to the hub."} - ) - - -if __name__ == "__main__": - # Configuration - parser = HfArgumentParser(TokenizerTrainingArguments) - args = parser.parse_args() - dataset = get_dataset(args.dataset, streaming=True, tokenize_column=args.text_column) - # Training and saving - if isinstance(dataset, Mapping) and not args.all_split: - dataset = dataset["train"] - - if not args.all_split: - logger.info("Using fast batch iterator.") - dataset_iterator = fast_batch_iterator( - dataset, batch_size=args.batch_size, n_examples=args.n_examples, column=args.text_column - ) - else: - logger.info("Using regular batch iterator.") - dataset_iterator = batch_iterator( - dataset, batch_size=args.batch_size, n_examples=args.n_examples, column=args.text_column - ) - - if args.base_tokenizer is not None: - tokenizer = AutoTokenizer.from_pretrained(args.base_tokenizer) - tokenizer = tokenizer.train_new_from_iterator(dataset_iterator, vocab_size=args.vocab_size) - else: - tokenizer = SAFETokenizer( - tokenizer_type=args.tokenizer_type, - splitter=args.splitter, - trainer_args={"vocab_size": args.vocab_size}, - ) - tokenizer.train_from_iterator(dataset_iterator) - tokenizer_name = f"{args.tokenizer_name}-{args.tokenizer_type}-{args.vocab_size}" - # also save locally to the outfile specified - if args.outfile is not None: - tokenizer.save(args.outfile) - tokenizer = tokenizer.get_pretrained() - - tokenizer.save_pretrained(tokenizer_name, push_to_hub=args.push_to_hub) diff --git a/tests/test_import.py b/tests/test_import.py new file mode 100644 index 0000000..1819509 --- /dev/null +++ b/tests/test_import.py @@ -0,0 +1,2 @@ +def test_import(): + import safe diff --git a/tests/test_notebooks.py b/tests/test_notebooks.py new file mode 100644 index 0000000..6a7e018 --- /dev/null +++ b/tests/test_notebooks.py @@ -0,0 +1,30 @@ +import pytest +import pathlib + +import nbformat +from nbconvert.preprocessors.execute import ExecutePreprocessor + + +ROOT_DIR = pathlib.Path(__file__).parent.resolve() + +TUTORIALS_DIR = ROOT_DIR.parent / "docs" / "tutorials" +DISABLE_NOTEBOOKS = [] +NOTEBOOK_PATHS = sorted(list(TUTORIALS_DIR.glob("*.ipynb"))) +NOTEBOOK_PATHS = list(filter(lambda x: x.name not in DISABLE_NOTEBOOKS, NOTEBOOK_PATHS)) + +# Discard some notebooks +NOTEBOOKS_TO_DISCARD = ["Basic_Concepts.ipynb"] +NOTEBOOK_PATHS = list(filter(lambda x: x.name not in NOTEBOOKS_TO_DISCARD, NOTEBOOK_PATHS)) + + +@pytest.mark.parametrize("nb_path", NOTEBOOK_PATHS, ids=[str(n.name) for n in NOTEBOOK_PATHS]) +def test_notebook(nb_path): + # Setup and configure the processor to execute the notebook + ep = ExecutePreprocessor(timeout=600, kernel_name="python3") + + # Open the notebook + with open(nb_path) as f: + nb = nbformat.read(f, as_version=nbformat.NO_CONVERT) + + # Execute the notebook + ep.preprocess(nb, {"metadata": {"path": TUTORIALS_DIR}})