forked from dusty-nv/jetson-containers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
568 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#--- | ||
# name: llamaspeak | ||
# group: llm | ||
# depends: [riva-client:python] | ||
# requires: '>=34.1.0' | ||
# docs: docs.md | ||
#--- | ||
ARG BASE_IMAGE | ||
FROM ${BASE_IMAGE} | ||
|
||
COPY requirements.txt /opt/llamaspeak/ | ||
RUN pip3 install --no-cache-dir --verbose -r /opt/llamaspeak/requirements.txt | ||
|
||
COPY *.py /opt/llamaspeak/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#!/usr/bin/env python3 | ||
import pprint | ||
import threading | ||
|
||
import riva.client | ||
import riva.client.audio_io | ||
|
||
|
||
class ASR(threading.Thread): | ||
""" | ||
Streaming ASR service | ||
""" | ||
def __init__(self, auth, input_device=0, sample_rate_hz=44100, audio_chunk=1600, audio_channels=1, | ||
automatic_punctuation=True, verbatim_transcripts=True, profanity_filter=False, | ||
language_code='en-US', boosted_lm_words=None, boosted_lm_score=4.0, callback=None, **kwargs): | ||
|
||
super(ASR, self).__init__() | ||
|
||
self.callback=callback | ||
self.language_code = language_code | ||
|
||
self.asr_service = riva.client.ASRService(auth) | ||
|
||
self.asr_config = riva.client.StreamingRecognitionConfig( | ||
config=riva.client.RecognitionConfig( | ||
encoding=riva.client.AudioEncoding.LINEAR_PCM, | ||
language_code=language_code, | ||
max_alternatives=1, | ||
profanity_filter=profanity_filter, | ||
enable_automatic_punctuation=automatic_punctuation, | ||
verbatim_transcripts=verbatim_transcripts, | ||
sample_rate_hertz=sample_rate_hz, | ||
audio_channel_count=audio_channels, | ||
), | ||
interim_results=True, | ||
) | ||
|
||
riva.client.add_word_boosting_to_config(self.asr_config, boosted_lm_words, boosted_lm_score) | ||
|
||
self.mic_stream = riva.client.audio_io.MicrophoneStream( | ||
sample_rate_hz, | ||
audio_chunk, | ||
device=input_device, | ||
).__enter__() | ||
|
||
self.responses = self.asr_service.streaming_response_generator( | ||
audio_chunks=self.mic_stream, streaming_config=self.asr_config | ||
) | ||
|
||
def run(self): | ||
print(f"-- running ASR service ({self.language_code})") | ||
|
||
for response in self.responses: | ||
if not response.results: | ||
continue | ||
|
||
for result in response.results: | ||
if self.callback is not None: | ||
self.callback(result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
#!/usr/bin/env python3 | ||
# python3 chat.py --input-device=24 --output-device=24 --sample-rate-hz=48000 | ||
import sys | ||
import time | ||
import pprint | ||
import argparse | ||
import threading | ||
|
||
import riva.client | ||
import riva.client.audio_io | ||
|
||
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters | ||
|
||
from asr import ASR | ||
from tts import TTS | ||
from llm import LLM | ||
|
||
|
||
def parse_args(): | ||
""" | ||
Parse command-line arguments for configuring the chatbot. | ||
""" | ||
default_device_info = riva.client.audio_io.get_default_input_device_info() | ||
default_device_index = None if default_device_info is None else default_device_info['index'] | ||
|
||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
|
||
# audio I/O | ||
parser.add_argument("--list-devices", action="store_true", help="List output audio devices indices.") | ||
parser.add_argument("--input-device", type=int, default=default_device_index, help="An input audio device to use.") | ||
parser.add_argument("--output-device", type=int, help="Output device to use.") | ||
parser.add_argument("--sample-rate-hz", type=int, default=44100, help="Number of audio frames per second in synthesized audio.") | ||
parser.add_argument("--audio-chunk", type=int, default=1600, help="A maximum number of frames in a audio chunk sent to server.") | ||
parser.add_argument("--audio-channels", type=int, default=1, help="The number of audio channels to use") | ||
|
||
# ASR/TTS settings | ||
parser.add_argument("--voice", type=str, default='English-US.Female-1', help="A voice name to use for TTS") | ||
parser.add_argument("--no-punctuation", action='store_true', help="Disable ASR automatic punctuation") | ||
|
||
# LLM settings | ||
parser.add_argument("--llm-server", type=str, default='0.0.0.0', help="hostname of the LLM server (text-generation-webui)") | ||
parser.add_argument("--llm-api-port", type=int, default=5000, help="port of the blocking API on the LLM server") | ||
parser.add_argument("--llm-streaming-port", type=int, default=5005, help="port of the streaming websocket API on the LLM server") | ||
|
||
parser = add_asr_config_argparse_parameters(parser, profanity_filter=True) | ||
parser = add_connection_argparse_parameters(parser) | ||
|
||
args = parser.parse_args() | ||
|
||
args.automatic_punctuation = not args.no_punctuation | ||
args.verbatim_transcripts = not args.no_verbatim_transcripts | ||
|
||
print(args) | ||
return args | ||
|
||
|
||
class Chatbot(threading.Thread): | ||
""" | ||
LLM-based chatbot with streaming ASR/TTS | ||
""" | ||
def __init__(self, args, **kwargs): | ||
super(Chatbot, self).__init__() | ||
|
||
self.args = args | ||
self.auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server) | ||
|
||
self.asr = ASR(self.auth, callback=self.on_asr_transcript, **vars(args)) | ||
self.tts = TTS(self.auth, **vars(args)) | ||
self.llm = LLM(**vars(args)) | ||
|
||
self.asr_history = "" | ||
|
||
def on_asr_transcript(self, result): | ||
""" | ||
Recieve new ASR responses | ||
""" | ||
transcript = result.alternatives[0].transcript.strip() | ||
|
||
if result.is_final: | ||
print(f"## {transcript} ({result.alternatives[0].confidence})") | ||
self.tts.generate(transcript) | ||
self.llm.generate(transcript) | ||
else: | ||
if transcript != self.asr_history: | ||
print(f">> {transcript}") | ||
|
||
self.asr_history = transcript | ||
|
||
def run(self): | ||
self.asr.start() | ||
self.tts.start() | ||
self.llm.start() | ||
|
||
while True: | ||
time.sleep(1.0) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
|
||
if args.list_devices: | ||
riva.client.audio_io.list_output_devices() | ||
sys.exit(0) | ||
|
||
chatbot = Chatbot(args) | ||
chatbot.start() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
|
||
* Talk live with LLM's using [RIVA](/packages/riva-client) ASR and TTS! | ||
* Requires the [RIVA server](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/riva/resources/riva_quickstart_arm64) and [`text-generation-webui`](/packages/llm/text-generation-webui) to be running | ||
* | ||
|
||
### Audio Check | ||
|
||
First, it's recommended to test your microphone/speaker with RIVA ASR/TTS. Follow the steps from the [`riva-client:python`](/packages/riva-client) package: | ||
|
||
1. Start the RIVA server running on your Jetson by following [`riva_quickstart_arm64`](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/riva/resources/riva_quickstart_arm64) | ||
2. List your [audio devices](/packages/riva-client/README.md#list-audio-devices) | ||
3. Perform the ASR/TTS [loopback test](/packages/riva-client/README.md#loopback) | ||
|
||
### Load LLM | ||
|
||
Next, start [`text-generation-webui`](/packages/llm/text-generation-webui) with the `--api` flag and load your chat model of choice through the web UI: | ||
|
||
```bash | ||
./run.sh --workdir /opt/text-generation-webui $(./autotag text-generation-webui) \ | ||
python3 server.py --listen --verbose --api \ | ||
--model-dir=/data/models/text-generation-webui | ||
``` | ||
|
||
Alternatively, you can manually specify the model that you want to load without needing to use the web UI: | ||
|
||
```bash | ||
./run.sh --workdir /opt/text-generation-webui $(./autotag text-generation-webui) \ | ||
python3 server.py --listen --verbose --api \ | ||
--model-dir=/data/models/text-generation-webui \ | ||
--model=llama-2-13b-chat.ggmlv3.q4_0.bin \ | ||
--loader=llamacpp \ | ||
--n-gpu-layers=128 | ||
``` | ||
|
||
See here for command-line arguments: https://github.com/oobabooga/text-generation-webui/tree/main#basic-settings |
Oops, something went wrong.