Skip to content

Commit

Permalink
report getNodeIdElementIdMapping function(#168)
Browse files Browse the repository at this point in the history
* allows to get the mapping from a circuit Node Id(s), and its element IDs in a report

Co-authored-by: Sergio <[email protected]>
Co-authored-by: Mike Gevaert <[email protected]>
Co-authored-by: Nadir Román Guerrero <[email protected]>
  • Loading branch information
4 people authored Dec 15, 2021
1 parent 970a7da commit fcc8483
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 39 deletions.
23 changes: 23 additions & 0 deletions include/bbp/sonata/report_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,29 @@ class SONATA_API ReportReader
* Return true if the data is sorted.
*/
bool getSorted() const;

/**
* Return all the node ids.
*/
std::vector<NodeID> getNodeIds() const;

/**
* Return the ElementIds for the passed Node.
* The return type will depend on the report reader:
* - For Soma report reader, the return value will be the Node ID to which the report
* value belongs to.
* - For Element/full compartment readers, the return value will be an array with 2
* elements, the first element is the Node ID and the second element is the
* compartment ID of the given Node.
*
* \param node_ids limit the report to the given selection. If nullptr, all nodes in the
* report are used
* \param fn lambda applied to all ranges for all node ids
*/
typename DataFrame<KeyType>::DataType getNodeIdElementIdMapping(
const nonstd::optional<Selection>& node_ids = nonstd::nullopt,
std::function<void(const Range&)> fn = nullptr) const;

/**
* \param node_ids limit the report to the given selection.
* \param tstart return voltages occurring on or after tstart. tstart=nonstd::nullopt
Expand All @@ -154,6 +175,8 @@ class SONATA_API ReportReader
std::string time_units_;
std::string data_units_;
bool nodes_ids_sorted_ = false;
Selection::Values node_ids_from_selection(
const nonstd::optional<Selection>& node_ids = nonstd::nullopt) const;

friend ReportReader;
};
Expand Down
8 changes: 8 additions & 0 deletions python/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,14 @@ void bindReportReader(py::module& m, const std::string& prefix) {
.def("get_node_ids",
&ReportType::Population::getNodeIds,
"Return the list of nodes ids for this population")
.def(
"get_node_id_element_id_mapping",
[](const typename ReportType::Population& population,
const nonstd::optional<Selection>& selection) {
return population.getNodeIdElementIdMapping(selection, nullptr);
},
DOC_REPORTREADER_POP(getNodeIdElementIdMapping),
"selection"_a = nonstd::nullopt)
.def_property_readonly("sorted",
&ReportType::Population::getSorted,
DOC_REPORTREADER_POP(getSorted))
Expand Down
19 changes: 18 additions & 1 deletion python/generated/docstrings.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,14 +431,31 @@ static const char *__doc_bbp_sonata_ReportReader_Population_getDataUnits = R"doc

static const char *__doc_bbp_sonata_ReportReader_Population_getIndex = R"doc()doc";

static const char *__doc_bbp_sonata_ReportReader_Population_getNodeIds = R"doc()doc";
static const char *__doc_bbp_sonata_ReportReader_Population_getNodeIdElementIdMapping =
R"doc(Return the ElementIds for the passed Node. The return type will depend
on the report reader: - For Soma report reader, the return value will
be the Node ID to which the report value belongs to. - For
Element/full compartment readers, the return value will be an array
with 2 elements, the first element is the Node ID and the second
element is the compartment ID of the given Node.
Parameter ``node_ids``:
limit the report to the given selection. If nullptr, all nodes in
the report are used
Parameter ``fn``:
lambda applied to all ranges for all node ids)doc";

static const char *__doc_bbp_sonata_ReportReader_Population_getNodeIds = R"doc(Return all the node ids.)doc";

static const char *__doc_bbp_sonata_ReportReader_Population_getSorted = R"doc(Return true if the data is sorted.)doc";

static const char *__doc_bbp_sonata_ReportReader_Population_getTimeUnits = R"doc(Return the unit of time)doc";

static const char *__doc_bbp_sonata_ReportReader_Population_getTimes = R"doc(Return (tstart, tstop, tstep) of the population)doc";

static const char *__doc_bbp_sonata_ReportReader_Population_node_ids_from_selection = R"doc()doc";

static const char *__doc_bbp_sonata_ReportReader_Population_nodes_ids = R"doc()doc";

static const char *__doc_bbp_sonata_ReportReader_Population_nodes_ids_sorted = R"doc()doc";
Expand Down
9 changes: 9 additions & 0 deletions python/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def test_get_spikes_from_population(self):
def test_getTimes_from_population(self):
self.assertEqual(self.test_obj['All'].times, (0.1, 1.3))


class TestSomaReportPopulation(unittest.TestCase):
def setUp(self):
path = os.path.join(PATH, "somas.h5")
Expand Down Expand Up @@ -332,6 +333,10 @@ def test_get_reports_from_population(self):
sel_empty = self.test_obj['All'].get(node_ids=[])
np.testing.assert_allclose(sel_empty.data, np.empty(shape=(0, 0)))

def test_get_node_id_element_id_mapping(self):
self.assertEqual(self.test_obj['All'].get_node_id_element_id_mapping([[3, 5]]),
[3, 4])


class TestElementReportPopulation(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -384,6 +389,10 @@ def test_get_reports_from_population(self):
np.testing.assert_allclose(self.test_obj['All'].get(node_ids=[3, 4], tstride=4).data[2],
[81.0, 81.1, 81.2, 81.3, 81.4, 81.5, 81.6, 81.7, 81.8, 81.9], 1e-6, 0)

def test_get_node_id_element_id_mapping(self):
self.assertEqual(self.test_obj['All'].get_node_id_element_id_mapping([[3, 5]]),
[[3, 5], [3, 5], [3, 6], [3, 6], [3, 7], [4, 7], [4, 8], [4, 8], [4, 9], [4, 9]])


class TestNodePopulationFailure(unittest.TestCase):
def test_CorrectStructure(self):
Expand Down
93 changes: 55 additions & 38 deletions src/report_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,27 @@ std::vector<NodeID> ReportReader<T>::Population::getNodeIds() const {
return nodes_ids_;
}

template <typename T>
Selection::Values ReportReader<T>::Population::node_ids_from_selection(
const nonstd::optional<Selection>& selection) const {
Selection::Values node_ids;

if (!selection) { // Take all nodes in this case
node_ids.reserve(nodes_pointers_.size());
std::transform(nodes_pointers_.begin(),
nodes_pointers_.end(),
std::back_inserter(node_ids),
[](const std::pair<NodeID, Range>& node_pointer) {
return node_pointer.first;
});
} else if (selection->empty()) {
return {};
} else {
node_ids = selection->flatten();
}
return node_ids;
}

template <typename T>
std::pair<size_t, size_t> ReportReader<T>::Population::getIndex(
const nonstd::optional<double>& tstart, const nonstd::optional<double>& tstop) const {
Expand Down Expand Up @@ -318,6 +339,34 @@ std::pair<size_t, size_t> ReportReader<T>::Population::getIndex(
}


template <typename T>
typename DataFrame<T>::DataType ReportReader<T>::Population::getNodeIdElementIdMapping(
const nonstd::optional<Selection>& selection, std::function<void(const Range&)> fn) const {
typename DataFrame<T>::DataType ids{};

Selection::Values node_ids = node_ids_from_selection(selection);

auto dataset_elem_ids = pop_group_.getGroup("mapping").getDataSet("element_ids");
for (const auto& node_id : node_ids) {
const auto it = nodes_pointers_.find(node_id);
if (it == nodes_pointers_.end()) {
continue;
}

std::vector<ElementID> element_ids(it->second.second - it->second.first);
dataset_elem_ids.select({it->second.first}, {it->second.second - it->second.first})
.read(element_ids.data());
for (const auto& elem : element_ids) {
ids.push_back(make_key<T>(node_id, elem));
}

if (fn) {
fn(it->second);
}
}
return ids;
}

template <typename T>
DataFrame<T> ReportReader<T>::Population::get(const nonstd::optional<Selection>& selection,
const nonstd::optional<double>& tstart,
Expand All @@ -339,48 +388,16 @@ DataFrame<T> ReportReader<T>::Population::get(const nonstd::optional<Selection>&
data_frame.times.push_back(times_index_[i].second);
}

// Simplify selection
// We should remove duplicates
// And when we can work with ranges let's sort them
// auto nodes_ids_ = Selection::fromValues(node_ids.flatten().sort());
Selection::Values node_ids;

if (!selection) { // Take all nodes in this case
node_ids.reserve(nodes_pointers_.size());
std::transform(nodes_pointers_.begin(),
nodes_pointers_.end(),
std::back_inserter(node_ids),
[](const std::pair<NodeID, Range>& node_pointer) {
return node_pointer.first;
});
} else if (selection->empty()) {
return DataFrame<T>{{}, {}, {}};
} else {
node_ids = selection->flatten();
}

Ranges positions;
// min and max offsets of the node_ids requested are calculated
// to reduce the amount of IO that is brought to memory
Ranges positions;
uint64_t min = std::numeric_limits<uint64_t>::max();
uint64_t max = std::numeric_limits<uint64_t>::min();
auto dataset_elem_ids = pop_group_.getGroup("mapping").getDataSet("element_ids");
for (const auto& node_id : node_ids) {
const auto it = nodes_pointers_.find(node_id);
if (it == nodes_pointers_.end()) {
continue;
}
min = std::min(it->second.first, min);
max = std::max(it->second.second, max);
positions.emplace_back(it->second.first, it->second.second);

std::vector<ElementID> element_ids(it->second.second - it->second.first);
dataset_elem_ids.select({it->second.first}, {it->second.second - it->second.first})
.read(element_ids.data());
for (const auto& elem : element_ids) {
data_frame.ids.push_back(make_key<T>(node_id, elem));
}
}
data_frame.ids = getNodeIdElementIdMapping(selection, [&](const Range& range) {
min = std::min(range.first, min);
max = std::max(range.second, max);
positions.emplace_back(range.first, range.second);
});
if (data_frame.ids.empty()) { // At the end no data available (wrong node_ids?)
return DataFrame<T>{{}, {}, {}};
}
Expand Down
6 changes: 6 additions & 0 deletions tests/test_report_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ TEST_CASE("SomaReportReader", "[base]") {

auto data_empty = pop.get(Selection({}));
REQUIRE(data_empty.data == std::vector<float>{});

auto ids = pop.getNodeIdElementIdMapping(Selection({{3, 5}}));
REQUIRE(ids == std::vector<NodeID>{3, 4});
}

TEST_CASE("ElementReportReader limits", "[base]") {
Expand Down Expand Up @@ -155,4 +158,7 @@ TEST_CASE("ElementReportReader", "[base]") {
// Select only one time
REQUIRE(pop.get(Selection({{1, 2}}), 0.6, 0.6).data ==
std::vector<float>{30.0f, 30.1f, 30.2f, 30.3f, 30.4f});

auto ids = pop.getNodeIdElementIdMapping(Selection({{3, 5}}));
REQUIRE(ids == std::vector<CompartmentID>{{3, 5}, {3, 5}, {3, 6}, {3, 6}, {3, 7}, {4, 7}, {4, 8}, {4, 8}, {4, 9}, {4, 9}});
}

0 comments on commit fcc8483

Please sign in to comment.