Skip to content

Commit

Permalink
Traffic Route support with QoL improvements (#7)
Browse files Browse the repository at this point in the history
- feat(config): update interval is now configurable
- feat(routes): added new api functions for traffic route support
- feat(routes): new device and switch entities
- feat(tests): added test coverage for new traffic route functions
  • Loading branch information
sirkirby authored Nov 5, 2024
1 parent f399426 commit 9cf9d31
Show file tree
Hide file tree
Showing 8 changed files with 449 additions and 54 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Unifi Network Rules Custom Integration

Pulls firewall and traffic rules from your Unifi Dream Machine and allows you to enable/disable them in Home Assistant.
Pulls firewall, traffic rules, and traffic routes from your Unifi Dream Machine and allows you to enable/disable them in Home Assistant.

## Installation

Expand Down Expand Up @@ -28,7 +28,7 @@ THEN

## Usage

Once you have configured the integration, you will be able to see the firewall rules configured on your Unifi Network as switches in Home Assistant. Add the switch to a custom dashboard or use it in automations just like any other Home Assistant switch.
Once you have configured the integration, you will be able to see the firewall rules and traffic routes configured on your Unifi Network as switches in Home Assistant. Add the switch to a custom dashboard or use it in automations just like any other Home Assistant switch.

## Local Development

Expand All @@ -48,7 +48,7 @@ pytest tests

## Limitations

The integration is currently limited to managing firewall and traffic rules. It does not currently support managing other types of rules.
The integration is currently limited to managing firewall, traffic rules, and traffic routes. It does not currently support managing other types of rules.

## Contributions

Expand Down
17 changes: 12 additions & 5 deletions custom_components/unifi_network_rules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
from homeassistant.helpers import config_validation as cv
from homeassistant.exceptions import ConfigEntryNotReady

from .const import DOMAIN, CONF_MAX_RETRIES, CONF_RETRY_DELAY, DEFAULT_MAX_RETRIES, DEFAULT_RETRY_DELAY
from .const import DOMAIN, CONF_MAX_RETRIES, CONF_RETRY_DELAY, DEFAULT_MAX_RETRIES, DEFAULT_RETRY_DELAY, CONF_UPDATE_INTERVAL, DEFAULT_UPDATE_INTERVAL
from .udm_api import UDMAPI

_LOGGER = logging.getLogger(__name__)

PLATFORMS: list[str] = ["switch"]
UPDATE_INTERVAL = timedelta(minutes=5)

CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)

Expand All @@ -29,6 +28,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
host = entry.data[CONF_HOST]
username = entry.data[CONF_USERNAME]
password = entry.data[CONF_PASSWORD]
update_interval = entry.data.get(CONF_UPDATE_INTERVAL, DEFAULT_UPDATE_INTERVAL)
max_retries = entry.data.get(CONF_MAX_RETRIES, DEFAULT_MAX_RETRIES)
retry_delay = entry.data.get(CONF_RETRY_DELAY, DEFAULT_RETRY_DELAY)

Expand All @@ -44,13 +44,20 @@ async def async_update_data():
try:
traffic_success, traffic_rules, traffic_error = await api.get_traffic_rules()
firewall_success, firewall_rules, firewall_error = await api.get_firewall_rules()
routes_success, traffic_routes, routes_error = await api.get_traffic_routes()

if not traffic_success:
raise Exception(f"Failed to fetch traffic rules: {traffic_error}")
if not firewall_success:
raise Exception(f"Failed to fetch firewall rules: {firewall_error}")

return {"traffic_rules": traffic_rules, "firewall_rules": firewall_rules}
if not routes_success:
raise Exception(f"Failed to fetch traffic routes: {routes_error}")

return {
"traffic_rules": traffic_rules,
"firewall_rules": firewall_rules,
"traffic_routes": traffic_routes
}
except Exception as e:
_LOGGER.error(f"Error updating data: {str(e)}")
raise
Expand All @@ -60,7 +67,7 @@ async def async_update_data():
_LOGGER,
name="udm_rule_manager",
update_method=async_update_data,
update_interval=UPDATE_INTERVAL,
update_interval=timedelta(minutes=update_interval),
)

# Fetch initial data
Expand Down
139 changes: 98 additions & 41 deletions custom_components/unifi_network_rules/config_flow.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,112 @@
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.core import HomeAssistant
from homeassistant.const import CONF_HOST, CONF_USERNAME, CONF_PASSWORD
from homeassistant.helpers import config_validation as cv
from homeassistant import config_entries, core, exceptions
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
import homeassistant.helpers.config_validation as cv
from .const import DOMAIN, CONF_UPDATE_INTERVAL, DEFAULT_UPDATE_INTERVAL
import logging
from homeassistant.helpers.entity import EntityDescription
from ipaddress import ip_address
import re

