Skip to content

Commit

Permalink
A few more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jinlow committed Sep 19, 2023
1 parent a12eef2 commit 45b578b
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 26 deletions.
9 changes: 1 addition & 8 deletions py-forust/forust/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,14 +433,7 @@ def fit(

if sample_weight is None:
sample_weight = np.ones(y_.shape, dtype="float64")
sample_weight_ = (
sample_weight.to_numpy()
if isinstance(sample_weight, pd.Series)
else sample_weight
)

if not np.issubdtype(sample_weight_.dtype, "float64"):
sample_weight_ = sample_weight_.astype("float64", copy=False)
sample_weight_ = _convert_input_array(sample_weight)

# Convert the monotone constraints into the form needed
# by the rust code.
Expand Down
24 changes: 7 additions & 17 deletions py-forust/forust/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,13 @@


class BaseSerializer(ABC, Generic[T]):
def __call__(self, obj: Union[T, str]) -> Union[T, str]:
"""Serializer is callable, if it's a string we are deserializing, anything else we are serializing. For the string serializer, this works as well, because both serialize and deserialize just return itself.
Args:
obj (T | str): Object either to serialize, or deserialize.
Returns:
T | str: Object that is either serialized or deserialized.
"""
if isinstance(obj, str):
return self.deserialize(obj)
else:
return self.serialize(obj)

@abstractmethod
def serialize(self, obj: T) -> str:
...
"""serialize method - should take an object and return a string"""

@abstractmethod
def deserialize(self, obj_repr: str) -> T:
...
"""deserialize method - should take a string and return original object"""


Scaler = Union[int, float, str]
Expand Down Expand Up @@ -80,4 +66,8 @@ def serialize(self, obj: npt.NDArray) -> str:

def deserialize(self, obj_repr: str) -> npt.NDArray:
data = NumpyData(**json.loads(obj_repr))
return np.array(data.array, dtype=data.dtype, shape=data.shape) # type: ignore
a = np.array(data.array, dtype=data.dtype) # type: ignore
if len(data.shape) == 1:
return a
else:
return a.reshape(data.shape)
24 changes: 23 additions & 1 deletion py-forust/tests/test_serailze.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from __future__ import annotations

import numpy as np
import pytest

from forust.serialize import ObjectItem, ObjectSerializer, Scaler, ScalerSerializer
from forust.serialize import (
NumpySerializer,
ObjectItem,
ObjectSerializer,
Scaler,
ScalerSerializer,
)

scaler_values = [
1,
Expand Down Expand Up @@ -36,3 +43,18 @@ def test_object(value: ObjectItem):
r = serializer.serialize(value)
assert isinstance(r, str)
assert value == serializer.deserialize(r)


numpy_values = [
np.array([1.0, 2.23]),
np.array([1, 2, 3, 4, 5, 6]).reshape((2, 3)),
np.array([1, 2, 3, 4, 5, 6], dtype="int").reshape((2, 3)),
]


@pytest.mark.parametrize("value", numpy_values)
def test_numpy(value: np.ndarray):
serializer = NumpySerializer()
r = serializer.serialize(value)
assert isinstance(r, str)
assert np.array_equal(value, serializer.deserialize(r))

0 comments on commit 45b578b

Please sign in to comment.