From 45b578bc7b0fe65c809d8d2c2625d2238718689a Mon Sep 17 00:00:00 2001 From: jinlow Date: Tue, 19 Sep 2023 11:51:26 -0500 Subject: [PATCH] A few more tests --- py-forust/forust/__init__.py | 9 +-------- py-forust/forust/serialize.py | 24 +++++++----------------- py-forust/tests/test_serailze.py | 24 +++++++++++++++++++++++- 3 files changed, 31 insertions(+), 26 deletions(-) diff --git a/py-forust/forust/__init__.py b/py-forust/forust/__init__.py index a5136d4..e994aa3 100644 --- a/py-forust/forust/__init__.py +++ b/py-forust/forust/__init__.py @@ -433,14 +433,7 @@ def fit( if sample_weight is None: sample_weight = np.ones(y_.shape, dtype="float64") - sample_weight_ = ( - sample_weight.to_numpy() - if isinstance(sample_weight, pd.Series) - else sample_weight - ) - - if not np.issubdtype(sample_weight_.dtype, "float64"): - sample_weight_ = sample_weight_.astype("float64", copy=False) + sample_weight_ = _convert_input_array(sample_weight) # Convert the monotone constraints into the form needed # by the rust code. diff --git a/py-forust/forust/serialize.py b/py-forust/forust/serialize.py index 29b14d8..f55276c 100644 --- a/py-forust/forust/serialize.py +++ b/py-forust/forust/serialize.py @@ -13,27 +13,13 @@ class BaseSerializer(ABC, Generic[T]): - def __call__(self, obj: Union[T, str]) -> Union[T, str]: - """Serializer is callable, if it's a string we are deserializing, anything else we are serializing. For the string serializer, this works as well, because both serialize and deserialize just return itself. - - Args: - obj (T | str): Object either to serialize, or deserialize. - - Returns: - T | str: Object that is either serialized or deserialized. - """ - if isinstance(obj, str): - return self.deserialize(obj) - else: - return self.serialize(obj) - @abstractmethod def serialize(self, obj: T) -> str: - ... + """serialize method - should take an object and return a string""" @abstractmethod def deserialize(self, obj_repr: str) -> T: - ... + """deserialize method - should take a string and return original object""" Scaler = Union[int, float, str] @@ -80,4 +66,8 @@ def serialize(self, obj: npt.NDArray) -> str: def deserialize(self, obj_repr: str) -> npt.NDArray: data = NumpyData(**json.loads(obj_repr)) - return np.array(data.array, dtype=data.dtype, shape=data.shape) # type: ignore + a = np.array(data.array, dtype=data.dtype) # type: ignore + if len(data.shape) == 1: + return a + else: + return a.reshape(data.shape) diff --git a/py-forust/tests/test_serailze.py b/py-forust/tests/test_serailze.py index e8eceec..bd3c6ee 100644 --- a/py-forust/tests/test_serailze.py +++ b/py-forust/tests/test_serailze.py @@ -1,8 +1,15 @@ from __future__ import annotations +import numpy as np import pytest -from forust.serialize import ObjectItem, ObjectSerializer, Scaler, ScalerSerializer +from forust.serialize import ( + NumpySerializer, + ObjectItem, + ObjectSerializer, + Scaler, + ScalerSerializer, +) scaler_values = [ 1, @@ -36,3 +43,18 @@ def test_object(value: ObjectItem): r = serializer.serialize(value) assert isinstance(r, str) assert value == serializer.deserialize(r) + + +numpy_values = [ + np.array([1.0, 2.23]), + np.array([1, 2, 3, 4, 5, 6]).reshape((2, 3)), + np.array([1, 2, 3, 4, 5, 6], dtype="int").reshape((2, 3)), +] + + +@pytest.mark.parametrize("value", numpy_values) +def test_numpy(value: np.ndarray): + serializer = NumpySerializer() + r = serializer.serialize(value) + assert isinstance(r, str) + assert np.array_equal(value, serializer.deserialize(r))