Skip to content

Commit

Permalink
Merge pull request #114 from DocShow-AI/refactor_sqlexecutor
Browse files Browse the repository at this point in the history
use tablemanager directly instead of sqlexecutor
  • Loading branch information
liberty-rising authored Dec 11, 2023
2 parents b3df18e + cc8e3ca commit 4f2f44e
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 27 deletions.
19 changes: 17 additions & 2 deletions backend/databases/table_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ def drop_table(self, table_name: str):
print(f"An error occurred: {e}")
raise HTTPException(status_code=400, detail=str(e))

def execute_select_query(self, query: str, format_as_dict: bool = True):
try:
executor = SQLExecutor(self.session)
result = executor.execute_select_query(query, format_as_dict)
return result
except Exception as e:
print(f"An error occurred: {e}")
raise HTTPException(status_code=400, detail=str(e))

def get_org_tables(self, org_id: int) -> List:
"""Returns a list of names of all of the tables associated with an organization."""
try:
Expand Down Expand Up @@ -208,8 +217,14 @@ def get_table_metadata(self):
pass

def list_all_tables(self):
# Logic to list all tables
pass
"""Returns a list of all of the names of tables present within the database."""
try:
executor = SQLExecutor(self.session)
table_names = executor.get_all_table_names_as_list()
return table_names
except Exception as e:
print(f"An error occurred: {e}")
raise HTTPException(status_code=400, detail=str(e))

def validate_table_exists(self, table_name: str):
# Logic to validate if a table exists
Expand Down
26 changes: 13 additions & 13 deletions backend/envs/dev/initialization/setup_dev_environment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from databases.data_profile_manager import DataProfileManager
from databases.database_manager import DatabaseManager
from databases.dashboard_manager import DashboardManager
from databases.sql_executor import SQLExecutor
from databases.table_manager import TableManager
from databases.organization_manager import OrganizationManager
from databases.user_manager import UserManager
from models.organization import Organization
Expand Down Expand Up @@ -62,8 +62,8 @@ def create_sample_dashboard():
dashboard.name, dashboard.organization
)
if not existing_dashboard:
sql_executor = SQLExecutor(session)
charts = create_sample_charts(sql_executor)
table_manager = TableManager(session)
charts = create_sample_charts(table_manager)

dashboard.charts.extend(charts) # Directly append the list of charts
manager.save_dashboard(dashboard)
Expand All @@ -72,22 +72,22 @@ def create_sample_dashboard():
logger.debug("Sample dashboard already exists.")


def create_sample_charts(sql_executor: SQLExecutor):
def create_sample_charts(table_manager: TableManager):
# Add more chart creation logic here as needed
return [
create_sample_bar_chart(sql_executor),
create_sample_pie_chart(sql_executor),
create_sample_line_chart(sql_executor),
create_sample_bar_chart(table_manager),
create_sample_pie_chart(table_manager),
create_sample_line_chart(table_manager),
]


def create_sample_bar_chart(sql_executor: SQLExecutor):
def create_sample_bar_chart(table_manager: TableManager):
query = """
SELECT productline, SUM(sales) FROM sample_sales
GROUP BY productline
"""

sales_per_product = sql_executor.execute_select_query(query)
sales_per_product = table_manager.execute_select_query(query)

# Transforming the data for the bar chart
chart_data = [
Expand All @@ -110,15 +110,15 @@ def create_sample_bar_chart(sql_executor: SQLExecutor):
return bar_chart


def create_sample_pie_chart(sql_executor: SQLExecutor):
def create_sample_pie_chart(table_manager: TableManager):
query = """
SELECT country, ROUND(SUM(sales)::numeric,2) AS total_sales FROM sample_sales
GROUP BY country
ORDER BY total_sales DESC
LIMIT 10
"""

sales_by_country = sql_executor.execute_select_query(query)
sales_by_country = table_manager.execute_select_query(query)

# Transforming the data for the pie chart
chart_data = [
Expand All @@ -138,14 +138,14 @@ def create_sample_pie_chart(sql_executor: SQLExecutor):
return Chart(order=2, config=pie_chart_config)


def create_sample_line_chart(sql_executor: SQLExecutor):
def create_sample_line_chart(table_manager: TableManager):
query = """
SELECT year_id, SUM(sales) AS yearly_sales FROM sample_sales
GROUP BY year_id
ORDER BY year_id
"""

yearly_sales = sql_executor.execute_select_query(query)
yearly_sales = table_manager.execute_select_query(query)

# Transforming the data for the line chart
chart_data = [
Expand Down
9 changes: 3 additions & 6 deletions backend/envs/dev/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pandas as pd

from databases.database_manager import DatabaseManager
from databases.sql_executor import SQLExecutor
from databases.table_manager import TableManager
from databases.table_metadata_manager import TableMetadataManager
from utils.utils import get_app_logger
Expand All @@ -19,22 +18,20 @@ def seed_db():
- sample_tables: Dictionary mapping table names to corresponding CSV file names.
- existing_tables: List of table names that already exist in the database.
- session: Database session managed by DatabaseManager.
- executor: Instance of SQLExecutor for executing SQL queries.
- manager: Instance of TableManager for table operations.
Workflow:
1. Initialize `sample_tables` dictionary to hold table-to-file mappings.
2. Use DatabaseManager to create a session and SQLExecutor to get existing table names.
2. Use DatabaseManager to create a session and TableManager to get existing table names.
3. Loop through `sample_tables` and create tables if they don't exist, using data from CSV files.
"""

# Get existing tables
with DatabaseManager() as session:
executor = SQLExecutor(session)
existing_tables = executor.get_all_table_names_as_list()
table_manager = TableManager(session)
existing_tables = table_manager.list_all_tables()

# Create sample table if it doesn't exist
table_manager = TableManager(session)
if "sample_sales" not in existing_tables:
df = pd.read_csv("envs/dev/sample_data/sample_sales_data.csv")
df.columns = map(str.lower, df.columns)
Expand Down
5 changes: 2 additions & 3 deletions backend/routes/table_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from security import get_current_user

from databases.database_manager import DatabaseManager
from databases.sql_executor import SQLExecutor
from databases.table_manager import TableManager
from databases.table_metadata_manager import TableMetadataManager
from models.user import User
Expand Down Expand Up @@ -41,8 +40,8 @@ async def get_table_metadata(
@table_router.get("/tables/")
async def get_tables(current_user: User = Depends(get_current_user)):
with DatabaseManager() as session:
executor = SQLExecutor(session)
tables = executor.get_all_table_names_as_list()
table_manager = TableManager(session)
tables = table_manager.list_all_tables()
return tables


Expand Down
6 changes: 3 additions & 3 deletions backend/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
import sys

from databases.database_manager import DatabaseManager
from databases.sql_executor import SQLExecutor
from databases.table_manager import TableManager


def execute_select_query(query: str):
with DatabaseManager() as session:
sql_executor = SQLExecutor(session)
results = sql_executor.execute_select_query(query)
table_manager = TableManager(session)
results = table_manager.execute_select_query(query)
return results


Expand Down

0 comments on commit 4f2f44e

Please sign in to comment.