Skip to content

Commit

Permalink
Merge pull request #899 from adrianeboyd/chore/update-develop-from-ma…
Browse files Browse the repository at this point in the history
…ster-v8.2-2

Update develop from master for v8.2
  • Loading branch information
adrianeboyd authored Aug 11, 2023
2 parents 88df8a9 + 027c07c commit 70840bf
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 4 deletions.
2 changes: 1 addition & 1 deletion thinc/backends/cupy_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def asarray(self, data, dtype=None):
elif is_mxnet_gpu_array(data):
array = mxnet2xp(data)
else:
array = self.xp.array(data)
array = self.xp.array(data, dtype=dtype)

if dtype is not None:
array = array.astype(dtype=dtype, copy=False)
Expand Down
2 changes: 1 addition & 1 deletion thinc/backends/numpy_ops.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class NumpyOps(Ops):
elif hasattr(data, "get"):
array = data.get()
else:
array = self.xp.array(data)
array = self.xp.array(data, dtype=dtype)

if dtype is not None:
array = array.astype(dtype=dtype, copy=False)
Expand Down
6 changes: 4 additions & 2 deletions thinc/layers/strings2arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ def strings2arrays() -> Model[InT, OutT]:


def forward(model: Model[InT, OutT], Xs: InT, is_train: bool) -> Tuple[OutT, Callable]:
hashes = [[hash_unicode(word) for word in X] for X in Xs]
hash_arrays = [model.ops.asarray2i(h, dtype="uint64") for h in hashes]
hashes = model.ops.asarray2i(
[[hash_unicode(word) for word in X] for X in Xs], dtype="int32"
)
hash_arrays = [model.ops.asarray1i(h, dtype="uint64") for h in hashes]
arrays = [model.ops.reshape2i(array, -1, 1) for array in hash_arrays]

def backprop(dX: OutT) -> InT:
Expand Down
7 changes: 7 additions & 0 deletions thinc/tests/backends/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,3 +1597,10 @@ def test_custom_kernel_compilation():
assert compiled_kernel is not None

assert compile_mmh() is not None


@pytest.mark.parametrize("ops", ALL_OPS)
def test_asarray_from_list_uint64(ops):
# list contains int values both above and below int64.max
uint64_list = [16, 11648197037703959513]
assert uint64_list == list(ops.asarray(uint64_list, dtype="uint64"))

0 comments on commit 70840bf

Please sign in to comment.