diff --git a/tvbwidgets/ui/connectivity_matrix_editor_widget.py b/tvbwidgets/ui/connectivity_matrix_editor_widget.py index 4365189..e60b91c 100644 --- a/tvbwidgets/ui/connectivity_matrix_editor_widget.py +++ b/tvbwidgets/ui/connectivity_matrix_editor_widget.py @@ -18,11 +18,13 @@ from tvb.datatypes.connectivity import Connectivity from tvbwidgets.ui.base_widget import TVBWidget + class ConnectivityMatrixEditor(TVBWidget): - def __init__(self, connectivity, size = None): + def __init__(self, connectivity, size=None, **kwargs): + super().__init__(**kwargs) self.connectivity = connectivity - self.connectivities_history = [self.connectivity] - self.num_rows = int(len(self.connectivity.weights[0]) / 2) #num_cols will be equal to num_rows + self.connectivity_history_list = [self.connectivity] + self.num_rows = int(len(self.connectivity.weights[0]) / 2) # num_cols will be equal to num_rows if size is None: size = self.num_rows * 20 self.size = size @@ -30,35 +32,36 @@ def __init__(self, connectivity, size = None): self.is_connectivity_being_edited = True self.new_connectivity = self._prepare_new_connectivity() - self.header = widgets.HBox(layout = self.DEFAULT_BORDER) + self.header = widgets.HBox(layout=self.DEFAULT_BORDER) self._make_header() - self.tab = widgets.Tab(layout = self.DEFAULT_BORDER) - self._get_quadrant_range(selection = 1) + self.tab = widgets.Tab(layout=self.DEFAULT_BORDER) + self._get_quadrant_range(selection=1) self._prepare_matrices_tab() - + def _make_header(self): options = ["Quadrant 1", "Quadrant 2", "Quadrant 3", "Quadrant 4"] - - self.quadrants = widgets.Dropdown(options = options) - self.quadrants.observe(self._on_quadrant_select, names = ["value"]) - - self.cell_value = widgets.Text(description = "value", - layout = widgets.Layout(width = "200px", visibility = "hidden")) - - self.change_button = widgets.Button(description = "Change", - layout = widgets.Layout(width = "80px", visibility = "hidden")) - self.change_button.on_click(lambda change :self.on_apply_change(change)) - - self.save_button = widgets.Button(description = "Save", - layout = widgets.Layout(width = "100px", margin='0 0 0 auto')) + + self.quadrants = widgets.Dropdown(options=options) + self.quadrants.observe(self._on_quadrant_select, names=["value"]) + + self.cell_value = widgets.Text(description="value", + layout=widgets.Layout(width="200px", visibility="hidden")) + + self.change_button = widgets.Button(description="Change", + layout=widgets.Layout(width="80px", visibility="hidden")) + self.change_button.on_click(lambda change: self.on_apply_change(change)) + + self.save_button = widgets.Button(description="Save", + layout=widgets.Layout(width="100px", margin='0 0 0 auto')) self.save_button.on_click(self.on_click_save) - self.header.children = [self.quadrants, self.cell_value, self.change_button, self.save_button, self._get_history_dropdown()] + self.header.children = [self.quadrants, self.cell_value, self.change_button, self.save_button, + self._get_history_dropdown()] def _on_quadrant_select(self, change): self.cell_value.layout.visibility = "hidden" self.change_button.layout.visibility = "hidden" - + selection = int(change["new"][-1]) connectivity = self.new_connectivity if self.is_connectivity_being_edited else self.connectivity @@ -81,44 +84,44 @@ def _get_quadrant_range(self, selection): from_row = middle_val from_col = middle_val - #indexing starts from this row and col + # indexing starts from this row and col self.from_row = from_row self.from_col = from_col def _prepare_matrices_tab(self): self.weights_matrix = self._prepare_matrix("weights") - self.tract_lengths_matrix = self._prepare_matrix("tract_lengths") + self.tract_lengths_matrix = self._prepare_matrix("tract_lengths") self.weights_matrix.on_mouse_down(lambda x, y: self.on_cell_clicked(x, y, "weights")) self.weights_matrix.on_mouse_move(self.set_mouse_position) - + self.tract_lengths_matrix.on_mouse_down(lambda x, y: self.on_cell_clicked(x, y, "tract_lengths")) self.tract_lengths_matrix.on_mouse_move(self.set_mouse_position) - + out1 = widgets.Output() out2 = widgets.Output() with out1: display(self.weights_matrix) - + with out2: display(self.tract_lengths_matrix) - container1 = widgets.Box([out1], layout = widgets.Layout( - width = '1200px', - height = '600px', - overflow_x = 'auto', - overflow_y = 'auto', + container1 = widgets.Box([out1], layout=widgets.Layout( + width='1200px', + height='600px', + overflow_x='auto', + overflow_y='auto', )) - container2 = widgets.Box([out2], layout = container1.layout) + container2 = widgets.Box([out2], layout=container1.layout) self.tab.children = [container1, container2] self.tab.set_title(0, "weights") self.tab.set_title(1, "tract_lengths") def _prepare_matrix(self, matrix_name): - matrix = getattr(self.connectivity, matrix_name) #either weights or tracts matrix data + matrix = getattr(self.connectivity, matrix_name) matrix_full = canvas.MultiCanvas(6, width=self.size * 1.5, height=self.size * 1.2) matrix_view = matrix_full[0] @@ -128,79 +131,85 @@ def _prepare_matrix(self, matrix_name): grid = matrix_full[4] # sixth canvas is for displaying a grid around selected cell - #rotate the row_header canvas so they appear vertical + # rotate the row_header canvas so they appear vertical row_header.rotate(math.radians(-90)) - row_header.translate(-self.layout_offset ,0) + row_header.translate(-self.layout_offset, 0) with canvas.hold_canvas(matrix_full): - self.cell_x = np.tile(np.linspace(self.layout_offset, self.size, self.num_rows), self.num_rows) #x-coordinates of cells - self.cell_y = np.repeat(np.linspace(self.layout_offset, self.size, self.num_rows), self.num_rows) #y-coordinates of cells + self.cell_x = np.tile(np.linspace(self.layout_offset, self.size, self.num_rows), + self.num_rows) # x-coordinates of cells + self.cell_y = np.repeat(np.linspace(self.layout_offset, self.size, self.num_rows), + self.num_rows) # y-coordinates of cells self.cell_size = self.cell_x[1] - self.cell_x[0] - grid.stroke_rects(self.cell_x, self.cell_y, height = self.cell_size, width = self.cell_size) - value = matrix[self.from_row : self.from_row + self.num_rows, self.from_col : self.from_col + self.num_rows] - colors = self._generate_color(self.connectivity, value = value, matrix_name = matrix_name) - matrix_view.fill_styled_rects(self.cell_x, self.cell_y, color = colors, height = self.cell_size , width = self.cell_size) + grid.stroke_rects(self.cell_x, self.cell_y, height=self.cell_size, width=self.cell_size) + value = matrix[self.from_row: self.from_row + self.num_rows, self.from_col: self.from_col + self.num_rows] + colors = self._generate_color(self.connectivity, value=value, matrix_name=matrix_name) + matrix_view.fill_styled_rects(self.cell_x, self.cell_y, color=colors, height=self.cell_size, + width=self.cell_size) x = 0 y = np.linspace(self.layout_offset, self.size, self.num_rows) - grid.stroke_rects(y, x, height = self.layout_offset, width = self.cell_size) #grid for row headers - grid.stroke_rects(x, y, height = self.cell_size, width = self.layout_offset) #grid for column headers + grid.stroke_rects(y, x, height=self.layout_offset, width=self.cell_size) # grid for row headers + grid.stroke_rects(x, y, height=self.cell_size, width=self.layout_offset) # grid for column headers for i in range(self.num_rows): row_header.font = f"bold {self.cell_size}px px sans serif" row_header_text = f"{self.connectivity.region_labels[self.from_row + i]}" - row_header.fill_text(row_header_text, x + 10, y[i] + self.cell_size, max_width = self.layout_offset * 0.9) - + row_header.fill_text(row_header_text, x + 10, y[i] + self.cell_size, max_width=self.layout_offset * 0.9) + column_header.font = f"bold {self.cell_size}px px sans serif" column_header_text = f"{self.connectivity.region_labels[self.from_col + i]}" - column_header.fill_text(column_header_text, x + 10, y[i] + self.cell_size, max_width = self.layout_offset * 0.9) + column_header.fill_text(column_header_text, x + 10, y[i] + self.cell_size, + max_width=self.layout_offset * 0.9) self.colorbar_x = self.size * 1.1 gradient = grid.create_linear_gradient(self.colorbar_x, self.layout_offset, self.size * 1.2, self.size, - [(i/len(self.colors),self.colors[-i-1]) for i in range(len(self.colors))]) #color gradient for color-bar + [(i / len(self.colors), self.colors[-i - 1]) for i in + range(len(self.colors))]) # color gradient for color-bar grid.fill_style = gradient - grid.fill_rect(self.colorbar_x, self.layout_offset, 20 , self.size - self.layout_offset) + grid.fill_rect(self.colorbar_x, self.layout_offset, 20, self.size - self.layout_offset) grid.fill_style = "black" for i in range(7): label_text = f"--{round(matrix.max() * (6 - i) / 6, 2)}" - color_bar.fill_text(label_text, self.colorbar_x + 20, self.size * 0.8 / 6.1 * i + self.layout_offset + 5) #labels for colorbar - + color_bar.fill_text(label_text, self.colorbar_x + 20, + self.size * 0.8 / 6.1 * i + self.layout_offset + 5) # labels for colorbar + return matrix_full - def _generate_color(self, connectivity, i = 0, j = 0, matrix_name = None , value = None): - self.colors = ["#66797b", "#543146", "#5a1c5d", "#b468ab", "#6ade42", "#27913c", "#1c464a", - "#247663", "#38bcaa", "#a9e9ff", "#61cfff", "#37a5c1", "#e4e4e2", "#ff9f25", - "#fb5226"] + def _generate_color(self, connectivity, i=0, j=0, matrix_name=None, value=None): + self.colors = ["#66797b", "#543146", "#5a1c5d", "#b468ab", "#6ade42", "#27913c", "#1c464a", + "#247663", "#38bcaa", "#a9e9ff", "#61cfff", "#37a5c1", "#e4e4e2", "#ff9f25", + "#fb5226"] color_scheme = mcolors.LinearSegmentedColormap.from_list('color_scheme', self.colors) matrix = getattr(connectivity, matrix_name) - norm = mcolors.Normalize(vmin = 0, vmax = matrix.max()) + norm = mcolors.Normalize(vmin=0, vmax=matrix.max()) if not isinstance(value, np.ndarray): - if not value: + if not value: value = matrix[int(self.from_row + i)][int(self.from_col + j)] color = color_scheme(norm(value)) - color = f"rgba({color[0]*255:.0f}, {color[1]*255:.0f}, {color[2]*255:.0f}, {color[3]:.2f})" + color = f"rgba({color[0] * 255:.0f}, {color[1] * 255:.0f}, {color[2] * 255:.0f}, {color[3]:.2f})" return color - + colors = color_scheme(norm(value)) colors = colors[:, :, :3] * 255 return colors - + def set_mouse_position(self, x, y): self.x_coord = x self.y_coord = y - + def on_cell_clicked(self, x, y, matrix_name): self.clicked_matrix = matrix_name - x_coord, y_coord = self.x_coord , self.y_coord - col = ((x_coord - self.layout_offset) // self.cell_size) + x_coord, y_coord = self.x_coord, self.y_coord + col = ((x_coord - self.layout_offset) // self.cell_size) row = ((y_coord - self.layout_offset) // self.cell_size) - if row > -1 and row < self.num_rows and col > -1 and col < self.num_rows: + if -1 < row < self.num_rows and -1 < col < self.num_rows: self.row = row self.col = col connectivity = self.new_connectivity if self.is_connectivity_being_edited else self.connectivity @@ -223,7 +232,6 @@ def on_cell_clicked(self, x, y, matrix_name): matrix_ui[5].stroke_style = "white" matrix_ui[5].stroke_rect(x, y, self.cell_size, self.cell_size) - def on_apply_change(self, change): self.is_connectivity_being_edited = True @@ -233,14 +241,14 @@ def on_apply_change(self, change): value = float(self.cell_value.value) except: value = None - + if value is not None: - matrix_name = self.clicked_matrix + matrix_name = self.clicked_matrix matrix = getattr(self.new_connectivity, matrix_name) max_val = matrix.max() matrix[self.from_row + int(self.row)][self.from_col + int(self.col)] = value if max_val != matrix.max(): - self._update_matrices_view(self.new_connectivity) + self._update_matrices_view(self.new_connectivity) self.cell_value.layout.visibility = "hidden" self.change_button.layout.visibility = "hidden" @@ -249,7 +257,8 @@ def on_apply_change(self, change): y = self.layout_offset + self.row * self.cell_size with canvas.hold_canvas(matrix_ui[0]): - matrix_ui[0].fill_style = self._generate_color(self.new_connectivity, self.row, self.col, self.clicked_matrix, value) + matrix_ui[0].fill_style = self._generate_color(self.new_connectivity, self.row, self.col, + self.clicked_matrix, value) matrix_ui[0].fill_rect(x, y, self.cell_size, self.cell_size) matrix_ui[0].stroke_rect(x, y, self.cell_size, self.cell_size) @@ -257,14 +266,14 @@ def on_apply_change(self, change): def saved_connectivities(self): conn_list = [] - for conn in self.connectivities_history: + for conn in self.connectivity_history_list: conn_list.append(conn.gid.hex) - return conn_list + return conn_list - def get_connectivity(self, gid = None): + def get_connectivity(self, gid=None): if gid is None: return self.connectivity - for conn in self.connectivities_history: + for conn in self.connectivity_history_list: if conn.gid.hex == gid: return conn @@ -273,14 +282,14 @@ def on_click_save(self, change): self.change_button.layout.visibility = "hidden" conn = self.new_connectivity - self.connectivities_history.insert(0, conn) + self.connectivity_history_list.insert(0, conn) self.connectivity = conn self.header.children = list(self.header.children)[:-1] + [self._get_history_dropdown()] self.new_connectivity = self._prepare_new_connectivity() self.is_connectivity_being_edited = False self._update_matrices_view(self.connectivity) - + def _prepare_new_connectivity(self): new_connectivity = Connectivity() new_connectivity.parent_connectivity = self.connectivity.gid.hex @@ -293,18 +302,18 @@ def _prepare_new_connectivity(self): new_connectivity.weights = self.connectivity.weights new_connectivity.tract_lengths = self.connectivity.tract_lengths new_connectivity.configure() - + return new_connectivity - + def _get_history_dropdown(self): - values = [(conn.gid.hex, conn) for conn in self.connectivities_history] + values = [(conn.gid.hex, conn) for conn in self.connectivity_history_list] default = values[values.index((self.connectivity.gid.hex, self.connectivity))][1] - dropdown = widgets.Dropdown(options = values, - description = 'View history', - disabled = False, - value = default, - ) + dropdown = widgets.Dropdown(options=values, + description='View history', + disabled=False, + value=default, + ) def on_connectivity_change(change): self.cell_value.layout.visibility = "hidden" @@ -317,7 +326,7 @@ def on_connectivity_change(change): dropdown.observe(on_connectivity_change, 'value') return dropdown - + def _update_matrices_view(self, connectivity): matrices = ["weights", "tract_lengths"] for matrix_name in matrices: @@ -331,27 +340,32 @@ def _update_matrices_view(self, connectivity): matrix_view[3].clear() matrix_view[5].clear() - value = matrix[self.from_row : self.from_row + self.num_rows, self.from_col : self.from_col + self.num_rows] - color = self._generate_color(connectivity, value = value, matrix_name = matrix_name) - matrix_view[0].fill_styled_rects(self.cell_x, self.cell_y, color = color, height = self.cell_size, width = self.cell_size) - + value = matrix[self.from_row: self.from_row + self.num_rows, + self.from_col: self.from_col + self.num_rows] + color = self._generate_color(connectivity, value=value, matrix_name=matrix_name) + matrix_view[0].fill_styled_rects(self.cell_x, self.cell_y, color=color, height=self.cell_size, + width=self.cell_size) + max_value = matrix.max() region_labels = self.connectivity.region_labels - x = 0 + x = 0 y = np.linspace(self.layout_offset, self.size, self.num_rows) for i in range(self.num_rows): row_label = region_labels[self.from_row + i] - matrix_view[1].fill_text(row_label, x + 10, y[i] + self.cell_size, max_width = self.layout_offset * 0.9) + matrix_view[1].fill_text(row_label, x + 10, y[i] + self.cell_size, + max_width=self.layout_offset * 0.9) for i in range(self.num_rows): col_label = region_labels[self.from_col + i] - matrix_view[2].fill_text(col_label, x + 10, y[i] + self.cell_size, max_width = self.layout_offset * 0.9) + matrix_view[2].fill_text(col_label, x + 10, y[i] + self.cell_size, + max_width=self.layout_offset * 0.9) for i in range(7): value = f"--{round(max_value * (6 - i) / 6, 2)}" - matrix_view[3].fill_text(value, self.colorbar_x + 20, ((self.size - self.layout_offset) / 6.1) * i + self.layout_offset + 5) #labels for colorbar + matrix_view[3].fill_text(value, self.colorbar_x + 20, ((self.size - self.layout_offset) / 6.1) * i + + self.layout_offset + 5) # labels for colorbar def display(self): display(self.header) - display(self.tab) \ No newline at end of file + display(self.tab)