diff --git a/tf_notebooks/conf_notebook/tf_PSF_NonParam_Euclid_resolution.ipynb b/tf_notebooks/conf_notebook/tf_PSF_NonParam_Euclid_resolution.ipynb new file mode 100644 index 00000000..0ff3d9f0 --- /dev/null +++ b/tf_notebooks/conf_notebook/tf_PSF_NonParam_Euclid_resolution.ipynb @@ -0,0 +1,836 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "v2d_1RWpVQNa" + }, + "source": [ + "# PSF modelling" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "R5-NReW0Lboc", + "outputId": "6310e764-c1f0-4dcf-a2fb-12b73715b666", + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found GPU at: /device:GPU:0\n", + "tf_version: 2.4.1\n" + ] + } + ], + "source": [ + "#@title Import packages\n", + "import sys\n", + "import numpy as np\n", + "import scipy.io as sio\n", + "import time\n", + "\n", + "# Import wavefront code\n", + "import wf_psf as wf\n", + "\n", + "import tensorflow as tf\n", + "device_name = tf.test.gpu_device_name()\n", + "if device_name != '/device:GPU:0':\n", + " raise SystemError('GPU device not found')\n", + "print('Found GPU at: {}'.format(device_name))\n", + "print('tf_version: ' + str(tf.__version__))\n", + "\n", + "import tensorflow_addons as tfa\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Define saving paths" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "id_name = '_euclid_res_1000stars_RecAdam_v2'\n", + "\n", + "home_folder = '/local/home/tliaudat/'\n", + "\n", + "model = 'mccd'\n", + "# model = 'poly'\n", + "# model = 'param'\n", + "\n", + "run_id_name = model + id_name\n", + "\n", + "log_save_file = home_folder + 'checkpoints/log-files/'\n", + "chkp_save_file = home_folder + 'checkpoints/chkp/'\n", + "optim_hist_file = home_folder + 'checkpoints/optim-hist/'\n", + "\n", + "saving_optim_hist = dict()\n", + "\n", + "# Input paths\n", + "dataset_path = home_folder + 'psf-datasets/'\n", + "\n", + "train_path = 'train_Euclid_res_1000_stars_dim256.npy'\n", + "test_path = 'test_Euclid_res_1000_stars_dim256.npy'\n", + "\n", + "Zcube_path = home_folder + 'data/Zernike45.mat'\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "GfFjrvT2xG-0" + }, + "outputs": [], + "source": [ + "# Save output prints to logfile\n", + "\n", + "old_stdout = sys.stdout\n", + "log_file = open(log_save_file + run_id_name + '_output.log','w')\n", + "sys.stdout = log_file\n", + "print('Starting the log file.')\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "T-1a31yjI7aU" + }, + "source": [ + "# Define new model " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Decimation factor for Zernike polynomials\n", + "decim_f = 4 # Original shape (1024x1024)\n", + "n_zernikes = 15\n", + "\n", + "# Some parameters\n", + "pupil_diameter = 1024 // decim_f\n", + "n_bins_lda = 20\n", + "\n", + "output_Q = 3.\n", + "oversampling_rate = 3.\n", + "\n", + "batch_size = 16\n", + "output_dim = 32\n", + "d_max = 2\n", + "d_max_nonparam = 3 # polynomial-constraint features\n", + "x_lims = [0, 1e3]\n", + "y_lims = [0, 1e3]\n", + "graph_features = 10 # Graph-constraint features\n", + "l1_rate = 1e-8 # L1 regularisation\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uWeeueu-Z9Jy" + }, + "source": [ + "# Prepare the inputs" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "I_U_6UGkXZnP", + "outputId": "7d52af12-46b0-4c09-82bc-a8ba68859e8c" + }, + "outputs": [], + "source": [ + "#title Input preparation\n", + "\n", + "Zcube = sio.loadmat(Zcube_path)\n", + "zernikes = []\n", + "\n", + "zernike_shape = int(1024/decim_f)\n", + "\n", + "\n", + "for it in range(n_zernikes):\n", + " zernikes.append(wf.utils.downsample_im(Zcube['Zpols'][0,it][5], zernike_shape))\n", + "\n", + "# Now as cubes\n", + "np_zernike_cube = np.zeros((len(zernikes), zernikes[0].shape[0], zernikes[0].shape[1]))\n", + "\n", + "for it in range(len(zernikes)):\n", + " np_zernike_cube[it,:,:] = zernikes[it]\n", + "\n", + "np_zernike_cube[np.isnan(np_zernike_cube)] = 0\n", + "\n", + "tf_zernike_cube = tf.convert_to_tensor(np_zernike_cube, dtype=tf.float32)\n", + "\n", + "print('Zernike cube:')\n", + "print(tf_zernike_cube.shape)\n", + "\n", + "del Zcube\n", + "\n", + "\n", + "# Load the dictionaries\n", + "train_dataset = np.load(dataset_path + train_path, allow_pickle=True)[()]\n", + "# train_stars = train_dataset['stars']\n", + "# noisy_train_stars = train_dataset['noisy_stars']\n", + "# train_pos = train_dataset['positions']\n", + "train_SEDs = train_dataset['SEDs']\n", + "# train_zernike_coef = train_dataset['zernike_coef']\n", + "train_C_poly = train_dataset['C_poly']\n", + "train_parameters = train_dataset['parameters']\n", + "\n", + "\n", + "test_dataset = np.load(dataset_path + test_path, allow_pickle=True)[()]\n", + "# test_stars = test_dataset['stars']\n", + "# test_pos = test_dataset['positions']\n", + "test_SEDs = test_dataset['SEDs']\n", + "# test_zernike_coef = test_dataset['zernike_coef']\n", + "\n", + "# Convert to tensor\n", + "tf_noisy_train_stars = tf.convert_to_tensor(train_dataset['noisy_stars'], dtype=tf.float32)\n", + "tf_train_stars = tf.convert_to_tensor(train_dataset['stars'], dtype=tf.float32)\n", + "tf_train_pos = tf.convert_to_tensor(train_dataset['positions'], dtype=tf.float32)\n", + "\n", + "tf_test_stars = tf.convert_to_tensor(test_dataset['stars'], dtype=tf.float32)\n", + "tf_test_pos = tf.convert_to_tensor(test_dataset['positions'], dtype=tf.float32)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "print('Dataset parameters:')\n", + "print(train_parameters)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2XtLl5KbdVnp" + }, + "source": [ + "## Continue initialisation" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "fZR0vayTaTUo", + "scrolled": true + }, + "outputs": [], + "source": [ + "# Generate initializations\n", + "\n", + "\n", + "# Prepare np input\n", + "simPSF_np = wf.SimPSFToolkit(zernikes, max_order=n_zernikes,\n", + " pupil_diameter=pupil_diameter, output_dim=output_dim,\n", + " oversampling_rate=oversampling_rate, output_Q=output_Q)\n", + "simPSF_np.gen_random_Z_coeffs(max_order=n_zernikes)\n", + "z_coeffs = simPSF_np.normalize_zernikes(simPSF_np.get_z_coeffs(), simPSF_np.max_wfe_rms)\n", + "simPSF_np.set_z_coeffs(z_coeffs)\n", + "simPSF_np.generate_mono_PSF(lambda_obs=0.7, regen_sample=False)\n", + "\n", + "# Obscurations\n", + "obscurations = simPSF_np.generate_pupil_obscurations(N_pix=pupil_diameter, N_filter=2)\n", + "tf_obscurations = tf.convert_to_tensor(obscurations, dtype=tf.complex64)\n", + "\n", + "# Initialize the SED data list\n", + "packed_SED_data = [wf.utils.generate_packed_elems(_sed, simPSF_np, n_bins=n_bins_lda)\n", + " for _sed in train_SEDs]\n", + " \n", + "\n", + "\n", + "# Prepare the inputs for the training\n", + "tf_packed_SED_data = tf.convert_to_tensor(packed_SED_data, dtype=tf.float32)\n", + "tf_packed_SED_data = tf.transpose(tf_packed_SED_data, perm=[0, 2, 1])\n", + "\n", + "inputs = [tf_train_pos, tf_packed_SED_data]\n", + "\n", + "# Select the observed stars (noisy or noiseless)\n", + "outputs = tf_noisy_train_stars\n", + "# outputs = tf_train_stars\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Select the model" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'mccd'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m poly_dic, graph_dic = wf.tf_mccd_psf_field.build_mccd_spatial_dic_v2(obs_stars=outputs.numpy(),\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mobs_pos\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtf_train_pos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mx_lims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx_lims\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0my_lims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my_lims\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.8/site-packages/wf_psf-0.0.1-py3.8.egg/wf_psf/tf_mccd_psf_field.py\u001b[0m in \u001b[0;36mbuild_mccd_spatial_dic_v2\u001b[0;34m(obs_stars, obs_pos, x_lims, y_lims, d_max, graph_features, verbose)\u001b[0m\n\u001b[1;32m 258\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 259\u001b[0m \u001b[0;31m# Compute graph-spatial constraint matrix\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 260\u001b[0;31m \u001b[0mVT\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGraphBuilder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mgraph_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mVT\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 261\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0;31m# Compute polynomial-spatial constaint matrix\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.8/site-packages/wf_psf-0.0.1-py3.8.egg/wf_psf/graph_utils.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, obs_data, obs_pos, obs_weights, n_comp, n_eigenvects, n_iter, ea_gridsize, distances, auto_run, verbose)\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdistances\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdistances\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mauto_run\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 69\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_build_graphs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 70\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_build_graphs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.8/site-packages/wf_psf-0.0.1-py3.8.egg/wf_psf/graph_utils.py\u001b[0m in \u001b[0;36m_build_graphs\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0mlist_eigenvects\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_comp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 96\u001b[0;31m \u001b[0me\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbest_VT\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mselect_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mR\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0me_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma_range\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 97\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msel_e\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msel_a\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.8/site-packages/wf_psf-0.0.1-py3.8.egg/wf_psf/graph_utils.py\u001b[0m in \u001b[0;36mselect_params\u001b[0;34m(self, R, e_range, a_range)\u001b[0m\n\u001b[1;32m 159\u001b[0m for a in a_range])\n\u001b[1;32m 160\u001b[0m all_eigenvects = np.array(\n\u001b[0;32m--> 161\u001b[0;31m [self.gen_eigenvects(Pea) for Pea in Peas])\n\u001b[0m\u001b[1;32m 162\u001b[0m ea_idx, eigen_idx, best_VT = select_vstar(all_eigenvects, R,\n\u001b[1;32m 163\u001b[0m self.obs_weights)\n", + "\u001b[0;32m~/.local/lib/python3.8/site-packages/wf_psf-0.0.1-py3.8.egg/wf_psf/graph_utils.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 159\u001b[0m for a in a_range])\n\u001b[1;32m 160\u001b[0m all_eigenvects = np.array(\n\u001b[0;32m--> 161\u001b[0;31m [self.gen_eigenvects(Pea) for Pea in Peas])\n\u001b[0m\u001b[1;32m 162\u001b[0m ea_idx, eigen_idx, best_VT = select_vstar(all_eigenvects, R,\n\u001b[1;32m 163\u001b[0m self.obs_weights)\n", + "\u001b[0;32m~/.local/lib/python3.8/site-packages/wf_psf-0.0.1-py3.8.egg/wf_psf/graph_utils.py\u001b[0m in \u001b[0;36mgen_eigenvects\u001b[0;34m(self, mat)\u001b[0m\n\u001b[1;32m 172\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mthe\u001b[0m \u001b[0msmallest\u001b[0m \u001b[0meigenvalues\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 173\u001b[0m \"\"\"\n\u001b[0;32m--> 174\u001b[0;31m \u001b[0mU\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvT\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msvd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfull_matrices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 175\u001b[0m \u001b[0mvT\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvT\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_eigenvects\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 176\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mvT\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m<__array_function__ internals>\u001b[0m in \u001b[0;36msvd\u001b[0;34m(*args, **kwargs)\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.8/site-packages/numpy/linalg/linalg.py\u001b[0m in \u001b[0;36msvd\u001b[0;34m(a, full_matrices, compute_uv, hermitian)\u001b[0m\n\u001b[1;32m 1659\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1660\u001b[0m \u001b[0msignature\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'D->DdD'\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misComplexType\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m'd->ddd'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1661\u001b[0;31m \u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgufunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msignature\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msignature\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mextobj\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mextobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1662\u001b[0m \u001b[0mu\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_t\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1663\u001b[0m \u001b[0ms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_realType\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_t\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "\n", + "\n", + "if model == 'mccd':\n", + " poly_dic, graph_dic = wf.tf_mccd_psf_field.build_mccd_spatial_dic_v2(obs_stars=outputs.numpy(),\n", + " obs_pos=tf_train_pos.numpy(),\n", + " x_lims=x_lims,\n", + " y_lims=y_lims,\n", + " d_max=d_max_nonparam,\n", + " graph_features=graph_features)\n", + "\n", + " spatial_dic = [poly_dic, graph_dic]\n", + "\n", + "\n", + " # Initialize the model\n", + " tf_semiparam_field = wf.tf_mccd_psf_field.TF_SP_MCCD_field(zernike_maps=tf_zernike_cube,\n", + " obscurations=tf_obscurations,\n", + " batch_size=batch_size,\n", + " obs_pos=tf_train_pos,\n", + " spatial_dic=spatial_dic,\n", + " output_Q=output_Q,\n", + " d_max_nonparam=d_max_nonparam,\n", + " graph_features=graph_features,\n", + " l1_rate=l1_rate,\n", + " output_dim=output_dim,\n", + " n_zernikes=n_zernikes,\n", + " d_max=d_max,\n", + " x_lims=x_lims,\n", + " y_lims=y_lims)\n", + "\n", + "elif model == 'poly':\n", + " # # Initialize the model\n", + " tf_semiparam_field = wf.tf_psf_field.TF_SemiParam_field(zernike_maps=tf_zernike_cube,\n", + " obscurations=tf_obscurations,\n", + " batch_size=batch_size,\n", + " output_Q=output_Q,\n", + " d_max_nonparam=d_max_nonparam,\n", + " output_dim=output_dim,\n", + " n_zernikes=n_zernikes,\n", + " d_max=d_max,\n", + " x_lims=x_lims,\n", + " y_lims=y_lims)\n", + "\n", + "elif model == 'param':\n", + " # Initialize the model\n", + " tf_semiparam_field = wf.tf_psf_field.TF_PSF_field_model(zernike_maps=tf_zernike_cube,\n", + " obscurations=tf_obscurations,\n", + " batch_size=batch_size,\n", + " output_dim=output_dim,\n", + " n_zernikes=n_zernikes,\n", + " d_max=d_max,\n", + " x_lims=x_lims,\n", + " y_lims=y_lims)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "v2ra16Mwcm-H" + }, + "source": [ + "# Parameter Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# steps_in_epoch = tf_train_pos.shape[0]/batch_size\n", + "\n", + "\n", + "# # Non-parametric part\n", + "# boundaries_epoch_non_param = [15, 70]\n", + "# boundaries_non_param = [_bound * steps_in_epoch for _bound in boundaries_epoch_non_param]\n", + "# values_non_param = [1e0, 1e-1, 1e-2]\n", + "\n", + "# lr_schedule_non_param = tf.keras.optimizers.schedules.PiecewiseConstantDecay(\n", + "# boundaries_non_param, values_non_param)\n", + "# opt_non_param = tf.keras.optimizers.Adam(learning_rate=lr_schedule_non_param)\n", + "\n", + "\n", + "# # Parametric part\n", + "# boundaries_epoch_param = [5, 15]\n", + "# boundaries_param = [_bound * steps_in_epoch for _bound in boundaries_epoch_param]\n", + "# values_param = [1e-1, 1e-2, 1e-3]\n", + "\n", + "# lr_schedule_param = tf.keras.optimizers.schedules.PiecewiseConstantDecay(\n", + "# boundaries_param, values_param)\n", + "# opt_param = tf.keras.optimizers.Adam(learning_rate=lr_schedule_param)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set up the perfect initial values for the parametric part" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "opt_param = tfa.optimizers.RectifiedAdam(lr=1e-2)\n", + "opt_non_param = tfa.optimizers.RectifiedAdam(lr=1e-1)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MnAnPAgtuymN" + }, + "source": [ + "# Semi-param training" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "fhtxF1FmhHIL", + "outputId": "ca1a15d0-e5df-47ee-92d7-b6e287060830", + "scrolled": false + }, + "outputs": [], + "source": [ + "print('Starting cycle 1..')\n", + "start_cycle1 = time.time()\n", + "\n", + "# Compute the first training cycle\n", + "tf_semiparam_field, history_param, history_non_param = wf.train_utils.first_train_cycle(\n", + " tf_semiparam_field,\n", + " inputs, outputs, batch_size, \n", + " l_rate_param=1e-2, l_rate_non_param=1e-1,\n", + " param_optim=opt_param, non_param_optim=opt_non_param,\n", + " n_epochs_param=20, n_epochs_non_param=100,\n", + " verbose=2)\n", + "\n", + "tf_semiparam_field.save_weights(chkp_save_file + 'chkp_' + run_id_name + '_cycle1')\n", + "\n", + "\n", + "end_cycle1 = time.time()\n", + "print('Cycle1 elapsed time: %f'%(end_cycle1-start_cycle1))\n", + "\n", + "# Save optimisation history in the saving dict\n", + "saving_optim_hist['param_cycle1'] = history_param.history['loss']\n", + "saving_optim_hist['nonparam_cycle1'] = history_non_param.history['loss']\n", + "\n", + "\n", + "# Compute the train/test RMSE values\n", + "print('\\nCompute pixel metrics:')\n", + "test_res, train_res = wf.metrics.compute_metrics(tf_semiparam_field, simPSF_np,\n", + " test_SEDs=test_SEDs,\n", + " train_SEDs=train_SEDs,\n", + " tf_test_pos=tf_test_pos,\n", + " tf_test_stars=tf_test_stars,\n", + " tf_train_stars=tf_train_stars,\n", + " tf_train_pos=tf_train_pos,\n", + " n_bins_lda=n_bins_lda,\n", + " batch_size=batch_size)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "q5Me8CP2hHZ6", + "outputId": "17a76374-bbe2-4eab-e9e5-de85bd286b41", + "scrolled": false + }, + "outputs": [], + "source": [ + "print('Starting cycle 2..')\n", + "start_cycle2 = time.time()\n", + "\n", + "# Compute the next cycle\n", + "tf_semiparam_field, history_param, history_non_param = wf.train_utils.train_cycle(\n", + " tf_semiparam_field,\n", + " inputs, outputs, batch_size, \n", + " l_rate_param=1e-2, l_rate_non_param=1e-1,\n", + " n_epochs_param=15, n_epochs_non_param=100,\n", + " verbose=2)\n", + "\n", + "tf_semiparam_field.save_weights(chkp_save_file + 'chkp_' + run_id_name + '_cycle2')\n", + "\n", + "\n", + "end_cycle2 = time.time()\n", + "print('Cycle2 elapsed time: %f'%(end_cycle2 - start_cycle2))\n", + "\n", + "\n", + "# Save optimisation history in the saving dict\n", + "saving_optim_hist['param_cycle2'] = history_param.history['loss']\n", + "saving_optim_hist['nonparam_cycle2'] = history_non_param.history['loss']\n", + "\n", + "\n", + "# Compute the train/test RMSE values\n", + "test_res, train_res = wf.metrics.compute_metrics(tf_semiparam_field, simPSF_np,\n", + " test_SEDs=test_SEDs,\n", + " train_SEDs=train_SEDs,\n", + " tf_test_pos=tf_test_pos,\n", + " tf_test_stars=tf_test_stars,\n", + " tf_train_stars=tf_train_stars,\n", + " tf_train_pos=tf_train_pos,\n", + " n_bins_lda=n_bins_lda,\n", + " batch_size=batch_size)\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Save optimisation history dictionary\n", + "np.save(optim_hist_file + 'optim_hist_' + run_id_name + '.npy', saving_optim_hist)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Analysis of model" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NgXOkFP3hd4A", + "outputId": "535083d9-35bc-4dae-8dd2-dc3b5f21da1e" + }, + "outputs": [], + "source": [ + "# Preparate the GT model\n", + "\n", + "Zcube = sio.loadmat(Zcube_path)\n", + "zernikes = []\n", + "# Decimation factor for Zernike polynomials\n", + "decim_f = 4 # Original shape (1024x1024)\n", + "\n", + "n_zernikes_bis = 45\n", + "\n", + "for it in range(n_zernikes_bis):\n", + " zernike_map = wf.utils.downsample_im(Zcube['Zpols'][0,it][5], 1024//decim_f)\n", + " zernikes.append(zernike_map)\n", + "\n", + "# Now as cubes\n", + "np_zernike_cube = np.zeros((len(zernikes), zernikes[0].shape[0], zernikes[0].shape[1]))\n", + "\n", + "for it in range(len(zernikes)):\n", + " np_zernike_cube[it,:,:] = zernikes[it]\n", + "\n", + "np_zernike_cube[np.isnan(np_zernike_cube)] = 0\n", + "\n", + "tf_zernike_cube = tf.convert_to_tensor(np_zernike_cube, dtype=tf.float32)\n", + "\n", + "# print('Zernike cube:')\n", + "# print(tf_zernike_cube.shape)\n", + "\n", + "\n", + "# Initialize the model\n", + "GT_tf_semiparam_field = wf.tf_psf_field.TF_SemiParam_field(\n", + " zernike_maps=tf_zernike_cube,\n", + " obscurations=tf_obscurations,\n", + " batch_size=batch_size,\n", + " output_Q=output_Q,\n", + " d_max_nonparam=d_max_nonparam,\n", + " output_dim=output_dim,\n", + " n_zernikes=n_zernikes_bis,\n", + " d_max=d_max,\n", + " x_lims=x_lims,\n", + " y_lims=y_lims)\n", + "\n", + "\n", + "# For the Ground truth model\n", + "GT_tf_semiparam_field.tf_poly_Z_field.assign_coeff_matrix(train_C_poly)\n", + "_ = GT_tf_semiparam_field.tf_np_poly_opd.alpha_mat.assign(np.zeros_like(GT_tf_semiparam_field.tf_np_poly_opd.alpha_mat))\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "B-mIryiM5ltP" + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fXYm0_rU5l6K", + "outputId": "72e13409-47f8-4857-c492-17ee9f4fc202" + }, + "outputs": [], + "source": [ + "print('\\n\\nStarting evaluation of cycle1:')\n", + "# Load the weights\n", + "tf_semiparam_field.load_weights(chkp_save_file + 'chkp_' + run_id_name + '_cycle1')\n", + "\n", + "# Compute the train/test OPD RMSE values\n", + "if model == 'mccd':\n", + " train_opd_res, test_opd_res = wf.metrics.compute_opd_metrics(\n", + " tf_semiparam_field, GT_tf_semiparam_field, tf_test_pos, tf_train_pos)\n", + " \n", + "elif model == 'poly':\n", + " train_opd_res, test_opd_res = wf.metrics.compute_opd_metrics_polymodel(\n", + " tf_semiparam_field, GT_tf_semiparam_field, tf_test_pos, tf_train_pos)\n", + " \n", + "elif model == 'param':\n", + " train_opd_res, test_opd_res = wf.metrics.compute_opd_metrics_param_model(\n", + " tf_semiparam_field, GT_tf_semiparam_field, tf_test_pos, tf_train_pos)\n", + " \n", + "\n", + "# Compute the train/test pixel RMSE values\n", + "test_res, train_res = wf.metrics.compute_metrics(\n", + " tf_semiparam_field, simPSF_np,\n", + " test_SEDs=test_SEDs,\n", + " train_SEDs=train_SEDs,\n", + " tf_test_pos=tf_test_pos,\n", + " tf_test_stars=tf_test_stars,\n", + " tf_train_stars=tf_train_stars,\n", + " tf_train_pos=tf_train_pos,\n", + " n_bins_lda=n_bins_lda,\n", + " batch_size=batch_size)\n", + "\n", + "_, _ = wf.metrics.compute_shape_metrics(\n", + " tf_semiparam_field,\n", + " GT_tf_semiparam_field,\n", + " simPSF_np,\n", + " SEDs=train_SEDs,\n", + " tf_pos=tf_train_pos,\n", + " n_bins_lda=n_bins_lda, \n", + " output_Q=1, output_dim=64, batch_size=16)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "sTrhhnec7Wx0", + "outputId": "2fbe85c1-2ee2-4094-ee0b-36fcff3f1e29" + }, + "outputs": [], + "source": [ + "print('\\n\\nStarting evaluation of cycle2:')\n", + "\n", + "# Load the weights\n", + "tf_semiparam_field.load_weights(chkp_save_file + 'chkp_' + run_id_name + '_cycle2')\n", + "\n", + "# Compute the train/test OPD RMSE values\n", + "if model == 'mccd':\n", + " train_opd_res, test_opd_res = wf.metrics.compute_opd_metrics(\n", + " tf_semiparam_field, GT_tf_semiparam_field, tf_test_pos, tf_train_pos)\n", + " \n", + "elif model == 'poly':\n", + " train_opd_res, test_opd_res = wf.metrics.compute_opd_metrics_polymodel(\n", + " tf_semiparam_field, GT_tf_semiparam_field, tf_test_pos, tf_train_pos)\n", + " \n", + "elif model == 'param':\n", + " train_opd_res, test_opd_res = wf.metrics.compute_opd_metrics_param_model(\n", + " tf_semiparam_field, GT_tf_semiparam_field, tf_test_pos, tf_train_pos)\n", + " \n", + "\n", + "# Compute the train/test pixel RMSE values\n", + "test_res, train_res = wf.metrics.compute_metrics(\n", + " tf_semiparam_field, simPSF_np,\n", + " test_SEDs=test_SEDs,\n", + " train_SEDs=train_SEDs,\n", + " tf_test_pos=tf_test_pos,\n", + " tf_test_stars=tf_test_stars,\n", + " tf_train_stars=tf_train_stars,\n", + " tf_train_pos=tf_train_pos,\n", + " n_bins_lda=n_bins_lda,\n", + " batch_size=batch_size)\n", + "\n", + "_, _ = wf.metrics.compute_shape_metrics(\n", + " tf_semiparam_field,\n", + " GT_tf_semiparam_field,\n", + " simPSF_np,\n", + " SEDs=train_SEDs,\n", + " tf_pos=tf_train_pos,\n", + " n_bins_lda=n_bins_lda, \n", + " output_Q=1, output_dim=64, batch_size=16)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Before ending" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Close log file\n", + "print('\\n Good bye..')\n", + "sys.stdout = old_stdout\n", + "log_file.close()\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9Dt7D2ZQkCvW" + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "88_xEFHLVjOH" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "tf-PSF-NonParam-Euclid_resolution.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}