Skip to content

Commit

Permalink
Update TableFunctionProcessorProvider.getDataProcessor to include Con…
Browse files Browse the repository at this point in the history
…nectorSession parameter
  • Loading branch information
tbaeg committed Nov 22, 2023
1 parent c2fb0c6 commit 49f56b5
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import io.trino.memory.context.LocalMemoryContext;
import io.trino.operator.RegularTableFunctionPartition.PassThroughColumnSpecification;
import io.trino.spi.Page;
import io.trino.spi.connector.CatalogHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.SortOrder;
import io.trino.spi.function.table.ConnectorTableFunctionHandle;
import io.trino.spi.function.table.TableFunctionProcessorProvider;
Expand Down Expand Up @@ -54,6 +56,7 @@ public static class TableFunctionOperatorFactory

// a provider of table function processor to be called once per partition
private final TableFunctionProcessorProvider tableFunctionProvider;
private final CatalogHandle catalogHandle;

// all information necessary to execute the table function collected during analysis
private final ConnectorTableFunctionHandle functionHandle;
Expand Down Expand Up @@ -106,6 +109,7 @@ public TableFunctionOperatorFactory(
int operatorId,
PlanNodeId planNodeId,
TableFunctionProcessorProvider tableFunctionProvider,
CatalogHandle catalogHandle,
ConnectorTableFunctionHandle functionHandle,
int properChannelsCount,
int passThroughSourcesCount,
Expand All @@ -124,6 +128,7 @@ public TableFunctionOperatorFactory(
{
requireNonNull(planNodeId, "planNodeId is null");
requireNonNull(tableFunctionProvider, "tableFunctionProvider is null");
requireNonNull(catalogHandle, "catalogHandle is null");
requireNonNull(functionHandle, "functionHandle is null");
requireNonNull(requiredChannels, "requiredChannels is null");
requireNonNull(markerChannels, "markerChannels is null");
Expand All @@ -142,6 +147,7 @@ public TableFunctionOperatorFactory(
this.operatorId = operatorId;
this.planNodeId = planNodeId;
this.tableFunctionProvider = tableFunctionProvider;
this.catalogHandle = catalogHandle;
this.functionHandle = functionHandle;
this.properChannelsCount = properChannelsCount;
this.passThroughSourcesCount = passThroughSourcesCount;
Expand Down Expand Up @@ -170,6 +176,7 @@ public Operator createOperator(DriverContext driverContext)
return new TableFunctionOperator(
operatorContext,
tableFunctionProvider,
catalogHandle,
functionHandle,
properChannelsCount,
passThroughSourcesCount,
Expand Down Expand Up @@ -200,6 +207,7 @@ public OperatorFactory duplicate()
operatorId,
planNodeId,
tableFunctionProvider,
catalogHandle,
functionHandle,
properChannelsCount,
passThroughSourcesCount,
Expand All @@ -219,14 +227,15 @@ public OperatorFactory duplicate()
}

private final OperatorContext operatorContext;

private final ConnectorSession session;
private final PageBuffer pageBuffer = new PageBuffer();
private final WorkProcessor<Page> outputPages;
private final boolean processEmptyInput;

public TableFunctionOperator(
OperatorContext operatorContext,
TableFunctionProcessorProvider tableFunctionProvider,
CatalogHandle catalogHandle,
ConnectorTableFunctionHandle functionHandle,
int properChannelsCount,
int passThroughSourcesCount,
Expand All @@ -245,6 +254,7 @@ public TableFunctionOperator(
{
requireNonNull(operatorContext, "operatorContext is null");
requireNonNull(tableFunctionProvider, "tableFunctionProvider is null");
requireNonNull(catalogHandle, "catalogHandle is null");
requireNonNull(functionHandle, "functionHandle is null");
requireNonNull(requiredChannels, "requiredChannels is null");
requireNonNull(markerChannels, "markerChannels is null");
Expand All @@ -261,7 +271,7 @@ public TableFunctionOperator(
requireNonNull(pagesIndexFactory, "pagesIndexFactory is null");

this.operatorContext = operatorContext;

this.session = operatorContext.getSession().toConnectorSession(catalogHandle);
this.processEmptyInput = !pruneWhenEmpty;

PagesIndex pagesIndex = pagesIndexFactory.newPagesIndex(sourceTypes, expectedPositions);
Expand All @@ -273,6 +283,7 @@ public TableFunctionOperator(
groupPagesIndex,
hashStrategies,
tableFunctionProvider,
session,
functionHandle,
properChannelsCount,
passThroughSourcesCount,
Expand Down Expand Up @@ -517,6 +528,7 @@ private WorkProcessor<TableFunctionPartition> pagesIndexToTableFunctionPartition
PagesIndex pagesIndex,
HashStrategies hashStrategies,
TableFunctionProcessorProvider tableFunctionProvider,
ConnectorSession session,
ConnectorTableFunctionHandle functionHandle,
int properChannelsCount,
int passThroughSourcesCount,
Expand All @@ -542,7 +554,7 @@ public WorkProcessor.ProcessState<TableFunctionPartition> process()
// empty PagesIndex can only be passed once as the result of PartitionAndSort. Neither this nor any future instance of Process will ever get an empty PagesIndex again.
processEmpty = false;
return WorkProcessor.ProcessState.ofResult(new EmptyTableFunctionPartition(
tableFunctionProvider.getDataProcessor(functionHandle),
tableFunctionProvider.getDataProcessor(session, functionHandle),
properChannelsCount,
passThroughSourcesCount,
passThroughSpecifications.stream()
Expand All @@ -562,7 +574,7 @@ public WorkProcessor.ProcessState<TableFunctionPartition> process()
pagesIndex,
partitionStart,
partitionEnd,
tableFunctionProvider.getDataProcessor(functionHandle),
tableFunctionProvider.getDataProcessor(session, functionHandle),
properChannelsCount,
passThroughSourcesCount,
requiredChannels,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ public static TableFunctionProcessorProvider getExcludeColumnsFunctionProcessorP
return new TableFunctionProcessorProvider()
{
@Override
public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle)
public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle)
{
return input -> {
if (input == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,7 @@ public PhysicalOperation visitTableFunctionProcessor(TableFunctionProcessorNode
context.getNextOperatorId(),
node.getId(),
processorProvider,
node.getHandle().getCatalogHandle(),
node.getHandle().getFunctionHandle(),
properChannelsCount,
toIntExact(passThroughSourcesCount),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ public static class IdentityFunctionProcessorProvider
implements TableFunctionProcessorProvider
{
@Override
public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle)
public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle)
{
return input -> {
if (input == null) {
Expand Down Expand Up @@ -659,7 +659,7 @@ public static class IdentityPassThroughFunctionProcessorProvider
implements TableFunctionProcessorProvider
{
@Override
public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle)
public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle)
{
return new IdentityPassThroughFunctionProcessor();
}
Expand Down Expand Up @@ -750,7 +750,7 @@ public static class RepeatFunctionProcessorProvider
implements TableFunctionProcessorProvider
{
@Override
public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle)
public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle)
{
return new RepeatFunctionProcessor(((RepeatFunctionHandle) handle).getCount());
}
Expand Down Expand Up @@ -848,7 +848,7 @@ public static class EmptyOutputProcessorProvider
implements TableFunctionProcessorProvider
{
@Override
public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle)
public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle)
{
return new EmptyOutputProcessor();
}
Expand Down Expand Up @@ -906,7 +906,7 @@ public static class EmptyOutputWithPassThroughProcessorProvider
implements TableFunctionProcessorProvider
{
@Override
public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle)
public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle)
{
return new EmptyOutputWithPassThroughProcessor();
}
Expand Down Expand Up @@ -982,7 +982,7 @@ public static class TestInputsFunctionProcessorProvider
implements TableFunctionProcessorProvider
{
@Override
public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle)
public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle)
{
BlockBuilder resultBuilder = BOOLEAN.createBlockBuilder(null, 1);
BOOLEAN.writeBoolean(resultBuilder, true);
Expand Down Expand Up @@ -1043,7 +1043,7 @@ public static class PassThroughInputProcessorProvider
implements TableFunctionProcessorProvider
{
@Override
public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle)
public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle)
{
return new PassThroughInputProcessor();
}
Expand Down Expand Up @@ -1142,7 +1142,7 @@ public static class TestInputProcessorProvider
implements TableFunctionProcessorProvider
{
@Override
public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle)
public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle)
{
return new TestInputProcessor();
}
Expand Down Expand Up @@ -1206,7 +1206,7 @@ public static class TestSingleInputFunctionProcessorProvider
implements TableFunctionProcessorProvider
{
@Override
public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle)
public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle)
{
BlockBuilder builder = BOOLEAN.createBlockBuilder(null, 1);
BOOLEAN.writeBoolean(builder, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public interface TableFunctionProcessorProvider
* This method returns a {@code TableFunctionDataProcessor}. All the necessary information collected during analysis is available
* in the form of {@link ConnectorTableFunctionHandle}. It is called once per each partition processed by the table function.
*/
default TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle)
default TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle)
{
throw new UnsupportedOperationException("this table function does not process input data");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.base.classloader;

import io.trino.spi.Page;
import io.trino.spi.classloader.ThreadContextClassLoader;
import io.trino.spi.function.table.TableFunctionDataProcessor;
import io.trino.spi.function.table.TableFunctionProcessorState;

import java.util.List;
import java.util.Optional;

public class ClassLoaderSafeTableFunctionDataProcessor
implements TableFunctionDataProcessor
{
private final TableFunctionDataProcessor delegate;
private final ClassLoader classLoader;

public ClassLoaderSafeTableFunctionDataProcessor(TableFunctionDataProcessor delegate, ClassLoader classLoader)
{
this.delegate = delegate;
this.classLoader = classLoader;
}

@Override
public TableFunctionProcessorState process(List<Optional<Page>> input)
{
try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) {
return delegate.process(input);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ public ClassLoaderSafeTableFunctionProcessorProvider(TableFunctionProcessorProvi
}

@Override
public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle)
public TableFunctionDataProcessor getDataProcessor(ConnectorSession session, ConnectorTableFunctionHandle handle)
{
try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) {
return delegate.getDataProcessor(handle);
return delegate.getDataProcessor(session, handle);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.trino.spi.connector.SystemTable;
import io.trino.spi.eventlistener.EventListener;
import io.trino.spi.function.table.ConnectorTableFunction;
import io.trino.spi.function.table.TableFunctionDataProcessor;
import io.trino.spi.function.table.TableFunctionProcessorProvider;
import io.trino.spi.function.table.TableFunctionSplitProcessor;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -61,6 +62,7 @@ public void test()
testClassLoaderSafe(EventListener.class, ClassLoaderSafeEventListener.class);
testClassLoaderSafe(ConnectorTableFunction.class, ClassLoaderSafeConnectorTableFunction.class);
testClassLoaderSafe(TableFunctionSplitProcessor.class, ClassLoaderSafeTableFunctionSplitProcessor.class);
testClassLoaderSafe(TableFunctionDataProcessor.class, ClassLoaderSafeTableFunctionDataProcessor.class);
testClassLoaderSafe(TableFunctionProcessorProvider.class, ClassLoaderSafeTableFunctionProcessorProvider.class);
}

Expand Down

0 comments on commit 49f56b5

Please sign in to comment.