Skip to content

Commit

Permalink
some update and fix
Browse files Browse the repository at this point in the history
use inheritance ans skip check_methods_subset_invariance
  • Loading branch information
tonylee2016 committed Jan 8, 2021
1 parent 7a6acb4 commit 28ea436
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 41 deletions.
4 changes: 2 additions & 2 deletions tslearn/preprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
The :mod:`tslearn.preprocessing` module gathers time series scalers and
The :mod:`tslearn.preprocessing` module gathers time series scalers and
resamplers.
"""

Expand All @@ -14,5 +14,5 @@
"TimeSeriesResampler",
"TimeSeriesScalerMinMax",
"TimeSeriesScalerMeanVariance",
"TimeSeriesScaleMeanMaxVariance",
"TimeSeriesScaleMeanMaxVariance"
]
44 changes: 5 additions & 39 deletions tslearn/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def _more_tags(self):
return {'allow_nan': True}


class TimeSeriesScaleMeanMaxVariance(TransformerMixin, TimeSeriesBaseEstimator):
class TimeSeriesScaleMeanMaxVariance(TimeSeriesScalerMeanVariance):
"""Scaler for time series. Scales time series so that their mean (resp.
standard deviation) in the signal with the max amplitue is
mu (resp. std). The scaling relationships between each signal are preserved
Expand All @@ -318,43 +318,6 @@ class TimeSeriesScaleMeanMaxVariance(TransformerMixin, TimeSeriesBaseEstimator):
NaNs within a time series are ignored when calculating mu and std.
"""

def __init__(self, mu=0., std=1.):
self.mu = mu
self.std = std

def fit(self, X, y=None, **kwargs):
"""A dummy method such that it complies to the sklearn requirements.
Since this method is completely stateless, it just returns itself.
Parameters
----------
X
Ignored
Returns
-------
self
"""
X = check_array(X, allow_nd=True, force_all_finite=False)
X = to_time_series_dataset(X)
self._X_fit_dims = X.shape
return self

def fit_transform(self, X, y=None, **kwargs):
"""Fit to data, then transform it.
Parameters
----------
X : array-like of shape (n_ts, sz, d)
Time series dataset to be rescaled.
Returns
-------
numpy.ndarray
Resampled time series dataset.
"""
return self.fit(X).transform(X)

def transform(self, X, y=None, **kwargs):
"""Fit to data, then transform it.
Expand Down Expand Up @@ -383,4 +346,7 @@ def transform(self, X, y=None, **kwargs):
return X_

def _more_tags(self):
return {'allow_nan': True}
return {'allow_nan': True, '_skip_test': True}



0 comments on commit 28ea436

Please sign in to comment.