Skip to content

Commit

Permalink
Fix grid decimation filter on empty input (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
leavauchier authored Jun 11, 2024
1 parent d5ef3bd commit 16f039a
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 136 deletions.
248 changes: 119 additions & 129 deletions src/filter_grid_decimation/GridDecimationFilter.cpp
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
/******************************************************************************
* Copyright (c) 2023, Antoine Lavenant ([email protected])
*
* All rights reserved.
*
****************************************************************************/
* Copyright (c) 2023, Antoine Lavenant ([email protected])
*
* All rights reserved.
*
****************************************************************************/

#include "GridDecimationFilter.hpp"

#include <pdal/PointView.hpp>
#include <pdal/StageFactory.hpp>

#include <sstream>
#include <cstdarg>
#include <sstream>

namespace pdal
{
namespace pdal {

static StaticPluginInfo const s_info
{
static StaticPluginInfo const s_info{
"filters.grid_decimation_deprecated", // better to use the pdal gridDecimation plugIN
"keep max or min points in a grid",
"",
Expand All @@ -27,153 +25,145 @@ CREATE_SHARED_STAGE(GridDecimationFilter, s_info)

std::string GridDecimationFilter::getName() const { return s_info.name; }

GridDecimationFilter::GridDecimationFilter() : m_args(new GridDecimationFilter::GridArgs)
{}
GridDecimationFilter::GridDecimationFilter() : m_args(new GridDecimationFilter::GridArgs) {}

GridDecimationFilter::~GridDecimationFilter() {}

GridDecimationFilter::~GridDecimationFilter()
{}
void GridDecimationFilter::addArgs(ProgramArgs &args) {
args.add("resolution", "Cell edge size, in units of X/Y", m_args->m_edgeLength, 1.);
args.add("output_type", "Point kept into the cells ('min', 'max')", m_args->m_methodKeep, "max");
args.add("output_dimension", "Name of the added dimension", m_args->m_nameOutDimension, "grid");
args.add("output_wkt", "Export the grid as wkt", m_args->m_nameWktgrid, "");
}

void GridDecimationFilter::initialize() {}

void GridDecimationFilter::addArgs(ProgramArgs& args)
{
args.add("resolution", "Cell edge size, in units of X/Y",m_args->m_edgeLength, 1.);
args.add("output_type", "Point keept into the cells ('min', 'max')", m_args->m_methodKeep, "max" );
args.add("output_dimension", "Name of the added dimension", m_args->m_nameOutDimension, "grid" );
args.add("output_wkt", "Export the grid as wkt", m_args->m_nameWktgrid, "" );
void GridDecimationFilter::prepared(PointTableRef table) { PointLayoutPtr layout(table.layout()); }

}
void GridDecimationFilter::ready(PointTableRef table) {
if (m_args->m_edgeLength <= 0)
throwError("resolution must be positive.");

void GridDecimationFilter::initialize()
{
}
if (m_args->m_methodKeep != "max" && m_args->m_methodKeep != "min")
throwError("The output_type must be 'max' or 'min'.");

if (m_args->m_nameOutDimension.empty())
throwError("The output_dimension must be given.");

void GridDecimationFilter::prepared(PointTableRef table)
{
PointLayoutPtr layout(table.layout());
if (!m_args->m_nameWktgrid.empty())
std::remove(m_args->m_nameWktgrid.c_str());
}

void GridDecimationFilter::ready(PointTableRef table)
{
if (m_args->m_edgeLength <=0)
throwError("resolution must be positive.");

if (m_args->m_methodKeep != "max" && m_args->m_methodKeep != "min")
throwError("The output_type must be 'max' or 'min'.");

if (m_args->m_nameOutDimension.empty())
throwError("The output_dimension must be given.");

if (!m_args->m_nameWktgrid.empty())
std::remove(m_args->m_nameWktgrid.c_str());
void GridDecimationFilter::addDimensions(PointLayoutPtr layout) {
m_args->m_dim =
layout->registerOrAssignDim(m_args->m_nameOutDimension, Dimension::Type::Unsigned8);
}

void GridDecimationFilter::addDimensions(PointLayoutPtr layout)
{
m_args->m_dim = layout->registerOrAssignDim(m_args->m_nameOutDimension, Dimension::Type::Unsigned8);
void GridDecimationFilter::processOne(BOX2D bounds, PointRef &point, PointViewPtr view) {
// get the grid cell
double x = point.getFieldAs<double>(Dimension::Id::X);
double y = point.getFieldAs<double>(Dimension::Id::Y);
int id = point.getFieldAs<double>(Dimension::Id::PointId);

// if x==(xmax of the cell), we assume the point are in the upper cell
// if y==(ymax of the cell), we assume the point are in the right cell
int width = static_cast<int>((x - bounds.minx) / m_args->m_edgeLength);
int height = static_cast<int>((y - bounds.miny) / m_args->m_edgeLength);

// to avoid numeric pb with the division (append if the point is on the grid)
if (x < bounds.minx + width * m_args->m_edgeLength)
width--;
if (y < bounds.miny + height * m_args->m_edgeLength)
height--;
if (x >= bounds.minx + (width + 1) * m_args->m_edgeLength)
width++;
if (y >= bounds.miny + (height + 1) * m_args->m_edgeLength)
height++;

auto mptRefid = this->grid.find(std::make_pair(width, height));
assert(mptRefid != this->grid.end());
auto ptRefid = mptRefid->second;

if (ptRefid == -1) {
this->grid[std::make_pair(width, height)] = point.pointId();
return;
}

PointRef ptRef = view->point(ptRefid);

double z = point.getFieldAs<double>(Dimension::Id::Z);
double zRef = ptRef.getFieldAs<double>(Dimension::Id::Z);

if (this->m_args->m_methodKeep == "max" && z > zRef)
this->grid[std::make_pair(width, height)] = point.pointId();
if (this->m_args->m_methodKeep == "min" && z < zRef)
this->grid[std::make_pair(width, height)] = point.pointId();
}

void GridDecimationFilter::processOne(BOX2D bounds, PointRef& point, PointViewPtr view)
{
//get the grid cell
double x = point.getFieldAs<double>(Dimension::Id::X);
double y = point.getFieldAs<double>(Dimension::Id::Y);
int id = point.getFieldAs<double>(Dimension::Id::PointId);

// if x==(xmax of the cell), we assume the point are in the upper cell
// if y==(ymax of the cell), we assume the point are in the right cell
int width = static_cast<int>((x - bounds.minx) / m_args->m_edgeLength);
int height = static_cast<int>((y - bounds.miny) / m_args->m_edgeLength);

// to avoid numeric pb with the division (append if the point is on the grid)
if (x < bounds.minx+width*m_args->m_edgeLength) width--;
if (y < bounds.miny+height*m_args->m_edgeLength) height--;
if (x >= bounds.minx+(width+1)*m_args->m_edgeLength) width++;
if (y >= bounds.miny+(height+1)*m_args->m_edgeLength) height++;

auto mptRefid = this->grid.find( std::make_pair(width,height) );
assert( mptRefid != this->grid.end() );
auto ptRefid = mptRefid->second;

if (ptRefid==-1)
{
this->grid[ std::make_pair(width,height) ] = point.pointId();
return;
}

PointRef ptRef = view->point(ptRefid);
void GridDecimationFilter::createGrid(BOX2D bounds) {

double z = point.getFieldAs<double>(Dimension::Id::Z);
double zRef = ptRef.getFieldAs<double>(Dimension::Id::Z);
size_t d_width = std::floor((bounds.maxx - bounds.minx) / m_args->m_edgeLength) + 1;
size_t d_height = std::floor((bounds.maxy - bounds.miny) / m_args->m_edgeLength) + 1;

if (this->m_args->m_methodKeep == "max" && z>zRef)
this->grid[ std::make_pair(width,height) ] = point.pointId();
if (this->m_args->m_methodKeep == "min" && z<zRef)
this->grid[ std::make_pair(width,height) ] = point.pointId();
}
if (d_width < 0.0 || d_width > (std::numeric_limits<int>::max)())
throwError("Grid width out of range.");
if (d_height < 0.0 || d_height > (std::numeric_limits<int>::max)())
throwError("Grid height out of range.");

void GridDecimationFilter::createGrid(BOX2D bounds)
{
size_t d_width = std::floor((bounds.maxx - bounds.minx) / m_args->m_edgeLength) + 1;
size_t d_height = std::floor((bounds.maxy - bounds.miny) / m_args->m_edgeLength) + 1;

if (d_width < 0.0 || d_width > (std::numeric_limits<int>::max)())
throwError("Grid width out of range.");
if (d_height < 0.0 || d_height > (std::numeric_limits<int>::max)())
throwError("Grid height out of range.");

int width = static_cast<int>(d_width);
int height = static_cast<int>(d_height);

std::vector<Polygon> vgrid;

for (size_t l(0); l<height; l++)
for (size_t c(0); c<width; c++)
{
BOX2D bounds_dalle (bounds.minx + c*m_args->m_edgeLength, bounds.miny + l*m_args->m_edgeLength,
bounds.minx + (c+1)*m_args->m_edgeLength, bounds.miny + (l+1)*m_args->m_edgeLength );
vgrid.push_back(Polygon(bounds_dalle));
this->grid.insert( std::make_pair( std::make_pair(c,l), -1) );
}
int width = static_cast<int>(d_width);
int height = static_cast<int>(d_height);

if (!m_args->m_nameWktgrid.empty())
{
std::ofstream oss (m_args->m_nameWktgrid);
for (auto pol : vgrid)
oss << pol.wkt() << std::endl;
std::vector<Polygon> vgrid;

for (size_t l(0); l < height; l++)
for (size_t c(0); c < width; c++) {
BOX2D bounds_dalle(bounds.minx + c * m_args->m_edgeLength,
bounds.miny + l * m_args->m_edgeLength,
bounds.minx + (c + 1) * m_args->m_edgeLength,
bounds.miny + (l + 1) * m_args->m_edgeLength);
vgrid.push_back(Polygon(bounds_dalle));
this->grid.insert(std::make_pair(std::make_pair(c, l), -1));
}


if (!m_args->m_nameWktgrid.empty()) {
std::ofstream oss(m_args->m_nameWktgrid);
for (auto pol : vgrid)
oss << pol.wkt() << std::endl;
}
}

PointViewSet GridDecimationFilter::run(PointViewPtr view)
{
PointViewSet GridDecimationFilter::run(PointViewPtr view) {
if (view->empty()) {
if (!m_args->m_nameWktgrid.empty())
std::ofstream{m_args->m_nameWktgrid};

} else {
BOX2D bounds;
view->calculateBounds(bounds);
createGrid(bounds);

for (PointId i = 0; i < view->size(); ++i)
{
PointRef point = view->point(i);
processOne(bounds,point,view);
for (PointId i = 0; i < view->size(); ++i) {
PointRef point = view->point(i);
processOne(bounds, point, view);
}

std::set<PointId> keepPoint;
for (auto it : this->grid)
if (it.second != -1)
keepPoint.insert(it.second);

for (PointId i = 0; i < view->size(); ++i)
{
PointRef point = view->point(i);
if (keepPoint.find(view->point(i).pointId()) != keepPoint.end())
point.setField(m_args->m_dim, int64_t(1));
else
point.setField(m_args->m_dim, int64_t(0));
if (it.second != -1)
keepPoint.insert(it.second);

for (PointId i = 0; i < view->size(); ++i) {
PointRef point = view->point(i);
if (keepPoint.find(view->point(i).pointId()) != keepPoint.end())
point.setField(m_args->m_dim, int64_t(1));
else
point.setField(m_args->m_dim, int64_t(0));
}

PointViewSet viewSet;
viewSet.insert(view);
return viewSet;
}

PointViewSet viewSet;
viewSet.insert(view);
return viewSet;
}

} // namespace pdal
36 changes: 29 additions & 7 deletions test/test_grid_decimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile
from test import utils

import numpy as np
import pdal
import pdaltools.las_info as li
import pytest
Expand All @@ -14,14 +15,14 @@ def contains(bounds, x, y):
return bounds[0] <= x and x < bounds[1] and bounds[2] <= y and y < bounds[3]


def run_filter(type, resolution):
def run_filter(output_type, resolution):

ini_las = "test/data/4_6.las"

tmp_out_wkt = tempfile.NamedTemporaryFile(suffix=f"_{resolution}.wkt").name

filter = "filters.grid_decimation_deprecated"
utils.pdal_has_plugin(filter)
filter_name = "filters.grid_decimation_deprecated"
utils.pdal_has_plugin(filter_name)

bounds = li.las_get_xy_bounds(ini_las)

Expand All @@ -32,9 +33,9 @@ def run_filter(type, resolution):
PIPELINE = [
{"type": "readers.las", "filename": ini_las},
{
"type": filter,
"type": filter_name,
"resolution": resolution,
"output_type": type,
"output_type": output_type,
"output_dimension": "grid",
"output_wkt": tmp_out_wkt,
},
Expand Down Expand Up @@ -75,10 +76,10 @@ def run_filter(type, resolution):
continue

z = pt["Z"]
if type == "max":
if output_type == "max":
if ZRef == 0 or z > ZRef:
ZRef = z
elif type == "min":
elif output_type == "min":
if ZRef == 0 or z < ZRef:
ZRef = z

Expand Down Expand Up @@ -112,3 +113,24 @@ def test_grid_decimation_max(resolution):
)
def test_grid_decimation_min(resolution):
run_filter("min", resolution)


def test_grid_decimation_empty():
ini_las = "test/data/4_6.las"
with tempfile.NamedTemporaryFile(suffix="_empty.wkt") as tmp_out_wkt:
pipeline = pdal.Pipeline() | pdal.Reader.las(filename=ini_las)
pipeline |= pdal.Filter.grid_decimation_deprecated(
resolution=10,
output_type="min",
output_dimension="grid",
output_wkt=tmp_out_wkt.name,
where="Classification==123", # should create an empty result
)
pipeline.execute()

with open(tmp_out_wkt.name, "r") as f:
reader = csv.reader(f, delimiter="\t")
lines = [line for line in reader]
assert len(lines) == 0

assert np.all(pipeline.arrays[0]["grid"] == 0)

0 comments on commit 16f039a

Please sign in to comment.