Skip to content

Commit

Permalink
add all trials accuracy to summary table; bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
egmcbride committed Dec 5, 2024
1 parent d0bcaa4 commit 6a9222c
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions src/dynamic_routing_analysis/decoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,9 +1123,8 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units=
decoder_results[session_id]['results'][aa]['shift'][nunits]={}
decoder_results[session_id]['results'][aa]['no_shift'][nunits]={}
for rr in range(n_repeats):
if n_repeats>1:
decoder_results[session_id]['results'][aa]['shift'][nunits][rr]={}
decoder_results[session_id]['results'][aa]['no_shift'][nunits][rr]={}
decoder_results[session_id]['results'][aa]['shift'][nunits][rr]={}
decoder_results[session_id]['results'][aa]['no_shift'][nunits][rr]={}

if input_data_type=='spikes':
if nunits=='all':
Expand Down Expand Up @@ -1248,6 +1247,7 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
n_repeats=25

all_bal_acc={}
all_trials_bal_acc={}

linear_shift_dict={
'session_id':[],
Expand Down Expand Up @@ -1276,10 +1276,11 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
linear_shift_dict['null_accuracy_median_'+str(nu)]=[]
linear_shift_dict['null_accuracy_std_'+str(nu)]=[]
linear_shift_dict['p_value_'+str(nu)]=[]
linear_shift_dict['true_accuracy_all_trials_no_shift_'+str(nu)]=[]

#loop through sessions
for file in files:
try:
# try:
decoder_results=pickle.loads(upath.UPath(file).read_bytes())
session_id=str(list(decoder_results.keys())[0])
session_info=npc_lims.get_session_info(session_id)
Expand All @@ -1298,6 +1299,7 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
continue

all_bal_acc[session_id]={}
all_trials_bal_acc[session_id]={}

nunits=decoder_results[session_id]['n_units']
if nunits!=nunits_global:
Expand Down Expand Up @@ -1330,28 +1332,33 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
for aa in areas:
if aa in decoder_results[session_id]['results']:
all_bal_acc[session_id][aa]={}
all_trials_bal_acc[session_id][aa]={}
### ADD LOOP FOR NUNITS ###
for nu in nunits:
if nu not in decoder_results[session_id]['results'][aa]['shift'].keys():
continue
all_bal_acc[session_id][aa][nu]=[]
all_trials_bal_acc[session_id][aa][nu]=[]
for rr in range(n_repeats):
if rr in decoder_results[session_id]['results'][aa]['shift'][nu].keys():
temp_bal_acc=[]
temp_bal_acc_all_trials=[]
# else:
# print('n repeats invalid: '+str(rr))
# continue
for sh in half_shift_inds:
if sh in list(decoder_results[session_id]['results'][aa]['shift'][nu][rr].keys()):
temp_bal_acc.append(decoder_results[session_id]['results'][aa]['shift'][nu][rr][sh]['balanced_accuracy_test'])
if sh==0:
temp_bal_acc_all_trials.append(decoder_results[session_id]['results'][aa]['no_shift'][nu][rr]['balanced_accuracy_test'])

if len(temp_bal_acc)>0:
all_bal_acc[session_id][aa][nu].append(np.array(temp_bal_acc))

all_trials_bal_acc[session_id][aa][nu].append(decoder_results[session_id]['results'][aa]['no_shift'][nu][rr]['balanced_accuracy_test'])

all_bal_acc[session_id][aa][nu]=np.vstack(all_bal_acc[session_id][aa][nu])
all_bal_acc[session_id][aa][nu]=np.nanmean(all_bal_acc[session_id][aa][nu],axis=0)

all_trials_bal_acc[session_id][aa][nu]=np.nanmean(all_trials_bal_acc[session_id][aa][nu])

if type(aa)==str:
if '_probe' in aa:
area_name=aa.split('_probe')[0]
Expand Down Expand Up @@ -1391,6 +1398,12 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
linear_shift_dict['null_accuracy_std_'+str(nu)].append(np.nan)
linear_shift_dict['p_value_'+str(nu)].append(np.nan)

if nu in all_trials_bal_acc[session_id][aa].keys():
true_accuracy=all_trials_bal_acc[session_id][aa][nu]
linear_shift_dict['true_accuracy_all_trials_no_shift_'+str(nu)].append(true_accuracy)
else:
linear_shift_dict['true_accuracy_all_trials_no_shift_'+str(nu)].append(np.nan)

#make big dict/dataframe for this:
#save true decoding, mean/median null decoding, and p value for each area/probe
linear_shift_dict['session_id'].append(session_id)
Expand All @@ -1414,10 +1427,10 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
linear_shift_dict['probe'].append(np.nan)

print(aa+' done')
except Exception as e:
print(e)
print('error with session: '+session_id)
continue
# except Exception as e:
# print(e)
# print('error with session: '+session_id)
# continue


linear_shift_df=pd.DataFrame(linear_shift_dict)
Expand Down

0 comments on commit 6a9222c

Please sign in to comment.