Skip to content

Commit

Permalink
HTTPCLIENT-2350 - Refactored the connect method in DefaultHttpClientC…
Browse files Browse the repository at this point in the history
…onnectionOperator to enhance flexibility in address resolution, specifically allowing for direct handling of unresolved addresses. Updated DnsResolver to introduce a new resolve method supporting both standard and bypassed DNS lookups, enabling improved support for non-public resolvable hosts like .onion endpoints via SOCKS proxy. Adjusted related tests to align with the new resolution mechanism.
  • Loading branch information
arturobernalg committed Nov 15, 2024
1 parent a9da185 commit 71bf32b
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
package org.apache.hc.client5.http;

import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

import org.apache.hc.core5.annotation.Contract;
import org.apache.hc.core5.annotation.ThreadingBehavior;
Expand Down Expand Up @@ -61,4 +67,21 @@ public interface DnsResolver {
*/
String resolveCanonicalHostname(String host) throws UnknownHostException;

/**
* Returns a list of {@link SocketAddress} for the given host with the given port.
*
* @see SocketAddress
*
* @since 5.5
*/
default List<SocketAddress> resolve(String host, int port) throws UnknownHostException {
final InetAddress[] inetAddresses = resolve(host);
if (inetAddresses == null) {
return Collections.singletonList(InetSocketAddress.createUnresolved(host, port));
}
return Arrays.stream(inetAddresses)
.map(e -> new InetSocketAddress(e, port))
.collect(Collectors.toList());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,12 @@
package org.apache.hc.client5.http.impl.io;

import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.function.Function;
import java.util.Collections;
import java.util.List;

import javax.net.ssl.SSLSocket;

Expand Down Expand Up @@ -155,6 +153,7 @@ public void connect(
final SocketConfig socketConfig,
final Object attachment,
final HttpContext context) throws IOException {

Args.notNull(conn, "Connection");
Args.notNull(endpointHost, "Host");
Args.notNull(socketConfig, "Socket config");
Expand All @@ -163,41 +162,28 @@ public void connect(
final Timeout soTimeout = socketConfig.getSoTimeout();
final SocketAddress socksProxyAddress = socketConfig.getSocksProxyAddress();
final Proxy socksProxy = socksProxyAddress != null ? new Proxy(Proxy.Type.SOCKS, socksProxyAddress) : null;
final int port = this.schemePortResolver.resolve(endpointHost.getSchemeName(), endpointHost);

final Function<HttpHost, InetSocketAddress[]> addressResolver = host -> {
if (host.getAddress() != null) {
return new InetSocketAddress[]{new InetSocketAddress(host.getAddress(), port)};
} else {
if (LOG.isDebugEnabled()) {
LOG.debug("{} resolving remote address", host.getHostName());
}
try {
final InetAddress[] remoteAddresses = this.dnsResolver.resolve(host.getHostName());
if (remoteAddresses == null || remoteAddresses.length == 0) {
throw new UnknownHostException(host.getHostName());
}
return Arrays.stream(remoteAddresses)
.map(address -> new InetSocketAddress(address, port))
.toArray(InetSocketAddress[]::new);
} catch (final UnknownHostException e) {
return new InetSocketAddress[]{InetSocketAddress.createUnresolved(host.getHostName(), port)};
}
}
};

final InetSocketAddress[] remoteAddresses = addressResolver.apply(endpointHost);

for (int i = 0; i < remoteAddresses.length; i++) {
final InetSocketAddress remoteAddress = remoteAddresses[i];
final boolean last = i == remoteAddresses.length - 1;

final List<SocketAddress> remoteAddresses;
if (endpointHost.getAddress() != null) {
remoteAddresses = Collections.singletonList(
new InetSocketAddress(endpointHost.getAddress(), this.schemePortResolver.resolve(endpointHost.getSchemeName(), endpointHost)));
} else {
final int port = this.schemePortResolver.resolve(endpointHost.getSchemeName(), endpointHost);
remoteAddresses = this.dnsResolver.resolve(endpointHost.getHostName(), port);
}
for (int i = 0; i < remoteAddresses.size(); i++) {
final InetSocketAddress remoteAddress = (InetSocketAddress) remoteAddresses.get(i);
final boolean last = i == remoteAddresses.size() - 1;
onBeforeSocketConnect(context, endpointHost);
if (LOG.isDebugEnabled()) {
LOG.debug("{} connecting {}->{} ({})", endpointHost, localAddress, remoteAddress, connectTimeout);
}
final Socket socket = detachedSocketFactory.create(socksProxy);
try {
// Always bind to the local address if it's provided.
if (localAddress != null) {
socket.bind(localAddress);
}
conn.bind(socket);
if (soTimeout != null) {
socket.setSoTimeout(soTimeout.toMillisecondsIntBound());
Expand All @@ -216,17 +202,14 @@ public void connect(
if (linger >= 0) {
socket.setSoLinger(true, linger);
}

if (localAddress != null) {
socket.bind(localAddress);
}
socket.connect(remoteAddress, TimeValue.isPositive(connectTimeout) ? connectTimeout.toMillisecondsIntBound() : 0);
conn.bind(socket);
onAfterSocketConnect(context, endpointHost);

if (LOG.isDebugEnabled()) {
LOG.debug("{} {} connected {}->{}", ConnPoolSupport.getId(conn), endpointHost,
conn.getLocalAddress(), conn.getRemoteAddress());
LOG.debug("{} {} connected {}->{}", ConnPoolSupport.getId(conn), endpointHost, conn.getLocalAddress(), conn.getRemoteAddress());
}

conn.setSocketTimeout(soTimeout);
final TlsSocketStrategy tlsSocketStrategy = tlsSocketStrategyLookup != null ? tlsSocketStrategyLookup.lookup(endpointHost.getSchemeName()) : null;
if (tlsSocketStrategy != null) {
Expand All @@ -243,6 +226,7 @@ public void connect(
}
}
return;

} catch (final RuntimeException ex) {
Closer.closeQuietly(socket);
throw ex;
Expand All @@ -261,6 +245,8 @@ public void connect(
}
}



@Override
public void upgrade(
final ManagedHttpClientConnection conn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.Collections;
import java.util.concurrent.TimeUnit;

import javax.net.ssl.SSLSocket;
Expand Down Expand Up @@ -384,7 +385,8 @@ void testTargetConnect() throws Exception {
.build();
mgr.setTlsConfig(tlsConfig);

Mockito.when(dnsResolver.resolve("somehost")).thenReturn(new InetAddress[] {remote});
// Updated to use the new dnsResolver method that returns List<SocketAddress>
Mockito.when(dnsResolver.resolve("somehost", 8443)).thenReturn(Collections.singletonList(new InetSocketAddress(remote, 8443)));
Mockito.when(schemePortResolver.resolve(target.getSchemeName(), target)).thenReturn(8443);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);

Expand All @@ -398,21 +400,22 @@ void testTargetConnect() throws Exception {

mgr.connect(endpoint1, null, context);

Mockito.verify(dnsResolver, Mockito.times(1)).resolve("somehost");
Mockito.verify(dnsResolver, Mockito.times(1)).resolve("somehost", 8443);
Mockito.verify(schemePortResolver, Mockito.times(1)).resolve(target.getSchemeName(), target);
Mockito.verify(detachedSocketFactory, Mockito.times(1)).create(null);
Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8443), 234);
Mockito.verify(tlsSocketStrategy).upgrade(socket, "somehost", 443, tlsConfig, context);

mgr.connect(endpoint1, TimeValue.ofMilliseconds(123), context);

Mockito.verify(dnsResolver, Mockito.times(2)).resolve("somehost");
Mockito.verify(dnsResolver, Mockito.times(2)).resolve("somehost", 8443);
Mockito.verify(schemePortResolver, Mockito.times(2)).resolve(target.getSchemeName(), target);
Mockito.verify(detachedSocketFactory, Mockito.times(2)).create(null);
Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8443), 123);
Mockito.verify(tlsSocketStrategy, Mockito.times(2)).upgrade(socket, "somehost", 443, tlsConfig, context);
}


@Test
void testProxyConnectAndUpgrade() throws Exception {
final HttpHost target = new HttpHost("https", "somehost", 443);
Expand Down Expand Up @@ -441,15 +444,16 @@ void testProxyConnectAndUpgrade() throws Exception {
.build();
mgr.setTlsConfig(tlsConfig);

Mockito.when(dnsResolver.resolve("someproxy")).thenReturn(new InetAddress[] {remote});
// Updated mock to return a List<SocketAddress>
Mockito.when(dnsResolver.resolve("someproxy", 8080)).thenReturn(Collections.singletonList(new InetSocketAddress(remote, 8080)));
Mockito.when(schemePortResolver.resolve(proxy.getSchemeName(), proxy)).thenReturn(8080);
Mockito.when(schemePortResolver.resolve(target.getSchemeName(), target)).thenReturn(8443);
Mockito.when(tlsSocketStrategyLookup.lookup("https")).thenReturn(tlsSocketStrategy);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);

mgr.connect(endpoint1, null, context);

Mockito.verify(dnsResolver, Mockito.times(1)).resolve("someproxy");
Mockito.verify(dnsResolver, Mockito.times(1)).resolve("someproxy", 8080);
Mockito.verify(schemePortResolver, Mockito.times(1)).resolve(proxy.getSchemeName(), proxy);
Mockito.verify(detachedSocketFactory, Mockito.times(1)).create(null);
Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8080), 234);
Expand Down
Loading

0 comments on commit 71bf32b

Please sign in to comment.