Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hypergraph refactor #77

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ numpy>=1.19.5
quantities>=0.14.1
matplotlib>=3.3.2
seaborn>=0.9.0
bokeh>=3.0.0
holoviews>=1.16.0
networkx>=3.0.0
32 changes: 25 additions & 7 deletions viziphant/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,8 @@ def plot_patterns(spiketrains, patterns, circle_sizes=(3, 50, 70),
axes.yaxis.set_label_coords(-0.01, 0.5)
return axes

def plot_patterns_hypergraph(patterns, num_neurons=None):
def plot_patterns_hypergraph(patterns, pattern_size=None, num_neurons=None,\
must_involve_neuron=None, node_size=3, node_color='white', node_linewidth=1):
"""
Hypergraph visualization of spike patterns.

Expand Down Expand Up @@ -429,11 +430,22 @@ def plot_patterns_hypergraph(patterns, num_neurons=None):
:func:`elephant.spade.spade` or
:func:`elephant.cell_assembly_detection.cell_assembly_detection`
pattern detectors.
node_size (optional): int
Change the size of the drawen nodes
pattern_size (optional): range
Only draw patterns that are in range of pattern_size
num_neurons: None or int
If None, only the neurons that are part of a pattern are shown. If an
integer is passed, it identifies the total number of recorded neurons
including non-pattern neurons to be additionally shown in the graph.
Default: None
must_involve_neuron (optional) : int
Highlight pattern which includes neuron x
node_color (optional) : String
change the color of the nodes

node_linewidth (optional) : int
change the line width of the nodes
Comment on lines +433 to +448
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the order of the individual options seems a bit random. I think users will expect first the options that control what is shown (e.g., number of neurons), then how its shown (e.g., node color).


Returns
-------
Expand Down Expand Up @@ -461,8 +473,7 @@ def plot_patterns_hypergraph(patterns, num_neurons=None):
bst.rescale('ms')
patterns = cell_assembly_detection(bst, max_lag=2)

fig = viziphant.patterns.plot_patterns_hypergraph(patterns)
plt.show()
viziphant.patterns.plot_patterns_hypergraph(patterns)

Comment on lines -465 to 477
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the reason to remove the fig = and plt.show() parts?

"""
# If only patterns of a single dataset are given, wrap them in a list to
Expand Down Expand Up @@ -496,8 +507,15 @@ def plot_patterns_hypergraph(patterns, num_neurons=None):
# Create one hyperedge from every pattern
for pattern in patterns:
# A hyperedge is the set of neurons of a pattern
hyperedges.append(pattern['neurons'])

if pattern_size is None or len(pattern['neurons']) in pattern_size:
hyperedges.append(pattern['neurons'])

if must_involve_neuron is not None and isinstance(must_involve_neuron, int):
hyperedges = [edge for edge in hyperedges if must_involve_neuron in edge]

elif must_involve_neuron is not None and isinstance(must_involve_neuron, list):
hyperedges = [edge for edge in hyperedges if any(elem in edge for elem in must_involve_neuron)]

# Currently, all hyperedges receive the same weights
weights = [weight] * len(hyperedges)

Expand All @@ -507,8 +525,8 @@ def plot_patterns_hypergraph(patterns, num_neurons=None):
weights=weights,
repulse=repulsive)
hypergraphs.append(hg)

view = View(hypergraphs)
view = View(hypergraphs=hypergraphs, node_size=node_size,
node_color=node_color, node_linewidth=node_linewidth)
fig = view.show(subset_style=VisualizationStyle.COLOR,
triangulation_style=VisualizationStyle.INVISIBLE)

Expand Down
16 changes: 12 additions & 4 deletions viziphant/patterns_src/hypergraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,23 @@ def complete_and_star_associated_graph(self):
edges = []
weights = []
graph_vertices = list(self.vertices.copy())
if isinstance(self.vertices[0], int):
max_vertex = max(max(hyperedge) for hyperedge in self.hyperedges)
else:
max_vertex = 0
if isinstance(max_vertex, str):
max_vertex = 0
Comment on lines +107 to +112
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you include comments on the logic of these if statements? Why int and str?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you need this to get unique IDs for the pseudo vertecies?

for i, hyperedge in enumerate(self.hyperedges):
# Pseudo-vertex corresponding to hyperedge
graph_vertices.append(-i - 1)
pseudo_vertex = max_vertex + i + 1
graph_vertices.append(pseudo_vertex)

