From 244d707297127b73c8901837ed2cbaba3db95417 Mon Sep 17 00:00:00 2001 From: Konstantin Nikolaou <87869540+KonstiNik@users.noreply.github.com> Date: Tue, 25 Jul 2023 19:33:11 +0200 Subject: [PATCH] Konsti fix training for flaxmodels (#97) * Fix the training strategy of the loss aware reservoir for flax models Issue: When setting the number of latest points in the lar strategy to the same size as the overall training data, a forward pass of an empty data set was computed. This does not throw an error in stax but in flax. - Remove the need of an empty forward pass - extend the tests to check for the fixed problem * apply black and isort * merge main --------- Co-authored-by: knikolaou <> --- .../test_loss_aware_reservoir.py | 157 +++++++++++++----- .../loss_aware_reservoir.py | 16 +- 2 files changed, 129 insertions(+), 44 deletions(-) diff --git a/CI/unit_tests/training_strategies/test_loss_aware_reservoir.py b/CI/unit_tests/training_strategies/test_loss_aware_reservoir.py index c4fd3e3..4e47324 100644 --- a/CI/unit_tests/training_strategies/test_loss_aware_reservoir.py +++ b/CI/unit_tests/training_strategies/test_loss_aware_reservoir.py @@ -33,6 +33,7 @@ import numpy as np import optax +from flax import linen as nn from jax import random from neural_tangents import stax from numpy.testing import assert_array_equal @@ -40,7 +41,7 @@ from znnl.accuracy_functions import AccuracyFunction from znnl.distance_metrics import DistanceMetric from znnl.loss_functions import MeanPowerLoss -from znnl.models import JaxModel, NTModel +from znnl.models import FlaxModel, JaxModel, NTModel from znnl.training_recording import JaxRecorder from znnl.training_strategies import LossAwareReservoir, RecursiveMode from znnl.training_strategies.training_decorator import train_func @@ -97,11 +98,46 @@ def train_model( return epochs, batch_size +class FlaxArchitecture(nn.Module): + """ + Test model for the Flax tests. + """ + + @nn.compact + def __call__(self, x): + x = nn.Dense(5, use_bias=True)(x) + x = nn.relu(x) + x = nn.Dense(features=1, use_bias=True)(x) + return x + + class TestLossAwareReservoir: """ Unit test suite of the loss aware reservoir training strategy. """ + @classmethod + def setup_class(cls): + """ + Create models and data for the tests. + """ + key1, key2 = random.split(random.PRNGKey(1), 2) + x = random.normal(key1, (10, 8)) + y = random.normal(key1, (10, 1)) + cls.train_ds = {"inputs": x, "targets": y} + cls.test_ds = {"inputs": x, "targets": y} + + cls.nt_model = NTModel( + nt_module=stax.serial(stax.Dense(5), stax.Relu(), stax.Dense(1)), + optimizer=optax.adam(learning_rate=0.001), + input_shape=(1, 8), + ) + cls.flax_model = FlaxModel( + flax_module=FlaxArchitecture(), + optimizer=optax.adam(learning_rate=0.001), + input_shape=(1, 8), + ) + def test_reservoir_sorting(self): """ Test the sorting of the reservoir. @@ -139,77 +175,120 @@ def test_reservoir_sorting(self): selection_idx = np.argsort(np.abs(raw_x))[::-1][:4] assert_array_equal(reservoir, selection_idx) - @classmethod - def setup_class(cls): - """ - Create data for the tests. - """ - key1, key2 = random.split(random.PRNGKey(1), 2) - x = random.normal(key1, (10, 8)) - y = random.normal(key1, (10, 1)) - cls.train_ds = {"inputs": x, "targets": y} - cls.test_ds = {"inputs": x, "targets": y} - - def test_latest_point_exclusion(self): + def test_update_reservoir(self): """ - Test the method _update_reservoir excludes the latest points from train_ds. - - When selecting latest_points > 0, this number of points is separated from the - train data. The selected points will be appended to every batch. - This test checks if the method _update_reservoir removes the latest_points from - the data, as they cannot be part of the reservoir. The reservoir must only - consist of already seen data. + Test the method _update_reservoir. + + Test whether the method excludes the latest points from train_ds. + When selecting latest_points > 0, this number of points is separated from + the train data. The selected points will be appended to every batch. + This test checks if the method _update_reservoir removes the latest_points + from the data, as they cannot be part of the reservoir. The reservoir must + only consist of already seen data. 1. For reservoir_size = len(train_ds) * Shrinking reservoir for latest_points = 1 * Shrinking reservoir for latest_points = 4 + * Shrink the reservoir to include not points for latest_points = 10 2. For reservoir_size = 5 and len(train_ds) = 10 * No shrinking reservoir size for latest_points = 4 + + Perform both tests for nt and flax models. """ - model = NTModel( - nt_module=stax.serial(stax.Dense(5), stax.Relu(), stax.Dense(1)), - optimizer=optax.adam(learning_rate=0.001), - input_shape=(1, 8), - ) + nt_model = self.nt_model + flax_model = self.flax_model # Test for latest_points = 1 - trainer = LossAwareReservoir( - model=model, + nt_trainer = LossAwareReservoir( + model=nt_model, + loss_fn=MeanPowerLoss(order=2), + disable_loading_bar=True, + reservoir_size=10, + latest_points=1, + ) + flax_trainer = LossAwareReservoir( + model=flax_model, loss_fn=MeanPowerLoss(order=2), disable_loading_bar=True, reservoir_size=10, latest_points=1, ) - trainer.train_data_size = len(self.train_ds["inputs"]) - reservoir = trainer._update_reservoir(train_ds=self.train_ds) - assert len(self.train_ds["inputs"]) - 1 == len(reservoir) + nt_trainer.train_data_size = self.train_ds["inputs"].shape[0] + flax_trainer.train_data_size = self.train_ds["inputs"].shape[0] + reservoir = nt_trainer._update_reservoir(train_ds=self.train_ds) + assert self.train_ds["inputs"].shape[0] - 1 == len(reservoir) + reservoir = flax_trainer._update_reservoir(train_ds=self.train_ds) + assert self.train_ds["inputs"].shape[0] - 1 == len(reservoir) # Test for latest_points = 4 - trainer = LossAwareReservoir( - model=model, + nt_trainer = LossAwareReservoir( + model=nt_model, + loss_fn=MeanPowerLoss(order=2), + disable_loading_bar=True, + reservoir_size=10, + latest_points=4, + ) + flax_trainer = LossAwareReservoir( + model=flax_model, loss_fn=MeanPowerLoss(order=2), disable_loading_bar=True, reservoir_size=10, latest_points=4, ) - trainer.train_data_size = len(self.train_ds["inputs"]) - reservoir = trainer._update_reservoir(train_ds=self.train_ds) - assert len(self.train_ds["inputs"]) - 4 == len(reservoir) + nt_trainer.train_data_size = self.train_ds["inputs"].shape[0] + flax_trainer.train_data_size = self.train_ds["inputs"].shape[0] + reservoir = nt_trainer._update_reservoir(train_ds=self.train_ds) + assert self.train_ds["inputs"].shape[0] - 4 == len(reservoir) + reservoir = flax_trainer._update_reservoir(train_ds=self.train_ds) + assert self.train_ds["inputs"].shape[0] - 4 == len(reservoir) + + # Test for latest_points = 10 + nt_trainer = LossAwareReservoir( + model=nt_model, + loss_fn=MeanPowerLoss(order=2), + disable_loading_bar=True, + reservoir_size=10, + latest_points=10, + ) + flax_trainer = LossAwareReservoir( + model=flax_model, + loss_fn=MeanPowerLoss(order=2), + disable_loading_bar=True, + reservoir_size=10, + latest_points=10, + ) + + nt_trainer.train_data_size = self.train_ds["inputs"].shape[0] + flax_trainer.train_data_size = self.train_ds["inputs"].shape[0] + reservoir = nt_trainer._update_reservoir(train_ds=self.train_ds) + assert 0 == len(reservoir) + reservoir = flax_trainer._update_reservoir(train_ds=self.train_ds) + assert 0 == len(reservoir) # Test for latest_points = 2 but for reservoir_size = 5. The reservoir size # should not be affected now. - trainer = LossAwareReservoir( - model=model, + nt_trainer = LossAwareReservoir( + model=nt_model, + loss_fn=MeanPowerLoss(order=2), + disable_loading_bar=True, + reservoir_size=5, + latest_points=4, + ) + flax_trainer = LossAwareReservoir( + model=flax_model, loss_fn=MeanPowerLoss(order=2), disable_loading_bar=True, reservoir_size=5, latest_points=4, ) - trainer.train_data_size = len(self.train_ds["inputs"]) - reservoir = trainer._update_reservoir(train_ds=self.train_ds) + nt_trainer.train_data_size = self.train_ds["inputs"].shape[0] + flax_trainer.train_data_size = self.train_ds["inputs"].shape[0] + reservoir = nt_trainer._update_reservoir(train_ds=self.train_ds) + assert 5 == len(reservoir) + reservoir = flax_trainer._update_reservoir(train_ds=self.train_ds) assert 5 == len(reservoir) def test_update_training_kwargs(self): diff --git a/znnl/training_strategies/loss_aware_reservoir.py b/znnl/training_strategies/loss_aware_reservoir.py index b60af37..d759163 100644 --- a/znnl/training_strategies/loss_aware_reservoir.py +++ b/znnl/training_strategies/loss_aware_reservoir.py @@ -157,6 +157,7 @@ def _update_reservoir(self, train_ds: dict) -> List[int]: Updates the reservoir in the following steps: * Exclude latest_points from the train_data + * Check whether the the reservoir will be empty or it can cover all data * Compute distance of representations of the remaining training set * Sort the training set according to the distance @@ -179,13 +180,18 @@ def _update_reservoir(self, train_ds: dict) -> List[int]: else: old_data = {k: v[: -self.latest_points, ...] for k, v in train_ds.items()} - distances = self._compute_distance(old_data) - + # If the reservoir no old data is available, return an empty array + if old_data["inputs"].shape[0] == 0: + data_idx = np.array([]) # Return the old train data indices if the reservoir can cover them all - if self.reservoir_size >= self.train_data_size - self.latest_points: - return np.arange(self.train_data_size - self.latest_points) + elif self.reservoir_size >= self.train_data_size - self.latest_points: + data_idx = np.arange(self.train_data_size - self.latest_points) # If the reservoir is smaller than the train, data select data via the loss - return np.argsort(distances)[::-1][: self.reservoir_size] + else: + distances = self._compute_distance(old_data) + data_idx = np.argsort(distances)[::-1][: self.reservoir_size] + + return data_idx def _append_latest_points(self, data_idx: List[int], freq: int = 1): """