From ecf468e2c7ff77652d6b646f7f07d52ed19e84e3 Mon Sep 17 00:00:00 2001 From: Chris Broz Date: Tue, 27 Aug 2024 11:06:59 -0500 Subject: [PATCH] Periph table fallback on TableChain for experimenter summary (#1035) * Periph table fallback on TableChain * Update Changelog * Rely on search to remove no_visit, not id step * Include generic load_shared_schemas * Update changelog for release * Allow add custom prefix for load schemas * Fix merge error --- CHANGELOG.md | 18 +---- src/spyglass/utils/dj_graph.py | 43 +++++++++--- src/spyglass/utils/dj_mixin.py | 117 ++++++++++++++++----------------- 3 files changed, 95 insertions(+), 83 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a5d467ec1..57fd495d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,20 +1,6 @@ # Change Log -## [0.5.3] (Unreleased) - -## Release Notes - - - -```python -import datajoint as dj -from spyglass.common.common_behav import PositionIntervalMap -from spyglass.decoding.v1.core import PositionGroup - -dj.schema("common_ripple").drop() -PositionIntervalMap.alter() -PositionGroup.alter() -``` +## [0.5.3] (August 27, 2024) ### Infrastructure @@ -46,6 +32,8 @@ PositionGroup.alter() - Installation instructions -> Setup notebook. #1029 - Migrate SQL export tools to `utils` to support exporting `DandiPath` #1048 - Add tool for checking threads for metadata locks on a table #1063 +- Use peripheral tables as fallback in `TableChains` #1035 +- Ignore non-Spyglass tables during descendant check for `part_masters` #1035 ### Pipelines diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 0ab4ab477..6b3928042 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -248,7 +248,7 @@ def _get_ft(self, table, with_restr=False, warn=True): return ft & restr - def _is_out(self, table, warn=True): + def _is_out(self, table, warn=True, keep_alias=False): """Check if table is outside of spyglass.""" table = ensure_names(table) if self.graph.nodes.get(table): @@ -805,7 +805,8 @@ class TableChain(RestrGraph): Returns path OrderedDict of full table names in chain. If directed is True, uses directed graph. If False, uses undirected graph. Undirected excludes PERIPHERAL_TABLES like interval_list, nwbfile, etc. to maintain - valid joins. + valid joins by default. If no path is found, another search is attempted + with PERIPHERAL_TABLES included. cascade(restriction: str = None, direction: str = "up") Given a restriction at the beginning, return a restricted FreeTable object at the end of the chain. If direction is 'up', start at the child @@ -835,8 +836,12 @@ def __init__( super().__init__(seed_table=seed_table, verbose=verbose) self._ignore_peripheral(except_tables=[self.parent, self.child]) + self._ignore_outside_spy(except_tables=[self.parent, self.child]) + self.no_visit.update(ensure_names(banned_tables) or []) + self.no_visit.difference_update(set([self.parent, self.child])) + self.searched_tables = set() self.found_restr = False self.link_type = None @@ -872,7 +877,19 @@ def _ignore_peripheral(self, except_tables: List[str] = None): except_tables = ensure_names(except_tables) ignore_tables = set(PERIPHERAL_TABLES) - set(except_tables or []) self.no_visit.update(ignore_tables) - self.undirect_graph.remove_nodes_from(ignore_tables) + + def _ignore_outside_spy(self, except_tables: List[str] = None): + """Ignore tables not shared on shared prefixes.""" + except_tables = ensure_names(except_tables) + ignore_tables = set( # Ignore tables not in shared modules + [ + t + for t in self.undirect_graph.nodes + if t not in except_tables + and self._is_out(t, warn=False, keep_alias=True) + ] + ) + self.no_visit.update(ignore_tables) # --------------------------- Dunder Properties --------------------------- @@ -1066,9 +1083,9 @@ def find_path(self, directed=True) -> List[str]: List of names in the path. """ source, target = self.parent, self.child - search_graph = self.graph if directed else self.undirect_graph - - search_graph.remove_nodes_from(self.no_visit) + search_graph = ( # Copy to ensure orig not modified by no_visit + self.graph.copy() if directed else self.undirect_graph.copy() + ) try: path = shortest_path(search_graph, source, target) @@ -1096,6 +1113,12 @@ def path(self) -> list: self.link_type = "directed" elif path := self.find_path(directed=False): self.link_type = "undirected" + else: # Search with peripheral + self.no_visit.difference_update(PERIPHERAL_TABLES) + if path := self.find_path(directed=True): + self.link_type = "directed with peripheral" + elif path := self.find_path(directed=False): + self.link_type = "undirected with peripheral" self.searched_path = True return path @@ -1126,9 +1149,11 @@ def cascade( # Cascade will stop if any restriction is empty, so set rest to None # This would cause issues if we want a table partway through the chain # but that's not a typical use case, were the start and end are desired - non_numeric = [t for t in self.path if not t.isnumeric()] - if any(self._get_restr(t) is None for t in non_numeric): - for table in non_numeric: + safe_tbls = [ + t for t in self.path if not t.isnumeric() and not self._is_out(t) + ] + if any(self._get_restr(t) is None for t in safe_tbls): + for table in safe_tbls: if table is not start: self._set_restr(table, False, replace=True) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index ff3922087..04b873740 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -261,52 +261,41 @@ def fetch_pynapple(self, *attrs, **kwargs): # ------------------------ delete_downstream_parts ------------------------ - def _import_part_masters(self): - """Import tables that may constrain a RestrGraph. See #1002""" - from spyglass.decoding.decoding_merge import DecodingOutput # noqa F401 - from spyglass.decoding.v0.clusterless import ( - UnitMarksIndicatorSelection, - ) # noqa F401 - from spyglass.decoding.v0.sorted_spikes import ( - SortedSpikesIndicatorSelection, - ) # noqa F401 - from spyglass.decoding.v1.core import PositionGroup # noqa F401 - from spyglass.lfp.analysis.v1 import LFPBandSelection # noqa F401 - from spyglass.lfp.lfp_merge import LFPOutput # noqa F401 - from spyglass.linearization.merge import ( # noqa F401 - LinearizedPositionOutput, - LinearizedPositionV1, - ) - from spyglass.mua.v1.mua import MuaEventsV1 # noqa F401 - from spyglass.position.position_merge import PositionOutput # noqa F401 - from spyglass.ripple.v1.ripple import RippleTimesV1 # noqa F401 - from spyglass.spikesorting.analysis.v1.group import ( - SortedSpikesGroup, - ) # noqa F401 - from spyglass.spikesorting.spikesorting_merge import ( - SpikeSortingOutput, - ) # noqa F401 - from spyglass.spikesorting.v0.figurl_views import ( - SpikeSortingRecordingView, - ) # noqa F401 - - _ = ( - DecodingOutput(), - LFPBandSelection(), - LFPOutput(), - LinearizedPositionOutput(), - LinearizedPositionV1(), - MuaEventsV1(), - PositionGroup(), - PositionOutput(), - RippleTimesV1(), - SortedSpikesGroup(), - SortedSpikesIndicatorSelection(), - SpikeSortingOutput(), - SpikeSortingRecordingView(), - UnitMarksIndicatorSelection(), + def load_shared_schemas(self, additional_prefixes: list = None) -> None: + """Load shared schemas to include in graph traversal. + + Parameters + ---------- + additional_prefixes : list, optional + Additional prefixes to load. Default None. + """ + all_shared = [ + *SHARED_MODULES, + dj.config["database.user"], + "file", + "sharing", + ] + + if additional_prefixes: + all_shared.extend(additional_prefixes) + + # Get a list of all shared schemas in spyglass + schemas = dj.conn().query( + "SELECT DISTINCT table_schema " # Unique schemas + + "FROM information_schema.key_column_usage " + + "WHERE" + + ' table_name not LIKE "~%%"' # Exclude hidden + + " AND constraint_name='PRIMARY'" # Only primary keys + + "AND (" # Only shared schemas + + " OR ".join([f"table_schema LIKE '{s}_%%'" for s in all_shared]) + + ") " + + "ORDER BY table_schema;" ) + # Load the dependencies for all shared schemas + for schema in schemas: + dj.schema(schema[0]).connection.dependencies.load() + @cached_property def _part_masters(self) -> set: """Set of master tables downstream of self. @@ -318,23 +307,25 @@ def _part_masters(self) -> set: part_masters = set() def search_descendants(parent): - for desc in parent.descendants(as_objects=True): + for desc_name in parent.descendants(): if ( # Check if has master, is part - not (master := get_master(desc.full_table_name)) - # has other non-master parent - or not set(desc.parents()) - set([master]) + not (master := get_master(desc_name)) or master in part_masters # already in cache + or desc_name.replace("`", "").split("_")[0] + not in SHARED_MODULES ): continue - if master not in part_masters: - part_masters.add(master) - search_descendants(dj.FreeTable(self.connection, master)) + desc = dj.FreeTable(self.connection, desc_name) + if not set(desc.parents()) - set([master]): # no other parent + continue + part_masters.add(master) + search_descendants(dj.FreeTable(self.connection, master)) try: _ = search_descendants(self) except NetworkXError: - try: # Attempt to import missing table - self._import_part_masters() + try: # Attempt to import failing schema + self.load_shared_schemas() _ = search_descendants(self) except NetworkXError as e: table_name = "".join(e.args[0].split("`")[1:4]) @@ -484,7 +475,7 @@ def _delete_deps(self) -> List[Table]: self._member_pk = LabMember.primary_key[0] return [LabMember, LabTeam, Session, schema.external, IntervalList] - def _get_exp_summary(self): + def _get_exp_summary(self) -> Union[QueryExpression, None]: """Get summary of experimenters for session(s), including NULL. Parameters @@ -494,9 +485,12 @@ def _get_exp_summary(self): Returns ------- - str - Summary of experimenters for session(s). + Union[QueryExpression, None] + dj.Union object Summary of experimenters for session(s). If no link + to Session, return None. """ + if not self._session_connection.has_link: + return None Session = self._delete_deps[2] SesExp = Session.Experimenter @@ -521,8 +515,7 @@ def _session_connection(self): """Path from Session table to self. False if no connection found.""" from spyglass.utils.dj_graph import TableChain # noqa F401 - connection = TableChain(parent=self._delete_deps[2], child=self) - return connection if connection.has_link else False + return TableChain(parent=self._delete_deps[2], child=self, verbose=True) @cached_property def _test_mode(self) -> bool: @@ -564,7 +557,13 @@ def _check_delete_permission(self) -> None: ) return - sess_summary = self._get_exp_summary() + if not (sess_summary := self._get_exp_summary()): + logger.warn( + f"Could not find a connection from {self.camel_name} " + + "to Session.\n Be careful not to delete others' data." + ) + return + experimenters = sess_summary.fetch(self._member_pk) if None in experimenters: raise PermissionError(