Skip to content

Commit

Permalink
feat(sdk): add confusionMatrix chart
Browse files Browse the repository at this point in the history
  • Loading branch information
dvalleri committed Dec 11, 2024
1 parent e7ca03c commit 9a817bd
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 18 deletions.
3 changes: 2 additions & 1 deletion sdk/radicalbit_platform_sdk/charts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .chart_data import ChartData, NumericalBarChartData
from .chart_data import ChartData, NumericalBarChartData,ConfusionMatrixChartData
from .chart import Chart

__all__ = [
'ConfusionMatrixChartData',
'NumericalBarChartData',
'ChartData',
'Chart'
Expand Down
122 changes: 106 additions & 16 deletions sdk/radicalbit_platform_sdk/charts/chart.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from ipecharts import EChartsRawWidget

from .chart_data import ChartData, NumericalBarChartData
from .chart_data import ChartData, NumericalBarChartData, ConfusionMatrixChartData

from .utils import get_formatted_bucket_data,get_chart_header
from .utils import get_formatted_bucket_data, get_chart_header

import numpy as np


class Chart:
Expand All @@ -23,29 +25,31 @@ def placeholder_chart(self, data: ChartData) -> EChartsRawWidget:
return EChartsRawWidget(option=option)

def numerical_bar_chart(self, data: NumericalBarChartData) -> EChartsRawWidget:
bucket_data_formatted = get_formatted_bucket_data(bucket_data=data.bucket_data)
bucket_data_formatted = get_formatted_bucket_data(
bucket_data=data.bucket_data)

reference_data_json = {
"title": "reference",
"type": "bar",
"name":"Reference",
"title": "reference",
"type": "bar",
"name": "Reference",
"itemStyle": {
"color": "#9B99A1"
},
"data": data.reference_data
}
"data": data.reference_data
}

current_data_json = {
"title": "current",
"type": "bar",
"name":"Current",
"title": "current",
"type": "bar",
"name": "Current",
"itemStyle": {
"color": "#3695D9"
},
"data": data.current_data
}
"data": data.current_data
}

series = [reference_data_json] if not data.current_data else [reference_data_json, current_data_json]
series = [reference_data_json] if not data.current_data else [
reference_data_json, current_data_json]

option = {
"grid": {
Expand Down Expand Up @@ -97,7 +101,93 @@ def numerical_bar_chart(self, data: NumericalBarChartData) -> EChartsRawWidget:
},
"series": series
}

option.update(get_chart_header(title=data.title))

return EChartsRawWidget(option=option)

def confusion_matrix_chart(self, data: ConfusionMatrixChartData) -> EChartsRawWidget:
assert len(data.matrix) == len(data.axis_label) * len(data.axis_label) , "axis_label count and matrix item count are not compatibile"

np_matrix = np.matrix(data.matrix)

options = {
"yAxis": {
"type": "category",
"axisTick": {
"show": False
},
"axisLine": {
"show": False
},
"splitLine": {
"show": False
},
"axisLabel": {
"fontSize": 12,
"color": "#9B99A1"
},
"data":data.axis_label,
"name": "Actual",
"nameGap": 25,
"nameLocation": "middle"
},
"xAxis": {
"type": "category",
"axisTick": {
"show": False
},
"axisLine": {
"show": False
},
"splitLine": {
"show": False
},
"axisLabel": {
"fontSize": 12,
"interval": 0,
"color": "#9b99a1",
"rotate": 45
},
"data":data.axis_label.reverse(),
"name": "Predicted",
"nameGap": 25,
"nameLocation": "middle"
},
"grid": {
"bottom": 60,
"top": 0,
"left": 44,
"right": 80
},
"emphasis": {
"disabled": True
},
"axis": {
"axisLabel": {
"fontSize": 9,
"color": "#9b99a1"
}
},
"visualMap": {
"calculable": True,
"orient": "vertical",
"right": "right",
"top": "center",
"itemHeight": "250rem",
"max": np_matrix.max(),
"inRange": {
"color": data.color
}
},
"series": {
"name": "",
"type": "heatmap",
"label": {
"show": True
},
"data": data.matrix
}
}