for j, vertex in enumerate(hyperedge):
# Every vertex of a hyperedge is adjacent to the pseudo-vertex
# corresponding to the hyperedge
edges.append([-i - 1, vertex])
# Weight is equal to the weight of the hyperedge (if
# applicable)
edges.append([pseudo_vertex, vertex])

# Weight is equal to the weight of the hyperedge (if applicable)
if self.weights:
weights.append(self.weights[i])
# Unique unordered combinations of vertices of this hyperedge
Expand Down
45 changes: 29 additions & 16 deletions viziphant/patterns_src/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from holoviews import opts
from holoviews.streams import Pipe
import numpy as np
import matplotlib.pyplot as plt

from viziphant.patterns_src.hypergraph import Hypergraph

Expand All @@ -33,8 +34,7 @@ class View:
In summary, this class represents an interactive tool
for the visualization of hypergraphs.
"""

def __init__(self, hypergraphs, title=None):
def __init__(self, hypergraphs, node_size=3, node_color='white', node_linewidth=1, title=None):
"""
Constructs a View object that handles the visualization
of the given hypergraphs.
Expand All @@ -44,6 +44,15 @@ def __init__(self, hypergraphs, title=None):
hypergraphs: list of Hypergraph objects
Hypergraphs to be visualized.
Each hypergraph should contain data of one data set.

node_size (optional) : int
Size of the nodes in the Hypergraphs

node_color (optional) : String
change the color of the nodes

node_linewidth (optional) : int
change the line width of the nodes
Comment on lines +55 to +63
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
node_size (optional) : int
Size of the nodes in the Hypergraphs
node_color (optional) : String
change the color of the nodes
node_linewidth (optional) : int
change the line width of the nodes
node_size (optional) : int
Size of the nodes in the Hypergraphs
node_color (optional) : String
change the color of the nodes
node_linewidth (optional) : int
change the line width of the nodes

"""

# Hyperedge drawings
Expand All @@ -52,8 +61,17 @@ def __init__(self, hypergraphs, title=None):
# Which color of the color map to use next
self.current_color = 1

# Size of the vertices
self.node_radius = 0.2
# radius of the hyperedges
self.node_radius = .2

# Size of the nodes (vertices of hypergraph)
self.node_size = node_size

# Color of the nodes
self.node_color = node_color

# Width of the Node lines
self.node_linewidth = node_linewidth

# Selected title of the figure
self.title = title
Expand Down Expand Up @@ -123,8 +141,7 @@ def _setup_graph_visualization(self):
# All in black
cmap=['#ffffff', '#ffffff'] * 50,
# Size of the nodes
node_size=self.node_radius))

node_size=self.node_size, node_color=self.node_color, node_linewidth=self.node_linewidth, show_legend=True))
return dynamic_map, pipe

def _setup_hyperedge_drawing(self):
Expand Down Expand Up @@ -172,11 +189,6 @@ def create_polygon(*args, **kwargs):
# differently
import colorcet
cmap = colorcet.glasbey[:len(self.hypergraphs[0].hyperedges)]
elif self.n_hypergraphs <= 10:
# Select Category10 colormap as default for up to 10 data sets
# This is an often used colormap
from bokeh.palettes import all_palettes
cmap = list(all_palettes['Category10'][10][1:self.n_hypergraphs+1])[::-1]
else:
# For larger numbers of data sets, select Glasbey colormap
import colorcet
Expand Down Expand Up @@ -229,10 +241,11 @@ def show(self,
plot = self.dynamic_map * self.dynamic_map_edges
# Set size of the plot to a square to avoid distortions
self.plot = plot.redim.range(x=(-1, 11), y=(-1, 11))

return hv.render(plot, backend="matplotlib")

def draw_hyperedges(self,
# TODO: how to get axes? currently figure
fig = hv.render(plot, backend="matplotlib")
return fig

def draw_hyperedges(self, highlight_neuron=None,
subset_style=VisualizationStyle.COLOR,
triangulation_style=VisualizationStyle.INVISIBLE):
"""
Expand Down Expand Up @@ -399,7 +412,7 @@ def _update_nodes(self, data):
nodes = hv.Nodes((pos_x, pos_y, vertex_ids, vertex_labels),
extents=(0.01, 0.01, 0.01, 0.01),
vdims='Label')

new_data = ((edge_source, edge_target), nodes)
self.pipe.send(new_data)

Expand Down
Loading
Loading