diff --git a/nih2mne/GUI/qa_bids_gui.py b/nih2mne/GUI/qa_bids_gui.py index 862b27b..f93fc85 100644 --- a/nih2mne/GUI/qa_bids_gui.py +++ b/nih2mne/GUI/qa_bids_gui.py @@ -30,6 +30,9 @@ import mne import numpy as np from scipy.stats import zscore, trim_mean +import pandas as pd +import pyctf +import mne_bids CFG_VERSION = 1.0 @@ -463,6 +466,8 @@ def __init__(self, fname): # Calculate the 10s epoch trim mean (60% mid) average power self._calc_chan_power() + # Compute PSD + self._compute_psd() def _get_task(self): tmp = self.fname.split('_') @@ -491,20 +496,6 @@ def load(self, load_val=False): self._chan_picks = [i for i in chan_picks if len(i)==5] #Some datasets have extra odd chans self._megall_picks = self._ref_picks+self._chan_picks - def save(self, fname=None, overwrite=False): - import pickle - if fname == None: - raise ValueError('fname must be set during save') - - if hasattr(self, 'raw'): - del self.raw - - if op.exists(fname) and overwrite==False: - raise ValueError(f'The fname already exists: {fname}') - else: - with open(fname, 'wb') as f: - pickle.dump(self, f) - def _calc_bad_segments(self): @@ -544,6 +535,10 @@ def _calc_chan_power(self): epo_pow_av = epo_pow.mean(axis=-1) #average over time in block epo_robust_av = trim_mean(epo_pow_av, proportiontocut=0.2, axis=0) self.chan_power = epo_robust_av + + def _compute_psd(self): + # epochs = mne.make_fixed_length_epochs(self.raw, duration=10.0, preload=True) + self.psd = self.raw.compute_psd() def _is_valid(self, set_value=None): '''Fill in more of this -- maybe ''' @@ -563,6 +558,19 @@ def button_set_status(self): def set_status(self, status): self.status=status + + @property + def coil_locs_dewar(self): + return pyctf.getHC.getHC(op.join(self.fname, op.basename(self.fname).replace('.ds','.hc')), 'dewar') + + #Currently not working -- these get stripped out in BIDS + @property + def coil_locs_head(self): + return pyctf.getHC.getHC(op.join(self.fname, op.basename(self.fname).replace('.ds','.hc')), 'head') + + @property + def event_counts(self): + return pd.DataFrame(self.raw.annotations).description.value_counts() def __repr__(self): tmp_ = f'megraw: {self.task} : {self.fname}' @@ -574,12 +582,22 @@ def __repr__(self): #%% class meglist_class: def __init__(self, subject=None, bids_root=None): - dsets = glob.glob(f'{op.join(subject, "**", "*.ds")}', - root_dir=bids_root, recursive=True) + dsets = glob.glob(f'{op.join(bids_root, subject, "**", "*.ds")}', + recursive=True) tmp = [qa_megraw_object(i) for i in dsets] self.meg_list = tmp self.meg_emptyroom = [i for i in self.meg_list if i.is_emptyroom] + def _print_meg_list_idxs(self): + for idx, dset in enumerate(self.meg_list): + print(f'{idx}: {dset.fname}') + + def plot_meg(self): + self._print_meg_list_idxs() + dset_idx = input('Enter the number associated with the MEG dataset to plot: \n') + dset_idx = int(dset_idx) + self.meg_list[dset_idx].raw.plot() + @property def meg_count(self): return len(self.meg_list) @@ -594,6 +612,7 @@ def __init__(self, subject=None, bids_root=None): self.all_mris = all_mri_list if len(self.all_mris)==0: self.mri = None + self.mri_json_qa = 'No MRIs' elif len(self.all_mris)==1: self.mri = self.all_mris[0] if self.mri.endswith('.nii'): @@ -602,6 +621,7 @@ def __init__(self, subject=None, bids_root=None): self.mri_json = self.mri.replace('.nii.gz','.json') else: self.mri = 'Multiple' + self.mri_json_qa = 'Undetermined - Multiple MRIs' if (self.mri != 'Multiple') and (self.mri != None): self._valid_fids() @@ -612,6 +632,9 @@ def _sort_T1(self): def _valid_fids(self): import json + if not op.exists(self.mri_json): + self.mri_json_qa = 'No MRI JSON' + return with open(self.mri_json) as f: json_data = json.load(f) if 'AnatomicalLandmarkCoordinates' not in json_data.keys(): @@ -639,6 +662,9 @@ def __init__(self, subject, bids_root=None, subjects_dir=None): else: self.bids_root = bids_root + if not op.exists(op.join(bids_root, self.subject)): + raise ValueError(f'Subject {self.subject} does not exist in {bids_root}') + if subjects_dir==None: self.subjects_dir = op.join(bids_root, 'derivatives','freesurfer','subjects') else: @@ -652,6 +678,27 @@ def __init__(self, subject, bids_root=None, subjects_dir=None): # Freesurfer Component self.fs_recon = check_fs_recon(self.subject, self.subjects_dir) + + def plot_mri_fids(self): + ''' Open a triaxial image of the fiducial locations''' + from nih2mne.utilities.qa_fids import plot_fids_qa + plot_fids_qa(subjid=self.subject, + bids_root=self.bids_root, + outfile=None, block=True) + # tmp_ = input('Hit any button to close') + + def plot_3D_coreg(self): + self._print_meg_list_idxs() + dset_idx = input('Enter the number associated with the MEG dataset to plot coreg: \n') + dset_idx = int(dset_idx) + bids_path = mne_bids.get_bids_path_from_fname(self.meg_list[dset_idx].fname) + t1_bids_path = mne_bids.get_bids_path_from_fname(self.mri) + trans = mne_bids.get_head_mri_trans(bids_path, t1_bids_path=t1_bids_path, + extra_params=dict(system_clock='ignore'), + fs_subject=self.subject, fs_subjects_dir=self.subjects_dir) + mne.viz.plot_alignment(self.meg_list[dset_idx].raw.info, + trans=trans,subject=self.subject, + subjects_dir = self.subjects_dir) @property def info(self): @@ -677,6 +724,22 @@ def info(self): def __repr__(self): return self.info + + def save(self, fname=None, overwrite=False): + import pickle + if fname == None: + raise ValueError('fname must be set during save') + + #Remove fully loaded meg before saving + for meg_dset in self.meg_list: + if hasattr(meg_dset, 'raw'): + del meg_dset.raw + + if op.exists(fname) and overwrite==False: + raise ValueError(f'The fname already exists: {fname}') + else: + with open(fname, 'wb') as f: + pickle.dump(self, f) class subject_tile(subject_bids_info): '''Attach GUI tile properties to bids information''' @@ -701,11 +764,74 @@ def set_status(self, status): #%% + +def test_subject_bids_info(): + import nih2mne + bids_root = op.join(nih2mne.__path__[0], 'test_data','BIDS_test') + test = subject_bids_info('sub-S01', bids_root=bids_root) + assert test.meg_count == 2 + tmp_ = test.mri + tmp_ = tmp_.split('test_data')[-1] + assert tmp_ == '/BIDS_test/sub-S01/ses-1/anat/sub-S01_ses-1_T1w.nii.gz' + + airpuff = test.meg_list[0] + val_counts = airpuff.event_counts + assert val_counts['stim']==103 + assert val_counts['missingstim']==17 + + +test = subject_bids_info('sub-ON02811', bids_root=os.getcwd()) +test = subject_bids_info('sub-ON69163', bids_root=os.getcwd()) +test.save(op.join('QA_objs', 'sub-ON69163.pkl'), overwrite=True) + test = subject_tile(subject='sub-ON11394', bids_root=os.getcwd()) test = subject_tile(subject='sub-ON08710', bids_root=os.getcwd()) subject_tile_list = [subject_tile(i, bids_root=bids_root) for i in glob.glob('sub-*')] + +fail=[] +for i in glob.glob('sub-*'): + try: + tmp = subject_bids_info(i, bids_root='/fast2/BIDS') + tmp.save(op.join('QA_objs', f'{i}_bidsqa.pkl'), overwrite=True) + del tmp + except: + fail.append(i) + +# power_stack = {} +# psd_stack = {} +psd_dframe_list = [] +pow_dframe_list = [] +fail = [] +for i in glob.glob('QA_objs/*bidsqa.pkl'): + with open(i, 'rb') as f: + qaobj = pickle.load(f) + # power_stack[qaobj.subject]={} + # psd_stack[qaobj.subject]={} + for dset in qaobj.meg_list: + try: + if (not dset.is_emptyroom) and (dset.task != 'artifact'): + #power_stack[qaobj.subject][dset.task]=dset.chan_power + dset.bids_root=os.getcwd() + dset.load() + pow_tmp = pd.DataFrame(dset.chan_power[np.newaxis,:], columns=[dset.raw.ch_names]) + pow_tmp['subject'] = qaobj.subject + pow_tmp['task'] = dset.task + pow_dframe_list.append(pow_tmp) + # psd_stack[qaobj.subject][dset.task]=dset.psd + tmp = pd.DataFrame(dset.psd._data, columns=dset.psd.freqs, index=dset.psd.ch_names) + tmp['subject']=qaobj.subject + tmp['task'] = dset.task + psd_dframe_list.append(tmp) + del tmp, pow_tmp + except: + fail.append(f'{qaobj.subject} : {dset.task}') + +psd_dframe = pd.concat +power_dframe = pd.concat(pow_dframe_list) +power_dframe.to_csv('/home/jstout/src/nih_to_mne/nih2mne/dataQA/power_hv_092424.csv', index=False) + def make_bids_subject_layout(row_num=6, col_num=4, subject_list=None, opts=None): '''Generate a Grid of datasets''' idx = 0 diff --git a/nih2mne/utilities/qa_fids.py b/nih2mne/utilities/qa_fids.py index 2ae3502..a700842 100644 --- a/nih2mne/utilities/qa_fids.py +++ b/nih2mne/utilities/qa_fids.py @@ -12,7 +12,7 @@ import json import nibabel as nib -def plot_fids_qa(subjid=None, bids_root=None, outfile=None): +def plot_fids_qa(subjid=None, bids_root=None, outfile=None, block=False): ''' Plot triaxial images of the fiducial locations and save out image @@ -60,17 +60,29 @@ def plot_fids_qa(subjid=None, bids_root=None, outfile=None): mri_pos = {'LPA':lpa, 'NAS': nas , 'RPA': rpa} # Plot it - fig, axs = plt.subplots(3, 1, figsize=(7, 7), facecolor="k") - for point_idx, label in enumerate(("LPA", "NAS", "RPA")): - plot_anat( - str(t1w_bids_path), - axes=axs[point_idx], - cut_coords=mri_pos[label],#, :], - title=label, - vmax=160, - output_file = outfile - ) - plt.show() + if block==False: + fig, axs = plt.subplots(3, 1, figsize=(7, 7), facecolor="k") + for point_idx, label in enumerate(("LPA", "NAS", "RPA")): + plot_anat( + str(t1w_bids_path), + axes=axs[point_idx], + cut_coords=mri_pos[label],#, :], + title=label, + vmax=160, + output_file = outfile + ) + plt.show() + else: + fig, axs = plt.subplots(3, 1, figsize=(7, 7), facecolor="k") + for point_idx, label in enumerate(("LPA", "NAS", "RPA")): + plot_anat( + str(t1w_bids_path), + axes=axs[point_idx], + cut_coords=mri_pos[label],#, :], + title=label, + vmax=160, + ) + plt.show(block=True) def main(): import argparse