Skip to content

Commit

Permalink
Merge pull request #127 from DocShow-AI/restrict_tables
Browse files Browse the repository at this point in the history
restrict access of analytics ai to tables of the organization
  • Loading branch information
liberty-rising authored Dec 21, 2023
2 parents 4bc029f + 066db30 commit 4ae0c74
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 37 deletions.
33 changes: 7 additions & 26 deletions backend/databases/table_manager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from sqlalchemy.orm import Session
from typing import List, Optional
from typing import Optional

import pandas as pd

from fastapi import HTTPException
from databases.sql_executor import SQLExecutor
from databases.table_map_manager import TableMapManager
from databases.table_metadata_manager import TableMetadataManager
from llms.base import BaseLLM
from models.organization_table_map import OrganizationTableMap
from models.table_map import TableMap
from utils.sql_string_manipulator import SQLStringManipulator


Expand All @@ -34,14 +35,12 @@ def _map_table_to_org(
if self.session:
print(f"Session: {self.session}")
alias = table_name if not alias else alias
self.session.add(
OrganizationTableMap(
organization_id=org_id, table_name=table_name, table_alias=alias
)
table_map_manager = TableMapManager(self.session)
table_map = TableMap(
organization_id=org_id, table_name=table_name, table_alias=alias
)
self.session.commit()
table_map_manager.create_table_map(table_map)
except Exception as e:
self.session.rollback() if self.session else None
print(f"An error occurred: {e}")
raise HTTPException(status_code=400, detail=str(e))

Expand Down Expand Up @@ -184,24 +183,6 @@ def execute_select_query(self, query: str, format_as_dict: bool = True):
print(f"An error occurred: {e}")
raise HTTPException(status_code=400, detail=str(e))

def get_org_tables(self, org_id: int) -> List:
"""Returns a list of names of all of the tables associated with an organization."""
try:
if self.session:
table_names = (
self.session.query(OrganizationTableMap.table_name)
.filter(OrganizationTableMap.organization_id == org_id)
.all()
)
return [
name[0] for name in table_names
] # Extracting table_name from each tuple
return []
except Exception as e:
self.session.rollback() if self.session else None
print(f"An error occurred: {e}")
raise HTTPException(status_code=400, detail=str(e))

def get_table_columns(self, table_name: str):
"""Returns a list of all of the columns present within the table."""
try:
Expand Down
61 changes: 61 additions & 0 deletions backend/databases/table_map_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from fastapi import HTTPException
from sqlalchemy.orm import Session
from typing import List

from models.table_map import TableMap


class TableMapManager:
"""
A class to manage CRUD operations related to the TableMap model.
Attributes:
db_session (Session): An active database session for performing operations.
"""

def __init__(self, session: Session):
"""
Initializes the TableMap with the given database session.
Args:
db_session (Session): The database session to be used for operations.
"""
self.db_session = session

def create_table_map(self, table_map: TableMap):
"""
Add a new table map to the database.
Args:
table_map (TableMap): The TableMap object to be added.
Returns:
TableMap: The created TableMap object.
"""
try:
if self.db_session:
self.db_session.add(table_map)
self.db_session.commit()
return table_map
except Exception as e:
self.db_session.rollback() if self.db_session else None
print(f"An error occurred: {e}")
raise HTTPException(status_code=400, detail=str(e))

def get_org_tables(self, org_id: int) -> List:
"""Returns a list of names of all of the tables associated with an organization."""
try:
if self.db_session:
table_names = (
self.db_session.query(TableMap.table_name)
.filter(TableMap.organization_id == org_id)
.all()
)
return [
name[0] for name in table_names
] # Extracting table_name from each tuple
return []
except Exception as e:
self.db_session.rollback() if self.db_session else None
print(f"An error occurred: {e}")
raise HTTPException(status_code=400, detail=str(e))
11 changes: 7 additions & 4 deletions backend/llms/gpt_lang.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from typing import List

from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.llms.openai import OpenAI
from langchain.sql_database import SQLDatabase

from settings import DB_URL


class GPTLang:
def __init__(self):
self.db = SQLDatabase.from_uri(DB_URL)
class GPTLangSQL:
def __init__(self, tables: List[str]):
if not tables:
raise ValueError("No tables provided")
self.db = SQLDatabase.from_uri(DB_URL, include_tables=tables)
self.toolkit = SQLDatabaseToolkit(db=self.db, llm=OpenAI(temperature=0))
self.agent_executor = create_sql_agent(
llm=OpenAI(temperature=0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .base import Base


class OrganizationTableMap(Base):
class TableMap(Base):
"""
Represents a mapping between an organization and a table.
Expand All @@ -14,7 +14,7 @@ class OrganizationTableMap(Base):
table_alias (str): The alias for the table.
"""

__tablename__ = "organization_table_map"
__tablename__ = "table_map"

id = Column(Integer, primary_key=True, autoincrement=True)
organization_id = Column(Integer)
Expand Down
9 changes: 7 additions & 2 deletions backend/routes/chat_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from databases.chat_history_manager import ChatHistoryManager
from databases.database_manager import DatabaseManager
from databases.table_map_manager import TableMapManager
from llms.base import BaseLLM
from llms.gpt_lang import GPTLang
from llms.gpt_lang import GPTLangSQL
from llms.utils import ChatRequest, ChatResponse, get_llm_chat_object
from models.chat import AnalyticsRequest, AnalyticsResponse
from models.user import User
Expand All @@ -28,7 +29,11 @@ async def chat_analytics_endpoint(
request: AnalyticsRequest,
current_user: User = Depends(get_current_user),
):
gpt = GPTLang()
with DatabaseManager() as session:
table_map_manager = TableMapManager(session)
org_tables = table_map_manager.get_org_tables(current_user.organization_id)

gpt = GPTLangSQL(tables=org_tables)
response = gpt.generate(request.prompt)
return AnalyticsResponse(chat_id=1, response=response)

Expand Down
3 changes: 0 additions & 3 deletions frontend/src/pages/Analytics/AnalyticsPage.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ function AnalyticsPage() {
<Box>
<Typography variant="h4" gutterBottom>📊 Data Analytics</Typography>
<Grid container spacing={2}>
<Grid item xs={12}>
<TableSelectDropdown tables={tables} selectedTable={selectedTable} onTableSelect={handleTableSelect} />
</Grid>
<Grid item xs={12}>
<AIAssistant table={selectedTable} />
</Grid>
Expand Down

0 comments on commit 4ae0c74

Please sign in to comment.