diff --git a/appveyor.yml b/appveyor.yml index b1a5b8c..6663c0b 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -19,7 +19,6 @@ install: | for f in $(find . -maxdepth 1 -name 'requirements*.txt'); do pip install -r ${f} done - pip install pandas # Needed for some estimator checks. pip install . test_script: diff --git a/doc/index.rst b/doc/index.rst index 2ae4e6c..cd95ca5 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -13,24 +13,28 @@ cases including complex pre-processing, model stacking and benchmarking. from skdag import DAGBuilder dag = ( - DAGBuilder() + DAGBuilder(infer_dataframe=True) .add_step("impute", SimpleImputer()) - .add_step("vitals", "passthrough", deps={"impute": slice(0, 4)}) + .add_step( + "vitals", + "passthrough", + deps={"impute": ["age", "sex", "bmi", "bp"]}, + ) .add_step( "blood", PCA(n_components=2, random_state=0), - deps={"impute": slice(4, 10)} + deps={"impute": ["s1", "s2", "s3", "s4", "s5", "s6"]}, ) .add_step( "rf", RandomForestRegressor(max_depth=5, random_state=0), - deps=["blood", "vitals"] + deps=["blood", "vitals"], ) .add_step("svm", SVR(C=0.7), deps=["blood", "vitals"]) .add_step( "knn", KNeighborsRegressor(n_neighbors=5), - deps=["blood", "vitals"] + deps=["blood", "vitals"], ) .add_step("meta", LinearRegression(), deps=["rf", "svm", "knn"]) .make_dag(n_jobs=2, verbose=True) diff --git a/doc/quick_start.rst b/doc/quick_start.rst index e9e4794..e52f4b9 100644 --- a/doc/quick_start.rst +++ b/doc/quick_start.rst @@ -50,31 +50,34 @@ For more complex DAGs, it is recommended to use a :class:`skdag.dag.DAGBuilder`, which allows you to define the graph by specifying the dependencies of each new estimator: ->>> from skdag import DAGBuilder ->>> dag = ( -... DAGBuilder() -... .add_step("impute", SimpleImputer()) -... .add_step("vitals", "passthrough", deps={"impute": slice(0, 4)}) -... .add_step("blood", PCA(n_components=2, random_state=0), deps={"impute": slice(4, 10)}) -... .add_step("lr", LogisticRegression(random_state=0), deps=["blood", "vitals"]) -... .make_dag() -... ) ->>> dag.draw() -o impute -|\ -o o blood,vitals -|/ -o lr - +.. code-block:: python + + >>> from skdag import DAGBuilder + >>> dag = ( + ... DAGBuilder(infer_dataframe=True) + ... .add_step("impute", SimpleImputer()) + ... .add_step("vitals", "passthrough", deps={"impute": ["age", "sex", "bmi", "bp"]}) + ... .add_step("blood", PCA(n_components=2, random_state=0), deps={"impute": slice(4, 10)}) + ... .add_step("lr", LogisticRegression(random_state=0), deps=["blood", "vitals"]) + ... .make_dag() + ... ) + >>> dag.draw() + o impute + |\ + o o blood,vitals + |/ + o lr + .. image:: _static/img/dag2.png In the above examples we pass the first four columns directly to a regressor, but the remaining columns have dimensionality reduction applied first before being -passed to the same regressor as extra input columns. Note that we can define our graph -edges in two different ways: as a dict (if we need to select only certain columns from -the source node) or as a simple list (if we want to simply grab all columns from all -input nodes). +passed to the same regressor as extra input columns. + +In this DAG, as well as using the ``deps`` option to control which estimators feed in to +other estimators, but which columns are used (and ignored) by each step. For more detail +on how to control this behaviour, see the `User Guide `_. The DAG may now be used as an estimator in its own right: diff --git a/doc/user_guide.rst b/doc/user_guide.rst index 0c560ff..208ffa3 100644 --- a/doc/user_guide.rst +++ b/doc/user_guide.rst @@ -26,7 +26,8 @@ scikit-learn :class:`~sklearn.pipeline.Pipeline`. These DAGs may be created from ... ("impute", SimpleImputer()), ... ("pca", PCA()), ... ("lr", LogisticRegression()) - ... ] + ... ], + ... infer_dataframe=True, ... ) You may view a diagram of the DAG with the :meth:`~skdag.dag.DAG.show` method. In a @@ -44,6 +45,12 @@ ASCII text: .. image:: _static/img/dag1.png +Note that we also provided an extra option, ``infer_dataframe``. This is entirely +optional, but if set the DAG will ensure that dataframe inputs have column and index +information preserved (or inferred), and the output of the pipeline will also be a +dataframe. This is useful if you wish to filter down the inputs for one particular step +to only include certain columns; something we shall see in action later. + For more complex DAGs, it is recommended to use a :class:`skdag.dag.DAGBuilder`, which allows you to define the graph by specifying the dependencies of each new estimator: @@ -51,11 +58,12 @@ estimator: .. code-block:: python >>> from skdag import DAGBuilder + >>> from sklearn.compose import make_column_selector >>> dag = ( - ... DAGBuilder() + ... DAGBuilder(infer_dataframe=True) ... .add_step("impute", SimpleImputer()) - ... .add_step("vitals", "passthrough", deps={"impute": slice(0, 4)}) - ... .add_step("blood", PCA(n_components=2, random_state=0), deps={"impute": slice(4, 10)}) + ... .add_step("vitals", "passthrough", deps={"impute": ["age", "sex", "bmi", "bp"]}) + ... .add_step("blood", PCA(n_components=2, random_state=0), deps={"impute": make_column_selector("s[0-9]+")}) ... .add_step("lr", LogisticRegression(random_state=0), deps=["blood", "vitals"]) ... .make_dag() ... ) @@ -73,7 +81,16 @@ the remaining columns have dimensionality reduction applied first before being passed to the same regressor. Note that we can define our graph edges in two different ways: as a dict (if we need to select only certain columns from the source node) or as a simple list (if we want to simply grab all columns from all input -nodes). +nodes). Columns may be specified as any kind of iterable (list, slice etc.) or a column +selector function that conforms to :meth:`sklearn.compose.make_column_selector`. + +If you wish to specify string column names for dependencies, ensure you provide the +``infer_dataframe=True`` option when you create a dag. This will ensure that all +estimator outputs are coerced into dataframes. Where possible column names will be +inferred, otherwise the column names will just be the name of the estimator step with an +appended index number. If you do not specify ``infer_dataframe=True``, the dag will +leave the outputs unmodified, which in most cases will mean numpy arrays that only +support numeric column indices. The DAG may now be used as an estimator in its own right: @@ -189,7 +206,7 @@ as a dictionary of step name to column indices instead: >>> from sklearn.ensemble import RandomForestClassifier >>> from sklearn.svm import SVC >>> clf_stack = ( - ... DAGBuilder() + ... DAGBuilder(infer_dataframe=True) ... .add_step("pass", "passthrough") ... .add_step("rf", RandomForestClassifier(), deps=["pass"]) ... .add_step("svr", SVC(), deps=["pass"]) diff --git a/requirements_test.txt b/requirements_test.txt index 9955dec..277d732 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -1,2 +1,3 @@ +pandas pytest pytest-cov diff --git a/skdag/_version.py b/skdag/_version.py index 3b93d0b..27fdca4 100644 --- a/skdag/_version.py +++ b/skdag/_version.py @@ -1 +1 @@ -__version__ = "0.0.2" +__version__ = "0.0.3" diff --git a/skdag/dag/_builder.py b/skdag/dag/_builder.py index 9c7d114..f0c44e3 100644 --- a/skdag/dag/_builder.py +++ b/skdag/dag/_builder.py @@ -16,6 +16,15 @@ class DAGBuilder: that reference each step by name. Note that steps must be defined before they are used as dependencies. + Parameters + ---------- + + infer_dataframe : bool, default = False + If True, assume ``dataframe_columns="infer"`` every time :meth:`.add_step` is + called, if ``dataframe_columns`` is set to ``None``. This effectively makes the + resulting DAG always try to coerce output into pandas DataFrames wherever + possible. + See Also -------- :class:`skdag.DAG` : The estimator DAG created by this utility. @@ -43,10 +52,66 @@ class DAGBuilder: o lr """ - def __init__(self): + def __init__(self, infer_dataframe=False): self.graph = nx.DiGraph() + self.infer_dataframe = infer_dataframe + + def from_pipeline(self, steps, **kwargs): + """ + Construct a DAG from a simple linear sequence of steps. The resulting DAG will + be equivalent to a :class:`~sklearn.pipeline.Pipeline`. + + Parameters + ---------- + + steps : sequence of (str, estimator) + An ordered sequence of pipeline steps. A step is simply a pair of + ``(name, estimator)``, just like a scikit-learn Pipeline. + + infer_dataframe : bool, default = False + If True, assume ``dataframe_columns="infer"`` every time :meth:`.add_step` + is called, if ``dataframe_columns`` is set to ``None``. This effectively + makes the resulting DAG always try to coerce output into pandas DataFrames + wherever possible. + + kwargs : kwargs + Any other hyperparameters that are accepted by :class:`~skdag.dag.DAG`'s + contructor. + """ + if hasattr(steps, "steps"): + pipe = steps + steps = pipe.steps + if hasattr(pipe, "get_params"): + kwargs = { + **{ + k: v + for k, v in pipe.get_params().items() + if k in ("memory", "verbose") + }, + **kwargs, + } + + dfcols = "infer" if self.infer_dataframe else None + + for i in range(len(steps)): + name, estimator = steps[i] + self._validate_name(name) + deps = {} + if i > 0: + dep = steps[i - 1][0] + deps[dep] = None + self._validate_deps(deps) + + step = DAGStep(name, estimator, deps, dataframe_columns=dfcols) + self.graph.add_node(name, step=step) + if deps: + self.graph.add_edge(dep, name) - def add_step(self, name, est, deps=None): + self._validate_graph() + + return self + + def add_step(self, name, est, deps=None, dataframe_columns=None): self._validate_name(name) if isinstance(deps, Sequence): deps = {dep: None for dep in deps} @@ -56,7 +121,12 @@ def add_step(self, name, est, deps=None): else: deps = {} - step = DAGStep(name, est, deps=deps) + if dataframe_columns is None and self.infer_dataframe: + dfcols = "infer" + else: + dfcols = dataframe_columns + + step = DAGStep(name, est, deps=deps, dataframe_columns=dfcols) self.graph.add_node(name, step=step) for dep in deps: diff --git a/skdag/dag/_dag.py b/skdag/dag/_dag.py index 71f8e55..24a6884 100644 --- a/skdag/dag/_dag.py +++ b/skdag/dag/_dag.py @@ -14,7 +14,9 @@ from scipy.sparse import dok_matrix, issparse from skdag.dag._render import DAGRenderer from skdag.dag._utils import ( + _format_output, _in_notebook, + _is_pandas, _is_passthrough, _is_predictor, _is_transformer, @@ -22,7 +24,7 @@ ) from sklearn.base import clone from sklearn.exceptions import NotFittedError -from sklearn.utils import Bunch, _print_elapsed_time, _safe_indexing +from sklearn.utils import Bunch, _print_elapsed_time, _safe_indexing, deprecated from sklearn.utils._tags import _safe_tags from sklearn.utils.metaestimators import _BaseComposition, available_if from sklearn.utils.validation import check_is_fitted, check_memory @@ -34,7 +36,10 @@ def _stack_inputs(dag, X, node): # For root nodes, the dependency is just the node name itself. deps = {node.name: None} if node.is_root else node.deps - cols = [_safe_indexing(X[dep], deps[dep], axis=1) for dep in deps] + cols = [ + X[dep][cols(X[dep])] if callable(cols) else _safe_indexing(X[dep], cols, axis=1) + for dep, cols in deps.items() + ] to_stack = [ # If we sliced a single column from an input, reshape it to a 2d array. @@ -49,6 +54,31 @@ def _stack_inputs(dag, X, node): return X_stacked +def _leaf_estimators_have(attr, how="all"): + """Check that leaves have `attr`. + Used together with `avaliable_if` in `DAG`.""" + + def check_leaves(self): + # raises `AttributeError` with all details if `attr` does not exist + failed = [] + for leaf in self.leaves_: + try: + _is_passthrough(leaf.estimator) or getattr(leaf.estimator, attr) + except AttributeError: + failed.append(leaf.estimator) + + if (how == "all" and failed) or ( + how == "any" and len(failed) != len(self.leaves_) + ): + raise AttributeError( + f"{', '.join([repr(type(est)) for est in failed])} " + f"object(s) has no attribute '{attr}'" + ) + return True + + return check_leaves + + def _transform_one(transformer, X, weight, allow_predictor=True, **fit_params): if _is_passthrough(transformer): res = X @@ -66,9 +96,10 @@ def _transform_one(transformer, X, weight, allow_predictor=True, **fit_params): else: res = transformer.transform(X) # if we have a weight for this transformer, multiply output - if weight is None: - return res - return res * weight + if weight is not None: + res = res * weight + + return res def _fit_transform_one( @@ -115,36 +146,10 @@ def _fit_transform_one( f"'{type(transformer).__name__}' object has no attribute 'transform'" ) - if weight is None: - return res, transformer - return res * weight, transformer - - -def _leaf_estimators_have(attr, how="all"): - """Check that leaves have `attr`. - Used together with `avaliable_if` in `DAG`.""" - - def check_leaves(self): - # raises `AttributeError` with all details if `attr` does not exist - failed = [] - for leaf in self.leaves_: - try: - (_is_passthrough(leaf.estimator) and attr == "transform") or getattr( - leaf.estimator, attr - ) - except AttributeError: - failed.append(leaf.estimator) - - if (how == "all" and failed) or ( - how == "any" and len(failed) != len(self.leaves_) - ): - raise AttributeError( - f"{', '.join([repr(type(est)) for est in failed])} " - f"object(s) has no attribute '{attr}'" - ) - return True + if weight is not None: + res = res * weight - return check_leaves + return res, transformer def _parallel_fit(dag, step, Xin, Xs, y, fit_transform_fn, memory, **fit_params): @@ -159,27 +164,29 @@ def _parallel_fit(dag, step, Xin, Xs, y, fit_transform_fn, memory, **fit_params) X = _stack_inputs(dag, Xin, step) clsname = type(dag).__name__ - if transformer is None or transformer == "passthrough": - with _print_elapsed_time(clsname, dag._log_message(step)): - return X, transformer - - if hasattr(memory, "location") and memory.location is None: - # we do not clone when caching is disabled to - # preserve backward compatibility - cloned_transformer = transformer - else: - cloned_transformer = clone(transformer) - - # Fit or load from cache the current transformer - Xt, fitted_transformer = fit_transform_fn( - cloned_transformer, - X, - y, - None, - message_clsname=clsname, - message=dag._log_message(step), - **fit_params, - ) + with _print_elapsed_time(clsname, dag._log_message(step)): + if transformer is None or transformer == "passthrough": + Xt, fitted_transformer = X, transformer + else: + if hasattr(memory, "location") and memory.location is None: + # we do not clone when caching is disabled to + # preserve backward compatibility + cloned_transformer = transformer + else: + cloned_transformer = clone(transformer) + + # Fit or load from cache the current transformer + Xt, fitted_transformer = fit_transform_fn( + cloned_transformer, + X, + y, + None, + message_clsname=clsname, + message=dag._log_message(step), + **fit_params, + ) + + Xt = _format_output(Xt, X, step) return Xt, fitted_transformer @@ -195,19 +202,21 @@ def _parallel_transform(dag, step, Xin, Xs, transform_fn, **fn_params): # X = Xin[step.name] clsname = type(dag).__name__ - if transformer is None or transformer == "passthrough": - with _print_elapsed_time(clsname, dag._log_message(step)): - return X - - # Fit or load from cache the current transformer - Xt = transform_fn( - transformer, - X, - None, - message_clsname=clsname, - message=dag._log_message(step), - **fn_params, - ) + with _print_elapsed_time(clsname, dag._log_message(step)): + if transformer is None or transformer == "passthrough": + Xt = X + else: + # Fit or load from cache the current transformer + Xt = transform_fn( + transformer, + X, + None, + message_clsname=clsname, + message=dag._log_message(step), + **fn_params, + ) + + Xt = _format_output(Xt, X, step) return Xt @@ -244,7 +253,10 @@ def _parallel_execute( else: Xout = est_fn(Xt, **fn_params) + Xout = _format_output(Xout, Xt, leaf) + fitted_estimator = leaf.estimator + return Xout, fitted_estimator @@ -261,6 +273,14 @@ class DAGStep: deps : dict A map of dependency names to columns. If columns is ``None``, then all input columns will be selected. + dataframe_columns : list of str or "infer" (optional) + Either a hard-coded list of column names to apply to any output data, or the + string "infer", which means the column outputs will be assumed to match the + column inputs if the output is 2d and not already a dataframe, the estimator is + a transformer, and the final axis dimensions match the inputs. Otherwise the + column names will be assumed to be the step name + index if the output is not + already a dataframe. If set to ``None`` or inference is not possible, the + outputs will be left unmodified. axis : int, default = 1 The strategy for merging inputs if there is more than upstream dependency. ``axis=0`` will assume all inputs have the same features and stack the rows @@ -268,10 +288,11 @@ class DAGStep: same samples. """ - def __init__(self, name, estimator, deps, axis=1): + def __init__(self, name, estimator, deps, dataframe_columns, axis=1): self.name = name self.estimator = estimator self.deps = deps + self.dataframe_columns = dataframe_columns self.axis = axis self.index = None self.is_root = False @@ -428,49 +449,14 @@ class DAG(_BaseComposition): _required_parameters = ["graph"] @classmethod + @deprecated( + "DAG.from_pipeline is deprecated in 0.0.3 and will be removed in a future " + "release. Please use DAGBuilder.from_pipeline instead." + ) def from_pipeline(cls, steps, **kwargs): - """ - Construct a DAG from a simple linear sequence of steps. The resulting DAG will - be equivalent to a :class:`~sklearn.pipeline.Pipeline`. + from skdag.dag._builder import DAGBuilder - Parameters - ---------- - - steps : sequence of (str, estimator) - An ordered sequence of pipeline steps. A step is simply a pair of - ``(name, estimator)``, just like a scikit-learn Pipeline. - - kwargs : kwargs - Any other hyperparameters that are accepted by :class:`~skdag.dag.DAG`'s - contructor. - """ - graph = nx.DiGraph() - if hasattr(steps, "steps"): - pipe = steps - steps = pipe.steps - if hasattr(pipe, "get_params"): - kwargs = { - **{ - k: v - for k, v in pipe.get_params().items() - if k in ("memory", "verbose") - }, - **kwargs, - } - - for i in range(len(steps)): - name, estimator = steps[i] - deps = {} - if i > 0: - dep = steps[i - 1][0] - deps[dep] = None - - step = DAGStep(name, estimator, deps) - graph.add_node(name, step=step) - if deps: - graph.add_edge(dep, name) - - return cls(graph=graph, **kwargs) + return DAGBuilder().from_pipeline(steps, **kwargs).make_dag() def __init__(self, graph, *, memory=None, n_jobs=None, verbose=False): self.graph = graph @@ -672,7 +658,10 @@ def _resolve_inputs(self, X): ) X = {self.roots_[0].name: X} - X = {step: x if issparse(x) else np.asarray(x) for step, x in X.items()} + X = { + step: x if issparse(x) or _is_pandas(x) else np.asarray(x) + for step, x in X.items() + } return X @@ -1240,21 +1229,27 @@ def join(self, other, edges, **kwargs): if v not in other.graph_: raise KeyError(v) - attrs = other.graph_.nodes[v] - old_step = attrs["step"] - step = DAGStep( + # source node can no longer be a leaf + ustep = newgraph.nodes[u]["step"] + if ustep.is_leaf: + ustep.is_leaf = False + + vnode = other.graph_.nodes[v] + old_step = vnode["step"] + vstep = DAGStep( name=old_step.name, estimator=old_step.estimator, deps=old_step.deps, + dataframe_columns=old_step.dataframe_columns, axis=old_step.axis, ) - if u not in step.deps: - step.deps[u] = idx + if u not in vstep.deps: + vstep.deps[u] = idx - attrs["step"] = step + vnode["step"] = vstep - newgraph.add_node(v, **attrs) + newgraph.add_node(v, **vnode) newgraph.add_edge(u, v) return DAG(newgraph, **kwargs) @@ -1346,6 +1341,7 @@ def show(self, style=None, detailed=False, format=None, layout="dot"): """ if format is None: format = "svg" if _in_notebook() else "txt" + data = self.draw(style=style, detailed=detailed, format=format, layout=layout) if format == "svg": from IPython.display import SVG, display diff --git a/skdag/dag/_utils.py b/skdag/dag/_utils.py index 4364b48..bfa454f 100644 --- a/skdag/dag/_utils.py +++ b/skdag/dag/_utils.py @@ -1,6 +1,11 @@ import numpy as np from scipy import sparse +try: + import pandas as pd +except ImportError: + pd = None + def _is_passthrough(estimator): return estimator is None or estimator == "passthrough" @@ -39,7 +44,7 @@ def _stack(Xs, axis=0): hstack (combination of features). Higher axes are only supported for non-sparse data sources. """ - if any(sparse.issparse(f) for f in Xs): + if any(sparse.issparse(x) for x in Xs): if -2 <= axis < 2: axis = axis % 2 else: @@ -51,6 +56,8 @@ def _stack(Xs, axis=0): Xs = sparse.vstack(Xs).tocsr() elif axis == 1: Xs = sparse.hstack(Xs).tocsr() + elif pd and all(_is_pandas(x) for x in Xs): + Xs = pd.concat(Xs, axis=axis) else: if axis == 1: Xs = np.hstack(Xs) @@ -58,3 +65,53 @@ def _stack(Xs, axis=0): Xs = np.stack(Xs, axis=axis) return Xs + + +def _is_pandas(X): + "Check if X is a DataFrame or Series" + return hasattr(X, "iloc") + + +def _format_output(X, input, node): + outdim = np.asarray(X).ndim + if ( + outdim > 2 + or outdim < 1 + or node.dataframe_columns is None + or pd is None + or _is_pandas(X) + ): + return X + else: + inshape = np.asarray(input).shape + outshape = np.asarray(X).shape + indim = np.asarray(input).ndim + if node.dataframe_columns == "infer": + if ( + hasattr(node.estimator, "transform") + and ( + (inshape == outshape) + or (indim > 1 and outdim > 1 and inshape[1] == outshape[1]) + ) + and hasattr(input, "columns") + ): + columns = input.columns + else: + if outdim == 1: + columns = [node.name] + else: + columns = [f"{node.name}{i}" for i in range(outshape[1])] + else: + columns = node.dataframe_columns + + if hasattr(input, "index"): + index = input.index + else: + index = None + + if outdim == 2: + df = pd.DataFrame(X, columns=columns, index=index) + else: + df = pd.Series(X, name=columns[0], index=index) + + return df diff --git a/skdag/dag/tests/test_dag.py b/skdag/dag/tests/test_dag.py index 215242d..125105f 100644 --- a/skdag/dag/tests/test_dag.py +++ b/skdag/dag/tests/test_dag.py @@ -5,14 +5,18 @@ import time import numpy as np +import pandas as pd import pytest from skdag import DAG, DAGBuilder from skdag.dag.tests.utils import FitParamT, Mult, NoFit, NoTrans, Transf from sklearn import datasets +from sklearn import preprocessing from sklearn.base import BaseEstimator, clone +from sklearn.compose import make_column_selector from sklearn.decomposition import PCA -from sklearn.ensemble import RandomForestClassifier +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from sklearn.feature_selection import SelectKBest, f_classif +from sklearn.impute import SimpleImputer from sklearn.linear_model import LinearRegression, LogisticRegression from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler @@ -36,9 +40,8 @@ def test_dag_invalid_parameters(): # Test the various init parameters of the dag in fit # method - dag = DAG.from_pipeline([(1, 1)]) - with pytest.raises(TypeError): - dag.fit([[1]], [1]) + with pytest.raises(KeyError): + dag = DAGBuilder().from_pipeline([(1, 1)]).make_dag() # Check that we can't fit DAGs with objects without fit # method @@ -46,13 +49,13 @@ def test_dag_invalid_parameters(): "Leaf nodes of a DAG should implement fit or be the string 'passthrough'" ".*NoFit.*" ) - dag = DAG.from_pipeline([("clf", NoFit())]) + dag = DAGBuilder().from_pipeline([("clf", NoFit())]).make_dag() with pytest.raises(TypeError, match=msg): dag.fit([[1]], [1]) # Smoke test with only an estimator clf = NoTrans() - dag = DAG.from_pipeline([("svc", clf)]) + dag = DAGBuilder().from_pipeline([("svc", clf)]).make_dag() assert dag.get_params(deep=True) == dict( svc__a=None, svc__b=None, svc=clf, **dag.get_params(deep=False) ) @@ -67,7 +70,7 @@ def test_dag_invalid_parameters(): # Test with two objects clf = SVC() filter1 = SelectKBest(f_classif) - dag = DAG.from_pipeline([("anova", filter1), ("svc", clf)]) + dag = DAGBuilder().from_pipeline([("anova", filter1), ("svc", clf)]).make_dag() # Check that estimators are not cloned on pipeline construction assert dag.named_steps["anova"] is filter1 @@ -76,7 +79,7 @@ def test_dag_invalid_parameters(): # Check that we can't fit with non-transformers on the way # Note that NoTrans implements fit, but not transform msg = "All intermediate steps should be transformers.*\\bNoTrans\\b.*" - dag2 = DAG.from_pipeline([("t", NoTrans()), ("svc", clf)]) + dag2 = DAGBuilder().from_pipeline([("t", NoTrans()), ("svc", clf)]).make_dag() with pytest.raises(TypeError, match=msg): dag2.fit([[1]], [1]) @@ -132,7 +135,7 @@ def test_dag_pipeline_init(): steps = (("transf", Transf()), ("clf", FitParamT())) pipe = Pipeline(steps, verbose=False) for inp in [pipe, steps]: - dag = DAG.from_pipeline(inp) + dag = DAGBuilder().from_pipeline(inp).make_dag() dag.fit(X, y=None) dag.score(X) @@ -148,7 +151,9 @@ def test_dag_methods_anova(): # Test with Anova + LogisticRegression clf = LogisticRegression() filter1 = SelectKBest(f_classif, k=2) - dag1 = DAG.from_pipeline([("anova", filter1), ("logistic", clf)]) + dag1 = ( + DAGBuilder().from_pipeline([("anova", filter1), ("logistic", clf)]).make_dag() + ) dag2 = ( DAGBuilder() .add_step("anova", filter1) @@ -165,7 +170,11 @@ def test_dag_methods_anova(): def test_dag_fit_params(): # Test that the pipeline can take fit parameters - dag = DAG.from_pipeline([("transf", Transf()), ("clf", FitParamT())]) + dag = ( + DAGBuilder() + .from_pipeline([("transf", Transf()), ("clf", FitParamT())]) + .make_dag() + ) dag.fit(X=None, y=None, clf__should_succeed=True) # classifier should return True assert dag.predict(None) @@ -182,7 +191,11 @@ def test_dag_fit_params(): def test_dag_sample_weight_supported(): # DAG should pass sample_weight X = np.array([[1, 2]]) - dag = DAG.from_pipeline([("transf", Transf()), ("clf", FitParamT())]) + dag = ( + DAGBuilder() + .from_pipeline([("transf", Transf()), ("clf", FitParamT())]) + .make_dag() + ) dag.fit(X, y=None) assert dag.score(X) == 3 assert dag.score(X, y=None) == 3 @@ -193,7 +206,7 @@ def test_dag_sample_weight_supported(): def test_dag_sample_weight_unsupported(): # When sample_weight is None it shouldn't be passed X = np.array([[1, 2]]) - dag = DAG.from_pipeline([("transf", Transf()), ("clf", Mult())]) + dag = DAGBuilder().from_pipeline([("transf", Transf()), ("clf", Mult())]).make_dag() dag.fit(X, y=None) assert dag.score(X) == 3 assert dag.score(X, sample_weight=None) == 3 @@ -205,7 +218,7 @@ def test_dag_sample_weight_unsupported(): def test_dag_raise_set_params_error(): # Test dag raises set params error message for nested models. - dag = DAG.from_pipeline([("cls", LinearRegression())]) + dag = DAGBuilder().from_pipeline([("cls", LinearRegression())]).make_dag() # expected error message error_msg = ( @@ -259,6 +272,12 @@ def test_dag_stacking_pca_svm_rf(idx): assert dag.predict_log_proba(X).shape == prob_shape assert isinstance(dag.score(X, y), (float, np.floating)) + root = dag["log"] + for attr in ["n_features_in_", "feature_names_in_"]: + if hasattr(root, attr): + assert hasattr(dag, attr) + + def test_dag_draw(): txt = DAGBuilder().make_dag().draw(format="txt") @@ -304,10 +323,101 @@ def test_dag_draw(): assert f"{type(est).__name__}" in svg +def _dag_from_steplist(steps, **builder_opts): + builder = DAGBuilder(**builder_opts) + for step in steps: + builder.add_step(**step) + return builder.make_dag() + + +@pytest.mark.parametrize( + "steps", + [ + [ + { + "name": "pca", + "est": PCA(n_components=1), + }, + { + "name": "svc", + "est": SVC(probability=True, random_state=0), + "deps": ["pca"], + }, + { + "name": "rf", + "est": RandomForestClassifier(random_state=0), + "deps": ["pca"], + }, + { + "name": "log", + "est": LogisticRegression(), + "deps": ["svc", "rf"], + }, + ], + ], +) +@pytest.mark.parametrize( + "X,y", + [datasets.make_blobs(n_samples=200, n_features=10, centers=3, random_state=0)], +) +def test_pandas(X, y, steps): + dag_np = _dag_from_steplist(steps, infer_dataframe=False) + dag_pd = _dag_from_steplist(steps, infer_dataframe=True) + + dag_np.fit(X, y) + dag_pd.fit(X, y) + + y_pred_np = dag_np.predict_proba(X) + y_pred_pd = dag_pd.predict_proba(X) + assert isinstance(y_pred_np, np.ndarray) + assert isinstance(y_pred_pd, pd.DataFrame) + assert np.allclose(y_pred_np, y_pred_pd) + + +def test_pandas_indexing(): + X, y = datasets.load_diabetes(return_X_y=True, as_frame=True) + + passcols = ["age", "sex", "bmi", "bp"] + preprocessing = ( + DAGBuilder(infer_dataframe=True) + .add_step("imp", SimpleImputer()) + .add_step("vitals", "passthrough", deps={"imp": passcols}) + .add_step( + "blood", + PCA(n_components=2, random_state=0), + deps={"imp": make_column_selector("s[0-9]+")}, + ) + .add_step("out", "passthrough", deps=["vitals", "blood"]) + .make_dag() + ) + + X_tr = preprocessing.fit_transform(X, y) + assert isinstance(X_tr, pd.DataFrame) + assert (X_tr.index == X.index).all() + assert X_tr.columns.tolist() == passcols + ["blood0", "blood1"] + + predictor = ( + DAGBuilder(infer_dataframe=True) + .add_step("rf", RandomForestRegressor(random_state=0)) + .make_dag() + ) + + dag = preprocessing.join( + predictor, + edges=[("out", "rf")], + ) + + y_pred = dag.fit_predict(X, y) + + assert isinstance(y_pred, pd.Series) + assert (y_pred.index == y.index).all() + assert y_pred.name == dag.leaves_[0].name + + @parametrize_with_checks( [ - DAG.from_pipeline([("ss", StandardScaler())]), - DAG.from_pipeline([("lr", LinearRegression())]), + DAGBuilder().from_pipeline([("ss", StandardScaler())]).make_dag(), + DAGBuilder().from_pipeline([("lr", LinearRegression())]).make_dag(), ( DAGBuilder() .add_step("pca", PCA(n_components=1))