Skip to content

Commit

Permalink
fix calc class weight
Browse files Browse the repository at this point in the history
  • Loading branch information
oaksharks committed Feb 22, 2024
1 parent a92c0fa commit 253e50d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion deeptables/models/deeptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def get_class_weight(self, y):
# logger.info(f'class {i}:{weight}')

n = len(self.classes_)
class_weight = get_tool_box(y).compute_class_weight('balanced', classes=range(n), y=y)
class_weight = get_tool_box(y).compute_class_weight('balanced', classes=np.array(list(range(n))), y=y)
class_weight = {k: v for k, v in zip(range(n), class_weight)}
logger.info(f'classes weight: {class_weight}')

Expand Down

0 comments on commit 253e50d

Please sign in to comment.