from .const import DOMAIN, CONF_MAX_RETRIES, CONF_RETRY_DELAY, DEFAULT_MAX_RETRIES, DEFAULT_RETRY_DELAY
from .udm_api import UDMAPI
_LOGGER = logging.getLogger(__name__)

class UDMRuleManagerConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Unifi Network Rules."""
# Define entity descriptions for entities used in this integration
ENTITY_DESCRIPTIONS = {
"update_interval": EntityDescription(
key="update_interval",
name="Update Interval",
icon="mdi:update",
entity_category="config",
)
}

# Define a schema for configuration, adding basic validation
DATA_SCHEMA = vol.Schema({
vol.Required(CONF_HOST): cv.string,
vol.Required(CONF_USERNAME): cv.string,
vol.Required(CONF_PASSWORD): cv.string,
vol.Optional(CONF_UPDATE_INTERVAL, default=DEFAULT_UPDATE_INTERVAL): vol.All(vol.Coerce(int), vol.Range(min=1, max=1440)),
})

async def validate_input(hass: core.HomeAssistant, data: dict):
"""
Validate the user input allows us to connect.
Data has the keys from DATA_SCHEMA with values provided by the user.
"""
host = data[CONF_HOST]
username = data[CONF_USERNAME]
password = data[CONF_PASSWORD]

# Validate host (IP address or domain name)
try:
ip_address(host)
except ValueError:
# If it's not a valid IP address, check if it's a valid domain name
if not re.match(r'^[a-zA-Z0-9-]+(\.[a-zA-Z0-9-]+)*\.[a-zA-Z]{2,}$', host):
raise InvalidHost

return {"title": f"Unifi Network Manager ({host})"}

class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""
Handle a config flow for Unifi Network Rule Manager.
"""
VERSION = 1
CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_POLL

async def async_step_user(self, user_input=None):
"""Handle the initial step."""
"""
Handle the initial step of the config flow.
"""
errors = {}

if user_input is not None:
api = UDMAPI(
user_input[CONF_HOST],
user_input[CONF_USERNAME],
user_input[CONF_PASSWORD],
max_retries=user_input.get(CONF_MAX_RETRIES, DEFAULT_MAX_RETRIES),
retry_delay=user_input.get(CONF_RETRY_DELAY, DEFAULT_RETRY_DELAY)
)
success, error_message = await api.login()
if success:
return self.async_create_entry(title="Unifi Network Rules", data=user_input)
else:
try:
if CONF_UPDATE_INTERVAL in user_input:
update_interval = user_input[CONF_UPDATE_INTERVAL]
if not isinstance(update_interval, int) or update_interval < 1 or update_interval > 1440:
raise InvalidUpdateInterval
info = await validate_input(self.hass, user_input)
return self.async_create_entry(title=info["title"], data=user_input)
except InvalidAuth:
errors["base"] = "invalid_auth"
except CannotConnect:
errors["base"] = "cannot_connect"
if error_message:
errors["base_info"] = error_message
except InvalidUpdateInterval:
errors["base"] = "invalid_update_interval"
except InvalidHost:
errors["base"] = "invalid_host"
except vol.Invalid as vol_error:
_LOGGER.error("Validation error: %s", vol_error)
errors["base"] = "invalid_format"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"

return self.async_show_form(
step_id="user",
data_schema=vol.Schema(
{
vol.Required(CONF_HOST): str,
vol.Required(CONF_USERNAME): str,
vol.Required(CONF_PASSWORD): str,
vol.Optional(CONF_MAX_RETRIES, default=DEFAULT_MAX_RETRIES): vol.All(
vol.Coerce(int), vol.Range(min=1, max=10)
),
vol.Optional(CONF_RETRY_DELAY, default=DEFAULT_RETRY_DELAY): vol.All(
vol.Coerce(int), vol.Range(min=1, max=60)
),
}
),
errors=errors,
step_id="user", data_schema=DATA_SCHEMA, errors=errors
)

async def async_step_import(self, import_config):
"""Handle import from configuration.yaml."""
return await self.async_step_user(import_config)
class CannotConnect(exceptions.HomeAssistantError):
"""
Error to indicate we cannot connect.
"""
pass

class InvalidAuth(exceptions.HomeAssistantError):
"""
Error to indicate there is invalid auth.
"""
pass

class InvalidHost(exceptions.HomeAssistantError):
"""
Error to indicate there is invalid host address.
"""
pass

class InvalidUpdateInterval(exceptions.HomeAssistantError):
"""
Error to indicate the update interval is invalid.
"""
pass
5 changes: 4 additions & 1 deletion custom_components/unifi_network_rules/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@
CONF_RETRY_DELAY = "retry_delay"

