Skip to content

Commit

Permalink
feat: add search endpoint - expand query parameter parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
rkettelerij committed Dec 4, 2024
1 parent 85e6f0d commit 927cd70
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 67 deletions.
5 changes: 4 additions & 1 deletion internal/search/datasources/datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package datasources

import (
"context"

"github.com/PDOK/gomagpie/internal/search/domain"
)

// Datasource knows how make different kinds of queries/actions on the underlying actual datastore.
// This abstraction allows the rest of the system to stay datastore agnostic.
type Datasource interface {
Suggest(ctx context.Context, suggestForThis string) ([]string, error)
Suggest(ctx context.Context, searchTerm string, collections map[string]map[string]string,
srid domain.SRID, limit int) ([]string, error)

// Close closes (connections to) the datasource gracefully
Close()
Expand Down
19 changes: 9 additions & 10 deletions internal/search/datasources/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"

"github.com/PDOK/gomagpie/internal/search/domain"
"github.com/jackc/pgx/v5"
pgxgeom "github.com/twpayne/pgx-geom"

Expand Down Expand Up @@ -36,17 +37,17 @@ func (p *Postgres) Close() {
_ = p.db.Close(p.ctx)
}

func (p *Postgres) Suggest(ctx context.Context, suggestForThis string) ([]string, error) {
func (p *Postgres) Suggest(ctx context.Context, searchTerm string, _ map[string]map[string]string, _ domain.SRID, limit int) ([]string, error) {
queryCtx, cancel := context.WithTimeout(ctx, p.queryTimeout)
defer cancel()

// Prepare dynamic full-text search query
// Split terms by spaces and append :* to each term
terms := strings.Fields(suggestForThis)
terms := strings.Fields(searchTerm)
for i, term := range terms {
terms[i] = term + ":*"
}
searchTerm := strings.Join(terms, " & ")
searchTermForPostgres := strings.Join(terms, " & ")

sqlQuery := fmt.Sprintf(
`SELECT
Expand All @@ -56,17 +57,15 @@ func (p *Postgres) Suggest(ctx context.Context, suggestForThis string) ([]string
FROM (
SELECT display_name,
ts_rank_cd(ts, to_tsquery('%[1]s'), 1) AS rank,
ts_headline('dutch', suggest, to_tsquery('%[2]s')) AS highlighted_text
FROM
%[3]s
WHERE ts @@ to_tsquery('%[4]s') LIMIT 500
ts_headline('dutch', suggest, to_tsquery('%[1]s')) AS highlighted_text
FROM %[2]s
WHERE ts @@ to_tsquery('%[1]s') LIMIT 500
) r
GROUP BY display_name
ORDER BY rank DESC, display_name ASC LIMIT 50`,
searchTerm, searchTerm, p.searchIndex, searchTerm)
ORDER BY rank DESC, display_name ASC LIMIT $1`, searchTermForPostgres, p.searchIndex)

// Execute query
rows, err := p.db.Query(queryCtx, sqlQuery)
rows, err := p.db.Query(queryCtx, sqlQuery, limit)
if err != nil {
return nil, fmt.Errorf("query '%s' failed: %w", sqlQuery, err)
}
Expand Down
12 changes: 12 additions & 0 deletions internal/search/domain/spatialref.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package domain

const (
CrsURIPrefix = "http://www.opengis.net/def/crs/"
UndefinedSRID = 0
WGS84SRIDPostgis = 4326 // Use the same SRID as used during ETL
WGS84CodeOGC = "CRS84"
)

// SRID Spatial Reference System Identifier: a unique value to unambiguously identify a spatial coordinate system.
// For example '28992' in https://www.opengis.net/def/crs/EPSG/0/28992
type SRID int
164 changes: 139 additions & 25 deletions internal/search/main.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,37 @@
package search

import (
"context"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"

"github.com/PDOK/gomagpie/internal/engine"
ds "github.com/PDOK/gomagpie/internal/search/datasources"
"github.com/PDOK/gomagpie/internal/search/datasources/postgres"
"github.com/PDOK/gomagpie/internal/search/domain"
)

const timeout = time.Second * 15
const (
queryParam = "q"
limitParam = "limit"
crsParam = "crs"

limitDefault = 10
limitMax = 50

timeout = time.Second * 15
)

var (
deepObjectParamRegex = regexp.MustCompile(`\w+\[\w+]`)
)

type Search struct {
engine *engine.Engine
Expand All @@ -31,31 +50,42 @@ func NewSearch(e *engine.Engine, dbConn string, searchIndex string) *Search {
// Suggest autosuggest locations based on user input
func (s *Search) Suggest() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
params := parseQueryParams(r.URL.Query())
searchQuery := params["q"]
delete(params, "q")
format := params["f"]
delete(params, "f")
crs := params["crs"]
delete(params, "crs")
limit := params["limit"]
delete(params, "limit")

log.Printf("crs %s, limit %d, format %s, query %s, params %v", crs, limit, format, searchQuery, params)

suggestions, err := s.datasource.Suggest(r.Context(), searchQuery.(string)) // TODO check before casting
collections, searchTerm, outputSRID, limit, err := parseQueryParams(r.URL.Query())
if err != nil {
engine.RenderProblem(engine.ProblemServerError, w, err.Error())
engine.RenderProblem(engine.ProblemBadRequest, w, err.Error())
return
}
suggestions, err := s.datasource.Suggest(r.Context(), searchTerm, collections, outputSRID, limit)
if err != nil {
handleQueryError(w, err)
return
}
format := s.engine.CN.NegotiateFormat(r)
switch format {
case engine.FormatGeoJSON, engine.FormatJSON:
serveJSON(suggestions, engine.MediaTypeGeoJSON, w)
default:
engine.RenderProblem(engine.ProblemNotAcceptable, w, fmt.Sprintf("format '%s' is not supported", format))
return
}
serveJSON(suggestions, engine.MediaTypeGeoJSON, w)
}
}

func parseQueryParams(query url.Values) map[string]any {
result := make(map[string]any, len(query))
func parseQueryParams(query url.Values) (collectionsWithParams map[string]map[string]string, searchTerm string, outputSRID domain.SRID, limit int, err error) {
err = validateNoUnknownParams(query)
if err != nil {
return
}
searchTerm, searchTermErr := parseSearchTerm(query)
collectionsWithParams = parseCollectionDeepObjectParams(query)
outputSRID, outputSRIDErr := parseCrsToSRID(query, crsParam)
limit, limitErr := parseLimit(query)
err = errors.Join(searchTermErr, limitErr, outputSRIDErr)
return
}

deepObjectParams := make(map[string]map[string]string)
func parseCollectionDeepObjectParams(query url.Values) map[string]map[string]string {
deepObjectParams := make(map[string]map[string]string, len(query))
for key, values := range query {
if strings.Contains(key, "[") {
// Extract deepObject parameters
Expand All @@ -67,15 +97,17 @@ func parseQueryParams(query url.Values) map[string]any {
deepObjectParams[mainKey] = make(map[string]string)
}
deepObjectParams[mainKey][subKey] = values[0]
} else {
// Extract regular (flat) parameters
result[key] = values[0]
}
}
for mainKey, subParams := range deepObjectParams {
result[mainKey] = subParams
return deepObjectParams
}

func parseSearchTerm(query url.Values) (searchTerm string, err error) {
searchTerm = query.Get(queryParam)
if searchTerm == "" {
err = fmt.Errorf("no search term provided, '%s' query parameter is required", queryParam)
}
return result
return
}

func newDatasource(e *engine.Engine, dbConn string, searchIndex string) ds.Datasource {
Expand All @@ -86,3 +118,85 @@ func newDatasource(e *engine.Engine, dbConn string, searchIndex string) ds.Datas
e.RegisterShutdownHook(datasource.Close)
return datasource
}

// log error, but send generic message to client to prevent possible information leakage from datasource
func handleQueryError(w http.ResponseWriter, err error) {
msg := "failed to fulfill search request"
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
// provide more context when user hits the query timeout
msg += ": querying took too long (timeout encountered). Simplify your request and try again, or contact support"
}
log.Printf("%s, error: %v\n", msg, err)
engine.RenderProblem(engine.ProblemServerError, w, msg) // don't include sensitive information in details msg
}

// implements req 7.6 (https://docs.ogc.org/is/17-069r4/17-069r4.html#query_parameters)
func validateNoUnknownParams(query url.Values) error {
copyParams := clone(query)
copyParams.Del(engine.FormatParam)
copyParams.Del(queryParam)
copyParams.Del(limitParam)
copyParams.Del(crsParam)
for key := range query {
if deepObjectParamRegex.MatchString(key) {
copyParams.Del(key)
}
}
if len(copyParams) > 0 {
return fmt.Errorf("unknown query parameter(s) found: %v", copyParams.Encode())
}
return nil
}

func clone(params url.Values) url.Values {
copyParams := url.Values{}
for k, v := range params {
copyParams[k] = v
}
return copyParams
}

func parseCrsToSRID(params url.Values, paramName string) (domain.SRID, error) {
param := params.Get(paramName)
if param == "" {
return domain.UndefinedSRID, nil
}
param = strings.TrimSpace(param)
if !strings.HasPrefix(param, domain.CrsURIPrefix) {
return domain.UndefinedSRID, fmt.Errorf("%s param should start with %s, got: %s", paramName, domain.CrsURIPrefix, param)
}
var srid domain.SRID
lastIndex := strings.LastIndex(param, "/")
if lastIndex != -1 {
crsCode := param[lastIndex+1:]
if crsCode == domain.WGS84CodeOGC {
return domain.WGS84SRIDPostgis, nil // CRS84 is WGS84, just like EPSG:4326 (only axis order differs but SRID is the same)
}
val, err := strconv.Atoi(crsCode)
if err != nil {
return 0, fmt.Errorf("expected numerical CRS code, received: %s", crsCode)
}
srid = domain.SRID(val)
}
return srid, nil
}

func parseLimit(params url.Values) (int, error) {
limit := limitDefault
var err error
if params.Get(limitParam) != "" {
limit, err = strconv.Atoi(params.Get(limitParam))
if err != nil {
err = errors.New("limit must be numeric")
}
// "If the value of the limit parameter is larger than the maximum value, this SHALL NOT result
// in an error (instead use the maximum as the parameter value)."
if limit > limitMax {
limit = limitMax
}
}
if limit < 0 {
err = errors.New("limit can't be negative")
}
return limit, err
}
Loading

0 comments on commit 927cd70

Please sign in to comment.