Skip to content
This repository has been archived by the owner on Dec 4, 2024. It is now read-only.

Commit

Permalink
Fix #26: Implement Entity as HomeAssistant requires
Browse files Browse the repository at this point in the history
  • Loading branch information
m50 committed Sep 16, 2024
1 parent c066e5c commit 27aaace
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 154 deletions.
169 changes: 24 additions & 145 deletions custom_components/fallback_conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,167 +3,46 @@

import logging

from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.util import ulid
from home_assistant_intents import get_languages


from homeassistant.helpers import (
config_validation as cv,
intent,
)

from .const import (
CONF_DEBUG_LEVEL,
CONF_PRIMARY_AGENT,
CONF_FALLBACK_AGENT,
DEBUG_LEVEL_NO_DEBUG,
DEBUG_LEVEL_LOW_DEBUG,
DEBUG_LEVEL_VERBOSE_DEBUG,
DOMAIN,
STRANGE_ERROR_RESPONSES,
)
from homeassistant.const import Platform
from homeassistant.helpers import config_validation as cv

from .const import DOMAIN

_LOGGER = logging.getLogger(__name__)

CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)

PLATFORMS = (Platform.CONVERSATION,)

# hass.data key for agent.
DATA_AGENT = "agent"


async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up Fallback Conversation Agent from a config entry."""
agent = FallbackConversationAgent(hass, entry)

conversation.async_set_agent(hass, entry, agent)
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = entry

await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)

return True

async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload."""
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
return False
hass.data[DOMAIN].pop(entry.entry_id)
return True

async def async_migrate_entry(hass, config_entry: ConfigEntry):
"""Migrate old entry."""
_LOGGER.debug("Migrating from version %s", config_entry.version)

if config_entry.version == 1:
_LOGGER.error("Cannot upgrade models that were created prior to v0.3. Please delete and re-create them.")
return False

class FallbackConversationAgent(conversation.AbstractConversationAgent):
"""Fallback Conversation Agent."""

last_used_agent: str | None

