Skip to content

Commit

Permalink
remove old code
Browse files Browse the repository at this point in the history
  • Loading branch information
icfaust committed Nov 28, 2024
1 parent 1175a98 commit 50ba766
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 33 deletions.
10 changes: 5 additions & 5 deletions onedal/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class BasicStatistics(BaseEstimator, metaclass=ABCMeta):
"""
Basic Statistics oneDAL implementation.
"""

@abstractmethod
def __init__(self, result_options="all", algorithm="by_default"):
self.options = result_options
Expand Down Expand Up @@ -58,7 +59,7 @@ def options(self):
def options(self, options):
# options always to be an iterable
self._options = options.split("|") if isinstance(options, str) else options

def _get_onedal_params(self, is_csr, dtype=np.float32):
return {
"fptype": dtype,
Expand All @@ -79,13 +80,12 @@ def fit(self, data, sample_weight=None, queue=None):
module = self._get_backend("basic_statistics")
params = self._get_onedal_params(is_csr, data_table.dtype)
result = module.compute(policy, params, data_table, weights_table)

for opt in self.options:
value = from_table(getattr(result, opt)).ravel()
value = from_table(getattr(result, opt))[:, 0] # two-dimensional table [n, 1]
if is_single_dim:
setattr(self, getattr(raw_result, opt), value[0])
setattr(self, opt, value[0])
else:
setattr(self, opt, value)

return self

40 changes: 12 additions & 28 deletions onedal/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@

import numpy as np

from daal4py.sklearn._utils import get_dtype

from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils import _check_array
from .basic_statistics import BaseBasicStatistics
from .basic_statistics import BasicStatistics


class IncrementalBasicStatistics(BaseBasicStatistics):
class IncrementalBasicStatistics(BasicStatistics):
"""
Incremental estimator for basic statistics based on oneDAL implementation.
Allows to compute basic statistics if data are splitted into batches.
Expand Down Expand Up @@ -65,16 +62,16 @@ class IncrementalBasicStatistics(BaseBasicStatistics):
Second order moment of each feature over all samples.
"""

def __init__(self, result_options="all"):
super().__init__(result_options, algorithm="by_default")
def __init__(self, result_options="all", algorithm="by_default"):
super().__init__(result_options, algorithm)
self._reset()

def _reset(self):
self._partial_result = self._get_backend(
"basic_statistics", None, "partial_compute_result"
)

def partial_fit(self, X, weights=None, queue=None):
def partial_fit(self, X, sample_weight=None, queue=None):
"""
Computes partial data for basic statistics
from data batch X and saves it to `_partial_result`.
Expand All @@ -95,33 +92,20 @@ def partial_fit(self, X, weights=None, queue=None):
"""
self._queue = queue
policy = self._get_policy(queue, X)
X, weights = _convert_to_supported(policy, X, weights)

X = _check_array(
X, dtype=[np.float64, np.float32], ensure_2d=False, force_all_finite=False
)
if weights is not None:
weights = _check_array(
weights,
dtype=[np.float64, np.float32],
ensure_2d=False,
force_all_finite=False,
)
X, sample_weight = to_table(_convert_to_supported(policy, X, sample_weight))

if not hasattr(self, "_onedal_params"):
dtype = get_dtype(X)
self._onedal_params = self._get_onedal_params(False, dtype=dtype)
self._onedal_params = self._get_onedal_params(False, dtype=X.dtype)

X_table, weights_table = to_table(X, weights)
self._partial_result = self._get_backend(
"basic_statistics",
None,
"partial_compute",
policy,
self._onedal_params,
self._partial_result,
X_table,
weights_table,
X,
sample_weight,
)

def finalize_fit(self, queue=None):
Expand Down Expand Up @@ -153,8 +137,8 @@ def finalize_fit(self, queue=None):
self._onedal_params,
self._partial_result,
)
options = self._get_result_options(self.options).split("|")
for opt in options:
setattr(self, opt, from_table(getattr(result, opt)).ravel())

for opt in self.options:
setattr(self, opt, from_table(getattr(result, opt))[:, 0])

return self

0 comments on commit 50ba766

Please sign in to comment.