Skip to content

Commit

Permalink
cleanup docstrings and dtype usage
Browse files Browse the repository at this point in the history
  • Loading branch information
neindochoh committed Sep 20, 2023
1 parent d93a645 commit 111e8bc
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 30 deletions.
33 changes: 17 additions & 16 deletions renumics/spotlight/layouts/model_compare.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional, Union, Dict, Any
from renumics.spotlight import dtypes
from renumics.spotlight import layout
from renumics.spotlight.layout import (
Layout,
Expand All @@ -13,8 +15,6 @@
issues,
confusion_matrix,
)
from typing import Optional, Union
from renumics.spotlight import Audio, Image


def compare_classification(
Expand All @@ -25,22 +25,22 @@ def compare_classification(
model2_prediction: str = "m2_prediction",
model2_embedding: str = "",
model2_correct: str = "",
inspect: Optional[dict] = None,
inspect: Optional[Dict[str, Any]] = None,
) -> Layout:
"""This function generates a Spotlight layout for comparing two different machine learning classification models.
Args:
label (str, optional): Name of the dataframe column that contains the label. Defaults to "label".
model1_prediction (str, optional): Name of the dataframe column that contains the prediction for model 1. Defaults to "m1_prediction".
model1_embedding (str, optional): Name of the dataframe column that contains thee embedding for model 1. Defaults to "".
model1_correct (str, optional): Name of the dataframe column that contains a flag if the data sample is predicted correctly by model 1.
model2_prediction (str, optional): Name of the dataframe column that contains the prediction for model 2. Defaults to "m2_prediction".
model2_embedding (str, optional): Name of the dataframe column that contains thee embedding for model 2. Defaults to "".
model2_correct (str, optional): Name and type of the dataframe columns that are displayed in the inspector, e.g. {'audio': spotlight.Audio}. Defaults to None.
inspect (Optional[dict], optional): Name of the dataframe column that contains a flag if the data sample is predicted correctly by model 1.
label: Name of the column that contains the label.
model1_prediction: Name of the column that contains the prediction for model 1.
model1_embedding: Name of the column that contains thee embedding for model 1.
model1_correct: Name of the column that contains a flag if the data sample is predicted correctly by model 1.
model2_prediction: Name of the column that contains the prediction for model 2.
model2_embedding: Name of the column that contains thee embedding for model 2.
model2_correct: Name of the column that contains a flag if the data sample is predicted correctly by model 2.
inspect: Name and type of the columns that are displayed in the inspector, e.g. {'audio': spotlight.dtypes.audio_dtype}.
Returns:
Layout: _description_
The configured layout for `spotlight.show`.
"""

# first column: table + issues
Expand Down Expand Up @@ -133,13 +133,14 @@ def compare_classification(
# fourth column: inspector
inspector_fields = []
if inspect:
for item, _type in inspect.items():
if _type == Audio:
for item, dtype_like in inspect.items():
dtype = dtypes.create_dtype(dtype_like)
if dtypes.is_audio_dtype(dtype):
inspector_fields.append(lenses.audio(item))
elif _type == Image:
elif dtypes.is_image_dtype(dtype):
inspector_fields.append(lenses.image(item))
else:
print("Type {} not supported by this layout.".format(_type))
print(f"Type {dtype} not supported by this layout.")

inspector_fields.append(lenses.scalar(label))
inspector_fields.append(lenses.scalar(model1_prediction))
Expand Down
29 changes: 15 additions & 14 deletions renumics/spotlight/layouts/model_debug.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional, Union, Dict, List, Any
from renumics.spotlight import layout
from renumics.spotlight.layout import (
Layout,
Expand All @@ -14,28 +15,27 @@
confusion_matrix,
histogram,
)
from typing import Optional, Union
from renumics.spotlight import Audio, Image
from renumics.spotlight.dtypes import create_dtype, is_audio_dtype, is_image_dtype


def debug_classification(
label: str = "label",
prediction: str = "prediction",
embedding: str = "",
inspect: Optional[dict] = None,
features: Optional[list] = None,
inspect: Optional[Dict[str, Any]] = None,
features: Optional[List[str]] = None,
) -> Layout:
"""This function generates a Spotlight layout for debugging a machine learning classification model.
Args:
label (str, optional): Name of the dataframe column that contains the label. Defaults to "label".
prediction (str, optional): Name of the dataframe column that contains the prediction. Defaults to "prediction".
embedding (str, optional): Name of the dataframe column that contains the embedding. Defaults to "".
inspect (Optional[dict], optional): Name and type of the dataframe columns that are displayed in the inspector, e.g. {'audio': spotlight.Audio}. Defaults to None.
features (Optional[list], optional): Name of the dataframe columns that contain useful metadata and features. Defaults to None.
label: Name of the column that contains the label.
prediction: Name of the column that contains the prediction.
embedding: Name of the column that contains the embedding.
inspect: Name and type of the columns that are displayed in the inspector, e.g. {'audio': spotlight.dtypes.audio_dtype}.
features: Names of the columns that contain useful metadata and features.
Returns:
Layout: Layout to be displayed with Spotlight.
The configured layout for `spotlight.show`.
"""

# first column: table + issues
Expand Down Expand Up @@ -98,13 +98,14 @@ def debug_classification(
# fourth column: inspector
inspector_fields = []
if inspect:
for item, _type in inspect.items():
if _type == Audio:
for item, dtype_like in inspect.items():
dtype = create_dtype(dtype_like)
if is_audio_dtype(dtype):
inspector_fields.append(lenses.audio(item))
elif _type == Image:
elif is_image_dtype(dtype):
inspector_fields.append(lenses.image(item))
else:
print("Type {} not supported by this layout.".format(_type))
print("Type {} not supported by this layout.".format(dtype))

inspector_fields.append(lenses.scalar(label))
inspector_fields.append(lenses.scalar(prediction))
Expand Down

0 comments on commit 111e8bc

Please sign in to comment.