Skip to content

Commit

Permalink
Decoding qol updates (#1198)
Browse files Browse the repository at this point in the history
* generalize DecodingParameters fetch

* generalize DecodingParameters fetch1

* initial key decorator

* implement full_key_decorator within clusterless pipeline

* update changelog

* move decorator to mixin class method

* remove unused import

* update changelog

* Update src/spyglass/utils/dj_mixin.py

Co-authored-by: Chris Broz <[email protected]>

---------

Co-authored-by: Chris Broz <[email protected]>
  • Loading branch information
samuelbray32 and CBroz1 authored Dec 5, 2024
1 parent f56aba0 commit 692b281
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 33 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Merge table delete removes orphaned master entries #1164
- Edit `merge_fetch` to expect positional before keyword arguments #1181
- Allow part restriction `SpyglassMixinPart.delete` #1192
- Add mixin method `get_fully_defined_key` #1198

### Pipelines

Expand All @@ -62,6 +63,8 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Decoding

- Fix edge case errors in spike time loading #1083
- Allow fetch of partial key from `DecodingParameters` #1198
- Allow data fetching with partial but unique key #1198

- Linearization

Expand Down
46 changes: 38 additions & 8 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,8 @@ def fetch_model(self):
"""Retrieve the decoding model"""
return ClusterlessDetector.load_model(self.fetch1("classifier_path"))

@staticmethod
def fetch_environments(key):
@classmethod
def fetch_environments(cls, key):
"""Fetch the environments for the decoding model
Parameters
Expand All @@ -330,6 +330,9 @@ def fetch_environments(key):
List[TrackGraph]
list of track graphs in the trained model
"""
key = cls.get_fully_defined_key(
key, required_fields=["decoding_param_name"]
)
model_params = (
DecodingParameters
& {"decoding_param_name": key["decoding_param_name"]}
Expand All @@ -355,8 +358,8 @@ def fetch_environments(key):

return classifier.environments

@staticmethod
def fetch_position_info(key):
@classmethod
def fetch_position_info(cls, key):
"""Fetch the position information for the decoding model
Parameters
Expand All @@ -369,6 +372,15 @@ def fetch_position_info(key):
Tuple[pd.DataFrame, List[str]]
The position information and the names of the position variables
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"nwb_file_name",
"position_group_name",
"encoding_interval",
"decoding_interval",
],
)
position_group_key = {
"position_group_name": key["position_group_name"],
"nwb_file_name": key["nwb_file_name"],
Expand All @@ -381,8 +393,8 @@ def fetch_position_info(key):

return position_info, position_variable_names

@staticmethod
def fetch_linear_position_info(key):
@classmethod
def fetch_linear_position_info(cls, key):
"""Fetch the position information and project it onto the track graph
Parameters
Expand All @@ -395,6 +407,16 @@ def fetch_linear_position_info(key):
pd.DataFrame
The linearized position information
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"nwb_file_name",
"position_group_name",
"encoding_interval",
"decoding_interval",
],
)

environment = ClusterlessDecodingV1.fetch_environments(key)[0]

position_df = ClusterlessDecodingV1.fetch_position_info(key)[0]
Expand All @@ -417,8 +439,8 @@ def fetch_linear_position_info(key):
axis=1,
).loc[min_time:max_time]

@staticmethod
def fetch_spike_data(key, filter_by_interval=True):
@classmethod
def fetch_spike_data(cls, key, filter_by_interval=True):
"""Fetch the spike times for the decoding model
Parameters
Expand All @@ -434,6 +456,14 @@ def fetch_spike_data(key, filter_by_interval=True):
list[np.ndarray]
List of spike times for each unit in the model's spike group
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"nwb_file_name",
"waveform_features_group_name",
],
)

waveform_keys = (
(
UnitWaveformFeaturesGroup.UnitFeatures
Expand Down
48 changes: 34 additions & 14 deletions src/spyglass/decoding/v1/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,28 +70,48 @@ def insert(self, rows, *args, **kwargs):
def fetch(self, *args, **kwargs):
"""Return decoding parameters as a list of classes."""
rows = super().fetch(*args, **kwargs)
if len(rows) > 0 and len(rows[0]) > 1:
if kwargs.get("format", None) == "array":
# case when recalled by dj.fetch(), class conversion performed later in stack
return rows

if not len(args):
# infer args from table heading
args = tuple(self.heading)

if "decoding_params" not in args:
return rows

params_index = args.index("decoding_params")
if len(args) == 1:
# only fetching decoding_params
content = [restore_classes(r) for r in rows]
elif len(rows):
content = []
for (
decoding_param_name,
decoding_params,
decoding_kwargs,
) in rows:
content.append(
(
decoding_param_name,
restore_classes(decoding_params),
decoding_kwargs,
)
)
for row in zip(*rows):
row = list(row)
row[params_index] = restore_classes(row[params_index])
content.append(tuple(row))
else:
content = rows
return content

def fetch1(self, *args, **kwargs):
"""Return one decoding paramset as a class."""
row = super().fetch1(*args, **kwargs)
row["decoding_params"] = restore_classes(row["decoding_params"])

if len(args) == 0:
row["decoding_params"] = restore_classes(row["decoding_params"])
return row

if "decoding_params" in args:
if len(args) == 1:
return restore_classes(row)
row = list(row)
row[args.index("decoding_params")] = restore_classes(
row[args.index("decoding_params")]
)
return tuple(row)

return row


Expand Down
52 changes: 44 additions & 8 deletions src/spyglass/decoding/v1/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ def fetch_model(self):
"""Retrieve the decoding model"""
return SortedSpikesDetector.load_model(self.fetch1("classifier_path"))

@staticmethod
def fetch_environments(key):
@classmethod
def fetch_environments(cls, key):
"""Fetch the environments for the decoding model
Parameters
Expand All @@ -289,6 +289,10 @@ def fetch_environments(key):
List[TrackGraph]
list of track graphs in the trained model
"""
key = cls.get_fully_defined_key(
key, required_fields=["decoding_param_name"]
)

model_params = (
DecodingParameters
& {"decoding_param_name": key["decoding_param_name"]}
Expand All @@ -314,8 +318,8 @@ def fetch_environments(key):

return classifier.environments

@staticmethod
def fetch_position_info(key):
@classmethod
def fetch_position_info(cls, key):
"""Fetch the position information for the decoding model
Parameters
Expand All @@ -328,6 +332,16 @@ def fetch_position_info(key):
Tuple[pd.DataFrame, List[str]]
The position information and the names of the position variables
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"position_group_name",
"nwb_file_name",
"encoding_interval",
"decoding_interval",
],
)

