Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
dhw059 authored Apr 7, 2024
1 parent 31831db commit ac4dda0
Showing 1 changed file with 180 additions and 0 deletions.
180 changes: 180 additions & 0 deletions benchmarks/matbench_v0.1_DensGNN/train.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os.path\n",
"import argparse\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"from matbench.bench import MatbenchBenchmark\n",
"from kgcnn.data.crystal import CrystalDataset\n",
"from kgcnn.literature.DenseGNN import make_model_asu\n",
"\n",
"from sklearn.preprocessing import StandardScaler\n",
"from kgcnn.training.schedule import LinearWarmupExponentialDecay\n",
"from kgcnn.training.scheduler import LinearLearningRateScheduler\n",
"import kgcnn.training.callbacks\n",
"from kgcnn.utils.devices import set_devices_gpu\n",
"import numpy as np\n",
"from copy import deepcopy\n",
"from hyper import *\n",
"\n",
"parser = argparse.ArgumentParser(description='Train DenseGNN.')\n",
"parser.add_argument(\"--gpu\", required=False, help=\"GPU index used for training.\",\n",
" default=None, nargs=\"+\", type=int)\n",
"args = vars(parser.parse_args())\n",
"print(\"Input of argparse:\", args)\n",
"gpu_to_use = args[\"gpu\"]\n",
"set_devices_gpu(gpu_to_use)\n",
"\n",
"subsets_compatible = [\"matbench_jdft2d\", \"matbench_phonons\", \"matbench_mp_gap\", \n",
" \"matbench_perovskites\",\n",
" \"matbench_log_kvrh\", \"matbench_log_gvrh\", \"matbench_dielectric\"]\n",
"mb = MatbenchBenchmark(subset=subsets_compatible, autoload=False)\n",
"\n",
"callbacks = {\n",
" \"graph_labels\": lambda st, ds: np.expand_dims(ds, axis=-1),\n",
" \"node_coordinates\": lambda st, ds: np.array(st.cart_coords, dtype=\"float\"),\n",
" \"node_frac_coordinates\": lambda st, ds: np.array(st.frac_coords, dtype=\"float\"),\n",
" \"graph_lattice\": lambda st, ds: np.ascontiguousarray(np.array(st.lattice.matrix), dtype=\"float\"),\n",
" \"abc\": lambda st, ds: np.array(st.lattice.abc),\n",
" \"charge\": lambda st, ds: np.array([st.charge], dtype=\"float\"),\n",
" \"volume\": lambda st, ds: np.array([st.lattice.volume], dtype=\"float\"),\n",
" \"node_number\": lambda st, ds: np.array(st.atomic_numbers, dtype=\"int\"),\n",
"}\n",
"\n",
"hyper_all = {\n",
" \"matbench_jdft2d\": hyper_1,\n",
" \"matbench_phonons\": hyper_2,\n",
" \"matbench_mp_gap\": hyper_3,\n",
" \"matbench_perovskites\": hyper_4,\n",
" \"matbench_log_kvrh\": hyper_5,\n",
" \"matbench_log_gvrh\": hyper_6,\n",
" \"matbench_dielectric\": hyper_7,\n",
"}\n",
"\n",
"restart_training = True\n",
"remove_invalid_graphs_on_predict = True\n",
"\n",
"for idx_task, task in enumerate(mb.tasks):\n",
" task.load()\n",
" for i, fold in enumerate(task.folds):\n",
" hyper = deepcopy(hyper_all[task.dataset_name])\n",
"\n",
" # Define loss for either classification or regression\n",
" loss = {\n",
" \"class_name\": \"BinaryCrossentropy\", \"config\": {\"from_logits\": True}\n",
" } if task.metadata[\"task_type\"] == \"classification\" else \"mean_absolute_error\"\n",
" hyper[\"training\"][\"compile\"][\"loss\"] = loss\n",
"\n",
" if restart_training and os.path.exists(\n",
" \"%s_predictions_%s_fold_%s.npy\" % (task.dataset_name, hyper[\"model\"][\"config\"][\"name\"], i)):\n",
" predictions = np.load(\n",
" \"%s_predictions_%s_fold_%s.npy\" % (task.dataset_name, hyper[\"model\"][\"config\"][\"name\"], i)\n",
" )\n",
" task.record(fold, predictions)\n",
" continue\n",
"\n",
" train_inputs, train_outputs = task.get_train_and_val_data(fold)\n",
" data_train = CrystalDataset()\n",
"\n",
" data_train._map_callbacks(train_inputs, pd.Series(train_outputs.values), callbacks)\n",
" print(\"Making graph... (this may take a while)\")\n",
" data_train.set_methods(hyper[\"data\"][\"dataset\"][\"methods\"])\n",
" data_train.clean(hyper[\"model\"][\"config\"][\"inputs\"])\n",
"\n",
" y_train = np.array(data_train.get(\"graph_labels\"))\n",
" x_train = data_train.tensor(hyper[\"model\"][\"config\"][\"inputs\"])\n",
"\n",
" if task.metadata[\"task_type\"] == \"classification\":\n",
" scaler = None\n",
" else:\n",
" scaler = StandardScaler(**hyper[\"training\"][\"scaler\"][\"config\"])\n",
" y_train = scaler.fit_transform(y_train)\n",
" print(y_train.shape)\n",
"\n",
" # train and validate your model\n",
" model = make_model_asu(**hyper[\"model\"][\"config\"])\n",
" model.compile(\n",
" loss=tf.keras.losses.get(hyper[\"training\"][\"compile\"][\"loss\"]),\n",
" optimizer=tf.keras.optimizers.get(hyper[\"training\"][\"compile\"][\"optimizer\"])\n",
" )\n",
" hist = model.fit(\n",
" x_train, y_train,\n",
" batch_size=hyper[\"training\"][\"fit\"][\"batch_size\"],\n",
" epochs=hyper[\"training\"][\"fit\"][\"epochs\"],\n",
" verbose=hyper[\"training\"][\"fit\"][\"verbose\"],\n",
" callbacks=[tf.keras.utils.deserialize_keras_object(x) for x in hyper[\"training\"][\"fit\"][\"callbacks\"]]\n",
" )\n",
"\n",
" # Get testing data\n",
" test_inputs = task.get_test_data(fold, include_target=False)\n",
" data_test = CrystalDataset()\n",
" data_test._map_callbacks(test_inputs, pd.Series(np.zeros(len(test_inputs))), callbacks)\n",
" print(\"Making graph... (this may take a while)\")\n",
" data_test.set_methods(hyper[\"data\"][\"dataset\"][\"methods\"])\n",
"\n",
" if remove_invalid_graphs_on_predict:\n",
" removed = data_test.clean(hyper[\"model\"][\"config\"][\"inputs\"])\n",
" np.save(\n",
" \"%s_predictions_invalid_%s_fold_%s.npy\" % (task.dataset_name, hyper[\"model\"][\"config\"][\"name\"], i),\n",
" removed\n",
" )\n",
" else:\n",
" removed = None\n",
"\n",
" # Predict on the testing data\n",
" x_test = data_test.tensor(hyper[\"model\"][\"config\"][\"inputs\"])\n",
" predictions_model = model.predict(x_test)\n",
"\n",
" if remove_invalid_graphs_on_predict:\n",
" indices_test = [j for j in range(len(test_inputs))]\n",
" for j in removed:\n",
" indices_test.pop(j)\n",
" predictions = np.expand_dims(np.zeros(len(test_inputs), dtype=\"float\"), axis=-1)\n",
" predictions[np.array(indices_test)] = predictions_model\n",
" else:\n",
" predictions = predictions_model\n",
"\n",
" if task.metadata[\"task_type\"] == \"classification\":\n",
" def np_sigmoid(x):\n",
" return np.exp(-np.logaddexp(0, -x))\n",
" predictions = np_sigmoid(predictions)\n",
" else:\n",
" predictions = scaler.inverse_transform(predictions)\n",
"\n",
" if predictions.shape[-1] == 1:\n",
" predictions = np.squeeze(predictions, axis=-1)\n",
"\n",
" np.save(\n",
" \"%s_predictions_%s_fold_%s.npy\" % (task.dataset_name, hyper[\"model\"][\"config\"][\"name\"], i),\n",
" predictions\n",
" )\n",
"\n",
" # Record data!\n",
" task.record(fold, predictions)\n",
"\n",
"# Save your results\n",
"mb.to_file(\"results_densegnn.json.gz\")\n",
"\n",
"for key, values in mb.scores.items():\n",
" factor = 1000.0 if key in [\"matbench_jdft2d\"] else 1.0\n",
" if key not in [\"matbench_mp_is_metal\"]:\n",
" print(key, factor*values[\"mae\"][\"mean\"], factor*values[\"mae\"][\"std\"])\n",
" else:\n",
" print(key, values[\"rocauc\"][\"mean\"], values[\"rocauc\"][\"std\"])\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit ac4dda0

Please sign in to comment.