diff --git a/requirements.txt b/requirements.txt index d0f8b055b..b7682e738 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ catalogue>=2.0.4,<2.1.0 confection>=0.0.1,<1.0.0 ml_datasets>=0.2.0,<0.3.0; python_version < "3.11" # Third-party dependencies -pydantic>=1.7.4,!=1.8,!=1.8.1,<1.11.0 +pydantic>=1.7.4,!=1.8,!=1.8.1,<3.0.0 numpy>=1.15.0; python_version < "3.9" numpy>=1.19.0; python_version >= "3.9" packaging>=20.0 diff --git a/setup.cfg b/setup.cfg index d38e994fb..f80422a8c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,7 +51,7 @@ install_requires = setuptools numpy>=1.15.0; python_version < "3.9" numpy>=1.19.0; python_version >= "3.9" - pydantic>=1.7.4,!=1.8,!=1.8.1,<1.11.0 + pydantic>=1.7.4,!=1.8,!=1.8.1,<3.0.0 packaging>=20.0 # Backports of modern Python features dataclasses>=0.6,<1.0; python_version < "3.7" diff --git a/thinc/backends/numpy_ops.pyx b/thinc/backends/numpy_ops.pyx index f64aa29dd..5ab4d0d8f 100644 --- a/thinc/backends/numpy_ops.pyx +++ b/thinc/backends/numpy_ops.pyx @@ -317,8 +317,11 @@ class NumpyOps(Ops): cdef int O = X.shape[1] cdef int T = X.shape[0] - assert B != 0 - assert O != 0 + if B == 0 or O == 0: + if reals2d_ft is float2d_t: + return numpy.zeros(shape=(B, O), dtype="float32") + else: + return numpy.zeros(shape=(B, O), dtype="float64") cdef np.ndarray means if reals2d_ft is float2d_t: @@ -340,8 +343,11 @@ class NumpyOps(Ops): raise ValueError(f"all sequence lengths must be >= 0, got {length}") T += length - assert T != 0 - assert O != 0 + if T == 0 or O == 0: + if reals2d_ft is float2d_t: + return numpy.zeros(shape=(T, O), dtype="float32") + else: + return numpy.zeros(shape=(T, O), dtype="float64") cdef np.ndarray dX if reals2d_ft is float2d_t: @@ -358,8 +364,11 @@ class NumpyOps(Ops): cdef int O = X.shape[1] cdef int T = X.shape[0] - assert B != 0 - assert O != 0 + if B == 0 or O == 0: + if reals2d_ft is float2d_t: + return numpy.zeros(shape=(B, O), dtype="float32") + else: + return numpy.zeros(shape=(B, O), dtype="float64") cdef np.ndarray sums if reals2d_ft is float2d_t: @@ -381,8 +390,11 @@ class NumpyOps(Ops): raise ValueError(f"all sequence lengths must be >= 0, got {length}") T += length - assert T != 0 - assert O != 0 + if T == 0 or O == 0: + if reals2d_ft is float2d_t: + return numpy.zeros(shape=(T, O), dtype="float32") + else: + return numpy.zeros(shape=(T, O), dtype="float64") cdef np.ndarray dX if reals2d_ft is float2d_t: @@ -399,12 +411,16 @@ class NumpyOps(Ops): cdef int O = X.shape[1] cdef int T = X.shape[0] - assert B != 0 - assert O != 0 - - cdef np.ndarray maxes # Needs to be zero-initialized as we start by assuming that the first element is the max value. cdef np.ndarray which = self.alloc(shape=(B, O), dtype="i", zeros=True) + + if B == 0 or O == 0: + if reals2d_ft is float2d_t: + return numpy.zeros(shape=(B, O), dtype="float32"), which + else: + return numpy.zeros(shape=(B, O), dtype="float64"), which + + cdef np.ndarray maxes if reals2d_ft is float2d_t: maxes = self.alloc(shape=(B, O), dtype="float32", zeros=False) cpu_reduce_max(maxes.data, which.data, &X[0, 0], &lengths[0], B, T, O) @@ -424,8 +440,11 @@ class NumpyOps(Ops): raise ValueError(f"all sequence lengths must be > 0, got {length}") T += length - assert T != 0 - assert O != 0 + if T == 0 or O == 0: + if reals2d_ft is float2d_t: + return numpy.zeros(shape=(T, O), dtype="float32") + else: + return numpy.zeros(shape=(T, O), dtype="float64") cdef np.ndarray dX if reals2d_ft is float2d_t: diff --git a/thinc/backends/ops.py b/thinc/backends/ops.py index 01bb2f852..e3fec5c86 100644 --- a/thinc/backends/ops.py +++ b/thinc/backends/ops.py @@ -1289,8 +1289,10 @@ def reduce_max(self, X: Floats2d, lengths: Ints1d) -> Tuple[Floats2d, Ints2d]: def backprop_reduce_first( self, d_firsts: Floats2d, starts_ends: Ints1d ) -> Floats2d: - if starts_ends.size < 2: - raise ValueError(f"starts_ends should least have size 2") + if starts_ends.size == 0: + return self.alloc2f(0, d_firsts.shape[1], dtype=d_firsts.dtype, zeros=True) + elif starts_ends.size == 1: + raise ValueError(f"starts_ends must not have size 1") dX = self.alloc2f( int(starts_ends[-1]), d_firsts.shape[1], dtype=d_firsts.dtype, zeros=True ) @@ -1298,8 +1300,8 @@ def backprop_reduce_first( return dX def backprop_reduce_last(self, d_lasts: Floats2d, lasts: Ints1d) -> Floats2d: - if lasts.size < 1: - raise ValueError(f"lasts should least have size 2") + if lasts.size == 0: + return self.alloc2f(0, d_lasts.shape[1], dtype=d_lasts.dtype, zeros=True) dX = self.alloc2f( int(lasts[-1]) + 1, d_lasts.shape[1], dtype=d_lasts.dtype, zeros=True ) diff --git a/thinc/tests/backends/test_ops.py b/thinc/tests/backends/test_ops.py index d5235ecc3..b867b14e4 100644 --- a/thinc/tests/backends/test_ops.py +++ b/thinc/tests/backends/test_ops.py @@ -41,6 +41,20 @@ FLOAT_TYPES = ["float32", "float64"] INT_TYPES = ["int32", "int64"] +REDUCTIONS = ["reduce_first", "reduce_last", "reduce_max", "reduce_mean", "reduce_sum"] + +REDUCE_ZERO_LENGTH_RAISES = [ + ("reduce_first", True), + ("reduce_last", True), + ("reduce_max", True), + # From a mathematical perspective we'd want mean reduction to raise for + # zero-length sequences, since floating point numbers are not a monoid + # under averaging. However, floret relies on reduce_mean to return a + # zero-vector in this case. + ("reduce_mean", False), + ("reduce_sum", False), +] + def create_pytorch_funcs(): import math @@ -1077,6 +1091,71 @@ def test_backprop_reduce_mean(ops, dtype): ) +@pytest.mark.parametrize("ops", ALL_OPS) +@pytest.mark.parametrize("dtype", FLOAT_TYPES) +@pytest.mark.parametrize("reduction", REDUCTIONS) +def test_reduce_empty_batch(ops, dtype, reduction): + func = getattr(ops, reduction) + backprop_func = getattr(ops, f"backprop_{reduction}") + + lengths = ops.asarray1i([]) + Y = func(ops.alloc((0, 10), dtype=dtype), lengths) + + if reduction == "reduce_max": + Y, which = Y + dX = backprop_func(Y, which, lengths) + elif isinstance(Y, tuple): + Y, extra = Y + dX = backprop_func(Y, extra) + else: + dX = backprop_func(Y, lengths) + + assert Y.shape == (0, 10) + assert dX.shape == (0, 10) + + +@pytest.mark.parametrize("ops", ALL_OPS) +@pytest.mark.parametrize("dtype", FLOAT_TYPES) +@pytest.mark.parametrize("reduction", REDUCTIONS) +def test_reduce_empty_hidden(ops, dtype, reduction): + func = getattr(ops, reduction) + backprop_func = getattr(ops, f"backprop_{reduction}") + + lengths = ops.asarray1i([2, 3]) + Y = func(ops.alloc((5, 0), dtype=dtype), lengths) + + if reduction == "reduce_max": + Y, which = Y + dX = backprop_func(Y, which, lengths) + elif isinstance(Y, tuple): + Y, extra = Y + dX = backprop_func(Y, extra) + else: + dX = backprop_func(Y, lengths) + + assert Y.shape == (2, 0) + assert dX.shape == (5, 0) + + +@pytest.mark.parametrize("ops", ALL_OPS) +@pytest.mark.parametrize("dtype", FLOAT_TYPES) +@pytest.mark.parametrize("reduction_raises", REDUCE_ZERO_LENGTH_RAISES) +def test_reduce_zero_seq_length(ops, dtype, reduction_raises): + reduction_str, raises = reduction_raises + reduction = getattr(ops, reduction_str) + X = ops.asarray2f( + [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [1.0, 2.0], [3.0, 4.0]], dtype=dtype + ) + lengths = ops.asarray1i([3, 0, 2]) + + if raises: + with pytest.raises(ValueError): + reduction(X, lengths) + else: + # All non-raising reductions have zero as their identity element. + ops.xp.testing.assert_allclose(reduction(X, lengths)[1], [0.0, 0.0]) + + @pytest.mark.parametrize("ops", ALL_OPS) @settings(max_examples=MAX_EXAMPLES, deadline=None) @given(X=strategies.arrays_BI()) diff --git a/thinc/tests/test_config.py b/thinc/tests/test_config.py index fe2118e25..a3f4ede46 100644 --- a/thinc/tests/test_config.py +++ b/thinc/tests/test_config.py @@ -6,7 +6,11 @@ import catalogue import numpy import pytest -from pydantic import BaseModel, PositiveInt, StrictBool, StrictFloat, constr + +try: + from pydantic.v1 import BaseModel, PositiveInt, StrictBool, StrictFloat, constr +except ImportError: + from pydantic import BaseModel, PositiveInt, StrictBool, StrictFloat, constr # type: ignore import thinc.config from thinc.api import Config, Model, NumpyOps, RAdam diff --git a/thinc/tests/test_types.py b/thinc/tests/test_types.py index ebfbb6fb6..738a309f9 100644 --- a/thinc/tests/test_types.py +++ b/thinc/tests/test_types.py @@ -1,6 +1,11 @@ import numpy import pytest -from pydantic import ValidationError, create_model + +try: + from pydantic.v1 import ValidationError, create_model +except ImportError: + from pydantic import ValidationError, create_model # type: ignore + from thinc.types import ( Floats1d, diff --git a/thinc/util.py b/thinc/util.py index ce8fcbb78..5ca928698 100644 --- a/thinc/util.py +++ b/thinc/util.py @@ -24,7 +24,12 @@ import numpy from packaging.version import Version -from pydantic import ValidationError, create_model + +try: + from pydantic.v1 import ValidationError, create_model +except ImportError: + from pydantic import ValidationError, create_model # type: ignore + from wasabi import table from .compat import ( @@ -251,7 +256,6 @@ def to_categorical( *, label_smoothing: float = 0.0, ) -> FloatsXd: - if n_classes is None: n_classes = int(numpy.max(Y) + 1) # type: ignore diff --git a/website/docs/api-backends.md b/website/docs/api-backends.md index c5a54cff8..fc69a775d 100644 --- a/website/docs/api-backends.md +++ b/website/docs/api-backends.md @@ -937,9 +937,10 @@ Backpropagate the Swish activation -Dish or "Daniël's Swish-like activation" is an activation function with a non-monotinic shape similar to -[GELU](#gelu), [Swish](#swish) and [Mish](#mish). However, Dish does not rely on -elementary functions like `exp` or `erf`, making it much +Dish or "Daniël's Swish-like activation" is an activation function with a +non-monotinic shape similar to [GELU](#gelu), [Swish](#swish) and [Mish](#mish). +However, Dish does not rely on elementary functions like `exp` or `erf`, making +it much [faster to compute](https://twitter.com/danieldekok/status/1484898130441166853) in most cases. @@ -1264,9 +1265,12 @@ Backpropagate the hard Swish MobileNet activation. -Perform sequence-wise first pooling for data in the ragged format. Zero-length -sequences are not allowed. A `ValueError` is raised if any element in `lengths` -is zero. +Perform sequence-wise first pooling for data in the ragged format. + +- Zero-length sequences are not allowed. A `ValueError` is raised if any element + in `lengths` is zero. +- Batch and hidden dimensions can have a size of zero. In these cases the + corresponding dimensions in the output also have a size of zero. | Argument | Type | Description | | ----------- | ------------------------------- | --------------------------------------------------------------------- | @@ -1302,9 +1306,12 @@ Backpropagate the `reduce_first` operation. -Perform sequence-wise last pooling for data in the ragged format. Zero-length -sequences are not allowed. A `ValueError` is raised if any element in `lengths` -is zero. +Perform sequence-wise last pooling for data in the ragged format. + +- Zero-length sequences are not allowed. A `ValueError` is raised if any element + in `lengths` is zero. +- Batch and hidden dimensions can have a size of zero. In these cases the + corresponding dimensions in the output also have a size of zero. | Argument | Type | Description | | ----------- | ------------------------------- | ------------------------------------------------------------------------------- | @@ -1340,8 +1347,11 @@ Backpropagate the `reduce_last` operation. -Perform sequence-wise summation for data in the ragged format. Zero-length -sequences are reduced to the zero vector. +Perform sequence-wise summation for data in the ragged format. + +- Zero-length sequences are reduced to all-zero vectors. +- Batch and hidden dimensions can have a size of zero. In these cases the + corresponding dimensions in the output also have a size of zero. | Argument | Type | Description | | ----------- | ----------------- | ----------------------------- | @@ -1377,8 +1387,11 @@ Backpropagate the `reduce_sum` operation. -Perform sequence-wise averaging for data in the ragged format. Zero-length -sequences are reduced to the zero vector. +Perform sequence-wise averaging for data in the ragged format. + +- Zero-length sequences are reduced to all-zero vectors. +- Batch and hidden dimensions can have a size of zero. In these cases the + corresponding dimensions in the output also have a size of zero. | Argument | Type | Description | | ----------- | ----------------- | --------------------------- | @@ -1415,8 +1428,12 @@ Backpropagate the `reduce_mean` operation. Perform sequence-wise max pooling for data in the ragged format. Zero-length -sequences are not allowed. A `ValueError` is raised if any element in `lengths` -is zero. +sequences are not allowed. + +- Zero-length sequences are not allowed. A `ValueError` is raised if any element + in `lengths` is zero. +- Batch and hidden dimensions can have a size of zero. In these cases the + corresponding dimensions in the output also have a size of zero. | Argument | Type | Description | | ----------- | -------------------------------- | --------------------------- | @@ -1434,8 +1451,7 @@ is zero. -Backpropagate the `reduce_max` operation. A `ValueError` is raised if any -element in `lengths` is zero. +Backpropagate the `reduce_max` operation. | Argument | Type | Description | | ----------- | ----------------- | ------------------------------------------- |