From 9e24b10aa41dd669446ca7f39adcb7318dbecd60 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 29 Nov 2024 15:09:58 +0100 Subject: [PATCH] Remove Mean Op This Op does not really fit the CAReduce API, as it requires an extra bit of information (number of elements in the axis) during the loop. A better solution will be a fused Elemwise+CAReduce --- pytensor/link/numba/dispatch/elemwise.py | 6 -- pytensor/scalar/basic.py | 26 -------- pytensor/tensor/math.py | 77 +----------------------- tests/link/numba/test_elemwise.py | 14 +---- tests/scalar/test_basic.py | 31 +--------- tests/tensor/test_math.py | 10 --- 6 files changed, 3 insertions(+), 161 deletions(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 842cf695aa..def4746a18 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -34,7 +34,6 @@ Add, Composite, IntDiv, - Mean, Mul, ScalarMaximum, ScalarMinimum, @@ -77,11 +76,6 @@ def scalar_in_place_fn_Sub(op, idx, res, arr): return f"{res}[{idx}] -= {arr}" -@scalar_in_place_fn.register(Mean) -def scalar_in_place_fn_Mean(op, idx, res, arr): - return f"{res}[{idx}] += ({arr} - {res}[{idx}]) / (i + 1)" - - @scalar_in_place_fn.register(Mul) def scalar_in_place_fn_Mul(op, idx, res, arr): return f"{res}[{idx}] *= {arr}" diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index bb2baf0636..3c33434e56 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1871,32 +1871,6 @@ def L_op(self, inputs, outputs, gout): add = Add(upcast_out, name="add") -class Mean(ScalarOp): - identity = 0 - commutative = True - associative = False - nfunc_spec = ("mean", 2, 1) - nfunc_variadic = "mean" - - def impl(self, *inputs): - return sum(inputs) / len(inputs) - - def c_code(self, node, name, inputs, outputs, sub): - (z,) = outputs - if not inputs: - return f"{z} = 0;" - else: - return f"{z} = ({' + '.join(inputs)}) / ((double) {len(inputs)});" - - def L_op(self, inputs, outputs, gout): - (gz,) = gout - retval = [gz / len(inputs)] * len(inputs) - return retval - - -mean = Mean(float_out, name="mean") - - class Mul(ScalarOp): identity = 1 commutative = True diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 8c86a834ea..efcc2500a7 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -1316,63 +1316,7 @@ def complex_from_polar(abs, angle): """Return complex-valued tensor from polar coordinate specification.""" -class Mean(FixedOpCAReduce): - __props__ = ("axis",) - nfunc_spec = ("mean", 1, 1) - - def __init__(self, axis=None): - super().__init__(ps.mean, axis) - assert self.axis is None or len(self.axis) == 1 - - def __str__(self): - if self.axis is not None: - args = ", ".join(str(x) for x in self.axis) - return f"Mean{{{args}}}" - else: - return "Mean" - - def _output_dtype(self, idtype): - # we want to protect against overflow - return "float64" - - def perform(self, node, inp, out): - (input,) = inp - (output,) = out - if self.axis is None: - axis = None - else: - axis = self.axis[0] - # numpy.asarray is needed as otherwise we can end up with a - # numpy scalar. - output[0] = np.asarray(np.mean(input, dtype="float64", axis=axis)) - - def c_code(self, node, name, inames, onames, sub): - ret = super().c_code(node, name, inames, onames, sub) - - if self.axis is not None: - return ret - - # TODO: c_code perform support only axis is None - return ( - ret - + f""" - *((double *)PyArray_DATA({onames[0]})) /= PyArray_SIZE({inames[0]}); - """ - ) - - def clone(self, **kwargs): - axis = kwargs.get("axis", self.axis) - return type(self)(axis=axis) - - -# TODO: implement the grad. When done and tested, you can make this the default -# version. -# def grad(self, (x,), (gout,)): -# import pdb;pdb.set_trace() -# return grad(mean(x, self.axis, op=False),[x]) - - -def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None): +def mean(input, axis=None, dtype=None, keepdims=False, acc_dtype=None): """ Computes the mean value along the given axis(es) of a tensor `input`. @@ -1397,25 +1341,6 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None) be in a float type). If None, then we use the same rules as `sum()`. """ input = as_tensor_variable(input) - if op: - if dtype not in (None, "float64"): - raise NotImplementedError( - "The Mean op does not support the dtype argument, " - "and will always use float64. If you want to specify " - "the dtype, call tensor.mean(..., op=False).", - dtype, - ) - if acc_dtype not in (None, "float64"): - raise NotImplementedError( - "The Mean op does not support the acc_dtype argument, " - "and will always use float64. If you want to specify " - "acc_dtype, call tensor.mean(..., op=False).", - dtype, - ) - out = Mean(axis)(input) - if keepdims: - out = makeKeepDims(input, out, axis) - return out if dtype is not None: # The summation will be done with the specified dtype. diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 4c13004409..3fb3979c27 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -16,7 +16,7 @@ from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum +from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from tests.link.numba.test_basic import ( compare_numba_and_py, @@ -256,18 +256,6 @@ def test_Dimshuffle_non_contiguous(): 0, set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x), - 0, - set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x), - 0, - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Sum( axis=axis, dtype=dtype, acc_dtype=acc_dtype diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index e648869d4c..5aab9a95cc 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -43,7 +43,6 @@ log1p, log2, log10, - mean, mul, neg, neq, @@ -58,7 +57,7 @@ true_div, uint8, ) -from pytensor.tensor.type import fscalar, imatrix, iscalar, matrix +from pytensor.tensor.type import fscalar, imatrix, matrix from tests.link.test_link import make_function @@ -521,34 +520,6 @@ def test_constant(): assert c.dtype == "float32" -@pytest.mark.parametrize("mode", [Mode("py"), Mode("cvm")]) -def test_mean(mode): - a = iscalar("a") - b = iscalar("b") - z = mean(a, b) - z_fn = pytensor.function([a, b], z, mode=mode) - res = z_fn(1, 1) - assert np.allclose(res, 1.0) - - a = fscalar("a") - b = fscalar("b") - c = fscalar("c") - - z = mean(a, b, c) - - z_fn = pytensor.function([a, b, c], pytensor.grad(z, [a]), mode=mode) - res = z_fn(3, 4, 5) - assert np.allclose(res, 1 / 3) - - z_fn = pytensor.function([a, b, c], pytensor.grad(z, [b]), mode=mode) - res = z_fn(3, 4, 5) - assert np.allclose(res, 1 / 3) - - z = mean() - z_fn = pytensor.function([], z, mode=mode) - assert z_fn() == 0 - - def test_shape(): a = float32("a") assert isinstance(a.type, ScalarType) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 14bc2614e3..2d19ef0114 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -40,7 +40,6 @@ Argmax, Dot, Max, - Mean, Prod, ProdWithoutZeros, Sum, @@ -2587,15 +2586,6 @@ def test_mod_compile(): class TestInferShape(utt.InferShapeTester): - def test_Mean(self): - adtens3 = dtensor3() - adtens3_val = random(3, 4, 5) - aiscal_val = 2 - self._compile_and_check([adtens3], [Mean(None)(adtens3)], [adtens3_val], Mean) - self._compile_and_check( - [adtens3], [Mean(aiscal_val)(adtens3)], [adtens3_val], Mean - ) - def test_Max(self): adtens3 = dtensor3() adtens3_val = random(4, 5, 3)