Skip to content

Commit

Permalink
HTTPCLIENT-2350 - Refactored the connect method in DefaultHttpClientC…
Browse files Browse the repository at this point in the history
…onnectionOperator to enhance flexibility in address resolution, specifically allowing for direct handling of unresolved addresses. Updated DnsResolver to introduce a new resolve method supporting both standard and bypassed DNS lookups, enabling improved support for non-public resolvable hosts like .onion endpoints via SOCKS proxy. Adjusted related tests to align with the new resolution mechanism.
  • Loading branch information
arturobernalg committed Nov 15, 2024
1 parent 4b2a365 commit 7b678ce
Show file tree
Hide file tree
Showing 7 changed files with 358 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
package org.apache.hc.client5.http;

import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

import org.apache.hc.core5.annotation.Contract;
import org.apache.hc.core5.annotation.ThreadingBehavior;
Expand Down Expand Up @@ -61,4 +67,21 @@ public interface DnsResolver {
*/
String resolveCanonicalHostname(String host) throws UnknownHostException;

/**
* Returns a list of {@link SocketAddress} for the given host with the given port.
*
* @see SocketAddress
*
* @since 5.5
*/
default List<SocketAddress> resolve(String host, int port) throws UnknownHostException {
final InetAddress[] inetAddresses = resolve(host);
if (inetAddresses == null) {
return Collections.singletonList(InetSocketAddress.createUnresolved(host, port));
}
return Arrays.stream(inetAddresses)
.map(e -> new InetSocketAddress(e, port))
.collect(Collectors.toList());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@
package org.apache.hc.client5.http.impl.io;

import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import javax.net.ssl.SSLSocket;

Expand Down Expand Up @@ -154,43 +153,37 @@ public void connect(
final SocketConfig socketConfig,
final Object attachment,
final HttpContext context) throws IOException {

Args.notNull(conn, "Connection");
Args.notNull(endpointHost, "Host");
Args.notNull(socketConfig, "Socket config");
Args.notNull(context, "Context");
final InetAddress[] remoteAddresses;
if (endpointHost.getAddress() != null) {
remoteAddresses = new InetAddress[] { endpointHost.getAddress() };
} else {
if (LOG.isDebugEnabled()) {
LOG.debug("{} resolving remote address", endpointHost.getHostName());
}

remoteAddresses = this.dnsResolver.resolve(endpointHost.getHostName());

if (LOG.isDebugEnabled()) {
LOG.debug("{} resolved to {}", endpointHost.getHostName(), remoteAddresses == null ? "null" : Arrays.asList(remoteAddresses));
}

if (remoteAddresses == null || remoteAddresses.length == 0) {
throw new UnknownHostException(endpointHost.getHostName());
}
}

final Timeout soTimeout = socketConfig.getSoTimeout();
final SocketAddress socksProxyAddress = socketConfig.getSocksProxyAddress();
final Proxy socksProxy = socksProxyAddress != null ? new Proxy(Proxy.Type.SOCKS, socksProxyAddress) : null;
final int port = this.schemePortResolver.resolve(endpointHost.getSchemeName(), endpointHost);
for (int i = 0; i < remoteAddresses.length; i++) {
final InetAddress address = remoteAddresses[i];
final boolean last = i == remoteAddresses.length - 1;
final InetSocketAddress remoteAddress = new InetSocketAddress(address, port);

final List<SocketAddress> remoteAddresses;
if (endpointHost.getAddress() != null) {
remoteAddresses = Collections.singletonList(
new InetSocketAddress(endpointHost.getAddress(), this.schemePortResolver.resolve(endpointHost.getSchemeName(), endpointHost)));
} else {
final int port = this.schemePortResolver.resolve(endpointHost.getSchemeName(), endpointHost);
remoteAddresses = this.dnsResolver.resolve(endpointHost.getHostName(), port);
}
for (int i = 0; i < remoteAddresses.size(); i++) {
final InetSocketAddress remoteAddress = (InetSocketAddress) remoteAddresses.get(i);
final boolean last = i == remoteAddresses.size() - 1;
onBeforeSocketConnect(context, endpointHost);
if (LOG.isDebugEnabled()) {
LOG.debug("{} connecting {}->{} ({})", endpointHost, localAddress, remoteAddress, connectTimeout);
}
final Socket socket = detachedSocketFactory.create(socksProxy);
try {
// Always bind to the local address if it's provided.
if (localAddress != null) {
socket.bind(localAddress);
}
conn.bind(socket);
if (soTimeout != null) {
socket.setSoTimeout(soTimeout.toMillisecondsIntBound());
Expand All @@ -209,16 +202,11 @@ public void connect(
if (linger >= 0) {
socket.setSoLinger(true, linger);
}

if (localAddress != null) {
socket.bind(localAddress);
}
socket.connect(remoteAddress, TimeValue.isPositive(connectTimeout) ? connectTimeout.toMillisecondsIntBound() : 0);
conn.bind(socket);
onAfterSocketConnect(context, endpointHost);
if (LOG.isDebugEnabled()) {
LOG.debug("{} {} connected {}->{}", ConnPoolSupport.getId(conn), endpointHost,
conn.getLocalAddress(), conn.getRemoteAddress());
LOG.debug("{} {} connected {}->{}", ConnPoolSupport.getId(conn), endpointHost, conn.getLocalAddress(), conn.getRemoteAddress());
}
conn.setSocketTimeout(soTimeout);
final TlsSocketStrategy tlsSocketStrategy = tlsSocketStrategyLookup != null ? tlsSocketStrategyLookup.lookup(endpointHost.getSchemeName()) : null;
Expand All @@ -245,7 +233,7 @@ public void connect(
if (LOG.isDebugEnabled()) {
LOG.debug("{} connection to {} failed ({}); terminating operation", endpointHost, remoteAddress, ex.getClass());
}
throw ConnectExceptionSupport.enhance(ex, endpointHost, remoteAddresses);
throw ConnectExceptionSupport.enhance(ex, endpointHost);
}
if (LOG.isDebugEnabled()) {
LOG.debug("{} connection to {} failed ({}); retrying connection to the next address", endpointHost, remoteAddress, ex.getClass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;

Expand Down Expand Up @@ -108,19 +108,19 @@ public void cancelled() {
LOG.debug("{} resolving remote address", remoteEndpoint.getHostName());
}

final InetAddress[] remoteAddresses;
final List<SocketAddress> remoteAddresses;
try {
remoteAddresses = dnsResolver.resolve(remoteEndpoint.getHostName());
if (remoteAddresses == null || remoteAddresses.length == 0) {
throw new UnknownHostException(remoteEndpoint.getHostName());
remoteAddresses = dnsResolver.resolve(remoteEndpoint.getHostName(), remoteEndpoint.getPort());
if (remoteAddresses == null || remoteAddresses.isEmpty()) {
throw new UnknownHostException(remoteEndpoint.getHostName());
}
} catch (final UnknownHostException ex) {
future.failed(ex);
return future;
}

if (LOG.isDebugEnabled()) {
LOG.debug("{} resolved to {}", remoteEndpoint.getHostName(), Arrays.asList(remoteAddresses));
LOG.debug("{} resolved to {}", remoteEndpoint.getHostName(), remoteAddresses);
}

final Runnable runnable = new Runnable() {
Expand All @@ -129,7 +129,7 @@ public void cancelled() {

void executeNext() {
final int index = attempt.getAndIncrement();
final InetSocketAddress remoteAddress = new InetSocketAddress(remoteAddresses[index], remoteEndpoint.getPort());
final InetSocketAddress remoteAddress = (InetSocketAddress) remoteAddresses.get(index);

if (LOG.isDebugEnabled()) {
LOG.debug("{}:{} connecting {}->{} ({})",
Expand All @@ -155,13 +155,17 @@ public void completed(final IOSession session) {

@Override
public void failed(final Exception cause) {
if (attempt.get() >= remoteAddresses.length) {
if (attempt.get() >= remoteAddresses.size()) {
if (LOG.isDebugEnabled()) {
LOG.debug("{}:{} connection to {} failed ({}); terminating operation",
remoteEndpoint.getHostName(), remoteEndpoint.getPort(), remoteAddress, cause.getClass());
}
if (cause instanceof IOException) {
future.failed(ConnectExceptionSupport.enhance((IOException) cause, remoteEndpoint, remoteAddresses));
final InetAddress[] addresses = remoteAddresses.stream()
.filter(addr -> addr instanceof InetSocketAddress)
.map(addr -> ((InetSocketAddress) addr).getAddress())
.toArray(InetAddress[]::new);
future.failed(ConnectExceptionSupport.enhance((IOException) cause, remoteEndpoint, addresses));
} else {
future.failed(cause);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.Collections;
import java.util.concurrent.TimeUnit;

import javax.net.ssl.SSLSocket;
Expand Down Expand Up @@ -384,7 +385,8 @@ void testTargetConnect() throws Exception {
.build();
mgr.setTlsConfig(tlsConfig);

Mockito.when(dnsResolver.resolve("somehost")).thenReturn(new InetAddress[] {remote});
// Updated to use the new dnsResolver method that returns List<SocketAddress>
Mockito.when(dnsResolver.resolve("somehost", 8443)).thenReturn(Collections.singletonList(new InetSocketAddress(remote, 8443)));
Mockito.when(schemePortResolver.resolve(target.getSchemeName(), target)).thenReturn(8443);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);

Expand All @@ -398,21 +400,22 @@ void testTargetConnect() throws Exception {

mgr.connect(endpoint1, null, context);

Mockito.verify(dnsResolver, Mockito.times(1)).resolve("somehost");
Mockito.verify(dnsResolver, Mockito.times(1)).resolve("somehost", 8443);
Mockito.verify(schemePortResolver, Mockito.times(1)).resolve(target.getSchemeName(), target);
Mockito.verify(detachedSocketFactory, Mockito.times(1)).create(null);
Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8443), 234);
Mockito.verify(tlsSocketStrategy).upgrade(socket, "somehost", 443, tlsConfig, context);

mgr.connect(endpoint1, TimeValue.ofMilliseconds(123), context);

Mockito.verify(dnsResolver, Mockito.times(2)).resolve("somehost");
Mockito.verify(dnsResolver, Mockito.times(2)).resolve("somehost", 8443);
Mockito.verify(schemePortResolver, Mockito.times(2)).resolve(target.getSchemeName(), target);
Mockito.verify(detachedSocketFactory, Mockito.times(2)).create(null);
Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8443), 123);
Mockito.verify(tlsSocketStrategy, Mockito.times(2)).upgrade(socket, "somehost", 443, tlsConfig, context);
}


@Test
void testProxyConnectAndUpgrade() throws Exception {
final HttpHost target = new HttpHost("https", "somehost", 443);
Expand Down Expand Up @@ -441,15 +444,16 @@ void testProxyConnectAndUpgrade() throws Exception {
.build();
mgr.setTlsConfig(tlsConfig);

Mockito.when(dnsResolver.resolve("someproxy")).thenReturn(new InetAddress[] {remote});
// Updated mock to return a List<SocketAddress>
Mockito.when(dnsResolver.resolve("someproxy", 8080)).thenReturn(Collections.singletonList(new InetSocketAddress(remote, 8080)));
Mockito.when(schemePortResolver.resolve(proxy.getSchemeName(), proxy)).thenReturn(8080);
Mockito.when(schemePortResolver.resolve(target.getSchemeName(), target)).thenReturn(8443);
Mockito.when(tlsSocketStrategyLookup.lookup("https")).thenReturn(tlsSocketStrategy);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);

mgr.connect(endpoint1, null, context);

Mockito.verify(dnsResolver, Mockito.times(1)).resolve("someproxy");
Mockito.verify(dnsResolver, Mockito.times(1)).resolve("someproxy", 8080);
Mockito.verify(schemePortResolver, Mockito.times(1)).resolve(proxy.getSchemeName(), proxy);
Mockito.verify(detachedSocketFactory, Mockito.times(1)).create(null);
Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8080), 234);
Expand Down
Loading

0 comments on commit 7b678ce

Please sign in to comment.