Skip to content

Commit

Permalink
apply some comments, refactor operations
Browse files Browse the repository at this point in the history
  • Loading branch information
scarlehoff committed Nov 27, 2024
1 parent ef2bb37 commit 59464e9
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 138 deletions.
156 changes: 31 additions & 125 deletions n3fit/src/n3fit/backends/keras_backend/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
This includes an implementation of the NNPDF operations on fktable in the keras
language (with the mapping ``c_to_py_fun``) into Keras ``Lambda`` layers.
Tensor operations are compiled through the decorator for optimization
The rest of the operations in this module are divided into four categories:
numpy to tensor:
Operations that take a numpy array and return a tensorflow tensor
Expand All @@ -18,7 +16,12 @@
layer generation:
Instanciate a layer to be applied by the calling function
Some of these are just aliases to the backend (tensorflow or Keras) operations
Most of the operations in this module are just aliases to the backend
(Keras in this case) so that, when implementing new backends, it is clear
which operations may needd to be overwritten.
For a few selected operations, a more complicated wrapper to e.g., make
them into layers or apply some default, is included.
Note that tensor operations can also be applied to layers as the output of a layer is a tensor
equally operations are automatically converted to layers when used as such.
"""
Expand All @@ -27,12 +30,35 @@
from keras import ops as Kops
from keras.layers import ELU, Input
from keras.layers import Lambda as keras_Lambda
from keras.layers import multiply as keras_multiply
from keras.layers import subtract as keras_subtract
import numpy as np

from validphys.convolution import OP

# The following operations are either loaded directly from keras and exposed here
# or the name is change slightly (usually for historical or collision reasons,
# e.g., our logs are always logs or we were using the tf version in the past)

# isort: off
from keras.ops import (
absolute,
einsum,
expand_dims,
leaky_relu,
reshape,
split,
sum,
tanh,
transpose,
)
from keras.ops import log as op_log
from keras.ops import power as pow
from keras.ops import take as gather
from keras.ops import tensordot as tensor_product
from keras.layers import multiply as op_multiply
from keras.layers import subtract as op_subtract

# isort: on

# Backend dependent functions and operations
if K.backend() == "torch":
tensor_to_numpy_or_python = lambda x: x.detach().cpu().numpy()
Expand Down Expand Up @@ -144,40 +170,6 @@ def numpy_to_input(numpy_array, name=None):
return input_layer


#
# Layer to Layer operations
#
def op_multiply(o_list, **kwargs):
"""
Receives a list of layers of the same output size and multiply them element-wise
"""
return keras_multiply(o_list, **kwargs)


def op_multiply_dim(o_list, **kwargs):
"""
Bypass in order to multiply two layers with different output dimension
for instance: (10000 x 14) * (14)
as the normal keras multiply don't accept it (but somewhow it does accept it doing it like this)
"""
if len(o_list) != 2:
raise ValueError(
"The number of observables is incorrect, operations.py:op_multiply_dim, expected 2, received {}".format(
len(o_list)
)
)

layer_op = as_layer(lambda inputs: inputs[0] * inputs[1])
return layer_op(o_list)


def gather(*args, **kwargs):
"""
Gather elements from a tensor along an axis
"""
return Kops.take(*args, **kwargs)


def op_gather_keep_dims(tensor, indices, axis=0, **kwargs):
"""A convoluted way of providing ``x[:, indices, :]``
Expand All @@ -195,46 +187,11 @@ def tmp(x):
return layer_op(tensor)


#
# Tensor operations
# f(x: tensor[s]) -> y: tensor
#


# Generation operations
# generate tensors of given shape/content


def tensor_ones_like(*args, **kwargs):
"""
Generates a tensor of ones of the same shape as the input tensor
See full `docs <https://www.tensorflow.org/api_docs/python/tf/keras/backend/ones_like>`_
"""
return K.ones_like(*args, **kwargs)


# Property operations
# modify properties of the tensor like the shape or elements it has


def reshape(x, shape):
"""reshape tensor x"""
return Kops.reshape(x, shape)


def flatten(x):
"""Flatten tensor x"""
return reshape(x, (-1,))


def transpose(tensor, **kwargs):
"""
Transpose a layer,
see full `docs <https://www.tensorflow.org/api_docs/python/tf/keras/backend/transpose>`_
"""
return Kops.transpose(tensor, **kwargs)


