Skip to content

Commit

Permalink
Implement transitive matcher with pair generator + tests (colmap#2735)
Browse files Browse the repository at this point in the history
* Implement transitive matcher with pair generator + tests

* d
  • Loading branch information
ahojnnes authored Aug 26, 2024
1 parent 5b4941d commit 4ccfeb5
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 121 deletions.
120 changes: 2 additions & 118 deletions src/colmap/controllers/feature_matching.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace colmap {
namespace {

void PrintElapsedTime(const Timer& timer) {
LOG(INFO) << StringPrintf(" in %.3fs", timer.ElapsedSeconds());
LOG(INFO) << StringPrintf("in %.3fs", timer.ElapsedSeconds());
}

template <typename DerivedPairGenerator>
Expand Down Expand Up @@ -137,128 +137,12 @@ std::unique_ptr<Thread> CreateSpatialFeatureMatcher(
options, matching_options, geometry_options, database_path);
}

namespace {

class TransitiveFeatureMatcher : public Thread {
public:
TransitiveFeatureMatcher(const TransitiveMatchingOptions& options,
const SiftMatchingOptions& matching_options,
const TwoViewGeometryOptions& geometry_options,
const std::string& database_path)
: options_(options),
matching_options_(matching_options),
database_(std::make_shared<Database>(database_path)),
cache_(std::make_shared<FeatureMatcherCache>(options_.batch_size,
database_)),
matcher_(
matching_options, geometry_options, database_.get(), cache_.get()) {
THROW_CHECK(options.Check());
THROW_CHECK(matching_options.Check());
THROW_CHECK(geometry_options.Check());
}

private:
void Run() override {
PrintHeading1("Transitive feature matching");
Timer run_timer;
run_timer.Start();

if (!matcher_.Setup()) {
return;
}

cache_->Setup();

const std::vector<image_t> image_ids = cache_->GetImageIds();

std::vector<std::pair<image_t, image_t>> image_pairs;
std::unordered_set<image_pair_t> image_pair_ids;

for (int iteration = 0; iteration < options_.num_iterations; ++iteration) {
if (IsStopped()) {
run_timer.PrintMinutes();
return;
}

Timer timer;
timer.Start();

LOG(INFO) << StringPrintf(
"Iteration [%d/%d]", iteration + 1, options_.num_iterations);

std::vector<std::pair<image_t, image_t>> existing_image_pairs;
std::vector<int> existing_num_inliers;
database_->ReadTwoViewGeometryNumInliers(&existing_image_pairs,
&existing_num_inliers);

THROW_CHECK_EQ(existing_image_pairs.size(), existing_num_inliers.size());

std::unordered_map<image_t, std::vector<image_t>> adjacency;
for (const auto& image_pair : existing_image_pairs) {
adjacency[image_pair.first].push_back(image_pair.second);
adjacency[image_pair.second].push_back(image_pair.first);
}

const size_t batch_size = static_cast<size_t>(options_.batch_size);

size_t num_batches = 0;
image_pairs.clear();
image_pair_ids.clear();
for (const auto& image : adjacency) {
const auto image_id1 = image.first;
for (const auto& image_id2 : image.second) {
if (adjacency.count(image_id2) > 0) {
for (const auto& image_id3 : adjacency.at(image_id2)) {
const auto image_pair_id =
Database::ImagePairToPairId(image_id1, image_id3);
if (image_pair_ids.count(image_pair_id) == 0) {
image_pairs.emplace_back(image_id1, image_id3);
image_pair_ids.insert(image_pair_id);
if (image_pairs.size() >= batch_size) {
num_batches += 1;
LOG(INFO) << StringPrintf(" Batch %d", num_batches);
DatabaseTransaction database_transaction(database_.get());
matcher_.Match(image_pairs);
image_pairs.clear();
PrintElapsedTime(timer);
timer.Restart();

if (IsStopped()) {
run_timer.PrintMinutes();
return;
}
}
}
}
}
}
}

num_batches += 1;
LOG(INFO) << StringPrintf(" Batch %d", num_batches);
DatabaseTransaction database_transaction(database_.get());
matcher_.Match(image_pairs);
PrintElapsedTime(timer);
}

run_timer.PrintMinutes();
}

const TransitiveMatchingOptions options_;
const SiftMatchingOptions matching_options_;
const std::shared_ptr<Database> database_;
const std::shared_ptr<FeatureMatcherCache> cache_;
FeatureMatcherController matcher_;
};

} // namespace

std::unique_ptr<Thread> CreateTransitiveFeatureMatcher(
const TransitiveMatchingOptions& options,
const SiftMatchingOptions& matching_options,
const TwoViewGeometryOptions& geometry_options,
const std::string& database_path) {
return std::make_unique<TransitiveFeatureMatcher>(
return std::make_unique<GenericFeatureMatcher<TransitivePairGenerator>>(
options, matching_options, geometry_options, database_path);
}

Expand Down
6 changes: 3 additions & 3 deletions src/colmap/controllers/feature_matching.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ std::unique_ptr<Thread> CreateSpatialFeatureMatcher(
const std::string& database_path);

// Match transitive image pairs in a database with existing feature matches.
// This matcher transitively closes loops. For example, if image pairs A-B and
// B-C match but A-C has not been matched, then this matcher attempts to match
// A-C. This procedure is performed for multiple iterations.
// This matcher transitively closes loops/triplets. For example, if image pairs
// A-B and B-C match but A-C has not been matched, then this matcher attempts to
// match A-C. This procedure is performed for multiple iterations.
std::unique_ptr<Thread> CreateTransitiveFeatureMatcher(
const TransitiveMatchingOptions& options,
const SiftMatchingOptions& matching_options,
Expand Down
6 changes: 6 additions & 0 deletions src/colmap/feature/matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ void FeatureMatcherCache::Setup() {
});
}

void FeatureMatcherCache::AccessDatabase(
const std::function<void(const Database& database)>& func) {
std::lock_guard<std::mutex> lock(database_mutex_);
func(*database_);
}

const Camera& FeatureMatcherCache::GetCamera(const camera_t camera_id) const {
return cameras_cache_.at(camera_id);
}
Expand Down
5 changes: 5 additions & 0 deletions src/colmap/feature/matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class FeatureMatcherCache {

void Setup();

// Executes a function that accesses the database. This function is thread
// safe and ensures that only one function can access the database at a time.
void AccessDatabase(
const std::function<void(const Database& database)>& func);

const Camera& GetCamera(camera_t camera_id) const;
const Image& GetImage(image_t image_id) const;
const PosePrior& GetPosePrior(image_t image_id) const;
Expand Down
96 changes: 96 additions & 0 deletions src/colmap/feature/pairing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,102 @@ SpatialPairGenerator::ReadPositionPriorData(const FeatureMatcherCache& cache) {
return position_matrix;
}

TransitivePairGenerator::TransitivePairGenerator(
const TransitiveMatchingOptions& options,
const std::shared_ptr<FeatureMatcherCache>& cache)
: options_(options), cache_(cache) {
THROW_CHECK(options.Check());
}

TransitivePairGenerator::TransitivePairGenerator(
const TransitiveMatchingOptions& options,
const std::shared_ptr<Database>& database)
: TransitivePairGenerator(
options,
std::make_shared<FeatureMatcherCache>(CacheSize(options),
THROW_CHECK_NOTNULL(database),
/*do_setup=*/true)) {}

void TransitivePairGenerator::Reset() {
current_iteration_ = 0;
current_batch_idx_ = 0;
image_pairs_.clear();
image_pair_ids_.clear();
}

bool TransitivePairGenerator::HasFinished() const {
return current_iteration_ >= options_.num_iterations && image_pairs_.empty();
}

std::vector<std::pair<image_t, image_t>> TransitivePairGenerator::Next() {
if (!image_pairs_.empty()) {
current_batch_idx_++;
std::vector<std::pair<image_t, image_t>> batch;
while (!image_pairs_.empty() &&
static_cast<int>(batch.size()) < options_.batch_size) {
batch.push_back(image_pairs_.back());
image_pairs_.pop_back();
}
LOG(INFO) << StringPrintf(
"Matching batch [%d/%d]", current_batch_idx_, current_num_batches_);
return batch;
}

if (current_iteration_ >= options_.num_iterations) {
return {};
}

current_batch_idx_ = 0;
current_num_batches_ = 0;
current_iteration_++;

LOG(INFO) << StringPrintf(
"Iteration [%d/%d]", current_iteration_, options_.num_iterations);

std::vector<std::pair<image_t, image_t>> existing_image_pairs;
std::vector<int> existing_num_inliers;
cache_->AccessDatabase(
[&existing_image_pairs, &existing_num_inliers](const Database& database) {
database.ReadTwoViewGeometryNumInliers(&existing_image_pairs,
&existing_num_inliers);
});

std::unordered_map<image_t, std::vector<image_t>> adjacency;
for (const auto& image_pair : existing_image_pairs) {
adjacency[image_pair.first].push_back(image_pair.second);
adjacency[image_pair.second].push_back(image_pair.first);
image_pair_ids_.insert(
Database::ImagePairToPairId(image_pair.first, image_pair.second));
}

for (const auto& image : adjacency) {
const auto image_id1 = image.first;
for (const auto& image_id2 : image.second) {
const auto it = adjacency.find(image_id2);
if (it == adjacency.end()) {
continue;
}
for (const auto& image_id3 : it->second) {
if (image_id1 == image_id3) {
continue;
}
const auto image_pair_id =
Database::ImagePairToPairId(image_id1, image_id3);
if (image_pair_ids_.count(image_pair_id) != 0) {
continue;
}
image_pairs_.emplace_back(image_id1, image_id3);
image_pair_ids_.insert(image_pair_id);
}
}
}

current_num_batches_ =
std::ceil(static_cast<double>(image_pairs_.size()) / options_.batch_size);

return Next();
}

ImportedPairGenerator::ImportedPairGenerator(
const ImagePairsMatchingOptions& options,
const std::shared_ptr<FeatureMatcherCache>& cache)
Expand Down
29 changes: 29 additions & 0 deletions src/colmap/feature/pairing.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,35 @@ class SpatialPairGenerator : public PairGenerator {
int knn_ = 0;
};

class TransitivePairGenerator : public PairGenerator {
public:
using PairOptions = TransitiveMatchingOptions;
static size_t CacheSize(const TransitiveMatchingOptions& options) {
return 2 * options.batch_size;
}

TransitivePairGenerator(const TransitiveMatchingOptions& options,
const std::shared_ptr<FeatureMatcherCache>& cache);

TransitivePairGenerator(const TransitiveMatchingOptions& options,
const std::shared_ptr<Database>& database);

void Reset() override;

bool HasFinished() const override;

std::vector<std::pair<image_t, image_t>> Next() override;

private:
const TransitiveMatchingOptions options_;
const std::shared_ptr<FeatureMatcherCache> cache_;
int current_iteration_ = 0;
int current_batch_idx_ = 0;
int current_num_batches_ = 0;
std::vector<std::pair<image_t, image_t>> image_pairs_;
std::unordered_set<image_pair_t> image_pair_ids_;
};

class ImportedPairGenerator : public PairGenerator {
public:
using PairOptions = ImagePairsMatchingOptions;
Expand Down
35 changes: 35 additions & 0 deletions src/colmap/feature/pairing_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,41 @@ TEST(SpatialPairGenerator, Nominal) {
}
}

TEST(TransitivePairGenerator, Nominal) {
constexpr int kNumImages = 5;
auto database = std::make_shared<Database>(Database::kInMemoryDatabasePath);
CreateSyntheticDatabase(kNumImages, *database);
const std::vector<Image> images = database->ReadAllImages();
CHECK_EQ(images.size(), kNumImages);

TwoViewGeometry two_view_geometry;
two_view_geometry.inlier_matches.resize(10);

database->ClearTwoViewGeometries();
database->WriteTwoViewGeometry(
images[0].ImageId(), images[1].ImageId(), two_view_geometry);
database->WriteTwoViewGeometry(
images[0].ImageId(), images[2].ImageId(), two_view_geometry);
database->WriteTwoViewGeometry(
images[1].ImageId(), images[3].ImageId(), two_view_geometry);

TransitiveMatchingOptions options;
TransitivePairGenerator generator(options, database);
const auto pairs1 = generator.Next();
EXPECT_THAT(pairs1,
testing::UnorderedElementsAre(
std::make_pair(images[2].ImageId(), images[1].ImageId()),
std::make_pair(images[3].ImageId(), images[0].ImageId())));
for (const auto& pair : pairs1) {
database->WriteTwoViewGeometry(pair.first, pair.second, two_view_geometry);
}
EXPECT_THAT(generator.Next(),
testing::ElementsAre(
std::make_pair(images[3].ImageId(), images[2].ImageId())));
EXPECT_TRUE(generator.Next().empty());
EXPECT_TRUE(generator.HasFinished());
}

TEST(ImportedPairGenerator, Nominal) {
constexpr int kNumImages = 10;
auto database = std::make_shared<Database>(Database::kInMemoryDatabasePath);
Expand Down

0 comments on commit 4ccfeb5

Please sign in to comment.