Skip to content

Commit

Permalink
Support proxy for plaintext HTTP/2 clients with prior-knowledge (#2716)
Browse files Browse the repository at this point in the history
Motivation:

Plaintext proxies operate as forwarding message proxies. They expect
clients to send requests with absolute-form request-target to determine
where to forward the message. This logic does not depend on the protocol
version, it already works for HTTP/1.1 and should work for HTTP/2 the
same way.

Modifications:

- `DefaultSingleAddressHttpClientBuilder`: remove check that does not
allow users to use HTTP/2 prior-knowledge with a proxy over plaintext
connections;
- Enhance `HttpProxyTest` to validate behavior consistency for both
protocols;

Result:

Users can use plaintext HTTP/2 prior-knowledge clients to communicate
via a forwarding message proxy.
  • Loading branch information
idelpivnitskiy authored Oct 2, 2023
1 parent 61b65b8 commit 6d38417
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,6 @@ public HttpExecutionStrategy executionStrategy() {
}
};
final SslContext sslContext = roConfig.tcpConfig().sslContext();
if (roConfig.hasProxy() && sslContext == null && roConfig.h2Config() != null) {
throw new IllegalStateException("Proxying is not yet supported with plaintext HTTP/2");
}

// Track resources that potentially need to be closed when an exception is thrown during buildStreaming
final CompositeCloseable closeOnException = newCompositeCloseable();
Expand Down Expand Up @@ -266,7 +263,9 @@ public HttpExecutionStrategy executionStrategy() {
ctx.builder.addIdleTimeoutConnectionFilter ?
appendConnectionFilter(ctx.builder.connectionFilterFactory, DEFAULT_IDLE_TIMEOUT_FILTER) :
ctx.builder.connectionFilterFactory;
if (!roConfig.hasProxy() && roConfig.isH2PriorKnowledge()) {
if (roConfig.isH2PriorKnowledge() &&
// Direct connection or HTTP proxy
(!roConfig.hasProxy() || sslContext == null)) {
H2ProtocolConfig h2Config = roConfig.h2Config();
assert h2Config != null;
connectionFactory = new H2LBHttpConnectionFactory<>(roConfig, executionContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@

import io.servicetalk.http.api.BlockingHttpClient;
import io.servicetalk.http.api.HttpClient;
import io.servicetalk.http.api.HttpProtocolVersion;
import io.servicetalk.http.api.HttpResponse;
import io.servicetalk.http.api.SingleAddressHttpClientBuilder;
import io.servicetalk.http.netty.HttpsProxyTest.TargetAddressCheckConnectionFactoryFilter;
import io.servicetalk.transport.api.HostAndPort;
import io.servicetalk.transport.api.ServerContext;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;

import java.net.InetSocketAddress;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
Expand All @@ -39,10 +41,13 @@
import static io.servicetalk.http.api.HttpHeaderNames.HOST;
import static io.servicetalk.http.api.HttpResponseStatus.OK;
import static io.servicetalk.http.api.HttpSerializers.textSerializerUtf8;
import static io.servicetalk.http.netty.HttpProtocol.HTTP_1;
import static io.servicetalk.http.netty.HttpProtocol.HTTP_2;
import static io.servicetalk.http.netty.HttpsProxyTest.safeClose;
import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress;
import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static java.util.Arrays.asList;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
Expand All @@ -62,10 +67,9 @@ class HttpProxyTest {
private final AtomicInteger proxyRequestCount = new AtomicInteger();
private final AtomicReference<Object> targetAddress = new AtomicReference<>();

@BeforeEach
void setup() throws Exception {
startProxy();
startServer();
private void setUp(HttpProtocol clientProtocol, HttpProtocol serverProtocol) throws Exception {
startProxy(clientProtocol, serverProtocol);
startServer(serverProtocol);
}

@AfterEach
Expand All @@ -75,79 +79,104 @@ void tearDown() {
safeClose(serverContext);
}

void startProxy() throws Exception {
proxyClient = HttpClients.forMultiAddressUrl(getClass().getSimpleName()).build();
private void startProxy(HttpProtocol clientProtocol, HttpProtocol serverProtocol) throws Exception {
proxyClient = HttpClients.forMultiAddressUrl(getClass().getSimpleName())
.initializer((scheme, address, builder) -> builder.protocols(serverProtocol.config))
.build();
proxyContext = HttpServers.forAddress(localAddress(0))
.protocols(clientProtocol.config)
.listenAndAwait((ctx, request, responseFactory) -> {
proxyRequestCount.incrementAndGet();
return proxyClient.request(request);
return proxyClient.request(request.version(serverProtocol.version))
.map(response -> response.version(clientProtocol.version));
});
proxyAddress = serverHostAndPort(proxyContext);
}

void startServer() throws Exception {
private void startServer(HttpProtocol protocol) throws Exception {
serverContext = HttpServers.forAddress(localAddress(0))
.protocols(protocol.config)
.listenAndAwait((ctx, request, responseFactory) -> succeeded(responseFactory.ok()
.payloadBody("host: " + request.headers().get(HOST), textSerializerUtf8())));
serverAddress = serverHostAndPort(serverContext);
}

private enum ClientSource {
SINGLE(HttpClients::forSingleAddress),
RESOLVED(HttpClients::forResolvedAddress);
private static List<Arguments> protocols() {
return asList(Arguments.of(HTTP_1, HTTP_1), Arguments.of(HTTP_2, HTTP_2),
Arguments.of(HTTP_1, HTTP_2), Arguments.of(HTTP_2, HTTP_1));
}

private final Function<HostAndPort, SingleAddressHttpClientBuilder<HostAndPort, InetSocketAddress>>
clientBuilderFactory;
@ParameterizedTest(name = "[{index}] clientProtocol={0} serverProtocol={1}")
@MethodSource("protocols")
void testRequestForSingleAddress(HttpProtocol clientProtocol, HttpProtocol serverProtocol) throws Exception {
testRequest(clientProtocol, serverProtocol, HttpClients::forSingleAddress);
}

ClientSource(Function<HostAndPort, SingleAddressHttpClientBuilder<HostAndPort, InetSocketAddress>>
clientBuilderFactory) {
this.clientBuilderFactory = clientBuilderFactory;
}
@ParameterizedTest(name = "[{index}] clientProtocol={0} serverProtocol={1}")
@MethodSource("protocols")
void testRequestForResolvedAddress(HttpProtocol clientProtocol, HttpProtocol serverProtocol) throws Exception {
testRequest(clientProtocol, serverProtocol, HttpClients::forResolvedAddress);
}

@ParameterizedTest(name = "[{index}] client = {0}")
@EnumSource
void testRequest(ClientSource clientSource) throws Exception {
private void testRequest(
HttpProtocol clientProtocol, HttpProtocol serverProtocol,
Function<HostAndPort, SingleAddressHttpClientBuilder<HostAndPort, InetSocketAddress>> clientBuilderFactory)
throws Exception {
setUp(clientProtocol, serverProtocol);
assert serverAddress != null && proxyAddress != null;

final BlockingHttpClient client = clientSource.clientBuilderFactory.apply(serverAddress)
try (BlockingHttpClient client = clientBuilderFactory.apply(serverAddress)
.proxyAddress(proxyAddress)
.protocols(clientProtocol.config)
.appendConnectionFactoryFilter(new TargetAddressCheckConnectionFactoryFilter(targetAddress, false))
.buildBlocking();
.buildBlocking()) {

final HttpResponse httpResponse = client.request(client.get("/path"));
assertThat(httpResponse.status(), is(OK));
assertThat(proxyRequestCount.get(), is(1));
assertThat(httpResponse.payloadBody().toString(US_ASCII), is("host: " + serverAddress));
assertThat(targetAddress.get(), is(equalTo(serverAddress.toString())));
safeClose(client);
assertResponse(client.request(client.get("/path")), clientProtocol.version);
}
}

@Test
void testBuilderReuseEachClientUsesOwnProxy() throws Exception {
@ParameterizedTest(name = "[{index}] protocol={0}")
@EnumSource(HttpProtocol.class)
void testBuilderReuseEachClientUsesOwnProxy(HttpProtocol protocol) throws Exception {
setUp(protocol, protocol);
assert serverAddress != null && proxyAddress != null;

final SingleAddressHttpClientBuilder<HostAndPort, InetSocketAddress> builder =
HttpClients.forSingleAddress(serverAddress);
final BlockingHttpClient client = builder.proxyAddress(proxyAddress).buildBlocking();
HttpClients.forSingleAddress(serverAddress)
.protocols(protocol.config);

final HttpClient otherProxyClient = HttpClients.forMultiAddressUrl(getClass().getSimpleName()).build();
final AtomicInteger otherProxyRequestCount = new AtomicInteger();
try (ServerContext otherProxyContext = HttpServers.forAddress(localAddress(0))
try (BlockingHttpClient client = builder.proxyAddress(proxyAddress).buildBlocking();
HttpClient otherProxyClient = HttpClients.forMultiAddressUrl(getClass().getSimpleName())
.initializer((scheme, address, builder1) -> builder1.protocols(protocol.config))
.build();
ServerContext otherProxyContext = HttpServers.forAddress(localAddress(0))
.protocols(protocol.config)
.listenAndAwait((ctx, request, responseFactory) -> {
otherProxyRequestCount.incrementAndGet();
return otherProxyClient.request(request);
});
BlockingHttpClient otherClient = builder.proxyAddress(serverHostAndPort(otherProxyContext))
.protocols(protocol.config)
.appendConnectionFactoryFilter(new TargetAddressCheckConnectionFactoryFilter(targetAddress, false))
.buildBlocking()) {

final HttpResponse httpResponse = otherClient.request(client.get("/path"));
assertThat(httpResponse.status(), is(OK));
assertResponse(otherClient.request(client.get("/path")), protocol.version, otherProxyRequestCount);
assertThat(proxyRequestCount.get(), is(0));
assertResponse(client.request(client.get("/path")), protocol.version);
assertThat(otherProxyRequestCount.get(), is(1));
assertThat(httpResponse.payloadBody().toString(US_ASCII), is("host: " + serverAddress));
}
}

private void assertResponse(HttpResponse httpResponse, HttpProtocolVersion expectedVersion) {
assertResponse(httpResponse, expectedVersion, proxyRequestCount);
}

final HttpResponse httpResponse = client.request(client.get("/path"));
private void assertResponse(HttpResponse httpResponse, HttpProtocolVersion expectedVersion,
AtomicInteger proxyRequestCount) {
assert serverAddress != null;
assertThat(httpResponse.status(), is(OK));
assertThat(httpResponse.version(), is(expectedVersion));
assertThat(proxyRequestCount.get(), is(1));
assertThat(httpResponse.payloadBody().toString(US_ASCII), is("host: " + serverAddress));
assertThat(targetAddress.get(), is(equalTo(serverAddress.toString())));
Expand Down

0 comments on commit 6d38417

Please sign in to comment.