From 02b5bdeac2fb7d2f3d685e9de7b60e0b5f1b7460 Mon Sep 17 00:00:00 2001 From: liberty-rising Date: Mon, 11 Dec 2023 00:11:11 +0100 Subject: [PATCH] show only tables for the user that are part of their organization --- backend/databases/sql_executor.py | 3 -- backend/databases/table_manager.py | 33 +++++++++++++------ backend/envs/dev/utils.py | 11 +++++-- backend/models/organization_table_map.py | 2 +- .../src/pages/Charts/Configs/ChartConfig.jsx | 21 +++++++++--- 5 files changed, 49 insertions(+), 21 deletions(-) diff --git a/backend/databases/sql_executor.py b/backend/databases/sql_executor.py index 8614810..5dc7d5e 100644 --- a/backend/databases/sql_executor.py +++ b/backend/databases/sql_executor.py @@ -82,9 +82,6 @@ def get_all_table_names_as_list(self) -> List: print(f"An error occurred: {e}") raise - def get_org_tables(self): - pass - def get_table_columns(self, table_name: str) -> List: try: engine = self.session.bind diff --git a/backend/databases/table_manager.py b/backend/databases/table_manager.py index a4d7ae2..bae4911 100644 --- a/backend/databases/table_manager.py +++ b/backend/databases/table_manager.py @@ -1,5 +1,5 @@ from sqlalchemy.orm import Session -from typing import Optional +from typing import List, Optional import pandas as pd @@ -17,22 +17,27 @@ class TableManager: Has functions that integrate a Large Language Model (LLM) and SQLExecutor. Attributes: - llm (BaseLLM): Instance of a Large Language Model for SQL operations. session (Optional[str]): The session used for database operations. + llm (BaseLLM): Instance of a Large Language Model for SQL operations. """ - def __init__(self, llm: BaseLLM = None, session: Optional[Session] = None): - self.llm = llm + def __init__(self, session: Optional[Session] = None, llm: BaseLLM = None): self.session = session + self.llm = llm def _map_table_to_org( self, org_id: int, table_name: str, alias: Optional[str] = None ): """Maps a table to an organization.""" try: + print(f"Mapping table {table_name} to organization {org_id}") if self.session: + print(f"Session: {self.session}") + alias = table_name if not alias else alias self.session.add( - OrganizationTableMap(table_name=table_name, organization_id=org_id) + OrganizationTableMap( + organization_id=org_id, table_name=table_name, table_alias=alias + ) ) self.session.commit() except Exception as e: @@ -170,13 +175,21 @@ def drop_table(self, table_name: str): print(f"An error occurred: {e}") raise HTTPException(status_code=400, detail=str(e)) - def get_org_tables(self, org_id: int): - """Returns a list of all of the tables present within the organization.""" + def get_org_tables(self, org_id: int) -> List: + """Returns a list of names of all of the tables associated with an organization.""" try: - executor = SQLExecutor(self.session) - tables = executor.get_org_tables(org_id) - return tables + if self.session: + table_names = ( + self.session.query(OrganizationTableMap.table_name) + .filter(OrganizationTableMap.organization_id == org_id) + .all() + ) + return [ + name[0] for name in table_names + ] # Extracting table_name from each tuple + return [] except Exception as e: + self.session.rollback() if self.session else None print(f"An error occurred: {e}") raise HTTPException(status_code=400, detail=str(e)) diff --git a/backend/envs/dev/utils.py b/backend/envs/dev/utils.py index 661eace..4f55251 100644 --- a/backend/envs/dev/utils.py +++ b/backend/envs/dev/utils.py @@ -1,9 +1,12 @@ +import pandas as pd + from databases.database_manager import DatabaseManager from databases.sql_executor import SQLExecutor from databases.table_manager import TableManager from databases.table_metadata_manager import TableMetadataManager +from utils.utils import get_app_logger -import pandas as pd +logger = get_app_logger(__name__) def seed_db(): @@ -35,7 +38,9 @@ def seed_db(): if "sample_sales" not in existing_tables: df = pd.read_csv("envs/dev/sample_data/sample_sales_data.csv") df.columns = map(str.lower, df.columns) - table_manager.create_table_from_df(df, 1, "sample_sales") + table_manager.create_table_from_df( + df=df, org_id=1, table_name="sample_sales" + ) # Add metadata metadata_manager = TableMetadataManager(session) @@ -81,3 +86,5 @@ def seed_db(): and time-based sales performance. """, ) + logger.info('Created table "sample_sales".') + logger.info('Sample table "sample_sales" already exists.') diff --git a/backend/models/organization_table_map.py b/backend/models/organization_table_map.py index e360198..0512500 100644 --- a/backend/models/organization_table_map.py +++ b/backend/models/organization_table_map.py @@ -18,5 +18,5 @@ class OrganizationTableMap(Base): id = Column(Integer, primary_key=True, autoincrement=True) organization_id = Column(Integer) - table_name = Column(String) + table_name = Column(String, unique=True) table_alias = Column(String) diff --git a/frontend/src/pages/Charts/Configs/ChartConfig.jsx b/frontend/src/pages/Charts/Configs/ChartConfig.jsx index dfd1985..3a88600 100644 --- a/frontend/src/pages/Charts/Configs/ChartConfig.jsx +++ b/frontend/src/pages/Charts/Configs/ChartConfig.jsx @@ -11,18 +11,29 @@ function ChartConfig({ onConfigChange, onRequiredSelected, chartConfig }) { const [selectedTable, setSelectedTable] = useState(''); const [chartTypes, setChartTypes] = useState([]); const [selectedChartType, setSelectedChartType] = useState(''); + const [organizationId, setOrganizationId] = useState(null); useEffect(() => { - // Fetch tables from API - axios.get(`${API_URL}tables/`) - .then(response => setTables(response.data)) - .catch(error => console.error('Error fetching tables:', error)) + // Fetch user data from API + axios.get(`${API_URL}users/me/`) + .then(response => { + // Set organizationId state + setOrganizationId(response.data.organization_id); + + // Fetch tables from API using organizationId + axios.get(`${API_URL}organization/${response.data.organization_id}/tables/`) + .then(response => setTables(response.data)) + .catch(error => console.error('Error fetching tables:', error)); + }) + .catch(error => console.error('Error fetching user data:', error)); + }, []); + useEffect(() => { // Fetch chart types from API axios.get(`${API_URL}charts/types/`) .then(response => setChartTypes(response.data)) .catch(error => console.error('Error fetching chart types:', error)); - }, []); + }, []); // Empty dependency array means this effect runs once on mount useEffect(() => { onConfigChange({