diff --git a/shard.lock b/shard.lock index 1f6944db97..255637522c 100644 --- a/shard.lock +++ b/shard.lock @@ -16,6 +16,10 @@ shards: git: https://github.com/84codes/lz4.cr.git version: 1.0.0+git.commit.96d714f7593c66ca7425872fd26c7b1286806d3d + mqtt-protocol: + git: https://github.com/84codes/mqtt-protocol.cr.git + version: 0.2.0+git.commit.3f82ee85d029e6d0505cbe261b108e156df4e598 + systemd: git: https://github.com/84codes/systemd.cr.git version: 2.0.0 diff --git a/shard.yml b/shard.yml index 8709ab05c1..18a31fd4f7 100644 --- a/shard.yml +++ b/shard.yml @@ -32,6 +32,8 @@ dependencies: github: 84codes/systemd.cr lz4: github: 84codes/lz4.cr + mqtt-protocol: + github: 84codes/mqtt-protocol.cr development_dependencies: ameba: diff --git a/spec/clustering_spec.cr b/spec/clustering_spec.cr index d1859ee6bb..1c82b80992 100644 --- a/spec/clustering_spec.cr +++ b/spec/clustering_spec.cr @@ -2,6 +2,8 @@ require "./spec_helper" require "../src/lavinmq/clustering/client" require "../src/lavinmq/clustering/controller" +alias IndexTree = LavinMQ::MQTT::TopicTree(String) + describe LavinMQ::Clustering::Client do follower_data_dir = "/tmp/lavinmq-follower" @@ -72,6 +74,48 @@ describe LavinMQ::Clustering::Client do end end + it "replicates and streams retained messages to followers" do + replicator = LavinMQ::Clustering::Server.new(LavinMQ::Config.instance, LavinMQ::Etcd.new, 0) + tcp_server = TCPServer.new("localhost", 0) + + spawn(replicator.listen(tcp_server), name: "repli server spec") + config = LavinMQ::Config.new.tap &.data_dir = follower_data_dir + repli = LavinMQ::Clustering::Client.new(config, 1, replicator.password, proxy: false) + done = Channel(Nil).new + spawn(name: "follow spec") do + repli.follow("localhost", tcp_server.local_address.port) + done.send nil + end + wait_for { replicator.followers.size == 1 } + + retain_store = LavinMQ::MQTT::RetainStore.new("#{LavinMQ::Config.instance.data_dir}/retain_store", replicator) + wait_for { replicator.followers.first?.try &.lag_in_bytes == 0 } + + props = LavinMQ::AMQP::Properties.new + msg1 = LavinMQ::Message.new(100, "test", "rk", props, 10, IO::Memory.new("body1")) + msg2 = LavinMQ::Message.new(100, "test", "rk", props, 10, IO::Memory.new("body2")) + retain_store.retain("topic1", msg1.body_io, msg1.bodysize) + retain_store.retain("topic2", msg2.body_io, msg2.bodysize) + + wait_for { replicator.followers.first?.try &.lag_in_bytes == 0 } + repli.close + done.receive + + follower_retain_store = LavinMQ::MQTT::RetainStore.new("#{follower_data_dir}/retain_store", LavinMQ::Clustering::NoopServer.new) + a = Array(String).new(2) + b = Array(String).new(2) + follower_retain_store.each("#") do |topic, bytes| + a << topic + b << String.new(bytes) + end + + a.sort!.should eq(["topic1", "topic2"]) + b.sort!.should eq(["body1", "body2"]) + follower_retain_store.retained_messages.should eq(2) + ensure + replicator.try &.close + end + it "can stream full file" do replicator = LavinMQ::Clustering::Server.new(LavinMQ::Config.instance, LavinMQ::Etcd.new, 0) tcp_server = TCPServer.new("localhost", 0) diff --git a/spec/message_routing_spec.cr b/spec/message_routing_spec.cr index da3ab9bdbd..de32981ef6 100644 --- a/spec/message_routing_spec.cr +++ b/spec/message_routing_spec.cr @@ -421,3 +421,42 @@ describe LavinMQ::Exchange do end end end + +describe LavinMQ::MQTT::Exchange do + it "should only allow Session to bind" do + with_amqp_server do |s| + vhost = s.vhosts.create("x") + q1 = LavinMQ::AMQP::Queue.new(vhost, "q1") + s1 = LavinMQ::MQTT::Session.new(vhost, "q1") + index = LavinMQ::MQTT::TopicTree(String).new + store = LavinMQ::MQTT::RetainStore.new("tmp/retain_store", LavinMQ::Clustering::NoopServer.new, index) + x = LavinMQ::MQTT::Exchange.new(vhost, "", store) + x.bind(s1, "s1", LavinMQ::AMQP::Table.new) + expect_raises(LavinMQ::Exchange::AccessRefused) do + x.bind(q1, "q1", LavinMQ::AMQP::Table.new) + end + end + end + + it "publish messages to queues with it's own publish method" do + with_amqp_server do |s| + vhost = s.vhosts.create("x") + s1 = LavinMQ::MQTT::Session.new(vhost, "session 1") + index = LavinMQ::MQTT::TopicTree(String).new + store = LavinMQ::MQTT::RetainStore.new("tmp/retain_store", LavinMQ::Clustering::NoopServer.new, index) + x = LavinMQ::MQTT::Exchange.new(vhost, "mqtt.default", store) + x.bind(s1, "s1", LavinMQ::AMQP::Table.new) + pub_args = { + packet_id: 1u16, + payload: Bytes.new(0), + dup: false, + qos: 0u8, + retain: false, + topic: "s1", + } + msg = MQTT::Protocol::Publish.new(**pub_args) + x.publish(msg) + s1.message_count.should eq 1 + end + end +end diff --git a/spec/mqtt/integrations/connect_spec.cr b/spec/mqtt/integrations/connect_spec.cr new file mode 100644 index 0000000000..d5f3215d1c --- /dev/null +++ b/spec/mqtt/integrations/connect_spec.cr @@ -0,0 +1,320 @@ +require "../spec_helper" + +module MqttSpecs + extend MqttHelpers + extend MqttMatchers + describe "connect [MQTT-3.1.4-1]" do + describe "when client already connected" do + it "should replace the already connected client [MQTT-3.1.4-2]" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + with_client_io(server) do |io2| + connect(io2) + io.should be_closed + end + end + end + end + end + + describe "receives connack" do + describe "with expected flags set" do + it "no session present when reconnecting a non-clean session with a clean session [MQTT-3.1.2-6]" do + with_server do |server| + with_client_io(server) do |io| + connect(io, clean_session: false) + + # LavinMQ won't save sessions without subscriptions + subscribe(io, + topic_filters: [subtopic("a/topic", 0u8)], + packet_id: 1u16 + ) + disconnect(io) + end + with_client_io(server) do |io| + connack = connect(io, clean_session: true) + connack.should be_a(MQTT::Protocol::Connack) + connack = connack.as(MQTT::Protocol::Connack) + connack.session_present?.should be_false + end + end + end + + it "no session present when reconnecting a clean session with a non-clean session [MQTT-3.1.2-6]" do + with_server do |server| + with_client_io(server) do |io| + connect(io, clean_session: true) + subscribe(io, + topic_filters: [subtopic("a/topic", 0u8)], + packet_id: 1u16 + ) + disconnect(io) + end + with_client_io(server) do |io| + connack = connect(io, clean_session: false) + connack.should be_a(MQTT::Protocol::Connack) + connack = connack.as(MQTT::Protocol::Connack) + connack.session_present?.should be_false + end + end + end + + it "no session present when reconnecting a clean session [MQTT-3.1.2-6]" do + with_server do |server| + with_client_io(server) do |io| + connect(io, clean_session: true) + subscribe(io, + topic_filters: [subtopic("a/topic", 0u8)], + packet_id: 1u16 + ) + disconnect(io) + end + with_client_io(server) do |io| + connack = connect(io, clean_session: true) + connack.should be_a(MQTT::Protocol::Connack) + connack = connack.as(MQTT::Protocol::Connack) + connack.session_present?.should be_false + end + end + end + + it "session present when reconnecting a non-clean session [MQTT-3.1.2-4]" do + with_server do |server| + with_client_io(server) do |io| + connect(io, clean_session: false) + subscribe(io, + topic_filters: [subtopic("a/topic", 0u8)], + packet_id: 1u16 + ) + disconnect(io) + end + with_client_io(server) do |io| + connack = connect(io, clean_session: false) + connack.should be_a(MQTT::Protocol::Connack) + connack = connack.as(MQTT::Protocol::Connack) + connack.session_present?.should be_true + end + end + end + end + + describe "with expected return code" do + it "for valid credentials [MQTT-3.1.4-4]" do + with_server do |server| + with_client_io(server) do |io| + connack = connect(io) + connack.should be_a(MQTT::Protocol::Connack) + connack = connack.as(MQTT::Protocol::Connack) + connack.return_code.should eq(MQTT::Protocol::Connack::ReturnCode::Accepted) + end + end + end + + # pending "for invalid credentials" do + # auth = SpecAuth.new({"a" => {password: "b", acls: ["a", "a/b", "/", "/a"] of String}}) + # with_server(auth: auth) do |server| + # with_client_io(server) do |io| + # connack = connect(io, username: "nouser") + + # connack.should be_a(MQTT::Protocol::Connack) + # connack = connack.as(MQTT::Protocol::Connack) + # connack.return_code.should eq(MQTT::Protocol::Connack::ReturnCode::NotAuthorized) + # # Verify that connection is closed [MQTT-3.1.4-1] + # io.should be_closed + # end + # end + # end + + it "for invalid protocol version [MQTT-3.1.2-2]" do + with_server do |server| + with_client_io(server) do |io| + temp_io = IO::Memory.new + temp_mqtt_io = MQTT::Protocol::IO.new(temp_io) + connect(temp_mqtt_io, expect_response: false) + temp_io.rewind + connect_pkt = temp_io.to_slice + # This will overwrite the protocol level byte + connect_pkt[8] = 9u8 + io.write_bytes_raw connect_pkt + + connack = MQTT::Protocol::Packet.from_io(io) + + connack.should be_a(MQTT::Protocol::Connack) + connack = connack.as(MQTT::Protocol::Connack) + connack.return_code.should eq(MQTT::Protocol::Connack::ReturnCode::UnacceptableProtocolVersion) + # Verify that connection is closed [MQTT-3.1.4-1] + io.should be_closed + end + end + end + + it "client_id must be the first field of the connect packet [MQTT-3.1.3-3]" do + with_server do |server| + with_client_io(server) do |io| + connect = MQTT::Protocol::Connect.new( + client_id: "client_id", + clean_session: true, + keepalive: 30u16, + username: "valid_user", + password: "valid_password".to_slice, + will: nil + ).to_slice + connect[0] = 'x'.ord.to_u8 + io.write_bytes_raw connect + io.should be_closed + end + end + end + + it "accepts zero byte client_id but is assigned a unique client_id [MQTT-3.1.3-6]" do + with_server do |server| + with_client_io(server) do |io| + connect(io, client_id: "", clean_session: true) + server.vhosts["/"].connections.select(LavinMQ::MQTT::Client).first.client_id.should_not eq("") + end + end + end + + it "accepts zero-byte ClientId with CleanSession set to 1 [MQTT-3.1.3-7]" do + with_server do |server| + with_client_io(server) do |io| + connack = connect(io, client_id: "", clean_session: true) + connack.should be_a(MQTT::Protocol::Connack) + connack = connack.as(MQTT::Protocol::Connack) + connack.return_code.should eq(MQTT::Protocol::Connack::ReturnCode::Accepted) + io.should_not be_closed + end + end + end + + it "for empty client id with non-clean session [MQTT-3.1.3-8]" do + with_server do |server| + with_client_io(server) do |io| + connack = connect(io, client_id: "", clean_session: false) + connack.should be_a(MQTT::Protocol::Connack) + connack = connack.as(MQTT::Protocol::Connack) + connack.return_code.should eq(MQTT::Protocol::Connack::ReturnCode::IdentifierRejected) + io.should be_closed + end + end + end + + it "for password flag set without username flag set [MQTT-3.1.2-22]" do + with_server do |server| + with_client_io(server) do |io| + connect = MQTT::Protocol::Connect.new( + client_id: "client_id", + clean_session: true, + keepalive: 30u16, + username: nil, + password: "valid_password".to_slice, + will: nil + ).to_slice + # Set password flag + connect[9] |= 0b0100_0000 + io.write_bytes_raw connect + + # Verify that connection is closed [MQTT-3.1.4-1] + io.should be_closed + end + end + end + end + + describe "tcp socket is closed [MQTT-3.1.4-1]" do + it "if first packet is not a CONNECT [MQTT-3.1.0-1]" do + with_server do |server| + with_client_io(server) do |io| + payload = Bytes[1, 254, 200, 197, 123, 4, 87] + publish(io, topic: "test", payload: payload, qos: 0u8) + io.should be_closed + end + end + end + + it "for a second CONNECT packet [MQTT-3.1.0-2]" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + connect(io, expect_response: false) + + io.should be_closed + end + end + end + + it "for invalid client id [MQTT-3.1.3-4]." do + with_server do |server| + with_client_io(server) do |io| + MQTT::Protocol::Connect.new( + client_id: "client\u0000_id", + clean_session: true, + keepalive: 30u16, + username: "valid_user", + password: "valid_user".to_slice, + will: nil + ).to_io(io) + + io.should be_closed + end + end + end + + it "for invalid protocol name [MQTT-3.1.2-1]" do + with_server do |server| + with_client_io(server) do |io| + connect = MQTT::Protocol::Connect.new( + client_id: "client_id", + clean_session: true, + keepalive: 30u16, + username: "valid_user", + password: "valid_password".to_slice, + will: nil + ).to_slice + + # This will overwrite the last "T" in MQTT + connect[7] = 'x'.ord.to_u8 + io.write_bytes_raw connect + + io.should be_closed + end + end + end + + it "for reserved bit set [MQTT-3.1.2-3]" do + with_server do |server| + with_client_io(server) do |io| + connect = MQTT::Protocol::Connect.new( + client_id: "client_id", + clean_session: true, + keepalive: 30u16, + username: "valid_user", + password: "valid_password".to_slice, + will: nil + ).to_slice + connect[9] |= 0b0000_0001 + io.write_bytes_raw connect + + io.should be_closed + end + end + end + + it "should not publish after disconnect" do + with_server do |server| + # Create a non-clean session with an active subscription + with_client_io(server) do |io| + connect(io, clean_session: false) + topics = mk_topic_filters({"a/b", 1}) + subscribe(io, topic_filters: topics) + disconnect(io) + end + sleep 100.milliseconds + server.vhosts["/"].queues["mqtt.client_id"].consumers.should be_empty + end + end + end + end + end +end diff --git a/spec/mqtt/integrations/duplicate_message_spec.cr b/spec/mqtt/integrations/duplicate_message_spec.cr new file mode 100644 index 0000000000..da7ac20c30 --- /dev/null +++ b/spec/mqtt/integrations/duplicate_message_spec.cr @@ -0,0 +1,90 @@ +require "../spec_helper.cr" + +module MqttSpecs + extend MqttHelpers + extend MqttMatchers + + describe "duplicate messages" do + it "dup must not be set if qos is 0 [MQTT-3.3.1-2]" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + # Subscribe with qos=0 means downgrade messages to qos=0 + topic_filter = MQTT::Protocol::Subscribe::TopicFilter.new("a/b", 0u8) + subscribe(io, topic_filters: [topic_filter]) + + with_client_io(server) do |publisher_io| + connect(publisher_io, client_id: "publisher") + publish(publisher_io, topic: "a/b", qos: 0u8) + publish(publisher_io, topic: "a/b", qos: 1u8) + disconnect(publisher_io) + end + + pub1 = MQTT::Protocol::Packet.from_io(io).as(MQTT::Protocol::Publish) + pub1.qos.should eq(0u8) + pub1.dup?.should be_false + pub2 = MQTT::Protocol::Packet.from_io(io).as(MQTT::Protocol::Publish) + pub2.qos.should eq(0u8) + pub2.dup?.should be_false + + disconnect(io) + end + end + end + + it "dup is set when a message is being redelivered [MQTT-3.3.1.-1]" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + topic_filter = MQTT::Protocol::Subscribe::TopicFilter.new("a/b", 1u8) + subscribe(io, topic_filters: [topic_filter]) + + with_client_io(server) do |publisher_io| + connect(publisher_io, client_id: "publisher") + publish(publisher_io, topic: "a/b", qos: 1u8) + disconnect(publisher_io) + end + + pub = MQTT::Protocol::Packet.from_io(io).as(MQTT::Protocol::Publish) + pub.dup?.should be_false + disconnect(io) + end + + with_client_io(server) do |io| + connect(io) + pub = MQTT::Protocol::Packet.from_io(io).as(MQTT::Protocol::Publish) + pub.dup?.should be_true + disconnect(io) + end + end + end + + it "dup on incoming messages is not propagated to other clients [MQTT-3.3.1-3]" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + # Subscribe with qos=0 means downgrade messages to qos=0 + topic_filter = MQTT::Protocol::Subscribe::TopicFilter.new("a/b", 1u8) + subscribe(io, topic_filters: [topic_filter]) + + with_client_io(server) do |publisher_io| + connect(publisher_io, client_id: "publisher") + publish(publisher_io, topic: "a/b", qos: 1u8, dup: true) + publish(publisher_io, topic: "a/b", qos: 1u8, dup: true) + disconnect(publisher_io) + end + + pub1 = MQTT::Protocol::Packet.from_io(io).as(MQTT::Protocol::Publish) + pub1.dup?.should be_false + pub2 = MQTT::Protocol::Packet.from_io(io).as(MQTT::Protocol::Publish) + pub2.dup?.should be_false + + puback(io, pub1.packet_id) + puback(io, pub2.packet_id) + + disconnect(io) + end + end + end + end +end diff --git a/spec/mqtt/integrations/message_qos_spec.cr b/spec/mqtt/integrations/message_qos_spec.cr new file mode 100644 index 0000000000..6109f76376 --- /dev/null +++ b/spec/mqtt/integrations/message_qos_spec.cr @@ -0,0 +1,264 @@ +require "../spec_helper.cr" + +module MqttSpecs + extend MqttHelpers + extend MqttMatchers + describe "message qos" do + it "both qos bits can't be set [MQTT-3.3.1-4]" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + temp_io = IO::Memory.new + publish(MQTT::Protocol::IO.new(temp_io), topic: "a/b", qos: 1u8, expect_response: false) + pub_pkt = temp_io.to_slice + pub_pkt[0] |= 0b0000_0110u8 + io.write pub_pkt + + io.should be_closed + end + end + end + + it "qos is set according to subscription qos [LavinMQ non-normative]" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + # Subscribe with qos=0 means downgrade messages to qos=0 + topic_filters = mk_topic_filters({"a/b", 0u8}) + subscribe(io, topic_filters: topic_filters) + + with_client_io(server) do |publisher_io| + connect(publisher_io, client_id: "publisher") + publish(publisher_io, topic: "a/b", qos: 0u8) + publish(publisher_io, topic: "a/b", qos: 1u8) + disconnect(publisher_io) + end + + pub1 = MQTT::Protocol::Packet.from_io(io).as(MQTT::Protocol::Publish) + pub1.qos.should eq(0u8) + pub1.dup?.should be_false + pub2 = MQTT::Protocol::Packet.from_io(io).as(MQTT::Protocol::Publish) + pub2.qos.should eq(0u8) + pub2.dup?.should be_false + + disconnect(io) + end + end + end + + it "qos1 messages are stored for offline sessions [MQTT-3.1.2-5]" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + topic_filters = mk_topic_filters({"a/b", 1u8}) + subscribe(io, topic_filters: topic_filters) + disconnect(io) + end + + with_client_io(server) do |publisher_io| + connect(publisher_io, client_id: "publisher") + 100.times do + # qos doesnt matter here + publish(publisher_io, topic: "a/b", qos: 0u8) + end + disconnect(publisher_io) + end + + with_client_io(server) do |io| + connect(io) + 100.times do + pkt = read_packet(io) + pkt.should be_a(MQTT::Protocol::Publish) + if pub = pkt.as?(MQTT::Protocol::Publish) + puback(io, pub.packet_id) + end + end + disconnect(io) + end + end + end + + it "acked qos1 message won't be sent again" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + topic_filters = mk_topic_filters({"a/b", 1u8}) + subscribe(io, topic_filters: topic_filters) + + with_client_io(server) do |publisher_io| + connect(publisher_io, client_id: "publisher") + publish(publisher_io, topic: "a/b", payload: "1".to_slice, qos: 0u8) + publish(publisher_io, topic: "a/b", payload: "2".to_slice, qos: 0u8) + disconnect(publisher_io) + end + + pkt = read_packet(io) + if pub = pkt.as?(MQTT::Protocol::Publish) + pub.payload.should eq("1".to_slice) + puback(io, pub.packet_id) + end + disconnect(io) + end + + with_client_io(server) do |io| + connect(io) + pkt = read_packet(io) + if pub = pkt.as?(MQTT::Protocol::Publish) + pub.payload.should eq("2".to_slice) + puback(io, pub.packet_id) + end + disconnect(io) + end + end + end + + it "acks must not be ordered" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + topic_filters = mk_topic_filters({"a/b", 1u8}) + subscribe(io, topic_filters: topic_filters) + + with_client_io(server) do |publisher_io| + connect(publisher_io, client_id: "publisher") + 10.times do |i| + publish(publisher_io, topic: "a/b", payload: "#{i}".to_slice, qos: 0u8) + end + disconnect(publisher_io) + end + + pubs = Array(MQTT::Protocol::Publish).new(9) + # Read all but one + 9.times do + pubs << read_packet(io).as(MQTT::Protocol::Publish) + end + [1, 3, 4, 0, 2, 7, 5, 6, 8].each do |i| + puback(io, pubs[i].packet_id) + end + disconnect(io) + end + with_client_io(server) do |io| + connect(io) + pub = read_packet(io).as(MQTT::Protocol::Publish) + pub.dup?.should be_true + pub.payload.should eq("9".to_slice) + disconnect(io) + end + end + end + + it "cannot ack invalid packet id" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + # we need to subscribe in order to have a session + topic_filters = mk_topic_filters({"a/b", 1u8}) + subscribe(io, topic_filters: topic_filters) + puback(io, 123u16) + + expect_raises(IO::Error) do + read_packet(io) + end + end + end + end + + it "cannot ack a message twice" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + topic_filters = mk_topic_filters({"a/b", 1u8}) + subscribe(io, topic_filters: topic_filters) + + with_client_io(server) do |publisher_io| + connect(publisher_io, client_id: "publisher") + publish(publisher_io, topic: "a/b", qos: 0u8) + disconnect(publisher_io) + end + + pub = read_packet(io).as(MQTT::Protocol::Publish) + + puback(io, pub.packet_id) + + # Sending the second ack make the server close the connection + puback(io, pub.packet_id) + + io.should be_closed + end + end + end + + it "qos1 unacked messages re-sent in the initial order [MQTT-4.6.0-1]" do + max_inflight_messages = 10 + # We'll only ACK odd packet ids, and the first id is 1, so if we don't + # do -1 the last packet (id=20) won't be sent because we've reached max + # inflight with all odd ids. + number_of_messages = (max_inflight_messages * 2 - 1).to_u16 + with_server do |server| + with_client_io(server) do |io| + connect(io, client_id: "subscriber") + topic_filters = mk_topic_filters({"a/b", 1u8}) + subscribe(io, topic_filters: topic_filters) + + with_client_io(server) do |publisher_io| + connect(publisher_io, client_id: "publisher") + number_of_messages.times do |i| + data = Bytes.new(sizeof(UInt16)) + IO::ByteFormat::SystemEndian.encode(i, data) + # qos doesnt matter here + publish(publisher_io, topic: "a/b", payload: data, qos: 0u8) + end + disconnect(publisher_io) + end + + # Read all messages, but only ack every second + # sync = Spectator::Synchronizer.new + sync = Channel(Bool).new(1) + spawn(name: "read msgs") do + number_of_messages.times do |i| + pkt = read_packet(io) + pub = pkt.should be_a(MQTT::Protocol::Publish) + # We only ack odd packet ids + puback(io, pub.packet_id) if (i % 2) > 0 + end + sync.send true + # sync.done + end + select + when sync.receive + when timeout(3.seconds) + fail "Timeout first read" + end + # sync.synchronize(timeout: 3.second, msg: "Timeout first read") + disconnect(io) + end + + # We should now get the 50 messages we didn't ack previously, and in order + with_client_io(server) do |io| + connect(io, client_id: "subscriber") + # sync = Spectator::Synchronizer.new + sync = Channel(Bool).new(1) + spawn(name: "read msgs") do + (number_of_messages // 2).times do |i| + pkt = read_packet(io) + pkt.should be_a(MQTT::Protocol::Publish) + pub = pkt.as(MQTT::Protocol::Publish) + puback(io, pub.packet_id) + data = IO::ByteFormat::SystemEndian.decode(UInt16, pub.payload) + data.should eq(i * 2) + end + sync.send true + # sync.done + end + select + when sync.receive + when timeout(3.seconds) + puts "Timeout second read" + end + # sync.synchronize(timeout: 3.second, msg: "Timeout second read") + disconnect(io) + end + end + end + end +end diff --git a/spec/mqtt/integrations/ping_spec.cr b/spec/mqtt/integrations/ping_spec.cr new file mode 100644 index 0000000000..35a2c80eb2 --- /dev/null +++ b/spec/mqtt/integrations/ping_spec.cr @@ -0,0 +1,17 @@ +require "../spec_helper" + +module MqttSpecs + extend MqttHelpers + describe "ping" do + it "responds to ping [MQTT-3.12.4-1]" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + ping(io) + resp = read_packet(io) + resp.should be_a(MQTT::Protocol::PingResp) + end + end + end + end +end diff --git a/spec/mqtt/integrations/publish_spec.cr b/spec/mqtt/integrations/publish_spec.cr new file mode 100644 index 0000000000..6bde29afab --- /dev/null +++ b/spec/mqtt/integrations/publish_spec.cr @@ -0,0 +1,31 @@ +require "../spec_helper" + +module MqttSpecs + extend MqttHelpers + extend MqttMatchers + + describe "publish" do + it "should return PubAck for QoS=1" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + payload = Bytes[1, 254, 200, 197, 123, 4, 87] + ack = publish(io, topic: "test", payload: payload, qos: 1u8) + ack.should be_a(MQTT::Protocol::PubAck) + end + end + end + + it "shouldn't return anything for QoS=0" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + + payload = Bytes[1, 254, 200, 197, 123, 4, 87] + ack = publish(io, topic: "test", payload: payload, qos: 0u8) + ack.should be_nil + end + end + end + end +end diff --git a/spec/mqtt/integrations/retain_store_spec.cr b/spec/mqtt/integrations/retain_store_spec.cr new file mode 100644 index 0000000000..5edb0483bc --- /dev/null +++ b/spec/mqtt/integrations/retain_store_spec.cr @@ -0,0 +1,106 @@ +require "../spec_helper" + +module MqttSpecs + extend MqttHelpers + alias IndexTree = LavinMQ::MQTT::TopicTree(String) + + context "retain_store" do + after_each do + # Clear out the retain_store directory + FileUtils.rm_rf("tmp/retain_store") + end + + describe "retain" do + it "adds to index and writes msg file" do + index = IndexTree.new + store = LavinMQ::MQTT::RetainStore.new("tmp/retain_store", LavinMQ::Clustering::NoopServer.new, index) + props = LavinMQ::AMQP::Properties.new + msg = LavinMQ::Message.new(100, "test", "rk", props, 10, IO::Memory.new("body")) + store.retain("a", msg.body_io, msg.bodysize) + + index.size.should eq(1) + index.@leafs.has_key?("a").should be_true + + entry = index["a"]?.should be_a String + File.exists?(File.join("tmp/retain_store", entry)).should be_true + end + + it "empty body deletes" do + index = IndexTree.new + store = LavinMQ::MQTT::RetainStore.new("tmp/retain_store", LavinMQ::Clustering::NoopServer.new, index) + props = LavinMQ::AMQP::Properties.new + msg = LavinMQ::Message.new(100, "test", "rk", props, 10, IO::Memory.new("body")) + + store.retain("a", msg.body_io, msg.bodysize) + index.size.should eq(1) + entry = index["a"]?.should be_a String + + store.retain("a", msg.body_io, 0) + index.size.should eq(0) + File.exists?(File.join("tmp/retain_store", entry)).should be_false + end + end + + describe "each" do + it "calls block with correct arguments" do + index = IndexTree.new + store = LavinMQ::MQTT::RetainStore.new("tmp/retain_store", LavinMQ::Clustering::NoopServer.new, index) + props = LavinMQ::AMQP::Properties.new + msg = LavinMQ::Message.new(100, "test", "rk", props, 10, IO::Memory.new("body")) + store.retain("a", msg.body_io, msg.bodysize) + store.retain("b", msg.body_io, msg.bodysize) + + called = [] of Tuple(String, Bytes) + store.each("a") do |topic, bytes| + called << {topic, bytes} + end + + called.size.should eq(1) + called[0][0].should eq("a") + String.new(called[0][1]).should eq("body") + end + + it "handles multiple subscriptions" do + index = IndexTree.new + store = LavinMQ::MQTT::RetainStore.new("tmp/retain_store", LavinMQ::Clustering::NoopServer.new, index) + props = LavinMQ::AMQP::Properties.new + msg1 = LavinMQ::Message.new(100, "test", "rk", props, 10, IO::Memory.new("body")) + msg2 = LavinMQ::Message.new(100, "test", "rk", props, 10, IO::Memory.new("body")) + store.retain("a", msg1.body_io, msg1.bodysize) + store.retain("b", msg2.body_io, msg2.bodysize) + + called = [] of Tuple(String, Bytes) + store.each("a") do |topic, bytes| + called << {topic, bytes} + end + store.each("b") do |topic, bytes| + called << {topic, bytes} + end + + called.size.should eq(2) + called[0][0].should eq("a") + String.new(called[0][1]).should eq("body") + called[1][0].should eq("b") + String.new(called[1][1]).should eq("body") + end + end + + describe "restore_index" do + it "restores the index from a file" do + index = IndexTree.new + store = LavinMQ::MQTT::RetainStore.new("tmp/retain_store", LavinMQ::Clustering::NoopServer.new, index) + props = LavinMQ::AMQP::Properties.new + msg = LavinMQ::Message.new(100, "test", "rk", props, 10, IO::Memory.new("body")) + + store.retain("a", msg.body_io, msg.bodysize) + store.close + + new_index = IndexTree.new + LavinMQ::MQTT::RetainStore.new("tmp/retain_store", LavinMQ::Clustering::NoopServer.new, new_index) + + new_index.size.should eq(1) + new_index.@leafs.has_key?("a").should be_true + end + end + end +end diff --git a/spec/mqtt/integrations/retained_messages_spec.cr b/spec/mqtt/integrations/retained_messages_spec.cr new file mode 100644 index 0000000000..d260be6447 --- /dev/null +++ b/spec/mqtt/integrations/retained_messages_spec.cr @@ -0,0 +1,79 @@ +require "../spec_helper.cr" + +module MqttSpecs + extend MqttHelpers + extend MqttMatchers + describe "retained messages" do + it "retained messages are received on subscribe" do + with_server do |server| + with_client_io(server) do |io| + connect(io, client_id: "publisher") + publish(io, topic: "a/b", qos: 0u8, retain: true) + disconnect(io) + end + + with_client_io(server) do |io| + connect(io, client_id: "subscriber") + subscribe(io, topic_filters: [subtopic("a/b")]) + pub = read_packet(io).as(MQTT::Protocol::Publish) + pub.topic.should eq("a/b") + pub.retain?.should eq(true) + disconnect(io) + end + end + end + + it "retained messages are redelivered for subscriptions with qos1" do + with_server do |server| + with_client_io(server) do |io| + connect(io, client_id: "publisher") + publish(io, topic: "a/b", qos: 0u8, retain: true) + disconnect(io) + end + + with_client_io(server) do |io| + connect(io, client_id: "subscriber") + subscribe(io, topic_filters: [subtopic("a/b", 1u8)]) + # Dont ack + pub = read_packet(io).as(MQTT::Protocol::Publish) + pub.qos.should eq(1u8) + pub.topic.should eq("a/b") + pub.retain?.should eq(true) + pub.dup?.should eq(false) + end + + with_client_io(server) do |io| + connect(io, client_id: "subscriber") + pub = read_packet(io).as(MQTT::Protocol::Publish) + pub.qos.should eq(1u8) + pub.topic.should eq("a/b") + pub.retain?.should eq(true) + pub.dup?.should eq(true) + puback(io, pub.packet_id) + end + end + end + + it "retain is set in PUBLISH for retained messages" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + publish(io, topic: "a/b", qos: 0u8, retain: true) + disconnect(io) + end + + with_client_io(server) do |io| + connect(io) + # Subscribe with qos=0 means downgrade messages to qos=0 + topic_filters = mk_topic_filters({"a/b", 0u8}) + subscribe(io, topic_filters: topic_filters) + + pub = read_packet(io).as(MQTT::Protocol::Publish) + pub.retain?.should eq(true) + + disconnect(io) + end + end + end + end +end diff --git a/spec/mqtt/integrations/subscribe_spec.cr b/spec/mqtt/integrations/subscribe_spec.cr new file mode 100644 index 0000000000..94b5ffda91 --- /dev/null +++ b/spec/mqtt/integrations/subscribe_spec.cr @@ -0,0 +1,135 @@ +require "../spec_helper" + +module MqttSpecs + extend MqttHelpers + extend MqttMatchers + describe "subscribe" do + it "pub/sub" do + with_server do |server| + with_client_io(server) do |sub_io| + connect(sub_io, client_id: "sub") + + topic_filters = mk_topic_filters({"test", 0}) + subscribe(sub_io, topic_filters: topic_filters) + + with_client_io(server) do |pub_io| + connect(pub_io, client_id: "pub") + + payload = Bytes[1, 254, 200, 197, 123, 4, 87] + packet_id = next_packet_id + ack = publish(pub_io, + topic: "test", + payload: payload, + qos: 0u8, + packet_id: packet_id + ) + ack.should be_nil + + msg = read_packet(sub_io) + msg.should be_a(MQTT::Protocol::Publish) + msg = msg.as(MQTT::Protocol::Publish) + msg.payload.should eq payload + msg.packet_id.should be_nil # QoS=0 + end + end + end + end + + it "bits 3,2,1,0 must be set to 0,0,1,0 [MQTT-3.8.1-1]" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + + temp_io = IO::Memory.new + topic_filters = mk_topic_filters({"a/b", 0}) + subscribe(MQTT::Protocol::IO.new(temp_io), topic_filters: topic_filters, expect_response: false) + temp_io.rewind + subscribe_pkt = temp_io.to_slice + # This will overwrite the protocol level byte + subscribe_pkt[0] |= 0b0000_1010u8 + io.write_bytes_raw subscribe_pkt + + # Verify that connection is closed + io.should be_closed + end + end + end + + it "must contain at least one topic filter [MQTT-3.8.3-3]" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + + topic_filters = mk_topic_filters({"a/b", 0}) + temp_io = IO::Memory.new + subscribe(MQTT::Protocol::IO.new(temp_io), topic_filters: topic_filters, expect_response: false) + temp_io.rewind + sub_pkt = temp_io.to_slice + sub_pkt[1] = 2u8 # Override remaning length + io.write_bytes_raw sub_pkt + + # Verify that connection is closed + io.should be_closed + end + end + end + + it "should not allow any payload reserved bits to be set [MQTT-3-8.3-4]" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + + topic_filters = mk_topic_filters({"a/b", 0}) + temp_io = IO::Memory.new + subscribe(MQTT::Protocol::IO.new(temp_io), topic_filters: topic_filters, expect_response: false) + temp_io.rewind + sub_pkt = temp_io.to_slice + sub_pkt[sub_pkt.size - 1] |= 0b1010_0100u8 + io.write_bytes_raw sub_pkt + + # Verify that connection is closed + io.should be_closed + end + end + end + + it "should replace old subscription with new [MQTT-3.8.4-3]" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + + topic_filters = mk_topic_filters({"a/b", 0}) + suback = subscribe(io, topic_filters: topic_filters) + suback.should be_a(MQTT::Protocol::SubAck) + suback = suback.as(MQTT::Protocol::SubAck) + # Verify that we subscribed as qos0 + suback.return_codes.first.should eq(MQTT::Protocol::SubAck::ReturnCode::QoS0) + + # Publish something to the topic we're subscribed to... + publish(io, topic: "a/b", payload: "a".to_slice, qos: 1u8) + # ... consume it... + packet = read_packet(io).as(MQTT::Protocol::Publish) + # ... and verify it be qos0 (i.e. our subscribe is correct) + packet.qos.should eq(0u8) + + # Now do a second subscribe with another qos and do the same verification + topic_filters = mk_topic_filters({"a/b", 1}) + suback = subscribe(io, topic_filters: topic_filters) + suback.should be_a(MQTT::Protocol::SubAck) + suback = suback.as(MQTT::Protocol::SubAck) + # Verify that we subscribed as qos1 + suback.return_codes.should eq([MQTT::Protocol::SubAck::ReturnCode::QoS1]) + + # Publish something to the topic we're subscribed to... + publish(io, topic: "a/b", payload: "a".to_slice, qos: 1u8) + # ... consume it... + packet = read_packet(io).as(MQTT::Protocol::Publish) + # ... and verify it be qos1 (i.e. our second subscribe is correct) + packet.qos.should eq(1u8) + + io.should be_drained + end + end + end + end +end diff --git a/spec/mqtt/integrations/unsubscribe_spec.cr b/spec/mqtt/integrations/unsubscribe_spec.cr new file mode 100644 index 0000000000..ceb7de992c --- /dev/null +++ b/spec/mqtt/integrations/unsubscribe_spec.cr @@ -0,0 +1,96 @@ +require "../spec_helper" + +module MqttSpecs + extend MqttHelpers + extend MqttMatchers + + describe "unsubscribe" do + it "bits 3,2,1,0 must be set to 0,0,1,0 [MQTT-3.10.1-1]" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + + temp_io = IO::Memory.new + unsubscribe(MQTT::Protocol::IO.new(temp_io), topics: ["a/b"], expect_response: false) + temp_io.rewind + unsubscribe_pkt = temp_io.to_slice + # This will overwrite the protocol level byte + unsubscribe_pkt[0] |= 0b0000_1010u8 + io.write_bytes_raw unsubscribe_pkt + + io.should be_closed + end + end + end + + it "must contain at least one topic filter [MQTT-3.10.3-2]" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + + temp_io = IO::Memory.new + unsubscribe(MQTT::Protocol::IO.new(temp_io), topics: ["a/b"], expect_response: false) + temp_io.rewind + unsubscribe_pkt = temp_io.to_slice + # Overwrite remaining length + unsubscribe_pkt[1] = 2u8 + io.write_bytes_raw unsubscribe_pkt + + io.should be_closed + end + end + end + + it "must stop adding any new messages for delivery to the Client, but completes delivery of previous messages [MQTT-3.10.4-2] and [MQTT-3.10.4-3]" do + with_server do |server| + with_client_io(server) do |pubio| + connect(pubio, client_id: "publisher") + + # Create a non-clean session with an active subscription + with_client_io(server) do |io| + connect(io, clean_session: false) + topics = mk_topic_filters({"a/b", 1}) + subscribe(io, topic_filters: topics) + disconnect(io) + end + + # Publish messages that will be stored for the subscriber + 2.times { |i| publish(pubio, topic: "a/b", payload: i.to_s.to_slice, qos: 0u8) } + + # Let the subscriber connect and read the messages, but don't ack. Then unsubscribe. + # We must read the Publish packets before unsubscribe, else the "suback" will be stuck. + with_client_io(server) do |io| + connect(io, clean_session: false) + 2.times do + pkt = read_packet(io) + pkt.should be_a(MQTT::Protocol::Publish) + # dont ack + end + + unsubscribe(io, topics: ["a/b"]) + disconnect(io) + end + + # Publish more messages + 2.times { |i| publish(pubio, topic: "a/b", payload: (2 + i).to_s.to_slice, qos: 0u8) } + + # Now, if unsubscribed worked, the last two publish packets shouldn't be held for the + # session. Read the two we expect, then test that there is nothing more to read. + with_client_io(server) do |io| + connect(io, clean_session: false) + 2.times do |i| + pkt = read_packet(io) + pkt.should be_a(MQTT::Protocol::Publish) + pkt = pkt.as(MQTT::Protocol::Publish) + pkt.payload.should eq(i.to_s.to_slice) + end + + io.should be_drained + disconnect(io) + end + disconnect(pubio) + end + end + end + end +end diff --git a/spec/mqtt/integrations/various_spec.cr b/spec/mqtt/integrations/various_spec.cr new file mode 100644 index 0000000000..8c9a0d4c3a --- /dev/null +++ b/spec/mqtt/integrations/various_spec.cr @@ -0,0 +1,31 @@ +require "../spec_helper" + +module MqttSpecs + extend MqttHelpers + extend MqttMatchers + + describe "session handling" do + it "messages are delivered to client that connects to a existing session" do + with_server do |server| + with_client_io(server) do |io| + connect(io, clean_session: false) + subscribe(io, topic_filters: [subtopic("a/b/c", 1u8)]) + disconnect(io) + end + + with_client_io(server) do |io| + connect(io, clean_session: false, client_id: "pub") + publish(io, topic: "a/b/c", qos: 0u8) + end + + with_client_io(server) do |io| + connect(io, clean_session: false) + packet = read_packet(io).should be_a(MQTT::Protocol::Publish) + packet.topic.should eq "a/b/c" + rescue + fail "timeout; message not routed" + end + end + end + end +end diff --git a/spec/mqtt/integrations/will_spec.cr b/spec/mqtt/integrations/will_spec.cr new file mode 100644 index 0000000000..b43e10ae39 --- /dev/null +++ b/spec/mqtt/integrations/will_spec.cr @@ -0,0 +1,187 @@ +require "../spec_helper" + +module MqttSpecs + extend MqttHelpers + extend MqttMatchers + + describe "client will" do + it "is not delivered on graceful disconnect [MQTT-3.14.4-3]" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + topic_filters = mk_topic_filters({"#", 0}) + subscribe(io, topic_filters: topic_filters) + + with_client_io(server) do |io2| + will = MQTT::Protocol::Will.new( + topic: "will/t", payload: "dead".to_slice, qos: 0u8, retain: false) + connect(io2, client_id: "will_client", will: will, keepalive: 1u16) + disconnect(io2) + end + + # If the will has been published it should be received before this + publish(io, topic: "a/b", payload: "alive".to_slice) + + pub = read_packet(io).should be_a(MQTT::Protocol::Publish) + pub.payload.should eq("alive".to_slice) + pub.topic.should eq("a/b") + + disconnect(io) + end + end + end + + describe "is delivered on ungraceful disconnect" do + it "when client unexpected closes tcp connection" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + topic_filters = mk_topic_filters({"will/t", 0}) + subscribe(io, topic_filters: topic_filters) + + with_client_io(server) do |io2| + will = MQTT::Protocol::Will.new( + topic: "will/t", payload: "dead".to_slice, qos: 0u8, retain: false) + connect(io2, client_id: "will_client", will: will, keepalive: 1u16) + end + + pub = read_packet(io).should be_a(MQTT::Protocol::Publish) + pub.payload.should eq("dead".to_slice) + pub.topic.should eq("will/t") + + disconnect(io) + end + end + end + + it "when server closes connection because protocol error" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + topic_filters = mk_topic_filters({"will/t", 0}) + subscribe(io, topic_filters: topic_filters) + + with_client_io(server) do |io2| + will = MQTT::Protocol::Will.new( + topic: "will/t", payload: "dead".to_slice, qos: 0u8, retain: false) + connect(io2, client_id: "will_client", will: will, keepalive: 20u16) + + broken_packet_io = IO::Memory.new + publish(MQTT::Protocol::IO.new(broken_packet_io), topic: "foo", qos: 1u8, expect_response: false) + broken_packet = broken_packet_io.to_slice + broken_packet[0] |= 0b0000_0110u8 # set both qos bits to 1 + io2.write broken_packet + end + + pub = read_packet(io).should be_a(MQTT::Protocol::Publish) + pub.payload.should eq("dead".to_slice) + pub.topic.should eq("will/t") + + disconnect(io) + end + end + end + end + + it "can be retained [MQTT-3.1.2-17]" do + with_server do |server| + with_client_io(server) do |io2| + will = MQTT::Protocol::Will.new( + topic: "will/t", payload: "dead".to_slice, qos: 0u8, retain: true) + connect(io2, client_id: "will_client", will: will, keepalive: 1u16) + end + + with_client_io(server) do |io| + connect(io) + topic_filters = mk_topic_filters({"will/t", 0}) + subscribe(io, topic_filters: topic_filters) + + pub = read_packet(io).should be_a(MQTT::Protocol::Publish) + pub.payload.should eq("dead".to_slice) + pub.topic.should eq("will/t") + pub.retain?.should eq(true) + + disconnect(io) + end + end + end + + it "won't be published if missing permission" do + with_server do |server| + with_client_io(server) do |io| + connect(io) + topic_filters = mk_topic_filters({"topic-without-permission/t", 0}) + subscribe(io, topic_filters: topic_filters) + + with_client_io(server) do |io2| + will = MQTT::Protocol::Will.new( + topic: "will/t", payload: "dead".to_slice, qos: 0u8, retain: false) + connect(io2, client_id: "will_client", will: will, keepalive: 1u16) + end + + # Send a ping to ensure we can read at least one packet, so we're not stuck + # waiting here (since this spec verifies that nothing is sent) + ping(io) + + pkt = read_packet(io) + pkt.should be_a(MQTT::Protocol::PingResp) + + disconnect(io) + end + end + end + + it "qos can't be set of will flag is unset [MQTT-3.1.2-13]" do + with_server do |server| + with_client_io(server) do |io| + temp_io = IO::Memory.new + connect(MQTT::Protocol::IO.new(temp_io), client_id: "will_client", keepalive: 1u16, expect_response: false) + temp_io.rewind + connect_pkt = temp_io.to_slice + connect_pkt[9] |= 0b0001_0000u8 + io.write connect_pkt + + expect_raises(IO::Error) do + read_packet(io) + end + end + end + end + + it "qos must not be 3 [MQTT-3.1.2-14]" do + with_server do |server| + with_client_io(server) do |io| + temp_io = IO::Memory.new + will = MQTT::Protocol::Will.new( + topic: "will/t", payload: "dead".to_slice, qos: 0u8, retain: false) + connect(MQTT::Protocol::IO.new(temp_io), will: will, client_id: "will_client", keepalive: 1u16, expect_response: false) + temp_io.rewind + connect_pkt = temp_io.to_slice + connect_pkt[9] |= 0b0001_1000u8 + io.write connect_pkt + + expect_raises(IO::Error) do + read_packet(io) + end + end + end + end + + it "retain can't be set of will flag is unset [MQTT-3.1.2-15]" do + with_server do |server| + with_client_io(server) do |io| + temp_io = IO::Memory.new + connect(MQTT::Protocol::IO.new(temp_io), client_id: "will_client", keepalive: 1u16, expect_response: false) + temp_io.rewind + connect_pkt = temp_io.to_slice + connect_pkt[9] |= 0b0010_0000u8 + io.write connect_pkt + + expect_raises(IO::Error) do + read_packet(io) + end + end + end + end + end +end diff --git a/spec/mqtt/multi_vhost_spec.cr b/spec/mqtt/multi_vhost_spec.cr new file mode 100644 index 0000000000..d960712f20 --- /dev/null +++ b/spec/mqtt/multi_vhost_spec.cr @@ -0,0 +1,40 @@ +require "./spec_helper" + +module MqttSpecs + extend MqttHelpers + describe LavinMQ::MQTT do + describe "multi-vhost" do + it "should create mqtt exchange when vhost is created" do + with_amqp_server do |server| + server.vhosts.create("new") + server.vhosts["new"].exchanges[LavinMQ::MQTT::EXCHANGE]?.should_not be_nil + end + end + + describe "authentication" do + it "should deny mqtt access for user lacking vhost permissions" do + with_server do |server| + server.users.create("foo", "bar") + with_client_io(server) do |io| + resp = connect io, username: "foo", password: "bar".to_slice + resp = resp.should be_a(MQTT::Protocol::Connack) + resp.return_code.should eq MQTT::Protocol::Connack::ReturnCode::NotAuthorized + end + end + end + + it "should allow mqtt access for user with vhost permissions" do + with_server do |server| + server.users.create("foo", "bar") + server.users.add_permission "foo", "/", /.*/, /.*/, /.*/ + with_client_io(server) do |io| + resp = connect io, username: "foo", password: "bar".to_slice + resp = resp.should be_a(MQTT::Protocol::Connack) + resp.return_code.should eq MQTT::Protocol::Connack::ReturnCode::Accepted + end + end + end + end + end + end +end diff --git a/spec/mqtt/routing_spec.cr b/spec/mqtt/routing_spec.cr new file mode 100644 index 0000000000..dbb059b52e --- /dev/null +++ b/spec/mqtt/routing_spec.cr @@ -0,0 +1,72 @@ +require "./spec_helper" + +module MqttSpecs + extend MqttHelpers + extend MqttMatchers + + describe "message routing" do + topic = "a/b/c" + positive_topic_filters = { + "a/b/c", + "#", + "a/#", + "a/b/#", + "a/b/+", + "a/+/+", + "+/+/+", + "+/+/c", + "+/b/c", + "+/#", + "+/+/#", + "a/+/#", + "a/+/c", + } + negative_topic_filters = { + "c/a/b", + "c/#", + "+/a/+", + "c/+/#", + "+/+/d", + } + positive_topic_filters.each do |topic_filter| + it "should route #{topic} to #{topic_filter}" do + with_server do |server| + with_client_io(server) do |sub| + connect(sub, client_id: "sub") + subscribe(sub, topic_filters: [subtopic(topic_filter, 1u8)]) + + with_client_io(server) do |pub_io| + connect(pub_io, client_id: "pub") + publish(pub_io, topic: "a/b/c", qos: 0u8) + end + + begin + packet = read_packet(sub).should be_a(MQTT::Protocol::Publish) + packet.topic.should eq "a/b/c" + rescue + fail "timeout; message not routed" + end + end + end + end + end + + negative_topic_filters.each do |topic_filter| + it "should not route #{topic} to #{topic_filter}" do + with_server do |server| + with_client_io(server) do |sub| + connect(sub, client_id: "sub") + subscribe(sub, topic_filters: [subtopic(topic_filter, 1u8)]) + + with_client_io(server) do |pub_io| + connect(pub_io, client_id: "pub") + publish(pub_io, topic: "a/b/c", qos: 0u8) + end + + expect_raises(::IO::TimeoutError) { MQTT::Protocol::Packet.from_io(sub) } + end + end + end + end + end +end diff --git a/spec/mqtt/spec_helper.cr b/spec/mqtt/spec_helper.cr new file mode 100644 index 0000000000..4a3b767606 --- /dev/null +++ b/spec/mqtt/spec_helper.cr @@ -0,0 +1,3 @@ +require "./spec_helper/mqtt_helpers_spec" +require "./spec_helper/mqtt_matchers_spec" +require "./spec_helper/mqtt_protocol_spec" diff --git a/spec/mqtt/spec_helper/mqtt_client_spec.cr b/spec/mqtt/spec_helper/mqtt_client_spec.cr new file mode 100644 index 0000000000..e028ac6d32 --- /dev/null +++ b/spec/mqtt/spec_helper/mqtt_client_spec.cr @@ -0,0 +1,111 @@ +require "mqtt-protocol" + +module MqttHelpers + class MqttClient + def next_packet_id + @packet_id_generator.next.as(UInt16) + end + + @packet_id_generator : Iterator(UInt16) + + getter client_id + + def initialize(io : IO) + @client_id = "" + @io = MQTT::Protocol::IO.new(io) + @packet_id_generator = (0u16..).each + end + + def connect( + expect_response = true, + username = "valid_user", + password = "valid_password", + client_id = "spec_client", + keepalive = 30u16, + will = nil, + clean_session = true, + **args + ) + connect_args = { + client_id: client_id, + clean_session: clean_session, + keepalive: keepalive, + will: will, + username: username, + password: password.to_slice, + }.merge(args) + @client_id = connect_args.fetch(:client_id, "").to_s + MQTT::Protocol::Connect.new(**connect_args).to_io(@io) + read_packet if expect_response + end + + def disconnect + MQTT::Protocol::Disconnect.new.to_io(@io) + true + rescue IO::Error + false + end + + def subscribe(topic : String, qos : UInt8 = 0u8, expect_response = true) + filter = MQTT::Protocol::Subscribe::TopicFilter.new(topic, qos) + MQTT::Protocol::Subscribe.new([filter], packet_id: next_packet_id).to_io(@io) + read_packet if expect_response + end + + def unsubscribe(*topics : String, expect_response = true) + MQTT::Protocol::Unsubscribe.new(topics.to_a, next_packet_id).to_io(@io) + read_packet if expect_response + end + + def publish( + topic : String, + payload : String, + qos = 0, + retain = false, + packet_id : UInt16? = next_packet_id, + expect_response = true + ) + pub_args = { + packet_id: packet_id, + payload: payload.to_slice, + topic: topic, + dup: false, + qos: qos.to_u8, + retain: retain, + } + MQTT::Protocol::Publish.new(**pub_args).to_io(@io) + read_packet if pub_args[:qos].positive? && expect_response + end + + def puback(packet_id : UInt16?) + return if packet_id.nil? + MQTT::Protocol::PubAck.new(packet_id).to_io(@io) + end + + def puback(packet : MQTT::Protocol::Publish) + if packet_id = packet.packet_id + MQTT::Protocol::PubAck.new(packet_id).to_io(@io) + end + end + + def ping(expect_response = true) + MQTT::Protocol::PingReq.new.to_io(@io) + read_packet if expect_response + end + + def read_packet + MQTT::Protocol::Packet.from_io(@io) + rescue ex : IO::Error + @io.close + raise ex + end + + def close + @io.close + end + + def closed? + @io.closed? + end + end +end diff --git a/spec/mqtt/spec_helper/mqtt_helpers_spec.cr b/spec/mqtt/spec_helper/mqtt_helpers_spec.cr new file mode 100644 index 0000000000..795c5548c7 --- /dev/null +++ b/spec/mqtt/spec_helper/mqtt_helpers_spec.cr @@ -0,0 +1,127 @@ +require "mqtt-protocol" +require "./mqtt_client_spec" +require "../../spec_helper" + +module MqttHelpers + GENERATOR = (0u16..).each + + def next_packet_id + GENERATOR.next.as(UInt16) + end + + def with_client_socket(server) + listener = server.listeners.find(&.[:protocol].mqtt?) + tcp_listener = listener.as(NamedTuple(ip_address: String, protocol: LavinMQ::Server::Protocol, port: Int32)) + + socket = TCPSocket.new( + tcp_listener[:ip_address], + tcp_listener[:port], + connect_timeout: 30) + socket.keepalive = true + socket.tcp_nodelay = false + socket.tcp_keepalive_idle = 60 + socket.tcp_keepalive_count = 3 + socket.tcp_keepalive_interval = 10 + socket.sync = true + socket.read_buffering = true + socket.buffer_size = 16384 + socket.read_timeout = 1.seconds + socket + end + + def with_client_socket(server, &) + socket = with_client_socket(server) + yield socket + ensure + socket.try &.close + end + + def with_server(& : LavinMQ::Server -> Nil) + mqtt_server = TCPServer.new("localhost", 0) + amqp_server = TCPServer.new("localhost", 0) + s = LavinMQ::Server.new(LavinMQ::Config.instance.data_dir, LavinMQ::Clustering::NoopServer.new) + begin + spawn(name: "amqp tcp listen") { s.listen(amqp_server, LavinMQ::Server::Protocol::AMQP) } + spawn(name: "mqtt tcp listen") { s.listen(mqtt_server, LavinMQ::Server::Protocol::MQTT) } + Fiber.yield + yield s + ensure + s.close + FileUtils.rm_rf(LavinMQ::Config.instance.data_dir) + end + end + + def with_client_io(server) + socket = with_client_socket(server) + MQTT::Protocol::IO.new(socket) + end + + def with_client_io(server, &) + with_client_socket(server) do |io| + with MqttHelpers yield MQTT::Protocol::IO.new(io) + end + end + + def connect(io, expect_response = true, **args) + MQTT::Protocol::Connect.new(**{ + client_id: "client_id", + clean_session: false, + keepalive: 30u16, + username: "guest", + password: "guest".to_slice, + will: nil, + }.merge(args)).to_io(io) + MQTT::Protocol::Packet.from_io(io) if expect_response + end + + def disconnect(io) + MQTT::Protocol::Disconnect.new.to_io(io) + end + + def mk_topic_filters(*args) : Array(MQTT::Protocol::Subscribe::TopicFilter) + ret = Array(MQTT::Protocol::Subscribe::TopicFilter).new + args.each { |topic, qos| ret << subtopic(topic, qos) } + ret + end + + def subscribe(io, expect_response = true, **args) + MQTT::Protocol::Subscribe.new(**{packet_id: next_packet_id}.merge(args)).to_io(io) + MQTT::Protocol::Packet.from_io(io) if expect_response + end + + def unsubscribe(io, topics : Array(String), expect_response = true, packet_id = next_packet_id) + MQTT::Protocol::Unsubscribe.new(topics, packet_id).to_io(io) + MQTT::Protocol::Packet.from_io(io) if expect_response + end + + def subtopic(topic : String, qos = 0) + MQTT::Protocol::Subscribe::TopicFilter.new(topic, qos.to_u8) + end + + def publish(io, expect_response = true, **args) + pub_args = { + packet_id: next_packet_id, + payload: "data".to_slice, + dup: false, + qos: 0u8, + retain: false, + }.merge(args) + MQTT::Protocol::Publish.new(**pub_args).to_io(io) + MQTT::Protocol::PubAck.from_io(io) if pub_args[:qos].positive? && expect_response + end + + def puback(io, packet_id : UInt16?) + return if packet_id.nil? + MQTT::Protocol::PubAck.new(packet_id).to_io(io) + end + + def ping(io) + MQTT::Protocol::PingReq.new.to_io(io) + end + + def read_packet(io) + MQTT::Protocol::Packet.from_io(io) + rescue IO::TimeoutError + nil + end +end diff --git a/spec/mqtt/spec_helper/mqtt_matchers_spec.cr b/spec/mqtt/spec_helper/mqtt_matchers_spec.cr new file mode 100644 index 0000000000..2789744e16 --- /dev/null +++ b/spec/mqtt/spec_helper/mqtt_matchers_spec.cr @@ -0,0 +1,47 @@ +module MqttMatchers + struct ClosedExpectation + include MqttHelpers + + def match(actual : MQTT::Protocol::IO) + return true if actual.closed? + read_packet(actual) + false + rescue e : IO::Error + true + end + + def failure_message(actual_value) + "Expected socket to be closed" + end + + def negative_failure_message(actual_value) + "Expected socket to be open" + end + end + + def be_closed + ClosedExpectation.new + end + + struct EmptyMatcher + include MqttHelpers + + def match(actual) + ping(actual) + resp = read_packet(actual) + resp.is_a?(MQTT::Protocol::PingResp) + end + + def failure_message(actual_value) + "Expected socket to be drained" + end + + def negative_failure_message(actual_value) + "Expected socket to not be drained" + end + end + + def be_drained + EmptyMatcher.new + end +end diff --git a/spec/mqtt/spec_helper/mqtt_protocol_spec.cr b/spec/mqtt/spec_helper/mqtt_protocol_spec.cr new file mode 100644 index 0000000000..5bf73e4ba9 --- /dev/null +++ b/spec/mqtt/spec_helper/mqtt_protocol_spec.cr @@ -0,0 +1,12 @@ +module MQTT + module Protocol + abstract struct Packet + def to_slice + io = ::IO::Memory.new + self.to_io(IO.new(io)) + io.rewind + io.to_slice + end + end + end +end diff --git a/spec/mqtt/string_token_iterator_spec.cr b/spec/mqtt/string_token_iterator_spec.cr new file mode 100644 index 0000000000..c4c0249f2b --- /dev/null +++ b/spec/mqtt/string_token_iterator_spec.cr @@ -0,0 +1,33 @@ +require "./spec_helper" +require "../../src/lavinmq/mqtt/string_token_iterator" + +def strings + [ + # { input, expected } + {"a", ["a"]}, + {"/", ["", ""]}, + {"a/", ["a", ""]}, + {"/a", ["", "a"]}, + {"a/b/c", ["a", "b", "c"]}, + {"a//c", ["a", "", "c"]}, + {"a//b/c/aa", ["a", "", "b", "c", "aa"]}, + {"long name here/and another long here", + ["long name here", "and another long here"]}, + ] +end + +describe LavinMQ::MQTT::StringTokenIterator do + strings.each do |testdata| + it "is iterated correctly" do + itr = LavinMQ::MQTT::StringTokenIterator.new(testdata[0], '/') + res = Array(String).new + while itr.next? + if val = itr.next + res << val + end + end + itr.next?.should be_false + res.should eq testdata[1] + end + end +end diff --git a/spec/mqtt/subscription_tree_spec.cr b/spec/mqtt/subscription_tree_spec.cr new file mode 100644 index 0000000000..1918edeb8a --- /dev/null +++ b/spec/mqtt/subscription_tree_spec.cr @@ -0,0 +1,205 @@ +require "./spec_helper" +require "../../src/lavinmq/mqtt/subscription_tree" + +describe LavinMQ::MQTT::SubscriptionTree do + describe "#any?" do + it "returns false for empty tree" do + tree = LavinMQ::MQTT::SubscriptionTree(String).new + tree.any?("a").should be_false + end + + describe "with subs" do + tree = LavinMQ::MQTT::SubscriptionTree(String).new + before_each do + test_data = [ + "a/b", + "a/+/b", + "a/b/c/d/#", + "a/+/c/d/#", + ] + target = "target" + + test_data.each do |topic| + tree.subscribe(topic, target, 0u8) + end + end + + it "returns false for no matching subscriptions" do + tree.any?("a").should be_false + end + + it "returns true for matching non-wildcard subs" do + tree.any?("a/b").should be_true + end + + it "returns true for matching '+'-wildcard subs" do + tree.any?("a/r/b").should be_true + end + + it "returns true for matching '#'-wildcard subs" do + tree.any?("a/b/c/d/e/f").should be_true + end + end + end + + describe "#empty?" do + it "returns true before any subscribe" do + tree = LavinMQ::MQTT::SubscriptionTree(String).new + tree.empty?.should be_true + end + + it "returns false after a non-wildcard subscribe" do + tree = LavinMQ::MQTT::SubscriptionTree(String).new + session = "target" + tree.subscribe("topic", session, 0u8) + tree.empty?.should be_false + end + + it "returns false after a +-wildcard subscribe" do + tree = LavinMQ::MQTT::SubscriptionTree(String).new + session = "target" + tree.subscribe("a/+/topic", session, 0u8) + tree.empty?.should be_false + end + + it "returns false after a #-wildcard subscribe" do + tree = LavinMQ::MQTT::SubscriptionTree(String).new + session = "session" + tree.subscribe("a/#/topic", session, 0u8) + tree.empty?.should be_false + end + + it "returns true after unsubscribing only existing non-wildcard subscription" do + tree = LavinMQ::MQTT::SubscriptionTree(String).new + session = "session" + tree.subscribe("topic", session, 0u8) + tree.unsubscribe("topic", session) + tree.empty?.should be_true + end + + it "returns true after unsubscribing only existing +-wildcard subscription" do + tree = LavinMQ::MQTT::SubscriptionTree(String).new + session = "session" + tree.subscribe("a/+/topic", session, 0u8) + tree.unsubscribe("a/+/topic", session) + tree.empty?.should be_true + end + + it "returns true after unsubscribing only existing #+-wildcard subscription" do + tree = LavinMQ::MQTT::SubscriptionTree(String).new + session = "session" + tree.subscribe("a/b/#", session, 0u8) + tree.unsubscribe("a/b/#", session) + tree.empty?.should be_true + end + + it "returns true after unsubscribing many different subscriptions" do + tree = LavinMQ::MQTT::SubscriptionTree(String).new + test_data = [ + {"session", "a/b"}, + {"session", "a/+/b"}, + {"session", "a/b/c/d#"}, + {"session", "a/+/c/d/#"}, + {"session", "#"}, + ] + + test_data.each do |session, topic| + tree.subscribe(topic, session, 0u8) + end + + test_data.shuffle.each do |session, topic| + tree.unsubscribe(topic, session) + end + + tree.empty?.should be_true + end + end + + it "subscriptions is found" do + tree = LavinMQ::MQTT::SubscriptionTree(String).new + test_data = [ + {"session1", [{"a/b", 0u8}]}, + {"session2", [{"a/b", 0u8}]}, + {"session3", [{"a/c", 0u8}]}, + {"session4", [{"a/+", 0u8}]}, + {"session5", [{"#", 0u8}]}, + ] + + test_data.each do |s| + session, subscriptions = s + subscriptions.each do |tq| + t, q = tq + tree.subscribe(t, session, q) + end + end + + calls = 0 + tree.each_entry "a/b" do |_session, qos| + qos.should eq 0u8 + calls += 1 + end + calls.should eq 4 + end + + it "unsubscribe unsubscribes" do + tree = LavinMQ::MQTT::SubscriptionTree(String).new + test_data = [ + {"session1", [{"a/b", 0u8}]}, + {"session2", [{"a/b", 0u8}]}, + {"session3", [{"a/c", 0u8}]}, + {"session4", [{"a/+", 0u8}]}, + {"session5", [{"#", 0u8}]}, + ] + + test_data.each do |session, subscriptions| + subscriptions.each do |topic, qos| + tree.subscribe(topic, session, qos) + end + end + + test_data[1, 3].each do |session, subscriptions| + subscriptions.each do |topic, _qos| + tree.unsubscribe(topic, session) + end + end + calls = 0 + tree.each_entry "a/b" do |_session, _qos| + calls += 1 + end + calls.should eq 2 + end + + it "changes qos level" do + tree = LavinMQ::MQTT::SubscriptionTree(String).new + session = "session" + tree.subscribe("a/b", session, 0u8) + tree.each_entry "a/b" { |_sess, qos| qos.should eq 0u8 } + tree.subscribe("a/b", session, 1u8) + tree.each_entry "a/b" { |_sess, qos| qos.should eq 1u8 } + end + + it "can iterate all entries" do + tree = LavinMQ::MQTT::SubscriptionTree(String).new + test_data = [ + {"session", [{"a/b", 0u8}]}, + {"session", [{"a/b/c/d/e", 0u8}]}, + {"session", [{"+/c", 0u8}]}, + {"session", [{"a/+", 0u8}]}, + {"session", [{"#", 0u8}]}, + {"session", [{"a/b/#", 0u8}]}, + {"session", [{"a/+/c", 0u8}]}, + ] + + test_data.each do |session, subscriptions| + subscriptions.each do |topic, qos| + tree.subscribe(topic, session, qos) + end + end + + calls = 0 + tree.each_entry do |_session, _qos| + calls += 1 + end + calls.should eq 7 + end +end diff --git a/spec/spec_helper.cr b/spec/spec_helper.cr index f4bf0652e1..be6b512239 100644 --- a/spec/spec_helper.cr +++ b/spec/spec_helper.cr @@ -29,6 +29,13 @@ end def with_channel(s : LavinMQ::Server, file = __FILE__, line = __LINE__, **args, &) name = "lavinmq-spec-#{file}:#{line}" + s.@listeners + .select { |k, v| k.is_a?(TCPServer) && v.amqp? } + .keys + .select(TCPServer) + .first + .local_address + .port args = {port: amqp_port(s), name: name}.merge(args) conn = AMQP::Client.new(**args).connect ch = conn.channel @@ -80,9 +87,9 @@ def with_amqp_server(tls = false, replicator = LavinMQ::Clustering::NoopServer.n ctx = OpenSSL::SSL::Context::Server.new ctx.certificate_chain = "spec/resources/server_certificate.pem" ctx.private_key = "spec/resources/server_key.pem" - spawn(name: "amqp tls listen") { s.listen_tls(tcp_server, ctx) } + spawn(name: "amqp tls listen") { s.listen_tls(tcp_server, ctx, LavinMQ::Server::Protocol::AMQP) } else - spawn(name: "amqp tcp listen") { s.listen(tcp_server) } + spawn(name: "amqp tcp listen") { s.listen(tcp_server, LavinMQ::Server::Protocol::AMQP) } end Fiber.yield yield s diff --git a/src/lavinmq/amqp/client.cr b/src/lavinmq/amqp/client.cr index 23fc2063d7..9588e04bbe 100644 --- a/src/lavinmq/amqp/client.cr +++ b/src/lavinmq/amqp/client.cr @@ -4,6 +4,7 @@ require "./channel" require "../client" require "../error" require "../logger" +require "../name_validator.cr" module LavinMQ module AMQP @@ -508,7 +509,7 @@ module LavinMQ end private def declare_exchange(frame) - if !valid_entity_name(frame.exchange_name) + if !NameValidator.valid_entity_name(frame.exchange_name) send_precondition_failed(frame, "Exchange name isn't valid") elsif frame.exchange_name.empty? send_access_refused(frame, "Not allowed to declare the default exchange") @@ -516,8 +517,8 @@ module LavinMQ redeclare_exchange(e, frame) elsif frame.passive send_not_found(frame, "Exchange '#{frame.exchange_name}' doesn't exists") - elsif frame.exchange_name.starts_with? "amq." - send_access_refused(frame, "Not allowed to use the amq. prefix") + elsif NameValidator.reserved_prefix?(frame.exchange_name) + send_access_refused(frame, "Prefix #{NameValidator::PREFIX_LIST} forbidden, please choose another name") else ae = frame.arguments["x-alternate-exchange"]?.try &.as?(String) ae_ok = ae.nil? || (@user.can_write?(@vhost.name, ae) && @user.can_read?(@vhost.name, frame.exchange_name)) @@ -545,12 +546,12 @@ module LavinMQ end private def delete_exchange(frame) - if !valid_entity_name(frame.exchange_name) + if !NameValidator.valid_entity_name(frame.exchange_name) send_precondition_failed(frame, "Exchange name isn't valid") elsif frame.exchange_name.empty? - send_access_refused(frame, "Not allowed to delete the default exchange") - elsif frame.exchange_name.starts_with? "amq." - send_access_refused(frame, "Not allowed to use the amq. prefix") + send_access_refused(frame, "Prefix #{NameValidator::PREFIX_LIST} forbidden, please choose another name") + elsif NameValidator.reserved_prefix?(frame.exchange_name) + send_access_refused(frame, "Prefix #{NameValidator::PREFIX_LIST} forbidden, please choose another name") elsif !@vhost.exchanges.has_key? frame.exchange_name # should return not_found according to spec but we make it idempotent send AMQP::Frame::Exchange::DeleteOk.new(frame.channel) unless frame.no_wait @@ -569,7 +570,7 @@ module LavinMQ if frame.queue_name.empty? && @last_queue_name frame.queue_name = @last_queue_name.not_nil! end - if !valid_entity_name(frame.queue_name) + if !NameValidator.valid_entity_name(frame.queue_name) send_precondition_failed(frame, "Queue name isn't valid") return end @@ -592,17 +593,12 @@ module LavinMQ end end - private def valid_entity_name(name) : Bool - return true if name.empty? - name.matches?(/\A[ -~]*\z/) - end - def queue_exclusive_to_other_client?(q) q.exclusive? && !@exclusive_queues.includes?(q) end private def declare_queue(frame) - if !frame.queue_name.empty? && !valid_entity_name(frame.queue_name) + if !frame.queue_name.empty? && !NameValidator.valid_entity_name(frame.queue_name) send_precondition_failed(frame, "Queue name isn't valid") elsif q = @vhost.queues.fetch(frame.queue_name, nil) redeclare_queue(frame, q) @@ -619,8 +615,8 @@ module LavinMQ end elsif frame.passive send_not_found(frame, "Queue '#{frame.queue_name}' doesn't exists") - elsif frame.queue_name.starts_with? "amq." - send_access_refused(frame, "Not allowed to use the amq. prefix") + elsif NameValidator.reserved_prefix?(frame.queue_name) + send_access_refused(frame, "Prefix #{NameValidator::PREFIX_LIST} forbidden, please choose another name") elsif @vhost.max_queues.try { |max| @vhost.queues.size >= max } send_access_refused(frame, "queue limit in vhost '#{@vhost.name}' (#{@vhost.max_queues}) is reached") else @@ -736,10 +732,10 @@ module LavinMQ end private def valid_q_bind_unbind?(frame) : Bool - if !valid_entity_name(frame.queue_name) + if !NameValidator.valid_entity_name(frame.queue_name) send_precondition_failed(frame, "Queue name isn't valid") return false - elsif !valid_entity_name(frame.exchange_name) + elsif !NameValidator.valid_entity_name(frame.exchange_name) send_precondition_failed(frame, "Exchange name isn't valid") return false end @@ -757,8 +753,8 @@ module LavinMQ send_access_refused(frame, "User doesn't have read permissions to exchange '#{frame.source}'") elsif !@user.can_write?(@vhost.name, frame.destination) send_access_refused(frame, "User doesn't have write permissions to exchange '#{frame.destination}'") - elsif frame.source.empty? || frame.destination.empty? - send_access_refused(frame, "Not allowed to bind to the default exchange") + # elsif source.is_a?(LavinMQ::MQTT::Exchange) || destination.is_a?(LavinMQ::MQTT::Exchange) + # send_access_refused(frame, "Not allowed to bind to an MQTT Exchange") else @vhost.apply(frame) send AMQP::Frame::Exchange::BindOk.new(frame.channel) unless frame.no_wait @@ -794,7 +790,7 @@ module LavinMQ send_access_refused(frame, "User doesn't have write permissions to queue '#{frame.queue_name}'") return end - if !valid_entity_name(frame.queue_name) + if !NameValidator.valid_entity_name(frame.queue_name) send_precondition_failed(frame, "Queue name isn't valid") elsif q = @vhost.queues.fetch(frame.queue_name, nil) if queue_exclusive_to_other_client?(q) @@ -820,7 +816,7 @@ module LavinMQ if frame.queue.empty? && @last_queue_name frame.queue = @last_queue_name.not_nil! end - if !valid_entity_name(frame.queue) + if !NameValidator.valid_entity_name(frame.queue) send_precondition_failed(frame, "Queue name isn't valid") return end @@ -835,7 +831,7 @@ module LavinMQ if frame.queue.empty? && @last_queue_name frame.queue = @last_queue_name.not_nil! end - if !valid_entity_name(frame.queue) + if !NameValidator.valid_entity_name(frame.queue) send_precondition_failed(frame, "Queue name isn't valid") return end diff --git a/src/lavinmq/amqp/connection_factory.cr b/src/lavinmq/amqp/connection_factory.cr index a427ba5eee..b182550db1 100644 --- a/src/lavinmq/amqp/connection_factory.cr +++ b/src/lavinmq/amqp/connection_factory.cr @@ -1,6 +1,8 @@ require "../version" require "../logger" require "./client" +require "../user_store" +require "../vhost_store" require "../client/connection_factory" module LavinMQ @@ -8,16 +10,19 @@ module LavinMQ class ConnectionFactory < LavinMQ::ConnectionFactory Log = LavinMQ::Log.for "amqp.connection_factory" - def start(socket, connection_info, vhosts, users) : Client? + def initialize(@users : UserStore, @vhosts : VHostStore) + end + + def start(socket, connection_info : ConnectionInfo) : Client? remote_address = connection_info.src socket.read_timeout = 15.seconds metadata = ::Log::Metadata.build({address: remote_address.to_s}) logger = Logger.new(Log, metadata) if confirm_header(socket, logger) if start_ok = start(socket, logger) - if user = authenticate(socket, remote_address, users, start_ok, logger) + if user = authenticate(socket, remote_address, start_ok, logger) if tune_ok = tune(socket, logger) - if vhost = open(socket, vhosts, user, logger) + if vhost = open(socket, user, logger) socket.read_timeout = heartbeat_timeout(tune_ok) return LavinMQ::AMQP::Client.new(socket, connection_info, vhost, user, tune_ok, start_ok) end @@ -71,7 +76,7 @@ module LavinMQ }, }) - def start(socket, log) + def start(socket, log : Logger) start = AMQP::Frame::Connection::Start.new(server_properties: SERVER_PROPERTIES) socket.write_bytes start, ::IO::ByteFormat::NetworkEndian socket.flush @@ -100,9 +105,9 @@ module LavinMQ end end - def authenticate(socket, remote_address, users, start_ok, log) + def authenticate(socket, remote_address, start_ok, log) username, password = credentials(start_ok) - user = users[username]? + user = @users[username]? return user if user && user.password && user.password.not_nil!.verify(password) && guest_only_loopback?(remote_address, user) @@ -150,10 +155,10 @@ module LavinMQ tune_ok end - def open(socket, vhosts, user, log) + def open(socket, user, log) open = AMQP::Frame.from_io(socket) { |f| f.as(AMQP::Frame::Connection::Open) } vhost_name = open.vhost.empty? ? "/" : open.vhost - if vhost = vhosts[vhost_name]? + if vhost = @vhosts[vhost_name]? if user.permissions[vhost_name]? if vhost.max_connections.try { |max| vhost.connections.size >= max } log.warn { "Max connections (#{vhost.max_connections}) reached for vhost #{vhost_name}" } diff --git a/src/lavinmq/amqp/queue/queue.cr b/src/lavinmq/amqp/queue/queue.cr index fc1d221613..6f120ac4b4 100644 --- a/src/lavinmq/amqp/queue/queue.cr +++ b/src/lavinmq/amqp/queue/queue.cr @@ -744,6 +744,18 @@ module LavinMQ::AMQP end end + def unacked_messages + unacked_messages = consumers.each.select(AMQP::Consumer).flat_map do |c| + c.unacked_messages.each.compact_map do |u| + next unless u.queue == self + if consumer = u.consumer + UnackedMessage.new(c.channel, u.tag, u.delivered_at, consumer.tag) + end + end + end + unacked_messages.chain(self.basic_get_unacked.each) + end + private def with_delivery_count_header(env) : Envelope? if limit = @delivery_limit sp = env.segment_position diff --git a/src/lavinmq/clustering/client.cr b/src/lavinmq/clustering/client.cr index 7cce19a60e..7f6a75de80 100644 --- a/src/lavinmq/clustering/client.cr +++ b/src/lavinmq/clustering/client.cr @@ -11,8 +11,10 @@ module LavinMQ @closed = false @amqp_proxy : Proxy? @http_proxy : Proxy? + @mqtt_proxy : Proxy? @unix_amqp_proxy : Proxy? @unix_http_proxy : Proxy? + @unix_mqtt_proxy : Proxy? @socket : TCPSocket? def initialize(@config : Config, @id : Int32, @password : String, proxy = true) @@ -30,8 +32,10 @@ module LavinMQ if proxy @amqp_proxy = Proxy.new(@config.amqp_bind, @config.amqp_port) @http_proxy = Proxy.new(@config.http_bind, @config.http_port) + @mqtt_proxy = Proxy.new(@config.mqtt_bind, @config.mqtt_port) @unix_amqp_proxy = Proxy.new(@config.unix_path) unless @config.unix_path.empty? @unix_http_proxy = Proxy.new(@config.http_unix_path) unless @config.http_unix_path.empty? + @unix_mqtt_proxy = Proxy.new(@config.mqtt_unix_path) unless @config.mqtt_unix_path.empty? end HTTP::Server.follower_internal_socket_http_server @@ -64,12 +68,18 @@ module LavinMQ if http_proxy = @http_proxy spawn http_proxy.forward_to(host, @config.http_port), name: "HTTP proxy" end + if mqtt_proxy = @mqtt_proxy + spawn mqtt_proxy.forward_to(host, @config.mqtt_port), name: "MQTT proxy" + end if unix_amqp_proxy = @unix_amqp_proxy spawn unix_amqp_proxy.forward_to(host, @config.amqp_port), name: "AMQP proxy" end if unix_http_proxy = @unix_http_proxy spawn unix_http_proxy.forward_to(host, @config.http_port), name: "HTTP proxy" end + if unix_mqtt_proxy = @unix_mqtt_proxy + spawn unix_mqtt_proxy.forward_to(host, @config.mqtt_port), name: "MQTT proxy" + end loop do @socket = socket = TCPSocket.new(host, port) socket.sync = true @@ -274,8 +284,10 @@ module LavinMQ @closed = true @amqp_proxy.try &.close @http_proxy.try &.close + @mqtt_proxy.try &.close @unix_amqp_proxy.try &.close @unix_http_proxy.try &.close + @unix_mqtt_proxy.try &.close @files.each_value &.close @data_dir_lock.release @socket.try &.close diff --git a/src/lavinmq/config.cr b/src/lavinmq/config.cr index 0d55b88b0c..64d26c6121 100644 --- a/src/lavinmq/config.cr +++ b/src/lavinmq/config.cr @@ -17,6 +17,10 @@ module LavinMQ property amqp_bind = "127.0.0.1" property amqp_port = 5672 property amqps_port = -1 + property mqtt_bind = "127.0.0.1" + property mqtt_port = 1883 + property mqtts_port = 8883 + property mqtt_unix_path = "" property unix_path = "" property unix_proxy_protocol = 1_u8 # PROXY protocol version on unix domain socket connections property tcp_proxy_protocol = 0_u8 # PROXY protocol version on amqp tcp connections @@ -39,6 +43,7 @@ module LavinMQ property socket_buffer_size = 16384 # bytes property? tcp_nodelay = false # bool property segment_size : Int32 = 8 * 1024**2 # bytes + property max_inflight_messages : UInt16 = 65_535 property? raise_gc_warn : Bool = false property? data_dir_lock : Bool = true property tcp_keepalive : Tuple(Int32, Int32, Int32)? = {60, 10, 3} # idle, interval, probes/count @@ -87,6 +92,12 @@ module LavinMQ p.on("--amqp-bind=BIND", "IP address that the AMQP server will listen on (default: 127.0.0.1)") do |v| @amqp_bind = v end + p.on("-m PORT", "--mqtt-port=PORT", "MQTT port to listen on (default: 1883)") do |v| + @mqtt_port = v.to_i + end + p.on("--mqtts-port=PORT", "MQTTS port to listen on (default: 8883)") do |v| + @mqtts_port = v.to_i + end p.on("--http-port=PORT", "HTTP port to listen on (default: 15672)") do |v| @http_port = v.to_i end @@ -102,6 +113,9 @@ module LavinMQ p.on("--http-unix-path=PATH", "HTTP UNIX path to listen to") do |v| @http_unix_path = v end + p.on("--mqtt-unix-path=PATH", "MQTT UNIX path to listen to") do |v| + @mqtt_unix_path = v + end p.on("--cert FILE", "TLS certificate (including chain)") { |v| @tls_cert_path = v } p.on("--key FILE", "Private key for the TLS certificate") { |v| @tls_key_path = v } p.on("--ciphers CIPHERS", "List of TLS ciphers to allow") { |v| @tls_ciphers = v } @@ -166,6 +180,7 @@ module LavinMQ case section when "main" then parse_main(settings) when "amqp" then parse_amqp(settings) + when "mqtt" then parse_mqtt(settings) when "mgmt", "http" then parse_mgmt(settings) when "clustering" then parse_clustering(settings) when "replication" then abort("#{file}: [replication] is deprecated and replaced with [clustering], see the README for more information") @@ -276,6 +291,22 @@ module LavinMQ end end + private def parse_mqtt(settings) + settings.each do |config, v| + case config + when "bind" then @mqtt_bind = v + when "port" then @mqtt_port = v.to_i32 + when "tls_port" then @mqtts_port = v.to_i32 + when "tls_cert" then @tls_cert_path = v + when "tls_key" then @tls_key_path = v + when "mqtt_unix_path" then @mqtt_unix_path = v + when "max_inflight_messages" then @max_inflight_messages = v.to_u16 + else + STDERR.puts "WARNING: Unrecognized configuration 'mqtt/#{config}'" + end + end + end + private def parse_mgmt(settings) settings.each do |config, v| case config diff --git a/src/lavinmq/exchange/direct.cr b/src/lavinmq/exchange/direct.cr index af559859f1..4cbc96d8a5 100644 --- a/src/lavinmq/exchange/direct.cr +++ b/src/lavinmq/exchange/direct.cr @@ -27,6 +27,10 @@ module LavinMQ true end + def bind(destination : MQTT::Session, routing_key : String, headers = nil) : Bool + raise LavinMQ::Exchange::AccessRefused.new(self) + end + def unbind(destination : Destination, routing_key, headers = nil) : Bool rk_bindings = @bindings[routing_key] return false unless rk_bindings.delete destination diff --git a/src/lavinmq/exchange/exchange.cr b/src/lavinmq/exchange/exchange.cr index 0a271b0395..1ddb5a0e86 100644 --- a/src/lavinmq/exchange/exchange.cr +++ b/src/lavinmq/exchange/exchange.cr @@ -5,6 +5,7 @@ require "../sortable_json" require "../observable" require "./event" require "../amqp/queue" +require "../mqtt/session" module LavinMQ alias Destination = Queue | Exchange diff --git a/src/lavinmq/exchange/fanout.cr b/src/lavinmq/exchange/fanout.cr index 623b59a2a3..14a3729f61 100644 --- a/src/lavinmq/exchange/fanout.cr +++ b/src/lavinmq/exchange/fanout.cr @@ -23,6 +23,10 @@ module LavinMQ true end + def bind(destination : MQTT::Session, routing_key : String, headers = nil) : Bool + raise LavinMQ::Exchange::AccessRefused.new(self) + end + def unbind(destination : Destination, routing_key, headers = nil) return false unless @bindings.delete destination binding_key = BindingKey.new("") diff --git a/src/lavinmq/exchange/headers.cr b/src/lavinmq/exchange/headers.cr index 98165aa930..bd2e358b58 100644 --- a/src/lavinmq/exchange/headers.cr +++ b/src/lavinmq/exchange/headers.cr @@ -36,6 +36,10 @@ module LavinMQ true end + def bind(destination : MQTT::Session, routing_key : String, headers = nil) : Bool + raise LavinMQ::Exchange::AccessRefused.new(self) + end + def unbind(destination : Destination, routing_key, headers) args = headers ? @arguments.clone.merge!(headers) : @arguments bds = @bindings[args] diff --git a/src/lavinmq/exchange/topic.cr b/src/lavinmq/exchange/topic.cr index 117a782143..ad0f67863c 100644 --- a/src/lavinmq/exchange/topic.cr +++ b/src/lavinmq/exchange/topic.cr @@ -27,6 +27,10 @@ module LavinMQ true end + def bind(destination : MQTT::Session, routing_key : String, headers = nil) : Bool + raise LavinMQ::Exchange::AccessRefused.new(self) + end + def unbind(destination : Destination, routing_key, headers = nil) rks = routing_key.split(".") bds = @bindings[routing_key.split(".")] diff --git a/src/lavinmq/http/controller/exchanges.cr b/src/lavinmq/http/controller/exchanges.cr index 8a6dbc0c5e..1edf4a49f3 100644 --- a/src/lavinmq/http/controller/exchanges.cr +++ b/src/lavinmq/http/controller/exchanges.cr @@ -69,8 +69,8 @@ module LavinMQ bad_request(context, "Not allowed to publish to internal exchange") end context.response.status_code = 204 - elsif name.starts_with? "amq." - bad_request(context, "Not allowed to use the amq. prefix") + elsif NameValidator.reserved_prefix?(name) + bad_request(context, "Prefix #{NameValidator::PREFIX_LIST} forbidden, please choose another name") elsif name.bytesize > UInt8::MAX bad_request(context, "Exchange name too long, can't exceed 255 characters") else diff --git a/src/lavinmq/http/controller/queues.cr b/src/lavinmq/http/controller/queues.cr index 7c7f054b18..6880598d87 100644 --- a/src/lavinmq/http/controller/queues.cr +++ b/src/lavinmq/http/controller/queues.cr @@ -2,6 +2,7 @@ require "uri" require "../controller" require "../binding_helpers" require "../../unacked_message" +require "../../name_validator" module LavinMQ module HTTP @@ -46,15 +47,7 @@ module LavinMQ with_vhost(context, params) do |vhost| refuse_unless_management(context, user(context), vhost) q = queue(context, params, vhost) - unacked_messages = q.consumers.each.flat_map do |c| - c.unacked_messages.each.compact_map do |u| - next unless u.queue == q - if consumer = u.consumer - UnackedMessage.new(c.channel, u.tag, u.delivered_at, consumer.tag) - end - end - end - unacked_messages = unacked_messages.chain(q.basic_get_unacked.each) + unacked_messages = q.unacked_messages page(context, unacked_messages) end end @@ -80,8 +73,8 @@ module LavinMQ bad_request(context, "Existing queue declared with other arguments arg") end context.response.status_code = 204 - elsif name.starts_with? "amq." - bad_request(context, "Not allowed to use the amq. prefix") + elsif NameValidator.reserved_prefix?(name) + bad_request(context, "Prefix #{NameValidator::PREFIX_LIST} forbidden, please choose another name") elsif name.bytesize > UInt8::MAX bad_request(context, "Queue name too long, can't exceed 255 characters") else diff --git a/src/lavinmq/http/handler/websocket.cr b/src/lavinmq/http/handler/websocket.cr index 4a8fb131bd..6f58784d5f 100644 --- a/src/lavinmq/http/handler/websocket.cr +++ b/src/lavinmq/http/handler/websocket.cr @@ -11,7 +11,7 @@ module LavinMQ Socket::IPAddress.new("127.0.0.1", 0) # Fake when UNIXAddress connection_info = ConnectionInfo.new(remote_address, local_address) io = WebSocketIO.new(ws) - spawn amqp_server.handle_connection(io, connection_info), name: "HandleWSconnection #{remote_address}" + spawn amqp_server.handle_connection(io, connection_info, Server::Protocol::AMQP), name: "HandleWSconnection #{remote_address}" end end end diff --git a/src/lavinmq/launcher.cr b/src/lavinmq/launcher.cr index a683bd9cd2..b6c888989c 100644 --- a/src/lavinmq/launcher.cr +++ b/src/lavinmq/launcher.cr @@ -116,25 +116,25 @@ module LavinMQ end end - private def listen + private def listen # ameba:disable Metrics/CyclomaticComplexity + if clustering_bind = @config.clustering_bind + spawn @amqp_server.listen_clustering(clustering_bind, @config.clustering_port), name: "Clustering listener" + end + if @config.amqp_port > 0 - spawn @amqp_server.listen(@config.amqp_bind, @config.amqp_port), + spawn @amqp_server.listen(@config.amqp_bind, @config.amqp_port, Server::Protocol::AMQP), name: "AMQP listening on #{@config.amqp_port}" end if @config.amqps_port > 0 if ctx = @tls_context - spawn @amqp_server.listen_tls(@config.amqp_bind, @config.amqps_port, ctx), + spawn @amqp_server.listen_tls(@config.amqp_bind, @config.amqps_port, ctx, Server::Protocol::AMQP), name: "AMQPS listening on #{@config.amqps_port}" end end - if clustering_bind = @config.clustering_bind - spawn @amqp_server.listen_clustering(clustering_bind, @config.clustering_port), name: "Clustering listener" - end - unless @config.unix_path.empty? - spawn @amqp_server.listen_unix(@config.unix_path), name: "AMQP listening at #{@config.unix_path}" + spawn @amqp_server.listen_unix(@config.unix_path, Server::Protocol::AMQP), name: "AMQP listening at #{@config.unix_path}" end if @config.http_port > 0 @@ -153,6 +153,21 @@ module LavinMQ spawn(name: "HTTP listener") do @http_server.not_nil!.listen end + + if @config.mqtt_port > 0 + spawn @amqp_server.listen(@config.mqtt_bind, @config.mqtt_port, Server::Protocol::MQTT), + name: "MQTT listening on #{@config.mqtt_port}" + end + + if @config.mqtts_port > 0 + if ctx = @tls_context + spawn @amqp_server.listen_tls(@config.mqtt_bind, @config.mqtts_port, ctx, Server::Protocol::MQTT), + name: "MQTTS listening on #{@config.mqtts_port}" + end + end + unless @config.mqtt_unix_path.empty? + spawn @amqp_server.listen_unix(@config.mqtt_unix_path, Server::Protocol::MQTT), name: "MQTT listening at #{@config.mqtt_unix_path}" + end end private def dump_debug_info @@ -179,7 +194,7 @@ module LavinMQ STDOUT.flush @amqp_server.vhosts.each_value do |vhost| vhost.queues.each_value do |q| - if q = q.as(LavinMQ::AMQP::Queue) + if q = (q.as(LavinMQ::AMQP::Queue) || q.as?(LavinMQ::MQTT::Session)) msg_store = q.@msg_store msg_store.@segments.each_value &.unmap msg_store.@acks.each_value &.unmap diff --git a/src/lavinmq/mqtt/broker.cr b/src/lavinmq/mqtt/broker.cr new file mode 100644 index 0000000000..9352046259 --- /dev/null +++ b/src/lavinmq/mqtt/broker.cr @@ -0,0 +1,110 @@ +require "./client" +require "./consts" +require "./exchange" +require "./protocol" +require "./session" +require "./sessions" +require "./retain_store" +require "../vhost" + +module LavinMQ + module MQTT + class Broker + getter vhost, sessions + + # The Broker class acts as an intermediary between the MQTT client and the Vhost & Server, + # It is initialized when starting a connection and it manages a clients connections, + # sessions, and message exchange. + # The broker is responsible for: + # - Handling client connections and disconnections + # - Managing client sessions, including clean and persistent sessions + # - Publishing messages to the exchange + # - Subscribing and unsubscribing clients to/from topics + # - Handling the retain_store + # - Interfacing with the virtual host (vhost) and the exchange to route messages. + # The Broker class helps keep the MQTT Client concise and focused on the protocol. + + def initialize(@vhost : VHost, @replicator : Clustering::Replicator) + @sessions = Sessions.new(@vhost) + @clients = Hash(String, Client).new + @retain_store = RetainStore.new(Path[@vhost.data_dir].join("mqtt_reatined_store").to_s, @replicator) + @exchange = MQTT::Exchange.new(@vhost, EXCHANGE, @retain_store) + @vhost.exchanges[EXCHANGE] = @exchange + end + + def session_present?(client_id : String, clean_session) : Bool + return false if clean_session + session = sessions[client_id]? + return false if session.nil? || session.clean_session? + true + end + + def connect_client(socket, connection_info, user, packet) + if prev_client = @clients[packet.client_id]? + Log.trace { "Found previous client connected with client_id: #{packet.client_id}, closing" } + prev_client.close + end + client = MQTT::Client.new(socket, connection_info, user, @vhost, self, packet.client_id, packet.clean_session?, packet.will) + if session = sessions[client.client_id]? + if session.clean_session? + sessions.delete session + else + session.client = client + end + end + @clients[packet.client_id] = client + client + end + + def disconnect_client(client) + client_id = client.client_id + if session = sessions[client_id]? + session.client = nil + sessions.delete(client_id) if session.clean_session? + end + @clients.delete client_id + vhost.rm_connection(client) + end + + def publish(packet : MQTT::Publish) + @exchange.publish(packet) + end + + def subscribe(client, packet) + unless session = sessions[client.client_id]? + session = sessions.declare(client.client_id, client.@clean_session) + session.client = client + end + qos = Array(MQTT::SubAck::ReturnCode).new(packet.topic_filters.size) + packet.topic_filters.each do |tf| + qos << MQTT::SubAck::ReturnCode.from_int(tf.qos) + session.subscribe(tf.topic, tf.qos) + @retain_store.each(tf.topic) do |topic, body| + headers = AMQP::Table.new + headers[RETAIN_HEADER] = true + msg = Message.new(EXCHANGE, topic, String.new(body), + AMQP::Properties.new(headers: headers, + delivery_mode: tf.qos)) + session.publish(msg) + end + end + qos + end + + def unsubscribe(client, packet) + session = sessions[client.client_id] + packet.topics.each do |tf| + session.unsubscribe(tf) + end + end + + def clear_session(client_id) + sessions.delete client_id + end + + def close + @retain_store.close + end + end + end +end diff --git a/src/lavinmq/mqtt/brokers.cr b/src/lavinmq/mqtt/brokers.cr new file mode 100644 index 0000000000..4853e8bb6c --- /dev/null +++ b/src/lavinmq/mqtt/brokers.cr @@ -0,0 +1,37 @@ +require "./broker" +require "../clustering/replicator" +require "../observable" +require "../vhost_store" + +module LavinMQ + module MQTT + class Brokers + include Observer(VHostStore::Event) + + def initialize(@vhosts : VHostStore, @replicator : Clustering::Replicator) + @brokers = Hash(String, Broker).new(initial_capacity: @vhosts.size) + @vhosts.each do |(name, vhost)| + @brokers[name] = Broker.new(vhost, @replicator) + end + @vhosts.register_observer(self) + end + + def []?(vhost : String) : Broker? + @brokers[vhost]? + end + + def on(event : VHostStore::Event, data : Object?) + return if data.nil? + vhost = data.to_s + case event + in VHostStore::Event::Added + @brokers[vhost] = Broker.new(@vhosts[vhost], @replicator) + in VHostStore::Event::Deleted + @brokers.delete(vhost) + in VHostStore::Event::Closed + @brokers[vhost].close + end + end + end + end +end diff --git a/src/lavinmq/mqtt/client.cr b/src/lavinmq/mqtt/client.cr new file mode 100644 index 0000000000..4515a08776 --- /dev/null +++ b/src/lavinmq/mqtt/client.cr @@ -0,0 +1,240 @@ +require "openssl" +require "socket" +require "../client" +require "../error" +require "./session" +require "./protocol" + +module LavinMQ + module MQTT + class Client < LavinMQ::Client + include Stats + include SortableJSON + + getter vhost, channels, log, name, user, client_id, socket, remote_address, connection_info + @connected_at = RoughTime.unix_ms + @channels = Hash(UInt16, Client::Channel).new + @session : MQTT::Session? + rate_stats({"send_oct", "recv_oct"}) + Log = ::Log.for "mqtt.client" + + def initialize(@socket : ::IO, + @connection_info : ConnectionInfo, + @user : User, + @vhost : VHost, + @broker : MQTT::Broker, + @client_id : String, + @clean_session = false, + @will : MQTT::Will? = nil) + @io = MQTT::IO.new(@socket) + @lock = Mutex.new + @remote_address = @connection_info.src + @local_address = @connection_info.dst + @name = "#{@remote_address} -> #{@local_address}" + @metadata = ::Log::Metadata.new(nil, {vhost: @broker.vhost.name, address: @remote_address.to_s, client_id: client_id}) + @log = Logger.new(Log, @metadata) + @broker.vhost.add_connection(self) + @log.info { "Connection established for user=#{@user.name}" } + spawn read_loop + end + + def client_name + "mqtt-client" + end + + private def read_loop + loop do + @log.trace { "waiting for packet" } + packet = read_and_handle_packet + # The disconnect packet has been handled and the socket has been closed. + # If we dont breakt the loop here we'll get a IO/Error on next read. + break if packet.is_a?(MQTT::Disconnect) + end + rescue ex : ::MQTT::Protocol::Error::PacketDecode + @log.warn(exception: ex) { "Packet decode error" } + publish_will if @will + rescue ex : MQTT::Error::Connect + @log.warn { "Connect error: #{ex.message}" } + rescue ex : ::IO::Error + @log.warn { "Client unexpectedly closed connection" } unless @closed + publish_will if @will + rescue ex + @log.warn(exception: ex) { "Read Loop error" } + publish_will if @will + ensure + @broker.disconnect_client(self) + close_socket + end + + def read_and_handle_packet + packet : MQTT::Packet = MQTT::Packet.from_io(@io) + @log.trace { "Recieved packet: #{packet.inspect}" } + @recv_oct_count += packet.bytesize + + case packet + when MQTT::Publish then recieve_publish(packet) + when MQTT::PubAck then recieve_puback(packet) + when MQTT::Subscribe then recieve_subscribe(packet) + when MQTT::Unsubscribe then recieve_unsubscribe(packet) + when MQTT::PingReq then receive_pingreq(packet) + when MQTT::Disconnect then return packet + else raise "received unexpected packet: #{packet}" + end + packet + end + + def send(packet) + @lock.synchronize do + packet.to_io(@io) + @socket.flush + end + @send_oct_count += packet.bytesize + end + + def receive_pingreq(packet : MQTT::PingReq) + send MQTT::PingResp.new + end + + def recieve_publish(packet : MQTT::Publish) + @broker.publish(packet) + # Ok to not send anything if qos = 0 (fire and forget) + if packet.qos > 0 && (packet_id = packet.packet_id) + send(MQTT::PubAck.new(packet_id)) + end + end + + def recieve_puback(packet : MQTT::PubAck) + @broker.sessions[@client_id].ack(packet) + end + + def recieve_subscribe(packet : MQTT::Subscribe) + qos = @broker.subscribe(self, packet) + send(MQTT::SubAck.new(qos, packet.packet_id)) + end + + def recieve_unsubscribe(packet : MQTT::Unsubscribe) + @broker.unsubscribe(self, packet) + send(MQTT::UnsubAck.new(packet.packet_id)) + end + + def details_tuple + { + vhost: @broker.vhost.name, + user: @user.name, + protocol: "MQTT", + client_id: @client_id, + connected_at: @connected_at, + }.merge(stats_details) + end + + private def publish_will + return unless will = @will + packet = MQTT::Publish.new( + topic: will.topic, + payload: will.payload, + packet_id: nil, + qos: will.qos, + retain: will.retain?, + dup: false, + ) + @broker.publish(packet) + rescue ex + @log.warn { "Failed to publish will: #{ex.message}" } + end + + def update_rates + end + + def close(reason = "") + @log.debug { "Client#close" } + @closed = true + @socket.close + end + + def force_close + end + + private def close_socket + socket = @socket + if socket.responds_to?(:"write_timeout=") + socket.write_timeout = 1.seconds + end + socket.close + rescue ::IO::Error + end + end + + class Consumer < LavinMQ::Client::Channel::Consumer + getter unacked = 0_u32 + getter tag : String = "mqtt" + property prefetch_count = 1 + + def initialize(@client : Client, @session : MQTT::Session) + @has_capacity.try_send? true + end + + def details_tuple + { + queue: { + name: "mqtt.#{@client.client_id}", + vhost: @client.vhost.name, + }, + channel_details: { + peer_host: "#{@client.remote_address}", + peer_port: "#{@client.connection_info.src}", + connection_name: "mqtt.#{@client.client_id}", + user: "#{@client.user}", + number: "", + name: "mqtt.#{@client.client_id}", + }, + prefetch_count: prefetch_count, + consumer_tag: @client.client_id, + } + end + + def no_ack? + true + end + + def accepts? : Bool + true + end + + def deliver(msg : MQTT::Publish) + @client.send(msg) + end + + def deliver(msg, sp, redelivered = false, recover = false) + end + + def exclusive? + true + end + + def cancel + end + + def close + end + + def closed? + false + end + + def flow(active : Bool) + end + + getter has_capacity = ::Channel(Bool).new + + def ack(sp) + end + + def reject(sp, requeue = false) + end + + def priority + 0 + end + end + end +end diff --git a/src/lavinmq/mqtt/connection_factory.cr b/src/lavinmq/mqtt/connection_factory.cr new file mode 100644 index 0000000000..4d0c41c82b --- /dev/null +++ b/src/lavinmq/mqtt/connection_factory.cr @@ -0,0 +1,80 @@ +require "log" +require "socket" +require "./protocol" +require "./client" +require "./brokers" +require "../user" +require "../client/connection_factory" + +module LavinMQ + module MQTT + class ConnectionFactory < LavinMQ::ConnectionFactory + def initialize(@users : UserStore, + @vhosts : VHostStore, + @brokers : Brokers, + replicator : Clustering::Replicator) + end + + def start(socket : ::IO, connection_info : ConnectionInfo) + io = MQTT::IO.new(socket) + if packet = Packet.from_io(socket).as?(Connect) + Log.trace { "recv #{packet.inspect}" } + if user_and_broker = authenticate(io, packet) + user, broker = user_and_broker + packet = assign_client_id(packet) if packet.client_id.empty? + session_present = broker.session_present?(packet.client_id, packet.clean_session?) + connack io, session_present, Connack::ReturnCode::Accepted + return broker.connect_client(socket, connection_info, user, packet) + else + Log.warn { "Authentication failure for user \"#{packet.username}\"" } + connack io, false, Connack::ReturnCode::NotAuthorized + end + end + rescue ex : MQTT::Error::Connect + Log.warn { "Connect error #{ex.inspect}" } + if io + connack io, false, Connack::ReturnCode.new(ex.return_code) + end + socket.close + rescue ex + Log.warn { "Recieved invalid Connect packet: #{ex.inspect}" } + socket.close + end + + private def connack(io : MQTT::IO, session_present : Bool, return_code : Connack::ReturnCode) + Connack.new(session_present, return_code).to_io(io) + io.flush + end + + def authenticate(io, packet) + return unless (username = packet.username) && (password = packet.password) + + vhost = "/" + if split_pos = username.index(':') + vhost = username[0, split_pos] + username = username[split_pos + 1..] + end + + user = @users[username]? + return unless user + return unless user.password && user.password.try(&.verify(String.new(password))) + has_vhost_permissions = user.try &.permissions.has_key?(vhost) + return unless has_vhost_permissions + broker = @brokers[vhost]? + return unless broker + + {user, broker} + end + + def assign_client_id(packet) + client_id = Random::DEFAULT.base64(32) + Connect.new(client_id, + packet.clean_session?, + packet.keepalive, + packet.username, + packet.password, + packet.will) + end + end + end +end diff --git a/src/lavinmq/mqtt/consts.cr b/src/lavinmq/mqtt/consts.cr new file mode 100644 index 0000000000..c51ddfb0c8 --- /dev/null +++ b/src/lavinmq/mqtt/consts.cr @@ -0,0 +1,7 @@ +module LavinMQ + module MQTT + EXCHANGE = "mqtt.default" + QOS_HEADER = "mqtt.qos" + RETAIN_HEADER = "mqtt.retain" + end +end diff --git a/src/lavinmq/mqtt/exchange.cr b/src/lavinmq/mqtt/exchange.cr new file mode 100644 index 0000000000..b2e9a0eaa2 --- /dev/null +++ b/src/lavinmq/mqtt/exchange.cr @@ -0,0 +1,134 @@ +require "../exchange" +require "./consts" +require "./subscription_tree" +require "./session" +require "./retain_store" + +module LavinMQ + module MQTT + class Exchange < Exchange + struct BindingKey + def initialize(routing_key : String, arguments : AMQP::Table? = nil) + @binding_key = LavinMQ::BindingKey.new(routing_key, arguments) + end + + def inner + @binding_key + end + + def hash + @binding_key.routing_key.hash + end + end + + @bindings = Hash(BindingKey, Set(MQTT::Session)).new do |h, k| + h[k] = Set(MQTT::Session).new + end + @tree = MQTT::SubscriptionTree(MQTT::Session).new + + def type : String + "mqtt" + end + + def initialize(vhost : VHost, name : String, @retain_store : MQTT::RetainStore) + super(vhost, name, true, false, true) + end + + def publish(msg : Message, immediate : Bool, + queues : Set(Queue) = Set(Queue).new, + exchanges : Set(Exchange) = Set(Exchange).new) : Int32 + raise LavinMQ::Exchange::AccessRefused.new(self) + end + + def publish(packet : MQTT::Publish) : Int32 + @publish_in_count += 1 + + headers = AMQP::Table.new.tap do |h| + h[RETAIN_HEADER] = true if packet.retain? + end + properties = AMQP::Properties.new(headers: headers).tap do |p| + p.delivery_mode = packet.qos if packet.responds_to?(:qos) + end + + timestamp = RoughTime.unix_ms + bodysize = packet.payload.size.to_u64 + body = ::IO::Memory.new(bodysize) + body.write(packet.payload) + body.rewind + + @retain_store.retain(packet.topic, body, bodysize) if packet.retain? + + body.rewind + msg = Message.new(timestamp, EXCHANGE, packet.topic, properties, bodysize, body) + + count = 0 + @tree.each_entry(packet.topic) do |queue, qos| + msg.properties.delivery_mode = qos + if queue.publish(msg) + count += 1 + msg.body_io.seek(-msg.bodysize.to_i64, ::IO::Seek::Current) # rewind + end + end + @unroutable_count += 1 if count.zero? + @publish_out_count += count + count + end + + def bindings_details : Iterator(BindingDetails) + @bindings.each.flat_map do |binding_key, ds| + ds.each.map do |d| + BindingDetails.new(name, vhost.name, binding_key.inner, d) + end + end + end + + # Only here to make superclass happy + protected def bindings(routing_key, headers) : Iterator(Destination) + Iterator(Destination).empty + end + + def bind(destination : MQTT::Session, routing_key : String, headers = nil) : Bool + qos = headers.try { |h| h[QOS_HEADER]?.try(&.as(UInt8)) } || 0u8 + binding_key = BindingKey.new(routing_key, headers) + @bindings[binding_key].add destination + @tree.subscribe(routing_key, destination, qos) + + data = BindingDetails.new(name, vhost.name, binding_key.inner, destination) + notify_observers(ExchangeEvent::Bind, data) + true + end + + def unbind(destination : MQTT::Session, routing_key, headers = nil) : Bool + binding_key = BindingKey.new(routing_key, headers) + rk_bindings = @bindings[binding_key] + rk_bindings.delete destination + @bindings.delete binding_key if rk_bindings.empty? + + @tree.unsubscribe(routing_key, destination) + + data = BindingDetails.new(name, vhost.name, binding_key.inner, destination) + notify_observers(ExchangeEvent::Unbind, data) + + delete if @auto_delete && @bindings.each_value.all?(&.empty?) + true + end + + def bind(destination : Destination, routing_key : String, headers = nil) : Bool + raise LavinMQ::Exchange::AccessRefused.new(self) + end + + def unbind(destination : Destination, routing_key, headers = nil) : Bool + raise LavinMQ::Exchange::AccessRefused.new(self) + end + + def apply_policy(policy : Policy?, operator_policy : OperatorPolicy?) + end + + def clear_policy + end + + def handle_arguments + end + end + end +end diff --git a/src/lavinmq/mqtt/protocol.cr b/src/lavinmq/mqtt/protocol.cr new file mode 100644 index 0000000000..9349f9560f --- /dev/null +++ b/src/lavinmq/mqtt/protocol.cr @@ -0,0 +1,7 @@ +require "mqtt-protocol" + +module LavinMQ + module MQTT + include ::MQTT::Protocol + end +end diff --git a/src/lavinmq/mqtt/retain_store.cr b/src/lavinmq/mqtt/retain_store.cr new file mode 100644 index 0000000000..e4e0db3225 --- /dev/null +++ b/src/lavinmq/mqtt/retain_store.cr @@ -0,0 +1,172 @@ +require "./topic_tree" +require "digest/md5" + +module LavinMQ + module MQTT + class RetainStore + Log = LavinMQ::Log.for("retainstore") + + MESSAGE_FILE_SUFFIX = ".msg" + INDEX_FILE_NAME = "index" + + alias IndexTree = TopicTree(String) + + def initialize(@dir : String, @replicator : Clustering::Replicator, @index = IndexTree.new) + Dir.mkdir_p @dir + @files = Hash(String, File).new do |files, file_name| + file_path = File.join(@dir, file_name) + unless File.exists?(file_path) + File.open(file_path, "w").close + end + f = files[file_name] = File.new(file_path, "r+") + f.sync = true + f + end + @index_file_name = File.join(@dir, INDEX_FILE_NAME) + @index_file = File.new(@index_file_name, "a+") + @replicator.register_file(@index_file) + @lock = Mutex.new + @lock.synchronize do + if @index.empty? + restore_index(@index, @index_file) + write_index + @index_file = File.new(@index_file_name, "a+") + end + end + end + + def close + @lock.synchronize do + write_index + end + end + + private def restore_index(index : IndexTree, index_file : ::IO) + Log.info { "restoring index" } + dir = @dir + msg_count = 0 + msg_file_segments = Set(String).new( + Dir[Path[dir, "*#{MESSAGE_FILE_SUFFIX}"]].compact_map do |fname| + File.basename(fname) + end + ) + + while topic = index_file.gets + msg_file_name = make_file_name(topic) + unless msg_file_segments.delete(msg_file_name) + Log.warn { "msg file for topic #{topic} missing, dropping from index" } + next + end + index.insert(topic, msg_file_name) + Log.debug { "restored #{topic}" } + msg_count += 1 + end + + # TODO: Device what's the truth: index file or msgs file. Mybe drop the index file and rebuild + # index from msg files? + unless msg_file_segments.empty? + Log.warn { "unreferenced messages will be deleted: #{msg_file_segments.join(",")}" } + msg_file_segments.each do |file_name| + File.delete? File.join(dir, file_name) + end + end + Log.info { "restoring index done, msg_count = #{msg_count}" } + end + + def retain(topic : String, body_io : ::IO, size : UInt64) : Nil + @lock.synchronize do + Log.debug { "retain topic=#{topic} body.bytesize=#{size}" } + # An empty message with retain flag means clear the topic from retained messages + if size.zero? + delete_from_index(topic) + return + end + + unless msg_file_name = @index[topic]? + msg_file_name = make_file_name(topic) + add_to_index(topic, msg_file_name) + end + + tmp_file = File.join(@dir, "#{msg_file_name}.tmp") + File.open(tmp_file, "w+") do |f| + f.sync = true + ::IO.copy(body_io, f) + end + final_file_path = File.join(@dir, msg_file_name) + File.rename(tmp_file, final_file_path) + @files.delete(final_file_path) + @files[final_file_path] = File.new(final_file_path, "r+") + @replicator.replace_file(final_file_path) + ensure + FileUtils.rm_rf tmp_file unless tmp_file.nil? + end + end + + private def write_index + tmp_file = File.join(@dir, "#{INDEX_FILE_NAME}.next") + File.open(tmp_file, "w+") do |f| + @index.each do |topic, _filename| + f.puts topic + end + end + File.rename tmp_file, @index_file_name + @replicator.replace_file(@index_file_name) + ensure + FileUtils.rm_rf tmp_file unless tmp_file.nil? + end + + private def add_to_index(topic : String, file_name : String) : Nil + @index.insert topic, file_name + @index_file.puts topic + @index_file.flush + bytes = Bytes.new(topic.bytesize + 1) + bytes.copy_from(topic.to_slice) + bytes[-1] = 10u8 + @replicator.append(@index_file_name, bytes) + end + + private def delete_from_index(topic : String) : Nil + if file_name = @index.delete topic + Log.trace { "deleted '#{topic}' from index, deleting file #{file_name}" } + if file = @files[file_name] + @files.delete(file) + file.close + file.delete + end + @replicator.delete_file(File.join(@dir, file_name)) + end + end + + def each(subscription : String, &block : String, Bytes -> Nil) : Nil + @lock.synchronize do + @index.each(subscription) do |topic, file_name| + block.call(topic, read(file_name)) + end + end + end + + private def read(file_name : String) : Bytes + File.open(File.join(@dir, file_name), "r") do |f| + body = Bytes.new(f.size) + f.read_fully(body) + body + end + end + + def retained_messages + @lock.synchronize do + @index.size + end + end + + @hasher = Digest::MD5.new + + def make_file_name(topic : String) : String + @hasher.update topic.to_slice + "#{@hasher.hexfinal}#{MESSAGE_FILE_SUFFIX}" + ensure + @hasher.reset + end + end + end +end diff --git a/src/lavinmq/mqtt/session.cr b/src/lavinmq/mqtt/session.cr new file mode 100644 index 0000000000..7961ea988e --- /dev/null +++ b/src/lavinmq/mqtt/session.cr @@ -0,0 +1,216 @@ +require "../amqp/queue/queue" +require "../error" +require "./consts" + +module LavinMQ + module MQTT + class Session < LavinMQ::AMQP::Queue + include SortableJSON + Log = ::LavinMQ::Log.for "mqtt.session" + + @clean_session : Bool = false + getter clean_session + getter max_inflight_messages : UInt16? = Config.instance.max_inflight_messages + + def initialize(@vhost : VHost, + @name : String, + @auto_delete = false, + arguments : ::AMQ::Protocol::Table = AMQP::Table.new) + @count = 0u16 + @unacked = Hash(UInt16, SegmentPosition).new + + super(@vhost, @name, false, @auto_delete, arguments) + + @log = Logger.new(Log, @metadata) + spawn deliver_loop, name: "Session#deliver_loop", same_thread: true + end + + def clean_session? + @auto_delete + end + + private def deliver_loop + i = 0 + loop do + break if @closed + if @msg_store.empty? || @consumers.empty? + select + when @msg_store.empty_change.receive? + when @consumers_empty_change.receive? + end + next + end + consumer = consumers.first.as(MQTT::Consumer) + get_packet(false) do |pub_packet| + consumer.deliver(pub_packet) + end + Fiber.yield if (i &+= 1) % 32768 == 0 + rescue ::IO::Error + rescue ArgumentError + rescue ex + @log.error(exception: ex) { "Unexpected error in deliver loop" } + end + rescue ::Channel::ClosedError + return + rescue ex + @log.error(exception: ex) { "deliver loop exited unexpectedly" } + end + + def client=(client : MQTT::Client?) + return if @closed + @last_get_time = RoughTime.monotonic + consumers.each do |c| + c.close + rm_consumer c + end + + @msg_store_lock.synchronize do + @unacked.values.each do |sp| + @msg_store.requeue(sp) + end + end + @unacked.clear + + if c = client + add_consumer MQTT::Consumer.new(c, self) + end + @log.debug { "client set to '#{client.try &.name}'" } + end + + def durable? + !clean_session? + end + + def unacked_messages + Iterator(UnackedMessage).empty + end + + def subscribe(tf, qos) + arguments = AMQP::Table.new + arguments[QOS_HEADER] = qos + if binding = find_binding(tf) + return if binding.binding_key.arguments == arguments + unbind(tf, binding.binding_key.arguments) + end + @vhost.bind_queue(@name, EXCHANGE, tf, arguments) + end + + def unsubscribe(tf) + if binding = find_binding(tf) + unbind(tf, binding.binding_key.arguments) + end + end + + private def find_binding(rk) + bindings.find { |b| b.binding_key.routing_key == rk } + end + + private def unbind(rk, arguments) + @vhost.unbind_queue(@name, EXCHANGE, rk, arguments || AMQP::Table.new) + end + + private def get_packet(no_ack : Bool, & : MQTT::Publish -> Nil) : Bool + raise ClosedError.new if @closed + loop do + env = @msg_store_lock.synchronize { @msg_store.shift? } || break + sp = env.segment_position + no_ack = env.message.properties.delivery_mode == 0 + if no_ack + packet = build_packet(env, nil) + begin + yield packet + rescue ex + @msg_store_lock.synchronize { @msg_store.requeue(sp) } + raise ex + end + delete_message(sp) + else + id = next_id + return false unless id + packet = build_packet(env, id) + mark_unacked(sp) do + yield packet + @unacked[id] = sp + end + end + return true + end + false + rescue ex : MessageStore::Error + @log.error(ex) { "Queue closed due to error" } + close + raise ClosedError.new(cause: ex) + end + + def build_packet(env, packet_id) : MQTT::Publish + msg = env.message + retained = msg.properties.try &.headers.try &.["mqtt.retain"]? == true + + qos = case msg.properties.delivery_mode + when 2u8 + 1u8 + else + msg.properties.delivery_mode || 0u8 + end + MQTT::Publish.new( + packet_id: packet_id, + payload: msg.body, + dup: env.redelivered, + qos: qos, + retain: retained, + topic: msg.routing_key + ) + end + + def apply_policy(policy : Policy?, operator_policy : OperatorPolicy?) # ameba:disable Metrics/CyclomaticComplexity + clear_policy + Policy.merge_definitions(policy, operator_policy).each do |k, v| + @log.debug { "Applying policy #{k}: #{v}" } + case k + when "max-length" + unless @max_length.try &.< v.as_i64 + @max_length = v.as_i64 + drop_overflow + end + when "max-length-bytes" + unless @max_length_bytes.try &.< v.as_i64 + @max_length_bytes = v.as_i64 + drop_overflow + end + when "overflow" + @reject_on_overflow ||= v.as_s == "reject-publish" + end + end + @policy = policy + @operator_policy = operator_policy + end + + def ack(packet : MQTT::PubAck) : Nil + # TODO: maybe risky to not have lock around this + id = packet.packet_id + sp = @unacked[id] + @unacked.delete id + super sp + rescue + raise ::IO::Error.new("Could not acknowledge package with id: #{id}") + end + + private def message_expire_loop; end + + private def queue_expire_loop; end + + private def next_id : UInt16? + return nil if @unacked.size == max_inflight_messages + start_id = @count + next_id : UInt16 = start_id &+ 1_u16 + while @unacked.has_key?(next_id) + next_id &+= 1u16 + next_id = 1u16 if next_id == 0 + return nil if next_id == start_id + end + @count = next_id + next_id + end + end + end +end diff --git a/src/lavinmq/mqtt/sessions.cr b/src/lavinmq/mqtt/sessions.cr new file mode 100644 index 0000000000..3127226699 --- /dev/null +++ b/src/lavinmq/mqtt/sessions.cr @@ -0,0 +1,37 @@ +require "./session" +require "../vhost" + +module LavinMQ + module MQTT + struct Sessions + @queues : Hash(String, Queue) + + def initialize(@vhost : VHost) + @queues = @vhost.queues + end + + def []?(client_id : String) : Session? + @queues["mqtt.#{client_id}"]?.try &.as(Session) + end + + def [](client_id : String) : Session + @queues["mqtt.#{client_id}"].as(Session) + end + + def declare(client_id : String, clean_session : Bool) + self[client_id]? || begin + @vhost.declare_queue("mqtt.#{client_id}", !clean_session, clean_session, AMQP::Table.new({"x-queue-type": "mqtt"})) + self[client_id] + end + end + + def delete(client_id : String) + @vhost.delete_queue("mqtt.#{client_id}") + end + + def delete(session : Session) + session.delete + end + end + end +end diff --git a/src/lavinmq/mqtt/string_token_iterator.cr b/src/lavinmq/mqtt/string_token_iterator.cr new file mode 100644 index 0000000000..c62d39eb95 --- /dev/null +++ b/src/lavinmq/mqtt/string_token_iterator.cr @@ -0,0 +1,50 @@ +# +# str = "my/example/string" +# it = StringTokenIterator.new(str, '/') +# while substr = it.next +# puts substr +# end +# outputs: +# my +# example +# string +# +# Note that "empty" parts will also be returned +# str = "/" will result in two "" +# str "a//b" will result in "a", "" and "b" +# +module LavinMQ + module MQTT + struct StringTokenIterator + def initialize(@str : String, @delimiter : Char = '/') + @reader = Char::Reader.new(@str) + @iteration = 0 + end + + def next : String? + return if @reader.pos >= @str.size + # This is to make sure we return an empty string first iteration if @str starts with @delimiter + @reader.next_char unless @iteration.zero? + @iteration += 1 + head = @reader.pos + while @reader.has_next? && @reader.current_char != @delimiter + @reader.next_char + end + tail = @reader.pos + @str[head, tail - head] + end + + def next? + @reader.pos < @str.size + end + + def to_s + @str + end + + def inspect + "#{self.class.name}(@str=#{@str} @reader.pos=#{@reader.pos} @reader.current_char=#{@reader.current_char} @iteration=#{@iteration})" + end + end + end +end diff --git a/src/lavinmq/mqtt/subscription_tree.cr b/src/lavinmq/mqtt/subscription_tree.cr new file mode 100644 index 0000000000..b0285c539e --- /dev/null +++ b/src/lavinmq/mqtt/subscription_tree.cr @@ -0,0 +1,144 @@ +require "./session" +require "./string_token_iterator" + +module LavinMQ + module MQTT + class SubscriptionTree(T) + @wildcard_rest = Hash(T, UInt8).new + @plus : SubscriptionTree(T)? + @leafs = Hash(T, UInt8).new + # Non wildcards may be an unnecessary "optimization". We store all subscriptions without + # wildcard in the first level. No need to make a tree out of them. + @non_wildcards = Hash(String, Hash(T, UInt8)).new do |h, k| + h[k] = Hash(T, UInt8).new + h[k].compare_by_identity + h[k] + end + @sublevels = Hash(String, SubscriptionTree(T)).new + + def initialize + @wildcard_rest.compare_by_identity + @leafs.compare_by_identity + end + + def subscribe(filter : String, session : T, qos : UInt8) + if filter.index('#').nil? && filter.index('+').nil? + @non_wildcards[filter][session] = qos + return + end + subscribe(StringTokenIterator.new(filter), session, qos) + end + + protected def subscribe(filter : StringTokenIterator, session : T, qos : UInt8) + unless current = filter.next + @leafs[session] = qos + return + end + if current == "#" + @wildcard_rest[session] = qos + return + end + if current == "+" + plus = (@plus ||= SubscriptionTree(T).new) + plus.subscribe filter, session, qos + return + end + if !(sublevels = @sublevels[current]?) + sublevels = @sublevels[current] = SubscriptionTree(T).new + end + sublevels.subscribe filter, session, qos + return + end + + def unsubscribe(filter : String, session : T) + if subs = @non_wildcards[filter]? + return unless subs.delete(session).nil? + end + unsubscribe(StringTokenIterator.new(filter), session) + end + + protected def unsubscribe(filter : StringTokenIterator, session : T) + unless current = filter.next + @leafs.delete session + return + end + if current == "#" + @wildcard_rest.delete session + end + if (plus = @plus) && current == "+" + plus.unsubscribe filter, session + end + if sublevel = @sublevels[current]? + sublevel.unsubscribe filter, session + if sublevel.empty? + @sublevels.delete current + end + end + end + + # Returns wether any subscription matches the given filter + def any?(filter : String) : Bool + if subs = @non_wildcards[filter]? + return !subs.empty? + end + any?(StringTokenIterator.new(filter)) + end + + protected def any?(filter : StringTokenIterator) + return !@leafs.empty? unless current = filter.next + return true if !@wildcard_rest.empty? + return true if @plus.try &.any?(filter) + return true if @sublevels[current]?.try &.any?(filter) + false + end + + def empty? + return false unless @non_wildcards.empty? || @non_wildcards.values.all? &.empty? + return false unless @leafs.empty? + return false unless @wildcard_rest.empty? + if plus = @plus + return false unless plus.empty? + end + if sublevels = @sublevels + return false unless sublevels.empty? + end + true + end + + def each_entry(topic : String, &block : (T, UInt8) -> _) + if subs = @non_wildcards[topic]? + subs.each &block + end + each_entry(StringTokenIterator.new(topic), &block) + end + + protected def each_entry(topic : StringTokenIterator, &block : (T, UInt8) -> _) + unless current = topic.next + @leafs.each &block + return + end + @wildcard_rest.each &block + @plus.try &.each_entry topic, &block + if sublevel = @sublevels.fetch(current, nil) + sublevel.each_entry topic, &block + end + end + + def each_entry(&block : (T, UInt8) -> _) + @non_wildcards.each do |_, entries| + entries.each &block + end + @leafs.each &block + @wildcard_rest.each &block + @plus.try &.each_entry &block + @sublevels.each do |_, sublevel| + sublevel.each_entry &block + end + end + + def inspect + "#{self.class.name}(@wildcard_rest=#{@wildcard_rest.inspect}, @non_wildcards=#{@non_wildcards.inspect}, @plus=#{@plus.inspect}, @sublevels=#{@sublevels.inspect}, @leafs=#{@leafs.inspect})" + end + end + end +end diff --git a/src/lavinmq/mqtt/topic_tree.cr b/src/lavinmq/mqtt/topic_tree.cr new file mode 100644 index 0000000000..85611a9c73 --- /dev/null +++ b/src/lavinmq/mqtt/topic_tree.cr @@ -0,0 +1,126 @@ +require "./string_token_iterator" + +module LavinMQ + module MQTT + class TopicTree(TEntity) + @sublevels = Hash(String, TopicTree(TEntity)).new do |h, k| + h[k] = TopicTree(TEntity).new + end + + @leafs = Hash(String, Tuple(String, TEntity)).new + + def initialize + end + + def insert(topic : String, entity : TEntity) : TEntity? + insert(StringTokenIterator.new(topic, '/'), entity) + end + + def insert(topic : StringTokenIterator, entity : TEntity) : TEntity? + current = topic.next + raise ArgumentError.new "topic cannot be empty" unless current + if topic.next? + @sublevels[current].insert(topic, entity) + else + old_value = @leafs[current]? + @leafs[current] = {topic.to_s, entity} + old_value.try &.last + end + end + + def []?(topic : String) : (TEntity | Nil) + self[StringTokenIterator.new(topic, '/')]? + end + + def []?(topic : StringTokenIterator) : (TEntity | Nil) + current = topic.next + if topic.next? + return unless @sublevels.has_key?(current) + @sublevels[current][topic]? + else + @leafs[current]?.try &.last + end + end + + def [](topic : String) : TEntity + self[StringTokenIterator.new(topic, '/')] + rescue KeyError + raise KeyError.new "#{topic} not found" + end + + def [](topic : StringTokenIterator) : TEntity + current = topic.next + if topic.next? + raise KeyError.new unless @sublevels.has_key?(current) + @sublevels[current][topic] + else + @leafs[current].last + end + end + + def delete(topic : String) + delete(StringTokenIterator.new(topic, '/')) + end + + def delete(topic : StringTokenIterator) + current = topic.next + if topic.next? + return unless @sublevels.has_key?(current) + deleted = @sublevels[current].delete(topic) + if @sublevels[current].empty? + @sublevels.delete(current) + end + deleted + else + @leafs.delete(current).try &.last + end + end + + def empty? + @leafs.empty? && @sublevels.empty? + end + + def size + @leafs.size + @sublevels.values.sum(0, &.size) + end + + def each(filter : String, &blk : (String, TEntity) -> _) + each(StringTokenIterator.new(filter, '/'), &blk) + end + + def each(filter : StringTokenIterator, &blk : (String, TEntity) -> _) + current = filter.next + if current == "#" + each &blk + return + end + if current == "+" + if filter.next? + @sublevels.values.each(&.each(filter, &blk)) + else + @leafs.values.each &blk + end + return + end + if filter.next? + if sublevel = @sublevels.fetch(current, nil) + sublevel.each filter, &blk + end + else + if leaf = @leafs.fetch(current, nil) + yield leaf.first, leaf.last + end + end + end + + def each(&blk : (String, TEntity) -> _) + @leafs.values.each &blk + @sublevels.values.each(&.each(&blk)) + end + + def inspect + "#{self.class.name}(@sublevels=#{@sublevels.inspect} @leafs=#{@leafs.inspect})" + end + end + end +end diff --git a/src/lavinmq/name_validator.cr b/src/lavinmq/name_validator.cr new file mode 100644 index 0000000000..5e98129860 --- /dev/null +++ b/src/lavinmq/name_validator.cr @@ -0,0 +1,14 @@ +require "./error" + +class NameValidator + PREFIX_LIST = ["mqtt.", "amq."] + + def self.reserved_prefix?(name) + PREFIX_LIST.any? { |prefix| name.starts_with? prefix } + end + + def self.valid_entity_name(name) : Bool + return true if name.empty? + name.matches?(/\A[ -~]*\z/) + end +end diff --git a/src/lavinmq/queue_factory.cr b/src/lavinmq/queue_factory.cr index 6a0513fd9b..20d45cebc3 100644 --- a/src/lavinmq/queue_factory.cr +++ b/src/lavinmq/queue_factory.cr @@ -2,6 +2,8 @@ require "./amqp/queue" require "./amqp/queue/priority_queue" require "./amqp/queue/durable_queue" require "./amqp/queue/stream_queue" +require "./mqtt/session" +require "./name_validator" module LavinMQ class QueueFactory @@ -25,6 +27,8 @@ module LavinMQ raise Error::PreconditionFailed.new("A stream queue cannot be auto-delete") end AMQP::StreamQueue.new(vhost, frame.queue_name, frame.exclusive, frame.auto_delete, frame.arguments) + elsif mqtt_session? frame + MQTT::Session.new(vhost, frame.queue_name, frame.auto_delete, frame.arguments) else warn_if_unsupported_queue_type frame AMQP::DurableQueue.new(vhost, frame.queue_name, frame.exclusive, frame.auto_delete, frame.arguments) @@ -36,6 +40,8 @@ module LavinMQ AMQP::PriorityQueue.new(vhost, frame.queue_name, frame.exclusive, frame.auto_delete, frame.arguments) elsif stream_queue? frame raise Error::PreconditionFailed.new("A stream queue cannot be non-durable") + elsif mqtt_session? frame + MQTT::Session.new(vhost, frame.queue_name, frame.auto_delete, frame.arguments) else warn_if_unsupported_queue_type frame AMQP::Queue.new(vhost, frame.queue_name, frame.exclusive, frame.auto_delete, frame.arguments) @@ -60,5 +66,9 @@ module LavinMQ Log.info { "The queue type #{frame.arguments["x-queue-type"]} is not supported by LavinMQ and will be changed to the default queue type" } end end + + private def self.mqtt_session?(frame) : Bool + frame.arguments["x-queue-type"]? == "mqtt" + end end end diff --git a/src/lavinmq/reporter.cr b/src/lavinmq/reporter.cr index c979216b02..3a1d10b3cd 100644 --- a/src/lavinmq/reporter.cr +++ b/src/lavinmq/reporter.cr @@ -17,7 +17,7 @@ module LavinMQ puts_size_capacity vh.@queues, 4 vh.queues.each do |_, q| puts " #{q.name} #{q.durable? ? "durable" : ""} args=#{q.arguments}" - if q = q.as(LavinMQ::AMQP::Queue) + if q = (q.as(LavinMQ::AMQP::Queue) || q.as(LavinMQ::MQTT::Session)) puts_size_capacity q.@consumers, 6 puts_size_capacity q.@deliveries, 6 puts_size_capacity q.@msg_store.@segments, 6 diff --git a/src/lavinmq/server.cr b/src/lavinmq/server.cr index 69b729079d..4c5bf6b8e1 100644 --- a/src/lavinmq/server.cr +++ b/src/lavinmq/server.cr @@ -2,6 +2,7 @@ require "socket" require "openssl" require "systemd" require "./amqp" +require "./mqtt/protocol" require "./rough_time" require "../stdlib/*" require "./vhost_store" @@ -15,10 +16,16 @@ require "./proxy_protocol" require "./client/client" require "./client/connection_factory" require "./amqp/connection_factory" +require "./mqtt/connection_factory" require "./stats" module LavinMQ class Server + enum Protocol + AMQP + MQTT + end + getter vhosts, users, data_dir, parameters getter? closed, flow include ParameterTarget @@ -26,7 +33,8 @@ module LavinMQ @start = Time.monotonic @closed = false @flow = true - @listeners = Hash(Socket::Server, Symbol).new # Socket => protocol + @listeners = Hash(Socket::Server, Protocol).new # Socket => protocol + @connection_factories = Hash(Protocol, ConnectionFactory).new @replicator : Clustering::Replicator Log = LavinMQ::Log.for "server" @@ -35,8 +43,10 @@ module LavinMQ Schema.migrate(@data_dir, @replicator) @users = UserStore.new(@data_dir, @replicator) @vhosts = VHostStore.new(@data_dir, @users, @replicator) + @brokers = MQTT::Brokers.new(@vhosts, @replicator) @parameters = ParameterStore(Parameter).new(@data_dir, "parameters.json", @replicator) - @amqp_connection_factory = LavinMQ::AMQP::ConnectionFactory.new + @connection_factories[Protocol::AMQP] = AMQP::ConnectionFactory.new(@users, @vhosts) + @connection_factories[Protocol::MQTT] = MQTT::ConnectionFactory.new(@users, @vhosts, @brokers, @replicator) apply_parameter spawn stats_loop, name: "Server#stats_loop" end @@ -46,7 +56,12 @@ module LavinMQ end def amqp_url - addr = @listeners.each_key.select(TCPServer).first.local_address + addr = @listeners + .select { |k, v| k.is_a?(TCPServer) && v.amqp? } + .keys + .select(TCPServer) + .first + .local_address "amqp://#{addr}" end @@ -64,6 +79,8 @@ module LavinMQ Schema.migrate(@data_dir, @replicator) @users = UserStore.new(@data_dir, @replicator) @vhosts = VHostStore.new(@data_dir, @users, @replicator) + @connection_factories[Protocol::AMQP] = AMQP::ConnectionFactory.new(@users, @vhosts) + @connection_factories[Protocol::MQTT] = MQTT::ConnectionFactory.new(@users, @vhosts, @replicator) @parameters = ParameterStore(Parameter).new(@data_dir, "parameters.json", @replicator) apply_parameter @closed = false @@ -74,9 +91,9 @@ module LavinMQ Iterator(Client).chain(@vhosts.each_value.map(&.connections.each)) end - def listen(s : TCPServer) - @listeners[s] = :amqp - Log.info { "Listening on #{s.local_address}" } + def listen(s : TCPServer, protocol : Protocol) + @listeners[s] = protocol + Log.info { "Listening for #{protocol} on #{s.local_address}" } loop do client = s.accept? || break next client.close if @closed @@ -85,7 +102,7 @@ module LavinMQ set_socket_options(client) set_buffer_size(client) conn_info = extract_conn_info(client) - handle_connection(client, conn_info) + handle_connection(client, conn_info, protocol) rescue ex Log.warn(exception: ex) { "Error accepting connection from #{remote_address}" } client.close rescue nil @@ -120,9 +137,9 @@ module LavinMQ end end - def listen(s : UNIXServer) - @listeners[s] = :amqp - Log.info { "Listening on #{s.local_address}" } + def listen(s : UNIXServer, protocol : Protocol) + @listeners[s] = protocol + Log.info { "Listening for #{protocol} on #{s.local_address}" } loop do # do not try to use while client = s.accept? || break next client.close if @closed @@ -135,7 +152,7 @@ module LavinMQ when 2 then ProxyProtocol::V2.parse(client) else ConnectionInfo.local # TODO: use unix socket address, don't fake local end - handle_connection(client, conn_info) + handle_connection(client, conn_info, protocol) rescue ex Log.warn(exception: ex) { "Error accepting connection from #{remote_address}" } client.close rescue nil @@ -147,14 +164,14 @@ module LavinMQ @listeners.delete(s) end - def listen(bind = "::", port = 5672) + def listen(bind = "::", port = 5672, protocol : Protocol = :amqp) s = TCPServer.new(bind, port) - listen(s) + listen(s, protocol) end - def listen_tls(s : TCPServer, context) - @listeners[s] = :amqps - Log.info { "Listening on #{s.local_address} (TLS)" } + def listen_tls(s : TCPServer, context, protocol : Protocol) + @listeners[s] = protocol + Log.info { "Listening for #{protocol} on #{s.local_address} (TLS)" } loop do # do not try to use while client = s.accept? || break next client.close if @closed @@ -168,7 +185,7 @@ module LavinMQ conn_info.ssl = true conn_info.ssl_version = ssl_client.tls_version conn_info.ssl_cipher = ssl_client.cipher - handle_connection(ssl_client, conn_info) + handle_connection(ssl_client, conn_info, protocol) rescue ex Log.warn(exception: ex) { "Error accepting TLS connection from #{remote_addr}" } client.close rescue nil @@ -180,15 +197,15 @@ module LavinMQ @listeners.delete(s) end - def listen_tls(bind, port, context) - listen_tls(TCPServer.new(bind, port), context) + def listen_tls(bind, port, context, protocol : Protocol = :amqp) + listen_tls(TCPServer.new(bind, port), context, protocol) end - def listen_unix(path : String) + def listen_unix(path : String, protocol : Protocol) File.delete?(path) s = UNIXServer.new(path) File.chmod(path, 0o666) - listen(s) + listen(s, protocol) end def listen_clustering(bind, port) @@ -244,8 +261,8 @@ module LavinMQ end end - def handle_connection(socket, connection_info) - client = @amqp_connection_factory.start(socket, connection_info, @vhosts, @users) + def handle_connection(socket, connection_info, protocol : Protocol) + client = @connection_factories[protocol].start(socket, connection_info) ensure socket.close if client.nil? end diff --git a/src/lavinmq/vhost.cr b/src/lavinmq/vhost.cr index 98e3332564..d1428bd454 100644 --- a/src/lavinmq/vhost.cr +++ b/src/lavinmq/vhost.cr @@ -14,6 +14,7 @@ require "./schema" require "./event_type" require "./stats" require "./queue_factory" +require "./mqtt/session" module LavinMQ class VHost diff --git a/src/lavinmq/vhost_store.cr b/src/lavinmq/vhost_store.cr index fb52bab78f..32fab952a6 100644 --- a/src/lavinmq/vhost_store.cr +++ b/src/lavinmq/vhost_store.cr @@ -1,10 +1,21 @@ require "json" require "./vhost" require "./user" +require "./observable" module LavinMQ + class VHostStore + enum Event + Added + Deleted + Closed + end + end + class VHostStore include Enumerable({String, VHost}) + include Observable(Event) + Log = LavinMQ::Log.for "vhost_store" def initialize(@data_dir : String, @users : UserStore, @replicator : Clustering::Replicator) @@ -30,14 +41,16 @@ module LavinMQ @users.add_permission(UserStore::DIRECT_USER, name, /.*/, /.*/, /.*/) @vhosts[name] = vhost save! if save + notify_observers(Event::Added, name) vhost end def delete(name) : Nil if vhost = @vhosts.delete name - Log.info { "Deleted vhost #{name}" } @users.rm_vhost_permissions_for_all(name) vhost.delete + notify_observers(Event::Deleted, name) + Log.info { "Deleted vhost #{name}" } save! end end @@ -45,7 +58,10 @@ module LavinMQ def close WaitGroup.wait do |wg| @vhosts.each_value do |vhost| - wg.spawn &->vhost.close + wg.spawn do + vhost.close + notify_observers(Event::Closed, vhost.name) + end end end end diff --git a/static/js/connections.js b/static/js/connections.js index 611d4c5065..44ca8d6667 100644 --- a/static/js/connections.js +++ b/static/js/connections.js @@ -19,7 +19,7 @@ Table.renderTable('table', tableOptions, function (tr, item, all) { if (all) { const connectionLink = document.createElement('a') connectionLink.href = `connection#name=${encodeURIComponent(item.name)}` - if (item.client_properties.connection_name) { + if (item?.client_properties?.connection_name) { connectionLink.appendChild(document.createElement('span')).textContent = item.name connectionLink.appendChild(document.createElement('br')) connectionLink.appendChild(document.createElement('small')).textContent = item.client_properties.connection_name @@ -33,13 +33,15 @@ Table.renderTable('table', tableOptions, function (tr, item, all) { Table.renderCell(tr, 5, item.tls_version, 'center') Table.renderCell(tr, 6, item.cipher, 'center') Table.renderCell(tr, 7, item.protocol, 'center') + Table.renderCell(tr, 8, item.auth_mechanism) Table.renderCell(tr, 9, item.channel_max, 'right') Table.renderCell(tr, 10, item.timeout, 'right') - // Table.renderCell(tr, 8, item.auth_mechanism) const clientDiv = document.createElement('span') - clientDiv.textContent = `${item.client_properties.product} / ${item.client_properties.platform || ''}` - clientDiv.appendChild(document.createElement('br')) - clientDiv.appendChild(document.createElement('small')).textContent = item.client_properties.version + if (item?.client_properties) { + clientDiv.textContent = `${item.client_properties.product} / ${item.client_properties.platform || ''}` + clientDiv.appendChild(document.createElement('br')) + clientDiv.appendChild(document.createElement('small')).textContent = item.client_properties.version + } Table.renderCell(tr, 11, clientDiv) Table.renderCell(tr, 12, new Date(item.connected_at).toLocaleString(), 'center') }