From ad452fa8ea7d21cdbe8df478af4156e0ce2c79bf Mon Sep 17 00:00:00 2001 From: Cedric Kulbach Date: Fri, 23 Dec 2022 14:10:59 +0100 Subject: [PATCH] refactor predict_proba_many and tensor_conversion.py --- .../classification/rolling_classifier.py | 2 +- deep_river/utils/tensor_conversion.py | 6 ++-- deep_river/utils/test_tensor_conversion.py | 28 +++++++++++-------- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/deep_river/classification/rolling_classifier.py b/deep_river/classification/rolling_classifier.py index 52a1d910..d07d2578 100644 --- a/deep_river/classification/rolling_classifier.py +++ b/deep_river/classification/rolling_classifier.py @@ -333,7 +333,7 @@ def predict_proba_many(self, X: pd.DataFrame) -> pd.DataFrame: probas = [default_proba] * len(X) return pd.DataFrame(probas) - def _get_default_proba(self)-> List[Dict[ClfTarget, float]]: + def _get_default_proba(self) -> List[Dict[ClfTarget, float]]: if len(self.observed_classes) > 0: mean_proba = ( 1 / len(self.observed_classes) diff --git a/deep_river/utils/tensor_conversion.py b/deep_river/utils/tensor_conversion.py index e61a233d..00babeda 100644 --- a/deep_river/utils/tensor_conversion.py +++ b/deep_river/utils/tensor_conversion.py @@ -1,4 +1,4 @@ -from typing import Collection, Deque, Dict, Optional, Union +from typing import Deque, Dict, List, Optional, Union import numpy as np import pandas as pd @@ -146,7 +146,7 @@ def labels2onehot( def output2proba( preds: torch.Tensor, classes: OrderedSet, with_logits=False -) -> Collection[Dict[ClfTarget, float]]: +) -> List[Dict[ClfTarget, float]]: if with_logits: if preds.shape[-1] >= 1: preds = torch.softmax(preds, dim=-1) @@ -168,4 +168,4 @@ def output2proba( if preds_np.shape[0] == 1 else [dict(zip(classes, pred)) for pred in preds_np] ) - return [probas] if isinstance(probas, dict) else probas + return [probas] if isinstance(probas, dict) else list(probas) diff --git a/deep_river/utils/test_tensor_conversion.py b/deep_river/utils/test_tensor_conversion.py index 7b18197f..fff4637e 100644 --- a/deep_river/utils/test_tensor_conversion.py +++ b/deep_river/utils/test_tensor_conversion.py @@ -80,7 +80,9 @@ def test_output2proba(): def assert_dicts_almost_equal(d1, d2): for i in range(len(d1)): for k in d1[i]: - assert np.isclose(d1[i][k], d2[i][k]), f"{d1[i][k]} != {d2[i][k]}" + assert np.isclose( + d1[i][k], d2[i][k] + ), f"{d1[i][k]} != {d2[i][k]}" y = torch.tensor([[0.1, 0.2, 0.7]]) classes = ["first class", "second class", "third class"] @@ -92,20 +94,24 @@ def assert_dicts_almost_equal(d1, d2): classes = ["first class"] assert_dicts_almost_equal( output2proba(y, classes), - [dict( - zip( - ["first class", "unobserved 0"], - np.array([0.6, 0.4], dtype=np.float32), + [ + dict( + zip( + ["first class", "unobserved 0"], + np.array([0.6, 0.4], dtype=np.float32), + ) ) - )], + ], ) y = torch.tensor([[0.6, 0.4, 0.0]]) assert_dicts_almost_equal( output2proba(y, classes), - [dict( - zip( - ["first class", "unobserved 0", "unobserved 1"], - np.array([0.6, 0.4, 0.0], dtype=np.float32), + [ + dict( + zip( + ["first class", "unobserved 0", "unobserved 1"], + np.array([0.6, 0.4, 0.0], dtype=np.float32), + ) ) - )] + ], )