diff --git a/mindsdb/api/http/namespaces/sql.py b/mindsdb/api/http/namespaces/sql.py index 628c5fb2bd2..79b4ea5dee5 100644 --- a/mindsdb/api/http/namespaces/sql.py +++ b/mindsdb/api/http/namespaces/sql.py @@ -16,6 +16,11 @@ from mindsdb.utilities.config import Config from mindsdb.utilities.context import context as ctx +from mindsdb_sql import parse_sql +from mindsdb_sql.parser import ast +from mindsdb_sql.parser.ast import Constant, Identifier +from mindsdb_sql.planner.utils import query_traversal + logger = log.getLogger(__name__) @@ -25,6 +30,8 @@ class Query(Resource): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + + @ns_conf.doc("query") @api_endpoint_metrics('POST', '/sql/query') def post(self): @@ -108,6 +115,73 @@ def post(self): return query_response, 200 +@ns_conf.route("/query/constants") +@ns_conf.param("query", "Get Constants for the query") +class QueryConstants(Resource): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def find_constants_with_identifiers(self, node, replace_constants=False, identifiers_to_replace={}): + identifier_to_constant = {} + identifier_count = {} + last_identifier = None + + + def callback(n, **kwargs): + nonlocal last_identifier + + if isinstance(n, Identifier): + last_identifier = n + elif isinstance(n, Constant): + if last_identifier: + identifier_str = last_identifier.get_string() + if identifier_str in identifier_count: + identifier_count[identifier_str] += 1 + identifier_str += str(identifier_count[identifier_str]) + else: + identifier_count[identifier_str] = 0 + + identifier_to_constant[identifier_str] = (last_identifier.get_string(), n.value) + + if replace_constants and identifier_str in identifiers_to_replace: + n.value = '@' + identifiers_to_replace[identifier_str] + last_identifier = None # Reset after associating with a Constant + return None + + query_traversal(node, callback) + return identifier_to_constant + + + + @ns_conf.doc("query_constants") + @api_endpoint_metrics('POST', '/sql/query/constants') + def post(self): + query = request.json["query"] + replace_constants = request.json.get("replace_constants", False) + identifiers_to_replace = request.json.get("identifiers_to_replace", {}) + context = request.json.get("context", {}) + + try: + query_ast = parse_sql(query) + parameterized_query = query + constants_with_identifiers = self.find_constants_with_identifiers(query_ast,replace_constants=replace_constants, identifiers_to_replace=identifiers_to_replace) + if replace_constants: + parameterized_query = query_ast.to_string() + response = { + "constant_with_identifiers": constants_with_identifiers, + "parameterized_query": parameterized_query + } + query_response = {"type": SQL_RESPONSE_TYPE.OK, "data": response} + except Exception as e: + query_response = { + "type": SQL_RESPONSE_TYPE.ERROR, + "error_code": 0, + "error_message": str(e), + } + + return query_response, 200 + + @ns_conf.route("/list_databases") @ns_conf.param("list_databases", "lists databases of mindsdb")