Skip to content

Commit

Permalink
Merge pull request #895 from adrianeboyd/chore/update-develop-from-ma…
Browse files Browse the repository at this point in the history
…ster-v8.2

Update develop from master for v8.2
  • Loading branch information
adrianeboyd authored Aug 9, 2023
2 parents 4db3879 + 45e1bed commit 88df8a9
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 41 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ catalogue>=2.0.4,<2.1.0
confection>=0.0.1,<1.0.0
ml_datasets>=0.2.0,<0.3.0; python_version < "3.11"
# Third-party dependencies
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.11.0
pydantic>=1.7.4,!=1.8,!=1.8.1,<3.0.0
numpy>=1.15.0; python_version < "3.9"
numpy>=1.19.0; python_version >= "3.9"
packaging>=20.0
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ install_requires =
setuptools
numpy>=1.15.0; python_version < "3.9"
numpy>=1.19.0; python_version >= "3.9"
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.11.0
pydantic>=1.7.4,!=1.8,!=1.8.1,<3.0.0
packaging>=20.0
# Backports of modern Python features
dataclasses>=0.6,<1.0; python_version < "3.7"
Expand Down
47 changes: 33 additions & 14 deletions thinc/backends/numpy_ops.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,11 @@ class NumpyOps(Ops):
cdef int O = X.shape[1]
cdef int T = X.shape[0]

assert B != 0
assert O != 0
if B == 0 or O == 0:
if reals2d_ft is float2d_t:
return numpy.zeros(shape=(B, O), dtype="float32")
else:
return numpy.zeros(shape=(B, O), dtype="float64")

cdef np.ndarray means
if reals2d_ft is float2d_t:
Expand All @@ -340,8 +343,11 @@ class NumpyOps(Ops):
raise ValueError(f"all sequence lengths must be >= 0, got {length}")
T += length

assert T != 0
assert O != 0
if T == 0 or O == 0:
if reals2d_ft is float2d_t:
return numpy.zeros(shape=(T, O), dtype="float32")
else:
return numpy.zeros(shape=(T, O), dtype="float64")

cdef np.ndarray dX
if reals2d_ft is float2d_t:
Expand All @@ -358,8 +364,11 @@ class NumpyOps(Ops):
cdef int O = X.shape[1]
cdef int T = X.shape[0]

assert B != 0
assert O != 0
if B == 0 or O == 0:
if reals2d_ft is float2d_t:
return numpy.zeros(shape=(B, O), dtype="float32")
else:
return numpy.zeros(shape=(B, O), dtype="float64")

cdef np.ndarray sums
if reals2d_ft is float2d_t:
Expand All @@ -381,8 +390,11 @@ class NumpyOps(Ops):
raise ValueError(f"all sequence lengths must be >= 0, got {length}")
T += length

assert T != 0
assert O != 0
if T == 0 or O == 0:
if reals2d_ft is float2d_t:
return numpy.zeros(shape=(T, O), dtype="float32")
else:
return numpy.zeros(shape=(T, O), dtype="float64")

cdef np.ndarray dX
if reals2d_ft is float2d_t:
Expand All @@ -399,12 +411,16 @@ class NumpyOps(Ops):
cdef int O = X.shape[1]
cdef int T = X.shape[0]

assert B != 0
assert O != 0

cdef np.ndarray maxes
# Needs to be zero-initialized as we start by assuming that the first element is the max value.
cdef np.ndarray which = self.alloc(shape=(B, O), dtype="i", zeros=True)

if B == 0 or O == 0:
if reals2d_ft is float2d_t:
return numpy.zeros(shape=(B, O), dtype="float32"), which
else:
return numpy.zeros(shape=(B, O), dtype="float64"), which

cdef np.ndarray maxes
if reals2d_ft is float2d_t:
maxes = self.alloc(shape=(B, O), dtype="float32", zeros=False)
cpu_reduce_max(<float*>maxes.data, <int*>which.data, &X[0, 0], &lengths[0], B, T, O)
Expand All @@ -424,8 +440,11 @@ class NumpyOps(Ops):
raise ValueError(f"all sequence lengths must be > 0, got {length}")
T += length

assert T != 0
assert O != 0
if T == 0 or O == 0:
if reals2d_ft is float2d_t:
return numpy.zeros(shape=(T, O), dtype="float32")
else:
return numpy.zeros(shape=(T, O), dtype="float64")

cdef np.ndarray dX
if reals2d_ft is float2d_t:
Expand Down
10 changes: 6 additions & 4 deletions thinc/backends/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,17 +1289,19 @@ def reduce_max(self, X: Floats2d, lengths: Ints1d) -> Tuple[Floats2d, Ints2d]:
def backprop_reduce_first(
self, d_firsts: Floats2d, starts_ends: Ints1d
) -> Floats2d:
if starts_ends.size < 2:
raise ValueError(f"starts_ends should least have size 2")
if starts_ends.size == 0:
return self.alloc2f(0, d_firsts.shape[1], dtype=d_firsts.dtype, zeros=True)
elif starts_ends.size == 1:
raise ValueError(f"starts_ends must not have size 1")
dX = self.alloc2f(
int(starts_ends[-1]), d_firsts.shape[1], dtype=d_firsts.dtype, zeros=True
)
dX[starts_ends[:-1]] = d_firsts
return dX

