diff --git a/drivers/Aqara/aqara-presence-sensor/config.yml b/drivers/Aqara/aqara-presence-sensor/config.yml new file mode 100644 index 0000000000..83bc690de3 --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/config.yml @@ -0,0 +1,5 @@ +name: 'Aqara Presence Sensor' +packageKey: 'aqara-presence-sensor' +permissions: + lan: {} + discovery: {} diff --git a/drivers/Aqara/aqara-presence-sensor/profiles/aqara-fp2-fallDetection.yml b/drivers/Aqara/aqara-presence-sensor/profiles/aqara-fp2-fallDetection.yml new file mode 100644 index 0000000000..0455fb3946 --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/profiles/aqara-fp2-fallDetection.yml @@ -0,0 +1,20 @@ +name: aqara-fp2-fallDetection +components: + - id: main + capabilities: + - id: activitySensor + version: 1 + - id: presenceSensor + version: 1 + - id: illuminanceMeasurement + version: 1 + - id: refresh + version: 1 + categories: + - name: PresenceSensor + - id: mode + capabilities: + - id: stse.deviceMode + version: 1 + categories: + - name: PresenceSensor diff --git a/drivers/Aqara/aqara-presence-sensor/profiles/aqara-fp2-sleepMonitoring.yml b/drivers/Aqara/aqara-presence-sensor/profiles/aqara-fp2-sleepMonitoring.yml new file mode 100644 index 0000000000..cb71dc61f9 --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/profiles/aqara-fp2-sleepMonitoring.yml @@ -0,0 +1,16 @@ +name: aqara-fp2-sleepMonitoring +components: + - id: main + capabilities: + - id: illuminanceMeasurement + version: 1 + - id: refresh + version: 1 + categories: + - name: PresenceSensor + - id: mode + capabilities: + - id: stse.deviceMode + version: 1 + categories: + - name: PresenceSensor diff --git a/drivers/Aqara/aqara-presence-sensor/profiles/aqara-fp2-zoneDetection.yml b/drivers/Aqara/aqara-presence-sensor/profiles/aqara-fp2-zoneDetection.yml new file mode 100644 index 0000000000..f0d466c16e --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/profiles/aqara-fp2-zoneDetection.yml @@ -0,0 +1,22 @@ +name: aqara-fp2-zoneDetection +components: + - id: main + capabilities: + - id: presenceSensor + version: 1 + - id: movementSensor + version: 1 + - id: multipleZonePresence + version: 1 + - id: illuminanceMeasurement + version: 1 + - id: refresh + version: 1 + categories: + - name: PresenceSensor + - id: mode + capabilities: + - id: stse.deviceMode + version: 1 + categories: + - name: PresenceSensor diff --git a/drivers/Aqara/aqara-presence-sensor/search-parameters.yml b/drivers/Aqara/aqara-presence-sensor/search-parameters.yml new file mode 100644 index 0000000000..fc5290a0f1 --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/search-parameters.yml @@ -0,0 +1,2 @@ +mdns: + - service: "_Aqara-FP2._tcp" diff --git a/drivers/Aqara/aqara-presence-sensor/src/discovery.lua b/drivers/Aqara/aqara-presence-sensor/src/discovery.lua new file mode 100644 index 0000000000..20b2aa64e1 --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/src/discovery.lua @@ -0,0 +1,93 @@ +local log = require "log" +local discovery = {} +local fields = require "fields" +local discovery_mdns = require "discovery_mdns" +local socket = require "cosock.socket" + +-- mapping from device DNI to info needed at discovery/init time +local device_discovery_cache = {} + +local function set_device_field(driver, device) + local device_cache_value = device_discovery_cache[device.device_network_id] + + -- persistent fields + if device_cache_value ~= nil then + device:set_field(fields.DEVICE_IPV4, device_cache_value.ip, { persist = true }) + device:set_field(fields.DEVICE_INFO, device_cache_value.device_info, { persist = true }) + device:set_field(fields.CREDENTIAL, device_cache_value.credential, { persist = true }) + end +end + +local function update_device_discovery_cache(driver, dni, ip, credential) + local device_info = driver.discovery_helper.get_device_info(driver, dni, ip) + device_discovery_cache[dni] = { + ip = ip, + device_info = device_info, + credential = credential, + } +end + +local function try_add_device(driver, device_dni, device_ip) + log.trace(string.format("try_add_device : dni= %s, ip= %s", device_dni, device_ip)) + + local credential = driver.discovery_helper.get_credential(driver, device_dni, device_ip) + + if not credential then + log.error(string.format("failed to get credential. dni= %s, ip= %s", device_dni, device_ip)) + return + end + + update_device_discovery_cache(driver, device_dni, device_ip, credential) + local create_device_msg = driver.discovery_helper.get_device_create_msg(driver, device_dni, device_ip) + driver:try_create_device(create_device_msg) +end + +function discovery.device_added(driver, device) + set_device_field(driver, device) + device_discovery_cache[device.device_network_id] = nil + driver.lifecycle_handlers.init(driver, device) +end + +function discovery.find_ip_table(driver) + return discovery_mdns.find_ip_table_by_mdns(driver) +end + +local function discovery_device(driver) + local unknown_discovered_devices = {} + local known_discovered_devices = {} + local known_devices = {} + + for _, device in pairs(driver:get_devices()) do + known_devices[device.device_network_id] = device + end + + local ip_table = discovery.find_ip_table(driver) + + for dni, ip in pairs(ip_table) do + if not known_devices[dni] then + unknown_discovered_devices[dni] = ip + else + known_discovered_devices[dni] = ip + end + end + + for dni, ip in pairs(known_discovered_devices) do + log.trace(string.format("known dni= %s, ip= %s", dni, ip)) + end + + for dni, ip in pairs(unknown_discovered_devices) do + log.trace(string.format("unknown dni= %s, ip= %s", dni, ip)) + if not device_discovery_cache[dni] then + try_add_device(driver, dni, ip) + end + end +end + +function discovery.do_network_discovery(driver, _, should_continue) + while should_continue() do + discovery_device(driver) + socket.sleep(0.2) + end +end + +return discovery diff --git a/drivers/Aqara/aqara-presence-sensor/src/discovery_mdns.lua b/drivers/Aqara/aqara-presence-sensor/src/discovery_mdns.lua new file mode 100644 index 0000000000..c00267da65 --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/src/discovery_mdns.lua @@ -0,0 +1,143 @@ +local log = require "log" +local mdns = require "st.mdns" +local net_utils = require "st.net_utils" + +local discovery_mdns = {} + +local function byte_array_to_plain_text(byte_array) + return string.char(table.unpack(byte_array)) +end + +local function get_text_by_srvname(srvname, discovery_responses) + for _, answer_item in pairs(discovery_responses.answers or {}) do + if answer_item.kind.TxtRecord ~= nil and answer_item.name == srvname then + return answer_item.kind.TxtRecord.text + end + end +end + +local function get_srvname_by_hostname(hostname, discovery_responses) + for _, answer_item in pairs(discovery_responses.answers or {}) do + if answer_item.kind.SrvRecord ~= nil and answer_item.kind.SrvRecord.target == hostname then + return answer_item.name + end + end +end + +local function get_hostname_by_ip(ip, discovery_responses) + for _, answer_item in pairs(discovery_responses.answers or {}) do + if answer_item.kind.ARecord ~= nil and answer_item.kind.ARecord.ipv4 == ip then + return answer_item.name + end + end +end + + +local function find_text_in_answers_by_ip(ip, discovery_responses) + local hostname = get_hostname_by_ip(ip, discovery_responses) + local srvname = get_srvname_by_hostname(hostname, discovery_responses) + local text = get_text_by_srvname(srvname, discovery_responses) + + return text +end + +function discovery_mdns.find_text_list_in_mdns_response(driver, ip, discovery_responses) + local text_list = {} + + for _, found_item in pairs(discovery_responses.found or {}) do + if found_item.host_info.address == ip then + for _, raw_text_array in pairs(found_item.txt.text or {}) do + local text_item = byte_array_to_plain_text(raw_text_array) + table.insert(text_list, text_item) + end + end + end + + local answer_text = find_text_in_answers_by_ip(ip, discovery_responses) + for _, text_item in pairs(answer_text or {}) do + table.insert(text_list, text_item) + end + return text_list +end + +local function filter_response_by_service_name(service_type, domain, discovery_responses) + local filtered_responses = { + answers = {}, + found = {} + } + + for _, answer in pairs(discovery_responses.answers or {}) do + table.insert(filtered_responses.answers, answer) + end + + for _, additional in pairs(discovery_responses.additional or {}) do + table.insert(filtered_responses.answers, additional) + end + + for _, found in pairs(discovery_responses.found or {}) do + if found.service_info.service_type == service_type then + table.insert(filtered_responses.found, found) + end + end + + return filtered_responses +end + +local function insert_dni_ip_from_answers(driver, filtered_responses, target_table) + for _, answer in pairs(filtered_responses.answers) do + local dni, ip + log.info("answer_name, arecod = " .. tostring(answer.name) .. ", " .. tostring(answer.kind.ARecord)) + + if answer.kind.ARecord ~= nil then + ip = answer.kind.ARecord.ipv4 + end + + if ip ~= nil then + dni = driver.discovery_helper.get_dni(driver, ip, filtered_responses) + + if dni ~= nil then + target_table[dni] = ip + end + end + end +end + +local function insert_dni_ip_from_found(driver, filtered_responses, target_table) + for _, found in pairs(filtered_responses.found) do + local dni, ip + log.info("found_name = " .. tostring(found.service_info.service_type)) + if found.host_info.address ~= nil and net_utils.validate_ipv4_string(found.host_info.address) then + log.info("ip = " .. tostring(found.host_info.address)) + ip = found.host_info.address + end + + if ip ~= nil then + dni = driver.discovery_helper.get_dni(driver, ip, filtered_responses) + + if dni ~= nil then + target_table[dni] = ip + end + end + end +end + +local function get_dni_ip_table_from_mdns_responses(driver, service_type, domain, discovery_responses) + local dni_ip_table = {} + + local filtered_responses = filter_response_by_service_name(service_type, domain, discovery_responses) + + insert_dni_ip_from_answers(driver, filtered_responses, dni_ip_table) + insert_dni_ip_from_found(driver, filtered_responses, dni_ip_table) + + return dni_ip_table +end + +function discovery_mdns.find_ip_table_by_mdns(driver) + log.info("discovery_mdns.find_device_ips") + local service_type, domain = driver.discovery_helper.get_service_type_and_domain() + local discovery_responses = mdns.discover(service_type, domain) or { found = {} } + local dni_ip_table = get_dni_ip_table_from_mdns_responses(driver, service_type, domain, discovery_responses) + return dni_ip_table +end + +return discovery_mdns diff --git a/drivers/Aqara/aqara-presence-sensor/src/fields.lua b/drivers/Aqara/aqara-presence-sensor/src/fields.lua new file mode 100644 index 0000000000..2df3fb0708 --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/src/fields.lua @@ -0,0 +1,16 @@ +--- Table of constants used to index in to device store fields +--- @module "fields" +--- @class table +--- @field IPV4 string the ipV4 address of the device + +local fields = { + DEVICE_IPV4 = "device_ipv4", + DEVICE_INFO = "device_info", + CONN_INFO = "conn_info", + EVENT_SOURCE = "eventsource", + MONITORING_TIMER = "monitoring_timer", + CREDENTIAL = "credential", + _INIT = "init" +} + +return fields diff --git a/drivers/Aqara/aqara-presence-sensor/src/fp2/api.lua b/drivers/Aqara/aqara-presence-sensor/src/fp2/api.lua new file mode 100644 index 0000000000..1c0a2ebb2b --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/src/fp2/api.lua @@ -0,0 +1,98 @@ +local log = require "log" +local json = require "st.json" +local RestClient = require "lunchbox.rest" +local utils = require "utils" + +local fp2_api = {} +fp2_api.__index = fp2_api + +local SSL_CONFIG = { + mode = "client", + protocol = "any", + verify = "peer", + options = "all", + cafile = "./selfSignedRootByAqaraLife.crt" +} + +local ADDITIONAL_HEADERS = { + ["Accept"] = "application/json", + ["Content-Type"] = "application/json", +} + +function fp2_api.labeled_socket_builder(label) + local socket_builder = utils.labeled_socket_builder(label, SSL_CONFIG) + return socket_builder +end + +local function get_base_url(device_ip) + return "https://" .. device_ip .. ":443" +end + +local function process_rest_response(response, err, partial) + if err ~= nil then + return response, err, nil + elseif response ~= nil then + local _, decoded_json = pcall(json.decode, response:get_body()) + return decoded_json, nil, response.status + else + return nil, "no response or error received", nil + end +end + +local function retry_fn(retry_attempts) + local count = 0 + return function() + count = count + 1 + return count < retry_attempts + end +end + +local function do_get(api_instance, path) + return process_rest_response(api_instance.client:get(path, api_instance.headers, retry_fn(5))) +end + +function fp2_api.new_device_manager(device_ip, bridge_info, socket_builder) + local base_url = get_base_url(device_ip) + + return setmetatable( + { + headers = ADDITIONAL_HEADERS, + client = RestClient.new(base_url, socket_builder), + base_url = base_url, + }, fp2_api + ) +end + +function fp2_api:add_header(key, value) + self.headers[key] = value +end + +function fp2_api.get_credential(device_ip, socket_builder) + local response, error, status = process_rest_response(RestClient.one_shot_get(get_base_url(device_ip) .. "/authcode", + ADDITIONAL_HEADERS, socket_builder)) + if not error and status == 200 then + local token = response + return token + else + log.error(string.format("get_credential : ip = %s, failed to get token, error = %s", device_ip, error)) + end +end + +function fp2_api.get_info(device_ip, socket_builder) + return process_rest_response(RestClient.one_shot_get(get_base_url(device_ip) .. "/info", ADDITIONAL_HEADERS, + socket_builder)) +end + +function fp2_api:get_attr() + return do_get(self, "/attr") +end + +function fp2_api:get_remove() + return do_get(self, "/remove") +end + +function fp2_api:get_sse_url() + return self.base_url .. "/status" +end + +return fp2_api diff --git a/drivers/Aqara/aqara-presence-sensor/src/fp2/device_manager.lua b/drivers/Aqara/aqara-presence-sensor/src/fp2/device_manager.lua new file mode 100644 index 0000000000..1e678d80c9 --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/src/fp2/device_manager.lua @@ -0,0 +1,278 @@ +local log = require "log" +local json = require "st.json" +local fields = require "fields" +local capabilities = require "st.capabilities" +local multipleZonePresence = require "multipleZonePresence" + +local PresenceSensor = capabilities.presenceSensor +local MovementSensor = capabilities.movementSensor +local ActivitySensor = capabilities.activitySensor +local DeviceMode = capabilities["stse.deviceMode"] + +local MOVEMENT_TIMER = "movement_timer" +local MOVEMENT_TIME = 5 +local COMP_MODE = "mode" + +local device_manager = {} +device_manager.__index = device_manager + +local FP2_MODES = { "zoneDetection", "fallDetection", "sleepMonitoring" } + +function device_manager.presence_handler(driver, device, zone, evt_value) + if not device:supports_capability(PresenceSensor) then return end + local evt_action = "not present" + if evt_value == 1 then evt_action = "present" end + device:emit_event(PresenceSensor.presence(evt_action)) +end + +function device_manager.movement_handler(driver, device, zone, evt_value) + if not device:supports_capability(MovementSensor) then return end + + local val = evt_value + local no_movement = function() + if not device:supports_capability(MovementSensor) then return end + device:emit_event(MovementSensor.movement("inactive")) + end + device:set_field(MOVEMENT_TIMER, device.thread:call_with_delay(MOVEMENT_TIME, no_movement)) + + if val == 0 then + device:emit_event(MovementSensor.movement("entering")) + elseif val == 1 then + device:emit_event(MovementSensor.movement("leaving")) + elseif val == 2 then + device:emit_event(MovementSensor.movement("enteringLeft")) + elseif val == 3 then + device:emit_event(MovementSensor.movement("leavingRight")) + elseif val == 4 then + device:emit_event(MovementSensor.movement("enteringRight")) + elseif val == 5 then + device:emit_event(MovementSensor.movement("leavingLeft")) + elseif val == 6 then + device:emit_event(MovementSensor.movement("approaching")) + elseif val == 7 then + device:emit_event(MovementSensor.movement("movingAway")) + end +end + +function device_manager.zone_presence_handler(driver, device, zone, evt_value) + if not device:supports_capability(capabilities.multipleZonePresence) then return end + + local zoneInfo = multipleZonePresence.findZoneById(driver, device, zone) + if not zoneInfo then + multipleZonePresence.createZone(driver, device, "zone" .. zone, zone) + end + local evt_action = multipleZonePresence.notPresent + if evt_value == 1 then evt_action = multipleZonePresence.present end + multipleZonePresence.changeState(driver, device, zone, evt_action) + multipleZonePresence.updateAttribute(driver, device) +end + +function device_manager.illuminance_handler(driver, device, zone, evt_value) + device:emit_event(capabilities.illuminanceMeasurement.illuminance(evt_value)) +end + +function device_manager.work_mode_handler(driver, device, zone, evt_value) + local cur_mode = device:get_latest_state(COMP_MODE, DeviceMode.ID, DeviceMode.mode.NAME) or 0 + local mode = 1 + local profile_name = "aqara-fp2-zoneDetection" + if not cur_mode then + device:emit_component_event(device.profile.components[COMP_MODE], DeviceMode.mode(FP2_MODES[mode])) + return + elseif evt_value == 0x03 then + if cur_mode == FP2_MODES[1] then return end + elseif evt_value == 0x05 then + if cur_mode == FP2_MODES[2] then return end + mode = 2 + profile_name = "aqara-fp2-fallDetection" + elseif evt_value == 0x09 then + if cur_mode == FP2_MODES[3] then return end + mode = 3 + profile_name = "aqara-fp2-sleepMonitoring" + end + device:emit_component_event(device.profile.components[COMP_MODE], DeviceMode.mode(FP2_MODES[mode])) + device:try_update_metadata({ profile = profile_name }) +end + +function device_manager.zone_quantities_handler(driver, device, zone, evt_value) + if not device:supports_capability(capabilities.multipleZonePresence) then return end + + for i = 0, 29 do + local zonePos = tostring(i + 1) + local zoneInfo = multipleZonePresence.findZoneById(driver, device, zonePos) + local curStatus = 0x1 & (evt_value >> i) + if zoneInfo and curStatus == 0 then -- delete + multipleZonePresence.deleteZone(driver, device, zonePos) + elseif not zoneInfo and curStatus == 1 then -- create + multipleZonePresence.createZone(driver, device, "zone" .. zonePos, zonePos) + multipleZonePresence.changeState(driver, device, zonePos, multipleZonePresence.notPresent) + end + end + multipleZonePresence.updateAttribute(driver, device) +end + +function device_manager.monitoring_mode_handler(driver, device, zone, evt_value) + if not device:supports_capability(MovementSensor) then return end + + local supportedEnum = { "inactive", "approaching", "movingAway", "entering", "leaving" } + if evt_value == 0x01 then + local additional = { "enteringLeft", "enteringRight", "leavingLeft", "leavingRight" } + table.move(additional, 1, #additional, 4, supportedEnum) + end + + device:emit_event(MovementSensor.supportedMovements(supportedEnum, { visibility = { displayed = false } })) +end + +function device_manager.fall_event_handler(driver, device, zone, evt_value) + if not device:supports_capability(ActivitySensor) then return end + + local event_name = "noActivity" + if evt_value == 0x01 then + event_name = "falling" + end + + device:emit_event(ActivitySensor.activity(event_name)) +end + +local resource_id = { + ["3.51.85"] = { zone = "", event_handler = device_manager.presence_handler }, + ["13.27.85"] = { zone = "", event_handler = device_manager.movement_handler }, + ["3.1.85"] = { zone = "1", event_handler = device_manager.zone_presence_handler }, + ["3.2.85"] = { zone = "2", event_handler = device_manager.zone_presence_handler }, + ["3.3.85"] = { zone = "3", event_handler = device_manager.zone_presence_handler }, + ["3.4.85"] = { zone = "4", event_handler = device_manager.zone_presence_handler }, + ["3.5.85"] = { zone = "5", event_handler = device_manager.zone_presence_handler }, + ["3.6.85"] = { zone = "6", event_handler = device_manager.zone_presence_handler }, + ["3.7.85"] = { zone = "7", event_handler = device_manager.zone_presence_handler }, + ["3.8.85"] = { zone = "8", event_handler = device_manager.zone_presence_handler }, + ["3.9.85"] = { zone = "9", event_handler = device_manager.zone_presence_handler }, + ["3.10.85"] = { zone = "10", event_handler = device_manager.zone_presence_handler }, + ["3.11.85"] = { zone = "11", event_handler = device_manager.zone_presence_handler }, + ["3.12.85"] = { zone = "12", event_handler = device_manager.zone_presence_handler }, + ["3.13.85"] = { zone = "13", event_handler = device_manager.zone_presence_handler }, + ["3.14.85"] = { zone = "14", event_handler = device_manager.zone_presence_handler }, + ["3.15.85"] = { zone = "15", event_handler = device_manager.zone_presence_handler }, + ["3.16.85"] = { zone = "16", event_handler = device_manager.zone_presence_handler }, + ["3.17.85"] = { zone = "17", event_handler = device_manager.zone_presence_handler }, + ["3.18.85"] = { zone = "18", event_handler = device_manager.zone_presence_handler }, + ["3.19.85"] = { zone = "19", event_handler = device_manager.zone_presence_handler }, + ["3.20.85"] = { zone = "20", event_handler = device_manager.zone_presence_handler }, + ["3.21.85"] = { zone = "21", event_handler = device_manager.zone_presence_handler }, + ["3.22.85"] = { zone = "22", event_handler = device_manager.zone_presence_handler }, + ["3.23.85"] = { zone = "23", event_handler = device_manager.zone_presence_handler }, + ["3.24.85"] = { zone = "24", event_handler = device_manager.zone_presence_handler }, + ["3.25.85"] = { zone = "25", event_handler = device_manager.zone_presence_handler }, + ["3.26.85"] = { zone = "26", event_handler = device_manager.zone_presence_handler }, + ["3.27.85"] = { zone = "27", event_handler = device_manager.zone_presence_handler }, + ["3.28.85"] = { zone = "28", event_handler = device_manager.zone_presence_handler }, + ["3.29.85"] = { zone = "29", event_handler = device_manager.zone_presence_handler }, + ["3.30.85"] = { zone = "30", event_handler = device_manager.zone_presence_handler }, + ["0.4.85"] = { zone = "", event_handler = device_manager.illuminance_handler }, + ["14.49.85"] = { zone = "", event_handler = device_manager.work_mode_handler }, + ["14.55.85"] = { zone = "", event_handler = device_manager.monitoring_mode_handler }, + ["4.31.85"] = { zone = "", event_handler = device_manager.fall_event_handler }, + ["200.2.20000"] = { zone = "", event_handler = device_manager.zone_quantities_handler } +} + +function device_manager.init_presence(driver, device) + if device:supports_capability(PresenceSensor) + and device:get_latest_state("main", PresenceSensor.ID, PresenceSensor.presence.NAME) == nil then + device:emit_event(PresenceSensor.presence("not present")) + end +end + +function device_manager.init_movement(driver, device) + if not device:supports_capability(MovementSensor) then return end + device:emit_event(MovementSensor.movement("inactive")) +end + +function device_manager.init_activity(driver, device) + if not device:supports_capability(ActivitySensor) then return end + device:emit_event(ActivitySensor.activity("noActivity")) +end + +function device_manager.set_zone_info_to_latest_state(driver, device) + if not device:supports_capability(capabilities.multipleZonePresence) then return end + + local zoneInfoTable = device:get_latest_state("main", capabilities.multipleZonePresence.ID, + capabilities.multipleZonePresence.zoneState.NAME, {}) + multipleZonePresence.setZoneInfo(driver, device, zoneInfoTable) +end + +function device_manager.handle_status(driver, device, status) + if not status then + log.warn("device_manager.handle_status : status is nil") + return + end + + for k, _ in pairs(status) do + if resource_id[k] then + resource_id[k].event_handler(driver, device, resource_id[k].zone, tonumber(status[k])) + else + log.warn("device_manager.handle_status : resource id status is nil") + end + end +end + +function device_manager.update_status(driver, device) + local conn_info = device:get_field(fields.CONN_INFO) + + if not conn_info then + log.warn(string.format("device_manager.update_status : failed to find conn_info, dni = %s", + device.device_network_id)) + return + end + + local _, err, status = conn_info:get_attr() + + if err or status ~= 200 then + log.error(string.format("device_manager.update_status : failed to get status, dni= %s, err= %s, status= %s", + device.device_network_id, err, status)) + if status == 404 then + device:offline() + end + return + end +end + +local sse_event_handlers = { + ["message"] = device_manager.handle_status +} + +function device_manager.handle_sse_event(driver, device, event_type, data) + local _, device_json = pcall(json.decode, data) + + local event_handler = sse_event_handlers[event_type] + if event_handler and device_json then + event_handler(driver, device, device_json) + else + log.error(string.format("handle_sse_event : unknown event type. dni = %s, event_type = '%s'", + device.device_network_id, event_type)) + end +end + +function device_manager.is_valid_connection(driver, device, conn_info) + if not conn_info then + log.warn(string.format("device_manager.is_valid_connection : failed to find conn_info, dni = %s", + device.device_network_id)) + return false + end + local _, err, status = conn_info:get_attr() + if err or status ~= 200 then + log.warn(string.format( + "device_manager.is_valid_connection : failed to connect to device, dni = %s, err= %s, status= %s", + device.device_network_id, err, status)) + return false + end + + return true +end + +function device_manager.device_monitor(driver, device, device_info) + device_manager.update_status(driver, device) +end + +function device_manager.get_sse_url(driver, device, conn_info) + return conn_info:get_sse_url() +end + +return device_manager diff --git a/drivers/Aqara/aqara-presence-sensor/src/fp2/discovery_helper.lua b/drivers/Aqara/aqara-presence-sensor/src/fp2/discovery_helper.lua new file mode 100644 index 0000000000..7a0e28917a --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/src/fp2/discovery_helper.lua @@ -0,0 +1,79 @@ +local log = require "log" +local discovery_helper = {} + +local SERVICE_TYPE = "_Aqara-FP2._tcp" +local DOMAIN = "local" + +local fp2_api = require "fp2.api" +local discovery_mdns = require "discovery_mdns" + +function discovery_helper.get_dni(driver, ip, discovery_responses) + local text_list = discovery_mdns.find_text_list_in_mdns_response(driver, ip, discovery_responses) + for _, text in ipairs(text_list) do + for key, value in string.gmatch(text, "(%S+)=(%S+)") do + if key == "mac" then + return value + end + end + end + + log.error("discovery_helper.get_dni : failed to find dni") + return nil +end + +function discovery_helper.get_service_type_and_domain() + return SERVICE_TYPE, DOMAIN +end + +function discovery_helper.get_device_create_msg(driver, device_dni, device_ip) + local device_info = fp2_api.get_info(device_ip, fp2_api.labeled_socket_builder(device_dni)) + + if not device_info then + log.warn("failed to create device create msg. device_info is nil.") + return nil + end + + local create_device_msg = { + type = "LAN", + device_network_id = device_dni, + label = device_info.label, + profile = "aqara-fp2-zoneDetection", + manufacturer = device_info.manufacturerName, + model = device_info.modelName, + vendor_provided_label = device_info.label, + } + return create_device_msg +end + +function discovery_helper.get_credential(driver, bridge_dni, bridge_ip) + local credential = fp2_api.get_credential(bridge_ip, fp2_api.labeled_socket_builder(bridge_dni)) + + if not credential then + log.warn("credential is nil") + return nil + end + + return "Bearer " .. credential.token +end + +function discovery_helper.get_connection_info(driver, device_dni, device_ip, device_info) + local conn_info = fp2_api.new_device_manager(device_ip, device_info, fp2_api.labeled_socket_builder(device_dni)) + + if conn_info == nil then + log.warn("conn_info is nil") + end + + return conn_info +end + +function discovery_helper.get_device_info(driver, device_dni, device_ip) + local device_info = fp2_api.get_info(device_ip, fp2_api.labeled_socket_builder(device_dni)) + + if device_info == nil then + log.warn("device_info is nil") + end + + return device_info +end + +return discovery_helper diff --git a/drivers/Aqara/aqara-presence-sensor/src/init.lua b/drivers/Aqara/aqara-presence-sensor/src/init.lua new file mode 100644 index 0000000000..f83561279a --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/src/init.lua @@ -0,0 +1,219 @@ +local log = require "log" +local capabilities = require "st.capabilities" +local Driver = require "st.driver" +local discovery = require "discovery" +local fields = require "fields" +local fp2_discovery_helper = require "fp2.discovery_helper" +local fp2_device_manager = require "fp2.device_manager" +local fp2_api = require "fp2.api" +local multipleZonePresence = require "multipleZonePresence" +local EventSource = require "lunchbox.sse.eventsource" + +local DEFAULT_MONITORING_INTERVAL = 300 +local CREDENTIAL_KEY_HEADER = "Authorization" + +local function handle_sse_event(driver, device, msg) + driver.device_manager.handle_sse_event(driver, device, msg.type, msg.data) +end + +local function status_update(driver, device) + local conn_info = device:get_field(fields.CONN_INFO) + if not conn_info then + log.warn(string.format("refresh : failed to find conn_info, dni = %s", device.device_network_id)) + else + local resp, err, status = conn_info:get_attr() + + if err or status ~= 200 then + log.error(string.format("refresh : failed to get attr, dni= %s, err= %s, status= %s", device.device_network_id, err, + status)) + if status == 404 then + device:offline() + end + else + driver.device_manager.handle_status(driver, device, resp) + end + end +end + +local function create_sse(driver, device, credential) + local conn_info = device:get_field(fields.CONN_INFO) + + if not driver.device_manager.is_valid_connection(driver, device, conn_info) then + log.warn("create_sse : invalid connection") + return + end + + local sse_url = driver.device_manager.get_sse_url(driver, device, conn_info) + if not sse_url then + log.error("failed to get sse_url") + else + log.trace(string.format("Creating SSE EventSource for %s, sse_url= %s", device.device_network_id, sse_url)) + local label = string.format("%s-SSE", device.device_network_id) + local eventsource = EventSource.new(sse_url, { [CREDENTIAL_KEY_HEADER] = credential }, + fp2_api.labeled_socket_builder(label)) + + eventsource.onmessage = function(msg) + if msg then + handle_sse_event(driver, device, msg) + end + end + + eventsource.onerror = function() + log.error(string.format("Eventsource error: dni= %s", device.device_network_id)) + device:offline() + end + + eventsource.onopen = function() + device:online() + end + + local old_eventsource = device:get_field(fields.EVENT_SOURCE) + if old_eventsource then + old_eventsource:close() + end + device:set_field(fields.EVENT_SOURCE, eventsource) + end +end + +local function update_connection(driver, device, device_ip, device_info) + local device_dni = device.device_network_id + local conn_info = driver.discovery_helper.get_connection_info(driver, device_dni, device_ip, device_info) + local credential = device:get_field(fields.CREDENTIAL) + + conn_info:add_header(CREDENTIAL_KEY_HEADER, credential) + + if driver.device_manager.is_valid_connection(driver, device, conn_info) then + device:set_field(fields.CONN_INFO, conn_info) + + create_sse(driver, device, credential) + end +end + + +local function find_new_connection(driver, device) + local ip_table = discovery.find_ip_table(driver) + local ip = ip_table[device.device_network_id] + if ip then + device:set_field(fields.DEVICE_IPV4, ip, { persist = true }) + local device_info = device:get_field(fields.DEVICE_INFO) + update_connection(driver, device, ip, device_info) + else + log.warn("find new conneciton : ip is nil") + end +end + +local function check_and_update_connection(driver, device) + local conn_info = device:get_field(fields.CONN_INFO) + if not driver.device_manager.is_valid_connection(driver, device, conn_info) then + device:offline() + find_new_connection(driver, device) + conn_info = device:get_field(fields.CONN_INFO) + end + + if driver.device_manager.is_valid_connection(driver, device, conn_info) then + device:online() + end +end + +local function create_monitoring_thread(driver, device, device_info) + local old_timer = device:get_field(fields.MONITORING_TIMER) + if old_timer ~= nil then + device.thread:cancel_timer(old_timer) + end + + local monitoring_interval = DEFAULT_MONITORING_INTERVAL + local new_timer = device.thread:call_on_schedule(monitoring_interval, function() + check_and_update_connection(driver, device) + driver.device_manager.device_monitor(driver, device, device_info) + end, "monitor_timer") + device:set_field(fields.MONITORING_TIMER, new_timer) +end + + + +local function do_refresh(driver, device, cmd) + check_and_update_connection(driver, device) + status_update(driver, device) + driver.device_manager.init_presence(driver, device) + driver.device_manager.init_movement(driver, device) + driver.device_manager.init_activity(driver, device) +end + +local function device_removed(driver, device) + local conn_info = device:get_field(fields.CONN_INFO) + if not conn_info then + log.warn(string.format("remove : failed to find conn_info, dni = %s", device.device_network_id)) + else + local _, err, status = conn_info:get_remove() + + if err or status ~= 200 then + log.error(string.format("remove : failed to get remove, dni= %s, err= %s, status= %s", device.device_network_id, + err, + status)) + end + end + + local eventsource = device:get_field(fields.EVENT_SOURCE) + if eventsource then + eventsource:close() + end +end + +local function device_init(driver, device) + if device:get_field(fields._INIT) then + return + end + + local device_dni = device.device_network_id + driver.controlled_devices[device_dni] = device + + local device_ip = device:get_field(fields.DEVICE_IPV4) + local device_info = device:get_field(fields.DEVICE_INFO) + local credential = device:get_field(fields.CREDENTIAL) + + if not credential then + log.error("failed to find credential.") + device:offline() + return + end + + log.trace(string.format("Creating device monitoring for %s", device.device_network_id)) + create_monitoring_thread(driver, device, device_info) + update_connection(driver, device, device_ip, device_info) + + driver.device_manager.set_zone_info_to_latest_state(driver, device) + + do_refresh(driver, device, nil) + device:set_field(fields._INIT, true, { persist = false }) +end + +local function device_info_changed(driver, device, event, args) + do_refresh(driver, device, nil) +end + +local lan_driver = Driver("aqara-fp2", + { + discovery = discovery.do_network_discovery, + lifecycle_handlers = { + added = discovery.device_added, + init = device_init, + infoChanged = device_info_changed, + removed = device_removed + }, + capability_handlers = { + [capabilities.refresh.ID] = { + [capabilities.refresh.commands.refresh.NAME] = do_refresh, + }, + [multipleZonePresence.id] = { + [multipleZonePresence.commands.createZone.name] = multipleZonePresence.commands.createZone.handler, + [multipleZonePresence.commands.deleteZone.name] = multipleZonePresence.commands.deleteZone.handler, + [multipleZonePresence.commands.updateZoneName.name] = multipleZonePresence.commands.updateZoneName.handler, + } + }, + discovery_helper = fp2_discovery_helper, + device_manager = fp2_device_manager, + controlled_devices = {}, + } +) + +lan_driver:run() diff --git a/drivers/Aqara/aqara-presence-sensor/src/lunchbox/init.lua b/drivers/Aqara/aqara-presence-sensor/src/lunchbox/init.lua new file mode 100644 index 0000000000..4767995ec2 --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/src/lunchbox/init.lua @@ -0,0 +1,4 @@ +local RestClient = require "lunchbox.rest" +local EventSource = require "lunchbox.sse.eventsource" + +return { RestClient = RestClient, EventSource = EventSource } diff --git a/drivers/Aqara/aqara-presence-sensor/src/lunchbox/rest.lua b/drivers/Aqara/aqara-presence-sensor/src/lunchbox/rest.lua new file mode 100644 index 0000000000..1715cde6ec --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/src/lunchbox/rest.lua @@ -0,0 +1,378 @@ +---@class ChunkedResponse : Response +---@field package _received_body boolean +---@field package _parsed_headers boolean +---@field public new fun(status_code: number, socket: table?): ChunkedResponse +---@field public fill_body fun(self: ChunkedResponse): string? +---@field public append_body fun(self: ChunkedResponse, next_chunk_body: string): ChunkedResponse + +local socket = require "cosock.socket" +local utils = require "utils" +local lb_utils = require "lunchbox.util" +local Request = require "luncheon.request" +local Response = require "luncheon.response" --[[@as ChunkedResponse]] +local api_version = require("version").api + +local RestCallStates = { + SEND = "Send", + RECEIVE = "Receive", + RETRY = "Retry", + RECONNECT = "Reconnect", + COMPLETE = "Complete", +} + +local function connect(client) + local port = 80 + local use_ssl = false + + if client.base_url.scheme == "https" then + port = 443 + use_ssl = true + end + + local sock, err = client.socket_builder(client.base_url.host, port, use_ssl) + + if sock == nil then + client.socket = nil + return false, err + end + + client.socket = sock + return true +end + +local function reconnect(client) + if client.socket ~= nil then + client.socket:close() + client.socket = nil + end + return connect(client) +end + +---comment +---@param client RestClient +---@param request HttpMessage +---@return integer? bytes_sent +---@return string? err_msg +---@return integer idx +local function send_request(client, request) + if client.socket == nil then + return nil, "no socket available", 0 + end + local payload = request:serialize() + + local bytes, err, idx = nil, nil, 0 + + repeat bytes, err, idx = client.socket:send(payload, idx + 1, #payload) until (bytes == #payload) + or (err ~= nil) + + return bytes, err, idx +end + +---@param original_response Response +---@param sock table +---@return Response? +---@return string? err +local function parse_chunked_response(original_response, sock) + local ChunkedTransferStates = { + EXPECTING_CHUNK_LENGTH = "ExpectingChunkLength", + EXPECTING_BODY_CHUNK = "ExpectingBodyChunk", + } + + local full_response = Response.new(original_response.status, nil) --[[@as ChunkedResponse]] + + for header in original_response.headers:iter() do full_response.headers:append_chunk(header) end + + local original_body, err = original_response:get_body() + if original_body == nil or err ~= nil then + return nil, (err or "unexpected nil in error position") + end + local next_chunk_bytes = tonumber(original_body, 16) + local next_chunk_body = "" + local bytes_read = 0; + + local state = ChunkedTransferStates.EXPECTING_BODY_CHUNK + + repeat + local pat = nil + local next_recv, next_err, partial = nil, nil, nil + + if state == ChunkedTransferStates.EXPECTING_BODY_CHUNK then + pat = next_chunk_bytes + else + pat = "*l" + end + + next_recv, next_err, partial = sock:receive(pat) + + if next_err ~= nil then + if string.lower(next_err) == "closed" then + if partial ~= nil and #partial >= 1 then + full_response:append_body(partial) + next_chunk_bytes = 0 + else + return nil, next_err + end + else + return nil, ("unexpected error reading chunked transfer: " .. next_err) + end + end + + if next_recv ~= nil and #next_recv >= 1 then + if state == ChunkedTransferStates.EXPECTING_BODY_CHUNK then + bytes_read = bytes_read + #next_recv + next_chunk_body = next_chunk_body .. next_recv + + if bytes_read >= next_chunk_bytes then + full_response = full_response:append_body(next_chunk_body) + next_chunk_body = "" + bytes_read = 0 + + state = ChunkedTransferStates.EXPECTING_CHUNK_LENGTH + end + elseif state == ChunkedTransferStates.EXPECTING_CHUNK_LENGTH then + next_chunk_bytes = tonumber(next_recv, 16) + + state = ChunkedTransferStates.EXPECTING_BODY_CHUNK + end + end + until next_chunk_bytes == 0 + + local _ = sock:receive("*l") -- clear the trailing CRLF + + full_response._received_body = true + full_response._parsed_headers = true + + return full_response +end + +---@param sock table +---@return Response|nil +---@return string? err +---@return string? partial +local function handle_response(sock) + if api_version >= 9 then + local response, err = Response.tcp_source(sock) + if err or (not response) then return response, (err or "unknown error") end + return response, response:fill_body() + end + -- called select right before passing in so we receive immediately + local initial_recv, initial_err, partial = Response.source(function() return sock:receive('*l') end) + + local full_response = nil + + if initial_recv ~= nil then + local headers = initial_recv:get_headers() + + if headers and headers:get_one("Transfer-Encoding") == "chunked" then + local response, err = parse_chunked_response(initial_recv, sock) + if err ~= nil then + return nil, err + end + full_response = response + else + full_response = initial_recv + end + + return full_response + else + return nil, initial_err, partial + end +end + +---@param client RestClient +---@param request HttpMessage +---@param retry_fn nil|fun(): boolean +---@return Response? response nil on error +---@return string? err nil on success +---@return string? partial +local function execute_request(client, request, retry_fn) + if not client._active then + return nil, "Called `execute request` on a terminated REST Client", nil + end + + if client.socket == nil then + local success, err = connect(client) + if not success then return nil, err, nil end + end + + local should_retry = retry_fn + + if type(should_retry) ~= "function" then + should_retry = function() return false end + end + + -- send output + local bytes_sent, send_err, _idx = nil, nil, 0 + -- recv output + local response, recv_err, partial = nil, nil, nil + -- return values + local ret, err = nil, nil + + local backoff = utils.backoff_builder(60, 1, 0.1) + local current_state = RestCallStates.SEND + + repeat + local retry = should_retry() + if current_state == RestCallStates.SEND then + backoff = utils.backoff_builder(60, 1, 0.1) + bytes_sent, send_err, _idx = send_request(client, request) + + if not send_err then + current_state = RestCallStates.RECEIVE + elseif retry then + if string.lower(send_err) == "closed" or string.lower(send_err):match("broken pipe") then + current_state = RestCallStates.RECONNECT + else + current_state = RestCallStates.RETRY + end + else + ret = nil + err = send_err + current_state = RestCallStates.COMPLETE + end + elseif current_state == RestCallStates.RECEIVE then + response, recv_err, partial = handle_response(client.socket) + + if not recv_err then + ret = response + err = nil + current_state = RestCallStates.COMPLETE + elseif retry then + if string.lower(recv_err) == "closed" or string.lower(recv_err):match("broken pipe") then + current_state = RestCallStates.RECONNECT + else + current_state = RestCallStates.RETRY + end + else + ret = nil + err = recv_err + current_state = RestCallStates.COMPLETE + end + elseif current_state == RestCallStates.RECONNECT then + local success, reconn_err = reconnect(client) + if success then + current_state = RestCallStates.RETRY + elseif not retry then + ret = nil + err = reconn_err + current_state = RestCallStates.COMPLETE + else + socket.sleep(backoff()) + end + elseif current_state == RestCallStates.RETRY then + bytes_sent, send_err, _idx = nil, nil, 0 + response, recv_err, partial = nil, nil, nil + current_state = RestCallStates.SEND + socket.sleep(backoff()) + end + until current_state == RestCallStates.COMPLETE + + return ret, err, partial +end + +---@class RestClient +--- +---@field base_url table `net.url` URL table +---@field socket table `cosock` TCP socket +local RestClient = {} +RestClient.__index = RestClient + +function RestClient.one_shot_get(full_url, additional_headers, socket_builder) + local url_table = lb_utils.force_url_table(full_url) + local client = RestClient.new(url_table.scheme .. "://" .. url_table.host, socket_builder) + local ret, err = client:get(url_table.path, additional_headers) + client:shutdown() + return ret, err +end + +function RestClient.one_shot_post(full_url, body, additional_headers, socket_builder) + local url_table = lb_utils.force_url_table(full_url) + local client = RestClient.new(url_table.scheme .. "://" .. url_table.host, socket_builder) + local ret, err = client:post(url_table.path, body, additional_headers) + client:shutdown() + return ret, err +end + +function RestClient:close_socket() + if self.socket ~= nil and self._active then + self.socket:close() + self.socket = nil + end +end + +function RestClient:shutdown() + self:close_socket() + self._active = false +end + +function RestClient:update_base_url(new_url) + if self.socket ~= nil then + self.socket:close() + self.socket = nil + end + + self.base_url = lb_utils.force_url_table(new_url) +end + +---@param path string +---@param additional_headers table +---@param retry_fn nil|fun(): boolean +---@return Response? the response, nil on error +---@return string? err error, nil on success +---@return string? partial +function RestClient:get(path, additional_headers, retry_fn) + local request = Request.new("GET", path, nil):add_header( + "user-agent", "smartthings-lua-edge-driver" + ):add_header("host", string.format("%s", self.base_url.host)):add_header( + "connection", "keep-alive" + ) + + if additional_headers ~= nil and type(additional_headers) == "table" then + for k, v in pairs(additional_headers) do request = request:add_header(k, v) end + end + + return execute_request(self, request, retry_fn) +end + +function RestClient:post(path, body_string, additional_headers, retry_fn) + local request = Request.new("POST", path, nil):add_header( + "user-agent", "smartthings-lua-edge-driver" + ):add_header("host", string.format("%s", self.base_url.host)):add_header( + "connection", "keep-alive" + ) + + if additional_headers ~= nil and type(additional_headers) == "table" then + for k, v in pairs(additional_headers) do request = request:add_header(k, v) end + end + + request = request:append_body(body_string) + + return execute_request(self, request, retry_fn) +end + +function RestClient:put(path, body_string, additional_headers, retry_fn) + local request = Request.new("PUT", path, nil):add_header( + "user-agent", "smartthings-lua-edge-driver" + ):add_header("host", string.format("%s", self.base_url.host)):add_header( + "connection", "keep-alive" + ) + + if additional_headers ~= nil and type(additional_headers) == "table" then + for k, v in pairs(additional_headers) do request = request:add_header(k, v) end + end + + request = request:append_body(body_string) + + return execute_request(self, request, retry_fn) +end + +function RestClient.new(base_url, sock_builder) + base_url = lb_utils.force_url_table(base_url) + + if type(sock_builder) ~= "function" then sock_builder = utils.labeled_socket_builder() end + + return + setmetatable({ base_url = base_url, socket_builder = sock_builder, socket = nil, _active = true }, RestClient) +end + +return RestClient diff --git a/drivers/Aqara/aqara-presence-sensor/src/lunchbox/sse/eventsource.lua b/drivers/Aqara/aqara-presence-sensor/src/lunchbox/sse/eventsource.lua new file mode 100644 index 0000000000..0b5244bcd1 --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/src/lunchbox/sse/eventsource.lua @@ -0,0 +1,526 @@ +local cosock = require "cosock" +local socket = require "cosock.socket" +local ssl = require "cosock.ssl" + +local log = require "log" +local util = require "lunchbox.util" +local Request = require "luncheon.request" +local Response = require "luncheon.response" + +--- A pure Lua implementation of the EventSource interface. +--- The EventSource interface represents the client end of an HTTP(S) +--- connection that receives an event stream following the Server-Sent events +--- specification. +--- +--- MDN Documentation for EventSource: https://developer.mozilla.org/en-US/docs/Web/API/EventSource +--- HTML Spec: https://html.spec.whatwg.org/multipage/server-sent-events.html +--- +--- @class EventSource +--- @field public url table A `net.url` table representing the URL for the connection +--- @field public ready_state number Enumeration of the ready states outlined in the spec. +--- @field public onopen function in-line callback for on-open events +--- @field public onmessage function in-line callback for on-message events +--- @field public onerror function in-line callback for on-error events; error callbacks will fire +--- @field package _reconnect boolean flag that says whether or not the client should attempt to reconnect on close. +--- @field package _reconnect_time_millis number The amount of time to wait between reconnects, in millis. Can be sent by the server. +--- @field package _sock_builder function|nil optional. If this function exists, it will be called to create a new TCP socket on connection. +--- @field package _sock table? the TCP socket for the connection +--- @field package _needs_more boolean flag to track whether or not we're still expecting mroe on this source before we dispatch +--- @field package _last_field string the last field the parsing path saw, in case it needs to append more to its value +--- @field package _extra_headers table a table of string:string key-value pairs that will be inserted in to the initial requests's headers. +--- @field package _parse_buffers table inner state, keeps track of the various event stream buffers in between dispatches. +--- @field package _listeners table event listeners attached using the add_event_listener API instead of the inline callbacks. +local EventSource = {} +EventSource.__index = EventSource + +--- The Ready States that an EventSource can be in. We use base 0 to match the specification. +EventSource.ReadyStates = util.read_only { + CONNECTING = 0, -- The connection has not yet been established + OPEN = 1, -- The connection is open + CLOSED = 2 -- The connection has closed +} + +--- The event types supported by this source, patterned after their values in JavaScript. +EventSource.EventTypes = util.read_only { + ON_OPEN = "open", + ON_MESSAGE = "message", + ON_ERROR = "error", +} + +--- Helper function that creates the initial Request to start the stream. +--- @function create_request +--- @local +--- @param url_table table a net.url table +--- @param extra_headers table a set of key/value pairs (strings) to capture any extra HTTP headers needed. +local function create_request(url_table, extra_headers) + local request = Request.new("GET", url_table.path, nil) + :add_header("user-agent", "smartthings-lua-edge-driver") + :add_header("host", string.format("%s", url_table.host)) + :add_header("connection", "keep-alive") + :add_header("accept", "text/event-stream") + + if type(extra_headers) == "table" then + for k, v in pairs(extra_headers) do + request = request:add_header(k, v) + end + end + + return request +end + +--- Helper function to send the request and kick off the stream. +--- @function send_stream_start_request +--- @local +--- @param payload string the entire string buffer to send +--- @param sock table the TCP socket to send it over +local function send_stream_start_request(payload, sock) + local bytes, err, idx = nil, nil, 0 + + repeat + bytes, err, idx = sock:send(payload, idx + 1, #payload) + until (bytes == #payload) or (err ~= nil) + + if err then + log.error_with({ hub_logs = true }, "send error: " .. err) + end + + return bytes, err, idx +end + +--- Helper function to create an table representing an event from the source's parse buffers. +--- @function make_event +--- @local +--- @param source EventSource +local function make_event(source) + local event_type = nil + + if #source._parse_buffers["event"] > 0 then + event_type = source._parse_buffers["event"] + end + + return { + type = event_type or "message", + data = source._parse_buffers["data"], + origin = source.url.scheme .. "://" .. source.url.host, + lastEventId = source._parse_buffers["id"] + } +end + +--- SSE spec for dispatching an event: +--- https://html.spec.whatwg.org/multipage/server-sent-events.html#dispatchMessage +--- @function dispatch_event +--- @local +--- @param source EventSource +local function dispatch_event(source) + local data_buffer = source._parse_buffers["data"] + local is_blank_line = data_buffer ~= nil and + (#data_buffer == 0) or + data_buffer == "\n" or + data_buffer == "\r" or + data_buffer == "\r\n" + if data_buffer ~= nil and not is_blank_line then + local event = util.read_only(make_event(source)) + + if type(source.onmessage) == "function" then + source.onmessage(event) + end + + for _, listener in ipairs(source._listeners[EventSource.EventTypes.ON_MESSAGE]) do + if type(listener) == "function" then + listener(event) + end + end + end + + source._parse_buffers["event"] = "" + source._parse_buffers["data"] = "" +end + +local valid_fields = util.read_only { + ["event"] = true, + ["data"] = true, + ["id"] = true, + ["retry"] = true +} + +-- An event stream "line" can end in more than one way; from the spec: +-- Lines must be separated by either +-- a U+000D CARRIAGE RETURN U+000A LINE FEED (CRLF) character pair, +-- a single U+000A LINE FEED (LF) character, +-- or a single U+000D CARRIAGE RETURN (CR) character. +-- +-- util.iter_string_lines won't suffice here because: +-- a.) it assumes \n, and +-- b.) it doesn't differentiate between a "line" that ends without a newline and one that does. +-- +-- h/t to github.com/FreeMasen for the suggestions on the efficient implementation of this +local function find_line_endings(chunk) + local r_idx, n_idx = string.find(chunk, "[\r\n]+") + if r_idx == nil or r_idx == n_idx then + -- 1 character or no match + return r_idx, n_idx + end + local slice = string.sub(chunk, r_idx, n_idx) + if slice == "\r\n" then + return r_idx, n_idx + end + -- invalid multi character match, return first character only + return r_idx, r_idx +end + +local function event_lines(chunk) + local remaining = chunk + local line_end, rn_end + local remainder_sent = false + return function() + line_end, rn_end = find_line_endings(remaining) + if not line_end then + if remainder_sent or (not remaining) or #remaining == 0 then + return nil + else + remainder_sent = true + return remaining, false + end + end + local next_line = string.sub(remaining, 1, line_end - 1) + remaining = string.sub(remaining, rn_end + 1) + return next_line, true + end +end +--- SSE spec for interpreting an event stream: +--- https://html.spec.whatwg.org/multipage/server-sent-events.html#the-eventsource-interface +--- @function parse +--- @local +--- @param source EventSource +--- @param recv string the received payload from the last socket receive +local function sse_parse_chunk(source, recv) + for line, complete in event_lines(recv) do + if not source._needs_more and (#line == 0 or (not line:match("([%w%p]+)"))) then -- empty/blank lines indicate dispatch + dispatch_event(source) + elseif source._needs_more then + local append = line + if source._last_field == "data" and complete then append = append .. "\n" end + if complete then source._needs_more = false end + source._parse_buffers[source._last_field] = source._parse_buffers[source._last_field] .. append + else + if line:sub(1, 1) ~= ":" then -- ignore any complete lines that start w/ a colon + local matches = line:gmatch("(%w*)(:*)(.*)") -- colon after field is optional, in that case it's a field w/ no value + + for field, _colon, value in matches do + value = value:gsub("^[^%g]", "", 1) -- trim a single leading space character + + if valid_fields[field] then + source._last_field = field + if field == "retry" then + local new_time = tonumber(value, 10) + if type(new_time) == "number" then + source._reconnect_time_millis = new_time + end + elseif field == "data" then + local append = (value or "") + if complete then append = append .. "\n" end + source._parse_buffers[field] = source._parse_buffers[field] .. append + elseif field == "id" then + -- skip ID's if they contain the NULL character + if not string.find(value, '\0') then + source._parse_buffers[field] = value + end + else + source._parse_buffers[field] = value + end + end + source._needs_more = source._needs_more or (not complete) + end + end + end + end +end + +--- Helper function that captures the cyclic logic of the EventSource while in the CONNECTING state. +--- @function connecting_action +--- @local +--- @param source EventSource +local function connecting_action(source) + if not source._sock then + if type(source._sock_builder) == "function" then + -- source._sock = source._sock_builder() + local use_ssl = false + if source.url.scheme == "https" then + use_ssl = true + end + + source._sock = source._sock_builder(source.url.host, source.url.port, use_ssl) + else + source._sock, err = socket.tcp() + if err ~= nil then return nil, err end + + _, err = source._sock:settimeout(60) + if err ~= nil then return nil, err end + + _, err = source._sock:connect(source.url.host, source.url.port) + if err ~= nil then return nil, err end + + _, err = source._sock:setoption("keepalive", true) + if err ~= nil then return nil, err end + + if source.url.scheme == "https" then + source._sock, err = ssl.wrap(source._sock, { + mode = "client", + protocol = "any", + verify = "none", + options = "all" + }) + if err ~= nil then return nil, err end + _, err = source._sock:dohandshake() + if err ~= nil then return nil, err end + end + end + end + + local request = create_request(source.url, source._extra_headers) + + local last_event_id = source._parse_buffers["id"] + + if last_event_id ~= nil and #last_event_id > 0 then + request = request:add_header("Last-Event-ID", last_event_id) + end + + if not source._sock then + return nil, "source._sock is nil" + end + + local _, err, _ = send_stream_start_request(request:serialize(), source._sock) + + if err ~= nil then + return nil, err + end + + local response + response, err = Response.tcp_source(source._sock) + + if not response or err ~= nil then + return nil, err or "nil response from Response.tcp_source" + end + + if response.status ~= 200 then + return nil, "Server responded with status other than 200 OK", { response.status, response.status_msg } + end + + local headers, err = response:get_headers() + if err ~= nil then + return nil, err + end + local content_type = string.lower((headers and headers:get_one('content-type') or "none")) + if not content_type:find("text/event-stream", 1, true) then + local err_msg = "Expected content type of text/event-stream in response headers, received: " .. content_type + return nil, err_msg + end + + source.ready_state = EventSource.ReadyStates.OPEN + + if type(source.onopen) == "function" then + source.onopen() + end + + for _, listener in ipairs(source._listeners[EventSource.EventTypes.ON_OPEN]) do + if type(listener) == "function" then + listener() + end + end +end +--- Helper function that captures the cyclic logic of the EventSource while in the OPEN state. +--- @function open_action +--- @local +--- @param source EventSource +local function open_action(source) + local recv, err, partial = source._sock:receive('*l') + + if err then + --- connection is fine but there was nothing + --- to be read from the other end so we just + --- early return. + if err == "timeout" or err == "wantread" then + return + else + --- real error, close the connection. + source._sock:close() + source._sock = nil + source.ready_state = EventSource.ReadyStates.CLOSED + return nil, err, partial + end + end + + -- the number of bytes to read per the chunked encoding spec + local recv_as_num = tonumber(recv, 16) + + if recv_as_num ~= nil then + recv, err, partial = source._sock:receive(recv_as_num) + if err then + if err == "timeout" or err == "wantread" then + return + else + --- real error, close the connection. + source._sock:close() + source._sock = nil + source.ready_state = EventSource.ReadyStates.CLOSED + return nil, err, partial + end + end + local _, err, partial = source._sock:receive('*l') -- clear the final line + + if err then + if err == "timeout" or err == "wantread" then + return + else + --- real error, close the connection. + source._sock:close() + source._sock = nil + source.ready_state = EventSource.ReadyStates.CLOSED + return nil, err, partial + end + end + sse_parse_chunk(source, recv) + else + local recv_dbg = recv or "" + if #recv_dbg == 0 then recv_dbg = "" end + recv_dbg = recv_dbg:gsub("\r\n", ""):gsub("\n", ""):gsub("\r", "") + log.error_with({ hub_logs = true }, + string.format("Received %s while expecting a chunked encoding payload length (hex number)\n", recv_dbg)) + end +end + +--- Helper function that captures the cyclic logic of the EventSource while in the CLOSED state. +--- @function closed_action +--- @local +--- @param source EventSource +local function closed_action(source) + if source._sock ~= nil then + source._sock:close() + source._sock = nil + end + + if source._reconnect then + if type(source.onerror) == "function" then + source.onerror() + end + + for _, listener in ipairs(source._listeners[EventSource.EventTypes.ON_ERROR]) do + if type(listener) == "function" then + listener() + end + end + + local sleep_time_secs = source._reconnect_time_millis / 1000.0 + socket.sleep(sleep_time_secs) + + source.ready_state = EventSource.ReadyStates.CONNECTING + end +end + +local state_actions = { + [EventSource.ReadyStates.CONNECTING] = connecting_action, + [EventSource.ReadyStates.OPEN] = open_action, + [EventSource.ReadyStates.CLOSED] = closed_action +} + +--- Create a new EventSource. The only required parameter is the URL, which can +--- be a string or a net.url table. The string form will be converted to a net.url table. +--- +--- @param url string|table a string or a net.url table representing the complete URL (minimally a scheme/host/path, port optional) for the event stream. +--- @param extra_headers table|nil an optional table of key-value pairs (strings) to be added to the initial GET request +--- @param sock_builder function|nil an optional function to be used to create the TCP socket for the stream. If nil, a set of defaults will be used to create a new TCP socket. +--- @return EventSource a new EventSource +function EventSource.new(url, extra_headers, sock_builder) + local url_table = util.force_url_table(url) + + local use_ssl = false + if url_table.scheme == "https" then + use_ssl = true + end + + if not url_table.port then + if url_table.scheme == "http" then + url_table.port = 80 + elseif url_table.scheme == "https" then + url_table.port = 443 + end + end + + local sock = nil + + if type(sock_builder) == "function" then + -- sock = sock_builder() + sock = sock_builder(url_table.host, url_table.port, use_ssl) + end + + local source = setmetatable({ + url = url_table, + ready_state = EventSource.ReadyStates.CONNECTING, + onopen = nil, + onmessage = nil, + onerror = nil, + _needs_more = false, + _last_field = nil, + _reconnect = true, + _reconnect_time_millis = 1000, + _sock_builder = sock_builder, + _sock = sock, + _extra_headers = extra_headers, + _parse_buffers = { + ["data"] = "", + ["id"] = "", + ["event"] = "", + }, + _listeners = { + [EventSource.EventTypes.ON_OPEN] = {}, + [EventSource.EventTypes.ON_MESSAGE] = {}, + [EventSource.EventTypes.ON_ERROR] = {} + }, + }, EventSource) + + cosock.spawn(function() + local st_utils = require "st.utils" + while true do + if source.ready_state == EventSource.ReadyStates.CLOSED and + not source._reconnect + then + return + end + local _, action_err, partial = state_actions[source.ready_state](source) + if action_err ~= nil then + if action_err ~= "timeout" or action_err ~= "wantread" then + log.error_with({ hub_logs = true }, "Event Source Coroutine State Machine error: " .. action_err) + if partial ~= nil and #partial > 0 then + log.error_with({ hub_logs = true }, st_utils.stringify_table(partial, "\tReceived Partial", true)) + end + source.ready_state = EventSource.ReadyStates.CLOSED + end + end + end + end) + + return source +end + +--- Close the event source, signalling that a reconnect is not desired +function EventSource:close() + self._reconnect = false + if self._sock ~= nil then + self._sock:close() + end + self._sock = nil + self.ready_state = EventSource.ReadyStates.CLOSED +end + +--- Add a callback to the event source +---@param listener_type string One of "message", "open", or "error" +---@param listener function the callback to be called in case of an event. Open and Error events have no payload. The message event will have a single argument, a table. +function EventSource:add_event_listener(listener_type, listener) + local list = self._listeners[listener_type] + + if list then + table.insert(list, listener) + end +end + +return EventSource diff --git a/drivers/Aqara/aqara-presence-sensor/src/lunchbox/util.lua b/drivers/Aqara/aqara-presence-sensor/src/lunchbox/util.lua new file mode 100644 index 0000000000..3bb5231a96 --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/src/lunchbox/util.lua @@ -0,0 +1,45 @@ +local net_url = require "net.url" +local util = {} + +util.force_url_table = function(url) + if type(url) ~= "table" then url = net_url.parse(url) end + + if not url.port then + if url.scheme == "http" then + url.port = 80 + elseif url.scheme == "https" then + url.port = 443 + end + end + + return url +end + +util.read_only = function(tbl) + if type(tbl) == "table" then + local proxy = {} + local mt = { -- create metatable + __index = tbl, + __newindex = function(t, k, v) error("attempt to update a read-only table", 2) end, + } + setmetatable(proxy, mt) + return proxy + else + return tbl + end +end + +util.iter_string_lines = function(str) + if str:sub(-1) ~= "\n" then str = str .. "\n" end + + return str:gmatch("(.-)\n") +end + +util.copy_data = function(tbl) + local ret = {} + for k, v in pairs(tbl) do ret[k] = v end + + return ret +end + +return util diff --git a/drivers/Aqara/aqara-presence-sensor/src/multipleZonePresence.lua b/drivers/Aqara/aqara-presence-sensor/src/multipleZonePresence.lua new file mode 100644 index 0000000000..7436ce4c59 --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/src/multipleZonePresence.lua @@ -0,0 +1,154 @@ +local capabilities = require "st.capabilities" +local mzp = {} + +mzp.capability = capabilities["multipleZonePresence"] +mzp.id = "multipleZonePresence" +mzp.commands = {} + +mzp.present = "present" +mzp.notPresent = "not present" + +local ZONE_INFO_KEY = "zoneInfo" + +function mzp.findZoneById(driver, device, id) + local zoneInfoTable = device:get_field(ZONE_INFO_KEY) or {} + for index, zoneInfo in pairs(zoneInfoTable) do + if zoneInfo.id == id then + return zoneInfo, index + end + end + return nil, nil +end + +function mzp.findNewZoneId(driver, device) + local maxId = 0 + local zoneInfoTable = device:get_field(ZONE_INFO_KEY) or {} + for _, zoneInfo in pairs(zoneInfoTable) do + local intId = tonumber(zoneInfo.id) + if intId and intId > maxId then + maxId = intId + end + end + return tostring(maxId + 1) +end + +function mzp.createZone(driver, device, name, id) + local err, createdId = nil, nil + local zoneInfo = {} + local zoneInfoTable = device:get_field(ZONE_INFO_KEY) or {} + if id == nil then + id = mzp.findNewZoneId(driver, device) + end + if mzp.findZoneById(driver, device, id) then + err = string.format("id %s already exists", id) + return err, createdId + end + zoneInfo.id = id + zoneInfo.name = name + zoneInfo.state = mzp.notPresent + zoneInfoTable["zone"..id] = zoneInfo + createdId = id + + device:set_field(ZONE_INFO_KEY, zoneInfoTable, { persist = true }) + + return err, createdId +end + +function mzp.deleteZone(driver, device, id) + local err, deletedId = nil, nil + local _, index = mzp.findZoneById(driver, device, id) + if index then + local zoneInfoTable = device:get_field(ZONE_INFO_KEY) or {} + zoneInfoTable[index] = nil + deletedId = id + device:set_field(ZONE_INFO_KEY, zoneInfoTable, { persist = true }) + else + err = string.format("id %s doesn't exist", id) + end + return err, deletedId +end + +function mzp.renameZone(driver, device, id, name) + local err, changedId = nil, nil + local _, index = mzp.findZoneById(driver, device, id) + if index then + local zoneInfoTable = device:get_field(ZONE_INFO_KEY) or {} + zoneInfoTable[index].name = name + changedId = id + device:set_field(ZONE_INFO_KEY, zoneInfoTable, { persist = true }) + else + err = string.format("id %s doesn't exist", id) + end + return err, changedId +end + +function mzp.changeState(driver, device, id, state) + local err, changedId = nil, nil + local zoneInfo, index = mzp.findZoneById(driver, device, id) + if zoneInfo then + local zoneInfoTable = device:get_field(ZONE_INFO_KEY) or {} + zoneInfoTable[index].state = state + changedId = id + device:set_field(ZONE_INFO_KEY, zoneInfoTable, { persist = true }) + else + err = string.format("id %s doesn't exist", id) + end + return err, changedId +end + +function mzp.setZoneInfo(driver, device, inputZoneInfoTable) + --prevents overwriting with a default name ("zone%d"). + local zoneInfoTable = device:get_field(ZONE_INFO_KEY) or {} + for __, inputZoneInfo in pairs(inputZoneInfoTable) do + local zoneInfo, index = mzp.findZoneById(driver, device, inputZoneInfo.id) + if zoneInfo then + if inputZoneInfo.name ~= "zone" .. inputZoneInfo.id then + zoneInfoTable[index].name = inputZoneInfo.name + end + else + local newZoneInfo = {} + newZoneInfo.id = inputZoneInfo.id + newZoneInfo.name = inputZoneInfo.name + newZoneInfo.state = mzp.notPresent + table.insert(zoneInfoTable, newZoneInfo) + end + end + device:set_field(ZONE_INFO_KEY, zoneInfoTable, { persist = true }) +end + +mzp.commands.updateZoneName = {} +mzp.commands.updateZoneName.name = "updateZoneName" +function mzp.commands.updateZoneName.handler(driver, device, args) + local name = args.args.name + local id = args.args.id + mzp.renameZone(driver, device, id, name) + mzp.updateAttribute(driver, device) +end + +mzp.commands.deleteZone = {} +mzp.commands.deleteZone.name = "deleteZone" +function mzp.commands.deleteZone.handler(driver, device, args) + local id = args.args.id + mzp.deleteZone(driver, device, id) + mzp.updateAttribute(driver, device) +end + +mzp.commands.createZone = {} +mzp.commands.createZone.name = "createZone" +function mzp.commands.createZone.handler(driver, device, args) + local name = args.args.name + local id = args.args.id + mzp.createZone(driver, device, name, id) + mzp.updateAttribute(driver, device) +end + +function mzp.updateAttribute(driver, device) + local zoneInfoTable = device:get_field(ZONE_INFO_KEY) or {} + local zoneStatePayload = {} + for _, zoneInfo in pairs(zoneInfoTable) do + table.insert(zoneStatePayload, zoneInfo) + end + device:emit_event(mzp.capability.zoneState({ value = zoneStatePayload })) +end + +return mzp diff --git a/drivers/Aqara/aqara-presence-sensor/src/selfSignedRootByAqaraLife.crt b/drivers/Aqara/aqara-presence-sensor/src/selfSignedRootByAqaraLife.crt new file mode 100644 index 0000000000..2dd6dd333c --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/src/selfSignedRootByAqaraLife.crt @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDRjCCAi6gAwIBAgIBATANBgkqhkiG9w0BAQsFADA6MQswCQYDVQQGEwJLUjET +MBEGA1UECgwKQXFhcmEgTGlmZTEWMBQGA1UEAwwNQXFhcmEgTGlmZSBDQTAgFw0y +NDA1MjAwNjEyMjFaGA8yMTI0MDQyNjA2MTIyMVowOjELMAkGA1UEBhMCS1IxEzAR +BgNVBAoMCkFxYXJhIExpZmUxFjAUBgNVBAMMDUFxYXJhIExpZmUgQ0EwggEiMA0G +CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC6rq5QcW/3saXs+gaRaqmR87KcABnW +7gNhjAodBJzv8TNCxUyXAP+xwSnCKsHoW4I3koA0Cze0FJ/f9u7exuyLr8cF0Sjv +k5EeRD4AQ81OeacS79D3bxgUs7WrU//WJo8HIyrs0qdVNv9p2kDzFuY1Jw0+CEir +SDAA8wpyxKylQhPQalEQIn44WlB9JAIifIK+oRq389JZjWutFXL+TvR/uSOjjLqO +A0bsRV9SFjmXg3YViPI6z09qwQb+CI2qV+6B+5EdB8YSF8SxA6z5p7FTCi+IcEr1 +fGvfopOZAg9KWk+a8p5GyLCh0cl50XlsBxrJyfP5xX1yjkFHvuIL3lNDAgMBAAGj +VTBTMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFNFpFOdYz0nxQDLDbl1X +fk530i/jMAsGA1UdDwQEAwIBBjARBglghkgBhvhCAQEEBAMCAAcwDQYJKoZIhvcN +AQELBQADggEBAKn6/oIfP0EOA8aMo65OZEta8MGwsUm7dNgtG8WkHjeU/y6XZm52 +I/LuYDtXIxaIXj2jspz5GLUxE12PMIkJZVwaDb4g0kACm3X5JKO09W7G9zTn/Jwp +UxpbaHVLfoNUNdQr7Ye3hBtcIcdgH569QGuQVIkDSI3bxEBDgL4v5vZLeqPQV4Il +PnsyuLMxyARpYmj0vYqbSm8brbpTlyeQbOOtgb+sYygDIoSSjmB4Ihk1sm9l+9oP +xPc5vrFx9C/88sweX9/wUzMdTAiGYaXd/w8YDoVEAtQui7oTDKpINnU9gZ481ERr +u9qH5VTGcSDqTc9QdYCf2wqYwKmbkWpUR8g= +-----END CERTIFICATE----- diff --git a/drivers/Aqara/aqara-presence-sensor/src/utils.lua b/drivers/Aqara/aqara-presence-sensor/src/utils.lua new file mode 100644 index 0000000000..4e4f8c2c14 --- /dev/null +++ b/drivers/Aqara/aqara-presence-sensor/src/utils.lua @@ -0,0 +1,175 @@ +local log = require "log" +---@module 'utils' +local utils = {} + + +function utils.str_starts_with(str, start) + return str:sub(1, #start) == start +end + +function utils.is_nan(number) + -- IEEE 754 dictates that NaN compares falsey to everything, including itself. + if number ~= number then + return true + end + + -- If someone passes in something that isn't a Number type, it'll pass the above check. + -- Philosophical question: Something that isn't a number can't technicaly have the value + -- of "nan" but "nan" stands for "not a number", so what do we do here? + if type(number) ~= "number" then + log.warn(string.format("utils.is_nan received value of type %s as argument, returning true", type(number))) + return true + end + + -- In the event that something goes wrong with the above two things, + -- we simply compare the tostring against a known NaN value. + return tostring(number) == tostring(0 / 0) +end + +-- build a exponential backoff time value generator +-- +-- max: the maximum wait interval (not including `rand factor`) +-- inc: the rate at which to exponentially back off +-- rand: a randomization range of (-rand, rand) to be added to each interval +function utils.backoff_builder(max, inc, rand) + local count = 0 + inc = inc or 1 + return function() + local randval = 0 + if rand then + randval = math.random() * rand * 2 - rand + end + + local base = inc * (2 ^ count - 1) + count = count + 1 + + -- ensure base backoff (not including random factor) is less than max + if max then base = math.min(base, max) end + + -- ensure total backoff is >= 0 + return math.max(base + randval, 0) + end +end + +function utils.labeled_socket_builder(label, ssl_config) + local log = require "log" + local socket = require "cosock.socket" + local ssl = require "cosock.ssl" + + label = (label or "") + if #label > 0 then + label = label .. " " + end + + if not ssl_config then + ssl_config = { mode = "client", protocol = "any", verify = "none", options = "all" } + end + + local function make_socket(host, port, wrap_ssl) + log.info( + string.format( + "%sCreating TCP socket for REST Connection", label + ) + ) + local _ = nil + local sock, err = socket.tcp() + + if err ~= nil or (not sock) then + return nil, (err or "unknown error creating TCP socket") + end + log.info( + string.format( + "%sSetting TCP socket timeout for REST Connection", label + ) + ) + _, err = sock:settimeout(60) + if err ~= nil then + return nil, "settimeout error: " .. err + end + log.info( + string.format( + "%sConnecting TCP socket for REST Connection", label + ) + ) + _, err = sock:connect(host, port) + if err ~= nil then + return nil, "Connect error: " .. err + end + log.info( + string.format( + "%sSet Keepalive for TCP socket for REST Connection", label + ) + ) + _, err = sock:setoption("keepalive", true) + if err ~= nil then + return nil, "Setoption error: " .. err + end + + if wrap_ssl then + log.info( + string.format( + "%sCreating SSL wrapper for for REST Connection", label + ) + ) + sock, err = + ssl.wrap(sock, ssl_config) + if err ~= nil then + return nil, "SSL wrap error: " .. err + end + log.info( + string.format( + "%sPerforming SSL handshake for for REST Connection", label + ) + ) + _, err = sock:dohandshake() + if err ~= nil then + return nil, "Error with SSL handshake: " .. err + end + end + log.info( + string.format( + "%sSuccessfully created TCP connection", label + ) + ) + return sock, err + end + return make_socket +end + +--- From https://gist.github.com/sapphyrus/fd9aeb871e3ce966cc4b0b969f62f539 +--- MIT licensed +function utils.deep_table_eq(tbl1, tbl2) + if tbl1 == tbl2 then + return true + elseif type(tbl1) == "table" and type(tbl2) == "table" then + for key1, value1 in pairs(tbl1) do + local value2 = tbl2[key1] + + if value2 == nil then + -- avoid the type call for missing keys in tbl2 by directly comparing with nil + return false + elseif value1 ~= value2 then + if type(value1) == "table" and type(value2) == "table" then + if not utils.deep_table_eq(value1, value2) then + return false + end + else + return false + end + end + end + + -- check for missing keys in tbl1 + for key2, _ in pairs(tbl2) do + if tbl1[key2] == nil then + return false + end + end + + return true + end + + return false +end + +return utils