Skip to content

Commit

Permalink
Extract TestingDirectTrinoClient
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasz-stec authored and kokosing committed Oct 11, 2024
1 parent ab88df1 commit 7d896cc
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 108 deletions.
114 changes: 10 additions & 104 deletions core/trino-main/src/main/java/io/trino/testing/DirectTrinoClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,56 +17,41 @@
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.slice.Slice;
import io.opentelemetry.api.trace.Span;
import io.trino.Session;
import io.trino.dispatcher.DispatchManager;
import io.trino.dispatcher.DispatchQuery;
import io.trino.exchange.DirectExchangeInput;
import io.trino.execution.QueryInfo;
import io.trino.execution.QueryManager;
import io.trino.execution.QueryState;
import io.trino.execution.buffer.PageDeserializer;
import io.trino.execution.buffer.PagesSerdeFactory;
import io.trino.memory.context.SimpleLocalMemoryContext;
import io.trino.operator.DirectExchangeClient;
import io.trino.operator.DirectExchangeClientSupplier;
import io.trino.server.ResultQueryInfo;
import io.trino.server.SessionContext;
import io.trino.server.protocol.ProtocolUtil;
import io.trino.server.protocol.Slug;
import io.trino.spi.Page;
import io.trino.spi.QueryId;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockEncodingSerde;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.exchange.ExchangeId;
import io.trino.spi.type.Type;
import org.intellij.lang.annotations.Language;

import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicReference;

import static com.google.common.base.Preconditions.checkArgument;
import static io.airlift.concurrent.MoreFutures.whenAnyComplete;
import static io.trino.SystemSessionProperties.getRetryPolicy;
import static io.trino.execution.QueryState.FAILED;
import static io.trino.execution.QueryState.FINISHED;
import static io.trino.execution.QueryState.FINISHING;
import static io.trino.execution.buffer.CompressionCodec.NONE;
import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext;
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.testing.MaterializedResult.DEFAULT_PRECISION;
import static io.trino.util.MoreLists.mappedCopy;
import static java.util.Objects.requireNonNull;