def stack(tensor_list, axis=0, **kwargs):
"""Stack a list of tensors
see full `docs <https://www.tensorflow.org/api_docs/python/tf/stack>`_
Expand All @@ -254,29 +211,6 @@ def concatenate(tensor_list, axis=-1, target_shape=None, name=None):
return K.reshape(concatenated_tensor, target_shape)


def einsum(equation, *args, **kwargs):
"""
Computes the tensor product using einsum
See full `docs <https://www.tensorflow.org/api_docs/python/tf/einsum>`_
"""
return Kops.einsum(equation, *args, **kwargs)


def tensor_product(*args, **kwargs):
"""
Computes the tensordot product between tensor_x and tensor_y
See full `docs <https://www.tensorflow.org/api_docs/python/tf/tensordot>`_
"""
return Kops.tensordot(*args, **kwargs)


def pow(tensor, power):
"""
Computes the power of the tensor
"""
return Kops.power(tensor, power)


def scatter_to_one(values, indices, output_shape):
"""
Like scatter_nd initialized to one instead of zero
Expand All @@ -286,14 +220,6 @@ def scatter_to_one(values, indices, output_shape):
return Kops.scatter_update(ones, indices, values)


def op_subtract(inputs, **kwargs):
"""
Computes the difference between two tensors.
see full `docs <https://www.tensorflow.org/api_docs/python/tf/keras/layers/subtract>`_
"""
return keras_subtract(inputs, **kwargs)


def swapaxes(tensor, source, destination):
"""
Moves the axis of the tensor from source to destination, as in numpy.swapaxes.
Expand All @@ -316,15 +242,6 @@ def elu(x, alpha=1.0, **kwargs):
return new_layer(x)


def backend_function(fun_name, *args, **kwargs):
"""
Wrapper to call non-explicitly implemented backend functions by name: (``fun_name``)
see full `docs <https://keras.io/api/utils/backend_utils/>`_ for some possibilities
"""
fun = getattr(K, fun_name)
return fun(*args, **kwargs)


def tensor_splitter(ishape, split_sizes, axis=2, name="splitter"):
"""
Generates a Lambda layer to apply the split operation to a given tensor shape.
Expand Down Expand Up @@ -371,14 +288,3 @@ def tensor_splitter(ishape, split_sizes, axis=2, name="splitter"):
lambda x: Kops.split(x, indices, axis=axis), output_shape=oshapes, name=name
)
return sp_layer


expand_dims = Kops.expand_dims
absolute = Kops.absolute
tanh = Kops.tanh
leaky_relu = Kops.leaky_relu
split = Kops.split
gather = Kops.take
take = Kops.take
sum = Kops.sum
op_log = Kops.log
11 changes: 2 additions & 9 deletions n3fit/src/n3fit/layers/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,7 @@ def build(self, input_shape):
indices = np.where(self._raw_mask)
# The batch dimension can be ignored
nreps = self.mask.shape[-2]
dims = (nreps, self.last_dim * nreps)
try:
self._flattened_indices = np.ravel_multi_index(indices, self._raw_mask.shape)
except:
import ipdb

ipdb.set_trace()

self._flattened_indices = np.ravel_multi_index(indices, self._raw_mask.shape)
self.masked_output_shape = [-1 if d is None else d for d in input_shape]
self.masked_output_shape[-1] = self.last_dim
self.masked_output_shape[-2] = nreps
Expand All @@ -73,7 +66,7 @@ def call(self, ret):
Tensor of shape (batch_size, n_replicas, n_features)
"""
if self.mask is not None:
ret = op.take(op.flatten(ret), self._flattened_indices)
ret = op.gather(op.flatten(ret), self._flattened_indices)
ret = op.reshape(ret, self.masked_output_shape)
if self.c is not None:
ret = ret * self.kernel
Expand Down
4 changes: 0 additions & 4 deletions n3fit/src/n3fit/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,6 @@ def test_op_multiply():
numpy_check(op.op_multiply, operator.mul)


def test_op_multiply_dim():
numpy_check(op.op_multiply_dim, operator.mul, mode="diff")


def test_op_log():
numpy_check(op.op_log, np.log, mode='single')

Expand Down

0 comments on commit 59464e9

Please sign in to comment.