Skip to content

Commit

Permalink
Make sure that print_report has column names if input is a df
Browse files Browse the repository at this point in the history
  • Loading branch information
sachaMorin committed Aug 7, 2023
1 parent 7ce11f5 commit 59484bf
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
4 changes: 2 additions & 2 deletions stepmix/stepmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def report(self, X, Y=None, sample_weight=None):
Y : array-like of shape (n_samples, n_features_structural), default=None
sample_weight : array-like of shape(n_samples,), default=None
"""
utils.print_report(self, X, Y, sample_weight)
utils.print_report(self, X, Y, sample_weight, self.x_names_, self.y_names_)

def em(
self,
Expand Down Expand Up @@ -1237,7 +1237,7 @@ def predict_proba(self, X, Y=None):
List of n_features-dimensional data points to fit the structural model. Each row
corresponds to a single data point. If the data is categorical, by default it should be
0-indexed and integer encoded (not one-hot encoded).
Returns
-------
resp : array, shape (n_samples, n_components)
Expand Down
10 changes: 7 additions & 3 deletions stepmix/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def modal(resp, clip=False):
return modal_resp


def print_report(model, X, Y=None, sample_weight=None):
def print_report(model, X, Y=None, sample_weight=None, x_names=None, y_names=None):
"""Print detailed output for the model.
Parameters
Expand All @@ -291,6 +291,10 @@ def print_report(model, X, Y=None, sample_weight=None):
X : array-like of shape (n_samples, n_features)
Y : array-like of shape (n_samples, n_features_structural), default=None
sample_weight : array-like of shape(n_samples,), default=None
x_names : List of str, default=None
Column names of X.
y_names : List of str, default=None
Column names of Y.
"""
check_is_fitted(model)
n_classes = model.n_components
Expand All @@ -317,13 +321,13 @@ def print_report(model, X, Y=None, sample_weight=None):
print(" " + "=" * 76)
print(f" Measurement model parameters")
print(" " + "=" * 76)
model._mm.print_parameters(indent=2)
model._mm.print_parameters(indent=2, feature_names=x_names)

if hasattr(model, "_sm"):
print(" " + "=" * 76)
print(f" Structural model parameters")
print(" " + "=" * 76)
model._sm.print_parameters(indent=2)
model._sm.print_parameters(indent=2, feature_names=y_names)

print(" " + "=" * 76)
print(f" Class weights")
Expand Down

0 comments on commit 59484bf

Please sign in to comment.