From e7ca03c5563ddb8646e87c42bbb7c64fd6ff7ee3 Mon Sep 17 00:00:00 2001 From: dvalleri Date: Wed, 11 Dec 2024 15:44:24 +0100 Subject: [PATCH] feat(sdk): add linearChart into sdk --- .../charts/binary_classification/__init__.py | 6 +- .../binary_classification/binary_chart.py | 114 +++++++++++++++- .../binary_chart_data.py | 13 +- .../charts/multi_classification/__init__.py | 7 +- .../multi_classification/multi_class_chart.py | 123 +++++++++++++++++- .../multi_class_chart_data.py | 22 +++- sdk/tests/chart/binary_charts.ipynb | 73 +++++++++-- sdk/tests/chart/multi_class.ipynb | 75 +++++++++-- 8 files changed, 389 insertions(+), 44 deletions(-) diff --git a/sdk/radicalbit_platform_sdk/charts/binary_classification/__init__.py b/sdk/radicalbit_platform_sdk/charts/binary_classification/__init__.py index 17044052..50753f5b 100644 --- a/sdk/radicalbit_platform_sdk/charts/binary_classification/__init__.py +++ b/sdk/radicalbit_platform_sdk/charts/binary_classification/__init__.py @@ -1,8 +1,8 @@ from .binary_chart import BinaryChart -from .binary_chart_data import BinaryChartData, Binary_Data +from .binary_chart_data import BinaryDistributionChartData, BinaryLinearChartData __all__ = [ 'BinaryChart', - 'BinaryChartData', - 'Binary_Data' + 'BinaryDistributionChartData', + 'BinaryLinearChartData' ] diff --git a/sdk/radicalbit_platform_sdk/charts/binary_classification/binary_chart.py b/sdk/radicalbit_platform_sdk/charts/binary_classification/binary_chart.py index 4857a460..4fa85f65 100644 --- a/sdk/radicalbit_platform_sdk/charts/binary_classification/binary_chart.py +++ b/sdk/radicalbit_platform_sdk/charts/binary_classification/binary_chart.py @@ -1,6 +1,6 @@ from ipecharts import EChartsRawWidget -from .binary_chart_data import BinaryChartData +from .binary_chart_data import BinaryDistributionChartData, BinaryLinearChartData from ..utils import get_chart_header @@ -8,15 +8,15 @@ class BinaryChart: def __init__(self) -> None: pass - def distribution_chart(self, data: BinaryChartData) -> EChartsRawWidget: + def distribution_chart(self, data: BinaryDistributionChartData) -> EChartsRawWidget: assert len(data.reference_data) <= 2 assert len(data.y_axis_label) <= 2 - + if data.current_data: assert len(data.current_data) <= 2 reference_json_data = data.model_dump().get('reference_data') - current_data_json = data.model_dump().get('current_data') + current_data_json = data.model_dump().get('current_data') reference_series_data = { "title": data.title, @@ -50,7 +50,8 @@ def distribution_chart(self, data: BinaryChartData) -> EChartsRawWidget: } } - series = [reference_series_data] if not data.current_data else [reference_series_data, current_series_data] + series = [reference_series_data] if not data.current_data else [ + reference_series_data, current_series_data] option = { "grid": { @@ -105,3 +106,106 @@ def distribution_chart(self, data: BinaryChartData) -> EChartsRawWidget: return EChartsRawWidget(option=option) + def linear_chart(self, data: BinaryLinearChartData) -> EChartsRawWidget: + + reference_json_data = data.model_dump().get('reference_data') + current_data_json = data.model_dump().get('current_data') + + reference_series_data = { + "name": "Reference", + "type": "line", + "lineStyle": { + "width": 2.2, + "color": "#9B99A1", + "type": "dotted" + }, + "symbol": "none", + "data": reference_json_data, + "itemStyle": { + "color": "#9B99A1" + }, + "endLabel": { + "show": True, + "color": "#9B99A1" + }, + "color": "#9B99A1" + } + + current_series_data = { + "name": data.title, + "type": "line", + "lineStyle": { + "width": 2.2, + "color": "#73B2E0" + }, + "symbol": "none", + "data": current_data_json, + "itemStyle": { + "color": "#73B2E0" + } + } + + series = [reference_series_data, current_series_data] + + options = { + + "tooltip": { + "trigger": "axis", + "crosshairs": True, + "axisPointer": { + "type": "cross", + "label": { + "show": True + } + } + }, + "yAxis": { + "type": "value", + "axisLabel": { + "fontSize": 9, + "color": "#9b99a1" + }, + "splitLine": { + "lineStyle": { + "color": "#9f9f9f54" + } + }, + "scale": True + }, + "xAxis": { + "type": "time", + "axisTick": { + "show": False + }, + "axisLine": { + "show": False + }, + "splitLine": { + "show": False + }, + "axisLabel": { + "fontSize": 12, + "color": "#9b99a1" + }, + "scale": True + }, + "grid": { + "bottom": 0, + "top": 32, + "left": 0, + "right": 64, + "containLabel": True + }, + "series": series, + "legend": { + "show": True, + "textStyle": { + "color": "#9B99A1" + }, + } + } + + options.update(get_chart_header(title=data.title)) + + return EChartsRawWidget(option=options) + diff --git a/sdk/radicalbit_platform_sdk/charts/binary_classification/binary_chart_data.py b/sdk/radicalbit_platform_sdk/charts/binary_classification/binary_chart_data.py index a7615e22..265f9eed 100644 --- a/sdk/radicalbit_platform_sdk/charts/binary_classification/binary_chart_data.py +++ b/sdk/radicalbit_platform_sdk/charts/binary_classification/binary_chart_data.py @@ -3,13 +3,18 @@ from pydantic import BaseModel -class Binary_Data(BaseModel): +class BinaryDistributionData(BaseModel): percentage: float count: float value: float -class BinaryChartData(BaseModel): +class BinaryDistributionChartData(BaseModel): title: str y_axis_label: List[str] - reference_data: List[Binary_Data] - current_data: Optional[List[Binary_Data]] = None + reference_data: List[BinaryDistributionData] + current_data: Optional[List[BinaryDistributionData]] = None + +class BinaryLinearChartData(BaseModel): + title: str + reference_data: List[List[str]] + current_data: List[List[str]] \ No newline at end of file diff --git a/sdk/radicalbit_platform_sdk/charts/multi_classification/__init__.py b/sdk/radicalbit_platform_sdk/charts/multi_classification/__init__.py index 41fcf326..c9ab2a3c 100644 --- a/sdk/radicalbit_platform_sdk/charts/multi_classification/__init__.py +++ b/sdk/radicalbit_platform_sdk/charts/multi_classification/__init__.py @@ -1,8 +1,9 @@ from .multi_class_chart import MultiClassificationChart -from .multi_class_chart_data import MultiClassificationData, MultiClassificationChartData +from .multi_class_chart_data import MultiClassificationDistributionChartData, MultiClassificationLinearChartData, MultiClassificationLinearData __all__ = [ 'MultiClassificationChart', - 'MultiClassificationData', - 'MultiClassificationChartData' + 'MultiClassificationDistributionChartData', + 'MultiClassificationLinearChartData', + 'MultiClassificationLinearData' ] diff --git a/sdk/radicalbit_platform_sdk/charts/multi_classification/multi_class_chart.py b/sdk/radicalbit_platform_sdk/charts/multi_classification/multi_class_chart.py index 75966945..73f7616b 100644 --- a/sdk/radicalbit_platform_sdk/charts/multi_classification/multi_class_chart.py +++ b/sdk/radicalbit_platform_sdk/charts/multi_classification/multi_class_chart.py @@ -1,6 +1,6 @@ from ipecharts import EChartsRawWidget -from .multi_class_chart_data import MultiClassificationChartData +from .multi_class_chart_data import MultiClassificationDistributionChartData, MultiClassificationLinearChartData from ..utils import get_chart_header @@ -8,7 +8,7 @@ class MultiClassificationChart: def __init__(self) -> None: pass - def distribution_chart(self, data: MultiClassificationChartData) -> EChartsRawWidget: + def distribution_chart(self, data: MultiClassificationDistributionChartData) -> EChartsRawWidget: reference_json_data = data.model_dump().get('reference_data') current_data_json = data.model_dump().get('current_data') @@ -33,7 +33,8 @@ def distribution_chart(self, data: MultiClassificationChartData) -> EChartsRawWi "data": current_data_json } - series = [reference_series_data] if not data.current_data else [reference_series_data, current_series_data] + series = [reference_series_data] if not data.current_data else [ + reference_series_data, current_series_data] option = { "grid": { @@ -99,3 +100,119 @@ def distribution_chart(self, data: MultiClassificationChartData) -> EChartsRawWi option.update(get_chart_header(title=data.title)) return EChartsRawWidget(option=option) + + def linear_chart(self, data: MultiClassificationLinearChartData) -> EChartsRawWidget: + + series = [] + + for element in data.current_data: + series.append({ + "name": element.name, + "type": "line", + "lineStyle": { + "width": 2.2 + }, + "symbol": "none", + "data": element.values + }) + + for element in data.reference_data: + series.append({ + "name": element.name, + "type": "line", + "lineStyle": { + "width": 2, + "type": "dotted" + }, + "symbol": "none", + "data": element.values, + "endLabel": { + "show": False, + "color": "#9B99A1", + } + }) + + options = { + "yAxis": { + "type": "value", + "axisLabel": { + "fontSize": 9, + "color": "#9b99a1" + }, + "splitLine": { + "lineStyle": { + "color": "#9f9f9f54" + } + } + }, + "xAxis": { + "type": "time", + "axisTick": { + "show": False + }, + "axisLine": { + "show": False + }, + "splitLine": { + "show": False + }, + "axisLabel": { + "fontSize": 12, + "color": "#9b99a1" + } + }, + "grid": { + "bottom": 0, + "top": 32, + "left": 0, + "right": 140, + "containLabel": True + }, + "color": [ + "#00BFFF", + "#1E90FF", + "#00CED1", + "#20B2AA", + "#4169E1", + "#6A5ACD", + "#8A2BE2", + "#9400D3", + "#BA55D3" + ], + "legend": { + "right": 0, + "top": 16, + "bottom": 0, + "orient": "vertical", + "type": "scroll", + "scrollDataIndex": "scroll", + "pageIconSize": 8, + "pageTextStyle": { + "fontSize": 9, + "color": "#9b99a1" + }, + "textStyle": { + "fontSize": 9, + "color": "#9B99A1", + "fontWeight": "300" + }, + }, + "tooltip": { + "trigger": "axis" + }, + "emphasis": { + "focus": "series" + }, + "title": { + "text": "••• Reference", + "textStyle": { + "fontSize": 10, + "fontWeight": "300", + "color": "#9B99A1" + }, + "right": 0 + }, + "series": series + } + print('\033[1m' + data.title) + return EChartsRawWidget(option=options) diff --git a/sdk/radicalbit_platform_sdk/charts/multi_classification/multi_class_chart_data.py b/sdk/radicalbit_platform_sdk/charts/multi_classification/multi_class_chart_data.py index 2ec81e67..b53ed5c1 100644 --- a/sdk/radicalbit_platform_sdk/charts/multi_classification/multi_class_chart_data.py +++ b/sdk/radicalbit_platform_sdk/charts/multi_classification/multi_class_chart_data.py @@ -2,13 +2,27 @@ from pydantic import BaseModel -class MultiClassificationData(BaseModel): +class MultiClassificationDistributionData(BaseModel): percentage: float count: float value: float -class MultiClassificationChartData(BaseModel): +class MultiClassificationDistributionChartData(BaseModel): title: str x_axis_label: List[str] - reference_data: List[MultiClassificationData] - current_data: Optional[List[MultiClassificationData]] = None + reference_data: List[MultiClassificationDistributionData] + current_data: Optional[List[MultiClassificationDistributionData]] = None + +class MultiClassificationLinearData(BaseModel): + name: str + values: List[List[str]] + + +class MultiClassificationLinearChartData(BaseModel): + title: str + reference_data: List[MultiClassificationLinearData] + current_data: List[MultiClassificationLinearData] + + + + \ No newline at end of file diff --git a/sdk/tests/chart/binary_charts.ipynb b/sdk/tests/chart/binary_charts.ipynb index 202e1db6..91f02a3d 100644 --- a/sdk/tests/chart/binary_charts.ipynb +++ b/sdk/tests/chart/binary_charts.ipynb @@ -2,14 +2,14 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "9c508258-8fd0-4b79-9621-debfcb498ef3", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "8bf8620979544b5db95a4cfa88834f51", + "model_id": "fe122246acb54deb8a0d9cbb065ad96a", "version_major": 2, "version_minor": 0 }, @@ -17,16 +17,16 @@ "EChartsRawWidget(option={'grid': {'left': 0, 'right': 20, 'bottom': 0, 'top': 40, 'containLabel': True}, 'xAxi…" ] }, - "execution_count": 1, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "from radicalbit_platform_sdk.charts.binary_classification import BinaryChart, BinaryChartData\n", + "from radicalbit_platform_sdk.charts.binary_classification import BinaryChart, BinaryDistributionChartData\n", "\n", "BinaryChart().distribution_chart(data=\n", - " BinaryChartData(\n", + " BinaryDistributionChartData(\n", " title=\"income 50k\",\n", " y_axis_label=[\"0.0\",\"1.0\"],\n", " reference_data=[\n", @@ -39,14 +39,14 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "0d7cf4a4-89a2-4a83-b50c-1a125c65861f", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d144e25569334c54ad7bdb0abae4f8a9", + "model_id": "bb41e0377f20454987d193aa5f170cb7", "version_major": 2, "version_minor": 0 }, @@ -54,16 +54,16 @@ "EChartsRawWidget(option={'grid': {'left': 0, 'right': 20, 'bottom': 0, 'top': 40, 'containLabel': True}, 'xAxi…" ] }, - "execution_count": 2, + "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "from radicalbit_platform_sdk.charts.binary_classification import BinaryChart, BinaryChartData\n", + "from radicalbit_platform_sdk.charts.binary_classification import BinaryChart, BinaryDistributionChartData\n", "\n", "BinaryChart().distribution_chart(data=\n", - " BinaryChartData(\n", + " BinaryDistributionChartData(\n", " title=\"income 50k\",\n", " y_axis_label=[\"0.0\",\"1.0\"],\n", " reference_data=[\n", @@ -80,9 +80,60 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "a197fd36-69cf-46f1-b892-a049cf14f6c4", "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8449bb9839fc473599743e3b9fb4e285", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "EChartsRawWidget(option={'tooltip': {'trigger': 'axis', 'crosshairs': True, 'axisPointer': {'type': 'cross', '…" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from radicalbit_platform_sdk.charts.binary_classification import BinaryChart,BinaryLinearChartData\n", + "\n", + "BinaryChart().linear_chart(data=\n", + " BinaryLinearChartData(\n", + " title=\"Accuracy\",\n", + " \n", + " reference_data=[\n", + " [ \"2024-03-10 00:00:00\",\"0.789\" ],\n", + " [ \"2024-03-10 01:00:00\",\"0.789\" ],\n", + " [ \"2024-03-10 02:00:00\",\"0.789\" ],\n", + " [ \"2024-03-10 03:00:00\",\"0.789\" ],\n", + " [ \"2024-03-10 04:00:00\",\"0.789\" ],\n", + " [ \"2024-03-10 05:00:00\",\"0.789\" ],\n", + " [ \"2024-03-10 06:00:00\",\"0.789\" ]\n", + " ],\n", + " current_data=[\n", + " [ \"2024-03-10 00:00:00\",\"0.792\" ],\n", + " [ \"2024-03-10 01:00:00\",\"0.125\" ],\n", + " [ \"2024-03-10 02:00:00\",\"0.792\" ],\n", + " [ \"2024-03-10 03:00:00\",\"0.708\" ],\n", + " [ \"2024-03-10 04:00:00\",\"1\" ],\n", + " [ \"2024-03-10 05:00:00\",\"0\" ],\n", + " [ \"2024-03-10 06:00:00\",\"1\" ]\n", + " ]\n", + " )\n", + ") " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6cdf251f-fc47-47d3-955c-03f78f011e13", + "metadata": {}, "outputs": [], "source": [] } diff --git a/sdk/tests/chart/multi_class.ipynb b/sdk/tests/chart/multi_class.ipynb index 874e4d64..6e54d9f2 100644 --- a/sdk/tests/chart/multi_class.ipynb +++ b/sdk/tests/chart/multi_class.ipynb @@ -2,14 +2,14 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "9ded8395-69cd-4c5e-82dd-69e95d6b8182", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "c4f6a64ba49a48a1825267d381129953", + "model_id": "f4bead6a05e448b4b5c7066ed7a7170c", "version_major": 2, "version_minor": 0 }, @@ -17,16 +17,16 @@ "EChartsRawWidget(option={'grid': {'left': 0, 'right': 20, 'bottom': 0, 'top': 40, 'containLabel': True}, 'xAxi…" ] }, - "execution_count": 1, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "from radicalbit_platform_sdk.charts.multi_classification import MultiClassificationChart, MultiClassificationChartData\n", + "from radicalbit_platform_sdk.charts.multi_classification import MultiClassificationChart, MultiClassificationDistributionChartData\n", "\n", "MultiClassificationChart().distribution_chart(data=\n", - " MultiClassificationChartData(\n", + " MultiClassificationDistributionChartData(\n", " title=\"ground_truth\",\n", " x_axis_label= [\"0\",\"1\",\"2\"],\n", " reference_data=[\n", @@ -40,14 +40,14 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "c1f952cc-af88-4479-9111-1b3d6a6cc705", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3c6cab91629646cab538e26596d43c7b", + "model_id": "f12688cf0f184fb9b90a575e7aaf7f11", "version_major": 2, "version_minor": 0 }, @@ -55,16 +55,16 @@ "EChartsRawWidget(option={'grid': {'left': 0, 'right': 20, 'bottom': 0, 'top': 40, 'containLabel': True}, 'xAxi…" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "from radicalbit_platform_sdk.charts.multi_classification import MultiClassificationChart, MultiClassificationChartData\n", + "from radicalbit_platform_sdk.charts.multi_classification import MultiClassificationChart, MultiClassificationDistributionChartData\n", "\n", "MultiClassificationChart().distribution_chart(data=\n", - " MultiClassificationChartData(\n", + " MultiClassificationDistributionChartData(\n", " title=\"ground_truth\",\n", " x_axis_label= [\"0\",\"1\",\"2\"],\n", " reference_data=[\n", @@ -83,9 +83,62 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "4064888b-7b71-4581-92db-cb3d142e3213", "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1mRecall\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9a134d6f8ed64140967ec49c84fab6da", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "EChartsRawWidget(option={'yAxis': {'type': 'value', 'axisLabel': {'fontSize': 9, 'color': '#9b99a1'}, 'splitLi…" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from radicalbit_platform_sdk.charts.multi_classification import MultiClassificationChart, MultiClassificationLinearChartData, MultiClassificationLinearData\n", + "\n", + "MultiClassificationChart().linear_chart(data=MultiClassificationLinearChartData(\n", + " title=\"Recall\",\n", + " reference_data=[\n", + " MultiClassificationLinearData(name='0',values=[ \n", + " [ \"2024-04-09 00:00:00\",\"1\" ],\n", + " [ \"2024-04-10 00:00:00\",\"0.974\" ],\n", + " [ \"2024-04-11 00:00:00\",\"0.95\" ],\n", + " [ \"2024-04-12 00:00:00\",\"1\" ]\n", + " ])\n", + " ],\n", + " current_data=[\n", + " MultiClassificationLinearData(name='0',values=[ \n", + " [ \"2024-04-09 00:00:00\",\"1\" ],\n", + " [ \"2024-04-10 00:00:00\",\"0.974\" ],\n", + " [ \"2024-04-11 00:00:00\",\"0.95\" ],\n", + " [ \"2024-04-12 00:00:00\",\"1\" ]\n", + " ])\n", + " ] \n", + "))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87e59bfc-9a6d-4983-8274-37b01b115d0d", + "metadata": {}, "outputs": [], "source": [] }