diff --git a/include/bbp/sonata/edges.h b/include/bbp/sonata/edges.h index 11cb7c91..9d6b9522 100644 --- a/include/bbp/sonata/edges.h +++ b/include/bbp/sonata/edges.h @@ -58,6 +58,16 @@ class SONATA_API EdgePopulation : public Population * Find edges connecting two given nodes. */ Selection connectingEdges(const std::vector& source, const std::vector& target) const; + + /** + * Write bidirectional node->edge indices to EdgePopulation HDF5. + */ + static void writeIndices( + const std::string& h5FilePath, const std::string& population, + NodeID maxSourceNodeID, + NodeID maxTargetNodeID, + bool overwrite = false + ); }; //-------------------------------------------------------------------------------------------------- diff --git a/python/bindings.cpp b/python/bindings.cpp index 5cddee6a..235f982f 100644 --- a/python/bindings.cpp +++ b/python/bindings.cpp @@ -433,6 +433,16 @@ PYBIND11_MODULE(libsonata, m) "target"_a, "Find all edges connecting two given node sets" ) + .def_static( + "write_indices", + &EdgePopulation::writeIndices, + "h5_filepath"_a, + "population"_a, + "max_source_node_id"_a, + "max_target_node_id"_a, + "overwrite"_a, + "Write bidirectional node->edge indices to EdgePopulation HDF5" + ) ; bindStorageClass( diff --git a/src/edge_index.cpp b/src/edge_index.cpp index 7ff26920..e55711a6 100644 --- a/src/edge_index.cpp +++ b/src/edge_index.cpp @@ -14,6 +14,7 @@ #include #include #include +#include namespace bbp { @@ -24,6 +25,10 @@ namespace { typedef std::vector> RawIndex; +const char* SOURCE_NODE_ID_DSET = "source_node_id"; +const char* TARGET_NODE_ID_DSET = "target_node_id"; + +const char* INDEX_GROUP = "indices"; const char* SOURCE_INDEX_GROUP = "indices/source_to_target"; const char* TARGET_INDEX_GROUP = "indices/target_to_source"; const char* NODE_ID_TO_RANGES_DSET = "node_id_to_ranges"; @@ -106,6 +111,127 @@ Selection resolve(const HighFive::Group& indexGroup, const std::vector& return Selection::fromValues(result); } + +namespace { + +std::unordered_map _groupNodeRanges(const std::vector& nodeIDs) +{ + std::unordered_map result; + + if (nodeIDs.empty()) { + return result; + } + + uint64_t i0 = 0; + NodeID lastNodeID = nodeIDs[0]; + for (uint64_t i = 1; i < nodeIDs.size(); ++i) { + if (nodeIDs[i] != lastNodeID) { + result[lastNodeID].push_back({i0, i}); + i0 = i; + lastNodeID = nodeIDs[i0]; + } + } + + result[lastNodeID].push_back({i0, nodeIDs.size()}); + + return result; +} + + +std::vector _readNodeIDs(const HighFive::Group& h5Root, const std::string& name) +{ + std::vector result; + h5Root.getDataSet(name).read(result); + return result; +} + + +void _writeIndexDataset(const RawIndex& data, const std::string& name, HighFive::Group& h5Group) +{ + auto dset = h5Group.createDataSet(name, HighFive::DataSpace::From(data)); + dset.write(data); +} + + +void _writeIndexGroup(const std::vector& nodeIDs, NodeID maxNodeID, HighFive::Group& h5Root, const std::string& name) +{ + auto indexGroup = h5Root.createGroup(name); + + auto nodeToRanges = _groupNodeRanges(nodeIDs); + const auto rangeCount = std::accumulate( + nodeToRanges.begin(), nodeToRanges.end(), uint64_t(0), + [](uint64_t total, const std::pair& item) { + return total + item.second.size(); + } + ); + + RawIndex primaryIndex; + RawIndex secondaryIndex; + + primaryIndex.reserve(maxNodeID); + secondaryIndex.reserve(rangeCount); + + uint64_t offset = 0; + for (NodeID nodeID = 0; nodeID < maxNodeID; ++nodeID) { + const auto it = nodeToRanges.find(nodeID); + if (it == nodeToRanges.end()) { + primaryIndex.push_back({offset, offset}); + } else { + auto& ranges = it->second; + primaryIndex.push_back({offset, offset + ranges.size()}); + offset += ranges.size(); + std::move( + ranges.begin(), ranges.end(), + std::back_inserter(secondaryIndex) + ); + } + } + + _writeIndexDataset(primaryIndex, NODE_ID_TO_RANGES_DSET, indexGroup); + _writeIndexDataset(secondaryIndex, RANGE_TO_EDGE_ID_DSET, indexGroup); +} + +} // unnamed namespace + + +void write( + HighFive::Group& h5Root, + NodeID maxSourceNodeID, + NodeID maxTargetNodeID, + bool overwrite +) +{ + if (h5Root.exist(INDEX_GROUP)) { + if (overwrite) { + // TODO: remove INDEX_GROUP + throw SonataError("Index overwrite not implemented yet"); + } else { + throw SonataError("Index group already exists"); + } + } + + try { + _writeIndexGroup( + _readNodeIDs(h5Root, SOURCE_NODE_ID_DSET), + maxSourceNodeID, + h5Root, + SOURCE_INDEX_GROUP + ); + _writeIndexGroup( + _readNodeIDs(h5Root, TARGET_NODE_ID_DSET), + maxTargetNodeID, + h5Root, + TARGET_INDEX_GROUP + ); + } catch(...) { + try { + // TODO: remove INDEX_GROUP + } catch(...) { + } + throw; + } +} + } } } // namespace bbp::sonata::edge_index \ No newline at end of file diff --git a/src/edge_index.h b/src/edge_index.h index 35a36c32..017e83b5 100644 --- a/src/edge_index.h +++ b/src/edge_index.h @@ -24,6 +24,13 @@ const HighFive::Group targetIndex(const HighFive::Group& h5Root); Selection resolve(const HighFive::Group& indexGroup, NodeID nodeID); Selection resolve(const HighFive::Group& indexGroup, const std::vector& nodeIDs); +void write( + HighFive::Group& h5Root, + NodeID maxSourceNodeID, + NodeID maxTargetNodeID, + bool overwrite +); + } } } // namespace bbp::sonata::edge_index \ No newline at end of file diff --git a/src/edges.cpp b/src/edges.cpp index 678295ef..1894275d 100644 --- a/src/edges.cpp +++ b/src/edges.cpp @@ -112,6 +112,22 @@ Selection EdgePopulation::connectingEdges(const std::vector& source, con return Selection::fromValues(result); } +//-------------------------------------------------------------------------------------------------- + +void EdgePopulation::writeIndices( + const std::string& h5FilePath, const std::string& population, + NodeID maxSourceNodeID, + NodeID maxTargetNodeID, + bool overwrite +) +{ + HDF5_LOCK_GUARD + HighFive::File h5File(h5FilePath, HighFive::File::ReadWrite); + auto h5Root = h5File.getGroup(fmt::format("/edges/{}", population)); + edge_index::write(h5Root, maxSourceNodeID, maxTargetNodeID, overwrite); +} + + //-------------------------------------------------------------------------------------------------- constexpr const char* EdgePopulation::ELEMENT; diff --git a/tests/data/edges-no-index.h5 b/tests/data/edges-no-index.h5 new file mode 100644 index 00000000..4e8e8be9 Binary files /dev/null and b/tests/data/edges-no-index.h5 differ diff --git a/tests/test_edges.cpp b/tests/test_edges.cpp index 5fb99d7d..91d72d42 100644 --- a/tests/test_edges.cpp +++ b/tests/test_edges.cpp @@ -2,7 +2,9 @@ #include +#include #include +#include #include #include @@ -116,6 +118,61 @@ TEST_CASE("EdgePopulation", "[edges]") } +namespace { + +// TODO: remove after switching to C++17 +void copyFile(const std::string& srcFilePath, const std::string& dstFilePath) +{ + std::ifstream src(srcFilePath, std::ios::binary); + std::ofstream dst(dstFilePath, std::ios::binary); + dst << src.rdbuf(); +} + +} // unnamed namespace + + +TEST_CASE("EdgePopulation::writeIndices", "[edges]") +{ + const std::string srcFilePath = "./data/edges-no-index.h5"; + const std::string dstFilePath = "./data/edges-no-index.h5.tmp"; + { + const EdgePopulation population(srcFilePath, "", "edges-AB"); + + // no index datasets yet + CHECK_THROWS_AS( + population.afferentEdges({1, 2}), + SonataError + ); + CHECK_THROWS_AS( + population.efferentEdges({1, 2}), + SonataError + ); + } + + + copyFile(srcFilePath, dstFilePath); + + try { + EdgePopulation::writeIndices(dstFilePath, "edges-AB", 4, 4); + const EdgePopulation population(dstFilePath, "", "edges-AB"); + CHECK( + population.afferentEdges({1, 2}) == Selection({{0, 4}, {5, 6}}) + ); + CHECK( + population.efferentEdges({1, 2}) == Selection({{0, 4}}) + ); + } catch(...) { + try { + std::remove(dstFilePath.c_str()); + } catch (...) { + } + throw; + } + + std::remove(dstFilePath.c_str()); +} + + TEST_CASE("EdgeStorage", "[edges]") { // CSV not supported at the moment