Skip to content

Commit

Permalink
Merge pull request #47 from ttngu207/main
Browse files Browse the repository at this point in the history
feat(worker): apply restrictions to `get_key_source_count` in worker
  • Loading branch information
ttngu207 authored Nov 15, 2024
2 parents f342e8b + 416a60c commit 85afbb5
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 15 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.

## [0.5.1] - TBD

- Added - apply restrictions to `get_key_source_count` in worker


## [0.5.0] - 2024-11-08

### Added
Expand Down
53 changes: 40 additions & 13 deletions datajoint_utilities/dj_worker/worker_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandas as pd

from datajoint.user_tables import Part, UserTable
from datajoint.condition import AndList

logger = dj.logger

Expand Down Expand Up @@ -155,15 +156,7 @@ def _rename_attributes(table, props):
else table.proj()
)

def _remove_enclosed_parentheses(input_string):
pattern = r"\([^()]*\)"
# Use a while loop to recursively remove nested parentheses
while re.search(pattern, input_string):
# Replace all occurrences of the pattern with an {}
input_string = re.sub(pattern, "{}", input_string)
return input_string

target = dj.FreeTable(full_table_name=target_full_table_name, conn=dj.conn())
target = dj.FreeTable(full_table_name=target_full_table_name, conn=cls.connection)

try:
len(target)
Expand All @@ -189,14 +182,39 @@ def _remove_enclosed_parentheses(input_string):
return incomplete_sql

@classmethod
def get_key_source_count(cls, key_source_sql, target_full_table_name):
def get_key_source_count(cls, key_source_sql: str,
target_full_table_name: str,
andlist_restriction: AndList = None,
return_sql=False):
"""
From `key_source_sql`, count the total and incomplete key_source entries in the target table
Args:
key_source_sql (str): SQL statement for the key_source of the table
target_full_table_name (str): full table name of the target table
andlist_restriction (list|AndList): list of additional restrictions to be added to the key_source_sql
- the `restriction` property of QueryExpression - e.g. (table & key).restriction
return_sql (bool): if True, return the SQL statement instead of the count
"""
incomplete_sql = cls.get_incomplete_key_source_sql(key_source_sql, target_full_table_name)

if andlist_restriction:
restriction_str = ")AND(".join(str(s) for s in andlist_restriction)

AND_or_WHERE = (
"AND"
if "WHERE" in _remove_enclosed_parentheses(key_source_sql)
else " WHERE "
)

key_source_sql += f" {AND_or_WHERE} ({restriction_str})"
incomplete_sql += f" AND ({restriction_str})"

if return_sql:
return key_source_sql, incomplete_sql

try:
total = len(dj.conn().query(key_source_sql).fetchall())
incomplete = len(dj.conn().query(incomplete_sql).fetchall())
total = len(cls.connection.query(key_source_sql).fetchall())
incomplete = len(cls.connection.query(incomplete_sql).fetchall())
except Exception as e:
logger.error(
f"Error retrieving key_source for: {target_full_table_name}. \n{e}"
Expand All @@ -222,7 +240,7 @@ class WorkerLog(dj.Manual):
@classmethod
def log_process_job(cls, process, worker_name="", db_prefix=("",)):
process_name = get_process_name(process, db_prefix)
user = dj.conn().get_user()
user = cls.connection.get_user()

if not worker_name:
frame = inspect.currentframe()
Expand Down Expand Up @@ -381,3 +399,12 @@ def is_djtable(obj, base_class=None) -> bool:

def is_djparttable(obj) -> bool:
return is_djtable(obj, Part)


def _remove_enclosed_parentheses(input_string):
pattern = r"\([^()]*\)"
# Use a while loop to recursively remove nested parentheses
while re.search(pattern, input_string):
# Replace all occurrences of the pattern with an {}
input_string = re.sub(pattern, "{}", input_string)
return input_string
2 changes: 1 addition & 1 deletion datajoint_utilities/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Package metadata."""

__version__ = "0.5.0"
__version__ = "0.5.1"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "datajoint-utilities"
version = "0.5.0"
version = "0.5.1"
description = "A general purpose repository containing all generic tools/utilities surrounding the DataJoint ecosystem"
requires-python = ">=3.9, <3.12"
license = { file = "LICENSE" }
Expand Down

0 comments on commit 85afbb5

Please sign in to comment.