Skip to content

Commit

Permalink
feat(database): apis for creating user, removing db user, permission …
Browse files Browse the repository at this point in the history
…updates
  • Loading branch information
tanmoysrt committed Oct 31, 2024
1 parent 1387dae commit 503d39c
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 34 deletions.
12 changes: 6 additions & 6 deletions agent/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def execute_query(self, query: str, commit: bool = False, as_dict: bool = False)

"""
NOTE: These methods require root access to the database
- add_user
- create_user
- remove_user
- modify_user_access
"""

def add_user(self, username: str, password: str):
def create_user(self, username: str, password: str):
query = f"""
CREATE OR REPLACE USER '{username}'@'%' IDENTIFIED BY '{password}';
FLUSH PRIVILEGES;
Expand All @@ -58,7 +58,7 @@ def remove_user(self, username: str):
commit=True,
)

def modify_user_permission(self, username: str, mode: str, permissions: dict | None = None) -> None: # noqa C901
def modify_user_permissions(self, username: str, mode: str, permissions: dict | None = None) -> None: # noqa C901
"""
Args:
username: username of the user, whos privileges are to be modified
Expand Down Expand Up @@ -86,7 +86,7 @@ def modify_user_permission(self, username: str, mode: str, permissions: dict | N

if mode not in ["read_only", "read_write", "granular"]:
raise ValueError("mode must be read_only, read_write or granular")
privileges = {
privileges_map = {
"read_only": "SELECT",
"read_write": "ALL",
}
Expand Down Expand Up @@ -120,7 +120,7 @@ def modify_user_permission(self, username: str, mode: str, permissions: dict | N

# add new privileges
if mode == "read_only" or mode == "read_write":
privilege = privileges[mode]
privilege = privileges_map[mode]
queries.append(f"GRANT {privilege} ON {self.database_name}.* TO `{username}`@`%`;")
elif mode == "granular":
for table_name in permissions:
Expand All @@ -134,7 +134,7 @@ def modify_user_permission(self, username: str, mode: str, permissions: dict | N
columns = ",".join([f"`{x}`" for x in requested_columns])
columns = f"({columns})"

privilege = privileges[permissions[table_name]["mode"]]
privilege = privileges_map[permissions[table_name]["mode"]]
if columns == "" or privilege == "SELECT":
queries.append(
f"GRANT {privilege} {columns} ON `{self.database_name}`.`{table_name}` TO `{username}`@`%`;" # noqa: E501
Expand Down
44 changes: 33 additions & 11 deletions agent/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,28 @@ def revoke_database_access_credentials(self, user, mariadb_root_password):
if user == self.user:
# Do not revoke access for the main user
return {}
queries = [
f"DROP USER IF EXISTS '{user}'@'%'",
"FLUSH PRIVILEGES",
]
for query in queries:
command = f"mysql -h {self.host} -uroot -p{mariadb_root_password}" f' -e "{query}"'
self.execute(command)
self.db_instance("root", mariadb_root_password).remove_user(user)
return {}

def create_database_user(self, user, password, mariadb_root_password):
if user == self.user:
# Do not perform any operation for the main user
return {}
self.db_instance("root", mariadb_root_password).create_user(user, password)
return {}

def remove_database_user(self, user, mariadb_root_password):
if user == self.user:
# Do not perform any operation for the main user
return {}
self.db_instance("root", mariadb_root_password).remove_user(user)
return {}

def modify_permissions_for_database_user(self, user, mode, permissions, mariadb_root_password):
if user == self.user:
# Do not perform any operation for the main user
return {}
self.db_instance("root", mariadb_root_password).modify_user_permissions(user, mode, permissions)
return {}

@job("Setup ERPNext", priority="high")
Expand Down Expand Up @@ -868,13 +883,20 @@ def get_database_table_indexes(self):
return tables

def run_sql_query(self, query: str, commit: bool = False, as_dict: bool = False):
database = Database(self.host, 3306, self.user, self.password, self.database)
success, output = database.execute_query(query, commit=commit, as_dict=as_dict)
db = self.db_instance()
success, output = db.execute_query(query, commit=commit, as_dict=as_dict)
response = {"success": success, "data": output}
if not success and hasattr(database, "last_executed_query"):
response["failed_query"] = database.last_executed_query
if not success and hasattr(db, "last_executed_query"):
response["failed_query"] = db.last_executed_query
return response

def db_instance(self, username: str | None = None, password: str | None = None) -> Database:
if not username:
username = self.user
if not password:
password = self.password
return Database(self.host, 3306, self.user, self.password, self.database)

@property
def job_record(self):
return self.bench.server.job_record
Expand Down
34 changes: 17 additions & 17 deletions agent/tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,12 @@ def test_run_sql_fn_with_commit_disabled_should_not_persist_changes(self):
)
self.assertEqual(data[0]["row_count"], 0)

def test_add_user(self):
def test_create_user(self):
db = self._db(self.db1__name, "root", self.instance.db_root_password)

# add user
# create user
try:
db.add_user("test_user", "test_user_password")
db.create_user("test_user", "test_user_password")
except:
print(f"Failed query: {db.query}")
raise
Expand All @@ -269,7 +269,7 @@ def test_remove_user(self):
db = self._db(self.db1__name, "root", self.instance.db_root_password)

# add a dummy user
db.add_user("test_user", "test_user_password")
db.create_user("test_user", "test_user_password")

# remove user
try:
Expand All @@ -288,7 +288,7 @@ def test_remove_user(self):

def test_create_read_only_permission(self):
db = self._db(self.db1__name, "root", self.instance.db_root_password)
db.add_user("user1", "test_password") # add user
db.create_user("user1", "test_password") # create user

user1_db = self._db(self.db1__name, "user1", "test_password")

Expand All @@ -297,7 +297,7 @@ def test_create_read_only_permission(self):
self.assertFalse(success, "User `user1` should not have access to the database")
self.assertIn("Access denied for user", str(output))

db.modify_user_permission("user1", "read_only")
db.modify_user_permissions("user1", "read_only")

# try to access the database again
success, output = user1_db.execute_query("SELECT * FROM Person;")
Expand All @@ -311,7 +311,7 @@ def test_create_read_only_permission(self):

def test_create_read_write_permission(self):
db = self._db(self.db1__name, "root", self.instance.db_root_password)
db.add_user("user1", "test_password") # add user
db.create_user("user1", "test_password") # create user

user1_db = self._db(self.db1__name, "user1", "test_password")

Expand All @@ -320,7 +320,7 @@ def test_create_read_write_permission(self):
self.assertFalse(success, "User `user1` should not have access to the database")
self.assertIn("Access denied for user", str(output))

db.modify_user_permission("user1", "read_write")
db.modify_user_permissions("user1", "read_write")

# try to access the database again
success, output = user1_db.execute_query("SELECT * FROM Person;")
Expand All @@ -335,11 +335,11 @@ def test_granular_permission_table_level(self):
db = self._db(self.db1__name, "root", self.instance.db_root_password)
user1_db = self._db(self.db1__name, "user1", "test_password")

db.add_user("user1", "test_password") # add user
db.create_user("user1", "test_password") # create user

# modify access
try:
db.modify_user_permission(
db.modify_user_permissions(
"user1",
"granular",
permissions={
Expand Down Expand Up @@ -375,7 +375,7 @@ def test_granular_permission_table_level(self):
self.assertIn("SELECT command denied to user", str(output))

# purge access and verify
db.modify_user_permission("user1", "granular", permissions={})
db.modify_user_permissions("user1", "granular", permissions={})

success, output = user1_db.execute_query("SELECT * FROM Person;")
self.assertFalse(success, "User `user1` should not have access to `Person` table")
Expand All @@ -388,11 +388,11 @@ def test_granular_permission_column_level(self):
db = self._db(self.db1__name, "root", self.instance.db_root_password)
user1_db = self._db(self.db1__name, "user1", "test_password")

db.add_user("user1", "test_password") # add user
db.create_user("user1", "test_password") # create user

# modify access [read_only]
try:
db.modify_user_permission(
db.modify_user_permissions(
"user1", "granular", permissions={"Person": {"mode": "read_only", "columns": ["id"]}}
)
except Exception:
Expand All @@ -411,7 +411,7 @@ def test_granular_permission_column_level(self):

# modify access [read_write]
try:
db.modify_user_permission(
db.modify_user_permissions(
"user1", "granular", permissions={"Person": {"mode": "read_write", "columns": ["id"]}}
)
except Exception:
Expand Down Expand Up @@ -445,8 +445,8 @@ def test_granular_permission_column_level(self):

def test_revoke_permission(self):
db = self._db(self.db1__name, "root", self.instance.db_root_password)
db.add_user("user1", "test_password") # add user
db.modify_user_permission("user1", "read_only")
db.create_user("user1", "test_password") # create user
db.modify_user_permissions("user1", "read_only")

user1_db = self._db(self.db1__name, "user1", "test_password")

Expand All @@ -456,7 +456,7 @@ def test_revoke_permission(self):

# revoke permission
# setting no permission in granular mode, will revoke all existing permissions
db.modify_user_permission("user1", "granular", permissions={})
db.modify_user_permissions("user1", "granular", permissions={})

# try to access the database again
success, output = user1_db.execute_query("SELECT * FROM Person;")
Expand Down
38 changes: 38 additions & 0 deletions agent/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,44 @@ def run_sql(bench, site):
)


@application.route("/benches/<string:bench>/sites/<string:site>/database/users", methods=["POST"])
@validate_bench_and_site
def add_database_user(bench, site):
data = request.json
return (
Server()
.benches[bench]
.sites[site]
.create_database_user(data["user"], data["password"], data["mariadb_root_password"])
)


@application.route(
"/benches/<string:bench>/sites/<string:site>/database/users/<string:db_user>", methods=["DELETE"]
)
@validate_bench_and_site
def remove_database_user(bench, site, db_user):
data = request.json
return Server().benches[bench].sites[site].remove_database_user(db_user, data["mariadb_root_password"])


@application.route(
"/benches/<string:bench>/sites/<string:site>/database/users/<string:db_user>/permissions",
methods=["POST"],
)
@validate_bench_and_site
def add_database_permissions(bench, site, db_user):
data = request.json
return (
Server()
.benches[bench]
.sites[site]
.modify_permissions_for_database_user(
db_user, data["mode"], data.get("permissions", []), data["mariadb_root_password"]
)
)


@application.route(
"/benches/<string:bench>/sites/<string:site>/migrate",
methods=["POST"],
Expand Down

0 comments on commit 503d39c

Please sign in to comment.