From df85c02f54ae1c69fc6e389ee8a453e239f4a2cd Mon Sep 17 00:00:00 2001 From: liberty-rising Date: Mon, 29 Jan 2024 17:23:44 +0100 Subject: [PATCH] suggest data types for table columns --- backend/llms/gpt.py | 14 +- backend/llms/prompt_manager.py | 15 ++ backend/llms/system_message_manager.py | 2 +- backend/models/data_profile.py | 4 + backend/routes/data_profile_routes.py | 26 ++- frontend/src/api/dataProfilesRequests.jsx | 26 +++ .../pages/upload/CreateDataProfileWindow.jsx | 45 +++-- .../upload/DataPreviewAndSchemaEditor.jsx | 171 +++++++++--------- 8 files changed, 189 insertions(+), 114 deletions(-) create mode 100644 frontend/src/api/dataProfilesRequests.jsx diff --git a/backend/llms/gpt.py b/backend/llms/gpt.py index 20d037e..27f7445 100644 --- a/backend/llms/gpt.py +++ b/backend/llms/gpt.py @@ -368,16 +368,20 @@ async def generate_chart_config( return parsed_config - async def generate_suggested_column_types(self, data: dict): + async def generate_suggested_column_types(self, column_names: list, data: dict): """Generate suggested column types for the given data.""" self._add_system_message(assistant_type="column_type_suggestion") self._set_response_format(is_json=True) - prompt = self.prompt_manager.create_column_type_suggestion_prompt(data) + prompt = self.prompt_manager.create_column_type_suggestion_prompt( + column_names, data + ) gpt_response = await self._send_and_receive_message(prompt) - return gpt_response + suggested_column_types = json.loads(gpt_response) + + return suggested_column_types def fetch_table_name_from_sample( self, sample_content: str, extra_desc: str, table_metadata: str @@ -429,5 +433,9 @@ async def extract_data_from_jpgs( "\n```", "" ) data = json.loads(json_string) + + # If data is a dictionary, wrap it in a list + if isinstance(data, dict): + data = [data] print(data) return data diff --git a/backend/llms/prompt_manager.py b/backend/llms/prompt_manager.py index 61a1336..ef202a2 100644 --- a/backend/llms/prompt_manager.py +++ b/backend/llms/prompt_manager.py @@ -122,3 +122,18 @@ def jpg_data_extraction_prompt(self, instructions: str): Return only the requested information, no additional text or formatting. """ return prompt + + def create_column_type_suggestion_prompt(self, column_names, data): + prompt = f""" + Based on the following data, suggest the data types for each column in the table. + The available column types are: text, integer, money, date, boolean + + Column names: + {column_names} + + Data: + {data} + + Return a JSON with the column names as keys and the suggested data types as values. + """ + return prompt diff --git a/backend/llms/system_message_manager.py b/backend/llms/system_message_manager.py index 5904570..4a1f83b 100644 --- a/backend/llms/system_message_manager.py +++ b/backend/llms/system_message_manager.py @@ -6,7 +6,7 @@ def __init__(self): You will be generating SQL queries, and providing useful information for reports and analytics based on the given prompt.""", "column_type_suggestion": """ You are a column type suggestion assistant. - You will be suggesting PostgreSQL column types based on the given prompt. + You will be suggesting column data types based on the given prompt. """, "sql_code": """ You are a PostgreSQL SQL statement assistant. diff --git a/backend/models/data_profile.py b/backend/models/data_profile.py index b90d507..9e7dacf 100644 --- a/backend/models/data_profile.py +++ b/backend/models/data_profile.py @@ -52,3 +52,7 @@ class DataProfileCreateRequest(BaseModel): class DataProfileCreateResponse(BaseModel): name: str extract_instructions: str + + +class SuggestedColumnTypesRequest(BaseModel): + data: list diff --git a/backend/routes/data_profile_routes.py b/backend/routes/data_profile_routes.py index ea4f1f5..7a0ef7f 100644 --- a/backend/routes/data_profile_routes.py +++ b/backend/routes/data_profile_routes.py @@ -12,6 +12,7 @@ DataProfile, DataProfileCreateRequest, DataProfileCreateResponse, + SuggestedColumnTypesRequest, ) from models.user import User from security import get_current_user @@ -78,6 +79,11 @@ async def get_data_profile( return data_profile +@data_profile_router.get("/data-profiles/column-types/") +async def get_column_types(current_user: User = Depends(get_current_user)): + return ["text", "integer", "money", "date", "boolean"] + + @data_profile_router.post("/data-profiles/preview/") async def preview_data_profile( files: List[UploadFile] = File(...), @@ -128,12 +134,20 @@ async def preview_data_profile( return extracted_data -# @data_profile_router.post("/data-profiles/preview/column-types/") -# async def generate_suggested_column_types( -# data, current_user: User = Depends(get_current_user) -# ): -# gpt = GPTLLM(chat_id=1, user=current_user) -# suggested_column_types = await gpt.generate_suggested_column_types(data) +@data_profile_router.post("/data-profiles/preview/column-types/") +async def generate_suggested_column_types( + request: SuggestedColumnTypesRequest, current_user: User = Depends(get_current_user) +): + gpt = GPTLLM(chat_id=1, user=current_user) + if request.data: + column_names = list(request.data[0].keys()) + suggested_column_types = await gpt.generate_suggested_column_types( + column_names, request.data + ) + + print(suggested_column_types) + + return suggested_column_types @data_profile_router.post("/data-profiles/{data_profile_name}/preview/") diff --git a/frontend/src/api/dataProfilesRequests.jsx b/frontend/src/api/dataProfilesRequests.jsx new file mode 100644 index 0000000..e659d15 --- /dev/null +++ b/frontend/src/api/dataProfilesRequests.jsx @@ -0,0 +1,26 @@ +import axios from "axios"; +import { API_URL } from "../utils/constants"; + +export const getPreviewData = (sampleFiles, extractInstructions) => { + const formData = new FormData(); + sampleFiles.forEach((file) => { + formData.append("files", file); + }); + formData.append("extract_instructions", extractInstructions); + + return axios.post(`${API_URL}data-profiles/preview/`, formData, { + headers: { + "Content-Type": "multipart/form-data", + }, + }); +}; + +export const getAvailableColumnTypes = () => { + return axios.get(`${API_URL}data-profiles/column-types/`); +}; + +export const getSuggestedColumnTypes = (previewData) => { + return axios.post(`${API_URL}data-profiles/preview/column-types/`, { + data: previewData, + }); +}; diff --git a/frontend/src/pages/upload/CreateDataProfileWindow.jsx b/frontend/src/pages/upload/CreateDataProfileWindow.jsx index 0f3a21e..1cc9185 100644 --- a/frontend/src/pages/upload/CreateDataProfileWindow.jsx +++ b/frontend/src/pages/upload/CreateDataProfileWindow.jsx @@ -9,16 +9,21 @@ import { Stack, TextField, } from "@mui/material"; -import axios from "axios"; import FileUploader from "./FileUploader"; import DataPreviewAndSchemaEditor from "./DataPreviewAndSchemaEditor"; -import { API_URL } from "../../utils/constants"; +import { + getPreviewData, + getAvailableColumnTypes, + getSuggestedColumnTypes, +} from "../../api/dataProfilesRequests"; function CreateDataProfileWindow({ open, onClose, onCreate }) { const [name, setName] = useState(""); const [extractInstructions, setExtractInstructions] = useState(""); const [sampleFiles, setSampleFiles] = useState([]); const [previewData, setPreviewData] = useState(null); + const [availableColumnTypes, setAvailableColumnTypes] = useState([]); + const [selectedColumnTypes, setSelectedColumnTypes] = useState(null); const [isPreviewLoading, setIsPreviewLoading] = useState(false); const [isPreviewTableOpen, setIsPreviewTableOpen] = useState(false); @@ -30,25 +35,33 @@ function CreateDataProfileWindow({ open, onClose, onCreate }) { const handlePreview = () => { if (sampleFiles.length && extractInstructions) { setIsPreviewLoading(true); + setPreviewData(null); + setSelectedColumnTypes(null); + const formData = new FormData(); sampleFiles.forEach((file) => { - formData.append("files", file); // Append each file + formData.append("files", file); }); formData.append("extract_instructions", extractInstructions); - axios - .post(`${API_URL}data-profiles/preview/`, formData, { - headers: { - "Content-Type": "multipart/form-data", - }, + Promise.all([ + getPreviewData(sampleFiles, extractInstructions), + getAvailableColumnTypes(), + ]) + .then(([previewDataResponse, availableTypesResponse]) => { + setPreviewData(previewDataResponse.data); + setAvailableColumnTypes(availableTypesResponse.data); + + return getSuggestedColumnTypes(previewDataResponse.data); }) - .then((response) => { - setPreviewData(response.data); // Store the preview data + .then((suggestedTypesResponse) => { + setSelectedColumnTypes(suggestedTypesResponse.data); setIsPreviewTableOpen(true); - setIsPreviewLoading(false); }) .catch((error) => { - console.error("Error on preview:", error); + console.error("Error during preview setup:", error); + }) + .finally(() => { setIsPreviewLoading(false); }); } @@ -95,8 +108,12 @@ function CreateDataProfileWindow({ open, onClose, onCreate }) { /> - {previewData && ( - + {previewData && selectedColumnTypes && ( + )} diff --git a/frontend/src/pages/upload/DataPreviewAndSchemaEditor.jsx b/frontend/src/pages/upload/DataPreviewAndSchemaEditor.jsx index 82b72c4..7b16231 100644 --- a/frontend/src/pages/upload/DataPreviewAndSchemaEditor.jsx +++ b/frontend/src/pages/upload/DataPreviewAndSchemaEditor.jsx @@ -1,4 +1,4 @@ -import React, { useState, useEffect, useRef } from "react"; +import React, { useState, useEffect } from "react"; import { Box, IconButton, @@ -13,117 +13,108 @@ import { TableHead, TableRow, TextField, + Tooltip, } from "@mui/material"; import EditIcon from "@mui/icons-material/Edit"; -function DataPreviewAndSchemaEditor({ previewData }) { - const data = Array.isArray(previewData) ? previewData : [previewData]; - const [columnNames, setColumnNames] = useState([]); - const [columnTypes, setColumnTypes] = useState([]); - const [editingColumnIndex, setEditingColumnIndex] = useState(null); - const inputRefs = useRef([]); +function DataPreviewAndSchemaEditor({ + previewData, + availableColumnTypes, + selectedColumnTypes, +}) { + const [columns, setColumns] = useState([]); useEffect(() => { - if (data && data.length > 0) { - const newColumnNames = Object.keys(data[0]); - let newColumnTypes; - if (columnTypes.length === 0) { - newColumnTypes = newColumnNames.map(() => "text"); - } else { - newColumnTypes = columnTypes; - } - if (JSON.stringify(newColumnNames) !== JSON.stringify(columnNames)) { - setColumnNames(newColumnNames); - } - if (JSON.stringify(newColumnTypes) !== JSON.stringify(columnTypes)) { - setColumnTypes(newColumnTypes); - } + if (Array.isArray(previewData) && previewData.length > 0) { + const initialColumns = Object.keys(previewData[0]).map((key) => ({ + name: key, + type: selectedColumnTypes[key] || "text", + isEditing: false, + })); + setColumns(initialColumns); } - }, [data]); - - const generateHeaderRow = (data) => { - if (data && data.length > 0) { - return columnNames.map((key, index) => ( - - - handleColumnNameChange(index, event.target.value) - } - variant="standard" - InputProps={{ - disableUnderline: true, - readOnly: editingColumnIndex !== index, - endAdornment: ( - - handleEditClick(index)}> - - - - ), - }} - inputRef={(ref) => (inputRefs.current[index] = ref)} - onClick={() => handleEditClick(index)} - style={{ cursor: "pointer" }} - inputProps={{ - style: { - cursor: editingColumnIndex === index ? "text" : "pointer", - }, - }} - onKeyDown={(event) => { - if (event.key === "Enter") { - event.preventDefault(); - setEditingColumnIndex(null); - } - }} - /> - - - )); - } - }; + }, [previewData, selectedColumnTypes]); const handleColumnTypeChange = (index, newType) => { - let newColumnTypes = [...columnTypes]; - newColumnTypes[index] = newType; - setColumnTypes(newColumnTypes); + setColumns((prevColumns) => + prevColumns.map((column, colIndex) => + colIndex === index ? { ...column, type: newType } : column, + ), + ); }; const handleEditClick = (index) => { - setEditingColumnIndex(index); - inputRefs.current[index].select(); + setColumns((prevColumns) => + prevColumns.map((column, colIndex) => + colIndex === index + ? { ...column, isEditing: !column.isEditing } + : column, + ), + ); }; const handleColumnNameChange = (index, newName) => { - setColumnNames((prevColumnNames) => { - const newColumnNames = [...prevColumnNames]; - newColumnNames[index] = newName; - return newColumnNames; - }); + setColumns((prevColumns) => + prevColumns.map((column, colIndex) => + colIndex === index ? { ...column, name: newName } : column, + ), + ); }; return ( - {generateHeaderRow(data)} + + {columns.map((column, index) => ( + + + handleColumnNameChange(index, event.target.value) + } + variant="standard" + InputProps={{ + disableUnderline: true, + readOnly: !column.isEditing, + endAdornment: ( + + handleEditClick(index)}> + + + + ), + }} + style={{ cursor: "pointer" }} + /> + + + + + ))} + - {data.map((row, index) => ( - - {Object.values(row).map((value, idx) => ( - {value} + {previewData.map((row, rowIndex) => ( + + {Object.values(row).map((value, cellIndex) => ( + {value} ))} ))}