class DirectTrinoClient
public class DirectTrinoClient
{
private final DispatchManager dispatchManager;
private final QueryManager queryManager;
Expand All @@ -81,32 +66,23 @@ public DirectTrinoClient(DispatchManager dispatchManager, QueryManager queryMana
this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null");
}

public Result execute(Session session, @Language("SQL") String sql)
{
return execute(SessionContext.fromSession(session), sql);
}

public Result execute(SessionContext sessionContext, @Language("SQL") String sql)
public DispatchQuery execute(SessionContext sessionContext, @Language("SQL") String sql, QueryResultsListener queryResultsListener)
{
// create the query and wait for it to be dispatched
QueryId queryId = dispatchManager.createQueryId();
getQueryFuture(dispatchManager.createQuery(queryId, Span.getInvalid(), Slug.createNew(), sessionContext, sql));
getQueryFuture(dispatchManager.waitForDispatched(queryId));
DispatchQuery dispatchQuery = dispatchManager.getQuery(queryId);
if (dispatchQuery.getState().isDone()) {
return new Result(queryId, toMaterializedRows(dispatchQuery, ImmutableList.of(), ImmutableList.of(), ImmutableList.of()));
return dispatchQuery;
}

// read all output data
AtomicReference<List<String>> columnNames = new AtomicReference<>();
AtomicReference<List<Type>> columnTypes = new AtomicReference<>();
List<Page> pages = new ArrayList<>();
try (DirectExchangeClient exchangeClient = createExchangeClient(dispatchQuery)) {
queryManager.setOutputInfoListener(queryId, outputInfo -> {
// the listener is executed concurrently, so the call back must be synchronized to avoid a race between adding locations and setting no more locations
synchronized (this) {
columnNames.compareAndSet(null, outputInfo.getColumnNames());
columnTypes.compareAndSet(null, outputInfo.getColumnTypes());
queryResultsListener.setOutputColumns(outputInfo.getColumnNames(), outputInfo.getColumnTypes());

outputInfo.drainInputs(input -> {
DirectExchangeInput exchangeInput = (DirectExchangeInput) input;
Expand All @@ -126,7 +102,7 @@ public Result execute(SessionContext sessionContext, @Language("SQL") String sql
state = queryManager.getQueryState(queryId)) {
for (Slice serializedPage = exchangeClient.pollPage(); serializedPage != null; serializedPage = exchangeClient.pollPage()) {
Page page = pageDeserializer.deserialize(serializedPage);
pages.add(page);
queryResultsListener.consumeOutputPage(page);
}
getQueryFuture(whenAnyComplete(ImmutableList.of(queryManager.getStateChange(queryId, state), exchangeClient.isBlocked())));
}
Expand All @@ -138,7 +114,7 @@ public Result execute(SessionContext sessionContext, @Language("SQL") String sql
getQueryFuture(queryManager.getStateChange(queryId, queryState));
}

return new Result(queryId, toMaterializedRows(dispatchQuery, columnTypes.get(), columnNames.get(), pages));
return dispatchQuery;
}

private DirectExchangeClient createExchangeClient(DispatchQuery dispatchQuery)
Expand All @@ -152,74 +128,6 @@ private DirectExchangeClient createExchangeClient(DispatchQuery dispatchQuery)
getRetryPolicy(dispatchQuery.getSession()));
}

private static MaterializedResult toMaterializedRows(DispatchQuery dispatchQuery, List<Type> columnTypes, List<String> columnNames, List<Page> pages)
{
QueryInfo queryInfo = dispatchQuery.getFullQueryInfo();
ConnectorSession session = dispatchQuery.getSession().toConnectorSession();

if (queryInfo.getState() != FINISHED) {
if (queryInfo.getFailureInfo() == null) {
throw new QueryFailedException(queryInfo.getQueryId(), "Query failed without failure info");
}
RuntimeException remoteException = queryInfo.getFailureInfo().toException();
throw new QueryFailedException(queryInfo.getQueryId(), Optional.ofNullable(remoteException.getMessage()).orElseGet(remoteException::toString), remoteException);
}
if (pages.isEmpty() && columnTypes == null) {
// the query did not produce any output
return new MaterializedResult(
ImmutableList.of(),
ImmutableList.of(),
ImmutableList.of(),
queryInfo.getSetSessionProperties(),
queryInfo.getResetSessionProperties(),
Optional.ofNullable(queryInfo.getUpdateType()),
OptionalLong.empty(),
mappedCopy(queryInfo.getWarnings(), ProtocolUtil::toClientWarning),
Optional.of(ProtocolUtil.toStatementStats(new ResultQueryInfo(queryInfo))));
}

List<MaterializedRow> materializedRows = toMaterializedRows(session, columnTypes, pages);

OptionalLong updateCount = OptionalLong.empty();
if (queryInfo.getUpdateType() != null && materializedRows.size() == 1 && columnTypes.size() == 1 && columnTypes.get(0).equals(BIGINT)) {
Number value = (Number) materializedRows.get(0).getField(0);
if (value != null) {
updateCount = OptionalLong.of(value.longValue());
}
}

return new MaterializedResult(
materializedRows,
columnTypes,
columnNames,
queryInfo.getSetSessionProperties(),
queryInfo.getResetSessionProperties(),
Optional.ofNullable(queryInfo.getUpdateType()),
updateCount,
mappedCopy(queryInfo.getWarnings(), ProtocolUtil::toClientWarning),
Optional.of(ProtocolUtil.toStatementStats(new ResultQueryInfo(queryInfo))));
}

private static List<MaterializedRow> toMaterializedRows(ConnectorSession session, List<Type> types, List<Page> pages)
{
ImmutableList.Builder<MaterializedRow> rows = ImmutableList.builder();
for (Page page : pages) {
checkArgument(page.getChannelCount() == types.size(), "Expected a page with %s columns, but got %s columns", types.size(), page.getChannelCount());
for (int position = 0; position < page.getPositionCount(); position++) {
List<Object> values = new ArrayList<>(page.getChannelCount());
for (int channel = 0; channel < page.getChannelCount(); channel++) {
Type type = types.get(channel);
Block block = page.getBlock(channel);
values.add(type.getObjectValue(session, block, position));
}
values = Collections.unmodifiableList(values);

rows.add(new MaterializedRow(DEFAULT_PRECISION, values));
}
}
return rows.build();
}

private static <T> void getQueryFuture(ListenableFuture<T> future)
{
try {
Expand All @@ -234,12 +142,10 @@ private static <T> void getQueryFuture(ListenableFuture<T> future)
}
}

record Result(QueryId queryId, MaterializedResult result)
public interface QueryResultsListener
{
Result
{
requireNonNull(queryId, "queryId is null");
requireNonNull(result, "result is null");
}
void setOutputColumns(List<String> columnNames, List<Type> columnTypes);

void consumeOutputPage(Page page);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public final class StandaloneQueryRunner
{
private final Session defaultSession;
private final TestingTrinoServer server;
private final DirectTrinoClient trinoClient;
private final TestingDirectTrinoClient trinoClient;
private final InMemorySpanExporter spanExporter = InMemorySpanExporter.create();

private final ReadWriteLock lock = new ReentrantReadWriteLock();
Expand All @@ -88,7 +88,7 @@ public StandaloneQueryRunner(Session defaultSession, Consumer<TestingTrinoServer
serverProcessor.accept(builder);
this.server = builder.build();

this.trinoClient = new DirectTrinoClient(
this.trinoClient = new TestingDirectTrinoClient(
server.getDispatchManager(),
server.getQueryManager(),
server.getInstance(Key.get(DirectExchangeClientSupplier.class)),
Expand All @@ -111,11 +111,11 @@ public MaterializedResult execute(Session session, @Language("SQL") String sql)
@Override
public MaterializedResultWithPlan executeWithPlan(Session session, String sql)
{
DirectTrinoClient.Result result = executeInternal(session, sql);
TestingDirectTrinoClient.Result result = executeInternal(session, sql);
return new MaterializedResultWithPlan(result.queryId(), server.getQueryPlan(result.queryId()), result.result());
}

private DirectTrinoClient.Result executeInternal(Session session, @Language("SQL") String sql)
private TestingDirectTrinoClient.Result executeInternal(Session session, @Language("SQL") String sql)
{
lock.readLock().lock();
try {
Expand Down
Loading

0 comments on commit 7d896cc

Please sign in to comment.