From 4ccfeb55e6548a2d15636796e42e1df292c073c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Sch=C3=B6nberger?= Date: Mon, 26 Aug 2024 22:06:26 +0200 Subject: [PATCH] Implement transitive matcher with pair generator + tests (#2735) * Implement transitive matcher with pair generator + tests * d --- src/colmap/controllers/feature_matching.cc | 120 +-------------------- src/colmap/controllers/feature_matching.h | 6 +- src/colmap/feature/matcher.cc | 6 ++ src/colmap/feature/matcher.h | 5 + src/colmap/feature/pairing.cc | 96 +++++++++++++++++ src/colmap/feature/pairing.h | 29 +++++ src/colmap/feature/pairing_test.cc | 35 ++++++ 7 files changed, 176 insertions(+), 121 deletions(-) diff --git a/src/colmap/controllers/feature_matching.cc b/src/colmap/controllers/feature_matching.cc index 41aec4367..11d33728c 100644 --- a/src/colmap/controllers/feature_matching.cc +++ b/src/colmap/controllers/feature_matching.cc @@ -43,7 +43,7 @@ namespace colmap { namespace { void PrintElapsedTime(const Timer& timer) { - LOG(INFO) << StringPrintf(" in %.3fs", timer.ElapsedSeconds()); + LOG(INFO) << StringPrintf("in %.3fs", timer.ElapsedSeconds()); } template @@ -137,128 +137,12 @@ std::unique_ptr CreateSpatialFeatureMatcher( options, matching_options, geometry_options, database_path); } -namespace { - -class TransitiveFeatureMatcher : public Thread { - public: - TransitiveFeatureMatcher(const TransitiveMatchingOptions& options, - const SiftMatchingOptions& matching_options, - const TwoViewGeometryOptions& geometry_options, - const std::string& database_path) - : options_(options), - matching_options_(matching_options), - database_(std::make_shared(database_path)), - cache_(std::make_shared(options_.batch_size, - database_)), - matcher_( - matching_options, geometry_options, database_.get(), cache_.get()) { - THROW_CHECK(options.Check()); - THROW_CHECK(matching_options.Check()); - THROW_CHECK(geometry_options.Check()); - } - - private: - void Run() override { - PrintHeading1("Transitive feature matching"); - Timer run_timer; - run_timer.Start(); - - if (!matcher_.Setup()) { - return; - } - - cache_->Setup(); - - const std::vector image_ids = cache_->GetImageIds(); - - std::vector> image_pairs; - std::unordered_set image_pair_ids; - - for (int iteration = 0; iteration < options_.num_iterations; ++iteration) { - if (IsStopped()) { - run_timer.PrintMinutes(); - return; - } - - Timer timer; - timer.Start(); - - LOG(INFO) << StringPrintf( - "Iteration [%d/%d]", iteration + 1, options_.num_iterations); - - std::vector> existing_image_pairs; - std::vector existing_num_inliers; - database_->ReadTwoViewGeometryNumInliers(&existing_image_pairs, - &existing_num_inliers); - - THROW_CHECK_EQ(existing_image_pairs.size(), existing_num_inliers.size()); - - std::unordered_map> adjacency; - for (const auto& image_pair : existing_image_pairs) { - adjacency[image_pair.first].push_back(image_pair.second); - adjacency[image_pair.second].push_back(image_pair.first); - } - - const size_t batch_size = static_cast(options_.batch_size); - - size_t num_batches = 0; - image_pairs.clear(); - image_pair_ids.clear(); - for (const auto& image : adjacency) { - const auto image_id1 = image.first; - for (const auto& image_id2 : image.second) { - if (adjacency.count(image_id2) > 0) { - for (const auto& image_id3 : adjacency.at(image_id2)) { - const auto image_pair_id = - Database::ImagePairToPairId(image_id1, image_id3); - if (image_pair_ids.count(image_pair_id) == 0) { - image_pairs.emplace_back(image_id1, image_id3); - image_pair_ids.insert(image_pair_id); - if (image_pairs.size() >= batch_size) { - num_batches += 1; - LOG(INFO) << StringPrintf(" Batch %d", num_batches); - DatabaseTransaction database_transaction(database_.get()); - matcher_.Match(image_pairs); - image_pairs.clear(); - PrintElapsedTime(timer); - timer.Restart(); - - if (IsStopped()) { - run_timer.PrintMinutes(); - return; - } - } - } - } - } - } - } - - num_batches += 1; - LOG(INFO) << StringPrintf(" Batch %d", num_batches); - DatabaseTransaction database_transaction(database_.get()); - matcher_.Match(image_pairs); - PrintElapsedTime(timer); - } - - run_timer.PrintMinutes(); - } - - const TransitiveMatchingOptions options_; - const SiftMatchingOptions matching_options_; - const std::shared_ptr database_; - const std::shared_ptr cache_; - FeatureMatcherController matcher_; -}; - -} // namespace - std::unique_ptr CreateTransitiveFeatureMatcher( const TransitiveMatchingOptions& options, const SiftMatchingOptions& matching_options, const TwoViewGeometryOptions& geometry_options, const std::string& database_path) { - return std::make_unique( + return std::make_unique>( options, matching_options, geometry_options, database_path); } diff --git a/src/colmap/controllers/feature_matching.h b/src/colmap/controllers/feature_matching.h index 22b6e4c31..e50495096 100644 --- a/src/colmap/controllers/feature_matching.h +++ b/src/colmap/controllers/feature_matching.h @@ -108,9 +108,9 @@ std::unique_ptr CreateSpatialFeatureMatcher( const std::string& database_path); // Match transitive image pairs in a database with existing feature matches. -// This matcher transitively closes loops. For example, if image pairs A-B and -// B-C match but A-C has not been matched, then this matcher attempts to match -// A-C. This procedure is performed for multiple iterations. +// This matcher transitively closes loops/triplets. For example, if image pairs +// A-B and B-C match but A-C has not been matched, then this matcher attempts to +// match A-C. This procedure is performed for multiple iterations. std::unique_ptr CreateTransitiveFeatureMatcher( const TransitiveMatchingOptions& options, const SiftMatchingOptions& matching_options, diff --git a/src/colmap/feature/matcher.cc b/src/colmap/feature/matcher.cc index 7328ec4cb..96ca0e89b 100644 --- a/src/colmap/feature/matcher.cc +++ b/src/colmap/feature/matcher.cc @@ -103,6 +103,12 @@ void FeatureMatcherCache::Setup() { }); } +void FeatureMatcherCache::AccessDatabase( + const std::function& func) { + std::lock_guard lock(database_mutex_); + func(*database_); +} + const Camera& FeatureMatcherCache::GetCamera(const camera_t camera_id) const { return cameras_cache_.at(camera_id); } diff --git a/src/colmap/feature/matcher.h b/src/colmap/feature/matcher.h index 8959f673d..9394455a0 100644 --- a/src/colmap/feature/matcher.h +++ b/src/colmap/feature/matcher.h @@ -78,6 +78,11 @@ class FeatureMatcherCache { void Setup(); + // Executes a function that accesses the database. This function is thread + // safe and ensures that only one function can access the database at a time. + void AccessDatabase( + const std::function& func); + const Camera& GetCamera(camera_t camera_id) const; const Image& GetImage(image_t image_id) const; const PosePrior& GetPosePrior(image_t image_id) const; diff --git a/src/colmap/feature/pairing.cc b/src/colmap/feature/pairing.cc index 7ecc55707..42a62bd62 100644 --- a/src/colmap/feature/pairing.cc +++ b/src/colmap/feature/pairing.cc @@ -643,6 +643,102 @@ SpatialPairGenerator::ReadPositionPriorData(const FeatureMatcherCache& cache) { return position_matrix; } +TransitivePairGenerator::TransitivePairGenerator( + const TransitiveMatchingOptions& options, + const std::shared_ptr& cache) + : options_(options), cache_(cache) { + THROW_CHECK(options.Check()); +} + +TransitivePairGenerator::TransitivePairGenerator( + const TransitiveMatchingOptions& options, + const std::shared_ptr& database) + : TransitivePairGenerator( + options, + std::make_shared(CacheSize(options), + THROW_CHECK_NOTNULL(database), + /*do_setup=*/true)) {} + +void TransitivePairGenerator::Reset() { + current_iteration_ = 0; + current_batch_idx_ = 0; + image_pairs_.clear(); + image_pair_ids_.clear(); +} + +bool TransitivePairGenerator::HasFinished() const { + return current_iteration_ >= options_.num_iterations && image_pairs_.empty(); +} + +std::vector> TransitivePairGenerator::Next() { + if (!image_pairs_.empty()) { + current_batch_idx_++; + std::vector> batch; + while (!image_pairs_.empty() && + static_cast(batch.size()) < options_.batch_size) { + batch.push_back(image_pairs_.back()); + image_pairs_.pop_back(); + } + LOG(INFO) << StringPrintf( + "Matching batch [%d/%d]", current_batch_idx_, current_num_batches_); + return batch; + } + + if (current_iteration_ >= options_.num_iterations) { + return {}; + } + + current_batch_idx_ = 0; + current_num_batches_ = 0; + current_iteration_++; + + LOG(INFO) << StringPrintf( + "Iteration [%d/%d]", current_iteration_, options_.num_iterations); + + std::vector> existing_image_pairs; + std::vector existing_num_inliers; + cache_->AccessDatabase( + [&existing_image_pairs, &existing_num_inliers](const Database& database) { + database.ReadTwoViewGeometryNumInliers(&existing_image_pairs, + &existing_num_inliers); + }); + + std::unordered_map> adjacency; + for (const auto& image_pair : existing_image_pairs) { + adjacency[image_pair.first].push_back(image_pair.second); + adjacency[image_pair.second].push_back(image_pair.first); + image_pair_ids_.insert( + Database::ImagePairToPairId(image_pair.first, image_pair.second)); + } + + for (const auto& image : adjacency) { + const auto image_id1 = image.first; + for (const auto& image_id2 : image.second) { + const auto it = adjacency.find(image_id2); + if (it == adjacency.end()) { + continue; + } + for (const auto& image_id3 : it->second) { + if (image_id1 == image_id3) { + continue; + } + const auto image_pair_id = + Database::ImagePairToPairId(image_id1, image_id3); + if (image_pair_ids_.count(image_pair_id) != 0) { + continue; + } + image_pairs_.emplace_back(image_id1, image_id3); + image_pair_ids_.insert(image_pair_id); + } + } + } + + current_num_batches_ = + std::ceil(static_cast(image_pairs_.size()) / options_.batch_size); + + return Next(); +} + ImportedPairGenerator::ImportedPairGenerator( const ImagePairsMatchingOptions& options, const std::shared_ptr& cache) diff --git a/src/colmap/feature/pairing.h b/src/colmap/feature/pairing.h index 8668471d5..98446be98 100644 --- a/src/colmap/feature/pairing.h +++ b/src/colmap/feature/pairing.h @@ -310,6 +310,35 @@ class SpatialPairGenerator : public PairGenerator { int knn_ = 0; }; +class TransitivePairGenerator : public PairGenerator { + public: + using PairOptions = TransitiveMatchingOptions; + static size_t CacheSize(const TransitiveMatchingOptions& options) { + return 2 * options.batch_size; + } + + TransitivePairGenerator(const TransitiveMatchingOptions& options, + const std::shared_ptr& cache); + + TransitivePairGenerator(const TransitiveMatchingOptions& options, + const std::shared_ptr& database); + + void Reset() override; + + bool HasFinished() const override; + + std::vector> Next() override; + + private: + const TransitiveMatchingOptions options_; + const std::shared_ptr cache_; + int current_iteration_ = 0; + int current_batch_idx_ = 0; + int current_num_batches_ = 0; + std::vector> image_pairs_; + std::unordered_set image_pair_ids_; +}; + class ImportedPairGenerator : public PairGenerator { public: using PairOptions = ImagePairsMatchingOptions; diff --git a/src/colmap/feature/pairing_test.cc b/src/colmap/feature/pairing_test.cc index 8881fd92e..3a4cb4eb0 100644 --- a/src/colmap/feature/pairing_test.cc +++ b/src/colmap/feature/pairing_test.cc @@ -271,6 +271,41 @@ TEST(SpatialPairGenerator, Nominal) { } } +TEST(TransitivePairGenerator, Nominal) { + constexpr int kNumImages = 5; + auto database = std::make_shared(Database::kInMemoryDatabasePath); + CreateSyntheticDatabase(kNumImages, *database); + const std::vector images = database->ReadAllImages(); + CHECK_EQ(images.size(), kNumImages); + + TwoViewGeometry two_view_geometry; + two_view_geometry.inlier_matches.resize(10); + + database->ClearTwoViewGeometries(); + database->WriteTwoViewGeometry( + images[0].ImageId(), images[1].ImageId(), two_view_geometry); + database->WriteTwoViewGeometry( + images[0].ImageId(), images[2].ImageId(), two_view_geometry); + database->WriteTwoViewGeometry( + images[1].ImageId(), images[3].ImageId(), two_view_geometry); + + TransitiveMatchingOptions options; + TransitivePairGenerator generator(options, database); + const auto pairs1 = generator.Next(); + EXPECT_THAT(pairs1, + testing::UnorderedElementsAre( + std::make_pair(images[2].ImageId(), images[1].ImageId()), + std::make_pair(images[3].ImageId(), images[0].ImageId()))); + for (const auto& pair : pairs1) { + database->WriteTwoViewGeometry(pair.first, pair.second, two_view_geometry); + } + EXPECT_THAT(generator.Next(), + testing::ElementsAre( + std::make_pair(images[3].ImageId(), images[2].ImageId()))); + EXPECT_TRUE(generator.Next().empty()); + EXPECT_TRUE(generator.HasFinished()); +} + TEST(ImportedPairGenerator, Nominal) { constexpr int kNumImages = 10; auto database = std::make_shared(Database::kInMemoryDatabasePath);