diff --git a/.github/workflows/testing.yaml b/.github/workflows/testing.yaml index a223817..181ce3f 100644 --- a/.github/workflows/testing.yaml +++ b/.github/workflows/testing.yaml @@ -34,6 +34,10 @@ jobs: pip install --upgrade pip pip install tox + - name: Install Home Assistant testing platform + run: | + ./scripts/download_fixtures.sh $(curl --silent "https://api.github.com/repos/home-assistant/core/releases/latest" | grep -Po "(?<=\"tag_name\": \").*(?=\")") + - name: Test with tox environments run: tox diff --git a/.gitignore b/.gitignore index ec31ee6..880ad43 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ -# Home Assistant configuration +# Home Assistant config/ +tests/hass/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/README.md b/README.md index dbd4c90..e368dba 100644 --- a/README.md +++ b/README.md @@ -134,13 +134,21 @@ To create a virtual environment and install the project and its dependencies, ex terminal: ```bash -# Create and activate a new virtual environment -python3 -m venv venv +# Initialize the environment with the latest version of Home Assistant +E_HASS_VERSION=$(curl --silent "https://api.github.com/repos/home-assistant/core/releases/latest" | grep -Po "(?<=\"tag_name\": \").*(?=\")") +./scripts/init $E_HASS_VERSION source venv/bin/activate -# Upgrade pip and install all projects and their dependencies -pip install --upgrade pip -pip install -e '.[all]' +# Install pre-commit hooks +pre-commit install +``` + +Instead, if you want to develop and test this integration with a different Home Assistant version, just pass the +version to the init script: +```bash +# Initialize the environment Home Assistant 2024.1.1 +./scripts/init 2024.1.1 +source venv/bin/activate # Install pre-commit hooks pre-commit install diff --git a/scripts/download_fixtures.sh b/scripts/download_fixtures.sh new file mode 100755 index 0000000..b545a90 --- /dev/null +++ b/scripts/download_fixtures.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# +# This script downloads a specified version of Home Assistant, +# extracts its test suite, processes the files, and moves them into a +# local 'tests/hass' directory for further use. +# +# Usage: ./scripts/download_fixtures.sh +# Example: ./scripts/download_fixtures.sh 2021.3.4 +set -e + +# Parameters +VERSION=$1 + +# Abort if no version is specified +if [ -z "$VERSION" ]; then + echo "Usage: ./scripts/download_fixtures.sh " + exit 1 +fi + +# Variables +DOWNLOAD_FOLDER=$(mktemp -d) +HASS_TESTS_FOLDER=$DOWNLOAD_FOLDER/core-$VERSION/tests/ + +# Remove previous folder if exists +if [ -d "tests/hass" ]; then + echo "Removing previous tests/hass/ folder" + rm -rf tests/hass +fi + +# Download HASS version +echo "Downloading Home Assistant $VERSION in $DOWNLOAD_FOLDER" +curl -L https://github.com/home-assistant/core/archive/refs/tags/$VERSION.tar.gz -o $DOWNLOAD_FOLDER/$VERSION.tar.gz + +# Extract HASS fixtures and tests helpers, excluding all components and actual tests +echo "Extracting tests/ folder from $VERSION.tar.gz" +tar -C $DOWNLOAD_FOLDER --exclude='*/components/*' --exclude='*/pylint/*' -xf $DOWNLOAD_FOLDER/$VERSION.tar.gz core-$VERSION/tests +find $HASS_TESTS_FOLDER -type f -name "test_*.py" -delete + +# Recursively find and update imports +find $HASS_TESTS_FOLDER -type f -exec sed -i 's/from tests\./from tests.hass./g' {} + +mv $HASS_TESTS_FOLDER/conftest.py $HASS_TESTS_FOLDER/fixtures.py + +# Copy Home Assistant fixtures +mv $HASS_TESTS_FOLDER ./tests/hass +echo "Home Assistant $VERSION tests are now in tests/hass/" diff --git a/scripts/init b/scripts/init new file mode 100755 index 0000000..c1ab4f9 --- /dev/null +++ b/scripts/init @@ -0,0 +1,29 @@ +#!/bin/bash +set -e + +# Parameters +VERSION=$1 + +# Abort if no version is specified +if [ -z "$VERSION" ]; then + echo "Usage: ./scripts/init.sh " + exit 1 +fi + +# Abort if `venv` folder already exists +if [ -d "venv" ]; then + echo "venv/ folder already exists. Deactivate your venv and remove venv/ folder." + exit 1 +fi + +# Create and activate a new virtual environment +python3 -m venv venv +source venv/bin/activate + +# Upgrade pip and install all projects and their dependencies +pip install --upgrade pip +pip install -e '.[all]' + +# Override Home Assistant version +pip install homeassistant==$VERSION +./scripts/download_fixtures.sh $VERSION diff --git a/tests/hass/__init__.py b/tests/hass/__init__.py deleted file mode 100644 index 35d25f2..0000000 --- a/tests/hass/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for Home Assistant.""" diff --git a/tests/hass/common.py b/tests/hass/common.py deleted file mode 100644 index 48bb383..0000000 --- a/tests/hass/common.py +++ /dev/null @@ -1,1447 +0,0 @@ -"""Test the helper method for writing tests.""" -from __future__ import annotations - -import asyncio -from collections import OrderedDict -from collections.abc import Generator, Mapping, Sequence -from contextlib import contextmanager -from datetime import UTC, datetime, timedelta -import functools as ft -from functools import lru_cache -from io import StringIO -import json -import logging -import os -import pathlib -import threading -import time -from typing import Any, NoReturn -from unittest.mock import AsyncMock, Mock, patch - -from aiohttp.test_utils import unused_port as get_test_instance_port # noqa: F401 -import voluptuous as vol - -from homeassistant import auth, bootstrap, config_entries, loader -from homeassistant.auth import ( - auth_store, - models as auth_models, - permissions as auth_permissions, - providers as auth_providers, -) -from homeassistant.auth.permissions import system_policies -from homeassistant.components import device_automation, persistent_notification as pn -from homeassistant.components.device_automation import ( # noqa: F401 - _async_get_device_automation_capabilities as async_get_device_automation_capabilities, -) -from homeassistant.config import async_process_component_config -from homeassistant.config_entries import ConfigFlow -from homeassistant.const import ( - DEVICE_DEFAULT_NAME, - EVENT_HOMEASSISTANT_CLOSE, - EVENT_HOMEASSISTANT_STOP, - EVENT_STATE_CHANGED, - STATE_OFF, - STATE_ON, -) -from homeassistant.core import ( - CoreState, - Event, - HomeAssistant, - ServiceCall, - ServiceResponse, - State, - SupportsResponse, - callback, -) -from homeassistant.helpers import ( - area_registry as ar, - device_registry as dr, - entity, - entity_platform, - entity_registry as er, - event, - intent, - issue_registry as ir, - recorder as recorder_helper, - restore_state, - restore_state as rs, - storage, -) -from homeassistant.helpers.dispatcher import async_dispatcher_connect -from homeassistant.helpers.json import JSONEncoder, _orjson_default_encoder -from homeassistant.helpers.typing import ConfigType, StateType -from homeassistant.setup import setup_component -from homeassistant.util.async_ import run_callback_threadsafe -import homeassistant.util.dt as dt_util -from homeassistant.util.json import ( - JsonArrayType, - JsonObjectType, - JsonValueType, - json_loads, - json_loads_array, - json_loads_object, -) -from homeassistant.util.unit_system import METRIC_SYSTEM -import homeassistant.util.uuid as uuid_util -import homeassistant.util.yaml.loader as yaml_loader - -_LOGGER = logging.getLogger(__name__) -INSTANCES = [] -CLIENT_ID = "https://example.com/app" -CLIENT_REDIRECT_URI = "https://example.com/app/callback" - - -async def async_get_device_automations( - hass: HomeAssistant, - automation_type: device_automation.DeviceAutomationType, - device_id: str, -) -> Any: - """Get a device automation for a single device id.""" - automations = await device_automation.async_get_device_automations( - hass, automation_type, [device_id] - ) - return automations.get(device_id) - - -def threadsafe_callback_factory(func): - """Create threadsafe functions out of callbacks. - - Callback needs to have `hass` as first argument. - """ - - @ft.wraps(func) - def threadsafe(*args, **kwargs): - """Call func threadsafe.""" - hass = args[0] - return run_callback_threadsafe( - hass.loop, ft.partial(func, *args, **kwargs) - ).result() - - return threadsafe - - -def threadsafe_coroutine_factory(func): - """Create threadsafe functions out of coroutine. - - Callback needs to have `hass` as first argument. - """ - - @ft.wraps(func) - def threadsafe(*args, **kwargs): - """Call func threadsafe.""" - hass = args[0] - return asyncio.run_coroutine_threadsafe( - func(*args, **kwargs), hass.loop - ).result() - - return threadsafe - - -def get_test_config_dir(*add_path): - """Return a path to a test config dir.""" - return os.path.join(os.path.dirname(__file__), "testing_config", *add_path) - - -def get_test_home_assistant(): - """Return a Home Assistant object pointing at test config directory.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - hass = loop.run_until_complete(async_test_home_assistant(loop)) - - loop_stop_event = threading.Event() - - def run_loop(): - """Run event loop.""" - - loop._thread_ident = threading.get_ident() - loop.run_forever() - loop_stop_event.set() - - orig_stop = hass.stop - hass._stopped = Mock(set=loop.stop) - - def start_hass(*mocks): - """Start hass.""" - asyncio.run_coroutine_threadsafe(hass.async_start(), loop).result() - - def stop_hass(): - """Stop hass.""" - orig_stop() - loop_stop_event.wait() - loop.close() - - hass.start = start_hass - hass.stop = stop_hass - - threading.Thread(name="LoopThread", target=run_loop, daemon=False).start() - - return hass - - -async def async_test_home_assistant(event_loop, load_registries=True): - """Return a Home Assistant object pointing at test config dir.""" - hass = HomeAssistant(get_test_config_dir()) - store = auth_store.AuthStore(hass) - hass.auth = auth.AuthManager(hass, store, {}, {}) - ensure_auth_manager_loaded(hass.auth) - INSTANCES.append(hass) - - orig_async_add_job = hass.async_add_job - orig_async_add_executor_job = hass.async_add_executor_job - orig_async_create_task = hass.async_create_task - - def async_add_job(target, *args): - """Add job.""" - check_target = target - while isinstance(check_target, ft.partial): - check_target = check_target.func - - if isinstance(check_target, Mock) and not isinstance(target, AsyncMock): - fut = asyncio.Future() - fut.set_result(target(*args)) - return fut - - return orig_async_add_job(target, *args) - - def async_add_executor_job(target, *args): - """Add executor job.""" - check_target = target - while isinstance(check_target, ft.partial): - check_target = check_target.func - - if isinstance(check_target, Mock): - fut = asyncio.Future() - fut.set_result(target(*args)) - return fut - - return orig_async_add_executor_job(target, *args) - - def async_create_task(coroutine, name=None): - """Create task.""" - if isinstance(coroutine, Mock) and not isinstance(coroutine, AsyncMock): - fut = asyncio.Future() - fut.set_result(None) - return fut - - return orig_async_create_task(coroutine, name) - - hass.async_add_job = async_add_job - hass.async_add_executor_job = async_add_executor_job - hass.async_create_task = async_create_task - - hass.data[loader.DATA_CUSTOM_COMPONENTS] = {} - - hass.config.location_name = "test home" - hass.config.latitude = 32.87336 - hass.config.longitude = -117.22743 - hass.config.elevation = 0 - hass.config.set_time_zone("US/Pacific") - hass.config.units = METRIC_SYSTEM - hass.config.media_dirs = {"local": get_test_config_dir("media")} - hass.config.skip_pip = True - hass.config.skip_pip_packages = [] - - hass.config_entries = config_entries.ConfigEntries( - hass, - { - "_": ( - "Not empty or else some bad checks for hass config in discovery.py" - " breaks" - ) - }, - ) - hass.bus.async_listen_once( - EVENT_HOMEASSISTANT_STOP, hass.config_entries._async_shutdown - ) - - # Load the registries - entity.async_setup(hass) - loader.async_setup(hass) - if load_registries: - with patch( - "homeassistant.helpers.storage.Store.async_load", return_value=None - ), patch( - "homeassistant.helpers.restore_state.RestoreStateData.async_setup_dump", - return_value=None, - ), patch( - "homeassistant.helpers.restore_state.start.async_at_start" - ): - await asyncio.gather( - ar.async_load(hass), - dr.async_load(hass), - er.async_load(hass), - ir.async_load(hass), - rs.async_load(hass), - ) - hass.data[bootstrap.DATA_REGISTRIES_LOADED] = None - - hass.state = CoreState.running - - @callback - def clear_instance(event): - """Clear global instance.""" - INSTANCES.remove(hass) - - hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, clear_instance) - - return hass - - -def async_mock_service( - hass: HomeAssistant, - domain: str, - service: str, - schema: vol.Schema | None = None, - response: ServiceResponse = None, - supports_response: SupportsResponse | None = None, -) -> list[ServiceCall]: - """Set up a fake service & return a calls log list to this service.""" - calls = [] - - @callback - def mock_service_log(call): # pylint: disable=unnecessary-lambda - """Mock service call.""" - calls.append(call) - return response - - if supports_response is None and response is not None: - supports_response = SupportsResponse.OPTIONAL - - hass.services.async_register( - domain, - service, - mock_service_log, - schema=schema, - supports_response=supports_response, - ) - - return calls - - -mock_service = threadsafe_callback_factory(async_mock_service) - - -@callback -def async_mock_intent(hass, intent_typ): - """Set up a fake intent handler.""" - intents = [] - - class MockIntentHandler(intent.IntentHandler): - intent_type = intent_typ - - async def async_handle(self, intent): - """Handle the intent.""" - intents.append(intent) - return intent.create_response() - - intent.async_register(hass, MockIntentHandler()) - - return intents - - -@callback -def async_fire_mqtt_message( - hass: HomeAssistant, - topic: str, - payload: bytes | str, - qos: int = 0, - retain: bool = False, -) -> None: - """Fire the MQTT message.""" - # Local import to avoid processing MQTT modules when running a testcase - # which does not use MQTT. - - # pylint: disable-next=import-outside-toplevel - from paho.mqtt.client import MQTTMessage - - # pylint: disable-next=import-outside-toplevel - from homeassistant.components.mqtt.models import MqttData - - if isinstance(payload, str): - payload = payload.encode("utf-8") - - msg = MQTTMessage(topic=topic.encode("utf-8")) - msg.payload = payload - msg.qos = qos - msg.retain = retain - - mqtt_data: MqttData = hass.data["mqtt"] - assert mqtt_data.client - mqtt_data.client._mqtt_handle_message(msg) - - -fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message) - - -@callback -def async_fire_time_changed_exact( - hass: HomeAssistant, datetime_: datetime | None = None, fire_all: bool = False -) -> None: - """Fire a time changed event at an exact microsecond. - - Consider that it is not possible to actually achieve an exact - microsecond in production as the event loop is not precise enough. - If your code relies on this level of precision, consider a different - approach, as this is only for testing. - """ - if datetime_ is None: - utc_datetime = datetime.now(UTC) - else: - utc_datetime = dt_util.as_utc(datetime_) - - _async_fire_time_changed(hass, utc_datetime, fire_all) - - -@callback -def async_fire_time_changed( - hass: HomeAssistant, datetime_: datetime | None = None, fire_all: bool = False -) -> None: - """Fire a time changed event. - - If called within the first 500 ms of a second, time will be bumped to exactly - 500 ms to match the async_track_utc_time_change event listeners and - DataUpdateCoordinator which spreads all updates between 0.05..0.50. - Background in PR https://github.com/home-assistant/core/pull/82233 - - As asyncio is cooperative, we can't guarantee that the event loop will - run an event at the exact time we want. If you need to fire time changed - for an exact microsecond, use async_fire_time_changed_exact. - """ - if datetime_ is None: - utc_datetime = datetime.now(UTC) - else: - utc_datetime = dt_util.as_utc(datetime_) - - # Increase the mocked time by 0.5 s to account for up to 0.5 s delay - # added to events scheduled by update_coordinator and async_track_time_interval - utc_datetime += timedelta(microseconds=event.RANDOM_MICROSECOND_MAX) - - _async_fire_time_changed(hass, utc_datetime, fire_all) - - -_MONOTONIC_RESOLUTION = time.get_clock_info("monotonic").resolution - - -@callback -def _async_fire_time_changed( - hass: HomeAssistant, utc_datetime: datetime | None, fire_all: bool -) -> None: - timestamp = dt_util.utc_to_timestamp(utc_datetime) - for task in list(hass.loop._scheduled): - if not isinstance(task, asyncio.TimerHandle): - continue - if task.cancelled(): - continue - - mock_seconds_into_future = timestamp - time.time() - future_seconds = task.when() - (hass.loop.time() + _MONOTONIC_RESOLUTION) - - if fire_all or mock_seconds_into_future >= future_seconds: - with patch( - "homeassistant.helpers.event.time_tracker_utcnow", - return_value=utc_datetime, - ), patch( - "homeassistant.helpers.event.time_tracker_timestamp", - return_value=timestamp, - ): - task._run() - task.cancel() - - -fire_time_changed = threadsafe_callback_factory(async_fire_time_changed) - - -def get_fixture_path(filename: str, integration: str | None = None) -> pathlib.Path: - """Get path of fixture.""" - if integration is None and "/" in filename and not filename.startswith("helpers/"): - integration, filename = filename.split("/", 1) - - if integration is None: - return pathlib.Path(__file__).parent.joinpath("fixtures", filename) - - return pathlib.Path(__file__).parent.joinpath( - "components", integration, "fixtures", filename - ) - - -@lru_cache -def load_fixture(filename: str, integration: str | None = None) -> str: - """Load a fixture.""" - return get_fixture_path(filename, integration).read_text() - - -def load_json_value_fixture( - filename: str, integration: str | None = None -) -> JsonValueType: - """Load a JSON value from a fixture.""" - return json_loads(load_fixture(filename, integration)) - - -def load_json_array_fixture( - filename: str, integration: str | None = None -) -> JsonArrayType: - """Load a JSON array from a fixture.""" - return json_loads_array(load_fixture(filename, integration)) - - -def load_json_object_fixture( - filename: str, integration: str | None = None -) -> JsonObjectType: - """Load a JSON object from a fixture.""" - return json_loads_object(load_fixture(filename, integration)) - - -def mock_state_change_event( - hass: HomeAssistant, new_state: State, old_state: State | None = None -) -> None: - """Mock state change envent.""" - event_data = {"entity_id": new_state.entity_id, "new_state": new_state} - - if old_state: - event_data["old_state"] = old_state - - hass.bus.fire(EVENT_STATE_CHANGED, event_data, context=new_state.context) - - -@callback -def mock_component(hass: HomeAssistant, component: str) -> None: - """Mock a component is setup.""" - if component in hass.config.components: - AssertionError(f"Integration {component} is already setup") - - hass.config.components.add(component) - - -def mock_registry( - hass: HomeAssistant, - mock_entries: dict[str, er.RegistryEntry] | None = None, -) -> er.EntityRegistry: - """Mock the Entity Registry. - - This should only be used if you need to mock/re-stage a clean mocked - entity registry in your current hass object. It can be useful to, - for example, pre-load the registry with items. - - This mock will thus replace the existing registry in the running hass. - - If you just need to access the existing registry, use the `entity_registry` - fixture instead. - """ - registry = er.EntityRegistry(hass) - if mock_entries is None: - mock_entries = {} - registry.deleted_entities = {} - registry.entities = er.EntityRegistryItems() - registry._entities_data = registry.entities.data - for key, entry in mock_entries.items(): - registry.entities[key] = entry - - hass.data[er.DATA_REGISTRY] = registry - return registry - - -def mock_area_registry( - hass: HomeAssistant, mock_entries: dict[str, ar.AreaEntry] | None = None -) -> ar.AreaRegistry: - """Mock the Area Registry. - - This should only be used if you need to mock/re-stage a clean mocked - area registry in your current hass object. It can be useful to, - for example, pre-load the registry with items. - - This mock will thus replace the existing registry in the running hass. - - If you just need to access the existing registry, use the `area_registry` - fixture instead. - """ - registry = ar.AreaRegistry(hass) - registry.areas = mock_entries or OrderedDict() - - hass.data[ar.DATA_REGISTRY] = registry - return registry - - -def mock_device_registry( - hass: HomeAssistant, - mock_entries: dict[str, dr.DeviceEntry] | None = None, -) -> dr.DeviceRegistry: - """Mock the Device Registry. - - This should only be used if you need to mock/re-stage a clean mocked - device registry in your current hass object. It can be useful to, - for example, pre-load the registry with items. - - This mock will thus replace the existing registry in the running hass. - - If you just need to access the existing registry, use the `device_registry` - fixture instead. - """ - registry = dr.DeviceRegistry(hass) - registry.devices = dr.DeviceRegistryItems() - registry._device_data = registry.devices.data - if mock_entries is None: - mock_entries = {} - for key, entry in mock_entries.items(): - registry.devices[key] = entry - registry.deleted_devices = dr.DeviceRegistryItems() - - hass.data[dr.DATA_REGISTRY] = registry - return registry - - -class MockGroup(auth_models.Group): - """Mock a group in Home Assistant.""" - - def __init__(self, id=None, name="Mock Group", policy=system_policies.ADMIN_POLICY): - """Mock a group.""" - kwargs = {"name": name, "policy": policy} - if id is not None: - kwargs["id"] = id - - super().__init__(**kwargs) - - def add_to_hass(self, hass): - """Test helper to add entry to hass.""" - return self.add_to_auth_manager(hass.auth) - - def add_to_auth_manager(self, auth_mgr): - """Test helper to add entry to hass.""" - ensure_auth_manager_loaded(auth_mgr) - auth_mgr._store._groups[self.id] = self - return self - - -class MockUser(auth_models.User): - """Mock a user in Home Assistant.""" - - def __init__( - self, - id=None, - is_owner=False, - is_active=True, - name="Mock User", - system_generated=False, - groups=None, - ): - """Initialize mock user.""" - kwargs = { - "is_owner": is_owner, - "is_active": is_active, - "name": name, - "system_generated": system_generated, - "groups": groups or [], - "perm_lookup": None, - } - if id is not None: - kwargs["id"] = id - super().__init__(**kwargs) - - def add_to_hass(self, hass): - """Test helper to add entry to hass.""" - return self.add_to_auth_manager(hass.auth) - - def add_to_auth_manager(self, auth_mgr): - """Test helper to add entry to hass.""" - ensure_auth_manager_loaded(auth_mgr) - auth_mgr._store._users[self.id] = self - return self - - def mock_policy(self, policy): - """Mock a policy for a user.""" - self._permissions = auth_permissions.PolicyPermissions(policy, self.perm_lookup) - - -async def register_auth_provider( - hass: HomeAssistant, config: ConfigType -) -> auth_providers.AuthProvider: - """Register an auth provider.""" - provider = await auth_providers.auth_provider_from_config( - hass, hass.auth._store, config - ) - assert provider is not None, "Invalid config specified" - key = (provider.type, provider.id) - providers = hass.auth._providers - - if key in providers: - raise ValueError("Provider already registered") - - providers[key] = provider - return provider - - -@callback -def ensure_auth_manager_loaded(auth_mgr): - """Ensure an auth manager is considered loaded.""" - store = auth_mgr._store - if store._users is None: - store._set_defaults() - - -class MockModule: - """Representation of a fake module.""" - - def __init__( - self, - domain=None, - dependencies=None, - setup=None, - requirements=None, - config_schema=None, - platform_schema=None, - platform_schema_base=None, - async_setup=None, - async_setup_entry=None, - async_unload_entry=None, - async_migrate_entry=None, - async_remove_entry=None, - partial_manifest=None, - async_remove_config_entry_device=None, - ): - """Initialize the mock module.""" - self.__name__ = f"homeassistant.components.{domain}" - self.__file__ = f"homeassistant/components/{domain}" - self.DOMAIN = domain - self.DEPENDENCIES = dependencies or [] - self.REQUIREMENTS = requirements or [] - # Overlay to be used when generating manifest from this module - self._partial_manifest = partial_manifest - - if config_schema is not None: - self.CONFIG_SCHEMA = config_schema - - if platform_schema is not None: - self.PLATFORM_SCHEMA = platform_schema - - if platform_schema_base is not None: - self.PLATFORM_SCHEMA_BASE = platform_schema_base - - if setup: - # We run this in executor, wrap it in function - self.setup = lambda *args: setup(*args) - - if async_setup is not None: - self.async_setup = async_setup - - if setup is None and async_setup is None: - self.async_setup = AsyncMock(return_value=True) - - if async_setup_entry is not None: - self.async_setup_entry = async_setup_entry - - if async_unload_entry is not None: - self.async_unload_entry = async_unload_entry - - if async_migrate_entry is not None: - self.async_migrate_entry = async_migrate_entry - - if async_remove_entry is not None: - self.async_remove_entry = async_remove_entry - - if async_remove_config_entry_device is not None: - self.async_remove_config_entry_device = async_remove_config_entry_device - - def mock_manifest(self): - """Generate a mock manifest to represent this module.""" - return { - **loader.manifest_from_legacy_module(self.DOMAIN, self), - **(self._partial_manifest or {}), - } - - -class MockPlatform: - """Provide a fake platform.""" - - __name__ = "homeassistant.components.light.bla" - __file__ = "homeassistant/components/blah/light" - - def __init__( - self, - setup_platform=None, - dependencies=None, - platform_schema=None, - async_setup_platform=None, - async_setup_entry=None, - scan_interval=None, - ): - """Initialize the platform.""" - self.DEPENDENCIES = dependencies or [] - - if platform_schema is not None: - self.PLATFORM_SCHEMA = platform_schema - - if scan_interval is not None: - self.SCAN_INTERVAL = scan_interval - - if setup_platform is not None: - # We run this in executor, wrap it in function - self.setup_platform = lambda *args: setup_platform(*args) - - if async_setup_platform is not None: - self.async_setup_platform = async_setup_platform - - if async_setup_entry is not None: - self.async_setup_entry = async_setup_entry - - if setup_platform is None and async_setup_platform is None: - self.async_setup_platform = AsyncMock(return_value=None) - - -class MockEntityPlatform(entity_platform.EntityPlatform): - """Mock class with some mock defaults.""" - - def __init__( - self, - hass: HomeAssistant, - logger=None, - domain="test_domain", - platform_name="test_platform", - platform=None, - scan_interval=timedelta(seconds=15), - entity_namespace=None, - ): - """Initialize a mock entity platform.""" - if logger is None: - logger = logging.getLogger("homeassistant.helpers.entity_platform") - - # Otherwise the constructor will blow up. - if isinstance(platform, Mock) and isinstance(platform.PARALLEL_UPDATES, Mock): - platform.PARALLEL_UPDATES = 0 - - super().__init__( - hass=hass, - logger=logger, - domain=domain, - platform_name=platform_name, - platform=platform, - scan_interval=scan_interval, - entity_namespace=entity_namespace, - ) - - async def _async_on_stop(_: Event) -> None: - await self.async_shutdown() - - hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_on_stop) - - -class MockToggleEntity(entity.ToggleEntity): - """Provide a mock toggle device.""" - - def __init__(self, name, state, unique_id=None): - """Initialize the mock entity.""" - self._name = name or DEVICE_DEFAULT_NAME - self._state = state - self.calls = [] - - @property - def name(self): - """Return the name of the entity if any.""" - self.calls.append(("name", {})) - return self._name - - @property - def state(self): - """Return the state of the entity if any.""" - self.calls.append(("state", {})) - return self._state - - @property - def is_on(self): - """Return true if entity is on.""" - self.calls.append(("is_on", {})) - return self._state == STATE_ON - - def turn_on(self, **kwargs): - """Turn the entity on.""" - self.calls.append(("turn_on", kwargs)) - self._state = STATE_ON - - def turn_off(self, **kwargs): - """Turn the entity off.""" - self.calls.append(("turn_off", kwargs)) - self._state = STATE_OFF - - def last_call(self, method=None): - """Return the last call.""" - if not self.calls: - return None - if method is None: - return self.calls[-1] - try: - return next(call for call in reversed(self.calls) if call[0] == method) - except StopIteration: - return None - - -class MockConfigEntry(config_entries.ConfigEntry): - """Helper for creating config entries that adds some defaults.""" - - def __init__( - self, - *, - domain="test", - data=None, - version=1, - entry_id=None, - source=config_entries.SOURCE_USER, - title="Mock Title", - state=None, - options={}, - pref_disable_new_entities=None, - pref_disable_polling=None, - unique_id=None, - disabled_by=None, - reason=None, - ): - """Initialize a mock config entry.""" - kwargs = { - "entry_id": entry_id or uuid_util.random_uuid_hex(), - "domain": domain, - "data": data or {}, - "pref_disable_new_entities": pref_disable_new_entities, - "pref_disable_polling": pref_disable_polling, - "options": options, - "version": version, - "title": title, - "unique_id": unique_id, - "disabled_by": disabled_by, - } - if source is not None: - kwargs["source"] = source - if state is not None: - kwargs["state"] = state - super().__init__(**kwargs) - if reason is not None: - self.reason = reason - - def add_to_hass(self, hass): - """Test helper to add entry to hass.""" - hass.config_entries._entries[self.entry_id] = self - hass.config_entries._domain_index.setdefault(self.domain, []).append( - self.entry_id - ) - - def add_to_manager(self, manager): - """Test helper to add entry to entry manager.""" - manager._entries[self.entry_id] = self - manager._domain_index.setdefault(self.domain, []).append(self.entry_id) - - -def patch_yaml_files(files_dict, endswith=True): - """Patch load_yaml with a dictionary of yaml files.""" - # match using endswith, start search with longest string - matchlist = sorted(files_dict.keys(), key=len) if endswith else [] - - def mock_open_f(fname, **_): - """Mock open() in the yaml module, used by load_yaml.""" - # Return the mocked file on full match - if isinstance(fname, pathlib.Path): - fname = str(fname) - - if fname in files_dict: - _LOGGER.debug("patch_yaml_files match %s", fname) - res = StringIO(files_dict[fname]) - setattr(res, "name", fname) - return res - - # Match using endswith - for ends in matchlist: - if fname.endswith(ends): - _LOGGER.debug("patch_yaml_files end match %s: %s", ends, fname) - res = StringIO(files_dict[ends]) - setattr(res, "name", fname) - return res - - # Fallback for hass.components (i.e. services.yaml) - if "homeassistant/components" in fname: - _LOGGER.debug("patch_yaml_files using real file: %s", fname) - return open(fname, encoding="utf-8") - - # Not found - raise FileNotFoundError(f"File not found: {fname}") - - return patch.object(yaml_loader, "open", mock_open_f, create=True) - - -@contextmanager -def assert_setup_component(count, domain=None): - """Collect valid configuration from setup_component. - - - count: The amount of valid platforms that should be setup - - domain: The domain to count is optional. It can be automatically - determined most of the time - - Use as a context manager around setup.setup_component - with assert_setup_component(0) as result_config: - setup_component(hass, domain, start_config) - # using result_config is optional - """ - config = {} - - async def mock_psc(hass, config_input, integration): - """Mock the prepare_setup_component to capture config.""" - domain_input = integration.domain - res = await async_process_component_config(hass, config_input, integration) - config[domain_input] = None if res is None else res.get(domain_input) - _LOGGER.debug( - "Configuration for %s, Validated: %s, Original %s", - domain_input, - config[domain_input], - config_input.get(domain_input), - ) - return res - - assert isinstance(config, dict) - with patch("homeassistant.config.async_process_component_config", mock_psc): - yield config - - if domain is None: - assert len(config) == 1, "assert_setup_component requires DOMAIN: {}".format( - list(config.keys()) - ) - domain = list(config.keys())[0] - - res = config.get(domain) - res_len = 0 if res is None else len(res) - assert ( - res_len == count - ), f"setup_component failed, expected {count} got {res_len}: {res}" - - -def init_recorder_component(hass, add_config=None, db_url="sqlite://"): - """Initialize the recorder.""" - # Local import to avoid processing recorder and SQLite modules when running a - # testcase which does not use the recorder. - from homeassistant.components import recorder - - config = dict(add_config) if add_config else {} - if recorder.CONF_DB_URL not in config: - config[recorder.CONF_DB_URL] = db_url - if recorder.CONF_COMMIT_INTERVAL not in config: - config[recorder.CONF_COMMIT_INTERVAL] = 0 - - with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True): - if recorder.DOMAIN not in hass.data: - recorder_helper.async_initialize_recorder(hass) - assert setup_component(hass, recorder.DOMAIN, {recorder.DOMAIN: config}) - assert recorder.DOMAIN in hass.config.components - _LOGGER.info( - "Test recorder successfully started, database location: %s", - config[recorder.CONF_DB_URL], - ) - - -def mock_restore_cache(hass: HomeAssistant, states: Sequence[State]) -> None: - """Mock the DATA_RESTORE_CACHE.""" - key = restore_state.DATA_RESTORE_STATE - data = restore_state.RestoreStateData(hass) - now = dt_util.utcnow() - - last_states = {} - for state in states: - restored_state = state.as_dict() - restored_state = { - **restored_state, - "attributes": json.loads( - json.dumps(restored_state["attributes"], cls=JSONEncoder) - ), - } - last_states[state.entity_id] = restore_state.StoredState.from_dict( - {"state": restored_state, "last_seen": now} - ) - data.last_states = last_states - _LOGGER.debug("Restore cache: %s", data.last_states) - assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}" - - hass.data[key] = data - - -def mock_restore_cache_with_extra_data( - hass: HomeAssistant, states: Sequence[tuple[State, Mapping[str, Any]]] -) -> None: - """Mock the DATA_RESTORE_CACHE.""" - key = restore_state.DATA_RESTORE_STATE - data = restore_state.RestoreStateData(hass) - now = dt_util.utcnow() - - last_states = {} - for state, extra_data in states: - restored_state = state.as_dict() - restored_state = { - **restored_state, - "attributes": json.loads( - json.dumps(restored_state["attributes"], cls=JSONEncoder) - ), - } - last_states[state.entity_id] = restore_state.StoredState.from_dict( - {"state": restored_state, "extra_data": extra_data, "last_seen": now} - ) - data.last_states = last_states - _LOGGER.debug("Restore cache: %s", data.last_states) - assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}" - - hass.data[key] = data - - -async def async_mock_restore_state_shutdown_restart( - hass: HomeAssistant, -) -> restore_state.RestoreStateData: - """Mock shutting down and saving restore state and restoring.""" - data = restore_state.async_get(hass) - await data.async_dump_states() - await async_mock_load_restore_state_from_storage(hass) - return data - - -async def async_mock_load_restore_state_from_storage( - hass: HomeAssistant, -) -> None: - """Mock loading restore state from storage. - - hass_storage must already be mocked. - """ - await restore_state.async_get(hass).async_load() - - -class MockEntity(entity.Entity): - """Mock Entity class.""" - - def __init__(self, **values: Any) -> None: - """Initialize an entity.""" - self._values = values - - if "entity_id" in values: - self.entity_id = values["entity_id"] - - @property - def available(self) -> bool: - """Return True if entity is available.""" - return self._handle("available") - - @property - def capability_attributes(self) -> Mapping[str, Any] | None: - """Info about capabilities.""" - return self._handle("capability_attributes") - - @property - def device_class(self) -> str | None: - """Info how device should be classified.""" - return self._handle("device_class") - - @property - def device_info(self) -> dr.DeviceInfo | None: - """Info how it links to a device.""" - return self._handle("device_info") - - @property - def entity_category(self) -> entity.EntityCategory | None: - """Return the entity category.""" - return self._handle("entity_category") - - @property - def extra_state_attributes(self) -> Mapping[str, Any] | None: - """Return entity specific state attributes.""" - return self._handle("extra_state_attributes") - - @property - def has_entity_name(self) -> bool: - """Return the has_entity_name name flag.""" - return self._handle("has_entity_name") - - @property - def entity_registry_enabled_default(self) -> bool: - """Return if the entity should be enabled when first added to the entity registry.""" - return self._handle("entity_registry_enabled_default") - - @property - def entity_registry_visible_default(self) -> bool: - """Return if the entity should be visible when first added to the entity registry.""" - return self._handle("entity_registry_visible_default") - - @property - def icon(self) -> str | None: - """Return the suggested icon.""" - return self._handle("icon") - - @property - def name(self) -> str | None: - """Return the name of the entity.""" - return self._handle("name") - - @property - def should_poll(self) -> bool: - """Return the ste of the polling.""" - return self._handle("should_poll") - - @property - def state(self) -> StateType: - """Return the state of the entity.""" - return self._handle("state") - - @property - def supported_features(self) -> int | None: - """Info about supported features.""" - return self._handle("supported_features") - - @property - def translation_key(self) -> str | None: - """Return the translation key.""" - return self._handle("translation_key") - - @property - def unique_id(self) -> str | None: - """Return the unique ID of the entity.""" - return self._handle("unique_id") - - @property - def unit_of_measurement(self) -> str | None: - """Info on the units the entity state is in.""" - return self._handle("unit_of_measurement") - - def _handle(self, attr: str) -> Any: - """Return attribute value.""" - if attr in self._values: - return self._values[attr] - return getattr(super(), attr) - - -@contextmanager -def mock_storage( - data: dict[str, Any] | None = None -) -> Generator[dict[str, Any], None, None]: - """Mock storage. - - Data is a dict {'key': {'version': version, 'data': data}} - - Written data will be converted to JSON to ensure JSON parsing works. - """ - if data is None: - data = {} - - orig_load = storage.Store._async_load - - async def mock_async_load( - store: storage.Store, - ) -> dict[str, Any] | list[Any] | None: - """Mock version of load.""" - if store._data is None: - # No data to load - if store.key not in data: - # Make sure the next attempt will still load - store._load_task = None - return None - - mock_data = data.get(store.key) - - if "data" not in mock_data or "version" not in mock_data: - _LOGGER.error('Mock data needs "version" and "data"') - raise ValueError('Mock data needs "version" and "data"') - - store._data = mock_data - - # Route through original load so that we trigger migration - loaded = await orig_load(store) - _LOGGER.debug("Loading data for %s: %s", store.key, loaded) - return loaded - - async def mock_write_data( - store: storage.Store, path: str, data_to_write: dict[str, Any] - ) -> None: - """Mock version of write data.""" - # To ensure that the data can be serialized - _LOGGER.debug("Writing data to %s: %s", store.key, data_to_write) - raise_contains_mocks(data_to_write) - encoder = store._encoder - if encoder and encoder is not JSONEncoder: - # If they pass a custom encoder that is not the - # default JSONEncoder, we use the slow path of json.dumps - dump = ft.partial(json.dumps, cls=store._encoder) - else: - dump = _orjson_default_encoder - data[store.key] = json.loads(dump(data_to_write)) - - async def mock_remove(store: storage.Store) -> None: - """Remove data.""" - data.pop(store.key, None) - - with patch( - "homeassistant.helpers.storage.Store._async_load", - side_effect=mock_async_load, - autospec=True, - ), patch( - "homeassistant.helpers.storage.Store._async_write_data", - side_effect=mock_write_data, - autospec=True, - ), patch( - "homeassistant.helpers.storage.Store.async_remove", - side_effect=mock_remove, - autospec=True, - ): - yield data - - -async def flush_store(store: storage.Store) -> None: - """Make sure all delayed writes of a store are written.""" - if store._data is None: - return - - store._async_cleanup_final_write_listener() - store._async_cleanup_delay_listener() - await store._async_handle_write_data() - - -async def get_system_health_info(hass: HomeAssistant, domain: str) -> dict[str, Any]: - """Get system health info.""" - return await hass.data["system_health"][domain].info_callback(hass) - - -@contextmanager -def mock_config_flow(domain: str, config_flow: type[ConfigFlow]) -> None: - """Mock a config flow handler.""" - assert domain not in config_entries.HANDLERS - config_entries.HANDLERS[domain] = config_flow - _LOGGER.info("Adding mock config flow: %s", domain) - yield - config_entries.HANDLERS.pop(domain) - - -def mock_integration( - hass: HomeAssistant, module: MockModule, built_in: bool = True -) -> loader.Integration: - """Mock an integration.""" - integration = loader.Integration( - hass, - f"{loader.PACKAGE_BUILTIN}.{module.DOMAIN}" - if built_in - else f"{loader.PACKAGE_CUSTOM_COMPONENTS}.{module.DOMAIN}", - None, - module.mock_manifest(), - ) - - def mock_import_platform(platform_name: str) -> NoReturn: - raise ImportError( - f"Mocked unable to import platform '{integration.pkg_path}.{platform_name}'", - name=f"{integration.pkg_path}.{platform_name}", - ) - - integration._import_platform = mock_import_platform - - _LOGGER.info("Adding mock integration: %s", module.DOMAIN) - integration_cache = hass.data[loader.DATA_INTEGRATIONS] - integration_cache[module.DOMAIN] = integration - - module_cache = hass.data[loader.DATA_COMPONENTS] - module_cache[module.DOMAIN] = module - - return integration - - -def mock_entity_platform( - hass: HomeAssistant, platform_path: str, module: MockPlatform | None -) -> None: - """Mock a entity platform. - - platform_path is in form light.hue. Will create platform - hue.light. - """ - domain, platform_name = platform_path.split(".") - mock_platform(hass, f"{platform_name}.{domain}", module) - - -def mock_platform( - hass: HomeAssistant, platform_path: str, module: Mock | MockPlatform | None = None -) -> None: - """Mock a platform. - - platform_path is in form hue.config_flow. - """ - domain = platform_path.split(".")[0] - integration_cache = hass.data[loader.DATA_INTEGRATIONS] - module_cache = hass.data[loader.DATA_COMPONENTS] - - if domain not in integration_cache: - mock_integration(hass, MockModule(domain)) - - _LOGGER.info("Adding mock integration platform: %s", platform_path) - module_cache[platform_path] = module or Mock() - - -def async_capture_events(hass: HomeAssistant, event_name: str) -> list[Event]: - """Create a helper that captures events.""" - events = [] - - @callback - def capture_events(event: Event) -> None: - events.append(event) - - hass.bus.async_listen(event_name, capture_events, run_immediately=True) - - return events - - -@callback -def async_mock_signal(hass: HomeAssistant, signal: str) -> list[tuple[Any]]: - """Catch all dispatches to a signal.""" - calls = [] - - @callback - def mock_signal_handler(*args: Any) -> None: - """Mock service call.""" - calls.append(args) - - async_dispatcher_connect(hass, signal, mock_signal_handler) - - return calls - - -_SENTINEL = object() - - -class _HA_ANY: - """A helper object that compares equal to everything. - - Based on unittest.mock.ANY, but modified to not show up in pytest's equality - assertion diffs. - """ - - _other = _SENTINEL - - def __eq__(self, other: Any) -> bool: - """Test equal.""" - self._other = other - return True - - def __ne__(self, other: Any) -> bool: - """Test not equal.""" - self._other = other - return False - - def __repr__(self) -> str: - """Return repr() other to not show up in pytest quality diffs.""" - if self._other is _SENTINEL: - return "" - return repr(self._other) - - -ANY = _HA_ANY() - - -def raise_contains_mocks(val: Any) -> None: - """Raise for mocks.""" - if isinstance(val, Mock): - raise TypeError - - if isinstance(val, dict): - for dict_value in val.values(): - raise_contains_mocks(dict_value) - - if isinstance(val, list): - for dict_value in val: - raise_contains_mocks(dict_value) - - -@callback -def async_get_persistent_notifications( - hass: HomeAssistant, -) -> dict[str, pn.Notification]: - """Get the current persistent notifications.""" - return pn._async_get_or_create_notifications(hass) diff --git a/tests/hass/fixtures.py b/tests/hass/fixtures.py deleted file mode 100644 index 40fd1c2..0000000 --- a/tests/hass/fixtures.py +++ /dev/null @@ -1,1602 +0,0 @@ -"""Set up some common test helper things.""" -from __future__ import annotations - -import asyncio -from collections.abc import AsyncGenerator, Callable, Coroutine, Generator -from contextlib import asynccontextmanager -import datetime -import functools -import gc -import itertools -import logging -import os -import sqlite3 -import ssl -import threading -from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast -from unittest.mock import AsyncMock, MagicMock, Mock, patch - -from aiohttp import client -from aiohttp.test_utils import ( - BaseTestServer, - TestClient, - TestServer, - make_mocked_request, -) -from aiohttp.typedefs import JSONDecoder -from aiohttp.web import Application -import freezegun -import multidict -import pytest -import pytest_socket -import requests_mock -from syrupy.assertion import SnapshotAssertion - -from homeassistant import core as ha, loader, runner -from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY -from homeassistant.auth.models import Credentials -from homeassistant.auth.providers import homeassistant, legacy_api_password -from homeassistant.components.device_tracker.legacy import Device -from homeassistant.components.network.models import Adapter, IPv4ConfiguredAddress -from homeassistant.components.websocket_api.auth import ( - TYPE_AUTH, - TYPE_AUTH_OK, - TYPE_AUTH_REQUIRED, -) -from homeassistant.components.websocket_api.http import URL -from homeassistant.config import YAML_CONFIG_FILE -from homeassistant.config_entries import ConfigEntries, ConfigEntry -from homeassistant.const import HASSIO_USER_NAME -from homeassistant.core import CoreState, HassJob, HomeAssistant -from homeassistant.helpers import ( - area_registry as ar, - config_entry_oauth2_flow, - device_registry as dr, - entity_registry as er, - event, - issue_registry as ir, - recorder as recorder_helper, -) -from homeassistant.helpers.typing import ConfigType -from homeassistant.setup import BASE_PLATFORMS, async_setup_component -from homeassistant.util import dt as dt_util, location -from homeassistant.util.json import json_loads - -from .ignore_uncaught_exceptions import IGNORE_UNCAUGHT_EXCEPTIONS -from .syrupy import HomeAssistantSnapshotExtension -from .typing import ( - ClientSessionGenerator, - MockHAClientWebSocket, - MqttMockHAClient, - MqttMockHAClientGenerator, - MqttMockPahoClient, - RecorderInstanceGenerator, - WebSocketGenerator, -) - -if TYPE_CHECKING: - # Local import to avoid processing recorder and SQLite modules when running a - # testcase which does not use the recorder. - from homeassistant.components import recorder - -pytest.register_assert_rewrite("tests.common") - -from .common import ( # noqa: E402, isort:skip - CLIENT_ID, - INSTANCES, - MockConfigEntry, - MockUser, - async_fire_mqtt_message, - async_test_home_assistant, - get_test_home_assistant, - init_recorder_component, - mock_storage, - patch_yaml_files, -) -from .test_util.aiohttp import ( # noqa: E402, isort:skip - AiohttpClientMocker, - mock_aiohttp_client, -) - - -_LOGGER = logging.getLogger(__name__) - -logging.basicConfig(level=logging.INFO) -logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) - -asyncio.set_event_loop_policy(runner.HassEventLoopPolicy(False)) -# Disable fixtures overriding our beautiful policy -asyncio.set_event_loop_policy = lambda policy: None - - -def _utcnow() -> datetime.datetime: - """Make utcnow patchable by freezegun.""" - return datetime.datetime.now(datetime.UTC) - - -dt_util.utcnow = _utcnow # type: ignore[assignment] -event.time_tracker_utcnow = _utcnow # type: ignore[assignment] - - -def pytest_addoption(parser: pytest.Parser) -> None: - """Register custom pytest options.""" - parser.addoption("--dburl", action="store", default="sqlite://") - - -def pytest_configure(config: pytest.Config) -> None: - """Register marker for tests that log exceptions.""" - config.addinivalue_line( - "markers", "no_fail_on_log_exception: mark test to not fail on logged exception" - ) - if config.getoption("verbose") > 0: - logging.getLogger().setLevel(logging.DEBUG) - - -def pytest_runtest_setup() -> None: - """Prepare pytest_socket and freezegun. - - pytest_socket: - Throw if tests attempt to open sockets. - - allow_unix_socket is set to True because it's needed by asyncio. - Important: socket_allow_hosts must be called before disable_socket, otherwise all - destinations will be allowed. - - freezegun: - Modified to include https://github.com/spulec/freezegun/pull/424 - """ - pytest_socket.socket_allow_hosts(["127.0.0.1"]) - pytest_socket.disable_socket(allow_unix_socket=True) - - freezegun.api.datetime_to_fakedatetime = ha_datetime_to_fakedatetime # type: ignore[attr-defined] - freezegun.api.FakeDatetime = HAFakeDatetime # type: ignore[attr-defined] - - def adapt_datetime(val): - return val.isoformat(" ") - - # Setup HAFakeDatetime converter for sqlite3 - sqlite3.register_adapter(HAFakeDatetime, adapt_datetime) - - # Setup HAFakeDatetime converter for pymysql - try: - # pylint: disable-next=import-outside-toplevel - import MySQLdb.converters as MySQLdb_converters - except ImportError: - pass - else: - MySQLdb_converters.conversions[ - HAFakeDatetime - ] = MySQLdb_converters.DateTime2literal - - -def ha_datetime_to_fakedatetime(datetime) -> freezegun.api.FakeDatetime: # type: ignore[name-defined] - """Convert datetime to FakeDatetime. - - Modified to include https://github.com/spulec/freezegun/pull/424. - """ - return freezegun.api.FakeDatetime( # type: ignore[attr-defined] - datetime.year, - datetime.month, - datetime.day, - datetime.hour, - datetime.minute, - datetime.second, - datetime.microsecond, - datetime.tzinfo, - fold=datetime.fold, - ) - - -class HAFakeDatetime(freezegun.api.FakeDatetime): # type: ignore[name-defined] - """Modified to include https://github.com/spulec/freezegun/pull/424.""" - - @classmethod - def now(cls, tz=None): - """Return frozen now.""" - now = cls._time_to_freeze() or freezegun.api.real_datetime.now() - if tz: - result = tz.fromutc(now.replace(tzinfo=tz)) - else: - result = now - - # Add the _tz_offset only if it's non-zero to preserve fold - if cls._tz_offset(): - result += cls._tz_offset() - - return ha_datetime_to_fakedatetime(result) - - -_R = TypeVar("_R") -_P = ParamSpec("_P") - - -def check_real(func: Callable[_P, Coroutine[Any, Any, _R]]): - """Force a function to require a keyword _test_real to be passed in.""" - - @functools.wraps(func) - async def guard_func(*args: _P.args, **kwargs: _P.kwargs) -> _R: - real = kwargs.pop("_test_real", None) - - if not real: - raise RuntimeError( - f'Forgot to mock or pass "_test_real=True" to {func.__name__}' - ) - - return await func(*args, **kwargs) - - return guard_func - - -# Guard a few functions that would make network connections -location.async_detect_location_info = check_real(location.async_detect_location_info) - - -@pytest.fixture(name="caplog") -def caplog_fixture(caplog: pytest.LogCaptureFixture) -> pytest.LogCaptureFixture: - """Set log level to debug for tests using the caplog fixture.""" - caplog.set_level(logging.DEBUG) - return caplog - - -@pytest.fixture(autouse=True, scope="module") -def garbage_collection() -> None: - """Run garbage collection at known locations. - - This is to mimic the behavior of pytest-aiohttp, and is - required to avoid warnings during garbage collection from - spilling over into next test case. We run it per module which - handles the most common cases and let each module override - to run per test case if needed. - """ - gc.collect() - - -@pytest.fixture(autouse=True) -def expected_lingering_tasks() -> bool: - """Temporary ability to bypass test failures. - - Parametrize to True to bypass the pytest failure. - @pytest.mark.parametrize("expected_lingering_tasks", [True]) - - This should be removed when all lingering tasks have been cleaned up. - """ - return False - - -@pytest.fixture(autouse=True) -def expected_lingering_timers() -> bool: - """Temporary ability to bypass test failures. - - Parametrize to True to bypass the pytest failure. - @pytest.mark.parametrize("expected_lingering_timers", [True]) - - This should be removed when all lingering timers have been cleaned up. - """ - current_test = os.getenv("PYTEST_CURRENT_TEST") - if ( - current_test - and current_test.startswith("tests/components/") - and current_test.split("/")[2] not in BASE_PLATFORMS - ): - # As a starting point, we ignore non-platform components - return True - return False - - -@pytest.fixture -def wait_for_stop_scripts_after_shutdown() -> bool: - """Add ability to bypass _schedule_stop_scripts_after_shutdown. - - _schedule_stop_scripts_after_shutdown leaves a lingering timer. - - Parametrize to True to bypass the pytest failure. - @pytest.mark.parametrize("wait_for_stop_scripts_at_shutdown", [True]) - """ - return False - - -@pytest.fixture(autouse=True) -def skip_stop_scripts( - wait_for_stop_scripts_after_shutdown: bool, -) -> Generator[None, None, None]: - """Add ability to bypass _schedule_stop_scripts_after_shutdown.""" - if wait_for_stop_scripts_after_shutdown: - yield - return - with patch( - "homeassistant.helpers.script._schedule_stop_scripts_after_shutdown", - Mock(), - ): - yield - - -@pytest.fixture(autouse=True) -def verify_cleanup( - event_loop: asyncio.AbstractEventLoop, - expected_lingering_tasks: bool, - expected_lingering_timers: bool, -) -> Generator[None, None, None]: - """Verify that the test has cleaned up resources correctly.""" - threads_before = frozenset(threading.enumerate()) - tasks_before = asyncio.all_tasks(event_loop) - yield - - event_loop.run_until_complete(event_loop.shutdown_default_executor()) - - if len(INSTANCES) >= 2: - count = len(INSTANCES) - for inst in INSTANCES: - inst.stop() - pytest.exit(f"Detected non stopped instances ({count}), aborting test run") - - # Warn and clean-up lingering tasks and timers - # before moving on to the next test. - tasks = asyncio.all_tasks(event_loop) - tasks_before - for task in tasks: - if expected_lingering_tasks: - _LOGGER.warning("Lingering task after test %r", task) - else: - pytest.fail(f"Lingering task after test {repr(task)}") - task.cancel() - if tasks: - event_loop.run_until_complete(asyncio.wait(tasks)) - - for handle in event_loop._scheduled: # type: ignore[attr-defined] - if not handle.cancelled(): - if expected_lingering_timers: - _LOGGER.warning("Lingering timer after test %r", handle) - elif handle._args and isinstance(job := handle._args[0], HassJob): - pytest.fail(f"Lingering timer after job {repr(job)}") - else: - pytest.fail(f"Lingering timer after test {repr(handle)}") - handle.cancel() - - # Verify no threads where left behind. - threads = frozenset(threading.enumerate()) - threads_before - for thread in threads: - assert isinstance(thread, threading._DummyThread) or thread.name.startswith( - "waitpid-" - ) - - -@pytest.fixture(autouse=True) -def bcrypt_cost() -> Generator[None, None, None]: - """Run with reduced rounds during tests, to speed up uses.""" - import bcrypt - - gensalt_orig = bcrypt.gensalt - - def gensalt_mock(rounds=12, prefix=b"2b"): - return gensalt_orig(4, prefix) - - bcrypt.gensalt = gensalt_mock - yield - bcrypt.gensalt = gensalt_orig - - -@pytest.fixture -def hass_storage() -> Generator[dict[str, Any], None, None]: - """Fixture to mock storage.""" - with mock_storage() as stored_data: - yield stored_data - - -@pytest.fixture -def load_registries() -> bool: - """Fixture to control the loading of registries when setting up the hass fixture. - - To avoid loading the registries, tests can be marked with: - @pytest.mark.parametrize("load_registries", [False]) - """ - return True - - -class CoalescingResponse(client.ClientWebSocketResponse): - """ClientWebSocketResponse client that mimics the websocket js code.""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - """Init the ClientWebSocketResponse.""" - super().__init__(*args, **kwargs) - self._recv_buffer: list[Any] = [] - - async def receive_json( - self, - *, - loads: JSONDecoder = json_loads, - timeout: float | None = None, - ) -> Any: - """receive_json or from buffer.""" - if self._recv_buffer: - return self._recv_buffer.pop(0) - data = await self.receive_str(timeout=timeout) - decoded = loads(data) - if isinstance(decoded, list): - self._recv_buffer = decoded - return self._recv_buffer.pop(0) - return decoded - - -class CoalescingClient(TestClient): - """Client that mimics the websocket js code.""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - """Init TestClient.""" - super().__init__(*args, ws_response_class=CoalescingResponse, **kwargs) - - -@pytest.fixture -def aiohttp_client_cls() -> type[CoalescingClient]: - """Override the test class for aiohttp.""" - return CoalescingClient - - -@pytest.fixture -def aiohttp_client( - event_loop: asyncio.AbstractEventLoop, -) -> Generator[ClientSessionGenerator, None, None]: - """Override the default aiohttp_client since 3.x does not support aiohttp_client_cls. - - Remove this when upgrading to 4.x as aiohttp_client_cls - will do the same thing - - aiohttp_client(app, **kwargs) - aiohttp_client(server, **kwargs) - aiohttp_client(raw_server, **kwargs) - """ - loop = event_loop - clients = [] - - async def go( - __param: Application | BaseTestServer, - *args: Any, - server_kwargs: dict[str, Any] | None = None, - **kwargs: Any, - ) -> TestClient: - if isinstance(__param, Callable) and not isinstance( # type: ignore[arg-type] - __param, (Application, BaseTestServer) - ): - __param = __param(loop, *args, **kwargs) - kwargs = {} - else: - assert not args, "args should be empty" - - client: TestClient - if isinstance(__param, Application): - server_kwargs = server_kwargs or {} - server = TestServer(__param, loop=loop, **server_kwargs) - client = CoalescingClient(server, loop=loop, **kwargs) - elif isinstance(__param, BaseTestServer): - client = TestClient(__param, loop=loop, **kwargs) - else: - raise TypeError("Unknown argument type: %r" % type(__param)) - - await client.start_server() - clients.append(client) - return client - - yield go - - async def finalize() -> None: - while clients: - await clients.pop().close() - - loop.run_until_complete(finalize()) - - -@pytest.fixture -def hass_fixture_setup() -> list[bool]: - """Fixture which is truthy if the hass fixture has been setup.""" - return [] - - -@pytest.fixture -async def hass( - hass_fixture_setup: list[bool], - event_loop: asyncio.AbstractEventLoop, - load_registries: bool, - hass_storage: dict[str, Any], - request: pytest.FixtureRequest, -) -> AsyncGenerator[HomeAssistant, None]: - """Create a test instance of Home Assistant.""" - - loop = event_loop - hass_fixture_setup.append(True) - - orig_tz = dt_util.DEFAULT_TIME_ZONE - - def exc_handle(loop, context): - """Handle exceptions by rethrowing them, which will fail the test.""" - # Most of these contexts will contain an exception, but not all. - # The docs note the key as "optional" - # See https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.call_exception_handler - if "exception" in context: - exceptions.append(context["exception"]) - else: - exceptions.append( - Exception( - "Received exception handler without exception, but with message: %s" - % context["message"] - ) - ) - orig_exception_handler(loop, context) - - exceptions: list[Exception] = [] - hass = await async_test_home_assistant(loop, load_registries) - - orig_exception_handler = loop.get_exception_handler() - loop.set_exception_handler(exc_handle) - - yield hass - - # Config entries are not normally unloaded on HA shutdown. They are unloaded here - # to ensure that they could, and to help track lingering tasks and timers. - await asyncio.gather( - *( - config_entry.async_unload(hass) - for config_entry in hass.config_entries.async_entries() - ) - ) - - await hass.async_stop(force=True) - - # Restore timezone, it is set when creating the hass object - dt_util.DEFAULT_TIME_ZONE = orig_tz - - for ex in exceptions: - if ( - request.module.__name__, - request.function.__name__, - ) in IGNORE_UNCAUGHT_EXCEPTIONS: - continue - raise ex - - -@pytest.fixture -async def stop_hass( - event_loop: asyncio.AbstractEventLoop, -) -> AsyncGenerator[None, None]: - """Make sure all hass are stopped.""" - orig_hass = ha.HomeAssistant - - created = [] - - def mock_hass(): - hass_inst = orig_hass() - created.append(hass_inst) - return hass_inst - - with patch("homeassistant.core.HomeAssistant", mock_hass): - yield - - for hass_inst in created: - if hass_inst.state == ha.CoreState.stopped: - continue - - with patch.object(hass_inst.loop, "stop"): - await hass_inst.async_block_till_done() - await hass_inst.async_stop(force=True) - await event_loop.shutdown_default_executor() - - -@pytest.fixture(name="requests_mock") -def requests_mock_fixture() -> Generator[requests_mock.Mocker, None, None]: - """Fixture to provide a requests mocker.""" - with requests_mock.mock() as m: - yield m - - -@pytest.fixture -def aioclient_mock() -> Generator[AiohttpClientMocker, None, None]: - """Fixture to mock aioclient calls.""" - with mock_aiohttp_client() as mock_session: - yield mock_session - - -@pytest.fixture -def mock_device_tracker_conf() -> Generator[list[Device], None, None]: - """Prevent device tracker from reading/writing data.""" - devices: list[Device] = [] - - async def mock_update_config(path: str, dev_id: str, entity: Device) -> None: - devices.append(entity) - - with patch( - ( - "homeassistant.components.device_tracker.legacy" - ".DeviceTracker.async_update_config" - ), - side_effect=mock_update_config, - ), patch( - "homeassistant.components.device_tracker.legacy.async_load_config", - side_effect=lambda *args: devices, - ): - yield devices - - -@pytest.fixture -async def hass_admin_credential( - hass: HomeAssistant, local_auth: homeassistant.HassAuthProvider -) -> Credentials: - """Provide credentials for admin user.""" - return Credentials( - id="mock-credential-id", - auth_provider_type="homeassistant", - auth_provider_id=None, - data={"username": "admin"}, - is_new=False, - ) - - -@pytest.fixture -async def hass_access_token( - hass: HomeAssistant, hass_admin_user: MockUser, hass_admin_credential: Credentials -) -> str: - """Return an access token to access Home Assistant.""" - await hass.auth.async_link_user(hass_admin_user, hass_admin_credential) - - refresh_token = await hass.auth.async_create_refresh_token( - hass_admin_user, CLIENT_ID, credential=hass_admin_credential - ) - return hass.auth.async_create_access_token(refresh_token) - - -@pytest.fixture -def hass_owner_user( - hass: HomeAssistant, local_auth: homeassistant.HassAuthProvider -) -> MockUser: - """Return a Home Assistant admin user.""" - return MockUser(is_owner=True).add_to_hass(hass) - - -@pytest.fixture -async def hass_admin_user( - hass: HomeAssistant, local_auth: homeassistant.HassAuthProvider -) -> MockUser: - """Return a Home Assistant admin user.""" - admin_group = await hass.auth.async_get_group(GROUP_ID_ADMIN) - return MockUser(groups=[admin_group]).add_to_hass(hass) - - -@pytest.fixture -async def hass_read_only_user( - hass: HomeAssistant, local_auth: homeassistant.HassAuthProvider -) -> MockUser: - """Return a Home Assistant read only user.""" - read_only_group = await hass.auth.async_get_group(GROUP_ID_READ_ONLY) - return MockUser(groups=[read_only_group]).add_to_hass(hass) - - -@pytest.fixture -async def hass_read_only_access_token( - hass: HomeAssistant, - hass_read_only_user: MockUser, - local_auth: homeassistant.HassAuthProvider, -) -> str: - """Return a Home Assistant read only user.""" - credential = Credentials( - id="mock-readonly-credential-id", - auth_provider_type="homeassistant", - auth_provider_id=None, - data={"username": "readonly"}, - is_new=False, - ) - hass_read_only_user.credentials.append(credential) - - refresh_token = await hass.auth.async_create_refresh_token( - hass_read_only_user, CLIENT_ID, credential=credential - ) - return hass.auth.async_create_access_token(refresh_token) - - -@pytest.fixture -async def hass_supervisor_user( - hass: HomeAssistant, local_auth: homeassistant.HassAuthProvider -) -> MockUser: - """Return the Home Assistant Supervisor user.""" - admin_group = await hass.auth.async_get_group(GROUP_ID_ADMIN) - return MockUser( - name=HASSIO_USER_NAME, groups=[admin_group], system_generated=True - ).add_to_hass(hass) - - -@pytest.fixture -async def hass_supervisor_access_token( - hass: HomeAssistant, - hass_supervisor_user, - local_auth: homeassistant.HassAuthProvider, -) -> str: - """Return a Home Assistant Supervisor access token.""" - refresh_token = await hass.auth.async_create_refresh_token(hass_supervisor_user) - return hass.auth.async_create_access_token(refresh_token) - - -@pytest.fixture -def legacy_auth( - hass: HomeAssistant, -) -> legacy_api_password.LegacyApiPasswordAuthProvider: - """Load legacy API password provider.""" - prv = legacy_api_password.LegacyApiPasswordAuthProvider( - hass, - hass.auth._store, - {"type": "legacy_api_password", "api_password": "test-password"}, - ) - hass.auth._providers[(prv.type, prv.id)] = prv - return prv - - -@pytest.fixture -async def local_auth(hass: HomeAssistant) -> homeassistant.HassAuthProvider: - """Load local auth provider.""" - prv = homeassistant.HassAuthProvider( - hass, hass.auth._store, {"type": "homeassistant"} - ) - await prv.async_initialize() - hass.auth._providers[(prv.type, prv.id)] = prv - return prv - - -@pytest.fixture -def hass_client( - hass: HomeAssistant, - aiohttp_client: ClientSessionGenerator, - hass_access_token: str, - socket_enabled: None, -) -> ClientSessionGenerator: - """Return an authenticated HTTP client.""" - - async def auth_client() -> TestClient: - """Return an authenticated client.""" - return await aiohttp_client( - hass.http.app, headers={"Authorization": f"Bearer {hass_access_token}"} - ) - - return auth_client - - -@pytest.fixture -def hass_client_no_auth( - hass: HomeAssistant, - aiohttp_client: ClientSessionGenerator, - socket_enabled: None, -) -> ClientSessionGenerator: - """Return an unauthenticated HTTP client.""" - - async def client() -> TestClient: - """Return an authenticated client.""" - return await aiohttp_client(hass.http.app) - - return client - - -@pytest.fixture -def current_request() -> Generator[MagicMock, None, None]: - """Mock current request.""" - with patch("homeassistant.components.http.current_request") as mock_request_context: - mocked_request = make_mocked_request( - "GET", - "/some/request", - headers={"Host": "example.com"}, - sslcontext=ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT), - ) - mock_request_context.get.return_value = mocked_request - yield mock_request_context - - -@pytest.fixture -def current_request_with_host(current_request: MagicMock) -> None: - """Mock current request with a host header.""" - new_headers = multidict.CIMultiDict(current_request.get.return_value.headers) - new_headers[config_entry_oauth2_flow.HEADER_FRONTEND_BASE] = "https://example.com" - current_request.get.return_value = current_request.get.return_value.clone( - headers=new_headers - ) - - -@pytest.fixture -def hass_ws_client( - aiohttp_client: ClientSessionGenerator, - hass_access_token: str | None, - hass: HomeAssistant, - socket_enabled: None, -) -> WebSocketGenerator: - """Websocket client fixture connected to websocket server.""" - - async def create_client( - hass: HomeAssistant = hass, access_token: str | None = hass_access_token - ) -> MockHAClientWebSocket: - """Create a websocket client.""" - assert await async_setup_component(hass, "websocket_api", {}) - client = await aiohttp_client(hass.http.app) - websocket = await client.ws_connect(URL) - auth_resp = await websocket.receive_json() - assert auth_resp["type"] == TYPE_AUTH_REQUIRED - - if access_token is None: - await websocket.send_json({"type": TYPE_AUTH, "access_token": "incorrect"}) - else: - await websocket.send_json({"type": TYPE_AUTH, "access_token": access_token}) - - auth_ok = await websocket.receive_json() - assert auth_ok["type"] == TYPE_AUTH_OK - - def _get_next_id() -> Generator[int, None, None]: - i = 0 - while True: - yield (i := i + 1) - - id_generator = _get_next_id() - - def _send_json_auto_id(data: dict[str, Any]) -> Coroutine[Any, Any, None]: - data["id"] = next(id_generator) - return websocket.send_json(data) - - # wrap in client - wrapped_websocket = cast(MockHAClientWebSocket, websocket) - wrapped_websocket.client = client - wrapped_websocket.send_json_auto_id = _send_json_auto_id - return wrapped_websocket - - return create_client - - -@pytest.fixture(autouse=True) -def fail_on_log_exception( - request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch -) -> None: - """Fixture to fail if a callback wrapped by catch_log_exception or coroutine wrapped by async_create_catching_coro throws.""" - if "no_fail_on_log_exception" in request.keywords: - return - - def log_exception(format_err, *args): - raise - - monkeypatch.setattr("homeassistant.util.logging.log_exception", log_exception) - - -@pytest.fixture -def mqtt_config_entry_data() -> dict[str, Any] | None: - """Fixture to allow overriding MQTT config.""" - return None - - -@pytest.fixture -def mqtt_client_mock(hass: HomeAssistant) -> Generator[MqttMockPahoClient, None, None]: - """Fixture to mock MQTT client.""" - - mid: int = 0 - - def get_mid() -> int: - nonlocal mid - mid += 1 - return mid - - class FakeInfo: - """Class to fake MQTT info.""" - - def __init__(self, mid: int) -> None: - self.mid = mid - self.rc = 0 - - with patch("paho.mqtt.client.Client") as mock_client: - - @ha.callback - def _async_fire_mqtt_message(topic, payload, qos, retain): - async_fire_mqtt_message(hass, topic, payload, qos, retain) - mid = get_mid() - mock_client.on_publish(0, 0, mid) - return FakeInfo(mid) - - def _subscribe(topic, qos=0): - mid = get_mid() - mock_client.on_subscribe(0, 0, mid) - return (0, mid) - - def _unsubscribe(topic): - mid = get_mid() - mock_client.on_unsubscribe(0, 0, mid) - return (0, mid) - - mock_client = mock_client.return_value - mock_client.connect.return_value = 0 - mock_client.subscribe.side_effect = _subscribe - mock_client.unsubscribe.side_effect = _unsubscribe - mock_client.publish.side_effect = _async_fire_mqtt_message - yield mock_client - - -@pytest.fixture -async def mqtt_mock( - hass: HomeAssistant, - mock_hass_config: None, - mqtt_client_mock: MqttMockPahoClient, - mqtt_config_entry_data: dict[str, Any] | None, - mqtt_mock_entry: MqttMockHAClientGenerator, -) -> AsyncGenerator[MqttMockHAClient, None]: - """Fixture to mock MQTT component.""" - with patch("homeassistant.components.mqtt.PLATFORMS", []): - return await mqtt_mock_entry() - - -@asynccontextmanager -async def _mqtt_mock_entry( - hass: HomeAssistant, - mqtt_client_mock: MqttMockPahoClient, - mqtt_config_entry_data: dict[str, Any] | None, -) -> AsyncGenerator[MqttMockHAClientGenerator, None]: - """Fixture to mock a delayed setup of the MQTT config entry.""" - # Local import to avoid processing MQTT modules when running a testcase - # which does not use MQTT. - from homeassistant.components import mqtt # pylint: disable=import-outside-toplevel - - if mqtt_config_entry_data is None: - mqtt_config_entry_data = { - mqtt.CONF_BROKER: "mock-broker", - mqtt.CONF_BIRTH_MESSAGE: {}, - } - - await hass.async_block_till_done() - - entry = MockConfigEntry( - data=mqtt_config_entry_data, - domain=mqtt.DOMAIN, - title="MQTT", - ) - entry.add_to_hass(hass) - - real_mqtt = mqtt.MQTT - real_mqtt_instance = None - mock_mqtt_instance = None - - async def _setup_mqtt_entry( - setup_entry: Callable[[HomeAssistant, ConfigEntry], Coroutine[Any, Any, bool]] - ) -> MagicMock: - """Set up the MQTT config entry.""" - assert await setup_entry(hass, entry) - - # Assert that MQTT is setup - assert real_mqtt_instance is not None, "MQTT was not setup correctly" - mock_mqtt_instance.conf = real_mqtt_instance.conf # For diagnostics - mock_mqtt_instance._mqttc = mqtt_client_mock - - # connected set to True to get a more realistic behavior when subscribing - mock_mqtt_instance.connected = True - - hass.helpers.dispatcher.async_dispatcher_send(mqtt.MQTT_CONNECTED) - await hass.async_block_till_done() - - return mock_mqtt_instance - - def create_mock_mqtt(*args, **kwargs) -> MqttMockHAClient: - """Create a mock based on mqtt.MQTT.""" - nonlocal mock_mqtt_instance - nonlocal real_mqtt_instance - real_mqtt_instance = real_mqtt(*args, **kwargs) - spec = dir(real_mqtt_instance) + ["_mqttc"] - mock_mqtt_instance = MqttMockHAClient( - return_value=real_mqtt_instance, - spec_set=spec, - wraps=real_mqtt_instance, - ) - return mock_mqtt_instance - - with patch("homeassistant.components.mqtt.MQTT", side_effect=create_mock_mqtt): - yield _setup_mqtt_entry - - -@pytest.fixture -def hass_config() -> ConfigType: - """Fixture to parametrize the content of main configuration using mock_hass_config. - - To set a configuration, tests can be marked with: - @pytest.mark.parametrize("hass_config", [{integration: {...}}]) - Add the `mock_hass_config: None` fixture to the test. - """ - return {} - - -@pytest.fixture -def mock_hass_config( - hass: HomeAssistant, hass_config: ConfigType -) -> Generator[None, None, None]: - """Fixture to mock the content of main configuration. - - Patches homeassistant.config.load_yaml_config_file and hass.config_entries - with `hass_config` as parameterized. - """ - if hass_config: - hass.config_entries = ConfigEntries(hass, hass_config) - with patch("homeassistant.config.load_yaml_config_file", return_value=hass_config): - yield - - -@pytest.fixture -def hass_config_yaml() -> str: - """Fixture to parametrize the content of configuration.yaml file. - - To set yaml content, tests can be marked with: - @pytest.mark.parametrize("hass_config_yaml", ["..."]) - Add the `mock_hass_config_yaml: None` fixture to the test. - """ - return "" - - -@pytest.fixture -def hass_config_yaml_files(hass_config_yaml: str) -> dict[str, str]: - """Fixture to parametrize multiple yaml configuration files. - - To set the YAML files to patch, tests can be marked with: - @pytest.mark.parametrize( - "hass_config_yaml_files", [{"configuration.yaml": "..."}] - ) - Add the `mock_hass_config_yaml: None` fixture to the test. - """ - return {YAML_CONFIG_FILE: hass_config_yaml} - - -@pytest.fixture -def mock_hass_config_yaml( - hass: HomeAssistant, hass_config_yaml_files: dict[str, str] -) -> Generator[None, None, None]: - """Fixture to mock the content of the yaml configuration files. - - Patches yaml configuration files using the `hass_config_yaml` - and `hass_config_yaml_files` fixtures. - """ - with patch_yaml_files(hass_config_yaml_files): - yield - - -@pytest.fixture -async def mqtt_mock_entry( - hass: HomeAssistant, - mqtt_client_mock: MqttMockPahoClient, - mqtt_config_entry_data: dict[str, Any] | None, -) -> AsyncGenerator[MqttMockHAClientGenerator, None]: - """Set up an MQTT config entry.""" - - async def _async_setup_config_entry( - hass: HomeAssistant, entry: ConfigEntry - ) -> bool: - """Help set up the config entry.""" - assert await hass.config_entries.async_setup(entry.entry_id) - await hass.async_block_till_done() - return True - - async def _setup_mqtt_entry() -> MqttMockHAClient: - """Set up the MQTT config entry.""" - return await mqtt_mock_entry(_async_setup_config_entry) - - async with _mqtt_mock_entry( - hass, mqtt_client_mock, mqtt_config_entry_data - ) as mqtt_mock_entry: - yield _setup_mqtt_entry - - -@pytest.fixture(autouse=True) -def mock_network() -> Generator[None, None, None]: - """Mock network.""" - mock_adapter = Adapter( - name="eth0", - index=0, - enabled=True, - auto=True, - default=True, - ipv4=[IPv4ConfiguredAddress(address="10.10.10.10", network_prefix=24)], - ipv6=[], - ) - with patch( - "homeassistant.components.network.network.async_load_adapters", - return_value=[mock_adapter], - ): - yield - - -@pytest.fixture(autouse=True) -def mock_get_source_ip() -> Generator[None, None, None]: - """Mock network util's async_get_source_ip.""" - with patch( - "homeassistant.components.network.util.async_get_source_ip", - return_value="10.10.10.10", - ): - yield - - -@pytest.fixture -def mock_zeroconf() -> Generator[None, None, None]: - """Mock zeroconf.""" - from zeroconf import DNSCache # pylint: disable=import-outside-toplevel - - with patch( - "homeassistant.components.zeroconf.HaZeroconf", autospec=True - ) as mock_zc, patch( - "homeassistant.components.zeroconf.HaAsyncServiceBrowser", autospec=True - ): - zc = mock_zc.return_value - # DNSCache has strong Cython type checks, and MagicMock does not work - # so we must mock the class directly - zc.cache = DNSCache() - yield mock_zc - - -@pytest.fixture -def mock_async_zeroconf(mock_zeroconf: None) -> Generator[None, None, None]: - """Mock AsyncZeroconf.""" - from zeroconf import DNSCache # pylint: disable=import-outside-toplevel - - with patch("homeassistant.components.zeroconf.HaAsyncZeroconf") as mock_aiozc: - zc = mock_aiozc.return_value - zc.async_unregister_service = AsyncMock() - zc.async_register_service = AsyncMock() - zc.async_update_service = AsyncMock() - zc.zeroconf.async_wait_for_start = AsyncMock() - # DNSCache has strong Cython type checks, and MagicMock does not work - # so we must mock the class directly - zc.zeroconf.cache = DNSCache() - zc.zeroconf.done = False - zc.async_close = AsyncMock() - zc.ha_async_close = AsyncMock() - yield zc - - -@pytest.fixture -def enable_custom_integrations(hass: HomeAssistant) -> None: - """Enable custom integrations defined in the test dir.""" - hass.data.pop(loader.DATA_CUSTOM_COMPONENTS) - - -@pytest.fixture -def enable_statistics() -> bool: - """Fixture to control enabling of recorder's statistics compilation. - - To enable statistics, tests can be marked with: - @pytest.mark.parametrize("enable_statistics", [True]) - """ - return False - - -@pytest.fixture -def enable_schema_validation() -> bool: - """Fixture to control enabling of recorder's statistics table validation. - - To enable statistics table validation, tests can be marked with: - @pytest.mark.parametrize("enable_schema_validation", [True]) - """ - return False - - -@pytest.fixture -def enable_nightly_purge() -> bool: - """Fixture to control enabling of recorder's nightly purge job. - - To enable nightly purging, tests can be marked with: - @pytest.mark.parametrize("enable_nightly_purge", [True]) - """ - return False - - -@pytest.fixture -def enable_migrate_context_ids() -> bool: - """Fixture to control enabling of recorder's context id migration. - - To enable context id migration, tests can be marked with: - @pytest.mark.parametrize("enable_migrate_context_ids", [True]) - """ - return False - - -@pytest.fixture -def enable_migrate_event_type_ids() -> bool: - """Fixture to control enabling of recorder's event type id migration. - - To enable context id migration, tests can be marked with: - @pytest.mark.parametrize("enable_migrate_event_type_ids", [True]) - """ - return False - - -@pytest.fixture -def enable_migrate_entity_ids() -> bool: - """Fixture to control enabling of recorder's entity_id migration. - - To enable context id migration, tests can be marked with: - @pytest.mark.parametrize("enable_migrate_entity_ids", [True]) - """ - return False - - -@pytest.fixture -def recorder_config() -> dict[str, Any] | None: - """Fixture to override recorder config. - - To override the config, tests can be marked with: - @pytest.mark.parametrize("recorder_config", [{...}]) - """ - return None - - -@pytest.fixture -def recorder_db_url( - pytestconfig: pytest.Config, - hass_fixture_setup: list[bool], -) -> Generator[str, None, None]: - """Prepare a default database for tests and return a connection URL.""" - assert not hass_fixture_setup - - db_url = cast(str, pytestconfig.getoption("dburl")) - if db_url.startswith("mysql://"): - # pylint: disable-next=import-outside-toplevel - import sqlalchemy_utils - - charset = "utf8mb4' COLLATE = 'utf8mb4_unicode_ci" - assert not sqlalchemy_utils.database_exists(db_url) - sqlalchemy_utils.create_database(db_url, encoding=charset) - elif db_url.startswith("postgresql://"): - # pylint: disable-next=import-outside-toplevel - import sqlalchemy_utils - - assert not sqlalchemy_utils.database_exists(db_url) - sqlalchemy_utils.create_database(db_url, encoding="utf8") - yield db_url - if db_url.startswith("mysql://"): - # pylint: disable-next=import-outside-toplevel - import sqlalchemy as sa - - made_url = sa.make_url(db_url) - db = made_url.database - engine = sa.create_engine(db_url) - # Check for any open connections to the database before dropping it - # to ensure that InnoDB does not deadlock. - with engine.begin() as connection: - query = sa.text( - "select id FROM information_schema.processlist WHERE db=:db and id != CONNECTION_ID()" - ) - rows = connection.execute(query, parameters={"db": db}).fetchall() - if rows: - raise RuntimeError( - f"Unable to drop database {db} because it is in use by {rows}" - ) - engine.dispose() - sqlalchemy_utils.drop_database(db_url) - elif db_url.startswith("postgresql://"): - sqlalchemy_utils.drop_database(db_url) - - -@pytest.fixture -def hass_recorder( - recorder_db_url: str, - enable_nightly_purge: bool, - enable_statistics: bool, - enable_schema_validation: bool, - enable_migrate_context_ids: bool, - enable_migrate_event_type_ids: bool, - enable_migrate_entity_ids: bool, - hass_storage, -) -> Generator[Callable[..., HomeAssistant], None, None]: - """Home Assistant fixture with in-memory recorder.""" - # pylint: disable-next=import-outside-toplevel - from homeassistant.components import recorder - - # pylint: disable-next=import-outside-toplevel - from homeassistant.components.recorder import migration - - original_tz = dt_util.DEFAULT_TIME_ZONE - - hass = get_test_home_assistant() - nightly = recorder.Recorder.async_nightly_tasks if enable_nightly_purge else None - stats = recorder.Recorder.async_periodic_statistics if enable_statistics else None - schema_validate = ( - migration._find_schema_errors - if enable_schema_validation - else itertools.repeat(set()) - ) - migrate_states_context_ids = ( - recorder.Recorder._migrate_states_context_ids - if enable_migrate_context_ids - else None - ) - migrate_events_context_ids = ( - recorder.Recorder._migrate_events_context_ids - if enable_migrate_context_ids - else None - ) - migrate_event_type_ids = ( - recorder.Recorder._migrate_event_type_ids - if enable_migrate_event_type_ids - else None - ) - migrate_entity_ids = ( - recorder.Recorder._migrate_entity_ids if enable_migrate_entity_ids else None - ) - with patch( - "homeassistant.components.recorder.Recorder.async_nightly_tasks", - side_effect=nightly, - autospec=True, - ), patch( - "homeassistant.components.recorder.Recorder.async_periodic_statistics", - side_effect=stats, - autospec=True, - ), patch( - "homeassistant.components.recorder.migration._find_schema_errors", - side_effect=schema_validate, - autospec=True, - ), patch( - "homeassistant.components.recorder.Recorder._migrate_events_context_ids", - side_effect=migrate_events_context_ids, - autospec=True, - ), patch( - "homeassistant.components.recorder.Recorder._migrate_states_context_ids", - side_effect=migrate_states_context_ids, - autospec=True, - ), patch( - "homeassistant.components.recorder.Recorder._migrate_event_type_ids", - side_effect=migrate_event_type_ids, - autospec=True, - ), patch( - "homeassistant.components.recorder.Recorder._migrate_entity_ids", - side_effect=migrate_entity_ids, - autospec=True, - ): - - def setup_recorder(config: dict[str, Any] | None = None) -> HomeAssistant: - """Set up with params.""" - init_recorder_component(hass, config, recorder_db_url) - hass.start() - hass.block_till_done() - hass.data[recorder.DATA_INSTANCE].block_till_done() - return hass - - yield setup_recorder - hass.stop() - - # Restore timezone, it is set when creating the hass object - dt_util.DEFAULT_TIME_ZONE = original_tz - - -async def _async_init_recorder_component( - hass: HomeAssistant, - add_config: dict[str, Any] | None = None, - db_url: str | None = None, -) -> None: - """Initialize the recorder asynchronously.""" - # pylint: disable-next=import-outside-toplevel - from homeassistant.components import recorder - - config = dict(add_config) if add_config else {} - if recorder.CONF_DB_URL not in config: - config[recorder.CONF_DB_URL] = db_url - if recorder.CONF_COMMIT_INTERVAL not in config: - config[recorder.CONF_COMMIT_INTERVAL] = 0 - - with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True): - if recorder.DOMAIN not in hass.data: - recorder_helper.async_initialize_recorder(hass) - assert await async_setup_component( - hass, recorder.DOMAIN, {recorder.DOMAIN: config} - ) - assert recorder.DOMAIN in hass.config.components - _LOGGER.info( - "Test recorder successfully started, database location: %s", - config[recorder.CONF_DB_URL], - ) - - -@pytest.fixture -async def async_setup_recorder_instance( - recorder_db_url: str, - enable_nightly_purge: bool, - enable_statistics: bool, - enable_schema_validation: bool, - enable_migrate_context_ids: bool, - enable_migrate_event_type_ids: bool, - enable_migrate_entity_ids: bool, -) -> AsyncGenerator[RecorderInstanceGenerator, None]: - """Yield callable to setup recorder instance.""" - # pylint: disable-next=import-outside-toplevel - from homeassistant.components import recorder - - # pylint: disable-next=import-outside-toplevel - from homeassistant.components.recorder import migration - - # pylint: disable-next=import-outside-toplevel - from .components.recorder.common import async_recorder_block_till_done - - nightly = recorder.Recorder.async_nightly_tasks if enable_nightly_purge else None - stats = recorder.Recorder.async_periodic_statistics if enable_statistics else None - schema_validate = ( - migration._find_schema_errors - if enable_schema_validation - else itertools.repeat(set()) - ) - migrate_states_context_ids = ( - recorder.Recorder._migrate_states_context_ids - if enable_migrate_context_ids - else None - ) - migrate_events_context_ids = ( - recorder.Recorder._migrate_events_context_ids - if enable_migrate_context_ids - else None - ) - migrate_event_type_ids = ( - recorder.Recorder._migrate_event_type_ids - if enable_migrate_event_type_ids - else None - ) - migrate_entity_ids = ( - recorder.Recorder._migrate_entity_ids if enable_migrate_entity_ids else None - ) - with patch( - "homeassistant.components.recorder.Recorder.async_nightly_tasks", - side_effect=nightly, - autospec=True, - ), patch( - "homeassistant.components.recorder.Recorder.async_periodic_statistics", - side_effect=stats, - autospec=True, - ), patch( - "homeassistant.components.recorder.migration._find_schema_errors", - side_effect=schema_validate, - autospec=True, - ), patch( - "homeassistant.components.recorder.Recorder._migrate_events_context_ids", - side_effect=migrate_events_context_ids, - autospec=True, - ), patch( - "homeassistant.components.recorder.Recorder._migrate_states_context_ids", - side_effect=migrate_states_context_ids, - autospec=True, - ), patch( - "homeassistant.components.recorder.Recorder._migrate_event_type_ids", - side_effect=migrate_event_type_ids, - autospec=True, - ), patch( - "homeassistant.components.recorder.Recorder._migrate_entity_ids", - side_effect=migrate_entity_ids, - autospec=True, - ): - - async def async_setup_recorder( - hass: HomeAssistant, config: ConfigType | None = None - ) -> recorder.Recorder: - """Setup and return recorder instance.""" # noqa: D401 - await _async_init_recorder_component(hass, config, recorder_db_url) - await hass.async_block_till_done() - instance = hass.data[recorder.DATA_INSTANCE] - # The recorder's worker is not started until Home Assistant is running - if hass.state == CoreState.running: - await async_recorder_block_till_done(hass) - return instance - - yield async_setup_recorder - - -@pytest.fixture -async def recorder_mock( - recorder_config: dict[str, Any] | None, - async_setup_recorder_instance: RecorderInstanceGenerator, - hass: HomeAssistant, -) -> recorder.Recorder: - """Fixture with in-memory recorder.""" - return await async_setup_recorder_instance(hass, recorder_config) - - -@pytest.fixture -def mock_integration_frame() -> Generator[Mock, None, None]: - """Mock as if we're calling code from inside an integration.""" - correct_frame = Mock( - filename="/home/paulus/homeassistant/components/hue/light.py", - lineno="23", - line="self.light.is_on", - ) - with patch( - "homeassistant.helpers.frame.extract_stack", - return_value=[ - Mock( - filename="/home/paulus/homeassistant/core.py", - lineno="23", - line="do_something()", - ), - correct_frame, - Mock( - filename="/home/paulus/aiohue/lights.py", - lineno="2", - line="something()", - ), - ], - ): - yield correct_frame - - -@pytest.fixture(name="enable_bluetooth") -async def mock_enable_bluetooth( - hass: HomeAssistant, - mock_bleak_scanner_start: MagicMock, - mock_bluetooth_adapters: None, -) -> AsyncGenerator[None, None]: - """Fixture to mock starting the bleak scanner.""" - entry = MockConfigEntry(domain="bluetooth", unique_id="00:00:00:00:00:01") - entry.add_to_hass(hass) - await hass.config_entries.async_setup(entry.entry_id) - await hass.async_block_till_done() - yield - await hass.config_entries.async_unload(entry.entry_id) - await hass.async_block_till_done() - - -@pytest.fixture -def mock_bluetooth_adapters() -> Generator[None, None, None]: - """Fixture to mock bluetooth adapters.""" - with patch( - "bluetooth_adapters.systems.platform.system", return_value="Linux" - ), patch("bluetooth_adapters.systems.linux.LinuxAdapters.refresh"), patch( - "bluetooth_adapters.systems.linux.LinuxAdapters.adapters", - { - "hci0": { - "address": "00:00:00:00:00:01", - "hw_version": "usb:v1D6Bp0246d053F", - "passive_scan": False, - "sw_version": "homeassistant", - "manufacturer": "ACME", - "product": "Bluetooth Adapter 5.0", - "product_id": "aa01", - "vendor_id": "cc01", - }, - }, - ): - yield - - -@pytest.fixture -def mock_bleak_scanner_start() -> Generator[MagicMock, None, None]: - """Fixture to mock starting the bleak scanner.""" - - # Late imports to avoid loading bleak unless we need it - - # pylint: disable-next=import-outside-toplevel - from homeassistant.components.bluetooth import scanner as bluetooth_scanner - - # We need to drop the stop method from the object since we patched - # out start and this fixture will expire before the stop method is called - # when EVENT_HOMEASSISTANT_STOP is fired. - bluetooth_scanner.OriginalBleakScanner.stop = AsyncMock() # type: ignore[assignment] - with patch( - "homeassistant.components.bluetooth.scanner.OriginalBleakScanner.start", - ) as mock_bleak_scanner_start: - yield mock_bleak_scanner_start - - -@pytest.fixture -def mock_bluetooth( - mock_bleak_scanner_start: MagicMock, mock_bluetooth_adapters: None -) -> None: - """Mock out bluetooth from starting.""" - - -@pytest.fixture -def area_registry(hass: HomeAssistant) -> ar.AreaRegistry: - """Return the area registry from the current hass instance.""" - return ar.async_get(hass) - - -@pytest.fixture -def device_registry(hass: HomeAssistant) -> dr.DeviceRegistry: - """Return the device registry from the current hass instance.""" - return dr.async_get(hass) - - -@pytest.fixture -def entity_registry(hass: HomeAssistant) -> er.EntityRegistry: - """Return the entity registry from the current hass instance.""" - return er.async_get(hass) - - -@pytest.fixture -def issue_registry(hass: HomeAssistant) -> ir.IssueRegistry: - """Return the issue registry from the current hass instance.""" - return ir.async_get(hass) - - -@pytest.fixture -def snapshot(snapshot: SnapshotAssertion) -> SnapshotAssertion: - """Return snapshot assertion fixture with the Home Assistant extension.""" - return snapshot.use_extension(HomeAssistantSnapshotExtension) diff --git a/tests/hass/ignore_uncaught_exceptions.py b/tests/hass/ignore_uncaught_exceptions.py deleted file mode 100644 index e9327c0..0000000 --- a/tests/hass/ignore_uncaught_exceptions.py +++ /dev/null @@ -1,27 +0,0 @@ -"""List of tests that have uncaught exceptions today. Will be shrunk over time.""" -IGNORE_UNCAUGHT_EXCEPTIONS = [ - ( - # This test explicitly throws an uncaught exception - # and should not be removed. - "tests.test_runner", - "test_unhandled_exception_traceback", - ), - ( - "test_homeassistant_bridge", - "test_homeassistant_bridge_fan_setup", - ), - ( - "tests.components.owntracks.test_device_tracker", - "test_mobile_multiple_async_enter_exit", - ), - ( - "tests.components.smartthings.test_init", - "test_event_handler_dispatches_updated_devices", - ), - ( - "tests.components.unifi.test_controller", - "test_wireless_client_event_calls_update_wireless_devices", - ), - ("tests.components.iaqualink.test_config_flow", "test_with_invalid_credentials"), - ("tests.components.iaqualink.test_config_flow", "test_with_existing_config"), -] diff --git a/tests/hass/syrupy.py b/tests/hass/syrupy.py deleted file mode 100644 index 9433eb1..0000000 --- a/tests/hass/syrupy.py +++ /dev/null @@ -1,237 +0,0 @@ -"""Home Assistant extension for Syrupy.""" -from __future__ import annotations - -from contextlib import suppress -import dataclasses -from enum import IntFlag -from pathlib import Path -from typing import Any - -import attr -import attrs -from syrupy.extensions.amber import AmberDataSerializer, AmberSnapshotExtension -from syrupy.location import PyTestLocation -from syrupy.types import ( - PropertyFilter, - PropertyMatcher, - PropertyPath, - SerializableData, - SerializedData, -) -import voluptuous as vol -import voluptuous_serialize - -from homeassistant.config_entries import ConfigEntry -from homeassistant.core import State -from homeassistant.data_entry_flow import FlowResult -from homeassistant.helpers import ( - area_registry as ar, - device_registry as dr, - entity_registry as er, - issue_registry as ir, -) - - -class _ANY: - """Represent any value.""" - - def __repr__(self) -> str: - return "" - - -ANY = _ANY() - -__all__ = ["HomeAssistantSnapshotExtension"] - - -class AreaRegistryEntrySnapshot(dict): - """Tiny wrapper to represent an area registry entry in snapshots.""" - - -class ConfigEntrySnapshot(dict): - """Tiny wrapper to represent a config entry in snapshots.""" - - -class DeviceRegistryEntrySnapshot(dict): - """Tiny wrapper to represent a device registry entry in snapshots.""" - - -class EntityRegistryEntrySnapshot(dict): - """Tiny wrapper to represent an entity registry entry in snapshots.""" - - -class FlowResultSnapshot(dict): - """Tiny wrapper to represent a flow result in snapshots.""" - - -class IssueRegistryItemSnapshot(dict): - """Tiny wrapper to represent an entity registry entry in snapshots.""" - - -class StateSnapshot(dict): - """Tiny wrapper to represent an entity state in snapshots.""" - - -class HomeAssistantSnapshotSerializer(AmberDataSerializer): - """Home Assistant snapshot serializer for Syrupy. - - Handles special cases for Home Assistant data structures. - """ - - @classmethod - def _serialize( - cls, - data: SerializableData, - *, - depth: int = 0, - exclude: PropertyFilter | None = None, - matcher: PropertyMatcher | None = None, - path: PropertyPath = (), - visited: set[Any] | None = None, - ) -> SerializedData: - """Pre-process data before serializing. - - This allows us to handle specific cases for Home Assistant data structures. - """ - if isinstance(data, State): - serializable_data = cls._serializable_state(data) - elif isinstance(data, ar.AreaEntry): - serializable_data = cls._serializable_area_registry_entry(data) - elif isinstance(data, dr.DeviceEntry): - serializable_data = cls._serializable_device_registry_entry(data) - elif isinstance(data, er.RegistryEntry): - serializable_data = cls._serializable_entity_registry_entry(data) - elif isinstance(data, ir.IssueEntry): - serializable_data = cls._serializable_issue_registry_entry(data) - elif isinstance(data, dict) and "flow_id" in data and "handler" in data: - serializable_data = cls._serializable_flow_result(data) - elif isinstance(data, vol.Schema): - serializable_data = voluptuous_serialize.convert(data) - elif isinstance(data, ConfigEntry): - serializable_data = cls._serializable_config_entry(data) - elif dataclasses.is_dataclass(data): - serializable_data = dataclasses.asdict(data) - elif isinstance(data, IntFlag): - # The repr of an enum.IntFlag has changed between Python 3.10 and 3.11 - # so we normalize it here. - serializable_data = _IntFlagWrapper(data) - else: - serializable_data = data - with suppress(TypeError): - if attr.has(data): - serializable_data = attrs.asdict(data) - - return super()._serialize( - serializable_data, - depth=depth, - exclude=exclude, - matcher=matcher, - path=path, - visited=visited, - ) - - @classmethod - def _serializable_area_registry_entry(cls, data: ar.AreaEntry) -> SerializableData: - """Prepare a Home Assistant area registry entry for serialization.""" - serialized = AreaRegistryEntrySnapshot(attrs.asdict(data) | {"id": ANY}) - serialized.pop("_json_repr") - return serialized - - @classmethod - def _serializable_config_entry(cls, data: ConfigEntry) -> SerializableData: - """Prepare a Home Assistant config entry for serialization.""" - return ConfigEntrySnapshot(data.as_dict() | {"entry_id": ANY}) - - @classmethod - def _serializable_device_registry_entry( - cls, data: dr.DeviceEntry - ) -> SerializableData: - """Prepare a Home Assistant device registry entry for serialization.""" - serialized = DeviceRegistryEntrySnapshot( - attrs.asdict(data) - | { - "config_entries": ANY, - "id": ANY, - } - ) - if serialized["via_device_id"] is not None: - serialized["via_device_id"] = ANY - serialized.pop("_json_repr") - return serialized - - @classmethod - def _serializable_entity_registry_entry( - cls, data: er.RegistryEntry - ) -> SerializableData: - """Prepare a Home Assistant entity registry entry for serialization.""" - serialized = EntityRegistryEntrySnapshot( - attrs.asdict(data) - | { - "config_entry_id": ANY, - "device_id": ANY, - "id": ANY, - "options": {k: dict(v) for k, v in data.options.items()}, - } - ) - serialized.pop("_partial_repr") - serialized.pop("_display_repr") - return serialized - - @classmethod - def _serializable_flow_result(cls, data: FlowResult) -> SerializableData: - """Prepare a Home Assistant flow result for serialization.""" - return FlowResultSnapshot(data | {"flow_id": ANY}) - - @classmethod - def _serializable_issue_registry_entry( - cls, data: ir.IssueEntry - ) -> SerializableData: - """Prepare a Home Assistant issue registry entry for serialization.""" - return IssueRegistryItemSnapshot(data.to_json() | {"created": ANY}) - - @classmethod - def _serializable_state(cls, data: State) -> SerializableData: - """Prepare a Home Assistant State for serialization.""" - return StateSnapshot( - data.as_dict() - | { - "context": ANY, - "last_changed": ANY, - "last_updated": ANY, - } - ) - - -class _IntFlagWrapper: - def __init__(self, flag: IntFlag) -> None: - self._flag = flag - - def __repr__(self) -> str: - # 3.10: - # 3.11: - # Syrupy: - return f"<{self._flag.__class__.__name__}: {self._flag.value}>" - - -class HomeAssistantSnapshotExtension(AmberSnapshotExtension): - """Home Assistant extension for Syrupy.""" - - VERSION = "1" - """Current version of serialization format. - - Need to be bumped when we change the HomeAssistantSnapshotSerializer. - """ - - serializer_class: type[AmberDataSerializer] = HomeAssistantSnapshotSerializer - - @classmethod - def dirname(cls, *, test_location: PyTestLocation) -> str: - """Return the directory for the snapshot files. - - Syrupy, by default, uses the `__snapshosts__` directory in the same - folder as the test file. For Home Assistant, this is changed to just - `snapshots` in the same folder as the test file, to match our `fixtures` - folder structure. - """ - test_dir = Path(test_location.filepath).parent - return str(test_dir.joinpath("snapshots")) diff --git a/tests/hass/test_util/__init__.py b/tests/hass/test_util/__init__.py deleted file mode 100644 index b849967..0000000 --- a/tests/hass/test_util/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for the test utilities.""" diff --git a/tests/hass/test_util/aiohttp.py b/tests/hass/test_util/aiohttp.py deleted file mode 100644 index 356240d..0000000 --- a/tests/hass/test_util/aiohttp.py +++ /dev/null @@ -1,332 +0,0 @@ -"""Aiohttp test utils.""" -import asyncio -from contextlib import contextmanager -from http import HTTPStatus -import re -from unittest import mock -from urllib.parse import parse_qs - -from aiohttp import ClientSession -from aiohttp.client_exceptions import ClientError, ClientResponseError -from aiohttp.streams import StreamReader -from multidict import CIMultiDict -from yarl import URL - -from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE -from homeassistant.helpers.json import json_dumps -from homeassistant.util.json import json_loads - -RETYPE = type(re.compile("")) - - -def mock_stream(data): - """Mock a stream with data.""" - protocol = mock.Mock(_reading_paused=False) - stream = StreamReader(protocol, limit=2**16) - stream.feed_data(data) - stream.feed_eof() - return stream - - -class AiohttpClientMocker: - """Mock Aiohttp client requests.""" - - def __init__(self): - """Initialize the request mocker.""" - self._mocks = [] - self._cookies = {} - self.mock_calls = [] - - def request( - self, - method, - url, - *, - auth=None, - status=HTTPStatus.OK, - text=None, - data=None, - content=None, - json=None, - params=None, - headers={}, - exc=None, - cookies=None, - side_effect=None, - ): - """Mock a request.""" - if not isinstance(url, RETYPE): - url = URL(url) - if params: - url = url.with_query(params) - - self._mocks.append( - AiohttpClientMockResponse( - method=method, - url=url, - status=status, - response=content, - json=json, - text=text, - cookies=cookies, - exc=exc, - headers=headers, - side_effect=side_effect, - ) - ) - - def get(self, *args, **kwargs): - """Register a mock get request.""" - self.request("get", *args, **kwargs) - - def put(self, *args, **kwargs): - """Register a mock put request.""" - self.request("put", *args, **kwargs) - - def post(self, *args, **kwargs): - """Register a mock post request.""" - self.request("post", *args, **kwargs) - - def delete(self, *args, **kwargs): - """Register a mock delete request.""" - self.request("delete", *args, **kwargs) - - def options(self, *args, **kwargs): - """Register a mock options request.""" - self.request("options", *args, **kwargs) - - def patch(self, *args, **kwargs): - """Register a mock patch request.""" - self.request("patch", *args, **kwargs) - - @property - def call_count(self): - """Return the number of requests made.""" - return len(self.mock_calls) - - def clear_requests(self): - """Reset mock calls.""" - self._mocks.clear() - self._cookies.clear() - self.mock_calls.clear() - - def create_session(self, loop): - """Create a ClientSession that is bound to this mocker.""" - session = ClientSession(loop=loop, json_serialize=json_dumps) - # Setting directly on `session` will raise deprecation warning - object.__setattr__(session, "_request", self.match_request) - return session - - async def match_request( - self, - method, - url, - *, - data=None, - auth=None, - params=None, - headers=None, - allow_redirects=None, - timeout=None, - json=None, - cookies=None, - **kwargs, - ): - """Match a request against pre-registered requests.""" - data = data or json - url = URL(url) - if params: - url = url.with_query(params) - - for response in self._mocks: - if response.match_request(method, url, params): - self.mock_calls.append((method, url, data, headers)) - if response.side_effect: - response = await response.side_effect(method, url, data) - if response.exc: - raise response.exc - return response - - raise AssertionError(f"No mock registered for {method.upper()} {url} {params}") - - -class AiohttpClientMockResponse: - """Mock Aiohttp client response.""" - - def __init__( - self, - method, - url, - status=HTTPStatus.OK, - response=None, - json=None, - text=None, - cookies=None, - exc=None, - headers=None, - side_effect=None, - ): - """Initialize a fake response.""" - if json is not None: - text = json_dumps(json) - if text is not None: - response = text.encode("utf-8") - if response is None: - response = b"" - - self.charset = "utf-8" - self.method = method - self._url = url - self.status = status - self.response = response - self.exc = exc - self.side_effect = side_effect - self._headers = CIMultiDict(headers or {}) - self._cookies = {} - - if cookies: - for name, data in cookies.items(): - cookie = mock.MagicMock() - cookie.value = data - self._cookies[name] = cookie - - def match_request(self, method, url, params=None): - """Test if response answers request.""" - if method.lower() != self.method.lower(): - return False - - # regular expression matching - if isinstance(self._url, RETYPE): - return self._url.search(str(url)) is not None - - if ( - self._url.scheme != url.scheme - or self._url.host != url.host - or self._url.path != url.path - ): - return False - - # Ensure all query components in matcher are present in the request - request_qs = parse_qs(url.query_string) - matcher_qs = parse_qs(self._url.query_string) - for key, vals in matcher_qs.items(): - for val in vals: - try: - request_qs.get(key, []).remove(val) - except ValueError: - return False - - return True - - @property - def headers(self): - """Return content_type.""" - return self._headers - - @property - def cookies(self): - """Return dict of cookies.""" - return self._cookies - - @property - def url(self): - """Return yarl of URL.""" - return self._url - - @property - def content_type(self): - """Return yarl of URL.""" - return self._headers.get("content-type") - - @property - def content(self): - """Return content.""" - return mock_stream(self.response) - - async def read(self): - """Return mock response.""" - return self.response - - async def text(self, encoding="utf-8", errors="strict"): - """Return mock response as a string.""" - return self.response.decode(encoding, errors=errors) - - async def json(self, encoding="utf-8", content_type=None, loads=json_loads): - """Return mock response as a json.""" - return loads(self.response.decode(encoding)) - - def release(self): - """Mock release.""" - - def raise_for_status(self): - """Raise error if status is 400 or higher.""" - if self.status >= 400: - request_info = mock.Mock(real_url="http://example.com") - raise ClientResponseError( - request_info=request_info, - history=None, - status=self.status, - headers=self.headers, - ) - - def close(self): - """Mock close.""" - - -@contextmanager -def mock_aiohttp_client(): - """Context manager to mock aiohttp client.""" - mocker = AiohttpClientMocker() - - def create_session(hass, *args, **kwargs): - session = mocker.create_session(hass.loop) - - async def close_session(event): - """Close session.""" - await session.close() - - hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, close_session) - - return session - - with mock.patch( - "homeassistant.helpers.aiohttp_client._async_create_clientsession", - side_effect=create_session, - ): - yield mocker - - -class MockLongPollSideEffect: - """Imitate a long_poll request. - - It should be created and used as a side effect for a GET/PUT/etc. request. - Once created, actual responses are queued with queue_response - If queue is empty, will await until done. - """ - - def __init__(self): - """Initialize the queue.""" - self.semaphore = asyncio.Semaphore(0) - self.response_list = [] - self.stopping = False - - async def __call__(self, method, url, data): - """Fetch the next response from the queue or wait until the queue has items.""" - if self.stopping: - raise ClientError() - await self.semaphore.acquire() - kwargs = self.response_list.pop(0) - return AiohttpClientMockResponse(method=method, url=url, **kwargs) - - def queue_response(self, **kwargs): - """Add a response to the long_poll queue.""" - self.response_list.append(kwargs) - self.semaphore.release() - - def stop(self): - """Stop the current request and future ones. - - This avoids an exception if there is someone waiting when exiting test. - """ - self.stopping = True - self.queue_response(exc=ClientError()) diff --git a/tests/hass/typing.py b/tests/hass/typing.py deleted file mode 100644 index 7c5391d..0000000 --- a/tests/hass/typing.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Typing helpers for Home Assistant tests.""" -from __future__ import annotations - -from collections.abc import Callable, Coroutine -from typing import TYPE_CHECKING, Any, TypeAlias -from unittest.mock import MagicMock - -from aiohttp import ClientWebSocketResponse -from aiohttp.test_utils import TestClient - -if TYPE_CHECKING: - # Local import to avoid processing recorder module when running a - # testcase which does not use the recorder. - from homeassistant.components.recorder import Recorder - - -class MockHAClientWebSocket(ClientWebSocketResponse): - """Protocol for a wrapped ClientWebSocketResponse.""" - - client: TestClient - send_json_auto_id: Callable[[dict[str, Any]], Coroutine[Any, Any, None]] - - -ClientSessionGenerator = Callable[..., Coroutine[Any, Any, TestClient]] -MqttMockPahoClient = MagicMock -"""MagicMock for `paho.mqtt.client.Client`""" -MqttMockHAClient = MagicMock -"""MagicMock for `homeassistant.components.mqtt.MQTT`.""" -MqttMockHAClientGenerator = Callable[..., Coroutine[Any, Any, MqttMockHAClient]] -"""MagicMock generator for `homeassistant.components.mqtt.MQTT`.""" -RecorderInstanceGenerator: TypeAlias = Callable[..., Coroutine[Any, Any, "Recorder"]] -"""Instance generator for `homeassistant.components.recorder.Recorder`.""" -WebSocketGenerator = Callable[..., Coroutine[Any, Any, MockHAClientWebSocket]]