Skip to content

Commit

Permalink
Runtime: Fix malformed WHERE clauses for security policies with `OR…
Browse files Browse the repository at this point in the history
…` in the row filter (#3752)

* Runtime: fix malformed SQL for security policies with 'OR'

* Self review

* Fix test
  • Loading branch information
begelundmuller authored Dec 28, 2023
1 parent ccec0d5 commit 737fce3
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 60 deletions.
2 changes: 1 addition & 1 deletion runtime/queries/metricsview.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func buildFilterClauseForMetricsViewFilter(mv *runtimev1.MetricsViewSpec, filter
}

if policy != nil && policy.RowFilter != "" {
clauses = append(clauses, "AND "+policy.RowFilter)
clauses = append(clauses, fmt.Sprintf("AND (%s)", policy.RowFilter))
}

return strings.Join(clauses, " "), args, nil
Expand Down
16 changes: 9 additions & 7 deletions runtime/queries/metricsview_aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,16 @@ func (q *MetricsViewAggregation) buildMetricsAggregationSQL(mv *runtimev1.Metric
}
whereClause += clause
}
if q.Filter != nil {
clause, clauseArgs, err := buildFilterClauseForMetricsViewFilter(mv, q.Filter, dialect, policy)
if err != nil {
return "", nil, err
}
whereClause += " " + clause
args = append(args, clauseArgs...)

filterClause, filterClauseArgs, err := buildFilterClauseForMetricsViewFilter(mv, q.Filter, dialect, policy)
if err != nil {
return "", nil, err
}
if filterClause != "" {
whereClause += " " + filterClause
args = append(args, filterClauseArgs...)
}

if len(whereClause) > 0 {
whereClause = "WHERE 1=1" + whereClause
}
Expand Down
42 changes: 18 additions & 24 deletions runtime/queries/metricsview_comparison_toplist.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,14 +291,13 @@ func (q *MetricsViewComparison) buildMetricsTopListSQL(mv *runtimev1.MetricsView
}
baseWhereClause += trc

if q.Filter != nil {
clause, clauseArgs, err := buildFilterClauseForMetricsViewFilter(mv, q.Filter, dialect, policy)
if err != nil {
return "", nil, err
}
baseWhereClause += " " + clause

args = append(args, clauseArgs...)
filterClause, filterClauseArgs, err := buildFilterClauseForMetricsViewFilter(mv, q.Filter, dialect, policy)
if err != nil {
return "", nil, err
}
if filterClause != "" {
baseWhereClause += " " + filterClause
args = append(args, filterClauseArgs...)
}

var orderClauses []string
Expand Down Expand Up @@ -484,20 +483,20 @@ func (q *MetricsViewComparison) buildMetricsComparisonTopListSQL(mv *runtimev1.M

td := safeName(mv.TimeDimension)

filterClause, filterClauseArgs, err := buildFilterClauseForMetricsViewFilter(mv, q.Filter, dialect, policy)
if err != nil {
return "", nil, err
}

trc, err := timeRangeClause(q.TimeRange, mv, dialect, td, &args)
if err != nil {
return "", nil, err
}
baseWhereClause += trc

if q.Filter != nil {
clause, clauseArgs, err := buildFilterClauseForMetricsViewFilter(mv, q.Filter, dialect, policy)
if err != nil {
return "", nil, err
}
baseWhereClause += " " + clause

args = append(args, clauseArgs...)
if filterClause != "" {
baseWhereClause += " " + filterClause
args = append(args, filterClauseArgs...)
}

trc, err = timeRangeClause(q.ComparisonTimeRange, mv, dialect, td, &args)
Expand All @@ -506,14 +505,9 @@ func (q *MetricsViewComparison) buildMetricsComparisonTopListSQL(mv *runtimev1.M
}
comparisonWhereClause += trc

if q.Filter != nil {
clause, clauseArgs, err := buildFilterClauseForMetricsViewFilter(mv, q.Filter, dialect, policy)
if err != nil {
return "", nil, err
}
comparisonWhereClause += " " + clause

args = append(args, clauseArgs...)
if filterClause != "" {
comparisonWhereClause += " " + filterClause
args = append(args, filterClauseArgs...)
}

err = validateSort(q.Sort)
Expand Down
14 changes: 7 additions & 7 deletions runtime/queries/metricsview_rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,13 @@ func (q *MetricsViewRows) buildMetricsRowsSQL(mv *runtimev1.MetricsViewSpec, dia
}
}

if q.Filter != nil {
clause, clauseArgs, err := buildFilterClauseForMetricsViewFilter(mv, q.Filter, dialect, policy)
if err != nil {
return "", nil, err
}
whereClause += " " + clause
args = append(args, clauseArgs...)
filterClause, filterClauseArgs, err := buildFilterClauseForMetricsViewFilter(mv, q.Filter, dialect, policy)
if err != nil {
return "", nil, err
}
if filterClause != "" {
whereClause += " " + filterClause
args = append(args, filterClauseArgs...)
}

sortingCriteria := make([]string, 0, len(q.Sort))
Expand Down
14 changes: 7 additions & 7 deletions runtime/queries/metricsview_timeseries.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@ func (q *MetricsViewTimeSeries) buildMetricsTimeseriesSQL(olap drivers.OLAPStore
args = append(args, q.TimeEnd.AsTime())
}

if q.Filter != nil {
clause, clauseArgs, err := buildFilterClauseForMetricsViewFilter(mv, q.Filter, olap.Dialect(), policy)
if err != nil {
return "", "", nil, err
}
whereClause += " " + clause
args = append(args, clauseArgs...)
filterClause, filterClauseArgs, err := buildFilterClauseForMetricsViewFilter(mv, q.Filter, olap.Dialect(), policy)
if err != nil {
return "", "", nil, err
}
if filterClause != "" {
whereClause += " " + filterClause
args = append(args, filterClauseArgs...)
}

tsAlias := tempName("_ts_")
Expand Down
14 changes: 7 additions & 7 deletions runtime/queries/metricsview_toplist.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,13 @@ func (q *MetricsViewToplist) buildMetricsTopListSQL(mv *runtimev1.MetricsViewSpe
}
}

if q.Filter != nil {
clause, clauseArgs, err := buildFilterClauseForMetricsViewFilter(mv, q.Filter, dialect, policy)
if err != nil {
return "", nil, err
}
whereClause += " " + clause
args = append(args, clauseArgs...)
filterClause, filterClauseArgs, err := buildFilterClauseForMetricsViewFilter(mv, q.Filter, dialect, policy)
if err != nil {
return "", nil, err
}
if filterClause != "" {
whereClause += " " + filterClause
args = append(args, filterClauseArgs...)
}

sortingCriteria := make([]string, 0, len(q.Sort))
Expand Down
14 changes: 7 additions & 7 deletions runtime/queries/metricsview_totals.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,13 @@ func (q *MetricsViewTotals) buildMetricsTotalsSQL(mv *runtimev1.MetricsViewSpec,
}
}

if q.Filter != nil {
clause, clauseArgs, err := buildFilterClauseForMetricsViewFilter(mv, q.Filter, dialect, policy)
if err != nil {
return "", nil, err
}
whereClause += " " + clause
args = append(args, clauseArgs...)
filterClause, filterClauseArgs, err := buildFilterClauseForMetricsViewFilter(mv, q.Filter, dialect, policy)
if err != nil {
return "", nil, err
}
if filterClause != "" {
whereClause += " " + filterClause
args = append(args, filterClauseArgs...)
}

sql := fmt.Sprintf(
Expand Down

1 comment on commit 737fce3

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉 Published on https://ui.rilldata.com as production
🚀 Deployed on https://658d6a715b37ba00a38ef9de--rill-ui.netlify.app

Please sign in to comment.