diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index 62a07ce62f..3a6f6683c7 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -19,7 +19,8 @@ requirements: - pip run: - python >=3.9,<3.13 - - tensorflow >=2.10,<2.17 # 2.17 works ok but the conda-forge package for macos doesn't + - tensorflow >=2.10 + - keras >=3.1 - psutil # to ensure n3fit affinity is with the right processors - hyperopt - mongodb diff --git a/n3fit/src/n3fit/backends/keras_backend/MetaModel.py b/n3fit/src/n3fit/backends/keras_backend/MetaModel.py index 2399c16a64..f1cdbc418a 100644 --- a/n3fit/src/n3fit/backends/keras_backend/MetaModel.py +++ b/n3fit/src/n3fit/backends/keras_backend/MetaModel.py @@ -8,23 +8,15 @@ from pathlib import Path import re +from keras import backend as K +from keras import ops as Kops +from keras import optimizers as Kopt +from keras.models import Model import numpy as np import tensorflow as tf -from tensorflow.keras import optimizers as Kopt -from tensorflow.keras.models import Model -from tensorflow.python.keras.utils import tf_utils # pylint: disable=no-name-in-module import n3fit.backends.keras_backend.operations as op -# We need a function to transform tensors to numpy/python primitives -# which is not part of the official TF interface and can change with the version -if hasattr(tf_utils, "to_numpy_or_python_type"): - _to_numpy_or_python_type = tf_utils.to_numpy_or_python_type -elif hasattr(tf_utils, "sync_to_numpy_or_python_type"): # from TF 2.5 - _to_numpy_or_python_type = tf_utils.sync_to_numpy_or_python_type -else: # in case of disaster - _to_numpy_or_python_type = lambda ret: {k: i.numpy() for k, i in ret.items()} - # Starting with TF 2.16, a memory leak in TF https://github.com/tensorflow/tensorflow/issues/64170 # makes jit compilation unusable in GPU. # Before TF 2.16 it was set to `False` by default. From 2.16 onwards, it is set to `True` @@ -121,7 +113,7 @@ def __init__(self, input_tensors, output_tensors, scaler=None, input_values=None self.compute_losses_function = None self._scaler = scaler - @tf.autograph.experimental.do_not_convert + # @tf.autograph.experimental.do_not_convert def _parse_input(self, extra_input=None): """Returns the input data the model was compiled with. Introduces the extra_input in the places asigned to the placeholders. @@ -173,8 +165,8 @@ def perform_fit(self, x=None, y=None, epochs=1, **kwargs): steps_per_epoch = self._determine_steps_per_epoch(epochs) for k, v in x_params.items(): - x_params[k] = tf.repeat(v, steps_per_epoch, axis=0) - y = [tf.repeat(yi, steps_per_epoch, axis=0) for yi in y] + x_params[k] = Kops.repeat(v, steps_per_epoch, axis=0) + y = [Kops.repeat(yi, steps_per_epoch, axis=0) for yi in y] history = super().fit( x=x_params, y=y, epochs=epochs // steps_per_epoch, batch_size=1, **kwargs @@ -228,13 +220,13 @@ def compute_losses(self): inputs[k] = v[:1] # Compile a evaluation function - @tf.function + @op.decorator_compiler def losses_fun(): predictions = self(inputs) # If we only have one dataset the output changes if len(out_names) == 2: predictions = [predictions] - total_loss = tf.reduce_sum(predictions, axis=0) + total_loss = Kops.sum(predictions, axis=0) ret = [total_loss] + predictions return dict(zip(out_names, ret)) @@ -244,7 +236,7 @@ def losses_fun(): # The output of this function is to be used by python (and numpy) # so we need to convert the tensors - return _to_numpy_or_python_type(ret) + return op.dict_to_numpy_or_python(ret) def compile( self, diff --git a/n3fit/src/n3fit/backends/keras_backend/base_layers.py b/n3fit/src/n3fit/backends/keras_backend/base_layers.py index 2ed2628293..849cd74175 100644 --- a/n3fit/src/n3fit/backends/keras_backend/base_layers.py +++ b/n3fit/src/n3fit/backends/keras_backend/base_layers.py @@ -17,16 +17,14 @@ The names of the layer and the activation function are the ones to be used in the n3fit runcard. """ -from tensorflow import expand_dims, math, nn -from tensorflow.keras.layers import Dense as KerasDense -from tensorflow.keras.layers import Dropout, Lambda -from tensorflow.keras.layers import Input # pylint: disable=unused-import -from tensorflow.keras.layers import LSTM, Concatenate -from tensorflow.keras.regularizers import l1_l2 +from keras.layers import Dense as KerasDense +from keras.layers import Dropout, Lambda +from keras.layers import Input # pylint: disable=unused-import +from keras.layers import LSTM, Concatenate +from keras.regularizers import l1_l2 +from . import operations as ops from .MetaLayer import MetaLayer -from .operations import concatenate_function - # Custom activation functions def square_activation(x): @@ -38,17 +36,17 @@ def square_singlet(x): """Square the singlet sector Defined as the two first values of the NN""" singlet_squared = x[..., :2] ** 2 - return concatenate_function([singlet_squared, x[..., 2:]], axis=-1) + return ops.concatenate([singlet_squared, x[..., 2:]], axis=-1) def modified_tanh(x): """A non-saturating version of the tanh function""" - return math.abs(x) * nn.tanh(x) + return ops.absolute(x) * ops.tanh(x) def leaky_relu(x): """Computes the Leaky ReLU activation function""" - return nn.leaky_relu(x, alpha=0.2) + return ops.leaky_relu(x, alpha=0.2) custom_activations = { @@ -64,7 +62,7 @@ def LSTM_modified(**kwargs): LSTM asks for a sample X timestep X features kind of thing so we need to reshape the input """ the_lstm = LSTM(**kwargs) - ExpandDim = Lambda(lambda x: expand_dims(x, axis=-1)) + ExpandDim = Lambda(lambda x: ops.expand_dims(x, axis=-1)) def ReshapedLSTM(input_tensor): if len(input_tensor.shape) == 2: diff --git a/n3fit/src/n3fit/backends/keras_backend/callbacks.py b/n3fit/src/n3fit/backends/keras_backend/callbacks.py index 911f069e5c..29324e7cee 100644 --- a/n3fit/src/n3fit/backends/keras_backend/callbacks.py +++ b/n3fit/src/n3fit/backends/keras_backend/callbacks.py @@ -15,9 +15,9 @@ import logging from time import time +from keras.callbacks import Callback, TensorBoard import numpy as np import tensorflow as tf -from tensorflow.keras.callbacks import Callback, TensorBoard log = logging.getLogger(__name__) @@ -171,7 +171,6 @@ def on_train_begin(self, logs=None): layer = self.model.get_layer(layer_name) self.updateable_weights.append(layer.weights) - @tf.function def _update_weights(self): """Update all the weight with the corresponding multipliers Wrapped with tf.function to compensate the for loops as both weights variables diff --git a/n3fit/src/n3fit/backends/keras_backend/constraints.py b/n3fit/src/n3fit/backends/keras_backend/constraints.py index e943c1fcb6..bb6d85ff4b 100644 --- a/n3fit/src/n3fit/backends/keras_backend/constraints.py +++ b/n3fit/src/n3fit/backends/keras_backend/constraints.py @@ -2,9 +2,9 @@ Implementations of weight constraints for initializers """ -import tensorflow as tf -from tensorflow.keras import backend as K -from tensorflow.keras.constraints import MinMaxNorm +from keras import backend as K +from keras import ops as Kops +from keras.constraints import MinMaxNorm class MinMaxWeight(MinMaxNorm): @@ -17,8 +17,8 @@ def __init__(self, min_value, max_value, **kwargs): super().__init__(min_value=min_value, max_value=max_value, axis=1, **kwargs) def __call__(self, w): - norms = K.sum(w, axis=self.axis, keepdims=True) + norms = Kops.sum(w, axis=self.axis, keepdims=True) desired = ( - self.rate * K.clip(norms, self.min_value, self.max_value) + (1 - self.rate) * norms + self.rate * Kops.clip(norms, self.min_value, self.max_value) + (1 - self.rate) * norms ) return w * desired / (K.epsilon() + norms) diff --git a/n3fit/src/n3fit/backends/keras_backend/internal_state.py b/n3fit/src/n3fit/backends/keras_backend/internal_state.py index e818716940..dd61068190 100644 --- a/n3fit/src/n3fit/backends/keras_backend/internal_state.py +++ b/n3fit/src/n3fit/backends/keras_backend/internal_state.py @@ -1,6 +1,7 @@ """ Library of functions that modify the internal state of Keras/Tensorflow """ + import os import psutil @@ -13,10 +14,10 @@ import logging import random as rn +import keras +from keras import backend as K import numpy as np import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import backend as K log = logging.getLogger(__name__) diff --git a/n3fit/src/n3fit/backends/keras_backend/operations.py b/n3fit/src/n3fit/backends/keras_backend/operations.py index b6ad0e010e..f521b0536e 100644 --- a/n3fit/src/n3fit/backends/keras_backend/operations.py +++ b/n3fit/src/n3fit/backends/keras_backend/operations.py @@ -6,7 +6,7 @@ This includes an implementation of the NNPDF operations on fktable in the keras language (with the mapping ``c_to_py_fun``) into Keras ``Lambda`` layers. - Tensor operations are compiled through the @tf.function decorator for optimization + Tensor operations are compiled through the decorator for optimization The rest of the operations in this module are divided into four categories: numpy to tensor: @@ -25,30 +25,28 @@ from typing import Optional +from keras import backend as K +from keras import ops as Kops +from keras.layers import ELU, Input +from keras.layers import Lambda as keras_Lambda +from keras.layers import multiply as keras_multiply +from keras.layers import subtract as keras_subtract import numpy as np -import numpy.typing as npt -import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import backend as K -from tensorflow.keras.layers import Input -from tensorflow.keras.layers import Lambda as keras_Lambda -from tensorflow.keras.layers import multiply as keras_multiply -from tensorflow.keras.layers import subtract as keras_subtract from validphys.convolution import OP -# Select a concatenate function depending on the tensorflow version -try: - # For tensorflow >= 2.16, Keras >= 3 - concatenate_function = keras.ops.concatenate -except AttributeError: - # keras.ops was introduced in keras 3 - concatenate_function = tf.concat +# Backend dependent functions and operations +if K.backend() == "torch": + tensor_to_numpy_or_python = lambda x: x.detach().numpy() + decorator_compiler = lambda f: f +else: + tensor_to_numpy_or_python = lambda x: x.numpy() + lambda ret: {k: i.numpy() for k, i in ret.items()} + import tensorflow as tf + decorator_compiler = tf.function -def evaluate(tensor): - """Evaluate input tensor using the backend""" - return K.eval(tensor) +dict_to_numpy_or_python = lambda ret: {k: tensor_to_numpy_or_python(i) for k, i in ret.items()} def as_layer(operation, op_args=None, op_kwargs=None, **kwargs): @@ -101,7 +99,6 @@ def c_to_py_fun(op_name, name="dataset"): except KeyError as e: raise ValueError(f"Operation {op_name} not recognised") from e - @tf.function def operate_on_tensors(tensor_list): return operation(*tensor_list) @@ -113,19 +110,19 @@ def numpy_to_tensor(ival, **kwargs): """ Make the input into a tensor """ - if kwargs.get("dtype", None) is not bool: - kwargs["dtype"] = tf.keras.backend.floatx() - return K.constant(ival, **kwargs) + if (dtype := kwargs.get("dtype", None)) is not bool: + dtype = K.floatx() + return Kops.cast(ival, dtype) # f(x: tensor) -> y: tensor def batchit(x, batch_dimension=0, **kwarg): """Add a batch dimension to tensor x""" - return tf.expand_dims(x, batch_dimension, **kwarg) + return Kops.expand_dims(x, batch_dimension, **kwarg) # layer generation -def numpy_to_input(numpy_array: npt.NDArray, name: Optional[str] = None): +def numpy_to_input(numpy_array: np.typing.NDArray, name: Optional[str] = None): """ Takes a numpy array and generates an Input layer with the same shape, but with a batch dimension (of size 1) added. @@ -173,6 +170,13 @@ def op_multiply_dim(o_list, **kwargs): return layer_op(o_list) +def gather(*args, **kwargs): + """ + Gather elements from a tensor along an axis + """ + return Kops.take(*args, **kwargs) + + def op_gather_keep_dims(tensor, indices, axis=0, **kwargs): """A convoluted way of providing ``x[:, indices, :]`` @@ -183,20 +187,13 @@ def op_gather_keep_dims(tensor, indices, axis=0, **kwargs): indices = tensor.shape[axis] - 1 def tmp(x): - y = tf.gather(x, indices, axis=axis, **kwargs) - return tf.expand_dims(y, axis=axis) + y = gather(x, indices, axis=axis) + return Kops.expand_dims(y, axis=axis) layer_op = as_layer(tmp) return layer_op(tensor) -def gather(*args, **kwargs): - """ - Gather elements from a tensor along an axis - """ - return tf.gather(*args, **kwargs) - - # # Tensor operations # f(x: tensor[s]) -> y: tensor @@ -205,7 +202,8 @@ def gather(*args, **kwargs): # Generation operations # generate tensors of given shape/content -@tf.function + + def tensor_ones_like(*args, **kwargs): """ Generates a tensor of ones of the same shape as the input tensor @@ -216,49 +214,31 @@ def tensor_ones_like(*args, **kwargs): # Property operations # modify properties of the tensor like the shape or elements it has -@tf.function -def flatten(x): - """Flatten tensor x""" - return tf.reshape(x, (-1,)) -@tf.function def reshape(x, shape): """reshape tensor x""" - return tf.reshape(x, shape) - - -@tf.function -def boolean_mask(*args, target_shape=None, **kwargs): - """ - Applies a boolean mask to a tensor + return Kops.reshape(x, shape) - Relevant parameters: (tensor, mask, axis=None) - see full `docs `_. - tensorflow's masking concatenates the masked dimensions, it is possible to - provide a `target_shape` to reshape the output to the desired shape - """ - ret = tf.boolean_mask(*args, **kwargs) - if target_shape is not None: - ret = reshape(ret, target_shape) - return ret +def flatten(x): + """Flatten tensor x""" + return reshape(x, (-1,)) -@tf.function def transpose(tensor, **kwargs): """ Transpose a layer, see full `docs `_ """ - return K.transpose(tensor, **kwargs) + return Kops.transpose(tensor, **kwargs) def stack(tensor_list, axis=0, **kwargs): """Stack a list of tensors see full `docs `_ """ - return tf.stack(tensor_list, axis=axis, **kwargs) + return Kops.stack(tensor_list, axis=axis) def concatenate(tensor_list, axis=-1, target_shape=None, name=None): @@ -266,7 +246,7 @@ def concatenate(tensor_list, axis=-1, target_shape=None, name=None): Concatenates a list of numbers or tensor into a bigger tensor If the target shape is given, the output is reshaped to said shape """ - concatenated_tensor = concatenate_function(tensor_list, axis=axis) + concatenated_tensor = Kops.concatenate(tensor_list, axis=axis) if target_shape is None: return concatenated_tensor @@ -278,7 +258,7 @@ def einsum(equation, *args, **kwargs): Computes the tensor product using einsum See full `docs `_ """ - return tf.einsum(equation, *args, **kwargs) + return Kops.einsum(equation, *args, **kwargs) def tensor_product(*args, **kwargs): @@ -286,40 +266,14 @@ def tensor_product(*args, **kwargs): Computes the tensordot product between tensor_x and tensor_y See full `docs `_ """ - return tf.tensordot(*args, **kwargs) + return Kops.tensordot(*args, **kwargs) -@tf.function def pow(tensor, power): """ Computes the power of the tensor """ - return tf.pow(tensor, power) - - -@tf.function(reduce_retracing=True) -def op_log(o_tensor, **kwargs): - """ - Computes the logarithm of the input - """ - return K.log(o_tensor) - - -@tf.function -def sum(*args, **kwargs): - """ - Computes the sum of the elements of the tensor - see full `docs `_ - """ - return K.sum(*args, **kwargs) - - -def split(*args, **kwargs): - """ - Splits the tensor on the selected axis - see full `docs `_ - """ - return tf.split(*args, **kwargs) + return Kops.power(tensor, power) def scatter_to_one(values, indices, output_shape): @@ -327,8 +281,8 @@ def scatter_to_one(values, indices, output_shape): Like scatter_nd initialized to one instead of zero see full `docs `_ """ - ones = numpy_to_tensor(np.ones(output_shape)) - return tf.tensor_scatter_nd_update(ones, indices, values) + ones = Kops.ones(output_shape) + return Kops.scatter_update(ones, indices, values) def op_subtract(inputs, **kwargs): @@ -344,18 +298,23 @@ def swapaxes(tensor, source, destination): Moves the axis of the tensor from source to destination, as in numpy.swapaxes. see full `docs `_ """ - indices = list(range(tensor.shape.rank)) + rank = len(tensor.shape) + indices = list(range(rank)) if source < 0: - source += tensor.shape.rank + source += rank if destination < 0: - destination += tensor.shape.rank + destination += rank indices[source], indices[destination] = indices[destination], indices[source] - return tf.transpose(tensor, indices) + return Kops.transpose(tensor, indices) + + +def elu(x, alpha=1.0, **kwargs): + new_layer = ELU(alpha=alpha, **kwargs) + return new_layer(x) -@tf.function def backend_function(fun_name, *args, **kwargs): """ Wrapper to call non-explicitly implemented backend functions by name: (``fun_name``) @@ -363,3 +322,62 @@ def backend_function(fun_name, *args, **kwargs): """ fun = getattr(K, fun_name) return fun(*args, **kwargs) + + +def tensor_splitter(ishape, split_sizes, axis=2, name="splitter"): + """ + Generates a Lambda layer to apply the split operation to a given tensor shape. + This wrapper cannot split along the batch index (axis=0). + + Parameters + ---------- + ishape: list(int) + input shape of the tensor that will be split + split_sizes: list(int) + size of each chunk + axis: int + axis along which the split will be applied + name: str + name of the layer + Returns + ------- + sp_layer: layer + a keras layer that applies the split operation upon call + """ + if axis < 1: + raise ValueError("tensor_splitter wrapper can only split along non-batch dimensions") + + # Check that we can indeed split this + if ishape[axis] != np.sum(split_sizes): + raise ValueError( + f"Cannot split tensor of shape {ishape} along axis {axis} in chunks of {split_sizes}" + ) + + # Output shape of each split + oshapes = [] + # Indices at which to put the splits + # NB: tensorflow's split function would've taken the split_sizes directly + # keras instead takes the index at where to split + indices = [] + current_idx = 0 + + for xsize in split_sizes: + current_idx += xsize + indices.append(current_idx) + oshapes.append((*ishape[1:axis], xsize, *ishape[axis + 1 :])) + + sp_layer = keras_Lambda( + lambda x: Kops.split(x, indices, axis=axis), output_shape=oshapes, name=name + ) + return sp_layer + + +expand_dims = Kops.expand_dims +absolute = Kops.absolute +tanh = Kops.tanh +leaky_relu = Kops.leaky_relu +split = Kops.split +gather = Kops.take +take = Kops.take +sum = Kops.sum +op_log = Kops.log diff --git a/n3fit/src/n3fit/layers/DY.py b/n3fit/src/n3fit/layers/DY.py index f05416c5e4..94c982a391 100644 --- a/n3fit/src/n3fit/layers/DY.py +++ b/n3fit/src/n3fit/layers/DY.py @@ -86,7 +86,7 @@ def compute_dy_observable_many_replica(pdf, padded_fk): """ pdfa = pdf[1] pdfb = pdf[0] - + temp = op.einsum('nxfyg, bryg -> brnxf', padded_fk, pdfa) return op.einsum('brnxf, brxf -> brn', temp, pdfb) @@ -96,11 +96,13 @@ def compute_dy_observable_one_replica(pdf, mask_and_fk): Same operations as above but a specialized implementation that is more efficient for 1 replica, masking the PDF rather than the fk table. """ + # mask: (channels, flavs_b, flavs_a) Ffg + # fk: (npoints, channels, x_a, x_b) nFyx mask, fk = mask_and_fk # Retrieve the two PDFs (which may potentially be coming from different initial states) # Since this is the one-replica function, remove the batch and replica dimension - pdfb = pdf[0][0][0] # xf - pdfa = pdf[1][0][0] # yg + pdfb = pdf[0][0][0] # (x_b, flavs_b) xf + pdfa = pdf[1][0][0] # (x_a, flavs_a) yg # TODO: check which PDF must go first in case of different initial states!!! mask_x_pdf = op.tensor_product(mask, pdfa, axes=[(2,), (1,)]) # Ffg, yg -> Ffy diff --git a/n3fit/src/n3fit/layers/losses.py b/n3fit/src/n3fit/layers/losses.py index b33547a6ce..ee6162a8d4 100644 --- a/n3fit/src/n3fit/layers/losses.py +++ b/n3fit/src/n3fit/layers/losses.py @@ -160,7 +160,7 @@ def __init__(self, alpha=1e-7, **kwargs): super().__init__(**kwargs) def apply_loss(self, y_pred): - loss = op.backend_function("elu", -y_pred, alpha=self.alpha) + loss = op.elu(-y_pred, alpha=self.alpha) # Sum over the batch and the datapoints return op.sum(loss, axis=[0, -1]) @@ -180,6 +180,6 @@ class LossIntegrability(LossLagrange): """ def apply_loss(self, y_pred): - y = op.backend_function("square", y_pred) + y = y_pred * y_pred # Sum over the batch and the datapoints return op.sum(y, axis=[0, -1]) diff --git a/n3fit/src/n3fit/layers/mask.py b/n3fit/src/n3fit/layers/mask.py index 3ed007a18f..089c0b6ba6 100644 --- a/n3fit/src/n3fit/layers/mask.py +++ b/n3fit/src/n3fit/layers/mask.py @@ -1,4 +1,4 @@ -from numpy import count_nonzero +import numpy as np from n3fit.backends import MetaLayer from n3fit.backends import operations as op @@ -23,12 +23,14 @@ class Mask(MetaLayer): """ def __init__(self, bool_mask=None, c=None, **kwargs): + self._raw_mask = bool_mask + self._flattened_indices = None if bool_mask is None: self.mask = None self.last_dim = -1 else: self.mask = op.numpy_to_tensor(bool_mask, dtype=bool) - self.last_dim = count_nonzero(bool_mask[0, ...]) + self.last_dim = np.count_nonzero(bool_mask[0, ...]) self.c = c self.masked_output_shape = None super().__init__(**kwargs) @@ -40,9 +42,22 @@ def build(self, input_shape): # Make sure reshape will succeed: set the last dimension to the unmasked data length and before-last to # the number of replicas if self.mask is not None: + + # Prepare the indices to mask + indices = np.where(self._raw_mask) + # The batch dimension can be ignored + nreps = self.mask.shape[-2] + dims = (nreps, self.last_dim * nreps) + try: + self._flattened_indices = np.ravel_multi_index(indices, self._raw_mask.shape) + except: + import ipdb + + ipdb.set_trace() + self.masked_output_shape = [-1 if d is None else d for d in input_shape] self.masked_output_shape[-1] = self.last_dim - self.masked_output_shape[-2] = self.mask.shape[-2] + self.masked_output_shape[-2] = nreps super().build(input_shape) def call(self, ret): @@ -58,7 +73,8 @@ def call(self, ret): Tensor of shape (batch_size, n_replicas, n_features) """ if self.mask is not None: - ret = op.boolean_mask(ret, self.mask, axis=1, target_shape=self.masked_output_shape) + ret = op.take(op.flatten(ret), self._flattened_indices) + ret = op.reshape(ret, self.masked_output_shape) if self.c is not None: ret = ret * self.kernel return ret diff --git a/n3fit/src/n3fit/layers/msr_normalization.py b/n3fit/src/n3fit/layers/msr_normalization.py index 7695d4f11f..5159628c0d 100644 --- a/n3fit/src/n3fit/layers/msr_normalization.py +++ b/n3fit/src/n3fit/layers/msr_normalization.py @@ -194,6 +194,7 @@ def call(self, pdf_integrated, photon_integral): numerators += [self.vsr_factors] numerators = op.concatenate(numerators, axis=0) + divisors = op.gather(y, self.divisor_indices, axis=0) # Fill in the rest of the flavours with 1 diff --git a/n3fit/src/n3fit/layers/observable.py b/n3fit/src/n3fit/layers/observable.py index 8945cc4da4..29c5754ecc 100644 --- a/n3fit/src/n3fit/layers/observable.py +++ b/n3fit/src/n3fit/layers/observable.py @@ -89,7 +89,7 @@ def __init__( operation_name="NULL", nfl=14, n_replicas=1, - **kwargs + **kwargs, ): super(MetaLayer, self).__init__(**kwargs) @@ -178,7 +178,10 @@ def call(self, pdf): rank 3 tensor (batchsize, replicas, ndata) """ if self.splitting: - pdfs = op.split(pdf, self.splitting, axis=2) + splitter = op.tensor_splitter( + pdf.shape, self.splitting, axis=2, name=f"pdf_splitter_{self.name}" + ) + pdfs = splitter(pdf) else: pdfs = [pdf] * len(self.padded_fk_tables) @@ -222,7 +225,7 @@ def compute_float_mask(bool_mask): """ # Create a tensor with the shape (**bool_mask.shape, num_active_flavours) masked_to_full = [] - for idx in np.argwhere(bool_mask): + for idx in np.argwhere(np.array(bool_mask)): temp_matrix = np.zeros(bool_mask.shape) temp_matrix[tuple(idx)] = 1 masked_to_full.append(temp_matrix) diff --git a/n3fit/src/n3fit/model_gen.py b/n3fit/src/n3fit/model_gen.py index 852f93caf3..f7fe7c5608 100644 --- a/n3fit/src/n3fit/model_gen.py +++ b/n3fit/src/n3fit/model_gen.py @@ -99,13 +99,10 @@ def _generate_experimental_layer(self, pdf): the input PDF is evaluated in all points that the experiment needs and needs to be split """ if len(self.dataset_xsizes) > 1: - splitting_layer = op.as_layer( - op.split, - op_args=[self.dataset_xsizes], - op_kwargs={"axis": 2}, - name=f"{self.name}_split", + sp_layer = op.tensor_splitter( + pdf.shape, self.dataset_xsizes, axis=2, name=f"{self.name}_split" ) - sp_pdf = splitting_layer(pdf) + sp_pdf = sp_layer(pdf) output_layers = [obs(p) for obs, p in zip(self.observables, sp_pdf)] else: output_layers = [obs(pdf) for obs in self.observables] diff --git a/n3fit/src/n3fit/model_trainer.py b/n3fit/src/n3fit/model_trainer.py index d92e7cf51d..d864d2c6e5 100644 --- a/n3fit/src/n3fit/model_trainer.py +++ b/n3fit/src/n3fit/model_trainer.py @@ -16,7 +16,7 @@ import numpy as np from n3fit import model_gen -from n3fit.backends import NN_LAYER_ALL_REPLICAS, MetaModel, callbacks, clear_backend_state +from n3fit.backends import NN_LAYER_ALL_REPLICAS, Lambda, MetaModel, callbacks, clear_backend_state from n3fit.backends import operations as op from n3fit.hyper_optimization.hyper_scan import HYPEROPT_STATUSES import n3fit.hyper_optimization.penalties @@ -40,6 +40,9 @@ # Each how many epochs do we increase the integrability Lagrange Multiplier PUSH_INTEGRABILITY_EACH = 100 +# Final number of flavours +FLAVOURS = 14 + # See ModelTrainer::_xgrid_generation for the definition of each field and how they are generated InputInfo = namedtuple("InputInfo", ["input", "split", "idx"]) @@ -354,11 +357,13 @@ def _xgrid_generation(self): input_arr = self._scaler(input_arr) input_layer = op.numpy_to_input(input_arr, name="pdf_input") - # The PDF model will be called with a concatenation of all inputs - # now the output needs to be splitted so that each experiment takes its corresponding input - sp_ar = [[i.shape[1] for i in inputs_unique]] - sp_kw = {"axis": 2} - sp_layer = op.as_layer(op.split, op_args=sp_ar, op_kwargs=sp_kw, name="pdf_split") + # The PDF model is called with a concatenation of all inputs + # however, each output layer might require a different subset, this is achieved by + # splitting back the output + # Input shape: (batch size, replicas, input array, flavours) + ishape = (1, len(self.replicas), input_arr.shape[0], FLAVOURS) + xsizes = [i.shape[1] for i in inputs_unique] + sp_layer = op.tensor_splitter(ishape, xsizes, axis=2, name="splitter") return InputInfo(input_layer, sp_layer, inputs_idx) @@ -936,8 +941,10 @@ def hyperparametrizable(self, params): ) if photons: - if self._scaler: # select only the non-scaled input - pdf_model.get_layer("add_photon").register_photon(xinput.input.tensor_content[:,:,1:]) + if self._scaler: # select only the non-scaled input + pdf_model.get_layer("add_photon").register_photon( + xinput.input.tensor_content[:, :, 1:] + ) else: pdf_model.get_layer("add_photon").register_photon(xinput.input.tensor_content) diff --git a/n3fit/src/n3fit/performfit.py b/n3fit/src/n3fit/performfit.py index 04703ef924..7e91c1b5ca 100644 --- a/n3fit/src/n3fit/performfit.py +++ b/n3fit/src/n3fit/performfit.py @@ -3,11 +3,8 @@ """ # Backend-independent imports -import copy import logging -import numpy as np - import n3fit.checks from n3fit.vpinterface import N3PDF diff --git a/n3fit/src/n3fit/tests/test_backend.py b/n3fit/src/n3fit/tests/test_backend.py index eaae5667c8..e464ae2384 100644 --- a/n3fit/src/n3fit/tests/test_backend.py +++ b/n3fit/src/n3fit/tests/test_backend.py @@ -2,8 +2,11 @@ This module tests the mathematical functions in the n3fit backend and ensures they do the same thing as their numpy counterparts """ + import operator + import numpy as np + from n3fit.backends import operations as op # General parameters @@ -24,14 +27,14 @@ def are_equal(result, reference, threshold=THRESHOLD): - """ checks the difference between array `reference` and tensor `result` is - below `threshold` for all elements """ - res = op.evaluate(result) + """checks the difference between array `reference` and tensor `result` is + below `threshold` for all elements""" + res = op.tensor_to_numpy_or_python(result) assert np.allclose(res, reference, atol=threshold) def numpy_check(backend_op, python_op, mode="same"): - """ Receives a backend operation (`backend_op`) and a python operation + """Receives a backend operation (`backend_op`) and a python operation `python_op` and asserts that, applied to two random arrays, the result is the same. The option `mode` selects the two arrays to be tested and accepts the following @@ -53,7 +56,28 @@ def numpy_check(backend_op, python_op, mode="same"): arrays = [ARR1, ARR2, ARR1, ARR1] elif mode == "twenty": tensors = [T1, T2, T1, T1, T1, T1, T1, T1, T1, T1, T1, T2, T1, T1, T1, T1, T1, T1, T1, T1] - arrays = [ARR1, ARR2, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1, ARR2, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1] + arrays = [ + ARR1, + ARR2, + ARR1, + ARR1, + ARR1, + ARR1, + ARR1, + ARR1, + ARR1, + ARR1, + ARR1, + ARR2, + ARR1, + ARR1, + ARR1, + ARR1, + ARR1, + ARR1, + ARR1, + ARR1, + ] elif mode == "ten": tensors = [T1, T2, T1, T1, T1, T1, T1, T1, T1, T1] arrays = [ARR1, ARR2, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1] @@ -98,13 +122,16 @@ def test_c_to_py_fun(): numpy_check(op_smp, reference, "four") # COM op_com = op.c_to_py_fun("COM") - reference = lambda x, y, z, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t : (x + y + z + d + e + f + g + h + i + j) / (k + l + m + n + o + p + q + r + s + t) + reference = lambda x, y, z, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t: ( + x + y + z + d + e + f + g + h + i + j + ) / (k + l + m + n + o + p + q + r + s + t) numpy_check(op_com, reference, "twenty") # SMT op_smt = op.c_to_py_fun("SMT") - reference = lambda x, y, z, d, e, f, g, h, i, j : (x + y + z + d + e + f + g + h + i + j) + reference = lambda x, y, z, d, e, f, g, h, i, j: (x + y + z + d + e + f + g + h + i + j) numpy_check(op_smt, reference, "ten") + # Tests operations def test_op_multiply(): numpy_check(op.op_multiply, operator.mul) @@ -122,17 +149,11 @@ def test_flatten(): numpy_check(op.flatten, np.ndarray.flatten, mode=(T3, [ARR3])) -def test_boolean_mask(): - bools = np.random.randint(0, 2, DIM, dtype=bool) - np_result = ARR1[bools] - tf_bools = op.numpy_to_tensor(bools) - tf_result = op.boolean_mask(T1, tf_bools, axis=0) - are_equal(np_result, tf_result) - def test_tensor_product(): np_result = np.tensordot(ARR3, ARR1, axes=1) tf_result = op.tensor_product(T3, T1, axes=1) - are_equal(np_result, tf_result) + are_equal(tf_result, np_result) + def test_sum(): numpy_check(op.sum, np.sum, mode='single') diff --git a/n3fit/src/n3fit/tests/test_layers.py b/n3fit/src/n3fit/tests/test_layers.py index 8615414c2f..84ef8c8eaf 100644 --- a/n3fit/src/n3fit/tests/test_layers.py +++ b/n3fit/src/n3fit/tests/test_layers.py @@ -169,7 +169,7 @@ def test_DIS(): kp = op.numpy_to_tensor([[pdf]]) # add batch and replica dimension # generate the n3fit results result_tensor = obs_layer(kp) - result = op.evaluate(result_tensor) + result = op.tensor_to_numpy_or_python(result_tensor) # Compute the numpy version of this layer all_masks = obs_layer.all_masks if len(all_masks) < nfk: @@ -195,7 +195,7 @@ def test_DY(): kp = op.numpy_to_tensor([[pdf]]) # add batch and replica dimension # generate the n3fit results result_tensor = obs_layer(kp) - result = op.evaluate(result_tensor) + result = op.tensor_to_numpy_or_python(result_tensor) # Compute the numpy version of this layer all_masks = obs_layer.all_masks if len(all_masks) < nfk: diff --git a/pyproject.toml b/pyproject.toml index d4f135d44c..09ab007019 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ reportengine = { git = "https://github.com/NNPDF/reportengine" } # Fit psutil = "*" tensorflow = "*" +keras = "^3.1" eko = "^0.14.1" joblib = "*" # Hyperopt