diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java index 791d92bbd..3946ab6b3 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java @@ -47,6 +47,7 @@ import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeLocation; import io.trino.sql.tree.QualifiedName; +import io.trino.sql.tree.Query; import io.trino.sql.tree.RenameMaterializedView; import io.trino.sql.tree.RenameSchema; import io.trino.sql.tree.RenameTable; @@ -63,6 +64,7 @@ import io.trino.sql.tree.StringLiteral; import io.trino.sql.tree.Table; import io.trino.sql.tree.TableFunctionInvocation; +import io.trino.sql.tree.WithQuery; import jakarta.servlet.http.HttpServletRequest; import jakarta.ws.rs.HttpMethod; @@ -71,6 +73,7 @@ import java.net.URLDecoder; import java.util.ArrayList; import java.util.Enumeration; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -208,8 +211,9 @@ else if (statement instanceof ExecuteImmediate executeImmediate) { ImmutableSet.Builder catalogBuilder = ImmutableSet.builder(); ImmutableSet.Builder schemaBuilder = ImmutableSet.builder(); ImmutableSet.Builder catalogSchemaBuilder = ImmutableSet.builder(); + Set temporaryTables = new HashSet<>(); - visitNode(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder); + visitNode(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder, temporaryTables); tables = tableBuilder.build(); catalogBuilder.addAll(tables.stream().map(q -> q.getParts().getFirst()).iterator()); catalogs = catalogBuilder.build(); @@ -273,7 +277,8 @@ private String decodePreparedStatementFromHeader(String headerValue) private void visitNode(Node node, ImmutableSet.Builder tableBuilder, ImmutableSet.Builder catalogBuilder, ImmutableSet.Builder schemaBuilder, - ImmutableSet.Builder catalogSchemaBuilder) + ImmutableSet.Builder catalogSchemaBuilder, + Set temporaryTables) throws RequestParsingException { switch (node) { @@ -289,6 +294,7 @@ private void visitNode(Node node, ImmutableSet.Builder tableBuild case DropCatalog s -> catalogBuilder.add(s.getCatalogName().getValue()); case DropSchema s -> setCatalogAndSchemaNameFromSchemaQualifiedName(Optional.of(s.getSchemaName()), catalogBuilder, schemaBuilder, catalogSchemaBuilder); case DropTable s -> tableBuilder.add(qualifyName(s.getTableName())); + case Query q -> q.getWith().ifPresent(with -> temporaryTables.addAll(with.getQueries().stream().map(WithQuery::getName).map(Identifier::getValue).map(QualifiedName::of).toList())); case RenameMaterializedView s -> { tableBuilder.add(qualifyName(s.getSource())); tableBuilder.add(qualifyName(s.getTarget())); @@ -347,13 +353,18 @@ private void visitNode(Node node, ImmutableSet.Builder tableBuild case SetSchemaAuthorization s -> setCatalogAndSchemaNameFromSchemaQualifiedName(Optional.of(s.getSource()), catalogBuilder, schemaBuilder, catalogSchemaBuilder); case SetTableAuthorization s -> tableBuilder.add(qualifyName(s.getSource())); case SetViewAuthorization s -> tableBuilder.add(qualifyName(s.getSource())); - case Table s -> tableBuilder.add(qualifyName(s.getName())); + case Table s -> { + // ignore temporary tables as they can have various table parts + if (!temporaryTables.contains(s.getName())) { + tableBuilder.add(qualifyName(s.getName())); + } + } case TableFunctionInvocation s -> tableBuilder.add(qualifyName(s.getName())); default -> {} } for (Node child : node.getChildren()) { - visitNode(child, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder); + visitNode(child, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder, temporaryTables); } } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelector.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelector.java index 1066d19a7..38ec43bde 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelector.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelector.java @@ -407,6 +407,38 @@ void testTrinoQueryPropertiesTableExtraction(String query, Set catalogs, assertThat(trinoQueryProperties.getCatalogs()).isEqualTo(catalogs); } + @Test + void testWithQueryNameExcluded() + throws IOException + { + String query = """ + WITH dos AS (SELECT c1 from cat.schem.tbl1), + uno as (SELECT c1 FROM dos) + SELECT c1 FROM uno, dos + """; + HttpServletRequest mockRequestWithDefaults = prepareMockRequest(); + when(mockRequestWithDefaults.getReader()).thenReturn(new BufferedReader(new StringReader(query))); + when(mockRequestWithDefaults.getHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME)).thenReturn(DEFAULT_CATALOG); + when(mockRequestWithDefaults.getHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME)).thenReturn(DEFAULT_SCHEMA); + + TrinoQueryProperties trinoQueryPropertiesWithDefaults = new TrinoQueryProperties( + mockRequestWithDefaults, + requestAnalyzerConfig.isClientsUseV2Format(), + requestAnalyzerConfig.getMaxBodySize()); + Set tablesWithDefaults = trinoQueryPropertiesWithDefaults.getTables(); + assertThat(tablesWithDefaults).containsExactly(QualifiedName.of("cat", "schem", "tbl1")); + + HttpServletRequest mockRequestNoDefaults = prepareMockRequest(); + when(mockRequestNoDefaults.getReader()).thenReturn(new BufferedReader(new StringReader(query))); + + TrinoQueryProperties trinoQueryPropertiesNoDefaults = new TrinoQueryProperties( + mockRequestNoDefaults, + requestAnalyzerConfig.isClientsUseV2Format(), + requestAnalyzerConfig.getMaxBodySize()); + Set tablesNoDefaults = trinoQueryPropertiesNoDefaults.getTables(); + assertThat(tablesNoDefaults).containsExactly(QualifiedName.of("cat", "schem", "tbl1")); + } + private HttpServletRequest prepareMockRequest() { HttpServletRequest mockRequest = mock(HttpServletRequest.class);