diff --git a/src/dynamic_routing_analysis/decoding_utils.py b/src/dynamic_routing_analysis/decoding_utils.py index 393f53c..65f9ae6 100644 --- a/src/dynamic_routing_analysis/decoding_utils.py +++ b/src/dynamic_routing_analysis/decoding_utils.py @@ -1123,9 +1123,8 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units= decoder_results[session_id]['results'][aa]['shift'][nunits]={} decoder_results[session_id]['results'][aa]['no_shift'][nunits]={} for rr in range(n_repeats): - if n_repeats>1: - decoder_results[session_id]['results'][aa]['shift'][nunits][rr]={} - decoder_results[session_id]['results'][aa]['no_shift'][nunits][rr]={} + decoder_results[session_id]['results'][aa]['shift'][nunits][rr]={} + decoder_results[session_id]['results'][aa]['no_shift'][nunits][rr]={} if input_data_type=='spikes': if nunits=='all': @@ -1248,6 +1247,7 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session= n_repeats=25 all_bal_acc={} + all_trials_bal_acc={} linear_shift_dict={ 'session_id':[], @@ -1276,10 +1276,11 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session= linear_shift_dict['null_accuracy_median_'+str(nu)]=[] linear_shift_dict['null_accuracy_std_'+str(nu)]=[] linear_shift_dict['p_value_'+str(nu)]=[] + linear_shift_dict['true_accuracy_all_trials_no_shift_'+str(nu)]=[] #loop through sessions for file in files: - try: + # try: decoder_results=pickle.loads(upath.UPath(file).read_bytes()) session_id=str(list(decoder_results.keys())[0]) session_info=npc_lims.get_session_info(session_id) @@ -1298,6 +1299,7 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session= continue all_bal_acc[session_id]={} + all_trials_bal_acc[session_id]={} nunits=decoder_results[session_id]['n_units'] if nunits!=nunits_global: @@ -1330,28 +1332,33 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session= for aa in areas: if aa in decoder_results[session_id]['results']: all_bal_acc[session_id][aa]={} + all_trials_bal_acc[session_id][aa]={} ### ADD LOOP FOR NUNITS ### for nu in nunits: if nu not in decoder_results[session_id]['results'][aa]['shift'].keys(): continue all_bal_acc[session_id][aa][nu]=[] + all_trials_bal_acc[session_id][aa][nu]=[] for rr in range(n_repeats): if rr in decoder_results[session_id]['results'][aa]['shift'][nu].keys(): temp_bal_acc=[] - temp_bal_acc_all_trials=[] # else: # print('n repeats invalid: '+str(rr)) # continue for sh in half_shift_inds: if sh in list(decoder_results[session_id]['results'][aa]['shift'][nu][rr].keys()): temp_bal_acc.append(decoder_results[session_id]['results'][aa]['shift'][nu][rr][sh]['balanced_accuracy_test']) - if sh==0: - temp_bal_acc_all_trials.append(decoder_results[session_id]['results'][aa]['no_shift'][nu][rr]['balanced_accuracy_test']) + if len(temp_bal_acc)>0: all_bal_acc[session_id][aa][nu].append(np.array(temp_bal_acc)) + + all_trials_bal_acc[session_id][aa][nu].append(decoder_results[session_id]['results'][aa]['no_shift'][nu][rr]['balanced_accuracy_test']) + all_bal_acc[session_id][aa][nu]=np.vstack(all_bal_acc[session_id][aa][nu]) all_bal_acc[session_id][aa][nu]=np.nanmean(all_bal_acc[session_id][aa][nu],axis=0) + all_trials_bal_acc[session_id][aa][nu]=np.nanmean(all_trials_bal_acc[session_id][aa][nu]) + if type(aa)==str: if '_probe' in aa: area_name=aa.split('_probe')[0] @@ -1391,6 +1398,12 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session= linear_shift_dict['null_accuracy_std_'+str(nu)].append(np.nan) linear_shift_dict['p_value_'+str(nu)].append(np.nan) + if nu in all_trials_bal_acc[session_id][aa].keys(): + true_accuracy=all_trials_bal_acc[session_id][aa][nu] + linear_shift_dict['true_accuracy_all_trials_no_shift_'+str(nu)].append(true_accuracy) + else: + linear_shift_dict['true_accuracy_all_trials_no_shift_'+str(nu)].append(np.nan) + #make big dict/dataframe for this: #save true decoding, mean/median null decoding, and p value for each area/probe linear_shift_dict['session_id'].append(session_id) @@ -1414,10 +1427,10 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session= linear_shift_dict['probe'].append(np.nan) print(aa+' done') - except Exception as e: - print(e) - print('error with session: '+session_id) - continue + # except Exception as e: + # print(e) + # print('error with session: '+session_id) + # continue linear_shift_df=pd.DataFrame(linear_shift_dict)