Skip to content

Commit

Permalink
Fix race condition in HttpProtocol while having multiple proxy settings
Browse files Browse the repository at this point in the history
In HttpProtocol implementation, the client builder was singleton and may
be accessed and modified by different threads at same time. The result
is that a wrong proxy will be used or a wrong proxy auth will be
configured.

To fix it, create a local builder insteand of having a field-level
builder.

Fixes #1247
  • Loading branch information
chhsiao90 committed Jul 15, 2024
1 parent ef0899e commit bf49afb
Showing 1 changed file with 78 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
Expand All @@ -43,8 +41,6 @@
import okhttp3.Connection;
import okhttp3.ConnectionPool;
import okhttp3.Credentials;
import okhttp3.EventListener;
import okhttp3.EventListener.Factory;
import okhttp3.Handshake;
import okhttp3.Headers;
import okhttp3.Interceptor;
Expand Down Expand Up @@ -87,6 +83,18 @@ public class HttpProtocol extends AbstractHttpProtocol {

private int completionTimeout = -1;

private int httpTimeout;

private boolean retryOnConnectionFailure;

private boolean followRedirects;

private boolean insecure;

private final List<okhttp3.Protocol> protocols = new ArrayList<>();

private ConnectionPool connectionPool = null;

/** Accept partially fetched content as trimmed content */
private boolean partialContentAsTrimmed = false;

Expand All @@ -95,8 +103,6 @@ public class HttpProtocol extends AbstractHttpProtocol {
// track the time spent for each URL in DNS resolution
private final Map<String, Long> DNStimes = new HashMap<>();

private OkHttpClient.Builder builder;

private static final TrustManager[] trustAllCerts =
new TrustManager[] {
new X509TrustManager() {
Expand Down Expand Up @@ -137,27 +143,21 @@ public void configure(Config conf) {

globalMaxContent = ConfUtils.getInt(conf, "http.content.limit", -1);

final int timeout = ConfUtils.getInt(conf, "http.timeout", 10000);
this.httpTimeout = ConfUtils.getInt(conf, "http.timeout", 10000);

this.completionTimeout =
ConfUtils.getInt(conf, "topology.message.timeout.secs", completionTimeout);

this.partialContentAsTrimmed =
ConfUtils.getBoolean(conf, "http.content.partial.as.trimmed", false);

builder =
new OkHttpClient.Builder()
.retryOnConnectionFailure(
ConfUtils.getBoolean(
conf, "http.retry.on.connection.failure", true))
.followRedirects(ConfUtils.getBoolean(conf, "http.allow.redirects", false))
.connectTimeout(timeout, TimeUnit.MILLISECONDS)
.writeTimeout(timeout, TimeUnit.MILLISECONDS)
.readTimeout(timeout, TimeUnit.MILLISECONDS);
this.retryOnConnectionFailure =
ConfUtils.getBoolean(conf, "http.retry.on.connection.failure", true);

this.followRedirects = ConfUtils.getBoolean(conf, "http.allow.redirects", false);

// protocols in order of preference, see
// https://square.github.io/okhttp/4.x/okhttp/okhttp3/-ok-http-client/-builder/protocols/
final List<okhttp3.Protocol> protocols = new ArrayList<>();
for (String pVersion : protocolVersions) {
switch (pVersion) {
case "h2":
Expand All @@ -181,10 +181,6 @@ public void configure(Config conf) {
break;
}
}
if (protocols.size() > 0) {
LOG.info("Using protocol versions: {}", protocols);
builder.protocols(protocols);
}

final String userAgent = getAgentString(conf);
if (StringUtils.isNotBlank(userAgent)) {
Expand Down Expand Up @@ -216,46 +212,23 @@ public void configure(Config conf) {

customHeaders.forEach(customRequestHeaders::add);

if (storeHTTPHeaders) {
builder.addNetworkInterceptor(new HTTPHeadersInterceptor());
}

if (ConfUtils.getBoolean(conf, "http.trust.everything", true)) {
builder.sslSocketFactory(trustAllSslSocketFactory, (X509TrustManager) trustAllCerts[0]);
builder.hostnameVerifier(
new HostnameVerifier() {
@Override
public boolean verify(String hostname, SSLSession session) {
return true;
}
});
}

builder.eventListenerFactory(
new Factory() {
@Override
public EventListener create(Call call) {
return new DNSResolutionListener(DNStimes);
}
});

// enable support for Brotli compression (Content-Encoding)
builder.addInterceptor(BrotliInterceptor.INSTANCE);
this.insecure = ConfUtils.getBoolean(conf, "http.trust.everything", true);

final Map<String, Object> connectionPoolConf =
(Map<String, Object>) conf.get("okhttp.protocol.connection.pool");
if (connectionPoolConf != null) {
final int size = ConfUtils.getInt(connectionPoolConf, "max.idle.connections", 5);
final int time = ConfUtils.getInt(connectionPoolConf, "connection.keep.alive", 300);
builder.connectionPool(new ConnectionPool(size, time, TimeUnit.SECONDS));
this.connectionPool = new ConnectionPool(size, time, TimeUnit.SECONDS);
LOG.info(
"Using connection pool with max. {} idle connections "
+ "and {} sec. connection keep-alive time",
size,
time);
}

client = builder.build();
// default client without proxy
client = createClient(null);
}

private void addCookiesToRequest(Builder rb, String url, Metadata md) {
Expand Down Expand Up @@ -292,46 +265,20 @@ public ProtocolResponse getProtocolOutput(String url, final Metadata metadata)
// conditionally add a dynamic proxy
if (proxyManager != null) {
// retrieve proxy from proxy manager
SCProxy prox = proxyManager.getProxy(metadata);

// conditionally configure proxy authentication
if (StringUtils.isNotBlank(prox.getAddress())) {
// format SCProxy into native Java proxy
Proxy proxy =
new Proxy(
Proxy.Type.valueOf(prox.getProtocol().toUpperCase(Locale.ROOT)),
new InetSocketAddress(
prox.getAddress(), Integer.parseInt(prox.getPort())));

// set proxy in builder
builder.proxy(proxy);

// conditionally add proxy authentication
if (StringUtils.isNotBlank(prox.getUsername())) {
// add proxy authentication header to builder
builder.proxyAuthenticator(
(Route route, Response response) -> {
String credential =
Credentials.basic(prox.getUsername(), prox.getPassword());
return response.request()
.newBuilder()
.header("Proxy-Authorization", credential)
.build();
});
}
SCProxy proxy = proxyManager.getProxy(metadata);
if (StringUtils.isNotBlank(proxy.getAddress())) {
// create new local client from builder using proxy
localClient = createClient(proxy);
}

// save start time for debugging speed impact of client build
long buildStart = System.currentTimeMillis();

// create new local client from builder using proxy
localClient = builder.build();

LOG.debug(
"time to build okhttp client with proxy: {}ms",
System.currentTimeMillis() - buildStart);

LOG.debug("fetching with proxy {} - {} ", url, prox.toString());
LOG.debug("fetching with proxy {} - {} ", url, proxy);
}

final Builder rb = new Request.Builder().url(url);
Expand Down Expand Up @@ -616,6 +563,58 @@ public Response intercept(Interceptor.Chain chain) throws IOException {
}
}

private OkHttpClient createClient(final SCProxy proxy) {
final OkHttpClient.Builder builder =
new OkHttpClient.Builder()
.retryOnConnectionFailure(retryOnConnectionFailure)
.followRedirects(followRedirects)
.connectTimeout(httpTimeout, TimeUnit.MILLISECONDS)
.writeTimeout(httpTimeout, TimeUnit.MILLISECONDS)
.readTimeout(httpTimeout, TimeUnit.MILLISECONDS);

if (protocols.size() > 0) {
LOG.info("Using protocol versions: {}", protocols);
builder.protocols(protocols);
}

if (storeHTTPHeaders) {
builder.addNetworkInterceptor(new HTTPHeadersInterceptor());
}

if (insecure) {
builder.sslSocketFactory(trustAllSslSocketFactory, (X509TrustManager) trustAllCerts[0]);
builder.hostnameVerifier((hostname, session) -> true);
}

builder.eventListenerFactory(call -> new DNSResolutionListener(DNStimes));

// enable support for Brotli compression (Content-Encoding)
builder.addInterceptor(BrotliInterceptor.INSTANCE);

builder.connectionPool(connectionPool);

if (proxy != null) {
builder.proxy(
new Proxy(
Proxy.Type.valueOf(proxy.getProtocol().toUpperCase(Locale.ROOT)),
new InetSocketAddress(
proxy.getAddress(), Integer.parseInt(proxy.getPort()))));
if (StringUtils.isNotBlank(proxy.getUsername())) {
builder.proxyAuthenticator(
(Route route, Response response) -> {
String credential =
Credentials.basic(proxy.getUsername(), proxy.getPassword());
return response.request()
.newBuilder()
.header("Proxy-Authorization", credential)
.build();
});
}
}

return builder.build();
}

public static void main(String args[]) throws Exception {
org.apache.stormcrawler.protocol.Protocol.main(new HttpProtocol(), args);
}
Expand Down

0 comments on commit bf49afb

Please sign in to comment.