Skip to content

Commit

Permalink
don't have more time at the moment to do this.
Browse files Browse the repository at this point in the history
  • Loading branch information
icfaust authored Nov 26, 2024
1 parent 527ce22 commit 1175a98
Showing 1 changed file with 25 additions and 41 deletions.
66 changes: 25 additions & 41 deletions onedal/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
from ..common._base import BaseEstimator
from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils import _is_csr
from ..utils.validation import _check_array


class BaseBasicStatistics(BaseEstimator, metaclass=ABCMeta):
class BasicStatistics(BaseEstimator, metaclass=ABCMeta):
"""
Basic Statistics oneDAL implementation.
"""
@abstractmethod
def __init__(self, result_options, algorithm):
def __init__(self, result_options="all", algorithm="by_default"):
self.options = result_options
self.algorithm = algorithm

Expand All @@ -46,62 +48,44 @@ def get_all_result_options():
"second_order_raw_moment",
]

def _get_result_options(self, options):
if options == "all":
options = self.get_all_result_options()
if isinstance(options, list):
options = "|".join(options)
assert isinstance(options, str)
return options

@property
def options(self):
if self._options == "all":
return self.get_all_result_options()
return self._options

@options.setter
def options(self, options):
# options always to be an iterable
self._options = options.split("|") if isinstance(options, str) else options

def _get_onedal_params(self, is_csr, dtype=np.float32):
options = self._get_result_options(self.options)
return {
"fptype": dtype,
"method": "sparse" if is_csr else self.algorithm,
"result_option": options,
"result_option": "|".join(self.options),
}


class BasicStatistics(BaseBasicStatistics):
"""
Basic Statistics oneDAL implementation.
"""

def __init__(self, result_options="all", algorithm="by_default"):
super().__init__(result_options, algorithm)

def fit(self, data, sample_weight=None, queue=None):
policy = self._get_policy(queue, data, sample_weight)

is_csr = _is_csr(data)

if data is not None and not is_csr:
data = _check_array(data, ensure_2d=False)
if sample_weight is not None:
sample_weight = _check_array(sample_weight, ensure_2d=False)

data, sample_weight = _convert_to_supported(policy, data, sample_weight)
is_single_dim = data.ndim == 1
data_table, weights_table = to_table(data, sample_weight)

dtype = data.dtype
raw_result = self._compute_raw(data_table, weights_table, policy, dtype, is_csr)
for opt, raw_value in raw_result.items():
value = from_table(raw_value).ravel()
dtype = data_table.dtype
module = self._get_backend("basic_statistics")
params = self._get_onedal_params(is_csr, data_table.dtype)
result = module.compute(policy, params, data_table, weights_table)

for opt in self.options:
value = from_table(getattr(result, opt)).ravel()
if is_single_dim:
setattr(self, opt, value[0])
setattr(self, getattr(raw_result, opt), value[0])
else:
setattr(self, opt, value)

return self

def _compute_raw(
self, data_table, weights_table, policy, dtype=np.float32, is_csr=False
):
module = self._get_backend("basic_statistics")
params = self._get_onedal_params(is_csr, dtype)
result = module.compute(policy, params, data_table, weights_table)
options = self._get_result_options(self.options).split("|")

return {opt: getattr(result, opt) for opt in options}

0 comments on commit 1175a98

Please sign in to comment.