Skip to content

Commit

Permalink
ZOOKEEPER-4804 Use daemon threads for Netty client
Browse files Browse the repository at this point in the history
  • Loading branch information
stoty committed Feb 22, 2024
1 parent 7074448 commit 4eca315
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.concurrent.DefaultThreadFactory;
import io.netty.util.internal.StringUtil;

import java.net.InetAddress;
import java.net.NetworkInterface;
import java.net.SocketException;
import java.util.concurrent.ThreadFactory;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
Expand All @@ -47,6 +51,18 @@ public class NettyUtils {

private static final int DEFAULT_INET_ADDRESS_COUNT = 1;

/**
* Returns a ThreadFactory which generates daemon threads, and uses
* the passed class's name to generate the thread names.
*
* @param clazz Class to use for generating thread names
* @return Netty DefaultThreadFactory configured to create daemon threads
*/
private static ThreadFactory createThreadFactory(Class<? extends Object> clazz) {
String poolName = "zkNetty" + StringUtil.simpleClassName(clazz);
return new DefaultThreadFactory(poolName, true);
}

/**
* If {@link Epoll#isAvailable()} <code>== true</code>, returns a new
* {@link EpollEventLoopGroup}, otherwise returns a new
Expand All @@ -56,9 +72,9 @@ public class NettyUtils {
*/
public static EventLoopGroup newNioOrEpollEventLoopGroup() {
if (Epoll.isAvailable()) {
return new EpollEventLoopGroup();
return new EpollEventLoopGroup(createThreadFactory(EpollEventLoopGroup.class));
} else {
return new NioEventLoopGroup();
return new NioEventLoopGroup(createThreadFactory(NioEventLoopGroup.class));
}
}

Expand All @@ -72,9 +88,9 @@ public static EventLoopGroup newNioOrEpollEventLoopGroup() {
*/
public static EventLoopGroup newNioOrEpollEventLoopGroup(int nThreads) {
if (Epoll.isAvailable()) {
return new EpollEventLoopGroup(nThreads);
return new EpollEventLoopGroup(nThreads, createThreadFactory(EpollEventLoopGroup.class));
} else {
return new NioEventLoopGroup(nThreads);
return new NioEventLoopGroup(nThreads, createThreadFactory(NioEventLoopGroup.class));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,19 @@
import java.net.InetSocketAddress;
import java.net.ProtocolException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.zookeeper.AsyncCallback.DataCallback;
import org.apache.zookeeper.ClientCnxnSocketNetty;
import org.apache.zookeeper.CreateMode;
import org.apache.zookeeper.KeeperException;
import org.apache.zookeeper.ZooDefs.Ids;
import org.apache.zookeeper.client.ZKClientConfig;
import org.apache.zookeeper.ZooKeeper;
import org.apache.zookeeper.common.ClientX509Util;
import org.apache.zookeeper.data.Stat;
Expand All @@ -58,6 +62,7 @@
import org.apache.zookeeper.test.ClientBase;
import org.apache.zookeeper.test.SSLAuthTest;
import org.apache.zookeeper.test.TestByteBufAllocator;
import org.apache.zookeeper.test.TestUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -323,6 +328,39 @@ public void testEnableDisableThrottling_nonSecure_sequentially() throws Exceptio
runEnableDisableThrottling(false, false);
}

@Test
public void testNettyUsesDaemonThreads() throws Exception {
assertTrue(serverFactory instanceof NettyServerCnxnFactory,
"Didn't instantiate ServerCnxnFactory with NettyServerCnxnFactory!");

// Use Netty in the client to check the threads on both the client and server side
System.setProperty(ZKClientConfig.ZOOKEEPER_CLIENT_CNXN_SOCKET, ClientCnxnSocketNetty.class.getName());
try {
final ZooKeeper zk = createClient();
final ZooKeeperServer zkServer = serverFactory.getZooKeeperServer();
final String path = "/a";
try {
// make sure connection is established
zk.create(path, "test".getBytes(StandardCharsets.UTF_8), Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);

List<Thread> threads = TestUtils.getAllThreads();
boolean foundThread = false;
for (Thread t : threads) {
if (t.getName().startsWith("zkNetty")) {
foundThread = true;
assertTrue(t.isDaemon(), "All Netty threads started by ZK must deamon threads");
}
}
assertTrue(foundThread, "Did not find any Netty ZK Threads");
} finally {
zk.close();
zkServer.shutdown();
}
} finally {
System.clearProperty(ZKClientConfig.ZOOKEEPER_CLIENT_CNXN_SOCKET);
}
}

private void runEnableDisableThrottling(boolean secure, boolean randomDisableEnable) throws Exception {
ClientX509Util x509Util = null;
if (secure) {
Expand Down Expand Up @@ -433,4 +471,5 @@ public void processResult(int rc, String path, Object ctx, byte[] data, Stat sta
}
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import java.io.File;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.apache.zookeeper.WatchedEvent;

/**
Expand Down Expand Up @@ -71,4 +76,27 @@ public static void assertWatchedEventEquals(WatchedEvent expected, WatchedEvent
assertEquals(expected.getPath(), actual.getPath());
assertEquals(expected.getZxid(), actual.getZxid());
}

/**
* Return all threads
*
* Code based on commons-lang3 ThreadUtils
*
* @return all active threads
*/
public static List<Thread> getAllThreads() {
ThreadGroup threadGroup = Thread.currentThread().getThreadGroup();
while (threadGroup != null && threadGroup.getParent() != null) {
threadGroup = threadGroup.getParent();
}

int count = threadGroup.activeCount();
Thread[] threads;
do {
threads = new Thread[count + count / 2 + 1]; //slightly grow the array size
count = threadGroup.enumerate(threads, true);
//return value of enumerate() must be strictly less than the array size according to javadoc
} while (count >= threads.length);
return Collections.unmodifiableList(Stream.of(threads).limit(count).collect(Collectors.toList()));
}
}

0 comments on commit 4eca315

Please sign in to comment.