From 1175a98419f4b613bf7fa36773c425e1a3b92f00 Mon Sep 17 00:00:00 2001 From: Ian Faust Date: Tue, 26 Nov 2024 12:45:10 +0100 Subject: [PATCH] don't have more time at the moment to do this. --- onedal/basic_statistics/basic_statistics.py | 66 ++++++++------------- 1 file changed, 25 insertions(+), 41 deletions(-) diff --git a/onedal/basic_statistics/basic_statistics.py b/onedal/basic_statistics/basic_statistics.py index c60d1599ac..549524a533 100644 --- a/onedal/basic_statistics/basic_statistics.py +++ b/onedal/basic_statistics/basic_statistics.py @@ -22,12 +22,14 @@ from ..common._base import BaseEstimator from ..datatypes import _convert_to_supported, from_table, to_table from ..utils import _is_csr -from ..utils.validation import _check_array -class BaseBasicStatistics(BaseEstimator, metaclass=ABCMeta): +class BasicStatistics(BaseEstimator, metaclass=ABCMeta): + """ + Basic Statistics oneDAL implementation. + """ @abstractmethod - def __init__(self, result_options, algorithm): + def __init__(self, result_options="all", algorithm="by_default"): self.options = result_options self.algorithm = algorithm @@ -46,62 +48,44 @@ def get_all_result_options(): "second_order_raw_moment", ] - def _get_result_options(self, options): - if options == "all": - options = self.get_all_result_options() - if isinstance(options, list): - options = "|".join(options) - assert isinstance(options, str) - return options - + @property + def options(self): + if self._options == "all": + return self.get_all_result_options() + return self._options + + @options.setter + def options(self, options): + # options always to be an iterable + self._options = options.split("|") if isinstance(options, str) else options + def _get_onedal_params(self, is_csr, dtype=np.float32): - options = self._get_result_options(self.options) return { "fptype": dtype, "method": "sparse" if is_csr else self.algorithm, - "result_option": options, + "result_option": "|".join(self.options), } - -class BasicStatistics(BaseBasicStatistics): - """ - Basic Statistics oneDAL implementation. - """ - - def __init__(self, result_options="all", algorithm="by_default"): - super().__init__(result_options, algorithm) - def fit(self, data, sample_weight=None, queue=None): policy = self._get_policy(queue, data, sample_weight) is_csr = _is_csr(data) - if data is not None and not is_csr: - data = _check_array(data, ensure_2d=False) - if sample_weight is not None: - sample_weight = _check_array(sample_weight, ensure_2d=False) - data, sample_weight = _convert_to_supported(policy, data, sample_weight) is_single_dim = data.ndim == 1 data_table, weights_table = to_table(data, sample_weight) - dtype = data.dtype - raw_result = self._compute_raw(data_table, weights_table, policy, dtype, is_csr) - for opt, raw_value in raw_result.items(): - value = from_table(raw_value).ravel() + dtype = data_table.dtype + module = self._get_backend("basic_statistics") + params = self._get_onedal_params(is_csr, data_table.dtype) + result = module.compute(policy, params, data_table, weights_table) + + for opt in self.options: + value = from_table(getattr(result, opt)).ravel() if is_single_dim: - setattr(self, opt, value[0]) + setattr(self, getattr(raw_result, opt), value[0]) else: setattr(self, opt, value) return self - def _compute_raw( - self, data_table, weights_table, policy, dtype=np.float32, is_csr=False - ): - module = self._get_backend("basic_statistics") - params = self._get_onedal_params(is_csr, dtype) - result = module.compute(policy, params, data_table, weights_table) - options = self._get_result_options(self.options).split("|") - - return {opt: getattr(result, opt) for opt in options}