Skip to content

Commit

Permalink
Added logic to extract datasource from the query AST
Browse files Browse the repository at this point in the history
  • Loading branch information
Mohammed committed Oct 29, 2024
1 parent d36bb3f commit b58576e
Showing 1 changed file with 49 additions and 4 deletions.
53 changes: 49 additions & 4 deletions mindsdb/api/http/namespaces/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from mindsdb_sql import parse_sql
from mindsdb_sql.parser import ast
from mindsdb_sql.parser.ast import Constant, Identifier
from mindsdb_sql.parser.ast import Constant, Identifier, Select, Join
from mindsdb_sql.planner.utils import query_traversal

logger = log.getLogger(__name__)
Expand Down Expand Up @@ -141,7 +141,7 @@ def callback(n, **kwargs):
else:
identifier_count[identifier_str] = 0

identifier_to_constant[identifier_str] = (last_identifier.get_string(), n.value)
identifier_to_constant[identifier_str] = (last_identifier.get_string(), n.value, type(n.value).__name__)

if replace_constants and identifier_str in identifiers_to_replace:
n.value = '@' + identifiers_to_replace[identifier_str]
Expand All @@ -151,7 +151,48 @@ def callback(n, **kwargs):
query_traversal(node, callback)
return identifier_to_constant


def get_children(self, node):
if hasattr(node, 'children'):
return node.children
elif isinstance(node, Select):
children = []

if node.from_table:
children.append(node.from_table)
if node.cte:
# TODO: Handle CTEs
pass
return children
elif isinstance(node, ast.Join):
children = []
if node.left:
children.append(node.left)
if node.right:
children.append(node.right)
return children
else:
return []

def find_datasource(self, node):
datasource_node = None
max_depth = -1

def traverse(node, depth=0):
nonlocal datasource_node, max_depth
if isinstance(node, Identifier):
if depth > max_depth:
datasource_node = node
max_depth = depth
for child in self.get_children(node):
traverse(child, depth + 1)

traverse(node)
if len(datasource_node.parts) <= 1:
return ""

return datasource_node.parts[0]



@ns_conf.doc("query_constants")
@api_endpoint_metrics('POST', '/sql/query/constants')
Expand All @@ -164,12 +205,16 @@ def post(self):
try:
query_ast = parse_sql(query)
parameterized_query = query
datasource = ""
constants_with_identifiers = self.find_constants_with_identifiers(query_ast,replace_constants=replace_constants, identifiers_to_replace=identifiers_to_replace)
if replace_constants:
parameterized_query = query_ast.to_string()
else:
datasource = self.find_datasource(query_ast)
response = {
"constant_with_identifiers": constants_with_identifiers,
"parameterized_query": parameterized_query
"parameterized_query": parameterized_query,
"datasource": datasource
}
query_response = {"type": SQL_RESPONSE_TYPE.OK, "data": response}
except Exception as e:
Expand Down

0 comments on commit b58576e

Please sign in to comment.