diff --git a/deeptables/models/deeptable.py b/deeptables/models/deeptable.py index 66dea62..02fc0d6 100644 --- a/deeptables/models/deeptable.py +++ b/deeptables/models/deeptable.py @@ -663,7 +663,7 @@ def get_class_weight(self, y): # logger.info(f'class {i}:{weight}') n = len(self.classes_) - class_weight = get_tool_box(y).compute_class_weight('balanced', classes=range(n), y=y) + class_weight = get_tool_box(y).compute_class_weight('balanced', classes=np.array(list(range(n))), y=y) class_weight = {k: v for k, v in zip(range(n), class_weight)} logger.info(f'classes weight: {class_weight}')