From 9128e646d4d53174806b4e377824219b88f79d79 Mon Sep 17 00:00:00 2001 From: Viktor Sip Date: Mon, 6 Mar 2017 10:35:13 +0100 Subject: [PATCH] parc_fod_to_connectome cleanup --- bin/run_parc_fod_to_connectome.py | 4 +- tvb/recon/flow/parc_fod_to_connectome.py | 69 +++++++++++++----------- 2 files changed, 39 insertions(+), 34 deletions(-) diff --git a/bin/run_parc_fod_to_connectome.py b/bin/run_parc_fod_to_connectome.py index 14553ed..8938703 100644 --- a/bin/run_parc_fod_to_connectome.py +++ b/bin/run_parc_fod_to_connectome.py @@ -7,8 +7,8 @@ logging.basicConfig(level=logging.INFO) -parc, fod, con, gmwmi, ftt = sys.argv[1:] +parc, fod, track_counts, mean_lengths, gmwmi, ftt = sys.argv[1:] runner = SimpleRunner() -convtest = ParcFodToConnectome(parc, fod, con, gmwmi, ftt) +convtest = ParcFodToConnectome(parc, fod, track_counts, mean_lengths, gmwmi, ftt) convtest.run(runner) \ No newline at end of file diff --git a/tvb/recon/flow/parc_fod_to_connectome.py b/tvb/recon/flow/parc_fod_to_connectome.py index 10f3334..b054e9b 100644 --- a/tvb/recon/flow/parc_fod_to_connectome.py +++ b/tvb/recon/flow/parc_fod_to_connectome.py @@ -11,21 +11,26 @@ class ParcFodToConnectome(Flow): FRACTION_SIFT = 0.2 - CONV_CO = 1e-3 - CONV_ML = 1e-3 - CONV_TC = 1e-3 - MIN_TRACKS = 12500 - MAX_TRACKS = 50 * 10**6 - - def __init__(self, parc: os.PathLike, fod: os.PathLike, out_conn: os.PathLike, - gmwmi: os.PathLike=None, ftt: os.PathLike=None): + CONV_TC = 1e-2 + MIN_TRACKS = 5 * 10**5 + MAX_TRACKS = 64 * 10**6 + + def __init__(self, + parc: os.PathLike, + fod: os.PathLike, + out_track_counts: os.PathLike, + out_mean_lengths: os.PathLike, + gmwmi: os.PathLike=None, + ftt: os.PathLike=None): self.parc = parc self.fod = fod self.gmwmi = gmwmi self.ftt = ftt - self.out_conn = out_conn + self.out_track_counts = out_track_counts + self.out_mean_lengths = out_mean_lengths - def _get_track_stats(self, track_assignments: os.PathLike, track_lengths: os.PathLike) -> (np.ndarray, np.ndarray): + def _get_track_stats(self, track_assignments: os.PathLike, track_lengths: os.PathLike, keep_unassigned=False)\ + -> (np.ndarray, np.ndarray): tck_ass = np.genfromtxt(os.fspath(track_assignments), dtype=int) tck_len = np.genfromtxt(os.fspath(track_lengths), dtype=float) @@ -50,6 +55,11 @@ def _get_track_stats(self, track_assignments: os.PathLike, track_lengths: os.Pat # Replace NaNs with zeros. NaNs are present where is no connection. mean_len_mtx[mean_len_mtx != mean_len_mtx] = 0 + if not keep_unassigned: + # Unassigned values are in the first row/column according to tck2connectome docs + mean_len_mtx = mean_len_mtx[1:, 1:] + track_count_mtx = track_count_mtx[1:, 1:] + # Normalize track counts ntracks = np.sum(track_count_mtx) track_count_mtx /= ntracks @@ -75,15 +85,11 @@ def _gen_connectome(self, runner: Runner, tracks: os.PathLike, ntracks: int, par runner.run(mrtrix.run_tck2connectome(tracks_sifted, parc_lbl, conn, assignment=mrtrix.tck2connectome.Assignment.radial_search(2.0), out_assignments=track_assignments)) - conn_mtx = np.genfromtxt(os.fspath(conn), dtype=float) - - # Normalize - conn_mtx /= np.sum(conn_mtx) # Get the track statistics mean_len_mtx, track_count_mtx = self._get_track_stats(track_assignments, track_lengths) - return conn_mtx, mean_len_mtx, track_count_mtx + return mean_len_mtx, track_count_mtx def run(self, runner: Runner): log = logging.getLogger('parc_fod_to_connectome') @@ -96,8 +102,7 @@ def run(self, runner: Runner): runner.run(mrtrix.run_labelconvert(self.parc, lut_in, lut_out, parc_lbl)) ntracks = self.MIN_TRACKS - - fconv = open("conv.txt", "w", 1) + conv_history = [] # Generate initial number of tracks, the connectome and its statistics log.info('Generating initial %i tracks' % ntracks) @@ -105,8 +110,9 @@ def run(self, runner: Runner): runner.run(mrtrix.run_tckgen(self.fod, tracks_a, ntracks, seed_gmwmi=self.gmwmi, act=self.ftt)) log.info('Generating the connectome from %i tracks' % ntracks) - conn_mtx_0, mean_len_mtx_0, track_count_mtx_0 = self._gen_connectome(runner, tracks_a, ntracks, parc_lbl) + mean_len_mtx_0, track_count_mtx_0 = self._gen_connectome(runner, tracks_a, ntracks, parc_lbl) + converged = False while ntracks <= self.MAX_TRACKS/2: # TODO: catching errors, warnings @@ -120,29 +126,28 @@ def run(self, runner: Runner): tracks_a = tracks_merged log.info('Generating the connectome from %i tracks' % ntracks) - conn_mtx_1, mean_len_mtx_1, track_count_mtx_1 = self._gen_connectome(runner, tracks_a, ntracks, parc_lbl) + mean_len_mtx_1, track_count_mtx_1 = self._gen_connectome(runner, tracks_a, ntracks, parc_lbl) # Evaluate convergence criteria - norm_co = np.linalg.norm(conn_mtx_1 - conn_mtx_0) - norm_ml = np.linalg.norm(mean_len_mtx_1 - mean_len_mtx_0) - norm_tc = np.linalg.norm(track_count_mtx_1 - track_count_mtx_0) + norm_tc = np.linalg.norm(track_count_mtx_1 - track_count_mtx_0)/np.linalg.norm(track_count_mtx_1) log.info('N = %i' % ntracks) - log.info(' ||connectome difference|| = %s' % norm_co) - log.info(' ||mean lengths difference|| = %s' % norm_ml) - log.info(' ||track lengths difference|| = %s' % norm_tc) - fconv.write("%s %s %s %s\n" % (ntracks, norm_co, norm_ml, norm_tc)) + log.info(' Rel. track lengths difference norm = %s' % norm_tc) + conv_history.append((ntracks, norm_tc)) - conn_mtx_0, mean_len_mtx_0, track_count_mtx_0 = conn_mtx_1, mean_len_mtx_1, track_count_mtx_1 + mean_len_mtx_0, track_count_mtx_0 = mean_len_mtx_1, track_count_mtx_1 - if norm_co < self.CONV_CO and norm_ml < self.CONV_ML and norm_tc < self.CONV_TC: + if norm_tc < self.CONV_TC: + converged = True break - if ntracks > self.MAX_TRACKS: - log.warning('MAX_TRACKS reached, convergence criteria not satisfied for N=%s tracks') + if not converged: + log.warning('MAX_TRACKS reached, convergence criteria not satisfied for N=%s tracks' % ntracks) else: log.info('Convergence criteria satisfied for N=%i tracks' % ntracks) - np.savetxt(os.fspath(self.out_conn), conn_mtx_0) - fconv.close() + np.savetxt(os.fspath(self.out_track_counts), track_count_mtx_0) + np.savetxt(os.fspath(self.out_mean_lengths), mean_len_mtx_0) + log.info('Convergence history: ') + log.info(' '.join([str(a) for a in conv_history])) log.info('complete in %0.2fs', time.time() - tic)