From a52219ae7403f6dad968f57930a1cb125c58e53c Mon Sep 17 00:00:00 2001 From: Stefan Suwelack Date: Tue, 19 Sep 2023 13:27:01 +0200 Subject: [PATCH] docstring added --- renumics/spotlight/layouts/__init__.py | 6 ++++-- renumics/spotlight/layouts/model_compare.py | 24 +++++++++++++++++++-- renumics/spotlight/layouts/model_debug.py | 18 ++++++++++++++-- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/renumics/spotlight/layouts/__init__.py b/renumics/spotlight/layouts/__init__.py index c73f5fab..740fa53c 100644 --- a/renumics/spotlight/layouts/__init__.py +++ b/renumics/spotlight/layouts/__init__.py @@ -1,2 +1,4 @@ -from .model_debug import model_debug_classification -from .model_compare import model_compare_classification +from .model_debug import debug_classification +from .model_compare import compare_classification + +__all__ = ["debug_classification", "compare_classification"] diff --git a/renumics/spotlight/layouts/model_compare.py b/renumics/spotlight/layouts/model_compare.py index 29394942..75fb5ad1 100644 --- a/renumics/spotlight/layouts/model_compare.py +++ b/renumics/spotlight/layouts/model_compare.py @@ -1,6 +1,8 @@ from renumics.spotlight import layout from renumics.spotlight.layout import ( Layout, + Tab, + Split, lenses, table, similaritymap, @@ -11,11 +13,11 @@ issues, confusion_matrix, ) -from typing import Optional +from typing import Optional, Union from renumics.spotlight import Audio, Image -def model_compare_classification( +def compare_classification( label: str = "label", model1_prediction: str = "m1_prediction", model1_embedding: str = "", @@ -25,6 +27,22 @@ def model_compare_classification( model2_correct: str = "", inspect: Optional[dict] = None, ) -> Layout: + """This function generates a Spotlight layout for comparing two different machine learning classification models. + + Args: + label (str, optional): Name of the dataframe column that contains the label. Defaults to "label". + model1_prediction (str, optional): Name of the dataframe column that contains the prediction for model 1. Defaults to "m1_prediction". + model1_embedding (str, optional): Name of the dataframe column that contains thee embedding for model 1. Defaults to "". + model1_correct (str, optional): Name of the dataframe column that contains a flag if the data sample is predicted correctly by model 1. + model2_prediction (str, optional): Name of the dataframe column that contains the prediction for model 2. Defaults to "m2_prediction". + model2_embedding (str, optional): Name of the dataframe column that contains thee embedding for model 2. Defaults to "". + model2_correct (str, optional): Name and type of the dataframe columns that are displayed in the inspector, e.g. {'audio': spotlight.Audio}. Defaults to None. + inspect (Optional[dict], optional): Name of the dataframe column that contains a flag if the data sample is predicted correctly by model 1. + + Returns: + Layout: _description_ + """ + # first column: table + issues metrics = split( [ @@ -100,6 +118,8 @@ def model_compare_classification( ) column2_list.append(row3) + column2:Union[Tab, Split] + if len(column2_list) == 1: column2 = column2_list[0] elif len(column2_list) == 2: diff --git a/renumics/spotlight/layouts/model_debug.py b/renumics/spotlight/layouts/model_debug.py index 576b02a1..b5281979 100644 --- a/renumics/spotlight/layouts/model_debug.py +++ b/renumics/spotlight/layouts/model_debug.py @@ -16,13 +16,27 @@ from renumics.spotlight import Audio, Image -def model_debug_classification( + +def debug_classification( label: str = "label", prediction: str = "prediction", embedding: str = "", inspect: Optional[dict] = None, features: Optional[list] = None, ) -> Layout: + """This function generates a Spotlight layout for debugging a machine learning classification model. + + Args: + label (str, optional): Name of the dataframe column that contains the label. Defaults to "label". + prediction (str, optional): Name of the dataframe column that contains the prediction. Defaults to "prediction". + embedding (str, optional): Name of the dataframe column that contains the embedding. Defaults to "". + inspect (Optional[dict], optional): Name and type of the dataframe columns that are displayed in the inspector, e.g. {'audio': spotlight.Audio}. Defaults to None. + features (Optional[list], optional): Name of the dataframe columns that contain useful metadata and features. Defaults to None. + + Returns: + Layout: Layout to be displayed with Spotlight. + """ + # first column: table + issues metrics = tab( metric(name="Accuracy", metric="accuracy", columns=[label, prediction]), @@ -45,7 +59,7 @@ def model_debug_classification( ) ) - # third column: confusion matric, feature histograms (optional), embedding (optional) + # second column: confusion matric, feature histograms (optional), embedding (optional) if features is not None: histogram_list = [] for idx, feature in enumerate(features):