diff --git a/merlin/systems/workflow/base.py b/merlin/systems/workflow/base.py index 81090f60d..61c8a3c3d 100644 --- a/merlin/systems/workflow/base.py +++ b/merlin/systems/workflow/base.py @@ -27,6 +27,7 @@ import functools import json import logging +import os from merlin.dag import ColumnSelector, DataFormats, Supports from merlin.dag.executors import LocalExecutor, _convert_format, _data_format @@ -66,15 +67,24 @@ def __init__(self, workflow, output_dtypes, model_config, model_device): ) # recurse over all column groups, initializing operators for inference pipeline. - # (disabled for now while we sort out whether and how we want to use C++ implementations - # of NVTabular operators for performance optimization) - # self._initialize_ops(self.workflow.output_node) + # (disabled everything other than operators that are specifically listed + # by the `NVT_CPP_OPS` environment variable while we sort out whether + # and how we want to use C++ implementations of NVTabular operators for + # performance optimization) + _nvt_cpp_ops = os.environ.get("NVT_CPP_OPS", "Categorify").split(",") + self._initialize_ops(self.workflow.output_node, restrict=_nvt_cpp_ops) + + def _initialize_ops(self, workflow_node, visited=None, restrict=None): + restrict = restrict or [] - def _initialize_ops(self, workflow_node, visited=None): if visited is None: visited = set() - if workflow_node.op and hasattr(workflow_node.op, "inference_initialize"): + if ( + workflow_node.op + and hasattr(workflow_node.op, "inference_initialize") + and (not restrict or workflow_node.op.label in restrict) + ): inference_op = workflow_node.op.inference_initialize( workflow_node.selector, self.model_config ) @@ -96,7 +106,7 @@ def _initialize_ops(self, workflow_node, visited=None): for parent in workflow_node.parents_with_dependencies: if parent not in visited: visited.add(parent) - self._initialize_ops(parent, visited) + self._initialize_ops(parent, visited=visited, restrict=restrict) def run_workflow(self, input_tensors): transformable = TensorTable(input_tensors).to_df()