Skip to content

Commit

Permalink
Remove Mean Op
Browse files Browse the repository at this point in the history
This Op does not really fit the CAReduce API, as it requires an extra bit of information (number of elements in the axis) during the loop. A better solution will be a fused Elemwise+CAReduce
  • Loading branch information
ricardoV94 committed Nov 29, 2024
1 parent 1a3af4b commit 9e24b10
Show file tree
Hide file tree
Showing 6 changed files with 3 additions and 161 deletions.
6 changes: 0 additions & 6 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
Add,
Composite,
IntDiv,
Mean,
Mul,
ScalarMaximum,
ScalarMinimum,
Expand Down Expand Up @@ -77,11 +76,6 @@ def scalar_in_place_fn_Sub(op, idx, res, arr):
return f"{res}[{idx}] -= {arr}"


@scalar_in_place_fn.register(Mean)
def scalar_in_place_fn_Mean(op, idx, res, arr):
return f"{res}[{idx}] += ({arr} - {res}[{idx}]) / (i + 1)"


@scalar_in_place_fn.register(Mul)
def scalar_in_place_fn_Mul(op, idx, res, arr):
return f"{res}[{idx}] *= {arr}"
Expand Down
26 changes: 0 additions & 26 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1871,32 +1871,6 @@ def L_op(self, inputs, outputs, gout):
add = Add(upcast_out, name="add")


class Mean(ScalarOp):
identity = 0
commutative = True
associative = False
nfunc_spec = ("mean", 2, 1)
nfunc_variadic = "mean"

def impl(self, *inputs):
return sum(inputs) / len(inputs)

def c_code(self, node, name, inputs, outputs, sub):
(z,) = outputs
if not inputs:
return f"{z} = 0;"
else:
return f"{z} = ({' + '.join(inputs)}) / ((double) {len(inputs)});"

def L_op(self, inputs, outputs, gout):
(gz,) = gout
retval = [gz / len(inputs)] * len(inputs)
return retval


mean = Mean(float_out, name="mean")


class Mul(ScalarOp):
identity = 1
commutative = True
Expand Down
77 changes: 1 addition & 76 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,63 +1316,7 @@ def complex_from_polar(abs, angle):
"""Return complex-valued tensor from polar coordinate specification."""


class Mean(FixedOpCAReduce):
__props__ = ("axis",)
nfunc_spec = ("mean", 1, 1)

def __init__(self, axis=None):
super().__init__(ps.mean, axis)
assert self.axis is None or len(self.axis) == 1

def __str__(self):
if self.axis is not None:
args = ", ".join(str(x) for x in self.axis)
return f"Mean{{{args}}}"
else:
return "Mean"

def _output_dtype(self, idtype):
# we want to protect against overflow
return "float64"

def perform(self, node, inp, out):
(input,) = inp
(output,) = out
if self.axis is None:
axis = None
else:
axis = self.axis[0]
# numpy.asarray is needed as otherwise we can end up with a
# numpy scalar.
output[0] = np.asarray(np.mean(input, dtype="float64", axis=axis))

def c_code(self, node, name, inames, onames, sub):
ret = super().c_code(node, name, inames, onames, sub)

if self.axis is not None:
return ret

# TODO: c_code perform support only axis is None
return (
ret
+ f"""
*((double *)PyArray_DATA({onames[0]})) /= PyArray_SIZE({inames[0]});
"""
)

def clone(self, **kwargs):
axis = kwargs.get("axis", self.axis)
return type(self)(axis=axis)


# TODO: implement the grad. When done and tested, you can make this the default
# version.
# def grad(self, (x,), (gout,)):
# import pdb;pdb.set_trace()
# return grad(mean(x, self.axis, op=False),[x])


def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None):
def mean(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
"""
Computes the mean value along the given axis(es) of a tensor `input`.
Expand All @@ -1397,25 +1341,6 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None)
be in a float type). If None, then we use the same rules as `sum()`.
"""
input = as_tensor_variable(input)
if op:
if dtype not in (None, "float64"):
raise NotImplementedError(
"The Mean op does not support the dtype argument, "
"and will always use float64. If you want to specify "
"the dtype, call tensor.mean(..., op=False).",
dtype,
)
if acc_dtype not in (None, "float64"):
raise NotImplementedError(
"The Mean op does not support the acc_dtype argument, "
"and will always use float64. If you want to specify "
"acc_dtype, call tensor.mean(..., op=False).",
dtype,
)
out = Mean(axis)(input)
if keepdims:
out = makeKeepDims(input, out, axis)
return out

if dtype is not None:
# The summation will be done with the specified dtype.
Expand Down
14 changes: 1 addition & 13 deletions tests/link/numba/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from tests.link.numba.test_basic import (
compare_numba_and_py,
Expand Down Expand Up @@ -256,18 +256,6 @@ def test_Dimshuffle_non_contiguous():
0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x),
0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x),
0,
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
Expand Down
31 changes: 1 addition & 30 deletions tests/scalar/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
log1p,
log2,
log10,
mean,
mul,
neg,
neq,
Expand All @@ -58,7 +57,7 @@
true_div,
uint8,
)
from pytensor.tensor.type import fscalar, imatrix, iscalar, matrix
from pytensor.tensor.type import fscalar, imatrix, matrix
from tests.link.test_link import make_function


Expand Down Expand Up @@ -521,34 +520,6 @@ def test_constant():
assert c.dtype == "float32"


@pytest.mark.parametrize("mode", [Mode("py"), Mode("cvm")])
def test_mean(mode):
a = iscalar("a")
b = iscalar("b")
z = mean(a, b)
z_fn = pytensor.function([a, b], z, mode=mode)
res = z_fn(1, 1)
assert np.allclose(res, 1.0)

a = fscalar("a")
b = fscalar("b")
c = fscalar("c")

z = mean(a, b, c)

z_fn = pytensor.function([a, b, c], pytensor.grad(z, [a]), mode=mode)
res = z_fn(3, 4, 5)
assert np.allclose(res, 1 / 3)

z_fn = pytensor.function([a, b, c], pytensor.grad(z, [b]), mode=mode)
res = z_fn(3, 4, 5)
assert np.allclose(res, 1 / 3)

z = mean()
z_fn = pytensor.function([], z, mode=mode)
assert z_fn() == 0


def test_shape():
a = float32("a")
assert isinstance(a.type, ScalarType)
Expand Down
10 changes: 0 additions & 10 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
Argmax,
Dot,
Max,
Mean,
Prod,
ProdWithoutZeros,
Sum,
Expand Down Expand Up @@ -2587,15 +2586,6 @@ def test_mod_compile():


class TestInferShape(utt.InferShapeTester):
def test_Mean(self):
adtens3 = dtensor3()
adtens3_val = random(3, 4, 5)
aiscal_val = 2
self._compile_and_check([adtens3], [Mean(None)(adtens3)], [adtens3_val], Mean)
self._compile_and_check(
[adtens3], [Mean(aiscal_val)(adtens3)], [adtens3_val], Mean
)

def test_Max(self):
adtens3 = dtensor3()
adtens3_val = random(4, 5, 3)
Expand Down

0 comments on commit 9e24b10

Please sign in to comment.