Skip to content

Commit

Permalink
feat(sdk): add predictionActual chart
Browse files Browse the repository at this point in the history
  • Loading branch information
dvalleri committed Dec 12, 2024
1 parent f3082bf commit 2f21592
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 17 deletions.
5 changes: 3 additions & 2 deletions sdk/radicalbit_platform_sdk/charts/regression/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .regression_chart import RegressionChart
from .regression_chart_data import RegressionChartData
from .regression_chart_data import RegressionDistributionChartData, RegressionPredictedActualChartData

__all__ = [
'RegressionChart',
'RegressionChartData'
'RegressionDistributionChartData',
'RegressionPredictedActualChartData'
]
132 changes: 125 additions & 7 deletions sdk/radicalbit_platform_sdk/charts/regression/regression_chart.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
from ipecharts import EChartsRawWidget
import numpy as np

from ..utils import get_formatted_bucket_data,get_chart_header
from .regression_chart_data import RegressionChartData
from ..utils import get_formatted_bucket_data, get_chart_header
from .regression_chart_data import RegressionDistributionChartData, RegressionPredictedActualChartData


class RegressionChart:
def __init__(self) -> None:
pass

def distribution_chart(self, data: RegressionChartData) -> EChartsRawWidget:
bucket_data_formatted = get_formatted_bucket_data(bucket_data=data.bucket_data)
def distribution_chart(self, data: RegressionDistributionChartData) -> EChartsRawWidget:
bucket_data_formatted = get_formatted_bucket_data(
bucket_data=data.bucket_data)

reference_json_data = data.model_dump().get('reference_data')
current_data_json = data.model_dump().get('current_data')

