From 56b39a444d07b5af770881cb63554df5447b983a Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Mon, 14 Oct 2024 14:48:24 +0530 Subject: [PATCH 01/28] feat: api for fetching site specific database schema added --- agent/site.py | 9 +++++++++ agent/web.py | 4 ++++ 2 files changed, 13 insertions(+) diff --git a/agent/site.py b/agent/site.py index 2cba9785..350732e7 100644 --- a/agent/site.py +++ b/agent/site.py @@ -795,6 +795,15 @@ def get_database_free_tables(self): except Exception: return [] + def get_database_schema_sql(self): + return self.execute( + "set -o pipefail && " + "mysqldump --single-transaction --quick --lock-tables=false --no-data " + f"-h {self.host} -u {self.user} -p{self.password} " + f"{self.database}", + executable="/bin/bash", + ) + @property def job_record(self): return self.bench.server.job_record diff --git a/agent/web.py b/agent/web.py index e633c8c0..beb38fef 100644 --- a/agent/web.py +++ b/agent/web.py @@ -548,6 +548,10 @@ def backup_site(bench, site): job = Server().benches[bench].sites[site].backup_job(with_files, offsite) return {"job": job} +@application.route("/benches//sites//database-schema", methods=["POST"]) +@validate_bench_and_site +def fetch_database_schema(bench, site): + return {"data": Server().benches[bench].sites[site].get_database_schema_sql()} @application.route( "/benches//sites//migrate", From 94331ef31e8f526cbdb7984840bf9671f0c08332 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Mon, 14 Oct 2024 14:57:53 +0530 Subject: [PATCH 02/28] chore: update api route --- agent/web.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agent/web.py b/agent/web.py index beb38fef..f2dca7e9 100644 --- a/agent/web.py +++ b/agent/web.py @@ -548,7 +548,7 @@ def backup_site(bench, site): job = Server().benches[bench].sites[site].backup_job(with_files, offsite) return {"job": job} -@application.route("/benches//sites//database-schema", methods=["POST"]) +@application.route("/benches//sites//database/schema", methods=["GET"]) @validate_bench_and_site def fetch_database_schema(bench, site): return {"data": Server().benches[bench].sites[site].get_database_schema_sql()} From 73eb7de1fdcabed5b10ccd8dc727279b1811b945 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Mon, 14 Oct 2024 15:10:05 +0530 Subject: [PATCH 03/28] chore: debug json unserialize error --- agent/web.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agent/web.py b/agent/web.py index f2dca7e9..cca4a6ec 100644 --- a/agent/web.py +++ b/agent/web.py @@ -551,7 +551,7 @@ def backup_site(bench, site): @application.route("/benches//sites//database/schema", methods=["GET"]) @validate_bench_and_site def fetch_database_schema(bench, site): - return {"data": Server().benches[bench].sites[site].get_database_schema_sql()} + return {"data": str(Server().benches[bench].sites[site].get_database_schema_sql())} @application.route( "/benches//sites//migrate", From bf8347f74357ebc1797104707ae734f42f916af8 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Mon, 14 Oct 2024 15:13:01 +0530 Subject: [PATCH 04/28] fix: json serialize error --- agent/site.py | 2 +- agent/web.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/agent/site.py b/agent/site.py index 350732e7..e7d6ed5a 100644 --- a/agent/site.py +++ b/agent/site.py @@ -802,7 +802,7 @@ def get_database_schema_sql(self): f"-h {self.host} -u {self.user} -p{self.password} " f"{self.database}", executable="/bin/bash", - ) + ).get("output") @property def job_record(self): diff --git a/agent/web.py b/agent/web.py index cca4a6ec..f2dca7e9 100644 --- a/agent/web.py +++ b/agent/web.py @@ -551,7 +551,7 @@ def backup_site(bench, site): @application.route("/benches//sites//database/schema", methods=["GET"]) @validate_bench_and_site def fetch_database_schema(bench, site): - return {"data": str(Server().benches[bench].sites[site].get_database_schema_sql())} + return {"data": Server().benches[bench].sites[site].get_database_schema_sql()} @application.route( "/benches//sites//migrate", From 2b0e5a0eb506d1fed62afc72059c9b1f520f8ee5 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Tue, 15 Oct 2024 11:14:14 +0530 Subject: [PATCH 05/28] feat: replace mysqldump with information schema table column definitions --- agent/site.py | 14 ++++++-------- agent/web.py | 2 +- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/agent/site.py b/agent/site.py index e7d6ed5a..8e743b86 100644 --- a/agent/site.py +++ b/agent/site.py @@ -795,14 +795,10 @@ def get_database_free_tables(self): except Exception: return [] - def get_database_schema_sql(self): - return self.execute( - "set -o pipefail && " - "mysqldump --single-transaction --quick --lock-tables=false --no-data " - f"-h {self.host} -u {self.user} -p{self.password} " - f"{self.database}", - executable="/bin/bash", - ).get("output") + def get_database_table_schemas(self): + command = f"SELECT TABLE_NAME AS `table`, COLUMN_NAME AS `column`, DATA_TYPE AS `data_type`, IS_NULLABLE AS `is_nullable`, COLUMN_DEFAULT AS `default` FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA='{self.database}';" + command = quote(command) + return self.execute(f"mysql -sN -h {self.host} -u{self.user} -p{self.password} -e {command} --batch").get("output") @property def job_record(self): @@ -822,3 +818,5 @@ def generate_theme_files(self): " frappe.website.doctype.website_theme.website_theme" ".generate_theme_files_if_not_exist" ) + + diff --git a/agent/web.py b/agent/web.py index f2dca7e9..52ec80d8 100644 --- a/agent/web.py +++ b/agent/web.py @@ -548,7 +548,7 @@ def backup_site(bench, site): job = Server().benches[bench].sites[site].backup_job(with_files, offsite) return {"job": job} -@application.route("/benches//sites//database/schema", methods=["GET"]) +@application.route("/benches//sites//database/schemas", methods=["GET"]) @validate_bench_and_site def fetch_database_schema(bench, site): return {"data": Server().benches[bench].sites[site].get_database_schema_sql()} From a2c1c39b5a7dd9366f450a7c02fb7315b19344a9 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Tue, 15 Oct 2024 11:24:09 +0530 Subject: [PATCH 06/28] chore: typo --- agent/web.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/agent/web.py b/agent/web.py index 52ec80d8..7fde5e29 100644 --- a/agent/web.py +++ b/agent/web.py @@ -550,8 +550,8 @@ def backup_site(bench, site): @application.route("/benches//sites//database/schemas", methods=["GET"]) @validate_bench_and_site -def fetch_database_schema(bench, site): - return {"data": Server().benches[bench].sites[site].get_database_schema_sql()} +def fetch_database_schemas(bench, site): + return {"data": Server().benches[bench].sites[site].get_database_table_schemas()} @application.route( "/benches//sites//migrate", From 7d0bf1d9a3e77891bb490b8748fe0fe79a667939 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Tue, 15 Oct 2024 11:26:01 +0530 Subject: [PATCH 07/28] chore: send in csv format --- agent/site.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/agent/site.py b/agent/site.py index 8e743b86..e0c7fde2 100644 --- a/agent/site.py +++ b/agent/site.py @@ -798,7 +798,10 @@ def get_database_free_tables(self): def get_database_table_schemas(self): command = f"SELECT TABLE_NAME AS `table`, COLUMN_NAME AS `column`, DATA_TYPE AS `data_type`, IS_NULLABLE AS `is_nullable`, COLUMN_DEFAULT AS `default` FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA='{self.database}';" command = quote(command) - return self.execute(f"mysql -sN -h {self.host} -u{self.user} -p{self.password} -e {command} --batch").get("output") + data = self.execute(f"mysql -sN -h {self.host} -u{self.user} -p{self.password} -e {command} --batch").get("output") + data = data.split("\n") + data = [line.split("\t") for line in data] + return data @property def job_record(self): From de3ceefaf1d0fa220232bd5ff2935466f3fe04ed Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Tue, 15 Oct 2024 11:29:59 +0530 Subject: [PATCH 08/28] chore: structure the data in tables json --- agent/site.py | 15 ++++++++++++++- agent/web.py | 2 +- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/agent/site.py b/agent/site.py index e0c7fde2..a80291c7 100644 --- a/agent/site.py +++ b/agent/site.py @@ -801,7 +801,20 @@ def get_database_table_schemas(self): data = self.execute(f"mysql -sN -h {self.host} -u{self.user} -p{self.password} -e {command} --batch").get("output") data = data.split("\n") data = [line.split("\t") for line in data] - return data + tables = {} # : [, , ...] + for row in data: + if len(row) != 5: + continue + table = row[0] + if table not in tables: + tables[table] = [] + tables[table].append({ + "column": row[1], + "data_type": row[2], + "is_nullable": row[3], + "default": row[4], + }) + return tables @property def job_record(self): diff --git a/agent/web.py b/agent/web.py index 7fde5e29..284413d9 100644 --- a/agent/web.py +++ b/agent/web.py @@ -551,7 +551,7 @@ def backup_site(bench, site): @application.route("/benches//sites//database/schemas", methods=["GET"]) @validate_bench_and_site def fetch_database_schemas(bench, site): - return {"data": Server().benches[bench].sites[site].get_database_table_schemas()} + return Server().benches[bench].sites[site].get_database_table_schemas() @application.route( "/benches//sites//migrate", From fd481bbe084ab735a1469487792a4e6b62d98320 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Tue, 15 Oct 2024 11:32:24 +0530 Subject: [PATCH 09/28] chore: make is_nullable field boolean --- agent/site.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agent/site.py b/agent/site.py index a80291c7..f104df66 100644 --- a/agent/site.py +++ b/agent/site.py @@ -811,7 +811,7 @@ def get_database_table_schemas(self): tables[table].append({ "column": row[1], "data_type": row[2], - "is_nullable": row[3], + "is_nullable": True if row[3] == "YES" else False, "default": row[4], }) return tables From f453fb3a2c75de9317e0f4db2abaa8021373ee02 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Wed, 16 Oct 2024 07:56:12 +0000 Subject: [PATCH 10/28] feat: database query runner implemented --- agent/database2.py | 108 +++++++++++++++++++++++++++++++++++++++++++++ agent/site.py | 4 ++ agent/web.py | 12 +++++ 3 files changed, 124 insertions(+) create mode 100644 agent/database2.py diff --git a/agent/database2.py b/agent/database2.py new file mode 100644 index 00000000..11cec73e --- /dev/null +++ b/agent/database2.py @@ -0,0 +1,108 @@ +from __future__ import annotations +from peewee import MySQLDatabase, ProgrammingError, InternalError + +class Database: + def __init__(self, host, port, user, password, database): + self.db: 'MySQLDatabase' = MySQLDatabase( + database, + user=user, + password=password, + host=host, + port=port, + autocommit=False, + ) + + # Methods + def execute_query(self, query:str, commit:bool=False, as_dict:bool=False) -> list[bool, str]: + """ + This function will take the query and run in database. + + It will return a tuple of (bool, str) + bool: Whether the query has been executed successfully + str: The output of the query. It can be the output or error message as well + """ + try: + return True, self._sql(query, commit=commit, as_dict=as_dict) + except (ProgrammingError, InternalError) as e: + return False, "Error while executing query: " + str(e) + except Exception as e: + return False, "Failed to execute query. Please check the query and try again later." + + # Private helper methods + def _sql(self, query:str, params=(), commit:bool=False, as_dict:bool=False) -> dict|None: + """ + Run sql query in database + + Args: + query: SQL query + params: If you are using parameters in the query, you can pass them as a tuple + commit: True if you want to commit the changes. If commit is false, it will rollback the changes and also wouldnt allow to run ddl, dcl or tcl queries + as_dict: True if you want to return the result as a dictionary (like frappe.db.sql, get the results as dict). Otherwise it will return a dict of columns and data + + Return Format: + For as_dict = True: + [ + { + "name" : "Administrator", + "modified": "2019-01-01 00:00:00", + }, + ... + ] + + For as_dict = False: + { + "columns": ["name", "modified"], + "data": [ + ["Administrator", "2019-01-01 00:00:00"], + ... + ] + } + """ + + query = query.strip() + if not commit and self._is_restricted_query_for_no_commit_mode(query): + raise ProgrammingError("Provided query is not allowed in read only mode") + + # Start transaction + self.db.begin() + result = None + try: + cursor = self.db.execute_sql(query, params) + if cursor.description: + rows = cursor.fetchall() + columns = [d[0] for d in cursor.description] + if as_dict: + result = list(map(lambda x: dict(zip(columns, x)), rows)) + else: + result = { + "columns": columns, + "data": rows + } + except: + # if query execution fails, rollback the transaction and raise the error + self.db.rollback() + raise + else: + if commit: + # If commit is True, try to commit the transaction + try: + self.db.commit() + except: + self.db.rollback() + raise + else: + # If commit is False, rollback the transaction to discard the changes + self.db.rollback() + return result + + def _is_restricted_query_for_no_commit_mode(self, query:str) -> bool: + return self._is_ddl_query(query) or self._is_dcl_query(query) or self._is_tcl_query(query) + + def _is_ddl_query(self, query:str) -> bool: + return query.upper().startswith(("CREATE", "ALTER", "DROP", "TRUNCATE", "RENAME", "COMMENT")) + + def _is_dcl_query(self, query:str) -> bool: + return query.upper().startswith(("GRANT", "REVOKE")) + + def _is_tcl_query(self, query:str) -> bool: + return query.upper().replace(" ", "").startswith(("COMMIT", "ROLLBACK", "SAVEPOINT", "BEGINTRANSACTION")) diff --git a/agent/site.py b/agent/site.py index f104df66..8564d4fa 100644 --- a/agent/site.py +++ b/agent/site.py @@ -12,6 +12,7 @@ import requests from agent.base import AgentException, Base +from agent.database2 import Database from agent.job import job, step from agent.utils import b2mb, get_size @@ -816,6 +817,9 @@ def get_database_table_schemas(self): }) return tables + def run_sql_query(self, query:str, commit:bool=False, as_dict:bool=False): + return Database(self.host, 3306, self.user, self.password, self.database).execute_query(query, commit=commit, as_dict=as_dict) + @property def job_record(self): return self.bench.server.job_record diff --git a/agent/web.py b/agent/web.py index 284413d9..cd57794f 100644 --- a/agent/web.py +++ b/agent/web.py @@ -553,6 +553,18 @@ def backup_site(bench, site): def fetch_database_schemas(bench, site): return Server().benches[bench].sites[site].get_database_table_schemas() +@application.route("/benches//sites//database/query/execute", methods=["POST"]) +@validate_bench_and_site +def run_sql(bench, site): + query = request.json.get("query") + commit = request.json.get("commit") or False + as_dict = request.json.get("as_dict") or False + success, output = Server().benches[bench].sites[site].run_sql_query(query, commit, as_dict) + return { + "success": success, + "output": output + } + @application.route( "/benches//sites//migrate", methods=["POST"], From 6a7938a0d56736af31f38e3ede22d98c11362ecf Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Wed, 16 Oct 2024 13:47:26 +0530 Subject: [PATCH 11/28] chore: code cleanup and ruff format lint --- agent/database.py | 357 +++++++++++---------------------------- agent/database2.py | 108 ------------ agent/database_server.py | 277 ++++++++++++++++++++++++++++++ agent/site.py | 30 ++-- agent/web.py | 10 +- 5 files changed, 394 insertions(+), 388 deletions(-) delete mode 100644 agent/database2.py create mode 100644 agent/database_server.py diff --git a/agent/database.py b/agent/database.py index 395f0f43..deef35e0 100644 --- a/agent/database.py +++ b/agent/database.py @@ -1,277 +1,110 @@ from __future__ import annotations -import json -import os -import re -from datetime import datetime, timezone -from pathlib import Path - -from peewee import MySQLDatabase - -from agent.job import job, step -from agent.server import Server - - -class DatabaseServer(Server): - def __init__(self, directory=None): - self.directory = directory or os.getcwd() - self.config_file = os.path.join(self.directory, "config.json") - self.name = self.config["name"] - - self.mariadb_directory = "/var/lib/mysql" - self.pt_stalk_directory = "/var/lib/pt-stalk" - - self.job = None - self.step = None - - def search_binary_log( - self, - log, - database, - start_datetime, - stop_datetime, - search_pattern, - max_lines, - ): - log = os.path.join(self.mariadb_directory, log) - LINES_TO_SKIP = r"^(USE|COMMIT|START TRANSACTION|DELIMITER|ROLLBACK|#)" - command = ( - f"mysqlbinlog --short-form --database {database} " - f"--start-datetime '{start_datetime}' " - f"--stop-datetime '{stop_datetime}' " - f" {log} | grep -Piv '{LINES_TO_SKIP}'" +from peewee import InternalError, MySQLDatabase, ProgrammingError + + +class Database: + def __init__(self, host, port, user, password, database): + self.db: MySQLDatabase = MySQLDatabase( + database, + user=user, + password=password, + host=host, + port=port, + autocommit=False, ) - DELIMITER = "/*!*/;" - - events = [] - timestamp = 0 - for line in self.execute(command, skip_output_log=True)["output"].split(DELIMITER): - line = line.strip() - if line.startswith("SET TIMESTAMP"): - timestamp = int(line.split("=")[-1].split(".")[0]) - else: - if any(line.startswith(skip) for skip in ["SET", "/*!"]): - continue - if line and timestamp and re.search(search_pattern, line): - events.append( - { - "query": line, - "timestamp": str(datetime.utcfromtimestamp(timestamp)), - } - ) - if len(events) > max_lines: - break - return events - - @property - def binary_logs(self): - BINARY_LOG_FILE_PATTERN = r"mysql-bin.\d+" - files = [] - for file in Path(self.mariadb_directory).iterdir(): - if re.match(BINARY_LOG_FILE_PATTERN, file.name): - unix_timestamp = int(file.stat().st_mtime) - files.append( - { - "name": file.name, - "size": file.stat().st_size, - "modified": str(datetime.utcfromtimestamp(unix_timestamp)), - } - ) - return sorted(files, key=lambda x: x["name"]) - - def processes(self, private_ip, mariadb_root_password): - try: - mariadb = MySQLDatabase( - "mysql", - user="root", - password=mariadb_root_password, - host=private_ip, - port=3306, - ) - return self.sql(mariadb, "SHOW FULL PROCESSLIST") - except Exception: - import traceback - - traceback.print_exc() - return [] - - def locks(self, private_ip, mariadb_root_password): - try: - mariadb = MySQLDatabase( - "mysql", - user="root", - password=mariadb_root_password, - host=private_ip, - port=3306, - ) - return self.sql( - mariadb, - """ - SELECT l.*, t.* - FROM information_schema.INNODB_LOCKS l - JOIN information_schema.INNODB_TRX t ON l.lock_trx_id = t.trx_id - """, - ) - except Exception: - import traceback - - traceback.print_exc() - return [] + # Methods + def execute_query(self, query: str, commit: bool = False, as_dict: bool = False) -> list[bool, str]: + """ + This function will take the query and run in database. - def kill_processes(self, private_ip, mariadb_root_password, kill_threshold): - processes = self.processes(private_ip, mariadb_root_password) + It will return a tuple of (bool, str) + bool: Whether the query has been executed successfully + str: The output of the query. It can be the output or error message as well + """ try: - mariadb = MySQLDatabase( - "mysql", - user="root", - password=mariadb_root_password, - host=private_ip, - port=3306, - ) - for process in processes: - if (process["Time"] or 0) >= kill_threshold: - mariadb.execute_sql(f"KILL {process['Id']}") + return True, self._sql(query, commit=commit, as_dict=as_dict) + except (ProgrammingError, InternalError) as e: + return False, "Error while executing query: " + str(e) except Exception: - import traceback + return False, "Failed to execute query. Please check the query and try again later." - traceback.print_exc() - - def get_deadlocks( - self, - database, - start_datetime, - stop_datetime, - max_lines, - private_ip, - mariadb_root_password, - ): - mariadb = MySQLDatabase( - "percona", - user="root", - password=mariadb_root_password, - host=private_ip, - port=3306, - ) - - return self.sql( - mariadb, - f""" - select * - from deadlock - where user = %s - and ts >= %s - and ts <= %s - order by ts - limit {int(max_lines)}""", - (database, start_datetime, stop_datetime), - ) - - @staticmethod - def sql(db, query, params=()): - """Similar to frappe.db.sql, get the results as dict.""" - - cursor = db.execute_sql(query, params) - rows = cursor.fetchall() - columns = [d[0] for d in cursor.description] - return list(map(lambda x: dict(zip(columns, x)), rows)) - - @job("Column Statistics") - def fetch_column_stats(self, schema, table, private_ip, mariadb_root_password, doc_name): - self._fetch_column_stats(schema, table, private_ip, mariadb_root_password) - return {"doc_name": doc_name} - - @step("Fetch Column Statistics") - def _fetch_column_stats(self, schema, table, private_ip, mariadb_root_password): - """Get various stats about columns in a table. - - Refer: - - https://mariadb.com/kb/en/engine-independent-table-statistics/ - - https://mariadb.com/kb/en/mysqlcolumn_stats-table/ + # Private helper methods + def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = False) -> dict | None: + """ + Run sql query in database + + Args: + query: SQL query + params: If you are using parameters in the query, you can pass them as a tuple + commit: True if you want to commit the changes. If commit is false, it will rollback the changes and + also wouldnt allow to run ddl, dcl or tcl queries + as_dict: True if you want to return the result as a dictionary (like frappe.db.sql). + Otherwise it will return a dict of columns and data + + Return Format: + For as_dict = True: + [ + { + "name" : "Administrator", + "modified": "2019-01-01 00:00:00", + }, + ... + ] + + For as_dict = False: + { + "columns": ["name", "modified"], + "data": [ + ["Administrator", "2019-01-01 00:00:00"], + ... + ] + } """ - mariadb = MySQLDatabase( - "mysql", - user="root", - password=mariadb_root_password, - host=private_ip, - port=3306, - ) - - try: - self.sql( - mariadb, - f"ANALYZE TABLE `{schema}`.`{table}` PERSISTENT FOR ALL", - ) - - results = self.sql( - mariadb, - """ - SELECT - column_name, nulls_ratio, avg_length, avg_frequency, - decode_histogram(hist_type,histogram) as histogram - from mysql.column_stats - WHERE db_name = %s - and table_name = %s """, - (schema, table), - ) - - for row in results: - for column in ["nulls_ratio", "avg_length", "avg_frequency"]: - row[column] = float(row[column]) if row[column] else None - except Exception as e: - print(e) - - return {"output": json.dumps(results)} - - def explain_query(self, schema, query, private_ip, mariadb_root_password): - mariadb = MySQLDatabase( - schema, - user="root", - password=mariadb_root_password, - host=private_ip, - port=3306, - ) - if not query.lower().startswith(("select", "update", "delete")): - return [] + query = query.strip() + if not commit and self._is_restricted_query_for_no_commit_mode(query): + raise ProgrammingError("Provided query is not allowed in read only mode") + # Start transaction + self.db.begin() + result = None try: - return self.sql(mariadb, f"EXPLAIN {query}") - except Exception as e: - print(e) + cursor = self.db.execute_sql(query, params) + if cursor.description: + rows = cursor.fetchall() + columns = [d[0] for d in cursor.description] + if as_dict: + result = list(map(lambda x: dict(zip(columns, x)), rows)) + else: + result = {"columns": columns, "data": rows} + except: + # if query execution fails, rollback the transaction and raise the error + self.db.rollback() + raise + else: + if commit: + # If commit is True, try to commit the transaction + try: + self.db.commit() + except: + self.db.rollback() + raise + else: + # If commit is False, rollback the transaction to discard the changes + self.db.rollback() + return result + + def _is_restricted_query_for_no_commit_mode(self, query: str) -> bool: + return self._is_ddl_query(query) or self._is_dcl_query(query) or self._is_tcl_query(query) - def get_stalk(self, name): - diagnostics = [] - for file in Path(self.pt_stalk_directory).iterdir(): - if os.path.getsize(os.path.join(self.pt_stalk_directory, file.name)) > 16 * (1024**2): - # Skip files larger than 16 MB - continue - if re.match(name, file.name): - pt_stalk_path = (os.path.join(self.pt_stalk_directory, file.name),) - with open(pt_stalk_path, errors="replace") as f: - output = f.read() + def _is_ddl_query(self, query: str) -> bool: + return query.upper().startswith(("CREATE", "ALTER", "DROP", "TRUNCATE", "RENAME", "COMMENT")) - diagnostics.append( - { - "type": file.name.replace(name, "").strip("-"), - "output": output, - } - ) - return sorted(diagnostics, key=lambda x: x["type"]) + def _is_dcl_query(self, query: str) -> bool: + return query.upper().startswith(("GRANT", "REVOKE")) - def get_stalks(self): - stalk_pattern = r"(\d{4}_\d{2}_\d{2}_\d{2}_\d{2}_\d{2})-output" - stalks = [] - for file in Path(self.pt_stalk_directory).iterdir(): - matched = re.match(stalk_pattern, file.name) - if matched: - stalk = matched.group(1) - stalks.append( - { - "name": stalk, - "timestamp": datetime.strptime(stalk, "%Y_%m_%d_%H_%M_%S") - .replace(tzinfo=timezone.utc) - .isoformat(), - } - ) - return sorted(stalks, key=lambda x: x["name"]) + def _is_tcl_query(self, query: str) -> bool: + query = query.upper().replace(" ", "") + return query.startswith(("COMMIT", "ROLLBACK", "SAVEPOINT", "BEGINTRANSACTION")) diff --git a/agent/database2.py b/agent/database2.py deleted file mode 100644 index 11cec73e..00000000 --- a/agent/database2.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import annotations -from peewee import MySQLDatabase, ProgrammingError, InternalError - -class Database: - def __init__(self, host, port, user, password, database): - self.db: 'MySQLDatabase' = MySQLDatabase( - database, - user=user, - password=password, - host=host, - port=port, - autocommit=False, - ) - - # Methods - def execute_query(self, query:str, commit:bool=False, as_dict:bool=False) -> list[bool, str]: - """ - This function will take the query and run in database. - - It will return a tuple of (bool, str) - bool: Whether the query has been executed successfully - str: The output of the query. It can be the output or error message as well - """ - try: - return True, self._sql(query, commit=commit, as_dict=as_dict) - except (ProgrammingError, InternalError) as e: - return False, "Error while executing query: " + str(e) - except Exception as e: - return False, "Failed to execute query. Please check the query and try again later." - - # Private helper methods - def _sql(self, query:str, params=(), commit:bool=False, as_dict:bool=False) -> dict|None: - """ - Run sql query in database - - Args: - query: SQL query - params: If you are using parameters in the query, you can pass them as a tuple - commit: True if you want to commit the changes. If commit is false, it will rollback the changes and also wouldnt allow to run ddl, dcl or tcl queries - as_dict: True if you want to return the result as a dictionary (like frappe.db.sql, get the results as dict). Otherwise it will return a dict of columns and data - - Return Format: - For as_dict = True: - [ - { - "name" : "Administrator", - "modified": "2019-01-01 00:00:00", - }, - ... - ] - - For as_dict = False: - { - "columns": ["name", "modified"], - "data": [ - ["Administrator", "2019-01-01 00:00:00"], - ... - ] - } - """ - - query = query.strip() - if not commit and self._is_restricted_query_for_no_commit_mode(query): - raise ProgrammingError("Provided query is not allowed in read only mode") - - # Start transaction - self.db.begin() - result = None - try: - cursor = self.db.execute_sql(query, params) - if cursor.description: - rows = cursor.fetchall() - columns = [d[0] for d in cursor.description] - if as_dict: - result = list(map(lambda x: dict(zip(columns, x)), rows)) - else: - result = { - "columns": columns, - "data": rows - } - except: - # if query execution fails, rollback the transaction and raise the error - self.db.rollback() - raise - else: - if commit: - # If commit is True, try to commit the transaction - try: - self.db.commit() - except: - self.db.rollback() - raise - else: - # If commit is False, rollback the transaction to discard the changes - self.db.rollback() - return result - - def _is_restricted_query_for_no_commit_mode(self, query:str) -> bool: - return self._is_ddl_query(query) or self._is_dcl_query(query) or self._is_tcl_query(query) - - def _is_ddl_query(self, query:str) -> bool: - return query.upper().startswith(("CREATE", "ALTER", "DROP", "TRUNCATE", "RENAME", "COMMENT")) - - def _is_dcl_query(self, query:str) -> bool: - return query.upper().startswith(("GRANT", "REVOKE")) - - def _is_tcl_query(self, query:str) -> bool: - return query.upper().replace(" ", "").startswith(("COMMIT", "ROLLBACK", "SAVEPOINT", "BEGINTRANSACTION")) diff --git a/agent/database_server.py b/agent/database_server.py new file mode 100644 index 00000000..395f0f43 --- /dev/null +++ b/agent/database_server.py @@ -0,0 +1,277 @@ +from __future__ import annotations + +import json +import os +import re +from datetime import datetime, timezone +from pathlib import Path + +from peewee import MySQLDatabase + +from agent.job import job, step +from agent.server import Server + + +class DatabaseServer(Server): + def __init__(self, directory=None): + self.directory = directory or os.getcwd() + self.config_file = os.path.join(self.directory, "config.json") + self.name = self.config["name"] + + self.mariadb_directory = "/var/lib/mysql" + self.pt_stalk_directory = "/var/lib/pt-stalk" + + self.job = None + self.step = None + + def search_binary_log( + self, + log, + database, + start_datetime, + stop_datetime, + search_pattern, + max_lines, + ): + log = os.path.join(self.mariadb_directory, log) + LINES_TO_SKIP = r"^(USE|COMMIT|START TRANSACTION|DELIMITER|ROLLBACK|#)" + command = ( + f"mysqlbinlog --short-form --database {database} " + f"--start-datetime '{start_datetime}' " + f"--stop-datetime '{stop_datetime}' " + f" {log} | grep -Piv '{LINES_TO_SKIP}'" + ) + + DELIMITER = "/*!*/;" + + events = [] + timestamp = 0 + for line in self.execute(command, skip_output_log=True)["output"].split(DELIMITER): + line = line.strip() + if line.startswith("SET TIMESTAMP"): + timestamp = int(line.split("=")[-1].split(".")[0]) + else: + if any(line.startswith(skip) for skip in ["SET", "/*!"]): + continue + if line and timestamp and re.search(search_pattern, line): + events.append( + { + "query": line, + "timestamp": str(datetime.utcfromtimestamp(timestamp)), + } + ) + if len(events) > max_lines: + break + return events + + @property + def binary_logs(self): + BINARY_LOG_FILE_PATTERN = r"mysql-bin.\d+" + files = [] + for file in Path(self.mariadb_directory).iterdir(): + if re.match(BINARY_LOG_FILE_PATTERN, file.name): + unix_timestamp = int(file.stat().st_mtime) + files.append( + { + "name": file.name, + "size": file.stat().st_size, + "modified": str(datetime.utcfromtimestamp(unix_timestamp)), + } + ) + return sorted(files, key=lambda x: x["name"]) + + def processes(self, private_ip, mariadb_root_password): + try: + mariadb = MySQLDatabase( + "mysql", + user="root", + password=mariadb_root_password, + host=private_ip, + port=3306, + ) + return self.sql(mariadb, "SHOW FULL PROCESSLIST") + except Exception: + import traceback + + traceback.print_exc() + return [] + + def locks(self, private_ip, mariadb_root_password): + try: + mariadb = MySQLDatabase( + "mysql", + user="root", + password=mariadb_root_password, + host=private_ip, + port=3306, + ) + return self.sql( + mariadb, + """ + SELECT l.*, t.* + FROM information_schema.INNODB_LOCKS l + JOIN information_schema.INNODB_TRX t ON l.lock_trx_id = t.trx_id + """, + ) + except Exception: + import traceback + + traceback.print_exc() + return [] + + def kill_processes(self, private_ip, mariadb_root_password, kill_threshold): + processes = self.processes(private_ip, mariadb_root_password) + try: + mariadb = MySQLDatabase( + "mysql", + user="root", + password=mariadb_root_password, + host=private_ip, + port=3306, + ) + for process in processes: + if (process["Time"] or 0) >= kill_threshold: + mariadb.execute_sql(f"KILL {process['Id']}") + except Exception: + import traceback + + traceback.print_exc() + + def get_deadlocks( + self, + database, + start_datetime, + stop_datetime, + max_lines, + private_ip, + mariadb_root_password, + ): + mariadb = MySQLDatabase( + "percona", + user="root", + password=mariadb_root_password, + host=private_ip, + port=3306, + ) + + return self.sql( + mariadb, + f""" + select * + from deadlock + where user = %s + and ts >= %s + and ts <= %s + order by ts + limit {int(max_lines)}""", + (database, start_datetime, stop_datetime), + ) + + @staticmethod + def sql(db, query, params=()): + """Similar to frappe.db.sql, get the results as dict.""" + + cursor = db.execute_sql(query, params) + rows = cursor.fetchall() + columns = [d[0] for d in cursor.description] + return list(map(lambda x: dict(zip(columns, x)), rows)) + + @job("Column Statistics") + def fetch_column_stats(self, schema, table, private_ip, mariadb_root_password, doc_name): + self._fetch_column_stats(schema, table, private_ip, mariadb_root_password) + return {"doc_name": doc_name} + + @step("Fetch Column Statistics") + def _fetch_column_stats(self, schema, table, private_ip, mariadb_root_password): + """Get various stats about columns in a table. + + Refer: + - https://mariadb.com/kb/en/engine-independent-table-statistics/ + - https://mariadb.com/kb/en/mysqlcolumn_stats-table/ + """ + mariadb = MySQLDatabase( + "mysql", + user="root", + password=mariadb_root_password, + host=private_ip, + port=3306, + ) + + try: + self.sql( + mariadb, + f"ANALYZE TABLE `{schema}`.`{table}` PERSISTENT FOR ALL", + ) + + results = self.sql( + mariadb, + """ + SELECT + column_name, nulls_ratio, avg_length, avg_frequency, + decode_histogram(hist_type,histogram) as histogram + from mysql.column_stats + WHERE db_name = %s + and table_name = %s """, + (schema, table), + ) + + for row in results: + for column in ["nulls_ratio", "avg_length", "avg_frequency"]: + row[column] = float(row[column]) if row[column] else None + except Exception as e: + print(e) + + return {"output": json.dumps(results)} + + def explain_query(self, schema, query, private_ip, mariadb_root_password): + mariadb = MySQLDatabase( + schema, + user="root", + password=mariadb_root_password, + host=private_ip, + port=3306, + ) + + if not query.lower().startswith(("select", "update", "delete")): + return [] + + try: + return self.sql(mariadb, f"EXPLAIN {query}") + except Exception as e: + print(e) + + def get_stalk(self, name): + diagnostics = [] + for file in Path(self.pt_stalk_directory).iterdir(): + if os.path.getsize(os.path.join(self.pt_stalk_directory, file.name)) > 16 * (1024**2): + # Skip files larger than 16 MB + continue + if re.match(name, file.name): + pt_stalk_path = (os.path.join(self.pt_stalk_directory, file.name),) + with open(pt_stalk_path, errors="replace") as f: + output = f.read() + + diagnostics.append( + { + "type": file.name.replace(name, "").strip("-"), + "output": output, + } + ) + return sorted(diagnostics, key=lambda x: x["type"]) + + def get_stalks(self): + stalk_pattern = r"(\d{4}_\d{2}_\d{2}_\d{2}_\d{2}_\d{2})-output" + stalks = [] + for file in Path(self.pt_stalk_directory).iterdir(): + matched = re.match(stalk_pattern, file.name) + if matched: + stalk = matched.group(1) + stalks.append( + { + "name": stalk, + "timestamp": datetime.strptime(stalk, "%Y_%m_%d_%H_%M_%S") + .replace(tzinfo=timezone.utc) + .isoformat(), + } + ) + return sorted(stalks, key=lambda x: x["name"]) diff --git a/agent/site.py b/agent/site.py index 8564d4fa..6d9f3ca7 100644 --- a/agent/site.py +++ b/agent/site.py @@ -12,7 +12,7 @@ import requests from agent.base import AgentException, Base -from agent.database2 import Database +from agent.database import Database from agent.job import job, step from agent.utils import b2mb, get_size @@ -799,26 +799,32 @@ def get_database_free_tables(self): def get_database_table_schemas(self): command = f"SELECT TABLE_NAME AS `table`, COLUMN_NAME AS `column`, DATA_TYPE AS `data_type`, IS_NULLABLE AS `is_nullable`, COLUMN_DEFAULT AS `default` FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA='{self.database}';" command = quote(command) - data = self.execute(f"mysql -sN -h {self.host} -u{self.user} -p{self.password} -e {command} --batch").get("output") + data = self.execute( + f"mysql -sN -h {self.host} -u{self.user} -p{self.password} -e {command} --batch" + ).get("output") data = data.split("\n") data = [line.split("\t") for line in data] - tables = {} # : [, , ...] + tables = {} # : [, , ...] for row in data: if len(row) != 5: continue table = row[0] if table not in tables: tables[table] = [] - tables[table].append({ - "column": row[1], - "data_type": row[2], - "is_nullable": True if row[3] == "YES" else False, - "default": row[4], - }) + tables[table].append( + { + "column": row[1], + "data_type": row[2], + "is_nullable": row[3] == "YES", + "default": row[4], + } + ) return tables - def run_sql_query(self, query:str, commit:bool=False, as_dict:bool=False): - return Database(self.host, 3306, self.user, self.password, self.database).execute_query(query, commit=commit, as_dict=as_dict) + def run_sql_query(self, query: str, commit: bool = False, as_dict: bool = False): + return Database(self.host, 3306, self.user, self.password, self.database).execute_query( + query, commit=commit, as_dict=as_dict + ) @property def job_record(self): @@ -838,5 +844,3 @@ def generate_theme_files(self): " frappe.website.doctype.website_theme.website_theme" ".generate_theme_files_if_not_exist" ) - - diff --git a/agent/web.py b/agent/web.py index cd57794f..7d2797d7 100644 --- a/agent/web.py +++ b/agent/web.py @@ -14,7 +14,7 @@ from playhouse.shortcuts import model_to_dict from agent.builder import ImageBuilder, get_image_build_context_directory -from agent.database import DatabaseServer +from agent.database_server import DatabaseServer from agent.exceptions import BenchNotExistsException, SiteNotExistsException from agent.job import JobModel, connection from agent.minio import Minio @@ -548,11 +548,13 @@ def backup_site(bench, site): job = Server().benches[bench].sites[site].backup_job(with_files, offsite) return {"job": job} + @application.route("/benches//sites//database/schemas", methods=["GET"]) @validate_bench_and_site def fetch_database_schemas(bench, site): return Server().benches[bench].sites[site].get_database_table_schemas() + @application.route("/benches//sites//database/query/execute", methods=["POST"]) @validate_bench_and_site def run_sql(bench, site): @@ -560,10 +562,8 @@ def run_sql(bench, site): commit = request.json.get("commit") or False as_dict = request.json.get("as_dict") or False success, output = Server().benches[bench].sites[site].run_sql_query(query, commit, as_dict) - return { - "success": success, - "output": output - } + return {"success": success, "output": output} + @application.route( "/benches//sites//migrate", From f255605c1da10e8c9eb7ab48fe1b7906738e46a7 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Wed, 16 Oct 2024 13:49:48 +0530 Subject: [PATCH 12/28] chore:ruff lint --- agent/site.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/agent/site.py b/agent/site.py index 6d9f3ca7..5fe20870 100644 --- a/agent/site.py +++ b/agent/site.py @@ -797,7 +797,17 @@ def get_database_free_tables(self): return [] def get_database_table_schemas(self): - command = f"SELECT TABLE_NAME AS `table`, COLUMN_NAME AS `column`, DATA_TYPE AS `data_type`, IS_NULLABLE AS `is_nullable`, COLUMN_DEFAULT AS `default` FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA='{self.database}';" + command = f"""SELECT + TABLE_NAME AS `table`, + COLUMN_NAME AS `column`, + DATA_TYPE AS `data_type`, + IS_NULLABLE AS `is_nullable`, + COLUMN_DEFAULT AS `default` + FROM + INFORMATION_SCHEMA.COLUMNS + WHERE + TABLE_SCHEMA='{self.database}'; + """ command = quote(command) data = self.execute( f"mysql -sN -h {self.host} -u{self.user} -p{self.password} -e {command} --batch" From b2c166cd901e237a99d1fd76934559dbf26a80b9 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Fri, 18 Oct 2024 06:37:06 +0000 Subject: [PATCH 13/28] chore: remove custom error msg and log unknown error from sql runner --- agent/database.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/agent/database.py b/agent/database.py index deef35e0..41165e33 100644 --- a/agent/database.py +++ b/agent/database.py @@ -26,9 +26,13 @@ def execute_query(self, query: str, commit: bool = False, as_dict: bool = False) try: return True, self._sql(query, commit=commit, as_dict=as_dict) except (ProgrammingError, InternalError) as e: - return False, "Error while executing query: " + str(e) - except Exception: - return False, "Failed to execute query. Please check the query and try again later." + return False, str(e) + except Exception as e: + print(f"Error executing SQL Query on {self.database} : {e}") + return ( + False, + "Failed to execute query due to unknown error. Please check the query and try again later.", + ) # Private helper methods def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = False) -> dict | None: From 849990cdf56d959620551d4a3ae267be568a26c7 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Mon, 21 Oct 2024 06:05:44 +0000 Subject: [PATCH 14/28] feat: multi line sql query support added --- agent/database.py | 72 ++++++++++++++++++++++++++++++++--------------- agent/web.py | 4 +-- 2 files changed, 52 insertions(+), 24 deletions(-) diff --git a/agent/database.py b/agent/database.py index 41165e33..2f8fd0fe 100644 --- a/agent/database.py +++ b/agent/database.py @@ -38,6 +38,7 @@ def execute_query(self, query: str, commit: bool = False, as_dict: bool = False) def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = False) -> dict | None: """ Run sql query in database + It supports multi-line SQL queries. Each SQL Query should be terminated with `;\n` Args: query: SQL query @@ -51,38 +52,65 @@ def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = Fals For as_dict = True: [ { - "name" : "Administrator", - "modified": "2019-01-01 00:00:00", + "output": [ + { + "name" : "Administrator", + "modified": "2019-01-01 00:00:00", + }, + ... + ] + "query": "SELECT name, modified FROM `tabUser`", + "row_count": 10 }, ... ] For as_dict = False: - { - "columns": ["name", "modified"], - "data": [ - ["Administrator", "2019-01-01 00:00:00"], - ... - ] - } + [ + { + "output": { + "columns": ["name", "modified"], + "data": [ + ["Administrator", "2019-01-01 00:00:00"], + ... + ] + }, + "query": "SELECT name, modified FROM `tabUser`", + "row_count": 10 + }, + ... + ] """ - query = query.strip() - if not commit and self._is_restricted_query_for_no_commit_mode(query): - raise ProgrammingError("Provided query is not allowed in read only mode") + queries = [x.strip() for x in query.split(";\n")] + queries = [x for x in queries if x] + + if len(queries) == 0: + raise ProgrammingError("No query provided") # Start transaction self.db.begin() - result = None + results = [] try: - cursor = self.db.execute_sql(query, params) - if cursor.description: - rows = cursor.fetchall() - columns = [d[0] for d in cursor.description] - if as_dict: - result = list(map(lambda x: dict(zip(columns, x)), rows)) - else: - result = {"columns": columns, "data": rows} + for q in queries: + if not commit and self._is_restricted_query_for_no_commit_mode(query): + raise ProgrammingError("Provided query is not allowed in read only mode") + output = None + row_count = None + cursor = self.db.execute_sql(q, params) + row_count = cursor.rowcount + if cursor.description: + rows = cursor.fetchall() + columns = [d[0] for d in cursor.description] + if as_dict: + output = list(map(lambda x: dict(zip(columns, x)), rows)) + else: + output = {"columns": columns, "data": rows} + results.append({ + "query": q, + "output": output, + "row_count": row_count + }) except: # if query execution fails, rollback the transaction and raise the error self.db.rollback() @@ -98,7 +126,7 @@ def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = Fals else: # If commit is False, rollback the transaction to discard the changes self.db.rollback() - return result + return results def _is_restricted_query_for_no_commit_mode(self, query: str) -> bool: return self._is_ddl_query(query) or self._is_dcl_query(query) or self._is_tcl_query(query) diff --git a/agent/web.py b/agent/web.py index 7d2797d7..45f588a2 100644 --- a/agent/web.py +++ b/agent/web.py @@ -561,8 +561,8 @@ def run_sql(bench, site): query = request.json.get("query") commit = request.json.get("commit") or False as_dict = request.json.get("as_dict") or False - success, output = Server().benches[bench].sites[site].run_sql_query(query, commit, as_dict) - return {"success": success, "output": output} + success, data = Server().benches[bench].sites[site].run_sql_query(query, commit, as_dict) + return {"success": success, "data": data} @application.route( From db8a4064d3e7b6991b87b5e92bee215065d21d4e Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Mon, 21 Oct 2024 06:08:04 +0000 Subject: [PATCH 15/28] chore: ignore _sql function complexity --- agent/database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agent/database.py b/agent/database.py index 2f8fd0fe..8087ea70 100644 --- a/agent/database.py +++ b/agent/database.py @@ -35,7 +35,7 @@ def execute_query(self, query: str, commit: bool = False, as_dict: bool = False) ) # Private helper methods - def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = False) -> dict | None: + def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = False) -> dict | None: # noqa: C901 """ Run sql query in database It supports multi-line SQL queries. Each SQL Query should be terminated with `;\n` From 0ef1a7038a3ddb070d128a0e3c1f1fa3a021fad1 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Mon, 21 Oct 2024 06:10:52 +0000 Subject: [PATCH 16/28] chore: ruff format --- agent/database.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/agent/database.py b/agent/database.py index 8087ea70..fd7fd79e 100644 --- a/agent/database.py +++ b/agent/database.py @@ -35,7 +35,7 @@ def execute_query(self, query: str, commit: bool = False, as_dict: bool = False) ) # Private helper methods - def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = False) -> dict | None: # noqa: C901 + def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = False) -> dict | None: # noqa: C901 """ Run sql query in database It supports multi-line SQL queries. Each SQL Query should be terminated with `;\n` @@ -106,11 +106,7 @@ def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = Fals output = list(map(lambda x: dict(zip(columns, x)), rows)) else: output = {"columns": columns, "data": rows} - results.append({ - "query": q, - "output": output, - "row_count": row_count - }) + results.append({"query": q, "output": output, "row_count": row_count}) except: # if query execution fails, rollback the transaction and raise the error self.db.rollback() From 2616bc694a1272f1b9af520d6f4652891a4b5020 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Mon, 21 Oct 2024 08:15:31 +0000 Subject: [PATCH 17/28] fix: implement db.atomic based transaction --- agent/database.py | 66 ++++++++++++++++++++++++++--------------------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/agent/database.py b/agent/database.py index fd7fd79e..5c93ce70 100644 --- a/agent/database.py +++ b/agent/database.py @@ -1,5 +1,7 @@ from __future__ import annotations +import contextlib + from peewee import InternalError, MySQLDatabase, ProgrammingError @@ -91,37 +93,41 @@ def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = Fals # Start transaction self.db.begin() results = [] - try: - for q in queries: - if not commit and self._is_restricted_query_for_no_commit_mode(query): - raise ProgrammingError("Provided query is not allowed in read only mode") - output = None - row_count = None - cursor = self.db.execute_sql(q, params) - row_count = cursor.rowcount - if cursor.description: - rows = cursor.fetchall() - columns = [d[0] for d in cursor.description] - if as_dict: - output = list(map(lambda x: dict(zip(columns, x)), rows)) - else: - output = {"columns": columns, "data": rows} - results.append({"query": q, "output": output, "row_count": row_count}) - except: - # if query execution fails, rollback the transaction and raise the error - self.db.rollback() - raise - else: - if commit: - # If commit is True, try to commit the transaction - try: - self.db.commit() - except: - self.db.rollback() - raise + with self.db.atomic() as transaction: + try: + for q in queries: + if not commit and self._is_restricted_query_for_no_commit_mode(q): + raise ProgrammingError("Provided query is not allowed in read only mode") + output = None + row_count = None + cursor = self.db.execute_sql(q, params) + row_count = cursor.rowcount + if cursor.description: + rows = cursor.fetchall() + columns = [d[0] for d in cursor.description] + if as_dict: + output = list(map(lambda x: dict(zip(columns, x)), rows)) + else: + output = {"columns": columns, "data": rows} + results.append({"query": q, "output": output, "row_count": row_count}) + except: + # if query execution fails, rollback the transaction and raise the error + transaction.rollback() + raise else: - # If commit is False, rollback the transaction to discard the changes - self.db.rollback() + if commit: + # If commit is True, try to commit the transaction + try: + transaction.commit() + except: + transaction.rollback() + raise + else: + # If commit is False, rollback the transaction to discard the changes + transaction.rollback() + + with contextlib.suppress(Exception): + self.db.close() return results def _is_restricted_query_for_no_commit_mode(self, query: str) -> bool: From 31f8287c4fc8e8f974a4259d65437b42a57b9c1a Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Mon, 21 Oct 2024 08:36:37 +0000 Subject: [PATCH 18/28] feat: send failed query to request as well --- agent/database.py | 3 ++- agent/site.py | 12 +++++++++--- agent/web.py | 3 +-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/agent/database.py b/agent/database.py index 5c93ce70..85f670e3 100644 --- a/agent/database.py +++ b/agent/database.py @@ -96,8 +96,9 @@ def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = Fals with self.db.atomic() as transaction: try: for q in queries: + self.last_executed_query = q if not commit and self._is_restricted_query_for_no_commit_mode(q): - raise ProgrammingError("Provided query is not allowed in read only mode") + raise ProgrammingError(f"Provided query is not allowed in read only mode") output = None row_count = None cursor = self.db.execute_sql(q, params) diff --git a/agent/site.py b/agent/site.py index 5fe20870..c3d1034d 100644 --- a/agent/site.py +++ b/agent/site.py @@ -832,9 +832,15 @@ def get_database_table_schemas(self): return tables def run_sql_query(self, query: str, commit: bool = False, as_dict: bool = False): - return Database(self.host, 3306, self.user, self.password, self.database).execute_query( - query, commit=commit, as_dict=as_dict - ) + database = Database(self.host, 3306, self.user, self.password, self.database) + success, output = database.execute_query(query, commit=commit, as_dict=as_dict) + response = { + "success": success, + "data": output + } + if not success and hasattr(database, "last_executed_query"): + response["failed_query"] = database.last_executed_query + return response @property def job_record(self): diff --git a/agent/web.py b/agent/web.py index 45f588a2..3c0ecfd5 100644 --- a/agent/web.py +++ b/agent/web.py @@ -561,8 +561,7 @@ def run_sql(bench, site): query = request.json.get("query") commit = request.json.get("commit") or False as_dict = request.json.get("as_dict") or False - success, data = Server().benches[bench].sites[site].run_sql_query(query, commit, as_dict) - return {"success": success, "data": data} + return Server().benches[bench].sites[site].run_sql_query(query, commit, as_dict) @application.route( From a6f2f970d7e9c3b554f32741e9d41dff4380cb0b Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Mon, 21 Oct 2024 08:56:51 +0000 Subject: [PATCH 19/28] feat: ignore commented sql queries --- agent/database.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/agent/database.py b/agent/database.py index 85f670e3..bdfe4014 100644 --- a/agent/database.py +++ b/agent/database.py @@ -85,7 +85,7 @@ def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = Fals """ queries = [x.strip() for x in query.split(";\n")] - queries = [x for x in queries if x] + queries = [x for x in queries if x and not x.startswith("--")] if len(queries) == 0: raise ProgrammingError("No query provided") @@ -98,7 +98,7 @@ def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = Fals for q in queries: self.last_executed_query = q if not commit and self._is_restricted_query_for_no_commit_mode(q): - raise ProgrammingError(f"Provided query is not allowed in read only mode") + raise ProgrammingError("Provided query is not allowed in read only mode") output = None row_count = None cursor = self.db.execute_sql(q, params) From caaf3fe8e3a71671bf3fd9209ca95dd9efaa41fd Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Mon, 21 Oct 2024 10:26:27 +0000 Subject: [PATCH 20/28] feat: send index info with schemas --- agent/site.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/agent/site.py b/agent/site.py index c3d1034d..dce47bd1 100644 --- a/agent/site.py +++ b/agent/site.py @@ -797,6 +797,7 @@ def get_database_free_tables(self): return [] def get_database_table_schemas(self): + index_info = self.get_database_table_indexes() command = f"""SELECT TABLE_NAME AS `table`, COLUMN_NAME AS `column`, @@ -827,17 +828,44 @@ def get_database_table_schemas(self): "data_type": row[2], "is_nullable": row[3] == "YES", "default": row[4], + "indexes": index_info.get(table, {}).get(row[1], []), } ) return tables + def get_database_table_indexes(self): + command = f""" + SELECT + TABLE_NAME AS `table`, + COLUMN_NAME AS `column`, + INDEX_NAME AS `index` + FROM + INFORMATION_SCHEMA.STATISTICS + WHERE + TABLE_SCHEMA='{self.database}' + """ + command = quote(command) + data = self.execute( + f"mysql -sN -h {self.host} -u{self.user} -p{self.password} -e {command} --batch" + ).get("output") + data = data.split("\n") + data = [line.split("\t") for line in data] + tables = {} # : { : [, , ...] } + for row in data: + if len(row) != 3: + continue + table = row[0] + if table not in tables: + tables[table] = {} + if row[1] not in tables[table]: + tables[table][row[1]] = [] + tables[table][row[1]].append(row[2]) + return tables + def run_sql_query(self, query: str, commit: bool = False, as_dict: bool = False): database = Database(self.host, 3306, self.user, self.password, self.database) success, output = database.execute_query(query, commit=commit, as_dict=as_dict) - response = { - "success": success, - "data": output - } + response = {"success": success, "data": output} if not success and hasattr(database, "last_executed_query"): response["failed_query"] = database.last_executed_query return response From 5b53dd86249c80c151446e9f17a7d5d8af65a1b3 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Tue, 22 Oct 2024 01:35:27 +0530 Subject: [PATCH 21/28] feat(Test): added basic structure for writing db method tests --- agent/database.py | 7 +-- agent/tests/test_database.py | 100 +++++++++++++++++++++++++++++++++++ dev-requirements.txt | 1 + 3 files changed, 105 insertions(+), 3 deletions(-) create mode 100644 agent/tests/test_database.py diff --git a/agent/database.py b/agent/database.py index bdfe4014..9f569a49 100644 --- a/agent/database.py +++ b/agent/database.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +from typing import Any from peewee import InternalError, MySQLDatabase, ProgrammingError @@ -17,7 +18,7 @@ def __init__(self, host, port, user, password, database): ) # Methods - def execute_query(self, query: str, commit: bool = False, as_dict: bool = False) -> list[bool, str]: + def execute_query(self, query: str, commit: bool = False, as_dict: bool = False) -> tuple[bool, Any]: """ This function will take the query and run in database. @@ -30,14 +31,14 @@ def execute_query(self, query: str, commit: bool = False, as_dict: bool = False) except (ProgrammingError, InternalError) as e: return False, str(e) except Exception as e: - print(f"Error executing SQL Query on {self.database} : {e}") + print(f"Error executing SQL Query on {self.db} : {e}") return ( False, "Failed to execute query due to unknown error. Please check the query and try again later.", ) # Private helper methods - def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = False) -> dict | None: # noqa: C901 + def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = False) -> list[dict]: # noqa: C901 """ Run sql query in database It supports multi-line SQL queries. Each SQL Query should be terminated with `;\n` diff --git a/agent/tests/test_database.py b/agent/tests/test_database.py new file mode 100644 index 00000000..bffb45be --- /dev/null +++ b/agent/tests/test_database.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import unittest +from shlex import quote + +from testcontainers.mysql import MySqlContainer + +from agent.database import Database + + +class DatabaseTestInstance: + # Test database instance with few utility functions + def __init__(self) -> None: + self.db_root_password = "123456" + self.db_container = MySqlContainer(image="mysql:8.0", root_password=self.db_root_password) + self.db_container.start() + + @property + def host(self): + return self.db_container.get_container_host_ip() + + @property + def port(self): + return int(self.db_container.get_exposed_port(3306)) + + def destroy(self) -> None: + self.db_container.stop(force=True, delete_volume=True) + + def execute_cmd(self, cmd) -> str: + code, output = self.db_container.exec(cmd) + if code != 0: + raise Exception(output.decode()) + return output.decode() + + def create_database(self, db_name: str): + query = quote(f"CREATE DATABASE {db_name}") + root_password = quote(self.db_root_password) + self.execute_cmd(f"mysql -h 127.0.0.1 -uroot -p{root_password} -e {query}") + + def create_database_user(self, db_name: str, username: str, password: str): + queries = [ + f"CREATE USER '{username}'@'%' IDENTIFIED BY '{password}'", + f"GRANT ALL ON {db_name}.* TO '{username}'@'%'", + "FLUSH PRIVILEGES", + ] + for query in queries: + command = f'mysql -h 127.0.0.1 -uroot -p{self.db_root_password} -e "{query}"' + self.execute_cmd(command) + + +class TestDatabase(unittest.TestCase): + def setUp(self) -> None: + self.instance = DatabaseTestInstance() + + # create test databases (db1, db2) with user + self.db1__name = "db1" + self.db1__username = "db1_dummy_user1" + self.db1__password = "db1_dummy_password" + self._setup_db(self.db1__name, self.db1__username, self.db1__password) + + self.db2__name = "db2" + self.db2__username = "db2_dummy_user1" + self.db2__password = "db1_dummy_password" + self._setup_db(self.db2__name, self.db2__username, self.db2__password) + + def _setup_db(self, db_name: str, username: str, password: str): + self.instance.create_database(db_name) + self.instance.create_database_user(db_name, username, password) + + db = self._db(db_name, username, password) + success, _ = db.execute_query( + """ + CREATE TABLE Person ( + id int, + name varchar(255) + ); + INSERT INTO Person (id, name) VALUES (1, "John Doe"); + INSERT INTO Person (id, name) VALUES (2, "Jane Smith"); + INSERT INTO Person (id, name) VALUES (3, "Alice Johnson"); + INSERT INTO Person (id, name) VALUES (4, "Bob Brown"); + INSERT INTO Person (id, name) VALUES (5, "Charlie Davis"); + """, + commit=True, + as_dict=True, + ) + if not success: + raise Exception(f"Failed to prepare test database ({db_name})") + + def tearDown(self) -> None: + self.instance.destroy() + + def _db(self, db_name: str, username: str, password: str) -> Database: + return Database(self.instance.host, self.instance.port, username, password, db_name) + + def test_execute_query(self): + """Basic test for `execute_query` function""" + db = self._db(self.db1__name, self.db1__username, self.db1__password) + success, data = db.execute_query("SELECT * FROM Person", commit=False, as_dict=True) + self.assertEqual(success, True, "run sql query") + self.assertEqual(data[0]["row_count"], 5) diff --git a/dev-requirements.txt b/dev-requirements.txt index 20420596..0d70ae8b 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,2 +1,3 @@ pre-commit black +testcontainers[mysql] From 428a0b1a9a203a4a00fd6f86c60dfe47612e8292 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Tue, 22 Oct 2024 01:40:09 +0530 Subject: [PATCH 22/28] chore(ci): update workflow to install test python packages --- .github/workflows/main.yml | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8bad9d15..3405231e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -6,12 +6,12 @@ on: - master pull_request: branches: - - '*' + - "*" workflow_dispatch: jobs: lint-and-format: - name: 'Lint and Format' + name: "Lint and Format" runs-on: ubuntu-latest steps: @@ -21,7 +21,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.8' + python-version: "3.8" - name: Install dependencies run: | @@ -36,7 +36,7 @@ jobs: ruff format --check unit-tests: - name: 'Unit Tests' + name: "Unit Tests" runs-on: ubuntu-latest steps: @@ -46,13 +46,17 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.8' + python-version: "3.8" - name: Install agent run: | python -m venv env env/bin/pip install -e . - + + - name: Install development packages + run: | + env/bin/pip install -r dev-requirements.txt + - name: Setup agent run: | source env/bin/activate From 4ab0f7eb7f9595a4d97d677ce7c34c322ea68571 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Tue, 22 Oct 2024 01:44:01 +0530 Subject: [PATCH 23/28] chore(ci): pin version of testcontainers --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 0d70ae8b..a1965c81 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,3 +1,3 @@ pre-commit black -testcontainers[mysql] +testcontainers[mysql]==4.8.2 From 1f4fddfa0885f494dd46b18a74f005f2f90cfbe7 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Tue, 22 Oct 2024 01:50:05 +0530 Subject: [PATCH 24/28] fix(ci): use older version of testcontainers for python < 3.9 --- agent/tests/test_database.py | 16 ++++++++++++---- dev-requirements.txt | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/agent/tests/test_database.py b/agent/tests/test_database.py index bffb45be..911d6ea2 100644 --- a/agent/tests/test_database.py +++ b/agent/tests/test_database.py @@ -12,7 +12,9 @@ class DatabaseTestInstance: # Test database instance with few utility functions def __init__(self) -> None: self.db_root_password = "123456" - self.db_container = MySqlContainer(image="mysql:8.0", root_password=self.db_root_password) + self.db_container = MySqlContainer( + image="mysql:8.0", MYSQL_ROOT_PASSWORD=self.db_root_password + ) self.db_container.start() @property @@ -44,7 +46,9 @@ def create_database_user(self, db_name: str, username: str, password: str): "FLUSH PRIVILEGES", ] for query in queries: - command = f'mysql -h 127.0.0.1 -uroot -p{self.db_root_password} -e "{query}"' + command = ( + f'mysql -h 127.0.0.1 -uroot -p{self.db_root_password} -e "{query}"' + ) self.execute_cmd(command) @@ -90,11 +94,15 @@ def tearDown(self) -> None: self.instance.destroy() def _db(self, db_name: str, username: str, password: str) -> Database: - return Database(self.instance.host, self.instance.port, username, password, db_name) + return Database( + self.instance.host, self.instance.port, username, password, db_name + ) def test_execute_query(self): """Basic test for `execute_query` function""" db = self._db(self.db1__name, self.db1__username, self.db1__password) - success, data = db.execute_query("SELECT * FROM Person", commit=False, as_dict=True) + success, data = db.execute_query( + "SELECT * FROM Person", commit=False, as_dict=True + ) self.assertEqual(success, True, "run sql query") self.assertEqual(data[0]["row_count"], 5) diff --git a/dev-requirements.txt b/dev-requirements.txt index a1965c81..31076fb2 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,3 +1,3 @@ pre-commit black -testcontainers[mysql]==4.8.2 +testcontainers[mysql]==3.7.1 From e7aeb4e2d7c2f49700eb837244c64b8a4c4011f7 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Tue, 22 Oct 2024 01:50:48 +0530 Subject: [PATCH 25/28] chore: ruff format --- agent/tests/test_database.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/agent/tests/test_database.py b/agent/tests/test_database.py index 911d6ea2..e115db82 100644 --- a/agent/tests/test_database.py +++ b/agent/tests/test_database.py @@ -12,9 +12,7 @@ class DatabaseTestInstance: # Test database instance with few utility functions def __init__(self) -> None: self.db_root_password = "123456" - self.db_container = MySqlContainer( - image="mysql:8.0", MYSQL_ROOT_PASSWORD=self.db_root_password - ) + self.db_container = MySqlContainer(image="mysql:8.0", MYSQL_ROOT_PASSWORD=self.db_root_password) self.db_container.start() @property @@ -46,9 +44,7 @@ def create_database_user(self, db_name: str, username: str, password: str): "FLUSH PRIVILEGES", ] for query in queries: - command = ( - f'mysql -h 127.0.0.1 -uroot -p{self.db_root_password} -e "{query}"' - ) + command = f'mysql -h 127.0.0.1 -uroot -p{self.db_root_password} -e "{query}"' self.execute_cmd(command) @@ -94,15 +90,11 @@ def tearDown(self) -> None: self.instance.destroy() def _db(self, db_name: str, username: str, password: str) -> Database: - return Database( - self.instance.host, self.instance.port, username, password, db_name - ) + return Database(self.instance.host, self.instance.port, username, password, db_name) def test_execute_query(self): """Basic test for `execute_query` function""" db = self._db(self.db1__name, self.db1__username, self.db1__password) - success, data = db.execute_query( - "SELECT * FROM Person", commit=False, as_dict=True - ) + success, data = db.execute_query("SELECT * FROM Person", commit=False, as_dict=True) self.assertEqual(success, True, "run sql query") self.assertEqual(data[0]["row_count"], 5) From d4469bae799f07c6d0f7b9d6f111a0a93453be20 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Tue, 22 Oct 2024 01:53:20 +0530 Subject: [PATCH 26/28] chore(ci): add cryptography package --- dev-requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 31076fb2..ff306b75 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,3 +1,4 @@ pre-commit black -testcontainers[mysql]==3.7.1 +testcontainers[mysql]==3.7.1i +cryptography From fa2439e3f871181b6477e1a921a4e6f1cc08d3af Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Tue, 22 Oct 2024 01:54:54 +0530 Subject: [PATCH 27/28] chore(ci): add cryptography package --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index ff306b75..519d0717 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,4 +1,4 @@ pre-commit black -testcontainers[mysql]==3.7.1i +testcontainers[mysql]==3.7.1 cryptography From 5652143ba1c119246399a3d2b4d5e7cecc8b4976 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Tue, 22 Oct 2024 11:14:31 +0530 Subject: [PATCH 28/28] feat: added testcases for _run_sql function --- agent/database.py | 27 +++---- agent/tests/test_database.py | 146 +++++++++++++++++++++++++++++++++-- 2 files changed, 148 insertions(+), 25 deletions(-) diff --git a/agent/database.py b/agent/database.py index 9f569a49..bba86435 100644 --- a/agent/database.py +++ b/agent/database.py @@ -8,14 +8,7 @@ class Database: def __init__(self, host, port, user, password, database): - self.db: MySQLDatabase = MySQLDatabase( - database, - user=user, - password=password, - host=host, - port=port, - autocommit=False, - ) + self.db: MySQLDatabase = MySQLDatabase(database, user=user, password=password, host=host, port=port) # Methods def execute_query(self, query: str, commit: bool = False, as_dict: bool = False) -> tuple[bool, Any]: @@ -27,18 +20,17 @@ def execute_query(self, query: str, commit: bool = False, as_dict: bool = False) str: The output of the query. It can be the output or error message as well """ try: - return True, self._sql(query, commit=commit, as_dict=as_dict) + return True, self._run_sql(query, commit=commit, as_dict=as_dict) except (ProgrammingError, InternalError) as e: return False, str(e) - except Exception as e: - print(f"Error executing SQL Query on {self.db} : {e}") + except Exception: return ( False, "Failed to execute query due to unknown error. Please check the query and try again later.", ) # Private helper methods - def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = False) -> list[dict]: # noqa: C901 + def _run_sql(self, query: str, params=(), commit: bool = False, as_dict: bool = False) -> list[dict]: # noqa: C901 """ Run sql query in database It supports multi-line SQL queries. Each SQL Query should be terminated with `;\n` @@ -98,8 +90,12 @@ def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = Fals try: for q in queries: self.last_executed_query = q - if not commit and self._is_restricted_query_for_no_commit_mode(q): - raise ProgrammingError("Provided query is not allowed in read only mode") + if not commit and self._is_ddl_query(q): + raise ProgrammingError("Provided DDL query is not allowed in read only mode") + if self._is_dcl_query(q): + raise ProgrammingError("DCL query is not allowed to execute") + if self._is_tcl_query(q): + raise ProgrammingError("TCL query is not allowed to execute") output = None row_count = None cursor = self.db.execute_sql(q, params) @@ -132,9 +128,6 @@ def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = Fals self.db.close() return results - def _is_restricted_query_for_no_commit_mode(self, query: str) -> bool: - return self._is_ddl_query(query) or self._is_dcl_query(query) or self._is_tcl_query(query) - def _is_ddl_query(self, query: str) -> bool: return query.upper().startswith(("CREATE", "ALTER", "DROP", "TRUNCATE", "RENAME", "COMMENT")) diff --git a/agent/tests/test_database.py b/agent/tests/test_database.py index e115db82..cb0c6de0 100644 --- a/agent/tests/test_database.py +++ b/agent/tests/test_database.py @@ -9,7 +9,8 @@ class DatabaseTestInstance: - # Test database instance with few utility functions + """Test database instance with few utility functions""" + def __init__(self) -> None: self.db_root_password = "123456" self.db_container = MySqlContainer(image="mysql:8.0", MYSQL_ROOT_PASSWORD=self.db_root_password) @@ -37,6 +38,11 @@ def create_database(self, db_name: str): root_password = quote(self.db_root_password) self.execute_cmd(f"mysql -h 127.0.0.1 -uroot -p{root_password} -e {query}") + def remove_database(self, db_name: str): + query = quote(f"DROP DATABASE IF EXISTS {db_name}") + root_password = quote(self.db_root_password) + self.execute_cmd(f"mysql -h 127.0.0.1 -uroot -p{root_password} -e {query}") + def create_database_user(self, db_name: str, username: str, password: str): queries = [ f"CREATE USER '{username}'@'%' IDENTIFIED BY '{password}'", @@ -47,11 +53,25 @@ def create_database_user(self, db_name: str, username: str, password: str): command = f'mysql -h 127.0.0.1 -uroot -p{self.db_root_password} -e "{query}"' self.execute_cmd(command) + def remove_database_user(self, username: str): + query = quote(f""" + DROP USER IF EXISTS '{username}'@'%'; + FLUSH PRIVILEGES; + """) + root_password = quote(self.db_root_password) + self.execute_cmd(f"mysql -h 127.0.0.1 -uroot -p{root_password} -e {query}") + class TestDatabase(unittest.TestCase): - def setUp(self) -> None: - self.instance = DatabaseTestInstance() + @classmethod + def setUpClass(cls): + cls.instance = DatabaseTestInstance() + @classmethod + def tearDownClass(cls): + cls.instance.destroy() + + def setUp(self) -> None: # create test databases (db1, db2) with user self.db1__name = "db1" self.db1__username = "db1_dummy_user1" @@ -64,6 +84,7 @@ def setUp(self) -> None: self._setup_db(self.db2__name, self.db2__username, self.db2__password) def _setup_db(self, db_name: str, username: str, password: str): + # setup self.instance.create_database(db_name) self.instance.create_database_user(db_name, username, password) @@ -87,14 +108,123 @@ def _setup_db(self, db_name: str, username: str, password: str): raise Exception(f"Failed to prepare test database ({db_name})") def tearDown(self) -> None: - self.instance.destroy() + self.instance.remove_database(self.db1__name) + self.instance.remove_database_user(self.db1__username) + + self.instance.remove_database(self.db2__name) + self.instance.remove_database_user(self.db2__username) def _db(self, db_name: str, username: str, password: str) -> Database: return Database(self.instance.host, self.instance.port, username, password, db_name) - def test_execute_query(self): - """Basic test for `execute_query` function""" + # Test cases for _run_sql method + def test_run_sql_fn(self): + """Basic test for `_run_sql` function""" db = self._db(self.db1__name, self.db1__username, self.db1__password) - success, data = db.execute_query("SELECT * FROM Person", commit=False, as_dict=True) - self.assertEqual(success, True, "run sql query") + data = db._run_sql("SELECT * FROM Person", commit=False, as_dict=True) self.assertEqual(data[0]["row_count"], 5) + + def test_db1_user_shouldnt_be_able_to_access_db2(self): + db = self._db(self.db2__name, self.db1__username, self.db1__password) + with self.assertRaises(Exception) as cm: + db._run_sql( + """ + SELECT * + FROM Person + WHERE name = 'Bob Brown' + """, + commit=False, + as_dict=True, + ) + self.assertIn("Access denied for user 'db1_dummy_user1'@'%' to database 'db2'", str(cm.exception)) + + def test_run_sql_fn_with_commit_disabled_shouldnt_allow_ddl_queries(self): + db = self._db(self.db1__name, self.db1__username, self.db1__password) + with self.assertRaises(Exception) as cm: + db._run_sql( + """ + CREATE TABLE Person2 ( + id int, + name varchar(255) + ); + """, + commit=False, + as_dict=True, + ) + self.assertIn( + "Provided DDL query is not allowed in read only mode", + str(cm.exception), + "DDL Query should be failed for non-commit mode", + ) + + def test_run_sql_fn_shouldnt_allow_dcl_queries(self): + db = self._db(self.db1__name, self.db1__username, self.db1__password) + with self.assertRaises(Exception) as cm: + db._run_sql( + """ + REVOKE ALL PRIVILEGES ON *.* FROM 'db1_dummy_user1'@'%'; + """, + commit=True, + as_dict=True, + ) + self.assertIn( + "DCL query is not allowed to execute", + str(cm.exception), + "DCL queries should be failed in any condition", + ) + + def test_run_sql_fn_shouldnt_allow_tcl_queries(self): + db = self._db(self.db1__name, self.db1__username, self.db1__password) + with self.assertRaises(Exception) as cm: + db._run_sql( + """ + COMMIT; + """, + commit=True, + as_dict=True, + ) + self.assertIn( + "TCL query is not allowed to execute", + str(cm.exception), + "TCL queries should be failed in any condition", + ) + + def test_run_sql_fn_with_commit_enabled_should_persist_changes(self): + db = self._db(self.db1__name, self.db1__username, self.db1__password) + db._run_sql( + """ + INSERT INTO Person (id, name) VALUES (6, "John Doe2"); + """, + commit=True, + as_dict=True, + ) + data = db._run_sql( + """ + SELECT * + FROM Person + WHERE name = 'John Doe2' + """, + commit=False, + as_dict=True, + ) + self.assertEqual(data[0]["row_count"], 1) + + def test_run_sql_fn_with_commit_disabled_should_not_persist_changes(self): + db = self._db(self.db1__name, self.db1__username, self.db1__password) + db._run_sql( + """ + INSERT INTO Person (id, name) VALUES (6, "John Doe2"); + """, + commit=False, + as_dict=True, + ) + data = db._run_sql( + """ + SELECT * + FROM Person + WHERE name = 'John Doe2' + """, + commit=False, + as_dict=True, + ) + self.assertEqual(data[0]["row_count"], 0)