From b58576e8d002983bf8024c4ad5e17cc863fedfe5 Mon Sep 17 00:00:00 2001 From: Mohammed Date: Tue, 29 Oct 2024 17:46:47 +0530 Subject: [PATCH] Added logic to extract datasource from the query AST --- mindsdb/api/http/namespaces/sql.py | 53 +++++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/mindsdb/api/http/namespaces/sql.py b/mindsdb/api/http/namespaces/sql.py index 79b4ea5dee5..334b3b50270 100644 --- a/mindsdb/api/http/namespaces/sql.py +++ b/mindsdb/api/http/namespaces/sql.py @@ -18,7 +18,7 @@ from mindsdb_sql import parse_sql from mindsdb_sql.parser import ast -from mindsdb_sql.parser.ast import Constant, Identifier +from mindsdb_sql.parser.ast import Constant, Identifier, Select, Join from mindsdb_sql.planner.utils import query_traversal logger = log.getLogger(__name__) @@ -141,7 +141,7 @@ def callback(n, **kwargs): else: identifier_count[identifier_str] = 0 - identifier_to_constant[identifier_str] = (last_identifier.get_string(), n.value) + identifier_to_constant[identifier_str] = (last_identifier.get_string(), n.value, type(n.value).__name__) if replace_constants and identifier_str in identifiers_to_replace: n.value = '@' + identifiers_to_replace[identifier_str] @@ -151,7 +151,48 @@ def callback(n, **kwargs): query_traversal(node, callback) return identifier_to_constant - + def get_children(self, node): + if hasattr(node, 'children'): + return node.children + elif isinstance(node, Select): + children = [] + + if node.from_table: + children.append(node.from_table) + if node.cte: + # TODO: Handle CTEs + pass + return children + elif isinstance(node, ast.Join): + children = [] + if node.left: + children.append(node.left) + if node.right: + children.append(node.right) + return children + else: + return [] + + def find_datasource(self, node): + datasource_node = None + max_depth = -1 + + def traverse(node, depth=0): + nonlocal datasource_node, max_depth + if isinstance(node, Identifier): + if depth > max_depth: + datasource_node = node + max_depth = depth + for child in self.get_children(node): + traverse(child, depth + 1) + + traverse(node) + if len(datasource_node.parts) <= 1: + return "" + + return datasource_node.parts[0] + + @ns_conf.doc("query_constants") @api_endpoint_metrics('POST', '/sql/query/constants') @@ -164,12 +205,16 @@ def post(self): try: query_ast = parse_sql(query) parameterized_query = query + datasource = "" constants_with_identifiers = self.find_constants_with_identifiers(query_ast,replace_constants=replace_constants, identifiers_to_replace=identifiers_to_replace) if replace_constants: parameterized_query = query_ast.to_string() + else: + datasource = self.find_datasource(query_ast) response = { "constant_with_identifiers": constants_with_identifiers, - "parameterized_query": parameterized_query + "parameterized_query": parameterized_query, + "datasource": datasource } query_response = {"type": SQL_RESPONSE_TYPE.OK, "data": response} except Exception as e: