From 2f21592c87a45050948ea4cd1f953cd79deb1e2c Mon Sep 17 00:00:00 2001 From: dvalleri Date: Thu, 12 Dec 2024 11:15:18 +0100 Subject: [PATCH] feat(sdk): add predictionActual chart --- .../charts/regression/__init__.py | 5 +- .../charts/regression/regression_chart.py | 132 +++++++++++++++++- .../regression/regression_chart_data.py | 8 +- sdk/tests/chart/regression.ipynb | 93 +++++++++++- 4 files changed, 221 insertions(+), 17 deletions(-) diff --git a/sdk/radicalbit_platform_sdk/charts/regression/__init__.py b/sdk/radicalbit_platform_sdk/charts/regression/__init__.py index 2200f7e6..e86e9d6c 100644 --- a/sdk/radicalbit_platform_sdk/charts/regression/__init__.py +++ b/sdk/radicalbit_platform_sdk/charts/regression/__init__.py @@ -1,7 +1,8 @@ from .regression_chart import RegressionChart -from .regression_chart_data import RegressionChartData +from .regression_chart_data import RegressionDistributionChartData, RegressionPredictedActualChartData __all__ = [ 'RegressionChart', - 'RegressionChartData' + 'RegressionDistributionChartData', + 'RegressionPredictedActualChartData' ] diff --git a/sdk/radicalbit_platform_sdk/charts/regression/regression_chart.py b/sdk/radicalbit_platform_sdk/charts/regression/regression_chart.py index 0cf7b56b..3bfe8acd 100644 --- a/sdk/radicalbit_platform_sdk/charts/regression/regression_chart.py +++ b/sdk/radicalbit_platform_sdk/charts/regression/regression_chart.py @@ -1,15 +1,17 @@ from ipecharts import EChartsRawWidget +import numpy as np -from ..utils import get_formatted_bucket_data,get_chart_header -from .regression_chart_data import RegressionChartData +from ..utils import get_formatted_bucket_data, get_chart_header +from .regression_chart_data import RegressionDistributionChartData, RegressionPredictedActualChartData class RegressionChart: def __init__(self) -> None: pass - def distribution_chart(self, data: RegressionChartData) -> EChartsRawWidget: - bucket_data_formatted = get_formatted_bucket_data(bucket_data=data.bucket_data) + def distribution_chart(self, data: RegressionDistributionChartData) -> EChartsRawWidget: + bucket_data_formatted = get_formatted_bucket_data( + bucket_data=data.bucket_data) reference_json_data = data.model_dump().get('reference_data') current_data_json = data.model_dump().get('current_data') @@ -17,7 +19,7 @@ def distribution_chart(self, data: RegressionChartData) -> EChartsRawWidget: reference_series_data = { "title": "reference", "type": "bar", - "name":"Reference", + "name": "Reference", "itemStyle": { "color": "#9B99A1" }, @@ -27,7 +29,7 @@ def distribution_chart(self, data: RegressionChartData) -> EChartsRawWidget: current_series_data = { "title": "current", "type": "bar", - "name":"Current", + "name": "Current", "itemStyle": { "color": "#3695d9" }, @@ -62,7 +64,7 @@ def distribution_chart(self, data: RegressionChartData) -> EChartsRawWidget: "color": "#9b99a1", "rotate": 20 }, - "data":bucket_data_formatted, + "data": bucket_data_formatted, }, "yAxis": { "type": "value", @@ -91,3 +93,119 @@ def distribution_chart(self, data: RegressionChartData) -> EChartsRawWidget: option.update(get_chart_header(title=data.title)) return EChartsRawWidget(option=option) + + def predicted_actual_chart(self, data: RegressionPredictedActualChartData) -> EChartsRawWidget: + + np_array = np.array(data.scatter_data) + x_max = np_array.max() + x_min = np_array.min() + + regression_line_data = [ + [x_min, (data.coefficient * x_min) + data.intercept], + [x_max, (data.coefficient * x_max) + data.intercept] + ] + + diagonal_line_data = [ + [x_min, x_min], + [x_max, x_max] + ] + + options = { + "grid": { + "left": 20, + "right": 0, + "bottom": 50, + "top": 24, + "containLabel": True + }, + "xAxis": { + "type": "value", + "axisLabel": { + "fontSize": 9, + "color": "#9b99a1" + }, + "splitLine": { + "lineStyle": { + "color": "#9f9f9f54" + } + }, + "name": "ground_truth", + "nameGap": 25, + "nameLocation": "middle", + "scale": True + }, + "yAxis": { + "type": "value", + "axisLabel": { + "fontSize": 9, + "color": "#9b99a1" + }, + "splitLine": { + "lineStyle": { + "color": "#9f9f9f54" + } + }, + "name": "prediction", + "nameGap": 25, + "nameLocation": "middle", + "scale": True + }, + "tooltip": { + "axisPointer": { + "show": True, + "type": "cross", + "lineStyle": { + "type": "dashed", + "width": 1 + } + } + }, + "series": [ + { + "name": "", + "type": "scatter", + "emphasis": { + "focus": "series" + }, + "color": data.color, + "data": data.scatter_data + }, + { + "name": "Diagonal line", + "type": "line", + "lineStyle": { + "width": 2.2, + "color": "#FFC000" + }, + "symbol": "none", + "data": diagonal_line_data, + "itemStyle": { + "color": "#FFC000" + } + }, + { + "name": "Regression line", + "type": "line", + "lineStyle": { + "width": 2.2, + "color": "#8D6ECF" + }, + "symbol": "none", + "data": regression_line_data, + "itemStyle": { + "color": "#8D6ECF" + } + } + ], + "legend": { + "show": True, + "textStyle": { + "color": "#9B99A1" + }, + "right": 0 + } + } + + print('\033[1m'+'prediction vs ground_truth') + + return EChartsRawWidget(option=options) diff --git a/sdk/radicalbit_platform_sdk/charts/regression/regression_chart_data.py b/sdk/radicalbit_platform_sdk/charts/regression/regression_chart_data.py index 599205fc..803ace97 100644 --- a/sdk/radicalbit_platform_sdk/charts/regression/regression_chart_data.py +++ b/sdk/radicalbit_platform_sdk/charts/regression/regression_chart_data.py @@ -3,8 +3,14 @@ from pydantic import BaseModel -class RegressionChartData(BaseModel): +class RegressionDistributionChartData(BaseModel): title: str bucket_data: List[str] reference_data: List[float] current_data: Optional[List[float]] = None + +class RegressionPredictedActualChartData(BaseModel): + scatter_data: List[List[float]] + coefficient: float + intercept: float + color: Optional[str] = '#9B99A1' diff --git a/sdk/tests/chart/regression.ipynb b/sdk/tests/chart/regression.ipynb index d6459b9a..9c284eec 100644 --- a/sdk/tests/chart/regression.ipynb +++ b/sdk/tests/chart/regression.ipynb @@ -9,7 +9,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "2c24436ea2f844ed95f1293926cf13c3", + "model_id": "3c5acc7c17474d38bc397e5be1ae4acb", "version_major": 2, "version_minor": 0 }, @@ -23,10 +23,10 @@ } ], "source": [ - "from radicalbit_platform_sdk.charts.regression import RegressionChart, RegressionChartData\n", + "from radicalbit_platform_sdk.charts.regression import RegressionChart, RegressionDistributionChartData\n", "\n", "RegressionChart().distribution_chart(data=\n", - " RegressionChartData(\n", + " RegressionDistributionChartData(\n", " title=\"ground_truth\",\n", " bucket_data=[\"0\",\"2.7\",\"5.4\",\"8.1\",\"10.8\",\"13.5\",\"16.2\",\"18.9\",\"21.6\",\"24.3\",\"27\"],\n", " reference_data=[383,355,367,244,379,331,252,326,331,374]\n", @@ -43,7 +43,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "8b3a2852923a46c3bd9a6a096ac5136b", + "model_id": "731a8a221b114cb6b4e1bea6f7d2751d", "version_major": 2, "version_minor": 0 }, @@ -57,10 +57,10 @@ } ], "source": [ - "from radicalbit_platform_sdk.charts.regression import RegressionChart, RegressionChartData\n", + "from radicalbit_platform_sdk.charts.regression import RegressionChart, RegressionDistributionChartData\n", "\n", "RegressionChart().distribution_chart(data=\n", - " RegressionChartData(\n", + " RegressionDistributionChartData(\n", " title=\"ground_truth\",\n", " bucket_data=[\"0\",\"2.7\",\"5.4\",\"8.1\",\"10.8\",\"13.5\",\"16.2\",\"18.9\",\"21.6\",\"24.3\",\"27\"],\n", " reference_data=[383,355,367,244,379,331,252,326,331,374],\n", @@ -71,9 +71,88 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "a5e0914f-434f-4656-9c07-9e85ea2b5704", "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1mprediction vs ground_truth\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "500bff92c4a249e2bd98e00a8c899c67", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "EChartsRawWidget(option={'grid': {'left': 20, 'right': 0, 'bottom': 50, 'top': 24, 'containLabel': True}, 'xAx…" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from radicalbit_platform_sdk.charts.regression import RegressionPredictedActualChartData, RegressionChart\n", + "\n", + "RegressionChart().predicted_actual_chart(data=RegressionPredictedActualChartData(\n", + " scatter_data=[[11,11],[2,2],[2,2],[17,17],[27,27],[13,13],[7,7],[18,18],[0,22],[13,13],[18,18],[4,4],[22,22],[13,9],[0,0],[27,27],[24,12],[13,13],[16,16],[16,16],[24,24],[13,13],[19,7],[13,13],[8,8],[3,14],[12,12],[6,6],[6,6],[4,4],[24,24],[5,5],[26,26],[12,12],[18,18],[27,27],[12,3],[11,11],[18,18],[3,3],[1,1],[11,8],[0,0],[27,27],[2,2],[26,26],[11,11],[3,3],[8,8],[11,11],[17,17],[26,26]],\n", + " coefficient=0.7942363125434776,\n", + " intercept=2.8227418911402906\n", + "))" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d3be0c70-9dd1-432a-99fa-020490f212ca", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1mprediction vs ground_truth\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "22f4445d19fc487284a9c3276ddc0aca", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "EChartsRawWidget(option={'grid': {'left': 20, 'right': 0, 'bottom': 50, 'top': 24, 'containLabel': True}, 'xAx…" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from radicalbit_platform_sdk.charts.regression import RegressionPredictedActualChartData, RegressionChart\n", + "\n", + "RegressionChart().predicted_actual_chart(data=RegressionPredictedActualChartData(\n", + " scatter_data=[[11,11],[2,2],[2,2],[17,17],[27,27],[13,13],[7,7],[18,18],[0,22],[13,13],[18,18],[4,4],[22,22],[13,9],[0,0],[27,27],[24,12],[13,13],[16,16],[16,16],[24,24],[13,13],[19,7],[13,13],[8,8],[3,14],[12,12],[6,6],[6,6],[4,4],[24,24],[5,5],[26,26],[12,12],[18,18],[27,27],[12,3],[11,11],[18,18],[3,3],[1,1],[11,8],[0,0],[27,27],[2,2],[26,26],[11,11],[3,3],[8,8],[11,11],[17,17],[26,26]],\n", + " coefficient=0.7942363125434776,\n", + " intercept=2.8227418911402906,\n", + " color=\"#3695d9\"\n", + "))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f08f32fe-960b-4bd6-9386-317fa41e5148", + "metadata": {}, "outputs": [], "source": [] }