From ce9b1c5b6c1b59ce13722a645afa9b57fce244b1 Mon Sep 17 00:00:00 2001 From: Andreas Huber Date: Fri, 22 Nov 2024 06:20:41 -0800 Subject: [PATCH] Use hparams in predict_proba --- onedal/ensemble/forest.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/onedal/ensemble/forest.py b/onedal/ensemble/forest.py index d1d3c9849a..f79bc55285 100644 --- a/onedal/ensemble/forest.py +++ b/onedal/ensemble/forest.py @@ -364,7 +364,7 @@ def _predict(self, X, module, queue, hparams=None): y = from_table(result.responses) return y - def _predict_proba(self, X, module, queue): + def _predict_proba(self, X, module, queue, hparams=None): _check_is_fitted(self) X = _check_array( X, dtype=[np.float64, np.float32], force_all_finite=True, accept_sparse=False @@ -376,7 +376,11 @@ def _predict_proba(self, X, module, queue): params["infer_mode"] = "class_probabilities" model = self._onedal_model - result = module.infer(policy, params, model, to_table(X)) + if hparams is not None and not hparams.is_default: + result = module.infer(policy, params, hparams.backend, model, to_table(X)) + else: + result = module.infer(policy, params, model, to_table(X)) + y = from_table(result.probabilities) return y @@ -472,8 +476,13 @@ def predict(self, X, queue=None): return np.take(self.classes_, pred.ravel().astype(np.int64, casting="unsafe")) def predict_proba(self, X, queue=None): + hparams = get_hyperparameters("decision_forest", "infer") + return super()._predict_proba( - X, self._get_backend("decision_forest", "classification", None), queue + X, + self._get_backend("decision_forest", "classification", None), + queue, + hparams, )