diff --git a/pyearth/earth.py b/pyearth/earth.py index e027d6a..a8185df 100644 --- a/pyearth/earth.py +++ b/pyearth/earth.py @@ -1314,8 +1314,15 @@ def score_samples(self, X, y=None, missing=None): X, y, sample_weight, output_weight, missing = self._scrub( X, y, None, None, missing) y_hat = self.predict(X, missing=missing) - residual = 1 - (y - y_hat) ** 2 / y**2 - return residual + if y_hat.ndim == 1: + y_hat = y_hat.reshape(-1, 1) + squared_errors = (y - y_hat) ** 2 + variances = np.var(y, axis=0).reshape(1, -1) + nze = variances != 0 # non-zero variance + nze = nze.ravel() + output = np.ones(squared_errors.shape) + output[:, nze] = 1 - squared_errors[:, nze] / variances[:, nze] + return output def transform(self, X, missing=None): '''