def backprop_reduce_last(self, d_lasts: Floats2d, lasts: Ints1d) -> Floats2d:
if lasts.size < 1:
raise ValueError(f"lasts should least have size 2")
if lasts.size == 0:
return self.alloc2f(0, d_lasts.shape[1], dtype=d_lasts.dtype, zeros=True)
dX = self.alloc2f(
int(lasts[-1]) + 1, d_lasts.shape[1], dtype=d_lasts.dtype, zeros=True
)
Expand Down
79 changes: 79 additions & 0 deletions thinc/tests/backends/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,20 @@
FLOAT_TYPES = ["float32", "float64"]
INT_TYPES = ["int32", "int64"]

REDUCTIONS = ["reduce_first", "reduce_last", "reduce_max", "reduce_mean", "reduce_sum"]

REDUCE_ZERO_LENGTH_RAISES = [
("reduce_first", True),
("reduce_last", True),
("reduce_max", True),
# From a mathematical perspective we'd want mean reduction to raise for
# zero-length sequences, since floating point numbers are not a monoid
# under averaging. However, floret relies on reduce_mean to return a
# zero-vector in this case.
("reduce_mean", False),
("reduce_sum", False),
]


def create_pytorch_funcs():
import math
Expand Down Expand Up @@ -1077,6 +1091,71 @@ def test_backprop_reduce_mean(ops, dtype):
)


@pytest.mark.parametrize("ops", ALL_OPS)
@pytest.mark.parametrize("dtype", FLOAT_TYPES)
@pytest.mark.parametrize("reduction", REDUCTIONS)
def test_reduce_empty_batch(ops, dtype, reduction):
func = getattr(ops, reduction)
backprop_func = getattr(ops, f"backprop_{reduction}")

lengths = ops.asarray1i([])
Y = func(ops.alloc((0, 10), dtype=dtype), lengths)

if reduction == "reduce_max":
Y, which = Y
dX = backprop_func(Y, which, lengths)
elif isinstance(Y, tuple):
Y, extra = Y
dX = backprop_func(Y, extra)
else:
dX = backprop_func(Y, lengths)

assert Y.shape == (0, 10)
assert dX.shape == (0, 10)


@pytest.mark.parametrize("ops", ALL_OPS)
@pytest.mark.parametrize("dtype", FLOAT_TYPES)
@pytest.mark.parametrize("reduction", REDUCTIONS)
def test_reduce_empty_hidden(ops, dtype, reduction):
func = getattr(ops, reduction)
backprop_func = getattr(ops, f"backprop_{reduction}")

lengths = ops.asarray1i([2, 3])
Y = func(ops.alloc((5, 0), dtype=dtype), lengths)

if reduction == "reduce_max":
Y, which = Y
dX = backprop_func(Y, which, lengths)
elif isinstance(Y, tuple):
Y, extra = Y
dX = backprop_func(Y, extra)
else:
dX = backprop_func(Y, lengths)

assert Y.shape == (2, 0)
assert dX.shape == (5, 0)


@pytest.mark.parametrize("ops", ALL_OPS)
@pytest.mark.parametrize("dtype", FLOAT_TYPES)
@pytest.mark.parametrize("reduction_raises", REDUCE_ZERO_LENGTH_RAISES)
def test_reduce_zero_seq_length(ops, dtype, reduction_raises):
reduction_str, raises = reduction_raises
reduction = getattr(ops, reduction_str)
X = ops.asarray2f(
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [1.0, 2.0], [3.0, 4.0]], dtype=dtype
)
lengths = ops.asarray1i([3, 0, 2])

if raises:
with pytest.raises(ValueError):
reduction(X, lengths)
else:
# All non-raising reductions have zero as their identity element.
ops.xp.testing.assert_allclose(reduction(X, lengths)[1], [0.0, 0.0])


@pytest.mark.parametrize("ops", ALL_OPS)
@settings(max_examples=MAX_EXAMPLES, deadline=None)
@given(X=strategies.arrays_BI())
Expand Down
6 changes: 5 additions & 1 deletion thinc/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import catalogue
import numpy
import pytest
from pydantic import BaseModel, PositiveInt, StrictBool, StrictFloat, constr

try:
from pydantic.v1 import BaseModel, PositiveInt, StrictBool, StrictFloat, constr
except ImportError:
from pydantic import BaseModel, PositiveInt, StrictBool, StrictFloat, constr # type: ignore

import thinc.config
from thinc.api import Config, Model, NumpyOps, RAdam
Expand Down
7 changes: 6 additions & 1 deletion thinc/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import numpy
import pytest
from pydantic import ValidationError, create_model