reference_series_data = {
"title": "reference",
"type": "bar",
"name":"Reference",
"name": "Reference",
"itemStyle": {
"color": "#9B99A1"
},
Expand All @@ -27,7 +29,7 @@ def distribution_chart(self, data: RegressionChartData) -> EChartsRawWidget:
current_series_data = {
"title": "current",
"type": "bar",
"name":"Current",
"name": "Current",
"itemStyle": {
"color": "#3695d9"
},
Expand Down Expand Up @@ -62,7 +64,7 @@ def distribution_chart(self, data: RegressionChartData) -> EChartsRawWidget:
"color": "#9b99a1",
"rotate": 20
},
"data":bucket_data_formatted,
"data": bucket_data_formatted,
},
"yAxis": {
"type": "value",
Expand Down Expand Up @@ -91,3 +93,119 @@ def distribution_chart(self, data: RegressionChartData) -> EChartsRawWidget:
option.update(get_chart_header(title=data.title))

return EChartsRawWidget(option=option)

def predicted_actual_chart(self, data: RegressionPredictedActualChartData) -> EChartsRawWidget:

np_array = np.array(data.scatter_data)
x_max = np_array.max()
x_min = np_array.min()

regression_line_data = [
[x_min, (data.coefficient * x_min) + data.intercept],
[x_max, (data.coefficient * x_max) + data.intercept]
]

diagonal_line_data = [
[x_min, x_min],
[x_max, x_max]
]

options = {
"grid": {
"left": 20,
"right": 0,
"bottom": 50,
"top": 24,
"containLabel": True
},
"xAxis": {
"type": "value",
"axisLabel": {
"fontSize": 9,
"color": "#9b99a1"
},
"splitLine": {
"lineStyle": {
"color": "#9f9f9f54"
}
},
"name": "ground_truth",
"nameGap": 25,
"nameLocation": "middle",
"scale": True
},
"yAxis": {
"type": "value",
"axisLabel": {
"fontSize": 9,
"color": "#9b99a1"
},
"splitLine": {
"lineStyle": {
"color": "#9f9f9f54"
}
},
"name": "prediction",
"nameGap": 25,
"nameLocation": "middle",
"scale": True
},
"tooltip": {
"axisPointer": {
"show": True,
"type": "cross",
"lineStyle": {
"type": "dashed",
"width": 1
}
}
},
"series": [
{
"name": "",
"type": "scatter",
"emphasis": {
"focus": "series"
},
"color": data.color,
"data": data.scatter_data
},
{
"name": "Diagonal line",
"type": "line",
"lineStyle": {
"width": 2.2,
"color": "#FFC000"
},
"symbol": "none",
"data": diagonal_line_data,
"itemStyle": {
"color": "#FFC000"
}
},
{
"name": "Regression line",
"type": "line",
"lineStyle": {
"width": 2.2,
"color": "#8D6ECF"
},
"symbol": "none",
"data": regression_line_data,
"itemStyle": {
"color": "#8D6ECF"
}
}
],
"legend": {
"show": True,
"textStyle": {
"color": "#9B99A1"
},
"right": 0
}
}

print('\033[1m'+'prediction vs ground_truth')

return EChartsRawWidget(option=options)
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@
from pydantic import BaseModel


class RegressionChartData(BaseModel):
class RegressionDistributionChartData(BaseModel):
title: str
bucket_data: List[str]
reference_data: List[float]
current_data: Optional[List[float]] = None

class RegressionPredictedActualChartData(BaseModel):
scatter_data: List[List[float]]
coefficient: float
intercept: float
color: Optional[str] = '#9B99A1'
93 changes: 86 additions & 7 deletions sdk/tests/chart/regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2c24436ea2f844ed95f1293926cf13c3",
"model_id": "3c5acc7c17474d38bc397e5be1ae4acb",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -23,10 +23,10 @@
}
],
"source": [
"from radicalbit_platform_sdk.charts.regression import RegressionChart, RegressionChartData\n",
"from radicalbit_platform_sdk.charts.regression import RegressionChart, RegressionDistributionChartData\n",
"\n",
"RegressionChart().distribution_chart(data=\n",
" RegressionChartData(\n",
" RegressionDistributionChartData(\n",
" title=\"ground_truth\",\n",
" bucket_data=[\"0\",\"2.7\",\"5.4\",\"8.1\",\"10.8\",\"13.5\",\"16.2\",\"18.9\",\"21.6\",\"24.3\",\"27\"],\n",
" reference_data=[383,355,367,244,379,331,252,326,331,374]\n",
Expand All @@ -43,7 +43,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8b3a2852923a46c3bd9a6a096ac5136b",
"model_id": "731a8a221b114cb6b4e1bea6f7d2751d",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -57,10 +57,10 @@
}
],
"source": [
"from radicalbit_platform_sdk.charts.regression import RegressionChart, RegressionChartData\n",
"from radicalbit_platform_sdk.charts.regression import RegressionChart, RegressionDistributionChartData\n",
"\n",
"RegressionChart().distribution_chart(data=\n",
" RegressionChartData(\n",
" RegressionDistributionChartData(\n",
" title=\"ground_truth\",\n",
" bucket_data=[\"0\",\"2.7\",\"5.4\",\"8.1\",\"10.8\",\"13.5\",\"16.2\",\"18.9\",\"21.6\",\"24.3\",\"27\"],\n",
" reference_data=[383,355,367,244,379,331,252,326,331,374],\n",
Expand All @@ -71,9 +71,88 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "a5e0914f-434f-4656-9c07-9e85ea2b5704",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1mprediction vs ground_truth\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "500bff92c4a249e2bd98e00a8c899c67",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"EChartsRawWidget(option={'grid': {'left': 20, 'right': 0, 'bottom': 50, 'top': 24, 'containLabel': True}, 'xAx…"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from radicalbit_platform_sdk.charts.regression import RegressionPredictedActualChartData, RegressionChart\n",
"\n",
"RegressionChart().predicted_actual_chart(data=RegressionPredictedActualChartData(\n",
" scatter_data=[[11,11],[2,2],[2,2],[17,17],[27,27],[13,13],[7,7],[18,18],[0,22],[13,13],[18,18],[4,4],[22,22],[13,9],[0,0],[27,27],[24,12],[13,13],[16,16],[16,16],[24,24],[13,13],[19,7],[13,13],[8,8],[3,14],[12,12],[6,6],[6,6],[4,4],[24,24],[5,5],[26,26],[12,12],[18,18],[27,27],[12,3],[11,11],[18,18],[3,3],[1,1],[11,8],[0,0],[27,27],[2,2],[26,26],[11,11],[3,3],[8,8],[11,11],[17,17],[26,26]],\n",
" coefficient=0.7942363125434776,\n",
" intercept=2.8227418911402906\n",
"))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d3be0c70-9dd1-432a-99fa-020490f212ca",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1mprediction vs ground_truth\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "22f4445d19fc487284a9c3276ddc0aca",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"EChartsRawWidget(option={'grid': {'left': 20, 'right': 0, 'bottom': 50, 'top': 24, 'containLabel': True}, 'xAx…"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from radicalbit_platform_sdk.charts.regression import RegressionPredictedActualChartData, RegressionChart\n",
"\n",
"RegressionChart().predicted_actual_chart(data=RegressionPredictedActualChartData(\n",
" scatter_data=[[11,11],[2,2],[2,2],[17,17],[27,27],[13,13],[7,7],[18,18],[0,22],[13,13],[18,18],[4,4],[22,22],[13,9],[0,0],[27,27],[24,12],[13,13],[16,16],[16,16],[24,24],[13,13],[19,7],[13,13],[8,8],[3,14],[12,12],[6,6],[6,6],[4,4],[24,24],[5,5],[26,26],[12,12],[18,18],[27,27],[12,3],[11,11],[18,18],[3,3],[1,1],[11,8],[0,0],[27,27],[2,2],[26,26],[11,11],[3,3],[8,8],[11,11],[17,17],[26,26]],\n",
" coefficient=0.7942363125434776,\n",
" intercept=2.8227418911402906,\n",
" color=\"#3695d9\"\n",
"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f08f32fe-960b-4bd6-9386-317fa41e5148",
"metadata": {},
"outputs": [],
"source": []
}
Expand Down

0 comments on commit 2f21592

Please sign in to comment.