return EChartsRawWidget(option=options)
5 changes: 5 additions & 0 deletions sdk/radicalbit_platform_sdk/charts/chart_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,8 @@ class NumericalBarChartData(BaseModel):
bucket_data: List[str]
reference_data: List[float]
current_data: Optional[List[float]] = None

class ConfusionMatrixChartData(BaseModel):
axis_label: List[str]
matrix: List[List[float]]
color: Optional[List[str]] = ["#FFFFFF","#9B99A1"]
86 changes: 85 additions & 1 deletion sdk/tests/chart/test_chart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,93 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "3a2a5412-7085-48de-87b1-9b7995d918bc",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f332f2a7412f4eeea5570faf4f72150d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"EChartsRawWidget(option={'yAxis': {'type': 'category', 'axisTick': {'show': False}, 'axisLine': {'show': False…"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from radicalbit_platform_sdk.charts import ConfusionMatrixChartData, Chart\n",
"\n",
"Chart().confusion_matrix_chart(data=ConfusionMatrixChartData(axis_label=[\"0\",\"1\"],matrix=[[0,0,281],[1,0,2121],[0,1,448],[1,1,150]]))"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5865d9fc-8c08-484e-9b2a-5755eced42f7",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "01703f01ecbf403f8e366c4c401637e9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"EChartsRawWidget(option={'yAxis': {'type': 'category', 'axisTick': {'show': False}, 'axisLine': {'show': False…"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from radicalbit_platform_sdk.charts import ConfusionMatrixChartData, Chart\n",
"\n",
"Chart().confusion_matrix_chart(data=ConfusionMatrixChartData(axis_label=[\"0\",\"1\"],matrix=[[0,0,281],[1,0,2121],[0,1,448],[1,1,150]],color=['#FFFFFF','#3695d9']))"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "09407640-2a3d-45db-841d-1d21e63ef6a0",
"metadata": {},
"outputs": [
{
"ename": "AssertionError",
"evalue": "axis_label count and matrix item count are not compatibile",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mradicalbit_platform_sdk\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcharts\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m ConfusionMatrixChartData, Chart\n\u001b[0;32m----> 3\u001b[0m \u001b[43mChart\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfusion_matrix_chart\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mConfusionMatrixChartData\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43maxis_label\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m0\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m1\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mmatrix\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m131\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m115\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m102\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mcolor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m#FFFFFF\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m#3695d9\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/radicalbit/radicalbit-ai-monitoring/sdk/radicalbit_platform_sdk/charts/chart.py:110\u001b[0m, in \u001b[0;36mChart.confusion_matrix_chart\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mconfusion_matrix_chart\u001b[39m(\u001b[38;5;28mself\u001b[39m, data: ConfusionMatrixChartData) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m EChartsRawWidget:\n\u001b[0;32m--> 110\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(data\u001b[38;5;241m.\u001b[39mmatrix) \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mlen\u001b[39m(data\u001b[38;5;241m.\u001b[39maxis_label) \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mlen\u001b[39m(data\u001b[38;5;241m.\u001b[39maxis_label) , \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124maxis_label count and matrix item count are not compatibile\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 112\u001b[0m np_matrix \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmatrix(data\u001b[38;5;241m.\u001b[39mmatrix)\n\u001b[1;32m 114\u001b[0m options \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 115\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124myAxis\u001b[39m\u001b[38;5;124m\"\u001b[39m: {\n\u001b[1;32m 116\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtype\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcategory\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 190\u001b[0m }\n\u001b[1;32m 191\u001b[0m }\n",
"\u001b[0;31mAssertionError\u001b[0m: axis_label count and matrix item count are not compatibile"
]
}
],
"source": [
"from radicalbit_platform_sdk.charts import ConfusionMatrixChartData, Chart\n",
"\n",
"Chart().confusion_matrix_chart(data=ConfusionMatrixChartData(\n",
" axis_label=[\"0\",\"1\"],\n",
" matrix=[[0,0,2],[1,0,4],[2,0,131],[0,1,3],[1,1,115],[2,1,4],[0,2,102],[1,2,3],[2,2,0]],\n",
" color=['#FFFFFF','#3695d9']))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2ae305c-17b2-4eab-9bcb-370526358d08",
"metadata": {},
"outputs": [],
"source": []
}
Expand Down

0 comments on commit 9a817bd

Please sign in to comment.