Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ran black formatter #55

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions jlab_datascience_toolkit/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from jlab_datascience_toolkit.utils.registration import register, make, list_registered_modules
from jlab_datascience_toolkit.utils.registration import (
register,
make,
list_registered_modules,
)

register(
id="MultiClassClassificationAnalysis_v0",
entry_point="jlab_datascience_toolkit.analysis.multiclass_analysis_v0:Analysis"
entry_point="jlab_datascience_toolkit.analysis.multiclass_analysis_v0:Analysis",
)
53 changes: 40 additions & 13 deletions jlab_datascience_toolkit/analysis/multiclass_analysis_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,49 @@ class Analysis:
def __init__(self, configs: dict):
self.configs = configs

def run(self, y_true, y_pred, labels: np.ndarray = None, target_names: np.ndarray = None, sample_weight: np.ndarray = None, logdir: str = None) -> list:
def run(
self,
y_true,
y_pred,
labels: np.ndarray = None,
target_names: np.ndarray = None,
sample_weight: np.ndarray = None,
logdir: str = None,
) -> list:
ans = []
for submodule in self.configs["submodules"]:
submodule_type = submodule["type"]
submodule_configs = submodule.get("configs", {})
if submodule_type == "confusion_matrix":
cm = confusion_matrix(y_true, y_pred, labels=labels, sample_weight=sample_weight, **submodule_configs)
cm = confusion_matrix(
y_true,
y_pred,
labels=labels,
sample_weight=sample_weight,
**submodule_configs,
)
ans.append(cm)
if logdir: np.save(os.path.join(logdir, 'confusion_matrix.npy'), cm)
if logdir:
np.save(os.path.join(logdir, "confusion_matrix.npy"), cm)
elif submodule_type == "accuracy_score":
acc = accuracy_score(y_true, y_pred, sample_weight=sample_weight, **submodule_configs)
acc = accuracy_score(
y_true, y_pred, sample_weight=sample_weight, **submodule_configs
)
ans.append(acc)
if logdir: np.save(os.path.join(logdir, 'accuracy_score.npy'), acc)
elif submodule_type == 'classification_report':
cr = classification_report(y_true, y_pred, labels=labels, target_names=target_names, sample_weight=sample_weight, **submodule_configs)
if logdir:
np.save(os.path.join(logdir, "accuracy_score.npy"), acc)
elif submodule_type == "classification_report":
cr = classification_report(
y_true,
y_pred,
labels=labels,
target_names=target_names,
sample_weight=sample_weight,
**submodule_configs,
)
ans.append(cr)
if logdir and isinstance(cr, dict):
for metric in ['precision', 'recall', 'f1-score']:
for metric in ["precision", "recall", "f1-score"]:
metric_list = []
for k, v in cr.items():
if isinstance(v, dict) and (metric in v.keys()):
Expand All @@ -34,16 +59,18 @@ def run(self, y_true, y_pred, labels: np.ndarray = None, target_names: np.ndarra
fig, ax = plt.subplots()
ax.bar(
[tup[0] for tup in metric_list],
[tup[1] for tup in metric_list]
[tup[1] for tup in metric_list],
)
ax.set_title(metric)
fig.tight_layout()
fig.savefig(
os.path.join(logdir, f'{metric}.jpg'),
os.path.join(logdir, f"{metric}.jpg"),
transparent=True,
dpi=300
dpi=300,
)
plt.close(fig=fig)
else:
raise NameError('Unsupported submodule type in Multi-Class Analysis Module !')
return ans
raise NameError(
"Unsupported submodule type in Multi-Class Analysis Module !"
)
return ans
9 changes: 5 additions & 4 deletions jlab_datascience_toolkit/core/jdst_analysis.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from jlab_datascience_toolkit.core.jdst_module import JDSTModule
from abc import ABC, abstractmethod

class JDSTAnalysis(JDSTModule,ABC):
'''

class JDSTAnalysis(JDSTModule, ABC):
"""
Base class for the post-training analysis. This class inherits from the module base class.
'''
"""

# Run the analysis:
@abstractmethod
def run(self):
raise NotImplementedError
raise NotImplementedError
12 changes: 6 additions & 6 deletions jlab_datascience_toolkit/core/jdst_data_parser.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from jlab_datascience_toolkit.core.jdst_module import JDSTModule
from abc import ABC, abstractmethod

class JDSTDataParser(JDSTModule,ABC):

'''
class JDSTDataParser(JDSTModule, ABC):
"""
Base class for data parsing. This class inherits from the module base class.
'''
"""

# Load and save the data:
@abstractmethod
def load_data(self):
raise NotImplementedError

@abstractmethod
def save_data(self):
raise NotImplementedError
raise NotImplementedError
10 changes: 5 additions & 5 deletions jlab_datascience_toolkit/core/jdst_data_prep.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from jlab_datascience_toolkit.core.jdst_module import JDSTModule
from abc import ABC, abstractmethod

class JDSTDataPrep(JDSTModule,ABC):

'''
class JDSTDataPrep(JDSTModule, ABC):
"""
Base class for data preparation. This class inherits from the module base class.
'''
"""

# Save the data, if required
# This might be helpful, if the underlying data preperation is a computational intensive operation
# And we want to avoid calling it multiple times. Thus, we just store the data after preparation:
@abstractmethod
def save_data(self):
raise NotImplementedError

# Run the data preparation:
@abstractmethod
def run(self):
Expand All @@ -22,4 +22,4 @@ def run(self):
# Reverse the data preparation (if possible):
@abstractmethod
def reverse(self):
raise NotImplementedError
raise NotImplementedError
11 changes: 6 additions & 5 deletions jlab_datascience_toolkit/core/jdst_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from jlab_datascience_toolkit.core.jdst_module import JDSTModule
from abc import ABC, abstractmethod

class JDSTModel(JDSTModule,ABC):
'''

class JDSTModel(JDSTModule, ABC):
"""
Base class for the model. This class inherits from the module base class.
'''
"""

# Get a prediction:
@abstractmethod
def predict(self):
raise NotImplementedError
raise NotImplementedError
26 changes: 13 additions & 13 deletions jlab_datascience_toolkit/core/jdst_module.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,36 @@
from abc import ABC, abstractmethod

class JDSTModule(ABC):

'''
Base class for any module that is written for the JLab Data Science Toolkit. The functions defined here have to be implemented in
class JDSTModule(ABC):
"""
Base class for any module that is written for the JLab Data Science Toolkit. The functions defined here have to be implemented in
any new module that is written
'''
"""

# Initialize:
def __init__(self,**kwargs):
self.module_name = "" # --> Define the name of the module
def __init__(self, **kwargs):
self.module_name = "" # --> Define the name of the module

# Get module info: Just briefly describe what this module is doing,
# Get module info: Just briefly describe what this module is doing,
# what are the inputs and what is returned?
@abstractmethod
def get_info(self):
raise NotImplementedError

# Load and save configuration files which run the module:
@abstractmethod
def load_config(self):
raise NotImplementedError

@abstractmethod
def save_config(self):
raise NotImplementedError

# Load and save for checkpointing (i.e. capture state of module)
@abstractmethod
def load(self):
raise NotImplementedError

@abstractmethod
def save(self):
raise NotImplementedError
raise NotImplementedError
11 changes: 6 additions & 5 deletions jlab_datascience_toolkit/core/jdst_trainer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from jlab_datascience_toolkit.core.jdst_module import JDSTModule
from abc import ABC, abstractmethod

class JDSTTrainer(JDSTModule,ABC):
'''

class JDSTTrainer(JDSTModule, ABC):
"""
Base class for the Trainer. This class inherits from the module base class.
'''
"""

# Get a prediction:
@abstractmethod
def fit(self):
raise NotImplementedError
raise NotImplementedError
14 changes: 9 additions & 5 deletions jlab_datascience_toolkit/data_parser/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from jlab_datascience_toolkit.utils.registration import register, make, list_registered_modules
from jlab_datascience_toolkit.utils.registration import (
register,
make,
list_registered_modules,
)

register(
id="NumpyParser_v0",
entry_point="jlab_datascience_toolkit.data_parser.numpy_parser:NumpyParser"
entry_point="jlab_datascience_toolkit.data_parser.numpy_parser:NumpyParser",
)

from jlab_datascience_toolkit.data_parser.numpy_parser import NumpyParser

register(
id='CSVParser_v0',
id="CSVParser_v0",
entry_point="jlab_datascience_toolkit.data_parser.parser_to_dataframe:Parser2DataFrame",
kwargs={'registry_config': {'file_format': 'csv'}}
)
kwargs={"registry_config": {"file_format": "csv"}},
)
Loading