Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
k1o0 committed Oct 17, 2024
1 parent 136978f commit 063c713
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
19 changes: 17 additions & 2 deletions ibllib/io/session_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import logging
import socket
from pathlib import Path
from itertools import chain
from copy import deepcopy

from one.converters import ConversionMixin
Expand Down Expand Up @@ -77,6 +78,9 @@ def _patch_file(data: dict) -> dict:
if 'tasks' in data and isinstance(data['tasks'], dict):
data['tasks'] = [{k: v} for k, v in data['tasks'].copy().items()]
data['version'] = SPEC_VERSION
# Ensure all items in tasks list are single value dicts
if 'tasks' in data:
data['tasks'] = [{k: v} for k, v in chain.from_iterable(map(dict.items, data['tasks']))]
return data


Expand Down Expand Up @@ -168,8 +172,19 @@ def merge_params(a, b, copy=False):
assert k not in a or a[k] == b[k], 'multiple sync fields defined'
if isinstance(b[k], list):
prev = list(a.get(k, []))
# For procedures and projects, remove duplicates
to_add = b[k] if k == 'tasks' else set(b[k]) - set(prev)
if k == 'tasks':
# For tasks, keep order and skip duplicates
# Assert tasks is a list of single value dicts
assert set(map(len, prev)) == {1} and set(map(len, b[k])) == {1}
# Convert protocol -> dict map to hashable tuple of protocol + sorted key value pairs
to_hashable = lambda itm: (itm[0], *chain.from_iterable(sorted(itm[1].items()))) # noqa
# Get the set of previous tasks
prev_tasks = set(map(to_hashable, chain.from_iterable(map(dict.items, prev))))
tasks = chain.from_iterable(map(dict.items, b[k]))
to_add = [dict([itm]) for itm in tasks if to_hashable(itm) not in prev_tasks]
else:
# For procedures and projects, remove duplicates
to_add = set(b[k]) - set(prev)
a[k] = prev + list(to_add)
elif isinstance(b[k], dict):
a[k] = {**a.get(k, {}), **b[k]}
Expand Down
20 changes: 19 additions & 1 deletion ibllib/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,22 @@ def test_read_yaml(self):
self.assertCountEqual(self.fixture.keys(), data_keys)

def test_patch_data(self):
"""Test for session_params._patch_file function."""
with patch(session_params.__name__ + '.SPEC_VERSION', '1.0.0'), \
self.assertLogs(session_params.__name__, logging.WARNING):
data = session_params._patch_file({'version': '1.1.0'})
self.assertEqual(data, {'version': '1.0.0'})
# Check tasks dicts separated into lists
unpatched = {'version': '0.0.1', 'tasks': {
'fooChoiceWorld': {1: '1'}, 'barChoiceWorld': {2: '2'}}}
data = session_params._patch_file(unpatched)
self.assertIsInstance(data['tasks'], list)
self.assertEqual([['fooChoiceWorld'], ['barChoiceWorld']], list(map(list, data['tasks'])))
# Check patching list of dicts with some containing more than 1 key
unpatched = {'tasks': [{'foo': {1: '1'}}, {'bar': {2: '2'}, 'baz': {3: '3'}}]}
data = session_params._patch_file(unpatched)
self.assertEqual(3, len(data['tasks']))
self.assertEqual([['foo'], ['bar'], ['baz']], list(map(list, data['tasks'])))

def test_get_collections(self):
collections = session_params.get_collections(self.fixture)
Expand Down Expand Up @@ -561,10 +573,16 @@ def test_merge_params(self):
b = {'procedures': ['Imaging', 'Injection'], 'tasks': [{'fooChoiceWorld': {'collection': 'bar'}}]}
c = session_params.merge_params(a, b, copy=True)
self.assertCountEqual(['Imaging', 'Behavior training/tasks', 'Injection'], c['procedures'])
self.assertCountEqual(['passiveChoiceWorld', 'ephysChoiceWorld', 'fooChoiceWorld'], (list(x)[0] for x in c['tasks']))
self.assertEqual(['passiveChoiceWorld', 'ephysChoiceWorld', 'fooChoiceWorld'], [list(x)[0] for x in c['tasks']])
# Ensure a and b not modified
self.assertNotEqual(set(c['procedures']), set(a['procedures']))
self.assertNotEqual(set(a['procedures']), set(b['procedures']))
# Test duplicate tasks skipped while order kept constant
d = {'tasks': [a['tasks'][1], {'ephysChoiceWorld': {'collection': 'raw_task_data_02', 'sync_label': 'nidq'}}]}
e = session_params.merge_params(c, d, copy=True)
expected = ['passiveChoiceWorld', 'ephysChoiceWorld', 'fooChoiceWorld', 'ephysChoiceWorld']
self.assertEqual(expected, [list(x)[0] for x in e['tasks']])
self.assertDictEqual({'collection': 'raw_task_data_02', 'sync_label': 'nidq'}, e['tasks'][-1]['ephysChoiceWorld'])
# Test without copy
session_params.merge_params(a, b, copy=False)
self.assertCountEqual(['Imaging', 'Behavior training/tasks', 'Injection'], a['procedures'])
Expand Down

0 comments on commit 063c713

Please sign in to comment.