diff --git a/datajoint_utilities/dj_worker/worker_schema.py b/datajoint_utilities/dj_worker/worker_schema.py index b65e093..a70d408 100644 --- a/datajoint_utilities/dj_worker/worker_schema.py +++ b/datajoint_utilities/dj_worker/worker_schema.py @@ -11,6 +11,7 @@ import pandas as pd from datajoint.user_tables import Part, UserTable +from datajoint.condition import AndList logger = dj.logger @@ -155,15 +156,7 @@ def _rename_attributes(table, props): else table.proj() ) - def _remove_enclosed_parentheses(input_string): - pattern = r"\([^()]*\)" - # Use a while loop to recursively remove nested parentheses - while re.search(pattern, input_string): - # Replace all occurrences of the pattern with an {} - input_string = re.sub(pattern, "{}", input_string) - return input_string - - target = dj.FreeTable(full_table_name=target_full_table_name, conn=dj.conn()) + target = dj.FreeTable(full_table_name=target_full_table_name, conn=cls.connection) try: len(target) @@ -189,14 +182,39 @@ def _remove_enclosed_parentheses(input_string): return incomplete_sql @classmethod - def get_key_source_count(cls, key_source_sql, target_full_table_name): + def get_key_source_count(cls, key_source_sql: str, + target_full_table_name: str, + andlist_restriction: AndList = None, + return_sql=False): """ From `key_source_sql`, count the total and incomplete key_source entries in the target table + Args: + key_source_sql (str): SQL statement for the key_source of the table + target_full_table_name (str): full table name of the target table + andlist_restriction (list|AndList): list of additional restrictions to be added to the key_source_sql + - the `restriction` property of QueryExpression - e.g. (table & key).restriction + return_sql (bool): if True, return the SQL statement instead of the count """ incomplete_sql = cls.get_incomplete_key_source_sql(key_source_sql, target_full_table_name) + + if andlist_restriction: + restriction_str = ")AND(".join(str(s) for s in andlist_restriction) + + AND_or_WHERE = ( + "AND" + if "WHERE" in _remove_enclosed_parentheses(key_source_sql) + else " WHERE " + ) + + key_source_sql += f" {AND_or_WHERE} ({restriction_str})" + incomplete_sql += f" AND ({restriction_str})" + + if return_sql: + return key_source_sql, incomplete_sql + try: - total = len(dj.conn().query(key_source_sql).fetchall()) - incomplete = len(dj.conn().query(incomplete_sql).fetchall()) + total = len(cls.connection.query(key_source_sql).fetchall()) + incomplete = len(cls.connection.query(incomplete_sql).fetchall()) except Exception as e: logger.error( f"Error retrieving key_source for: {target_full_table_name}. \n{e}" @@ -222,7 +240,7 @@ class WorkerLog(dj.Manual): @classmethod def log_process_job(cls, process, worker_name="", db_prefix=("",)): process_name = get_process_name(process, db_prefix) - user = dj.conn().get_user() + user = cls.connection.get_user() if not worker_name: frame = inspect.currentframe() @@ -381,3 +399,12 @@ def is_djtable(obj, base_class=None) -> bool: def is_djparttable(obj) -> bool: return is_djtable(obj, Part) + + +def _remove_enclosed_parentheses(input_string): + pattern = r"\([^()]*\)" + # Use a while loop to recursively remove nested parentheses + while re.search(pattern, input_string): + # Replace all occurrences of the pattern with an {} + input_string = re.sub(pattern, "{}", input_string) + return input_string