diff --git a/ludwig/utils/algorithms_utils.py b/ludwig/utils/algorithms_utils.py index ebebffb72ed..29e79836c0f 100644 --- a/ludwig/utils/algorithms_utils.py +++ b/ludwig/utils/algorithms_utils.py @@ -13,12 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -import logging - from ludwig.constants import TIED -logger = logging.getLogger(__name__) - def topological_sort(graph_unsorted): """Repeatedly go through all of the nodes in the graph, moving each of the nodes that has all its edges @@ -86,10 +82,3 @@ def topological_sort_feature_dependencies(features): dependencies_graph[feature["name"]] = dependencies output_features_dict[feature["name"]] = feature return [output_features_dict[node[0]] for node in topological_sort(dependencies_graph)] - - -if __name__ == "__main__": - graph_unsorted = [(2, []), (5, [11]), (11, [2, 9, 10]), (7, [11, 8]), (9, []), (10, []), (8, [9]), (3, [10, 8])] - logger.info(topological_sort(graph_unsorted)) - graph_unsorted = [("macro", ["action", "contact_type"]), ("contact_type", None), ("action", ["contact_type"])] - logger.info(topological_sort(graph_unsorted)) diff --git a/tests/ludwig/utils/test_algorithm_utils.py b/tests/ludwig/utils/test_algorithm_utils.py new file mode 100644 index 00000000000..30153cadfe8 --- /dev/null +++ b/tests/ludwig/utils/test_algorithm_utils.py @@ -0,0 +1,20 @@ +import pytest + +from ludwig.utils.algorithms_utils import topological_sort + + +@pytest.mark.parametrize( + "unsorted,sorted", + [ + ( + [(2, []), (5, [11]), (11, [2, 9, 10]), (7, [11, 8]), (9, []), (10, []), (8, [9]), (3, [10, 8])], + [(2, []), (9, []), (10, []), (8, [9]), (3, [10, 8]), (11, [2, 9, 10]), (7, [11, 8]), (5, [11])], + ), + ( + [("macro", ["action", "contact_type"]), ("contact_type", None), ("action", ["contact_type"])], + [("contact_type", []), ("action", ["contact_type"]), ("macro", ["action", "contact_type"])], + ), + ], +) +def test_topological_sort(unsorted: list, sorted: list) -> None: + assert topological_sort(unsorted) == sorted