From 16f039a0f5312f98df3079ed6fe09f3e56b8f61e Mon Sep 17 00:00:00 2001 From: leavauchier <120112647+leavauchier@users.noreply.github.com> Date: Tue, 11 Jun 2024 11:26:45 +0200 Subject: [PATCH] Fix grid decimation filter on empty input (#8) --- .../GridDecimationFilter.cpp | 248 +++++++++--------- test/test_grid_decimation.py | 36 ++- 2 files changed, 148 insertions(+), 136 deletions(-) diff --git a/src/filter_grid_decimation/GridDecimationFilter.cpp b/src/filter_grid_decimation/GridDecimationFilter.cpp index 366caab..62cea95 100755 --- a/src/filter_grid_decimation/GridDecimationFilter.cpp +++ b/src/filter_grid_decimation/GridDecimationFilter.cpp @@ -1,23 +1,21 @@ /****************************************************************************** -* Copyright (c) 2023, Antoine Lavenant (antoine.lavenant@ign.fr) -* -* All rights reserved. -* -****************************************************************************/ + * Copyright (c) 2023, Antoine Lavenant (antoine.lavenant@ign.fr) + * + * All rights reserved. + * + ****************************************************************************/ #include "GridDecimationFilter.hpp" #include #include -#include #include +#include -namespace pdal -{ +namespace pdal { -static StaticPluginInfo const s_info -{ +static StaticPluginInfo const s_info{ "filters.grid_decimation_deprecated", // better to use the pdal gridDecimation plugIN "keep max or min points in a grid", "", @@ -27,153 +25,145 @@ CREATE_SHARED_STAGE(GridDecimationFilter, s_info) std::string GridDecimationFilter::getName() const { return s_info.name; } -GridDecimationFilter::GridDecimationFilter() : m_args(new GridDecimationFilter::GridArgs) -{} +GridDecimationFilter::GridDecimationFilter() : m_args(new GridDecimationFilter::GridArgs) {} +GridDecimationFilter::~GridDecimationFilter() {} -GridDecimationFilter::~GridDecimationFilter() -{} +void GridDecimationFilter::addArgs(ProgramArgs &args) { + args.add("resolution", "Cell edge size, in units of X/Y", m_args->m_edgeLength, 1.); + args.add("output_type", "Point kept into the cells ('min', 'max')", m_args->m_methodKeep, "max"); + args.add("output_dimension", "Name of the added dimension", m_args->m_nameOutDimension, "grid"); + args.add("output_wkt", "Export the grid as wkt", m_args->m_nameWktgrid, ""); +} +void GridDecimationFilter::initialize() {} -void GridDecimationFilter::addArgs(ProgramArgs& args) -{ - args.add("resolution", "Cell edge size, in units of X/Y",m_args->m_edgeLength, 1.); - args.add("output_type", "Point keept into the cells ('min', 'max')", m_args->m_methodKeep, "max" ); - args.add("output_dimension", "Name of the added dimension", m_args->m_nameOutDimension, "grid" ); - args.add("output_wkt", "Export the grid as wkt", m_args->m_nameWktgrid, "" ); +void GridDecimationFilter::prepared(PointTableRef table) { PointLayoutPtr layout(table.layout()); } -} +void GridDecimationFilter::ready(PointTableRef table) { + if (m_args->m_edgeLength <= 0) + throwError("resolution must be positive."); -void GridDecimationFilter::initialize() -{ -} + if (m_args->m_methodKeep != "max" && m_args->m_methodKeep != "min") + throwError("The output_type must be 'max' or 'min'."); + + if (m_args->m_nameOutDimension.empty()) + throwError("The output_dimension must be given."); -void GridDecimationFilter::prepared(PointTableRef table) -{ - PointLayoutPtr layout(table.layout()); + if (!m_args->m_nameWktgrid.empty()) + std::remove(m_args->m_nameWktgrid.c_str()); } -void GridDecimationFilter::ready(PointTableRef table) -{ - if (m_args->m_edgeLength <=0) - throwError("resolution must be positive."); - - if (m_args->m_methodKeep != "max" && m_args->m_methodKeep != "min") - throwError("The output_type must be 'max' or 'min'."); - - if (m_args->m_nameOutDimension.empty()) - throwError("The output_dimension must be given."); - - if (!m_args->m_nameWktgrid.empty()) - std::remove(m_args->m_nameWktgrid.c_str()); +void GridDecimationFilter::addDimensions(PointLayoutPtr layout) { + m_args->m_dim = + layout->registerOrAssignDim(m_args->m_nameOutDimension, Dimension::Type::Unsigned8); } -void GridDecimationFilter::addDimensions(PointLayoutPtr layout) -{ - m_args->m_dim = layout->registerOrAssignDim(m_args->m_nameOutDimension, Dimension::Type::Unsigned8); +void GridDecimationFilter::processOne(BOX2D bounds, PointRef &point, PointViewPtr view) { + // get the grid cell + double x = point.getFieldAs(Dimension::Id::X); + double y = point.getFieldAs(Dimension::Id::Y); + int id = point.getFieldAs(Dimension::Id::PointId); + + // if x==(xmax of the cell), we assume the point are in the upper cell + // if y==(ymax of the cell), we assume the point are in the right cell + int width = static_cast((x - bounds.minx) / m_args->m_edgeLength); + int height = static_cast((y - bounds.miny) / m_args->m_edgeLength); + + // to avoid numeric pb with the division (append if the point is on the grid) + if (x < bounds.minx + width * m_args->m_edgeLength) + width--; + if (y < bounds.miny + height * m_args->m_edgeLength) + height--; + if (x >= bounds.minx + (width + 1) * m_args->m_edgeLength) + width++; + if (y >= bounds.miny + (height + 1) * m_args->m_edgeLength) + height++; + + auto mptRefid = this->grid.find(std::make_pair(width, height)); + assert(mptRefid != this->grid.end()); + auto ptRefid = mptRefid->second; + + if (ptRefid == -1) { + this->grid[std::make_pair(width, height)] = point.pointId(); + return; + } + + PointRef ptRef = view->point(ptRefid); + + double z = point.getFieldAs(Dimension::Id::Z); + double zRef = ptRef.getFieldAs(Dimension::Id::Z); + + if (this->m_args->m_methodKeep == "max" && z > zRef) + this->grid[std::make_pair(width, height)] = point.pointId(); + if (this->m_args->m_methodKeep == "min" && z < zRef) + this->grid[std::make_pair(width, height)] = point.pointId(); } -void GridDecimationFilter::processOne(BOX2D bounds, PointRef& point, PointViewPtr view) -{ - //get the grid cell - double x = point.getFieldAs(Dimension::Id::X); - double y = point.getFieldAs(Dimension::Id::Y); - int id = point.getFieldAs(Dimension::Id::PointId); - - // if x==(xmax of the cell), we assume the point are in the upper cell - // if y==(ymax of the cell), we assume the point are in the right cell - int width = static_cast((x - bounds.minx) / m_args->m_edgeLength); - int height = static_cast((y - bounds.miny) / m_args->m_edgeLength); - - // to avoid numeric pb with the division (append if the point is on the grid) - if (x < bounds.minx+width*m_args->m_edgeLength) width--; - if (y < bounds.miny+height*m_args->m_edgeLength) height--; - if (x >= bounds.minx+(width+1)*m_args->m_edgeLength) width++; - if (y >= bounds.miny+(height+1)*m_args->m_edgeLength) height++; - - auto mptRefid = this->grid.find( std::make_pair(width,height) ); - assert( mptRefid != this->grid.end() ); - auto ptRefid = mptRefid->second; - - if (ptRefid==-1) - { - this->grid[ std::make_pair(width,height) ] = point.pointId(); - return; - } - - PointRef ptRef = view->point(ptRefid); +void GridDecimationFilter::createGrid(BOX2D bounds) { - double z = point.getFieldAs(Dimension::Id::Z); - double zRef = ptRef.getFieldAs(Dimension::Id::Z); + size_t d_width = std::floor((bounds.maxx - bounds.minx) / m_args->m_edgeLength) + 1; + size_t d_height = std::floor((bounds.maxy - bounds.miny) / m_args->m_edgeLength) + 1; - if (this->m_args->m_methodKeep == "max" && z>zRef) - this->grid[ std::make_pair(width,height) ] = point.pointId(); - if (this->m_args->m_methodKeep == "min" && zgrid[ std::make_pair(width,height) ] = point.pointId(); -} + if (d_width < 0.0 || d_width > (std::numeric_limits::max)()) + throwError("Grid width out of range."); + if (d_height < 0.0 || d_height > (std::numeric_limits::max)()) + throwError("Grid height out of range."); -void GridDecimationFilter::createGrid(BOX2D bounds) -{ - size_t d_width = std::floor((bounds.maxx - bounds.minx) / m_args->m_edgeLength) + 1; - size_t d_height = std::floor((bounds.maxy - bounds.miny) / m_args->m_edgeLength) + 1; - - if (d_width < 0.0 || d_width > (std::numeric_limits::max)()) - throwError("Grid width out of range."); - if (d_height < 0.0 || d_height > (std::numeric_limits::max)()) - throwError("Grid height out of range."); - - int width = static_cast(d_width); - int height = static_cast(d_height); - - std::vector vgrid; - - for (size_t l(0); lm_edgeLength, bounds.miny + l*m_args->m_edgeLength, - bounds.minx + (c+1)*m_args->m_edgeLength, bounds.miny + (l+1)*m_args->m_edgeLength ); - vgrid.push_back(Polygon(bounds_dalle)); - this->grid.insert( std::make_pair( std::make_pair(c,l), -1) ); - } + int width = static_cast(d_width); + int height = static_cast(d_height); - if (!m_args->m_nameWktgrid.empty()) - { - std::ofstream oss (m_args->m_nameWktgrid); - for (auto pol : vgrid) - oss << pol.wkt() << std::endl; + std::vector vgrid; + + for (size_t l(0); l < height; l++) + for (size_t c(0); c < width; c++) { + BOX2D bounds_dalle(bounds.minx + c * m_args->m_edgeLength, + bounds.miny + l * m_args->m_edgeLength, + bounds.minx + (c + 1) * m_args->m_edgeLength, + bounds.miny + (l + 1) * m_args->m_edgeLength); + vgrid.push_back(Polygon(bounds_dalle)); + this->grid.insert(std::make_pair(std::make_pair(c, l), -1)); } - + + if (!m_args->m_nameWktgrid.empty()) { + std::ofstream oss(m_args->m_nameWktgrid); + for (auto pol : vgrid) + oss << pol.wkt() << std::endl; + } } -PointViewSet GridDecimationFilter::run(PointViewPtr view) -{ +PointViewSet GridDecimationFilter::run(PointViewPtr view) { + if (view->empty()) { + if (!m_args->m_nameWktgrid.empty()) + std::ofstream{m_args->m_nameWktgrid}; + + } else { BOX2D bounds; view->calculateBounds(bounds); createGrid(bounds); - for (PointId i = 0; i < view->size(); ++i) - { - PointRef point = view->point(i); - processOne(bounds,point,view); + for (PointId i = 0; i < view->size(); ++i) { + PointRef point = view->point(i); + processOne(bounds, point, view); } - + std::set keepPoint; for (auto it : this->grid) - if (it.second != -1) - keepPoint.insert(it.second); - - for (PointId i = 0; i < view->size(); ++i) - { - PointRef point = view->point(i); - if (keepPoint.find(view->point(i).pointId()) != keepPoint.end()) - point.setField(m_args->m_dim, int64_t(1)); - else - point.setField(m_args->m_dim, int64_t(0)); + if (it.second != -1) + keepPoint.insert(it.second); + + for (PointId i = 0; i < view->size(); ++i) { + PointRef point = view->point(i); + if (keepPoint.find(view->point(i).pointId()) != keepPoint.end()) + point.setField(m_args->m_dim, int64_t(1)); + else + point.setField(m_args->m_dim, int64_t(0)); } - - PointViewSet viewSet; - viewSet.insert(view); - return viewSet; + } + + PointViewSet viewSet; + viewSet.insert(view); + return viewSet; } } // namespace pdal diff --git a/test/test_grid_decimation.py b/test/test_grid_decimation.py index 8cf4563..080a1a8 100755 --- a/test/test_grid_decimation.py +++ b/test/test_grid_decimation.py @@ -4,6 +4,7 @@ import tempfile from test import utils +import numpy as np import pdal import pdaltools.las_info as li import pytest @@ -14,14 +15,14 @@ def contains(bounds, x, y): return bounds[0] <= x and x < bounds[1] and bounds[2] <= y and y < bounds[3] -def run_filter(type, resolution): +def run_filter(output_type, resolution): ini_las = "test/data/4_6.las" tmp_out_wkt = tempfile.NamedTemporaryFile(suffix=f"_{resolution}.wkt").name - filter = "filters.grid_decimation_deprecated" - utils.pdal_has_plugin(filter) + filter_name = "filters.grid_decimation_deprecated" + utils.pdal_has_plugin(filter_name) bounds = li.las_get_xy_bounds(ini_las) @@ -32,9 +33,9 @@ def run_filter(type, resolution): PIPELINE = [ {"type": "readers.las", "filename": ini_las}, { - "type": filter, + "type": filter_name, "resolution": resolution, - "output_type": type, + "output_type": output_type, "output_dimension": "grid", "output_wkt": tmp_out_wkt, }, @@ -75,10 +76,10 @@ def run_filter(type, resolution): continue z = pt["Z"] - if type == "max": + if output_type == "max": if ZRef == 0 or z > ZRef: ZRef = z - elif type == "min": + elif output_type == "min": if ZRef == 0 or z < ZRef: ZRef = z @@ -112,3 +113,24 @@ def test_grid_decimation_max(resolution): ) def test_grid_decimation_min(resolution): run_filter("min", resolution) + + +def test_grid_decimation_empty(): + ini_las = "test/data/4_6.las" + with tempfile.NamedTemporaryFile(suffix="_empty.wkt") as tmp_out_wkt: + pipeline = pdal.Pipeline() | pdal.Reader.las(filename=ini_las) + pipeline |= pdal.Filter.grid_decimation_deprecated( + resolution=10, + output_type="min", + output_dimension="grid", + output_wkt=tmp_out_wkt.name, + where="Classification==123", # should create an empty result + ) + pipeline.execute() + + with open(tmp_out_wkt.name, "r") as f: + reader = csv.reader(f, delimiter="\t") + lines = [line for line in reader] + assert len(lines) == 0 + + assert np.all(pipeline.arrays[0]["grid"] == 0)