From 257c8ad0352f42465052a4f87efc3674cfcbd031 Mon Sep 17 00:00:00 2001 From: Tony Baeg Date: Thu, 16 Nov 2023 11:54:36 -0500 Subject: [PATCH] Update TableFunctionProcessorProvider.getDataProcessor to include ConnectorSession parameter --- .../trino/operator/TableFunctionOperator.java | 20 +++++++-- .../trino/operator/table/ExcludeColumns.java | 2 +- .../sql/planner/LocalExecutionPlanner.java | 1 + .../connector/TestingTableFunctions.java | 18 ++++---- .../table/TableFunctionProcessorProvider.java | 2 +- ...sLoaderSafeTableFunctionDataProcessor.java | 43 +++++++++++++++++++ ...derSafeTableFunctionProcessorProvider.java | 4 +- .../TestClassLoaderSafeWrappers.java | 2 + 8 files changed, 75 insertions(+), 17 deletions(-) create mode 100644 lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeTableFunctionDataProcessor.java diff --git a/core/trino-main/src/main/java/io/trino/operator/TableFunctionOperator.java b/core/trino-main/src/main/java/io/trino/operator/TableFunctionOperator.java index 7172673412a6..f309a6d145c5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TableFunctionOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/TableFunctionOperator.java @@ -22,6 +22,8 @@ import io.trino.memory.context.LocalMemoryContext; import io.trino.operator.RegularTableFunctionPartition.PassThroughColumnSpecification; import io.trino.spi.Page; +import io.trino.spi.connector.CatalogHandle; +import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.SortOrder; import io.trino.spi.function.table.ConnectorTableFunctionHandle; import io.trino.spi.function.table.TableFunctionProcessorProvider; @@ -54,6 +56,7 @@ public static class TableFunctionOperatorFactory // a provider of table function processor to be called once per partition private final TableFunctionProcessorProvider tableFunctionProvider; + private final CatalogHandle catalogHandle; // all information necessary to execute the table function collected during analysis private final ConnectorTableFunctionHandle functionHandle; @@ -106,6 +109,7 @@ public TableFunctionOperatorFactory( int operatorId, PlanNodeId planNodeId, TableFunctionProcessorProvider tableFunctionProvider, + CatalogHandle catalogHandle, ConnectorTableFunctionHandle functionHandle, int properChannelsCount, int passThroughSourcesCount, @@ -124,6 +128,7 @@ public TableFunctionOperatorFactory( { requireNonNull(planNodeId, "planNodeId is null"); requireNonNull(tableFunctionProvider, "tableFunctionProvider is null"); + requireNonNull(catalogHandle, "catalogHandle is null"); requireNonNull(functionHandle, "functionHandle is null"); requireNonNull(requiredChannels, "requiredChannels is null"); requireNonNull(markerChannels, "markerChannels is null"); @@ -142,6 +147,7 @@ public TableFunctionOperatorFactory( this.operatorId = operatorId; this.planNodeId = planNodeId; this.tableFunctionProvider = tableFunctionProvider; + this.catalogHandle = catalogHandle; this.functionHandle = functionHandle; this.properChannelsCount = properChannelsCount; this.passThroughSourcesCount = passThroughSourcesCount; @@ -170,6 +176,7 @@ public Operator createOperator(DriverContext driverContext) return new TableFunctionOperator( operatorContext, tableFunctionProvider, + catalogHandle, functionHandle, properChannelsCount, passThroughSourcesCount, @@ -200,6 +207,7 @@ public OperatorFactory duplicate() operatorId, planNodeId, tableFunctionProvider, + catalogHandle, functionHandle, properChannelsCount, passThroughSourcesCount, @@ -219,7 +227,7 @@ public OperatorFactory duplicate() } private final OperatorContext operatorContext; - + private final ConnectorSession session; private final PageBuffer pageBuffer = new PageBuffer(); private final WorkProcessor outputPages; private final boolean processEmptyInput; @@ -227,6 +235,7 @@ public OperatorFactory duplicate() public TableFunctionOperator( OperatorContext operatorContext, TableFunctionProcessorProvider tableFunctionProvider, + CatalogHandle catalogHandle, ConnectorTableFunctionHandle functionHandle, int properChannelsCount, int passThroughSourcesCount, @@ -245,6 +254,7 @@ public TableFunctionOperator( { requireNonNull(operatorContext, "operatorContext is null"); requireNonNull(tableFunctionProvider, "tableFunctionProvider is null"); + requireNonNull(catalogHandle, "catalogHandle is null"); requireNonNull(functionHandle, "functionHandle is null"); requireNonNull(requiredChannels, "requiredChannels is null"); requireNonNull(markerChannels, "markerChannels is null"); @@ -261,7 +271,7 @@ public TableFunctionOperator( requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); this.operatorContext = operatorContext; - + this.session = operatorContext.getSession().toConnectorSession(catalogHandle); this.processEmptyInput = !pruneWhenEmpty; PagesIndex pagesIndex = pagesIndexFactory.newPagesIndex(sourceTypes, expectedPositions); @@ -273,6 +283,7 @@ public TableFunctionOperator( groupPagesIndex, hashStrategies, tableFunctionProvider, + session, functionHandle, properChannelsCount, passThroughSourcesCount, @@ -517,6 +528,7 @@ private WorkProcessor pagesIndexToTableFunctionPartition PagesIndex pagesIndex, HashStrategies hashStrategies, TableFunctionProcessorProvider tableFunctionProvider, + ConnectorSession session, ConnectorTableFunctionHandle functionHandle, int properChannelsCount, int passThroughSourcesCount, @@ -542,7 +554,7 @@ public WorkProcessor.ProcessState process() // empty PagesIndex can only be passed once as the result of PartitionAndSort. Neither this nor any future instance of Process will ever get an empty PagesIndex again. processEmpty = false; return WorkProcessor.ProcessState.ofResult(new EmptyTableFunctionPartition( - tableFunctionProvider.getDataProcessor(functionHandle), + tableFunctionProvider.getDataProcessor(session, functionHandle), properChannelsCount, passThroughSourcesCount, passThroughSpecifications.stream() @@ -562,7 +574,7 @@ public WorkProcessor.ProcessState process() pagesIndex, partitionStart, partitionEnd, - tableFunctionProvider.getDataProcessor(functionHandle), + tableFunctionProvider.getDataProcessor(session, functionHandle), properChannelsCount, passThroughSourcesCount, requiredChannels, diff --git a/core/trino-main/src/main/java/io/trino/operator/table/ExcludeColumns.java b/core/trino-main/src/main/java/io/trino/operator/table/ExcludeColumns.java index 2ef635f3a946..d650c97b2b98 100644 --- a/core/trino-main/src/main/java/io/trino/operator/table/ExcludeColumns.java +++ b/core/trino-main/src/main/java/io/trino/operator/table/ExcludeColumns.java @@ -154,7 +154,7 @@ public static TableFunctionProcessorProvider getExcludeColumnsFunctionProcessorP return new TableFunctionProcessorProvider() { @Override - public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle) { return input -> { if (input == null) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index b00cc38fab58..a4d67d8e88d4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -1723,6 +1723,7 @@ public PhysicalOperation visitTableFunctionProcessor(TableFunctionProcessorNode context.getNextOperatorId(), node.getId(), processorProvider, + node.getHandle().getCatalogHandle(), node.getHandle().getFunctionHandle(), properChannelsCount, toIntExact(passThroughSourcesCount), diff --git a/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java b/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java index 8929d8758d4f..81ddc2f40a4e 100644 --- a/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java +++ b/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java @@ -610,7 +610,7 @@ public static class IdentityFunctionProcessorProvider implements TableFunctionProcessorProvider { @Override - public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle) { return input -> { if (input == null) { @@ -659,7 +659,7 @@ public static class IdentityPassThroughFunctionProcessorProvider implements TableFunctionProcessorProvider { @Override - public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle) { return new IdentityPassThroughFunctionProcessor(); } @@ -750,7 +750,7 @@ public static class RepeatFunctionProcessorProvider implements TableFunctionProcessorProvider { @Override - public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle) { return new RepeatFunctionProcessor(((RepeatFunctionHandle) handle).getCount()); } @@ -848,7 +848,7 @@ public static class EmptyOutputProcessorProvider implements TableFunctionProcessorProvider { @Override - public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle) { return new EmptyOutputProcessor(); } @@ -906,7 +906,7 @@ public static class EmptyOutputWithPassThroughProcessorProvider implements TableFunctionProcessorProvider { @Override - public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle) { return new EmptyOutputWithPassThroughProcessor(); } @@ -982,7 +982,7 @@ public static class TestInputsFunctionProcessorProvider implements TableFunctionProcessorProvider { @Override - public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle) { BlockBuilder resultBuilder = BOOLEAN.createBlockBuilder(null, 1); BOOLEAN.writeBoolean(resultBuilder, true); @@ -1043,7 +1043,7 @@ public static class PassThroughInputProcessorProvider implements TableFunctionProcessorProvider { @Override - public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle) { return new PassThroughInputProcessor(); } @@ -1142,7 +1142,7 @@ public static class TestInputProcessorProvider implements TableFunctionProcessorProvider { @Override - public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle) { return new TestInputProcessor(); } @@ -1206,7 +1206,7 @@ public static class TestSingleInputFunctionProcessorProvider implements TableFunctionProcessorProvider { @Override - public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle) { BlockBuilder builder = BOOLEAN.createBlockBuilder(null, 1); BOOLEAN.writeBoolean(builder, true); diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/table/TableFunctionProcessorProvider.java b/core/trino-spi/src/main/java/io/trino/spi/function/table/TableFunctionProcessorProvider.java index 0be4cd2ed585..d0d70448ead0 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/function/table/TableFunctionProcessorProvider.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/table/TableFunctionProcessorProvider.java @@ -24,7 +24,7 @@ public interface TableFunctionProcessorProvider * This method returns a {@code TableFunctionDataProcessor}. All the necessary information collected during analysis is available * in the form of {@link ConnectorTableFunctionHandle}. It is called once per each partition processed by the table function. */ - default TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + default TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle) { throw new UnsupportedOperationException("this table function does not process input data"); } diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeTableFunctionDataProcessor.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeTableFunctionDataProcessor.java new file mode 100644 index 000000000000..ccf073d671f3 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeTableFunctionDataProcessor.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.classloader; + +import io.trino.spi.Page; +import io.trino.spi.classloader.ThreadContextClassLoader; +import io.trino.spi.function.table.TableFunctionDataProcessor; +import io.trino.spi.function.table.TableFunctionProcessorState; + +import java.util.List; +import java.util.Optional; + +public class ClassLoaderSafeTableFunctionDataProcessor + implements TableFunctionDataProcessor +{ + private final TableFunctionDataProcessor delegate; + private final ClassLoader classLoader; + + public ClassLoaderSafeTableFunctionDataProcessor(TableFunctionDataProcessor delegate, ClassLoader classLoader) + { + this.delegate = delegate; + this.classLoader = classLoader; + } + + @Override + public TableFunctionProcessorState process(List> input) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.process(input); + } + } +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeTableFunctionProcessorProvider.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeTableFunctionProcessorProvider.java index 51f54e99ea67..c95483774291 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeTableFunctionProcessorProvider.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeTableFunctionProcessorProvider.java @@ -36,10 +36,10 @@ public ClassLoaderSafeTableFunctionProcessorProvider(TableFunctionProcessorProvi } @Override - public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle) { try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { - return delegate.getDataProcessor(handle); + return delegate.getDataProcessor(session, handle); } } diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/classloader/TestClassLoaderSafeWrappers.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/classloader/TestClassLoaderSafeWrappers.java index d0834ec00fa6..66a1412e61a5 100644 --- a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/classloader/TestClassLoaderSafeWrappers.java +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/classloader/TestClassLoaderSafeWrappers.java @@ -27,6 +27,7 @@ import io.trino.spi.connector.SystemTable; import io.trino.spi.eventlistener.EventListener; import io.trino.spi.function.table.ConnectorTableFunction; +import io.trino.spi.function.table.TableFunctionDataProcessor; import io.trino.spi.function.table.TableFunctionProcessorProvider; import io.trino.spi.function.table.TableFunctionSplitProcessor; import org.junit.jupiter.api.Test; @@ -61,6 +62,7 @@ public void test() testClassLoaderSafe(EventListener.class, ClassLoaderSafeEventListener.class); testClassLoaderSafe(ConnectorTableFunction.class, ClassLoaderSafeConnectorTableFunction.class); testClassLoaderSafe(TableFunctionSplitProcessor.class, ClassLoaderSafeTableFunctionSplitProcessor.class); + testClassLoaderSafe(TableFunctionDataProcessor.class, ClassLoaderSafeTableFunctionDataProcessor.class); testClassLoaderSafe(TableFunctionProcessorProvider.class, ClassLoaderSafeTableFunctionProcessorProvider.class); }