Skip to content

Commit

Permalink
Added feature flag named protect_direct_memory to control the usage o…
Browse files Browse the repository at this point in the history
…f OOM checking or not. Enabled by default.
  • Loading branch information
andsel committed Sep 26, 2023
1 parent 173410f commit 2eb55c2
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 12 deletions.
12 changes: 12 additions & 0 deletions docs/index.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ This plugin supports the following configuration options plus the <<plugins-{typ
| <<plugins-{type}s-{plugin}-host>> |<<string,string>>|No
| <<plugins-{type}s-{plugin}-include_codec_tag>> |<<boolean,boolean>>|__Deprecated__
| <<plugins-{type}s-{plugin}-port>> |<<number,number>>|Yes
| <<plugins-{type}s-{plugin}-protect_direct_memory>> |<<boolean,boolean>>|No
| <<plugins-{type}s-{plugin}-ssl>> |<<boolean,boolean>>|__Deprecated__
| <<plugins-{type}s-{plugin}-ssl_certificate>> |a valid filesystem path|No
| <<plugins-{type}s-{plugin}-ssl_certificate_authorities>> |<<array,array>>|No
Expand Down Expand Up @@ -384,6 +385,17 @@ deprecated[6.5.0, Replaced by <<plugins-{type}s-{plugin}-enrich>>]

The port to listen on.

[id="plugins-{type}s-{plugin}-protect_direct_memory"]
===== `protect_direct_memory`

* Value type is <<boolean,boolean>>
* Default value is `true`

If enabled, actively check native memory used by network part to do parsing and avoid
out of memory conditions. When the consumption of native memory used is close to
the maximum limit, connections are being closed in undetermined order until the safe
memory condition is reestablished.

[id="plugins-{type}s-{plugin}-ssl"]
===== `ssl`
deprecated[6.6.0, Replaced by <<plugins-{type}s-{plugin}-ssl_enabled>>]
Expand Down
6 changes: 5 additions & 1 deletion lib/logstash/inputs/beats.rb
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class LogStash::Inputs::Beats < LogStash::Inputs::Base
# The port to listen on.
config :port, :validate => :number, :required => true

# Proactive checks that keep the beats input active when the memory used by protocol parser and network
# related operations is going to terminate.
config :protect_direct_memory, :validate => :boolean, :default => true

# Events are by default sent in plain text. You can
# enable encryption by setting `ssl` to true and configuring
# the `ssl_certificate` and `ssl_key` options.
Expand Down Expand Up @@ -243,7 +247,7 @@ def register
end # def register

def create_server
server = org.logstash.beats.Server.new(@host, @port, @client_inactivity_timeout, @executor_threads)
server = org.logstash.beats.Server.new(@host, @port, @client_inactivity_timeout, @executor_threads, @protect_direct_memory)
server.setSslHandlerProvider(new_ssl_handshake_provider(new_ssl_context_builder)) if @ssl_enabled
server
end
Expand Down
7 changes: 4 additions & 3 deletions spec/inputs/beats_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
let(:port) { BeatsInputTest.random_port }
let(:client_inactivity_timeout) { 400 }
let(:threads) { 1 + rand(9) }
let(:protect_direct_memory) { true }
let(:queue) { Queue.new }
let(:config) do
{
Expand All @@ -36,7 +37,7 @@
let(:port) { 9001 }

it "sends the required options to the server" do
expect(org.logstash.beats.Server).to receive(:new).with(host, port, client_inactivity_timeout, threads)
expect(org.logstash.beats.Server).to receive(:new).with(host, port, client_inactivity_timeout, threads, protect_direct_memory)
subject.register
end
end
Expand Down Expand Up @@ -529,8 +530,8 @@
subject(:plugin) { LogStash::Inputs::Beats.new(config) }

before do
@server = org.logstash.beats.Server.new(host, port, client_inactivity_timeout, threads)
expect( org.logstash.beats.Server ).to receive(:new).with(host, port, client_inactivity_timeout, threads).and_return @server
@server = org.logstash.beats.Server.new(host, port, client_inactivity_timeout, threads, protect_direct_memory)
expect( org.logstash.beats.Server ).to receive(:new).with(host, port, client_inactivity_timeout, threads, protect_direct_memory).and_return @server
expect( @server ).to receive(:listen)

subject.register
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/logstash/beats/Runner.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ static public void main(String[] args) throws Exception {
// Check for leaks.
// ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID);

Server server = new Server("0.0.0.0", DEFAULT_PORT, 15, Runtime.getRuntime().availableProcessors());
Server server = new Server("0.0.0.0", DEFAULT_PORT, 15, Runtime.getRuntime().availableProcessors(), true);

if(args.length > 0 && args[0].equals("ssl")) {
logger.debug("Using SSL");
Expand Down
18 changes: 14 additions & 4 deletions src/main/java/org/logstash/beats/Server.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,24 @@ public class Server {
private final int port;
private final String host;
private final int beatsHeandlerThreadCount;
private final boolean protectDirectMemory;
private NioEventLoopGroup workGroup;
private IMessageListener messageListener = new MessageListener();
private SslHandlerProvider sslHandlerProvider;
private BeatsInitializer beatsInitializer;

private final int clientInactivityTimeoutSeconds;

public Server(String host, int p, int clientInactivityTimeoutSeconds, int threadCount) {
// public Server(String host, int p, int clientInactivityTimeoutSeconds, int threadCount) {
// this(host, p, clientInactivityTimeoutSeconds, threadCount, true);
// }

public Server(String host, int p, int clientInactivityTimeoutSeconds, int threadCount, boolean protectDirectMemory) {
this.host = host;
port = p;
this.clientInactivityTimeoutSeconds = clientInactivityTimeoutSeconds;
beatsHeandlerThreadCount = threadCount;
this.protectDirectMemory = protectDirectMemory;
}

public void setSslHandlerProvider(SslHandlerProvider sslHandlerProvider){
Expand Down Expand Up @@ -130,7 +136,9 @@ private class BeatsInitializer extends ChannelInitializer<SocketChannel> {

public void initChannel(SocketChannel socket){
ChannelPipeline pipeline = socket.pipeline();
pipeline.addLast(new OOMConnectionCloser());
if (protectDirectMemory) {
pipeline.addLast(new OOMConnectionCloser());
}

if (isSslEnabled()) {
pipeline.addLast(SSL_HANDLER, sslHandlerProvider.sslHandlerForChannel(socket));
Expand All @@ -139,8 +147,10 @@ public void initChannel(SocketChannel socket){
new IdleStateHandler(localClientInactivityTimeoutSeconds, IDLESTATE_WRITER_IDLE_TIME_SECONDS, localClientInactivityTimeoutSeconds));
pipeline.addLast(BEATS_ACKER, new AckEncoder());
pipeline.addLast(CONNECTION_HANDLER, new ConnectionHandler());
pipeline.addLast(new FlowLimiterHandler());
pipeline.addLast(new ThunderingGuardHandler());
if (protectDirectMemory) {
pipeline.addLast(new FlowLimiterHandler());
pipeline.addLast(new ThunderingGuardHandler());
}
pipeline.addLast(beatsHandlerExecutorGroup, new BeatsParser());
pipeline.addLast(beatsHandlerExecutorGroup, new BeatsHandler(localMessageListener));
}
Expand Down
6 changes: 3 additions & 3 deletions src/test/java/org/logstash/beats/ServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public void testServerShouldTerminateConnectionWhenExceptionHappen() throws Inte

final CountDownLatch latch = new CountDownLatch(concurrentConnections);

final Server server = new Server(host, randomPort, inactivityTime, threadCount);
final Server server = new Server(host, randomPort, inactivityTime, threadCount, true);
final AtomicBoolean otherCause = new AtomicBoolean(false);
server.setMessageListener(new MessageListener() {
public void onNewConnection(ChannelHandlerContext ctx) {
Expand Down Expand Up @@ -114,7 +114,7 @@ public void testServerShouldTerminateConnectionIdleForTooLong() throws Interrupt

final CountDownLatch latch = new CountDownLatch(concurrentConnections);
final AtomicBoolean exceptionClose = new AtomicBoolean(false);
final Server server = new Server(host, randomPort, inactivityTime, threadCount);
final Server server = new Server(host, randomPort, inactivityTime, threadCount, true);
server.setMessageListener(new MessageListener() {
@Override
public void onNewConnection(ChannelHandlerContext ctx) {
Expand Down Expand Up @@ -170,7 +170,7 @@ public void run() {

@Test
public void testServerShouldAcceptConcurrentConnection() throws InterruptedException {
final Server server = new Server(host, randomPort, 30, threadCount);
final Server server = new Server(host, randomPort, 30, threadCount, true);
SpyListener listener = new SpyListener();
server.setMessageListener(listener);
Runnable serverTask = new Runnable() {
Expand Down

0 comments on commit 2eb55c2

Please sign in to comment.