-
Notifications
You must be signed in to change notification settings - Fork 7
/
__init__.py
298 lines (236 loc) · 11.5 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
"""
Python library for interacting with a Rhasspy server over HTTP.
For more information on Rhasspy, please see:
https://rhasspy.readthedocs.io/
"""
import configparser
import io
import logging
import re
from collections import defaultdict
from typing import Any, Dict, List, Set, Tuple
from urllib.parse import urljoin
import aiohttp
from rhasspyclient.speech import Transcription, TranscriptionResult
from rhasspyclient.train import TrainingComplete, TrainingResult
_LOGGER = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class RhasspyClient:
"""Client object for remote Rhasspy server."""
def __init__(self, api_url: str, session: aiohttp.ClientSession):
self.api_url = api_url
if not self.api_url.endswith("/"):
self.api_url += "/"
# Construct URLs for end-points
self.sentences_url = urljoin(self.api_url, "sentences")
self.custom_words_url = urljoin(self.api_url, "custom-words")
self.slots_url = urljoin(self.api_url, "slots")
self.train_url = urljoin(self.api_url, "train")
self.stt_url = urljoin(self.api_url, "speech-to-text")
self.intent_url = urljoin(self.api_url, "text-to-intent")
self.tts_url = urljoin(self.api_url, "text-to-speech")
self.restart_url = urljoin(self.api_url, "restart")
self.wakeup_url = urljoin(self.api_url, "listen-for-command")
self.profile_url = urljoin(self.api_url, "profile")
self.lookup_url = urljoin(self.api_url, "lookup")
self.version_url = urljoin(self.api_url, "version")
self.session = session
assert self.session is not None, "ClientSession is required"
# -------------------------------------------------------------------------
async def get_sentences(self) -> Dict[str, List[str]]:
"""GET sentences.ini from server. Return sentences grouped by intent."""
async with self.session.get(self.sentences_url) as response:
# Parse ini
parser = configparser.ConfigParser(
allow_no_value=True, strict=False, delimiters=["="]
)
# case sensitive
parser.optionxform = str # type: ignore
parser.read_string(await response.text())
# Group sentences by intent
sentences: Dict[str, List[str]] = defaultdict(list)
for intent_name in parser.sections():
for key, value in parser[intent_name]:
if value is None:
# Sentence
sentences[intent_name].append(value)
else:
# Rule
sentences[intent_name].append(f"{key} = {value}")
return sentences
async def set_sentences(self, sentences: Dict[str, List[str]]) -> str:
"""POST sentences.ini to server from sentences grouped by intent."""
with io.StringIO() as sentences_file:
for intent_name in sorted(sentences):
print(f"[{intent_name}]", file=sentences_file)
for sentence in sorted(sentences[intent_name]):
if sentence.startswith("["):
# Escape initial [
sentence = "\\" + sentence
print(sentence.strip(), file=sentences_file)
# Blank line
print("", file=sentences_file)
# POST to server
async with self.session.post(
self.sentences_url, data=sentences_file.getvalue()
) as response:
response.raise_for_status()
return await response.text()
# -------------------------------------------------------------------------
async def get_custom_words(self) -> Dict[str, Set[str]]:
"""GET custom words from server. Return pronunciations grouped by word."""
async with self.session.get(self.custom_words_url) as response:
# Group pronunciations by word
pronunciations: Dict[str, Set[str]] = defaultdict(set)
async for line_bytes in response.content:
line = line_bytes.decode().strip()
# Skip blank lines
if len(line) == 0:
continue
word, pronunciation = re.split(r"\s+", line, maxsplit=1)
pronunciations[word].add(pronunciation)
return pronunciations
async def set_custom_words(self, pronunciations: Dict[str, Set[str]]) -> str:
"""POST custom words to server from pronunciations grouped by word."""
with io.StringIO() as custom_words_file:
for word in sorted(pronunciations):
word_pronunciations = pronunciations[word]
if isinstance(word_pronunciations, str):
word_pronunciations = [word_pronunciations]
for pronunciation in sorted(word_pronunciations):
print(word, pronunciation, file=custom_words_file)
# POST to server
async with self.session.post(
self.custom_words_url, data=custom_words_file.getvalue()
) as response:
response.raise_for_status()
return await response.text()
# -------------------------------------------------------------------------
async def train(self, no_cache=False) -> TrainingComplete:
"""Train Rhasspy profile. Delete doit database when no_cache is True."""
params = {}
if no_cache:
params["no_cache"] = "true"
async with self.session.post(self.train_url, params=params) as response:
text = await response.text()
try:
response.raise_for_status()
return TrainingComplete(result=TrainingResult.SUCCESS)
except Exception:
_LOGGER.exception("train")
return TrainingComplete(result=TrainingResult.FAILURE, errors=text)
# -------------------------------------------------------------------------
async def speech_to_text(self, wav_data: bytes) -> Transcription:
"""Transcribe WAV audio."""
headers = {"Content-Type": "audio/wav"}
async with self.session.post(
self.stt_url, headers=headers, data=wav_data
) as response:
text = await response.text()
try:
response.raise_for_status()
assert text
return Transcription(result=TranscriptionResult.SUCCESS, text=text)
except Exception:
_LOGGER.exception("speech_to_text")
return Transcription(result=TranscriptionResult.FAILURE)
# -------------------------------------------------------------------------
async def text_to_intent(
self, text: str, handle_intent: bool = False
) -> Dict[str, Any]:
"""
Recognize intent from text.
If handle_intent is True, Rhasspy will forward to Home Assistant.
"""
params = {"nohass": str(not handle_intent)}
async with self.session.post(
self.intent_url, params=params, data=text
) as response:
response.raise_for_status()
return await response.json()
# -------------------------------------------------------------------------
async def text_to_speech(self, text: str, repeat: bool = False) -> bytes:
"""
Generate speech from text.
If repeat is True, Rhasspy wil repeat the last spoken sentence.
"""
params = {"repeat": str(repeat)}
async with self.session.post(
self.tts_url, params=params, data=text
) as response:
response.raise_for_status()
return await response.read()
# -------------------------------------------------------------------------
async def get_slots(self) -> Dict[str, List[str]]:
"""GET slots/values from server. Return values grouped by slot."""
async with self.session.get(self.slots_url) as response:
return await response.json()
async def set_slots(self, slots: Dict[str, List[str]], overwrite=True) -> str:
"""
POST slots/values to server as values grouped by slot.
If overwrite is False, values are appended to existing slot.
"""
params = {"overwrite_all": str(overwrite)}
async with self.session.post(
self.slots_url, params=params, json=slots
) as response:
response.raise_for_status()
return await response.text()
# -------------------------------------------------------------------------
async def restart(self) -> str:
"""Restart Rhasspy server."""
async with self.session.post(self.restart_url) as response:
response.raise_for_status()
return await response.text()
# -------------------------------------------------------------------------
async def version(self) -> str:
"""Get Rhasspy version."""
async with self.session.get(self.version_url) as response:
response.raise_for_status()
return await response.text()
# -------------------------------------------------------------------------
async def wakeup_and_wait(self, handle_intent=False) -> Dict[str, Any]:
"""
Wake up Rhasspy so it starts listening for a voice command.
If handle_intent is True, Rhasspy will forward to Home Assistant.
"""
params = {"nohass": str(not handle_intent)}
async with self.session.post(self.wakeup_url, params=params) as response:
response.raise_for_status()
return await response.json()
# -------------------------------------------------------------------------
async def get_profile(self, defaults=True) -> Dict[str, List[str]]:
"""GET current profile. Include default settings when defaults is True."""
params = {"layers": "all" if defaults else "profile"}
async with self.session.get(self.profile_url, params=params) as response:
return await response.json()
async def set_profile(self, profile: Dict[str, List[str]]) -> str:
"""
POST slots/values to server as values grouped by slot.
If overwrite is False, values are appended to existing slot.
"""
async with self.session.post(self.profile_url, json=profile) as response:
response.raise_for_status()
return await response.text()
# -------------------------------------------------------------------------
async def get_pronunciations(self, word: str, n: int = 5) -> Tuple[bool, List[str]]:
"""
Get pronunciations for a word.
Returns if the word was in the dictionary and the pronunciations.
"""
params = {"n": str(n)}
async with self.session.post(
self.lookup_url, params=params, data=word
) as response:
response.raise_for_status()
result = await response.json()
return (result["in_dictionary"], result["pronunciations"])
# -------------------------------------------------------------------------
async def stream_to_text(self, raw_stream: aiohttp.StreamReader) -> str:
"""Stream raw 16-bit 16Khz mono audio to server. Return transcription."""
params = {"noheader": "true"}
async with self.session.post(
self.stt_url, params=params, data=raw_stream
) as response:
response.raise_for_status()
return await response.text()