try:
from pydantic.v1 import ValidationError, create_model
except ImportError:
from pydantic import ValidationError, create_model # type: ignore


from thinc.types import (
Floats1d,
Expand Down
8 changes: 6 additions & 2 deletions thinc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@

import numpy
from packaging.version import Version
from pydantic import ValidationError, create_model

try:
from pydantic.v1 import ValidationError, create_model
except ImportError:
from pydantic import ValidationError, create_model # type: ignore

from wasabi import table

from .compat import (
Expand Down Expand Up @@ -251,7 +256,6 @@ def to_categorical(
*,
label_smoothing: float = 0.0,
) -> FloatsXd:

if n_classes is None:
n_classes = int(numpy.max(Y) + 1) # type: ignore

Expand Down
50 changes: 33 additions & 17 deletions website/docs/api-backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -937,9 +937,10 @@ Backpropagate the Swish activation

</inline-list>

Dish or "Daniël's Swish-like activation" is an activation function with a non-monotinic shape similar to
[GELU](#gelu), [Swish](#swish) and [Mish](#mish). However, Dish does not rely on
elementary functions like `exp` or `erf`, making it much
Dish or "Daniël's Swish-like activation" is an activation function with a
non-monotinic shape similar to [GELU](#gelu), [Swish](#swish) and [Mish](#mish).
However, Dish does not rely on elementary functions like `exp` or `erf`, making
it much
[faster to compute](https://twitter.com/danieldekok/status/1484898130441166853)
in most cases.

Expand Down Expand Up @@ -1264,9 +1265,12 @@ Backpropagate the hard Swish MobileNet activation.

</inline-list>

Perform sequence-wise first pooling for data in the ragged format. Zero-length
sequences are not allowed. A `ValueError` is raised if any element in `lengths`
is zero.
Perform sequence-wise first pooling for data in the ragged format.

- Zero-length sequences are not allowed. A `ValueError` is raised if any element
in `lengths` is zero.
- Batch and hidden dimensions can have a size of zero. In these cases the
corresponding dimensions in the output also have a size of zero.

| Argument | Type | Description |
| ----------- | ------------------------------- | --------------------------------------------------------------------- |
Expand Down Expand Up @@ -1302,9 +1306,12 @@ Backpropagate the `reduce_first` operation.

</inline-list>

Perform sequence-wise last pooling for data in the ragged format. Zero-length
sequences are not allowed. A `ValueError` is raised if any element in `lengths`
is zero.
Perform sequence-wise last pooling for data in the ragged format.

- Zero-length sequences are not allowed. A `ValueError` is raised if any element
in `lengths` is zero.
- Batch and hidden dimensions can have a size of zero. In these cases the
corresponding dimensions in the output also have a size of zero.

| Argument | Type | Description |
| ----------- | ------------------------------- | ------------------------------------------------------------------------------- |
Expand Down Expand Up @@ -1340,8 +1347,11 @@ Backpropagate the `reduce_last` operation.

</inline-list>

Perform sequence-wise summation for data in the ragged format. Zero-length
sequences are reduced to the zero vector.
Perform sequence-wise summation for data in the ragged format.

- Zero-length sequences are reduced to all-zero vectors.
- Batch and hidden dimensions can have a size of zero. In these cases the
corresponding dimensions in the output also have a size of zero.

| Argument | Type | Description |
| ----------- | ----------------- | ----------------------------- |
Expand Down Expand Up @@ -1377,8 +1387,11 @@ Backpropagate the `reduce_sum` operation.

</inline-list>

Perform sequence-wise averaging for data in the ragged format. Zero-length
sequences are reduced to the zero vector.
Perform sequence-wise averaging for data in the ragged format.

- Zero-length sequences are reduced to all-zero vectors.
- Batch and hidden dimensions can have a size of zero. In these cases the
corresponding dimensions in the output also have a size of zero.

| Argument | Type | Description |
| ----------- | ----------------- | --------------------------- |
Expand Down Expand Up @@ -1415,8 +1428,12 @@ Backpropagate the `reduce_mean` operation.
</inline-list>

Perform sequence-wise max pooling for data in the ragged format. Zero-length
sequences are not allowed. A `ValueError` is raised if any element in `lengths`
is zero.
sequences are not allowed.

- Zero-length sequences are not allowed. A `ValueError` is raised if any element
in `lengths` is zero.
- Batch and hidden dimensions can have a size of zero. In these cases the
corresponding dimensions in the output also have a size of zero.

| Argument | Type | Description |
| ----------- | -------------------------------- | --------------------------- |
Expand All @@ -1434,8 +1451,7 @@ is zero.

</inline-list>

Backpropagate the `reduce_max` operation. A `ValueError` is raised if any
element in `lengths` is zero.
Backpropagate the `reduce_max` operation.

| Argument | Type | Description |
| ----------- | ----------------- | ------------------------------------------- |
Expand Down

0 comments on commit 88df8a9

Please sign in to comment.