def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Initialize the agent."""
self.hass = hass
self.entry = entry
self.last_used_agent = None

@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return get_languages()

async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
agent_manager = conversation.get_agent_manager(self.hass)
default_agent = conversation.default_agent.async_get_default_agent(self.hass)
agent_names = self._convert_agent_info_to_dict(
agent_manager.async_get_agent_info()
)
agent_names[conversation.const.HOME_ASSISTANT_AGENT] = default_agent.name
agent_names[conversation.const.OLD_HOME_ASSISTANT_AGENT] = default_agent.name
agents = [
self.entry.options.get(CONF_PRIMARY_AGENT, default_agent),
self.entry.options.get(CONF_FALLBACK_AGENT, default_agent),
]

debug_level = self.entry.options.get(CONF_DEBUG_LEVEL, DEBUG_LEVEL_NO_DEBUG)

if user_input.conversation_id is None:
user_input.conversation_id = ulid.ulid()

all_results = []
result = None
for agent_id in agents:
agent_name = "[unknown]"
if agent_id in agent_names:
agent_name = agent_names[agent_id]
else:
_LOGGER.warning("agent_name not found for agent_id %s", agent_id)

result = await self._async_process_agent(
agent_manager,
agent_id,
agent_name,
user_input,
debug_level,
result,
)
if result.response.response_type != intent.IntentResponseType.ERROR and result.response.speech['plain']['original_speech'].lower() not in STRANGE_ERROR_RESPONSES:
return result
all_results.append(result)

intent_response = intent.IntentResponse(language=user_input.language)
err = "Complete fallback failure. No Conversation Agent was able to respond."
if debug_level == DEBUG_LEVEL_LOW_DEBUG:
r = all_results[-1].response.speech['plain']
err += f"\n{r.get('agent_name', 'UNKNOWN')} responded with: {r.get('original_speech', r['speech'])}"
elif debug_level == DEBUG_LEVEL_VERBOSE_DEBUG:
for res in all_results:
r = res.response.speech['plain']
err += f"\n{r.get('agent_name', 'UNKNOWN')} responded with: {r.get('original_speech', r['speech'])}"
intent_response.async_set_error(
intent.IntentResponseErrorCode.NO_INTENT_MATCH,
err,
)
result = conversation.ConversationResult(
conversation_id=result.conversation_id,
response=intent_response
)

return result

async def _async_process_agent(
self,
agent_manager: conversation.AgentManager,
agent_id: str,
agent_name: str,
user_input: conversation.ConversationInput,
debug_level: int,
previous_result,
) -> conversation.ConversationResult:
"""Process a specified agent."""
agent = conversation.agent_manager.async_get_agent(self.hass, agent_id)

_LOGGER.debug("Processing in %s using %s with debug level %s: %s", user_input.language, agent_id, debug_level, user_input.text)

result = await agent.async_process(user_input)
r = result.response.speech['plain']['speech']
result.response.speech['plain']['original_speech'] = r
result.response.speech['plain']['agent_name'] = agent_name
result.response.speech['plain']['agent_id'] = agent_id
if debug_level == DEBUG_LEVEL_LOW_DEBUG:
result.response.speech['plain']['speech'] = f"{agent_name} responded with: {r}"
elif debug_level == DEBUG_LEVEL_VERBOSE_DEBUG:
if previous_result is not None:
pr = previous_result.response.speech['plain'].get('original_speech', previous_result.response.speech['plain']['speech'])
result.response.speech['plain']['speech'] = f"{previous_result.response.speech['plain'].get('agent_name', 'UNKNOWN')} failed with response: {pr} Then {agent_name} responded with {r}"
else:
result.response.speech['plain']['speech'] = f"{agent_name} responded with: {r}"

return result

def _convert_agent_info_to_dict(self, agents_info: list[conversation.AgentInfo]) -> dict[str, str]:
"""Takes a list of AgentInfo and makes it a dict of ID -> Name."""

agent_manager = conversation.get_agent_manager(self.hass)

r = {}
for agent_info in agents_info:
agent = agent_manager.async_get_agent(agent_info.id)
agent_id = agent_info.id
if hasattr(agent, "registry_entry"):
agent_id = agent.registry_entry.entity_id
r[agent_id] = agent_info.name
_LOGGER.debug("agent_id %s has name %s", agent_id, agent_info.name)
return r
_LOGGER.debug("Migration to version %s successful", config_entry.version)

return True
16 changes: 8 additions & 8 deletions custom_components/fallback_conversation/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

from homeassistant import config_entries
from homeassistant.const import CONF_NAME
from homeassistant.core import HomeAssistant, async_get_hass, callback
from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.selector import (
ConversationAgentSelector,
ConversationAgentSelector,
ConversationAgentSelectorConfig,
SelectSelector,
SelectSelectorConfig,
Expand Down Expand Up @@ -53,19 +53,19 @@
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Fallback Agent config flow."""

VERSION = 1
VERSION = 2

async def async_step_user(self, user_input: dict[str, Any] | None = None) -> FlowResult:
"""Handle the initial step."""
_LOGGER.debug("ConfigFlow::user_input %s", user_input)
if user_input is None:
return self.async_show_form(
step_id="user",
data_schema=STEP_USER_DATA_SCHEMA,
step_id="user",
data_schema=STEP_USER_DATA_SCHEMA,
)

return self.async_create_entry(
title=user_input.get(CONF_NAME, DEFAULT_NAME),
title=user_input.get(CONF_NAME, DEFAULT_NAME),
data=user_input,
)

Expand All @@ -91,7 +91,7 @@ async def async_step_init(
if user_input is not None:
self._options.update(user_input)
return self.async_create_entry(
title=user_input.get(CONF_NAME, DEFAULT_NAME),
title=user_input.get(CONF_NAME, DEFAULT_NAME),
data=self._options,
)

Expand All @@ -106,7 +106,7 @@ async def fallback_config_option_schema(self, options: dict) -> dict:
"""Return a schema for Fallback options."""
return {
vol.Required(
CONF_DEBUG_LEVEL,
CONF_DEBUG_LEVEL,
description={"suggested_value": options.get(CONF_DEBUG_LEVEL, DEFAULT_DEBUG_LEVEL)},
default=DEFAULT_DEBUG_LEVEL,
): SelectSelector(
Expand Down
Loading

0 comments on commit 27aaace

Please sign in to comment.