diff --git a/backend/databases/table_manager.py b/backend/databases/table_manager.py index 6d24f84..1d40a93 100644 --- a/backend/databases/table_manager.py +++ b/backend/databases/table_manager.py @@ -1,13 +1,14 @@ from sqlalchemy.orm import Session -from typing import List, Optional +from typing import Optional import pandas as pd from fastapi import HTTPException from databases.sql_executor import SQLExecutor +from databases.table_map_manager import TableMapManager from databases.table_metadata_manager import TableMetadataManager from llms.base import BaseLLM -from models.organization_table_map import OrganizationTableMap +from models.table_map import TableMap from utils.sql_string_manipulator import SQLStringManipulator @@ -34,14 +35,12 @@ def _map_table_to_org( if self.session: print(f"Session: {self.session}") alias = table_name if not alias else alias - self.session.add( - OrganizationTableMap( - organization_id=org_id, table_name=table_name, table_alias=alias - ) + table_map_manager = TableMapManager(self.session) + table_map = TableMap( + organization_id=org_id, table_name=table_name, table_alias=alias ) - self.session.commit() + table_map_manager.create_table_map(table_map) except Exception as e: - self.session.rollback() if self.session else None print(f"An error occurred: {e}") raise HTTPException(status_code=400, detail=str(e)) @@ -184,24 +183,6 @@ def execute_select_query(self, query: str, format_as_dict: bool = True): print(f"An error occurred: {e}") raise HTTPException(status_code=400, detail=str(e)) - def get_org_tables(self, org_id: int) -> List: - """Returns a list of names of all of the tables associated with an organization.""" - try: - if self.session: - table_names = ( - self.session.query(OrganizationTableMap.table_name) - .filter(OrganizationTableMap.organization_id == org_id) - .all() - ) - return [ - name[0] for name in table_names - ] # Extracting table_name from each tuple - return [] - except Exception as e: - self.session.rollback() if self.session else None - print(f"An error occurred: {e}") - raise HTTPException(status_code=400, detail=str(e)) - def get_table_columns(self, table_name: str): """Returns a list of all of the columns present within the table.""" try: diff --git a/backend/databases/table_map_manager.py b/backend/databases/table_map_manager.py new file mode 100644 index 0000000..f81495e --- /dev/null +++ b/backend/databases/table_map_manager.py @@ -0,0 +1,61 @@ +from fastapi import HTTPException +from sqlalchemy.orm import Session +from typing import List + +from models.table_map import TableMap + + +class TableMapManager: + """ + A class to manage CRUD operations related to the TableMap model. + + Attributes: + db_session (Session): An active database session for performing operations. + """ + + def __init__(self, session: Session): + """ + Initializes the TableMap with the given database session. + + Args: + db_session (Session): The database session to be used for operations. + """ + self.db_session = session + + def create_table_map(self, table_map: TableMap): + """ + Add a new table map to the database. + + Args: + table_map (TableMap): The TableMap object to be added. + + Returns: + TableMap: The created TableMap object. + """ + try: + if self.db_session: + self.db_session.add(table_map) + self.db_session.commit() + return table_map + except Exception as e: + self.db_session.rollback() if self.db_session else None + print(f"An error occurred: {e}") + raise HTTPException(status_code=400, detail=str(e)) + + def get_org_tables(self, org_id: int) -> List: + """Returns a list of names of all of the tables associated with an organization.""" + try: + if self.db_session: + table_names = ( + self.db_session.query(TableMap.table_name) + .filter(TableMap.organization_id == org_id) + .all() + ) + return [ + name[0] for name in table_names + ] # Extracting table_name from each tuple + return [] + except Exception as e: + self.db_session.rollback() if self.db_session else None + print(f"An error occurred: {e}") + raise HTTPException(status_code=400, detail=str(e)) diff --git a/backend/llms/gpt_lang.py b/backend/llms/gpt_lang.py index ada1306..8a0fc18 100644 --- a/backend/llms/gpt_lang.py +++ b/backend/llms/gpt_lang.py @@ -1,15 +1,18 @@ +from typing import List + from langchain.agents import create_sql_agent from langchain.agents.agent_toolkits import SQLDatabaseToolkit from langchain.agents.agent_types import AgentType from langchain.llms.openai import OpenAI from langchain.sql_database import SQLDatabase - from settings import DB_URL -class GPTLang: - def __init__(self): - self.db = SQLDatabase.from_uri(DB_URL) +class GPTLangSQL: + def __init__(self, tables: List[str]): + if not tables: + raise ValueError("No tables provided") + self.db = SQLDatabase.from_uri(DB_URL, include_tables=tables) self.toolkit = SQLDatabaseToolkit(db=self.db, llm=OpenAI(temperature=0)) self.agent_executor = create_sql_agent( llm=OpenAI(temperature=0), diff --git a/backend/models/organization_table_map.py b/backend/models/table_map.py similarity index 87% rename from backend/models/organization_table_map.py rename to backend/models/table_map.py index 0512500..206f273 100644 --- a/backend/models/organization_table_map.py +++ b/backend/models/table_map.py @@ -3,7 +3,7 @@ from .base import Base -class OrganizationTableMap(Base): +class TableMap(Base): """ Represents a mapping between an organization and a table. @@ -14,7 +14,7 @@ class OrganizationTableMap(Base): table_alias (str): The alias for the table. """ - __tablename__ = "organization_table_map" + __tablename__ = "table_map" id = Column(Integer, primary_key=True, autoincrement=True) organization_id = Column(Integer) diff --git a/backend/routes/chat_routes.py b/backend/routes/chat_routes.py index 595bbc5..7ae3c1a 100644 --- a/backend/routes/chat_routes.py +++ b/backend/routes/chat_routes.py @@ -2,8 +2,9 @@ from databases.chat_history_manager import ChatHistoryManager from databases.database_manager import DatabaseManager +from databases.table_map_manager import TableMapManager from llms.base import BaseLLM -from llms.gpt_lang import GPTLang +from llms.gpt_lang import GPTLangSQL from llms.utils import ChatRequest, ChatResponse, get_llm_chat_object from models.chat import AnalyticsRequest, AnalyticsResponse from models.user import User @@ -28,7 +29,11 @@ async def chat_analytics_endpoint( request: AnalyticsRequest, current_user: User = Depends(get_current_user), ): - gpt = GPTLang() + with DatabaseManager() as session: + table_map_manager = TableMapManager(session) + org_tables = table_map_manager.get_org_tables(current_user.organization_id) + + gpt = GPTLangSQL(tables=org_tables) response = gpt.generate(request.prompt) return AnalyticsResponse(chat_id=1, response=response) diff --git a/frontend/src/pages/Analytics/AnalyticsPage.jsx b/frontend/src/pages/Analytics/AnalyticsPage.jsx index ba9943f..7dbe3a8 100644 --- a/frontend/src/pages/Analytics/AnalyticsPage.jsx +++ b/frontend/src/pages/Analytics/AnalyticsPage.jsx @@ -32,9 +32,6 @@ function AnalyticsPage() { 📊 Data Analytics - - -