Skip to content

Commit

Permalink
feat(get_key_source_count): take AndList restriction as argument to f…
Browse files Browse the repository at this point in the history
…urther restrict the key_source
  • Loading branch information
ttngu207 committed Nov 14, 2024
1 parent 75a486c commit 8df4a2b
Showing 1 changed file with 40 additions and 13 deletions.
53 changes: 40 additions & 13 deletions datajoint_utilities/dj_worker/worker_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandas as pd

from datajoint.user_tables import Part, UserTable
from datajoint.condition import AndList

logger = dj.logger

Expand Down Expand Up @@ -155,15 +156,7 @@ def _rename_attributes(table, props):
else table.proj()
)

def _remove_enclosed_parentheses(input_string):
pattern = r"\([^()]*\)"
# Use a while loop to recursively remove nested parentheses
while re.search(pattern, input_string):
# Replace all occurrences of the pattern with an {}
input_string = re.sub(pattern, "{}", input_string)
return input_string

target = dj.FreeTable(full_table_name=target_full_table_name, conn=dj.conn())
target = dj.FreeTable(full_table_name=target_full_table_name, conn=cls.connection)

try:
len(target)
Expand All @@ -189,14 +182,39 @@ def _remove_enclosed_parentheses(input_string):
return incomplete_sql

@classmethod
def get_key_source_count(cls, key_source_sql, target_full_table_name):
def get_key_source_count(cls, key_source_sql: str,
target_full_table_name: str,
andlist_restriction: AndList = None,
return_sql=False):
"""
From `key_source_sql`, count the total and incomplete key_source entries in the target table
Args:
key_source_sql (str): SQL statement for the key_source of the table
target_full_table_name (str): full table name of the target table
andlist_restriction (list|AndList): list of additional restrictions to be added to the key_source_sql
- the `restriction` property of QueryExpression - e.g. (table & key).restriction
return_sql (bool): if True, return the SQL statement instead of the count
"""
incomplete_sql = cls.get_incomplete_key_source_sql(key_source_sql, target_full_table_name)

if andlist_restriction:
restriction_str = ")AND(".join(str(s) for s in andlist_restriction)

AND_or_WHERE = (
"AND"
if "WHERE" in _remove_enclosed_parentheses(key_source_sql)
else " WHERE "
)

key_source_sql += f" {AND_or_WHERE} ({restriction_str})"
incomplete_sql += f" AND ({restriction_str})"

if return_sql:
return key_source_sql, incomplete_sql

try:
total = len(dj.conn().query(key_source_sql).fetchall())
incomplete = len(dj.conn().query(incomplete_sql).fetchall())
total = len(cls.connection.query(key_source_sql).fetchall())
incomplete = len(cls.connection.query(incomplete_sql).fetchall())
except Exception as e:
logger.error(
f"Error retrieving key_source for: {target_full_table_name}. \n{e}"
Expand All @@ -222,7 +240,7 @@ class WorkerLog(dj.Manual):
@classmethod
def log_process_job(cls, process, worker_name="", db_prefix=("",)):
process_name = get_process_name(process, db_prefix)
user = dj.conn().get_user()
user = cls.connection.get_user()

if not worker_name:
frame = inspect.currentframe()
Expand Down Expand Up @@ -381,3 +399,12 @@ def is_djtable(obj, base_class=None) -> bool:

def is_djparttable(obj) -> bool:
return is_djtable(obj, Part)


def _remove_enclosed_parentheses(input_string):
pattern = r"\([^()]*\)"
# Use a while loop to recursively remove nested parentheses
while re.search(pattern, input_string):
# Replace all occurrences of the pattern with an {}
input_string = re.sub(pattern, "{}", input_string)
return input_string

0 comments on commit 8df4a2b

Please sign in to comment.