Skip to content

Commit

Permalink
Support TrinoQueryProperties for SQL containing WITH clauses
Browse files Browse the repository at this point in the history
  • Loading branch information
Chaho12 authored Nov 1, 2024
1 parent 5e2c203 commit a0de70f
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import io.trino.sql.tree.Node;
import io.trino.sql.tree.NodeLocation;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.Query;
import io.trino.sql.tree.RenameMaterializedView;
import io.trino.sql.tree.RenameSchema;
import io.trino.sql.tree.RenameTable;
Expand All @@ -63,6 +64,7 @@
import io.trino.sql.tree.StringLiteral;
import io.trino.sql.tree.Table;
import io.trino.sql.tree.TableFunctionInvocation;
import io.trino.sql.tree.WithQuery;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.HttpMethod;

Expand All @@ -71,6 +73,7 @@
import java.net.URLDecoder;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -208,8 +211,9 @@ else if (statement instanceof ExecuteImmediate executeImmediate) {
ImmutableSet.Builder<String> catalogBuilder = ImmutableSet.builder();
ImmutableSet.Builder<String> schemaBuilder = ImmutableSet.builder();
ImmutableSet.Builder<String> catalogSchemaBuilder = ImmutableSet.builder();
Set<QualifiedName> temporaryTables = new HashSet<>();

visitNode(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
visitNode(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder, temporaryTables);
tables = tableBuilder.build();
catalogBuilder.addAll(tables.stream().map(q -> q.getParts().getFirst()).iterator());
catalogs = catalogBuilder.build();
Expand Down Expand Up @@ -273,7 +277,8 @@ private String decodePreparedStatementFromHeader(String headerValue)
private void visitNode(Node node, ImmutableSet.Builder<QualifiedName> tableBuilder,
ImmutableSet.Builder<String> catalogBuilder,
ImmutableSet.Builder<String> schemaBuilder,
ImmutableSet.Builder<String> catalogSchemaBuilder)
ImmutableSet.Builder<String> catalogSchemaBuilder,
Set<QualifiedName> temporaryTables)
throws RequestParsingException
{
switch (node) {
Expand All @@ -289,6 +294,7 @@ private void visitNode(Node node, ImmutableSet.Builder<QualifiedName> tableBuild
case DropCatalog s -> catalogBuilder.add(s.getCatalogName().getValue());
case DropSchema s -> setCatalogAndSchemaNameFromSchemaQualifiedName(Optional.of(s.getSchemaName()), catalogBuilder, schemaBuilder, catalogSchemaBuilder);
case DropTable s -> tableBuilder.add(qualifyName(s.getTableName()));
case Query q -> q.getWith().ifPresent(with -> temporaryTables.addAll(with.getQueries().stream().map(WithQuery::getName).map(Identifier::getValue).map(QualifiedName::of).toList()));
case RenameMaterializedView s -> {
tableBuilder.add(qualifyName(s.getSource()));
tableBuilder.add(qualifyName(s.getTarget()));
Expand Down Expand Up @@ -347,13 +353,18 @@ private void visitNode(Node node, ImmutableSet.Builder<QualifiedName> tableBuild
case SetSchemaAuthorization s -> setCatalogAndSchemaNameFromSchemaQualifiedName(Optional.of(s.getSource()), catalogBuilder, schemaBuilder, catalogSchemaBuilder);
case SetTableAuthorization s -> tableBuilder.add(qualifyName(s.getSource()));
case SetViewAuthorization s -> tableBuilder.add(qualifyName(s.getSource()));
case Table s -> tableBuilder.add(qualifyName(s.getName()));
case Table s -> {
// ignore temporary tables as they can have various table parts
if (!temporaryTables.contains(s.getName())) {
tableBuilder.add(qualifyName(s.getName()));
}
}
case TableFunctionInvocation s -> tableBuilder.add(qualifyName(s.getName()));
default -> {}
}

for (Node child : node.getChildren()) {
visitNode(child, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
visitNode(child, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder, temporaryTables);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,38 @@ void testTrinoQueryPropertiesTableExtraction(String query, Set<String> catalogs,
assertThat(trinoQueryProperties.getCatalogs()).isEqualTo(catalogs);
}

@Test
void testWithQueryNameExcluded()
throws IOException
{
String query = """
WITH dos AS (SELECT c1 from cat.schem.tbl1),
uno as (SELECT c1 FROM dos)
SELECT c1 FROM uno, dos
""";
HttpServletRequest mockRequestWithDefaults = prepareMockRequest();
when(mockRequestWithDefaults.getReader()).thenReturn(new BufferedReader(new StringReader(query)));
when(mockRequestWithDefaults.getHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME)).thenReturn(DEFAULT_CATALOG);
when(mockRequestWithDefaults.getHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME)).thenReturn(DEFAULT_SCHEMA);

TrinoQueryProperties trinoQueryPropertiesWithDefaults = new TrinoQueryProperties(
mockRequestWithDefaults,
requestAnalyzerConfig.isClientsUseV2Format(),
requestAnalyzerConfig.getMaxBodySize());
Set<QualifiedName> tablesWithDefaults = trinoQueryPropertiesWithDefaults.getTables();
assertThat(tablesWithDefaults).containsExactly(QualifiedName.of("cat", "schem", "tbl1"));

HttpServletRequest mockRequestNoDefaults = prepareMockRequest();
when(mockRequestNoDefaults.getReader()).thenReturn(new BufferedReader(new StringReader(query)));

TrinoQueryProperties trinoQueryPropertiesNoDefaults = new TrinoQueryProperties(
mockRequestNoDefaults,
requestAnalyzerConfig.isClientsUseV2Format(),
requestAnalyzerConfig.getMaxBodySize());
Set<QualifiedName> tablesNoDefaults = trinoQueryPropertiesNoDefaults.getTables();
assertThat(tablesNoDefaults).containsExactly(QualifiedName.of("cat", "schem", "tbl1"));
}

private HttpServletRequest prepareMockRequest()
{
HttpServletRequest mockRequest = mock(HttpServletRequest.class);
Expand Down

0 comments on commit a0de70f

Please sign in to comment.