position_group_key = {
"position_group_name": key["position_group_name"],
"nwb_file_name": key["nwb_file_name"],
Expand All @@ -339,8 +353,8 @@ def fetch_position_info(key):

return position_info, position_variable_names

@staticmethod
def fetch_linear_position_info(key):
@classmethod
def fetch_linear_position_info(cls, key):
"""Fetch the position information and project it onto the track graph
Parameters
Expand All @@ -353,6 +367,16 @@ def fetch_linear_position_info(key):
pd.DataFrame
The linearized position information
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"position_group_name",
"nwb_file_name",
"encoding_interval",
"decoding_interval",
],
)

environment = SortedSpikesDecodingV1.fetch_environments(key)[0]

position_df = SortedSpikesDecodingV1.fetch_position_info(key)[0]
Expand All @@ -374,9 +398,13 @@ def fetch_linear_position_info(key):
axis=1,
).loc[min_time:max_time]

@staticmethod
@classmethod
def fetch_spike_data(
key, filter_by_interval=True, time_slice=None, return_unit_ids=False
cls,
key,
filter_by_interval=True,
time_slice=None,
return_unit_ids=False,
) -> Union[list[np.ndarray], Optional[list[dict]]]:
"""Fetch the spike times for the decoding model
Expand All @@ -399,6 +427,14 @@ def fetch_spike_data(
list[np.ndarray]
List of spike times for each unit in the model's spike group
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"encoding_interval",
"decoding_interval",
],
)

spike_times, unit_ids = SortedSpikesGroup.fetch_spike_data(
key, return_unit_ids=True
)
Expand Down
10 changes: 7 additions & 3 deletions src/spyglass/spikesorting/analysis/v1/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import datajoint as dj
import numpy as np
from ripple_detection import get_multiunit_population_firing_rate

from spyglass.common import Session # noqa: F401
from spyglass.settings import test_mode
Expand Down Expand Up @@ -127,9 +126,12 @@ def filter_units(
include_mask[ind] = True
return include_mask

@staticmethod
@classmethod
def fetch_spike_data(
key: dict, time_slice: list[float] = None, return_unit_ids: bool = False
cls,
key: dict,
time_slice: list[float] = None,
return_unit_ids: bool = False,
) -> Union[list[np.ndarray], Optional[list[dict]]]:
"""fetch spike times for units in the group
Expand All @@ -148,6 +150,8 @@ def fetch_spike_data(
list of np.ndarray
list of spike times for each unit in the group
"""
key = cls.get_fully_defined_key(key)

# get merge_ids for SpikeSortingOutput
merge_ids = (
(
Expand Down
22 changes: 22 additions & 0 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,28 @@ def _safe_context(cls):
else nullcontext()
)

@classmethod
def get_fully_defined_key(
cls, key: dict = None, required_fields: list[str] = None
) -> dict:
if key is None:
key = dict()

required_fields = required_fields or cls.primary_key
if isinstance(key, (str, dict)): # check is either keys or substrings
if not all(
field in key for field in required_fields
): # check if all required fields are in key
if not len(query := cls() & key) == 1: # check if key is unique
raise KeyError(
f"Key is neither fully specified nor a unique entry in"
+ f"table.\n\tTable: {cls.full_table_name}\n\tKey: {key}"
+ f"Required fields: {required_fields}\n\tResult: {query}"
)
key = query.fetch1("KEY")

return key

# ------------------------------- fetch_nwb -------------------------------

@cached_property
Expand Down

0 comments on commit 692b281

Please sign in to comment.