Skip to content

Commit

Permalink
feat(sdk): add radicalbit_platform_sdk plot module (#204)
Browse files Browse the repository at this point in the history
* feat(sdk): add placeholder chart (WIP)

* feat(sdk): add chart_sdk

* feat(sdk): fix version

* feat(sdk): rename file and format with ruff

* feat: add numerical_bar_chart

* feat(chart_sdk_: add test to NumericalBarChart

* feat(chart_sdk): binary wip

* feat(chart_sdk): add binary distribution chart

* feat(chart_sdk): add multiclassification data distribution chart

* feat(chart_sdk): add regression distribution chart + utils to get bucket data fromatted

* feat(chart_sdk): ruff fix

* feat(chart_sdk): add legend and title on every chart

* feat(chart_sdk): remove console.debug

* feat(chart_sdk): add assert into binaryCharts

* feat(sdk): move chart_sdk inside sdk

* feat(sdk): add linearChart into sdk

* feat(sdk): add confusionMatrix chart

* feat(sdk): fix test_chart notebook

* feat(sdk): add predictionActual chart

* feat(sdk): add residualScatterChart

* feat(sdk): add residualBucket chart

* feat(sdk): ruff format

* feat(sdk): ruff fix

* feat(sdk): ruff fixies

* feat(sdk): fix linear chart in multiclass

* feat(sdk) replace print chart titel with display fn

* feat(sdk): remove Option form color in regressionChart

* feat(sdk): replace model_dump + get with list comprehension

* feat(sdk): remove placeholder chart

---------

Co-authored-by: Luca Tagliabue <[email protected]>
  • Loading branch information
dvalleri and Luca Tagliabue authored Dec 12, 2024
1 parent cb9ea58 commit 7821c9a
Show file tree
Hide file tree
Showing 20 changed files with 4,075 additions and 551 deletions.
2,899 changes: 2,349 additions & 550 deletions sdk/poetry.lock

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion sdk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@ pydantic = "^2.7.1"
boto3 = "^1.34.111"
pandas = "^2.2.2"


[tool.poetry.group.dev.dependencies]
responses = "^0.25.0"
pytest = "^8.2.0"
ruff = "^0.4.4"
moto = {extras = ["s3"], version = "^5.0.7"}
jupyter = "^1.1.1"

[tool.poetry.extras]
chart = ["ipecharts"]

[build-system]
requires = ["poetry-core"]
Expand Down
8 changes: 8 additions & 0 deletions sdk/radicalbit_platform_sdk/charts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .chart_data import NumericalBarChartData, ConfusionMatrixChartData
from .chart import Chart

__all__ = [
'ConfusionMatrixChartData',
'NumericalBarChartData',
'Chart'
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .binary_chart import BinaryChart
from .binary_chart_data import BinaryDistributionChartData, BinaryLinearChartData

__all__ = [
'BinaryChart',
'BinaryDistributionChartData',
'BinaryLinearChartData'
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from ipecharts import EChartsRawWidget

from ..utils import get_chart_header
from .binary_chart_data import BinaryDistributionChartData, BinaryLinearChartData


class BinaryChart:
def __init__(self) -> None:
pass

def distribution_chart(self, data: BinaryDistributionChartData) -> EChartsRawWidget:
assert len(data.reference_data) <= 2
assert len(data.y_axis_label) <= 2

if data.current_data:
assert len(data.current_data) <= 2

reference_json_data = [binary_data.model_dump() for binary_data in data.reference_data]
current_data_json = [binary_data.model_dump() for binary_data in data.current_data] if data.current_data else []

reference_series_data = {
'title': data.title,
'type': 'bar',
'itemStyle': {'color': '#9B99A1'},
'data': reference_json_data,
'color': '#9B99A1',
'name': 'Reference',
'label': {
'show': True,
'position': 'insideRight',
'fontWeight': 'bold',
'color': '#FFFFFF',
},
}

current_series_data = {
'title': data.title + '_current',
'type': 'bar',
'itemStyle': {},
'data': current_data_json,
'color': '#3695d9',
'name': 'Current',
'label': {
'show': True,
'position': 'insideRight',
'fontWeight': 'bold',
'color': '#FFFFFF',
},
}

series = (
[reference_series_data]
if not data.current_data
else [reference_series_data, current_series_data]
)

option = {
'grid': {
'left': 0,
'right': 20,
'bottom': 0,
'top': 40,
'containLabel': True,
},
'xAxis': {
'type': 'value',
'axisLabel': {'fontSize': 9, 'color': '#9b99a1'},
'splitLine': {'lineStyle': {'color': '#9f9f9f54'}},
},
'yAxis': {
'type': 'category',
'axisTick': {'show': False},
'axisLine': {'show': False},
'splitLine': {'show': False},
'axisLabel': {'fontSize': 12, 'color': '#9B99A1'},
'data': data.y_axis_label,
},
'emphasis': {'disabled': True},
'barCategoryGap': '21%',
'barGap': '0',
'itemStyle': {'borderWidth': 1, 'borderColor': 'rgba(201, 25, 25, 1)'},
'series': series,
}

option.update(get_chart_header(title=data.title))

return EChartsRawWidget(option=option)

def linear_chart(self, data: BinaryLinearChartData) -> EChartsRawWidget:

reference_series_data = {
'name': 'Reference',
'type': 'line',
'lineStyle': {'width': 2.2, 'color': '#9B99A1', 'type': 'dotted'},
'symbol': 'none',
'data': data.reference_data,
'itemStyle': {'color': '#9B99A1'},
'endLabel': {'show': True, 'color': '#9B99A1'},
'color': '#9B99A1',
}

current_series_data = {
'name': data.title,
'type': 'line',
'lineStyle': {'width': 2.2, 'color': '#73B2E0'},
'symbol': 'none',
'data': data.current_data,
'itemStyle': {'color': '#73B2E0'},
}

series = [reference_series_data, current_series_data]

options = {
'tooltip': {
'trigger': 'axis',
'crosshairs': True,
'axisPointer': {'type': 'cross', 'label': {'show': True}},
},
'yAxis': {
'type': 'value',
'axisLabel': {'fontSize': 9, 'color': '#9b99a1'},
'splitLine': {'lineStyle': {'color': '#9f9f9f54'}},
'scale': True,
},
'xAxis': {
'type': 'time',
'axisTick': {'show': False},
'axisLine': {'show': False},
'splitLine': {'show': False},
'axisLabel': {'fontSize': 12, 'color': '#9b99a1'},
'scale': True,
},
'grid': {
'bottom': 0,
'top': 32,
'left': 0,
'right': 64,
'containLabel': True,
},
'series': series,
'legend': {
'show': True,
'textStyle': {'color': '#9B99A1'},
},
}

options.update(get_chart_header(title=data.title))

return EChartsRawWidget(option=options)
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import List, Optional

from pydantic import BaseModel


class BinaryDistributionData(BaseModel):
percentage: float
count: float
value: float


class BinaryDistributionChartData(BaseModel):
title: str
y_axis_label: List[str]
reference_data: List[BinaryDistributionData]
current_data: Optional[List[BinaryDistributionData]] = None


class BinaryLinearChartData(BaseModel):
title: str
reference_data: List[List[str]]
current_data: List[List[str]]
131 changes: 131 additions & 0 deletions sdk/radicalbit_platform_sdk/charts/chart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from ipecharts import EChartsRawWidget
import numpy as np

from .chart_data import ConfusionMatrixChartData, NumericalBarChartData
from .utils import get_chart_header, get_formatted_bucket_data


class Chart:
def __init__(self) -> None:
pass

def numerical_bar_chart(self, data: NumericalBarChartData) -> EChartsRawWidget:
bucket_data_formatted = get_formatted_bucket_data(bucket_data=data.bucket_data)

reference_data_json = {
'title': 'reference',
'type': 'bar',
'name': 'Reference',
'itemStyle': {'color': '#9B99A1'},
'data': data.reference_data,
}

current_data_json = {
'title': 'current',
'type': 'bar',
'name': 'Current',
'itemStyle': {'color': '#3695D9'},
'data': data.current_data,
}

series = (
[reference_data_json]
if not data.current_data
else [reference_data_json, current_data_json]
)

option = {
'grid': {
'left': 0,
'right': 20,
'bottom': 0,
'top': 40,
'containLabel': True,
},
'xAxis': {
'type': 'category',
'axisTick': {'show': False},
'axisLine': {'show': False},
'splitLine': {'show': False},
'axisLabel': {
'fontSize': 12,
'interval': 0,
'color': '#9B99A1',
'rotate': 20,
},
'data': bucket_data_formatted,
},
'yAxis': {
'type': 'value',
'axisLabel': {'fontSize': 9, 'color': '#9B99A1'},
'splitLine': {'lineStyle': {'color': '#9f9f9f54'}},
},
'emphasis': {'disabled': True},
'barCategoryGap': '0',
'barGap': '0',
'itemStyle': {'borderWidth': 1, 'borderColor': 'rgba(201, 25, 25, 1)'},
'series': series,
}

option.update(get_chart_header(title=data.title))

return EChartsRawWidget(option=option)

def confusion_matrix_chart(
self, data: ConfusionMatrixChartData
) -> EChartsRawWidget:
assert len(data.matrix) == len(data.axis_label) * len(
data.axis_label
), 'axis_label count and matrix item count are not compatibile'

np_matrix = np.matrix(data.matrix)

options = {
'yAxis': {
'type': 'category',
'axisTick': {'show': False},
'axisLine': {'show': False},
'splitLine': {'show': False},
'axisLabel': {'fontSize': 12, 'color': '#9B99A1'},
'data': data.axis_label,
'name': 'Actual',
'nameGap': 25,
'nameLocation': 'middle',
},
'xAxis': {
'type': 'category',
'axisTick': {'show': False},
'axisLine': {'show': False},
'splitLine': {'show': False},
'axisLabel': {
'fontSize': 12,
'interval': 0,
'color': '#9b99a1',
'rotate': 45,
},
'data': data.axis_label.reverse(),
'name': 'Predicted',
'nameGap': 25,
'nameLocation': 'middle',
},
'grid': {'bottom': 60, 'top': 0, 'left': 44, 'right': 80},
'emphasis': {'disabled': True},
'axis': {'axisLabel': {'fontSize': 9, 'color': '#9b99a1'}},
'visualMap': {
'calculable': True,
'orient': 'vertical',
'right': 'right',
'top': 'center',
'itemHeight': '250rem',
'max': np_matrix.max(),
'inRange': {'color': data.color},
},
'series': {
'name': '',
'type': 'heatmap',
'label': {'show': True},
'data': data.matrix,
},
}

return EChartsRawWidget(option=options)
16 changes: 16 additions & 0 deletions sdk/radicalbit_platform_sdk/charts/chart_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import List, Optional

from pydantic import BaseModel


class NumericalBarChartData(BaseModel):
title: str
bucket_data: List[str]
reference_data: List[float]
current_data: Optional[List[float]] = None


class ConfusionMatrixChartData(BaseModel):
axis_label: List[str]
matrix: List[List[float]]
color: Optional[List[str]] = ['#FFFFFF', '#9B99A1']
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .multi_class_chart import MultiClassificationChart
from .multi_class_chart_data import MultiClassificationDistributionChartData, MultiClassificationLinearChartData, MultiClassificationLinearData

__all__ = [
'MultiClassificationChart',
'MultiClassificationDistributionChartData',
'MultiClassificationLinearChartData',
'MultiClassificationLinearData'
]
Loading

0 comments on commit 7821c9a

Please sign in to comment.