Skip to content

Commit

Permalink
Merge pull request #256 from emqx/0818-fix-wait-before-publish-signal
Browse files Browse the repository at this point in the history
fix: wait before publish signal should be collected per worker
  • Loading branch information
zmstone authored Aug 20, 2024
2 parents bbbeb07 + 110e198 commit 883ae33
Showing 1 changed file with 56 additions and 33 deletions.
89 changes: 56 additions & 33 deletions src/emqtt_bench.erl
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@
-define(cnt_map, cnt_map).
-define(hdr_cnt64, "cnt64").
-define(hdr_ts, "ts").
-define(GO_SIGNAL, go).

main(["sub"|Argv]) ->
{ok, {Opts, _Args}} = getopt:parse(?SUB_OPTS, Argv),
Expand Down Expand Up @@ -346,17 +347,9 @@ main(pub, Opts) ->
unicode:characters_to_binary(StrPayload)
end,
MsgLimit = pub_limit_fun_init(proplists:get_value(limit, Opts)),
PublishSignalPid =
case proplists:get_value(wait_before_publishing, Opts) of
true ->
spawn(fun() -> receive go -> exit(start_publishing) end end);
false ->
undefined
end,
start(pub, [ {payload, Payload}
, {payload_size, Size}
, {limit_fun, MsgLimit}
, {publish_signal_pid, PublishSignalPid}
| Opts]);

main(conn, Opts) ->
Expand All @@ -371,9 +364,15 @@ start(PubSub, Opts) ->
HostList = addr_to_list(Host),
AddrList = addr_to_list(IfAddr),
NoAddrs = length(AddrList),
NoWorkers = max(erlang:system_info(schedulers_online), ceil(Rate / 1000)),
NoWorkers0 = max(erlang:system_info(schedulers_online), ceil(Rate / 1000)),
Count = proplists:get_value(count, Opts),
CntPerWorker = Count div NoWorkers,
{CntPerWorker, NoWorkers} =
case Count div NoWorkers0 of
0 ->
{1, Count};
N ->
{N, NoWorkers0}
end,
Rem = Count rem NoWorkers,
Interval = case Rate of
0 -> %% conn_rate is not set
Expand All @@ -385,9 +384,21 @@ start(PubSub, Opts) ->
io:format("Start with ~p workers, addrs pool size: ~p and req interval: ~p ms ~n~n",
[NoWorkers, NoAddrs, Interval]),
true = (Interval >= 1),
PublishSignalPid =
case proplists:get_value(wait_before_publishing, Opts) of
true ->
spawn(fun() ->
collect_go_signals(NoWorkers),
io:format("Collected ~p 'go' signals, start publishing~n", [NoWorkers]),
exit(start_publishing)
end);
false ->
undefined
end,
lists:foreach(fun(P) ->
StartNumber = proplists:get_value(startnumber, Opts) + CntPerWorker*(P-1),
Count1 = case Rem =/= 0 andalso P == NoWorkers of
IsLastBatch = (P =:= NoWorkers),
Count1 = case IsLastBatch of
true ->
CntPerWorker + Rem;
false ->
Expand All @@ -398,12 +409,21 @@ start(PubSub, Opts) ->
{payload_hdrs, PayloadHdrs},
{count, Count1}
]),
proc_lib:spawn(?MODULE, run, [self(), PubSub, WOpts, AddrList, HostList])
WOpts1 = [{publish_signal_pid, PublishSignalPid} | WOpts],
proc_lib:spawn(?MODULE, run, [self(), PubSub, WOpts1, AddrList, HostList])
end, lists:seq(1, NoWorkers)),
timer:send_interval(1000, stats),
maybe_spawn_gc_enforcer(Opts),
main_loop(erlang:monotonic_time(millisecond), Count).

collect_go_signals(0) ->
ok;
collect_go_signals(N) ->
receive
?GO_SIGNAL ->
collect_go_signals(N - 1)
end.

prepare(PubSub, Opts) ->
Sname = list_to_atom(lists:flatten(io_lib:format("~p-~p-~p", [?MODULE, PubSub, rand:uniform(1000)]))),
case proplists:get_bool(dist, Opts) of
Expand Down Expand Up @@ -557,18 +577,11 @@ inc_counter(CntName, Inc) ->
cnt_ref() -> persistent_term:get(?MODULE).

run(Parent, PubSub, Opts, AddrList, HostList) ->
run(Parent, proplists:get_value(count, Opts), PubSub, Opts, AddrList, HostList).

ok = run(Parent, 0, proplists:get_value(count, Opts), PubSub, Opts, AddrList, HostList).

run(_Parent, 0, _PubSub, Opts, _AddrList, _HostList) ->
case proplists:get_value(publish_signal_pid, Opts) of
Pid when is_pid(Pid) ->
Pid ! go;
_ ->
ok
end,
done;
run(Parent, N, PubSub, Opts0, AddrList, HostList) ->
run(_Parent, N, N, _PubSub, _Opts, _AddrList, _HostList) ->
ok;
run(Parent, I, N, PubSub, Opts0, AddrList, HostList) ->
emqtt_logger:setup(Opts0),
SpawnOpts = case proplists:get_bool(lowmem, Opts0) of
true ->
Expand All @@ -580,23 +593,33 @@ run(Parent, N, PubSub, Opts0, AddrList, HostList) ->
[]
end,

Opts = replace_opts(Opts0, [ {ifaddr, shard_addr(N, AddrList)}
, {host, shard_addr(N, HostList)}
]),

spawn_opt(?MODULE, connect, [Parent, N+proplists:get_value(startnumber, Opts), PubSub, Opts],
SpawnOpts),
Opts1 = replace_opts(Opts0, [ {ifaddr, shard_addr(N, AddrList)}
, {host, shard_addr(N, HostList)}
]),
%% only the last one can send the 'go' signal
Opts = [{send_go_signal, I + 1 =:= N} | Opts1],
ID = I + 1 + proplists:get_value(startnumber, Opts),
spawn_opt(?MODULE, connect, [Parent, ID, PubSub, Opts], SpawnOpts),
timer:sleep(proplists:get_value(interval, Opts)),
run(Parent, N-1, PubSub, Opts, AddrList, HostList).
run(Parent, I + 1, N, PubSub, Opts, AddrList, HostList).

connect(Parent, N, PubSub, Opts) ->
process_flag(trap_exit, true),
rand:seed(exsplus, erlang:timestamp()),
MRef = case proplists:get_value(publish_signal_pid, Opts) of
Pid when is_pid(Pid) ->
monitor(process, Pid);
GoSignalPid = proplists:get_value(publish_signal_pid, Opts),
SendGoSignal = proplists:get_value(send_go_signal, Opts),
MRef = case is_pid(GoSignalPid) of
true -> monitor(process, GoSignalPid);
_ -> undefined
end,
%% this is the last client in one batch, send go signal when it's ready to connect
case is_pid(GoSignalPid) andalso true =:= SendGoSignal of
true ->
GoSignalPid ! ?GO_SIGNAL;
false ->
ok
end,

ClientId = client_id(PubSub, N, Opts),
MqttOpts = [{clientid, ClientId},
{tcp_opts, tcp_opts(Opts)},
Expand Down

0 comments on commit 883ae33

Please sign in to comment.