Skip to content

Commit

Permalink
added llamaspeak package
Browse files Browse the repository at this point in the history
  • Loading branch information
dusty-nv committed Aug 13, 2023
1 parent 8709478 commit 5946c11
Show file tree
Hide file tree
Showing 7 changed files with 568 additions and 0 deletions.
14 changes: 14 additions & 0 deletions packages/llm/llamaspeak/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#---
# name: llamaspeak
# group: llm
# depends: [riva-client:python]
# requires: '>=34.1.0'
# docs: docs.md
#---
ARG BASE_IMAGE
FROM ${BASE_IMAGE}

COPY requirements.txt /opt/llamaspeak/
RUN pip3 install --no-cache-dir --verbose -r /opt/llamaspeak/requirements.txt

COPY *.py /opt/llamaspeak/
59 changes: 59 additions & 0 deletions packages/llm/llamaspeak/asr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env python3
import pprint
import threading

import riva.client
import riva.client.audio_io


class ASR(threading.Thread):
"""
Streaming ASR service
"""
def __init__(self, auth, input_device=0, sample_rate_hz=44100, audio_chunk=1600, audio_channels=1,
automatic_punctuation=True, verbatim_transcripts=True, profanity_filter=False,
language_code='en-US', boosted_lm_words=None, boosted_lm_score=4.0, callback=None, **kwargs):

super(ASR, self).__init__()

self.callback=callback
self.language_code = language_code

self.asr_service = riva.client.ASRService(auth)

self.asr_config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
encoding=riva.client.AudioEncoding.LINEAR_PCM,
language_code=language_code,
max_alternatives=1,
profanity_filter=profanity_filter,
enable_automatic_punctuation=automatic_punctuation,
verbatim_transcripts=verbatim_transcripts,
sample_rate_hertz=sample_rate_hz,
audio_channel_count=audio_channels,
),
interim_results=True,
)

riva.client.add_word_boosting_to_config(self.asr_config, boosted_lm_words, boosted_lm_score)

self.mic_stream = riva.client.audio_io.MicrophoneStream(
sample_rate_hz,
audio_chunk,
device=input_device,
).__enter__()

self.responses = self.asr_service.streaming_response_generator(
audio_chunks=self.mic_stream, streaming_config=self.asr_config
)

def run(self):
print(f"-- running ASR service ({self.language_code})")

for response in self.responses:
if not response.results:
continue

for result in response.results:
if self.callback is not None:
self.callback(result)
107 changes: 107 additions & 0 deletions packages/llm/llamaspeak/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#!/usr/bin/env python3
# python3 chat.py --input-device=24 --output-device=24 --sample-rate-hz=48000
import sys
import time
import pprint
import argparse
import threading

import riva.client
import riva.client.audio_io

from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters

from asr import ASR
from tts import TTS
from llm import LLM


def parse_args():
"""
Parse command-line arguments for configuring the chatbot.
"""
default_device_info = riva.client.audio_io.get_default_input_device_info()
default_device_index = None if default_device_info is None else default_device_info['index']

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

# audio I/O
parser.add_argument("--list-devices", action="store_true", help="List output audio devices indices.")
parser.add_argument("--input-device", type=int, default=default_device_index, help="An input audio device to use.")
parser.add_argument("--output-device", type=int, help="Output device to use.")
parser.add_argument("--sample-rate-hz", type=int, default=44100, help="Number of audio frames per second in synthesized audio.")
parser.add_argument("--audio-chunk", type=int, default=1600, help="A maximum number of frames in a audio chunk sent to server.")
parser.add_argument("--audio-channels", type=int, default=1, help="The number of audio channels to use")

# ASR/TTS settings
parser.add_argument("--voice", type=str, default='English-US.Female-1', help="A voice name to use for TTS")
parser.add_argument("--no-punctuation", action='store_true', help="Disable ASR automatic punctuation")

# LLM settings
parser.add_argument("--llm-server", type=str, default='0.0.0.0', help="hostname of the LLM server (text-generation-webui)")
parser.add_argument("--llm-api-port", type=int, default=5000, help="port of the blocking API on the LLM server")
parser.add_argument("--llm-streaming-port", type=int, default=5005, help="port of the streaming websocket API on the LLM server")

parser = add_asr_config_argparse_parameters(parser, profanity_filter=True)
parser = add_connection_argparse_parameters(parser)

args = parser.parse_args()

args.automatic_punctuation = not args.no_punctuation
args.verbatim_transcripts = not args.no_verbatim_transcripts

print(args)
return args


class Chatbot(threading.Thread):
"""
LLM-based chatbot with streaming ASR/TTS
"""
def __init__(self, args, **kwargs):
super(Chatbot, self).__init__()

self.args = args
self.auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)

self.asr = ASR(self.auth, callback=self.on_asr_transcript, **vars(args))
self.tts = TTS(self.auth, **vars(args))
self.llm = LLM(**vars(args))

self.asr_history = ""

def on_asr_transcript(self, result):
"""
Recieve new ASR responses
"""
transcript = result.alternatives[0].transcript.strip()

if result.is_final:
print(f"## {transcript} ({result.alternatives[0].confidence})")
self.tts.generate(transcript)
self.llm.generate(transcript)
else:
if transcript != self.asr_history:
print(f">> {transcript}")

self.asr_history = transcript

def run(self):
self.asr.start()
self.tts.start()
self.llm.start()

while True:
time.sleep(1.0)


if __name__ == '__main__':
args = parse_args()

if args.list_devices:
riva.client.audio_io.list_output_devices()
sys.exit(0)

chatbot = Chatbot(args)
chatbot.start()

35 changes: 35 additions & 0 deletions packages/llm/llamaspeak/docs.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

* Talk live with LLM's using [RIVA](/packages/riva-client) ASR and TTS!
* Requires the [RIVA server](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/riva/resources/riva_quickstart_arm64) and [`text-generation-webui`](/packages/llm/text-generation-webui) to be running
*

### Audio Check

First, it's recommended to test your microphone/speaker with RIVA ASR/TTS. Follow the steps from the [`riva-client:python`](/packages/riva-client) package:

1. Start the RIVA server running on your Jetson by following [`riva_quickstart_arm64`](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/riva/resources/riva_quickstart_arm64)
2. List your [audio devices](/packages/riva-client/README.md#list-audio-devices)
3. Perform the ASR/TTS [loopback test](/packages/riva-client/README.md#loopback)

### Load LLM

Next, start [`text-generation-webui`](/packages/llm/text-generation-webui) with the `--api` flag and load your chat model of choice through the web UI:

```bash
./run.sh --workdir /opt/text-generation-webui $(./autotag text-generation-webui) \
python3 server.py --listen --verbose --api \
--model-dir=/data/models/text-generation-webui
```

Alternatively, you can manually specify the model that you want to load without needing to use the web UI:

```bash
./run.sh --workdir /opt/text-generation-webui $(./autotag text-generation-webui) \
python3 server.py --listen --verbose --api \
--model-dir=/data/models/text-generation-webui \
--model=llama-2-13b-chat.ggmlv3.q4_0.bin \
--loader=llamacpp \
--n-gpu-layers=128
```

See here for command-line arguments: https://github.com/oobabooga/text-generation-webui/tree/main#basic-settings
Loading

0 comments on commit 5946c11

Please sign in to comment.