DEFAULT_MAX_RETRIES = 3
DEFAULT_RETRY_DELAY = 1
DEFAULT_RETRY_DELAY = 1

CONF_UPDATE_INTERVAL = "update_interval"
DEFAULT_UPDATE_INTERVAL = 5
2 changes: 1 addition & 1 deletion custom_components/unifi_network_rules/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
"iot_class": "local_polling",
"issue_tracker": "https://github.com/sirkirby/ha_udm_rule_manager/issues",
"requirements": ["aiohttp"],
"version": "0.2.0"
"version": "0.3.0"
}
106 changes: 105 additions & 1 deletion custom_components/unifi_network_rules/switch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry, async_add_e
for rule in coordinator.data['firewall_rules']:
entities.append(UDMFirewallRuleSwitch(coordinator, api, rule))

if coordinator.data.get('traffic_routes'):
for route in coordinator.data['traffic_routes']:
entities.append(UDMTrafficRouteSwitch(coordinator, api, route))

async_add_entities(entities, True)

class UDMRuleSwitch(CoordinatorEntity, SwitchEntity):
Expand Down Expand Up @@ -103,4 +107,104 @@ class UDMFirewallRuleSwitch(UDMRuleSwitch):

def __init__(self, coordinator, api, rule):
"""Initialize the UDM Firewall Rule Switch."""
super().__init__(coordinator, api, rule, 'firewall')
super().__init__(coordinator, api, rule, 'firewall')

class UDMTrafficRouteSwitch(CoordinatorEntity, SwitchEntity):
"""Representation of a UDM Traffic Route Switch."""

def __init__(self, coordinator, api, route):
"""Initialize the UDM Traffic Route Switch."""
super().__init__(coordinator)
self._api = api
self._attr_unique_id = f"traffic_route_{route['_id']}"
self._attr_name = f"Traffic Route: {route.get('description', 'Unnamed')}"

# Store route details for device info
self._route = route

@property
def is_on(self):
"""Return true if the switch is on."""
route = self._get_route()
return route['enabled'] if route else False

@property
def device_info(self):
"""Return device info for this traffic route."""
return {
"identifiers": {(DOMAIN, self._attr_unique_id)},
"name": self._attr_name,
"manufacturer": "Ubiquiti",
"model": "Traffic Route",
"sw_version": None,
}

async def async_turn_on(self, **kwargs):
"""Turn the switch on."""
await self._toggle(True)

async def async_turn_off(self, **kwargs):
"""Turn the switch off."""
await self._toggle(False)

async def _toggle(self, new_state):
"""Toggle the route state."""
route = self._get_route()
if not route:
raise HomeAssistantError("Traffic route not found")

_LOGGER.debug(f"Attempting to set traffic route {route['_id']} to {'on' if new_state else 'off'}")

try:
success, error_message = await self._api.toggle_traffic_route(route['_id'], new_state)
if success:
_LOGGER.info(f"Successfully set traffic route {route['_id']} to {'on' if new_state else 'off'}")
await self.coordinator.async_request_refresh()
else:
_LOGGER.error(f"Failed to set traffic route {route['_id']} to {'on' if new_state else 'off'}. Error: {error_message}")
raise HomeAssistantError(f"Failed to toggle traffic route: {error_message}")

except Exception as e:
_LOGGER.error(f"Error toggling traffic route {route['_id']}: {str(e)}")
raise HomeAssistantError(f"Error toggling traffic route: {str(e)}")

def _get_route(self):
"""Get the current route from the coordinator data."""
routes = self.coordinator.data.get('traffic_routes', [])
route_id = self._attr_unique_id.split('_')[-1]
for route in routes:
if route['_id'] == route_id:
return route
return None

@property
def extra_state_attributes(self):
"""Return additional state attributes."""
route = self._get_route()
if not route:
return {}

attributes = {
"description": route.get("description", ""),
"matching_target": route.get("matching_target", ""),
"network_id": route.get("network_id", ""),
"kill_switch_enabled": route.get("kill_switch_enabled", False),
}

# Add domain information if available
if route.get("domains"):
attributes["domains"] = [d.get("domain") for d in route["domains"]]

# Add target devices information
if route.get("target_devices"):
devices = []
for device in route["target_devices"]:
if device.get("type") == "ALL_CLIENTS":
devices.append("ALL_CLIENTS")
elif device.get("type") == "NETWORK":
devices.append(f"NETWORK: {device.get('network_id')}")
else:
devices.append(device.get("client_mac", ""))
attributes["target_devices"] = devices

return attributes
Loading

0 comments on commit 9cf9d31

Please sign in to comment.