diff --git a/agent/database.py b/agent/database.py index 4c21d323..018e069f 100644 --- a/agent/database.py +++ b/agent/database.py @@ -5,12 +5,13 @@ from decimal import Decimal from typing import Any -from peewee import InternalError, MySQLDatabase, ProgrammingError +import peewee class Database: def __init__(self, host, port, user, password, database): - self.db: MySQLDatabase = MySQLDatabase(database, user=user, password=password, host=host, port=port) + self.database_name = database + self.db: CustomPeeweeDB = CustomPeeweeDB(database, user=user, password=password, host=host, port=port) # Methods def execute_query(self, query: str, commit: bool = False, as_dict: bool = False) -> tuple[bool, Any]: @@ -23,7 +24,7 @@ def execute_query(self, query: str, commit: bool = False, as_dict: bool = False) """ try: return True, self._run_sql(query, commit=commit, as_dict=as_dict) - except (ProgrammingError, InternalError) as e: + except (peewee.ProgrammingError, peewee.InternalError, peewee.OperationalError) as e: return False, str(e) except Exception: return ( @@ -31,19 +32,143 @@ def execute_query(self, query: str, commit: bool = False, as_dict: bool = False) "Failed to execute query due to unknown error. Please check the query and try again later.", ) + """ + NOTE: These methods require root access to the database + - create_user + - remove_user + - modify_user_permissions + """ + + def create_user(self, username: str, password: str): + query = f""" + CREATE OR REPLACE USER '{username}'@'%' IDENTIFIED BY '{password}'; + FLUSH PRIVILEGES; + """ + self._run_sql( + query, + commit=True, + ) + + def remove_user(self, username: str): + self._run_sql( + f""" + DROP USER IF EXISTS '{username}'@'%'; + FLUSH PRIVILEGES; + """, + commit=True, + ) + + def modify_user_permissions(self, username: str, mode: str, permissions: dict | None = None) -> None: # noqa C901 + """ + Args: + username: username of the user, whos privileges are to be modified + mode: permission mode + - read_only: read only access to all tables + - read_write: read write access to all tables + - granular: granular access to tables + + permissions: list of permissions [only required if mode is granular] + { + "": { + "mode": "read_only" // read_only or read_write, + "columns": "*" // "*" or ["column1", "column2", ...] + }, + ... + } + all_read_only: True if you want to make all tables read only for the user + all_read_write: True if you want to make all tables read write for the user + + Returns: + It will return nothing, if anything goes wrong it will raise an exception + """ + if not permissions: + permissions = {} + + if mode not in ["read_only", "read_write", "granular"]: + raise ValueError("mode must be read_only, read_write or granular") + privileges_map = { + "read_only": "SELECT", + "read_write": "ALL", + } + # fetch existing privileges + records = self._run_sql(f"SHOW GRANTS FOR '{username}'@'%';", as_dict=False) + granted_records: list[str] = [] + if len(records) > 0 and records[0]["output"]["data"] and len(records[0]["output"]["data"]) > 0: + granted_records = [x[0] for x in records[0]["output"]["data"] if len(x) > 0] + + queries = [] + """ + First revoke all existing privileges + + Prepare revoke permission sql query + + `Show Grants` output: + GRANT SELECT ON `_cbace6eaa306751d`.* TO `_cbace6eaa306751d_read_only`@`%` + ... + + That need to be converted to this for revoke privileges + REVOKE SELECT ON _cbace6eaa306751d.* FROM '_cbace6eaa306751d_read_only'@'%' + """ + for record in granted_records: + if record.startswith("GRANT USAGE"): + # dont revoke usage + continue + queries.append( + record.replace("GRANT", "REVOKE").replace(f"TO `{username}`@`%`", f"FROM `{username}`@`%`") + + ";" + ) + + # add new privileges + if mode == "read_only" or mode == "read_write": + privilege = privileges_map[mode] + queries.append(f"GRANT {privilege} ON {self.database_name}.* TO `{username}`@`%`;") + elif mode == "granular": + for table_name in permissions: + columns = "" + if isinstance(permissions[table_name]["columns"], list): + if len(permissions[table_name]["columns"]) == 0: + raise ValueError( + "columns cannot be an empty list. please specify '*' or at least one column" + ) + requested_columns = permissions[table_name]["columns"] + columns = ",".join([f"`{x}`" for x in requested_columns]) + columns = f"({columns})" + + privilege = privileges_map[permissions[table_name]["mode"]] + if columns == "" or privilege == "SELECT": + queries.append( + f"GRANT {privilege} {columns} ON `{self.database_name}`.`{table_name}` TO `{username}`@`%`;" # noqa: E501 + ) + else: + # while usisng column level privileges `ALL` doesnt work + # So we need to provide all possible privileges for that columns + for p in ["SELECT", "INSERT", "UPDATE", "REFERENCES"]: + queries.append( + f"GRANT {p} {columns} ON `{self.database_name}`.`{table_name}` TO `{username}`@`%`;" # noqa: E501 + ) + + # flush privileges to apply changes + queries.append("FLUSH PRIVILEGES;") + queries_str = "\n".join(queries) + + self._run_sql(queries_str, commit=True, allow_all_stmt_types=True) + # Private helper methods - def _run_sql(self, query: str, params=(), commit: bool = False, as_dict: bool = False) -> list[dict]: # noqa: C901 + def _run_sql( # noqa C901 + self, query: str, commit: bool = False, as_dict: bool = False, allow_all_stmt_types: bool = False + ) -> list[dict]: """ Run sql query in database It supports multi-line SQL queries. Each SQL Query should be terminated with `;\n` Args: - query: SQL query - params: If you are using parameters in the query, you can pass them as a tuple + query: SQL query string commit: True if you want to commit the changes. If commit is false, it will rollback the changes and also wouldnt allow to run ddl, dcl or tcl queries as_dict: True if you want to return the result as a dictionary (like frappe.db.sql). - Otherwise it will return a dict of columns and data + Otherwise it will return a dict of columns and data + allow_all_stmt_types: True if you want to allow all type of sql statements + Default: False Return Format: For as_dict = True: @@ -83,7 +208,7 @@ def _run_sql(self, query: str, params=(), commit: bool = False, as_dict: bool = queries = [x for x in queries if x and not x.startswith("--")] if len(queries) == 0: - raise ProgrammingError("No query provided") + raise peewee.ProgrammingError("No query provided") # Start transaction self.db.begin() @@ -93,14 +218,14 @@ def _run_sql(self, query: str, params=(), commit: bool = False, as_dict: bool = for q in queries: self.last_executed_query = q if not commit and self._is_ddl_query(q): - raise ProgrammingError("Provided DDL query is not allowed in read only mode") - if self._is_dcl_query(q): - raise ProgrammingError("DCL query is not allowed to execute") - if self._is_tcl_query(q): - raise ProgrammingError("TCL query is not allowed to execute") + raise peewee.ProgrammingError("Provided DDL query is not allowed in read only mode") + if not allow_all_stmt_types and self._is_dcl_query(q): + raise peewee.ProgrammingError("DCL query is not allowed to execute") + if not allow_all_stmt_types and self._is_tcl_query(q): + raise peewee.ProgrammingError("TCL query is not allowed to execute") output = None row_count = None - cursor = self.db.execute_sql(q, params) + cursor = self.db.execute_sql(q) row_count = cursor.rowcount if cursor.description: rows = cursor.fetchall() @@ -146,3 +271,54 @@ def default(self, obj): if isinstance(obj, Decimal): return float(obj) return str(obj) + + +class CustomPeeweeDB(peewee.MySQLDatabase): + """ + Override peewee.MySQLDatabase to modify `execute_sql` method + + All queries coming from end-user has value inside query, so we can't pass the params seperately. + Peewee set `params` arg of `execute_sql` to `()` by default. + + We are overriding `execute_sql` method to pass the params as None + So that, pymysql doesn't try to parse the query and insert params in the query + """ + + __exception_wrapper__ = peewee.ExceptionWrapper( + { + "ConstraintError": peewee.IntegrityError, + "DatabaseError": peewee.DatabaseError, + "DataError": peewee.DataError, + "IntegrityError": peewee.IntegrityError, + "InterfaceError": peewee.InterfaceError, + "InternalError": peewee.InternalError, + "NotSupportedError": peewee.NotSupportedError, + "OperationalError": peewee.OperationalError, + "ProgrammingError": peewee.ProgrammingError, + "TransactionRollbackError": peewee.OperationalError, + } + ) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def execute_sql(self, sql): + if self.in_transaction(): + commit = False + elif self.commit_select: + commit = True + else: + commit = not sql[:6].lower().startswith("select") + + with self.__exception_wrapper__: + cursor = self.cursor(commit) + try: + cursor.execute(sql, None) # params passed as none + except Exception: + if self.autorollback and not self.in_transaction(): + self.rollback() + raise + else: + if commit and not self.in_transaction(): + self.commit() + return cursor diff --git a/agent/site.py b/agent/site.py index cebd5f15..1f13d968 100644 --- a/agent/site.py +++ b/agent/site.py @@ -258,13 +258,45 @@ def revoke_database_access_credentials(self, user, mariadb_root_password): if user == self.user: # Do not revoke access for the main user return {} - queries = [ - f"DROP USER IF EXISTS '{user}'@'%'", - "FLUSH PRIVILEGES", - ] - for query in queries: - command = f"mysql -h {self.host} -uroot -p{mariadb_root_password}" f' -e "{query}"' - self.execute(command) + self.db_instance("root", mariadb_root_password).remove_user(user) + return {} + + @job("Create Database User", priority="high") + def create_database_user_job(self, user, password, mariadb_root_password): + return self.create_database_user(user, password, mariadb_root_password) + + @step("Create Database User") + def create_database_user(self, user, password, mariadb_root_password): + if user == self.user: + # Do not perform any operation for the main user + return {} + self.db_instance("root", mariadb_root_password).create_user(user, password) + return { + "database": self.database, + } + + @job("Remove Database User", priority="high") + def remove_database_user_job(self, user, mariadb_root_password): + return self.remove_database_user(user, mariadb_root_password) + + @step("Remove Database User") + def remove_database_user(self, user, mariadb_root_password): + if user == self.user: + # Do not perform any operation for the main user + return {} + self.db_instance("root", mariadb_root_password).remove_user(user) + return {} + + @job("Modify Database User Permissions", priority="high") + def modify_database_user_permissions_job(self, user, mode, permissions, mariadb_root_password): + return self.modify_database_user_permissions(user, mode, permissions, mariadb_root_password) + + @step("Modify Database User Permissions") + def modify_database_user_permissions(self, user, mode, permissions, mariadb_root_password): + if user == self.user: + # Do not perform any operation for the main user + return {} + self.db_instance("root", mariadb_root_password).modify_user_permissions(user, mode, permissions) return {} @job("Setup ERPNext", priority="high") @@ -868,13 +900,20 @@ def get_database_table_indexes(self): return tables def run_sql_query(self, query: str, commit: bool = False, as_dict: bool = False): - database = Database(self.host, 3306, self.user, self.password, self.database) - success, output = database.execute_query(query, commit=commit, as_dict=as_dict) + db = self.db_instance() + success, output = db.execute_query(query, commit=commit, as_dict=as_dict) response = {"success": success, "data": output} - if not success and hasattr(database, "last_executed_query"): - response["failed_query"] = database.last_executed_query + if not success and hasattr(db, "last_executed_query"): + response["failed_query"] = db.last_executed_query return response + def db_instance(self, username: str | None = None, password: str | None = None) -> Database: + if not username: + username = self.user + if not password: + password = self.password + return Database(self.host, 3306, username, password, self.database) + @property def job_record(self): return self.bench.server.job_record diff --git a/agent/tests/test_database.py b/agent/tests/test_database.py index cb0c6de0..1d0007d2 100644 --- a/agent/tests/test_database.py +++ b/agent/tests/test_database.py @@ -13,7 +13,7 @@ class DatabaseTestInstance: def __init__(self) -> None: self.db_root_password = "123456" - self.db_container = MySqlContainer(image="mysql:8.0", MYSQL_ROOT_PASSWORD=self.db_root_password) + self.db_container = MySqlContainer(image="mariadb:10.6", MYSQL_ROOT_PASSWORD=self.db_root_password) self.db_container.start() @property @@ -100,6 +100,24 @@ def _setup_db(self, db_name: str, username: str, password: str): INSERT INTO Person (id, name) VALUES (3, "Alice Johnson"); INSERT INTO Person (id, name) VALUES (4, "Bob Brown"); INSERT INTO Person (id, name) VALUES (5, "Charlie Davis"); + CREATE TABLE Product ( + id int, + name varchar(255) + ); + INSERT INTO Product (id, name) VALUES (1, "Book"); + INSERT INTO Product (id, name) VALUES (2, "Car"); + INSERT INTO Product (id, name) VALUES (3, "House"); + INSERT INTO Product (id, name) VALUES (4, "Computer"); + INSERT INTO Product (id, name) VALUES (5, "Table"); + CREATE TABLE Account ( + id int, + name varchar(255) + ); + INSERT INTO Account (id, name) VALUES (1, "John Doe"); + INSERT INTO Account (id, name) VALUES (2, "Jane Smith"); + INSERT INTO Account (id, name) VALUES (3, "Alice Johnson"); + INSERT INTO Account (id, name) VALUES (4, "Bob Brown"); + INSERT INTO Account (id, name) VALUES (5, "Charlie Davis"); """, commit=True, as_dict=True, @@ -228,3 +246,219 @@ def test_run_sql_fn_with_commit_disabled_should_not_persist_changes(self): as_dict=True, ) self.assertEqual(data[0]["row_count"], 0) + + def test_create_user(self): + db = self._db(self.db1__name, "root", self.instance.db_root_password) + + # create user + try: + db.create_user("test_user", "test_user_password") + except: + print(f"Failed query: {db.query}") + raise + + # fetch users + success, output = db.execute_query("SELECT User FROM mysql.user;", commit=False, as_dict=False) + self.assertTrue(success) + self.assertIsInstance(output, list) + + users = [x[0] for x in output[0].get("output", []).get("data", [])] + self.assertIn("test_user", users) + + def test_remove_user(self): + db = self._db(self.db1__name, "root", self.instance.db_root_password) + + # add a dummy user + db.create_user("test_user", "test_user_password") + + # remove user + try: + db.remove_user("test_user") + except: + print(f"Failed query: {db.query}") + raise + + # fetch users + success, output = db.execute_query("SELECT User FROM mysql.user;", commit=False, as_dict=False) + self.assertTrue(success) + self.assertIsInstance(output, list) + + users = [x[0] for x in output[0].get("output", []).get("data", [])] + self.assertNotIn("test_user", users) + + def test_create_read_only_permission(self): + db = self._db(self.db1__name, "root", self.instance.db_root_password) + db.create_user("user1", "test_password") # create user + + user1_db = self._db(self.db1__name, "user1", "test_password") + + # try to access the database + success, output = user1_db.execute_query("SELECT * FROM Person;") + self.assertFalse(success, "User `user1` should not have access to the database") + self.assertIn("Access denied for user", str(output)) + + db.modify_user_permissions("user1", "read_only") + + # try to access the database again + success, output = user1_db.execute_query("SELECT * FROM Person;") + self.assertTrue(success, "User `user1` should have read access to the database") + self.assertGreater(len(output), 0) + + # user shouldnt have write access + success, output = user1_db.execute_query('INSERT INTO Person (id, name) VALUES (10, "Test Person");') + self.assertFalse(success, "User `user1` should not have write access to the database") + self.assertIn("INSERT command denied to user", str(output)) + + def test_create_read_write_permission(self): + db = self._db(self.db1__name, "root", self.instance.db_root_password) + db.create_user("user1", "test_password") # create user + + user1_db = self._db(self.db1__name, "user1", "test_password") + + # try to access the database + success, output = user1_db.execute_query("SELECT * FROM Person;") + self.assertFalse(success, "User `user1` should not have access to the database") + self.assertIn("Access denied for user", str(output)) + + db.modify_user_permissions("user1", "read_write") + + # try to access the database again + success, output = user1_db.execute_query("SELECT * FROM Person;") + self.assertTrue(success, "User `user1` should have read access to the database") + self.assertGreater(len(output), 0) + + # user should have write access + success, _ = user1_db.execute_query('INSERT INTO Person (id, name) VALUES (10, "Test Person");') + self.assertTrue(success, "User `user1` should have write access to the database") + + def test_granular_permission_table_level(self): + db = self._db(self.db1__name, "root", self.instance.db_root_password) + user1_db = self._db(self.db1__name, "user1", "test_password") + + db.create_user("user1", "test_password") # create user + + # modify access + try: + db.modify_user_permissions( + "user1", + "granular", + permissions={ + "Person": {"mode": "read_only", "columns": "*"}, + "Product": {"mode": "read_write", "columns": "*"}, + }, + ) + except Exception: + if hasattr(db, "last_executed_query"): + print("Failed query: ", db.last_executed_query) + raise + + # verify access for `Person` table + success, output = user1_db.execute_query("SELECT * FROM Person;") + self.assertTrue(success, "User `user1` should have read access to `Person` table") + self.assertGreater(len(output), 0) + + success, output = user1_db.execute_query('INSERT INTO Person (id, name) VALUES (10, "Test Person");') + self.assertFalse(success, "User `user1` should not have write access to `Person` table") + self.assertIn("INSERT command denied to user", str(output)) + + # verify access for `Product` table + success, output = user1_db.execute_query("SELECT * FROM Product;") + self.assertTrue(success, "User `user1` should have read access to `Product` table") + self.assertGreater(len(output), 0) + + success, _ = user1_db.execute_query('INSERT INTO Product (id, name) VALUES (10, "Test Product");') + self.assertTrue(success, "User `user1` should have write access to `Product` table") + + # verify access for `Account` table + success, output = user1_db.execute_query("SELECT * FROM Account;") + self.assertFalse(success, "User `user1` should not have access to `Account` table") + self.assertIn("SELECT command denied to user", str(output)) + + # purge access and verify + db.modify_user_permissions("user1", "granular", permissions={}) + + success, output = user1_db.execute_query("SELECT * FROM Person;") + self.assertFalse(success, "User `user1` should not have access to `Person` table") + success, output = user1_db.execute_query("SELECT * FROM Product;") + self.assertFalse(success, "User `user1` should not have access to `Product` table") + success, output = user1_db.execute_query("SELECT * FROM Account;") + self.assertFalse(success, "User `user1` should not have access to `Account` table") + + def test_granular_permission_column_level(self): + db = self._db(self.db1__name, "root", self.instance.db_root_password) + user1_db = self._db(self.db1__name, "user1", "test_password") + + db.create_user("user1", "test_password") # create user + + # modify access [read_only] + try: + db.modify_user_permissions( + "user1", "granular", permissions={"Person": {"mode": "read_only", "columns": ["id"]}} + ) + except Exception: + if hasattr(db, "last_executed_query"): + print("Failed query: ", db.last_executed_query) + raise + + # verify access for `Person` table + success, output = user1_db.execute_query("SELECT * FROM Person;") + self.assertFalse(success, "User `user1` should not have read access to all columns of `Person` table") + self.assertIn("SELECT command denied to user", str(output)) + + success, output = user1_db.execute_query("SELECT id FROM Person;") + self.assertTrue(success, "User `user1` should have read access to `id` column of `Person` table") + self.assertGreater(len(output), 0) + + # modify access [read_write] + try: + db.modify_user_permissions( + "user1", "granular", permissions={"Person": {"mode": "read_write", "columns": ["id"]}} + ) + except Exception: + if hasattr(db, "last_executed_query"): + print("Failed query: ", db.last_executed_query) + raise + + # verify access for `Person` table + success, output = user1_db.execute_query("SELECT * FROM Person;") + self.assertFalse(success, "User `user1` should not have read access to all columns of `Person` table") + self.assertIn("SELECT command denied to user", str(output)) + + success, output = user1_db.execute_query("SELECT id FROM Person;") + self.assertTrue(success, "User `user1` should have read access to `id` column of `Person` table") + self.assertGreater(len(output), 0) + + success, output = user1_db.execute_query("UPDATE Person SET name = 'Columbia' WHERE id = 1;") + self.assertFalse(success, "User `user1` should have write access to `name` col of `Person` table") + self.assertIn("UPDATE command denied to user", str(output)) + + success, output = user1_db.execute_query("UPDATE Person SET id = 2 WHERE name = 'Columbia';") + self.assertFalse(success, "User `user1` should have write access to `id` col of `Person` table") + self.assertIn("SELECT command denied to user", str(output)) + + success, output = user1_db.execute_query("UPDATE Person SET id = 50 WHERE id = 1;") + self.assertTrue(success, "User `user1` should have write access to `id` col of `Person` table") + + success, output = user1_db.execute_query("DELETE FROM Person WHERE id = 2;") + self.assertFalse(success, "User `user1` should have write access to `id` col of `Person` table") + self.assertIn("DELETE command denied to user", str(output)) + + def test_revoke_permission(self): + db = self._db(self.db1__name, "root", self.instance.db_root_password) + db.create_user("user1", "test_password") # create user + db.modify_user_permissions("user1", "read_only") + + user1_db = self._db(self.db1__name, "user1", "test_password") + + # try to access the database + success, _ = user1_db.execute_query("SELECT * FROM Person;") + self.assertTrue(success, "User `user1` should have read access to the database") + + # revoke permission + # setting no permission in granular mode, will revoke all existing permissions + db.modify_user_permissions("user1", "granular", permissions={}) + + # try to access the database again + success, output = user1_db.execute_query("SELECT * FROM Person;") + self.assertFalse(success, "User `user1` should not have access to the database") + self.assertIn("Access denied for user", str(output)) diff --git a/agent/web.py b/agent/web.py index 151b71d2..28264039 100644 --- a/agent/web.py +++ b/agent/web.py @@ -572,6 +572,47 @@ def run_sql(bench, site): ) +@application.route("/benches//sites//database/users", methods=["POST"]) +@validate_bench_and_site +def create_database_user(bench, site): + data = request.json + job = ( + Server() + .benches[bench] + .sites[site] + .create_database_user_job(data["username"], data["password"], data["mariadb_root_password"]) + ) + return {"job": job} + + +@application.route( + "/benches//sites//database/users/", methods=["DELETE"] +) +@validate_bench_and_site +def remove_database_user(bench, site, db_user): + data = request.json + job = Server().benches[bench].sites[site].remove_database_user_job(db_user, data["mariadb_root_password"]) + return {"job": job} + + +@application.route( + "/benches//sites//database/users//permissions", + methods=["POST"], +) +@validate_bench_and_site +def update_database_permissions(bench, site, db_user): + data = request.json + job = ( + Server() + .benches[bench] + .sites[site] + .modify_database_user_permissions_job( + db_user, data["mode"], data.get("permissions", {}), data["mariadb_root_password"] + ) + ) + return {"job": job} + + @application.route( "/benches//sites//migrate", methods=["POST"],