From 5946c116e7cfc084fa5139fefa3d039131f0b8cd Mon Sep 17 00:00:00 2001 From: Dustin Franklin Date: Sun, 13 Aug 2023 19:13:58 -0400 Subject: [PATCH] added llamaspeak package --- packages/llm/llamaspeak/Dockerfile | 14 ++ packages/llm/llamaspeak/asr.py | 59 +++++ packages/llm/llamaspeak/chat.py | 107 +++++++++ packages/llm/llamaspeak/docs.md | 35 +++ packages/llm/llamaspeak/llm.py | 284 +++++++++++++++++++++++ packages/llm/llamaspeak/requirements.txt | 2 + packages/llm/llamaspeak/tts.py | 67 ++++++ 7 files changed, 568 insertions(+) create mode 100644 packages/llm/llamaspeak/Dockerfile create mode 100755 packages/llm/llamaspeak/asr.py create mode 100755 packages/llm/llamaspeak/chat.py create mode 100644 packages/llm/llamaspeak/docs.md create mode 100755 packages/llm/llamaspeak/llm.py create mode 100644 packages/llm/llamaspeak/requirements.txt create mode 100755 packages/llm/llamaspeak/tts.py diff --git a/packages/llm/llamaspeak/Dockerfile b/packages/llm/llamaspeak/Dockerfile new file mode 100644 index 000000000..ac6076c34 --- /dev/null +++ b/packages/llm/llamaspeak/Dockerfile @@ -0,0 +1,14 @@ +#--- +# name: llamaspeak +# group: llm +# depends: [riva-client:python] +# requires: '>=34.1.0' +# docs: docs.md +#--- +ARG BASE_IMAGE +FROM ${BASE_IMAGE} + +COPY requirements.txt /opt/llamaspeak/ +RUN pip3 install --no-cache-dir --verbose -r /opt/llamaspeak/requirements.txt + +COPY *.py /opt/llamaspeak/ diff --git a/packages/llm/llamaspeak/asr.py b/packages/llm/llamaspeak/asr.py new file mode 100755 index 000000000..fbb41c00f --- /dev/null +++ b/packages/llm/llamaspeak/asr.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +import pprint +import threading + +import riva.client +import riva.client.audio_io + + +class ASR(threading.Thread): + """ + Streaming ASR service + """ + def __init__(self, auth, input_device=0, sample_rate_hz=44100, audio_chunk=1600, audio_channels=1, + automatic_punctuation=True, verbatim_transcripts=True, profanity_filter=False, + language_code='en-US', boosted_lm_words=None, boosted_lm_score=4.0, callback=None, **kwargs): + + super(ASR, self).__init__() + + self.callback=callback + self.language_code = language_code + + self.asr_service = riva.client.ASRService(auth) + + self.asr_config = riva.client.StreamingRecognitionConfig( + config=riva.client.RecognitionConfig( + encoding=riva.client.AudioEncoding.LINEAR_PCM, + language_code=language_code, + max_alternatives=1, + profanity_filter=profanity_filter, + enable_automatic_punctuation=automatic_punctuation, + verbatim_transcripts=verbatim_transcripts, + sample_rate_hertz=sample_rate_hz, + audio_channel_count=audio_channels, + ), + interim_results=True, + ) + + riva.client.add_word_boosting_to_config(self.asr_config, boosted_lm_words, boosted_lm_score) + + self.mic_stream = riva.client.audio_io.MicrophoneStream( + sample_rate_hz, + audio_chunk, + device=input_device, + ).__enter__() + + self.responses = self.asr_service.streaming_response_generator( + audio_chunks=self.mic_stream, streaming_config=self.asr_config + ) + + def run(self): + print(f"-- running ASR service ({self.language_code})") + + for response in self.responses: + if not response.results: + continue + + for result in response.results: + if self.callback is not None: + self.callback(result) \ No newline at end of file diff --git a/packages/llm/llamaspeak/chat.py b/packages/llm/llamaspeak/chat.py new file mode 100755 index 000000000..23724f3a3 --- /dev/null +++ b/packages/llm/llamaspeak/chat.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +# python3 chat.py --input-device=24 --output-device=24 --sample-rate-hz=48000 +import sys +import time +import pprint +import argparse +import threading + +import riva.client +import riva.client.audio_io + +from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters + +from asr import ASR +from tts import TTS +from llm import LLM + + +def parse_args(): + """ + Parse command-line arguments for configuring the chatbot. + """ + default_device_info = riva.client.audio_io.get_default_input_device_info() + default_device_index = None if default_device_info is None else default_device_info['index'] + + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + # audio I/O + parser.add_argument("--list-devices", action="store_true", help="List output audio devices indices.") + parser.add_argument("--input-device", type=int, default=default_device_index, help="An input audio device to use.") + parser.add_argument("--output-device", type=int, help="Output device to use.") + parser.add_argument("--sample-rate-hz", type=int, default=44100, help="Number of audio frames per second in synthesized audio.") + parser.add_argument("--audio-chunk", type=int, default=1600, help="A maximum number of frames in a audio chunk sent to server.") + parser.add_argument("--audio-channels", type=int, default=1, help="The number of audio channels to use") + + # ASR/TTS settings + parser.add_argument("--voice", type=str, default='English-US.Female-1', help="A voice name to use for TTS") + parser.add_argument("--no-punctuation", action='store_true', help="Disable ASR automatic punctuation") + + # LLM settings + parser.add_argument("--llm-server", type=str, default='0.0.0.0', help="hostname of the LLM server (text-generation-webui)") + parser.add_argument("--llm-api-port", type=int, default=5000, help="port of the blocking API on the LLM server") + parser.add_argument("--llm-streaming-port", type=int, default=5005, help="port of the streaming websocket API on the LLM server") + + parser = add_asr_config_argparse_parameters(parser, profanity_filter=True) + parser = add_connection_argparse_parameters(parser) + + args = parser.parse_args() + + args.automatic_punctuation = not args.no_punctuation + args.verbatim_transcripts = not args.no_verbatim_transcripts + + print(args) + return args + + +class Chatbot(threading.Thread): + """ + LLM-based chatbot with streaming ASR/TTS + """ + def __init__(self, args, **kwargs): + super(Chatbot, self).__init__() + + self.args = args + self.auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server) + + self.asr = ASR(self.auth, callback=self.on_asr_transcript, **vars(args)) + self.tts = TTS(self.auth, **vars(args)) + self.llm = LLM(**vars(args)) + + self.asr_history = "" + + def on_asr_transcript(self, result): + """ + Recieve new ASR responses + """ + transcript = result.alternatives[0].transcript.strip() + + if result.is_final: + print(f"## {transcript} ({result.alternatives[0].confidence})") + self.tts.generate(transcript) + self.llm.generate(transcript) + else: + if transcript != self.asr_history: + print(f">> {transcript}") + + self.asr_history = transcript + + def run(self): + self.asr.start() + self.tts.start() + self.llm.start() + + while True: + time.sleep(1.0) + + +if __name__ == '__main__': + args = parse_args() + + if args.list_devices: + riva.client.audio_io.list_output_devices() + sys.exit(0) + + chatbot = Chatbot(args) + chatbot.start() + \ No newline at end of file diff --git a/packages/llm/llamaspeak/docs.md b/packages/llm/llamaspeak/docs.md new file mode 100644 index 000000000..e29c32bba --- /dev/null +++ b/packages/llm/llamaspeak/docs.md @@ -0,0 +1,35 @@ + +* Talk live with LLM's using [RIVA](/packages/riva-client) ASR and TTS! +* Requires the [RIVA server](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/riva/resources/riva_quickstart_arm64) and [`text-generation-webui`](/packages/llm/text-generation-webui) to be running +* + +### Audio Check + +First, it's recommended to test your microphone/speaker with RIVA ASR/TTS. Follow the steps from the [`riva-client:python`](/packages/riva-client) package: + +1. Start the RIVA server running on your Jetson by following [`riva_quickstart_arm64`](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/riva/resources/riva_quickstart_arm64) +2. List your [audio devices](/packages/riva-client/README.md#list-audio-devices) +3. Perform the ASR/TTS [loopback test](/packages/riva-client/README.md#loopback) + +### Load LLM + +Next, start [`text-generation-webui`](/packages/llm/text-generation-webui) with the `--api` flag and load your chat model of choice through the web UI: + +```bash +./run.sh --workdir /opt/text-generation-webui $(./autotag text-generation-webui) \ + python3 server.py --listen --verbose --api \ + --model-dir=/data/models/text-generation-webui +``` + +Alternatively, you can manually specify the model that you want to load without needing to use the web UI: + +```bash +./run.sh --workdir /opt/text-generation-webui $(./autotag text-generation-webui) \ + python3 server.py --listen --verbose --api \ + --model-dir=/data/models/text-generation-webui \ + --model=llama-2-13b-chat.ggmlv3.q4_0.bin \ + --loader=llamacpp \ + --n-gpu-layers=128 +``` + +See here for command-line arguments: https://github.com/oobabooga/text-generation-webui/tree/main#basic-settings \ No newline at end of file diff --git a/packages/llm/llamaspeak/llm.py b/packages/llm/llamaspeak/llm.py new file mode 100755 index 000000000..344f5816f --- /dev/null +++ b/packages/llm/llamaspeak/llm.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +import sys +import json +import queue +import pprint +import asyncio +import argparse +import requests +import threading + +from websockets.sync.client import connect as websocket_connect + + +class LLM(threading.Thread): + """ + LLM service using text-generation-webui API + """ + def __init__(self, llm_server='0.0.0.0', llm_api_port=5000, llm_streaming_port=5005, **kwargs): + + super(LLM, self).__init__() + + self.queue = queue.Queue() + + self.server = llm_server + self.blocking_port = llm_api_port + self.streaming_port = llm_streaming_port + + self.request_count = 0 + + pprint.pprint(self.model_list()) + pprint.pprint(self.model_info()) + + model_name = self.model_name().lower() + + # find default chat template based on the model + self.instruction_template = None + + if any(x in model_name for x in ['llama2', 'llama_2', 'llama-2']): + self.instruction_template = 'Llama-v2' + elif 'vicuna' in model_name: + self.instruction_template = 'Vicuna-v1.1' + + def model_info(self): + """ + Returns info about the model currently loaded on the server. + """ + return self.model_api({'action': 'info'})['result'] + + def model_name(self): + """ + Return the list of models available on the server. + """ + return self.model_info()['model_name'] + + def model_list(self): + """ + Return the list of models available on the server. + """ + return self.model_api({'action': 'list'})['result'] + + def model_api(self, request): + """ + Call the text-generation-webui model API with one of these requests: + + {'action': 'info'} + {'action': 'list'} + + See model_list() and model_info() for using these requests. + """ + return requests.post(f'http://{self.server}:{self.blocking_port}/api/v1/model', json=request).json() + + def generate(self, prompt, callback=None, **kwargs): + """ + Generate an asynchronous text completion request to run on the LLM server. + You can set optional parameters for the request through the kwargs (e.g. max_new_tokens=50) + If the callback function is provided, it will be called as the generated tokens are streamed in. + This function returns the request that was queued. + """ + params = { + 'prompt': prompt, + 'max_new_tokens': 250, + 'auto_max_new_tokens': False, + + # Generation params. If 'preset' is set to different than 'None', the values + # in presets/preset-name.yaml are used instead of the individual numbers. + 'preset': 'None', + 'do_sample': True, + 'temperature': 0.7, + 'top_p': 0.1, + 'typical_p': 1, + 'epsilon_cutoff': 0, # In units of 1e-4 + 'eta_cutoff': 0, # In units of 1e-4 + 'tfs': 1, + 'top_a': 0, + 'repetition_penalty': 1.18, + 'repetition_penalty_range': 0, + 'top_k': 40, + 'min_length': 0, + 'no_repeat_ngram_size': 0, + 'num_beams': 1, + 'penalty_alpha': 0, + 'length_penalty': 1, + 'early_stopping': False, + 'mirostat_mode': 0, + 'mirostat_tau': 5, + 'mirostat_eta': 0.1, + 'guidance_scale': 1, + 'negative_prompt': '', + + 'seed': -1, + 'add_bos_token': True, + 'truncation_length': 2048, + 'ban_eos_token': False, + 'skip_special_tokens': True, + 'stopping_strings': [] + } + + params.update(kwargs) + + request = { + 'id': self.request_count, + 'type': 'completion', + 'params': params, + 'callback': callback + } + + self.request_count += 1 + self.queue.put(request) + return request + + def generate_chat(self, user_input, history, callback=None, **kwargs): + """ + Generate an asynchronous chat request to run on the LLM server. + You can set optional parameters for the request through the kwargs (e.g. max_new_tokens=50) + If the callback function is provided, it will be called as the generated tokens are streamed in. + This function returns the request that was queued. + """ + params = { + 'user_input': user_input, + 'max_new_tokens': 250, + 'auto_max_new_tokens': False, + 'history': history, + 'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct' + 'character': 'Example', + #'instruction_template': 'Llama-v2', # Will get autodetected if unset (see below) + 'your_name': 'You', + # 'name1': 'name of user', # Optional + # 'name2': 'name of character', # Optional + # 'context': 'character context', # Optional + # 'greeting': 'greeting', # Optional + # 'name1_instruct': 'You', # Optional + # 'name2_instruct': 'Assistant', # Optional + # 'context_instruct': 'context_instruct', # Optional + # 'turn_template': 'turn_template', # Optional + 'regenerate': False, + '_continue': False, + 'stop_at_newline': False, + 'chat_generation_attempts': 1, + 'chat_instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', + + # Generation params. If 'preset' is set to different than 'None', the values + # in presets/preset-name.yaml are used instead of the individual numbers. + 'preset': 'None', + 'do_sample': True, + 'temperature': 0.7, + 'top_p': 0.1, + 'typical_p': 1, + 'epsilon_cutoff': 0, # In units of 1e-4 + 'eta_cutoff': 0, # In units of 1e-4 + 'tfs': 1, + 'top_a': 0, + 'repetition_penalty': 1.18, + 'repetition_penalty_range': 0, + 'top_k': 40, + 'min_length': 0, + 'no_repeat_ngram_size': 0, + 'num_beams': 1, + 'penalty_alpha': 0, + 'length_penalty': 1, + 'early_stopping': False, + 'mirostat_mode': 0, + 'mirostat_tau': 5, + 'mirostat_eta': 0.1, + 'guidance_scale': 1, + 'negative_prompt': '', + + 'seed': -1, + 'add_bos_token': True, + 'truncation_length': 2048, + 'ban_eos_token': False, + 'skip_special_tokens': True, + 'stopping_strings': [] + } + + params.update(kwargs) + + if 'instruction_template' not in params and self.instruction_template: + params['instruction_template'] = self.instruction_template + + request = { + 'id': self.request_count, + 'type': 'chat', + 'params': params, + 'callback': callback + } + + self.request_count += 1 + self.queue.put(request) + return request + + def run(self): + print(f"-- running LLM service ({self.model_name()})") + + while True: + request = self.queue.get() + + print("-- LLM:") + pprint.pprint(request) + + if request['type'] == 'completion': + url = f"ws://{self.server}:{self.streaming_port}/api/v1/stream" + elif request['type'] == 'chat': + url = f"ws://{self.server}:{self.streaming_port}/api/v1/chat-stream" + + with websocket_connect(url) as websocket: + websocket.send(json.dumps(request['params'])) + + while True: + incoming_data = websocket.recv() + incoming_data = json.loads(incoming_data) + + if request['callback'] is None: + continue + + if incoming_data['event'] == 'text_stream': + key = 'history' if request['type'] is 'chat' else 'text' + request['callback'](incoming_data[key], request=request, end=False) + elif incoming_data['event'] == 'stream_end': + request['callback'](None, request=request, end=True) + return + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("--llm-server", type=str, default='0.0.0.0', help="hostname of the LLM server (text-generation-webui)") + parser.add_argument("--llm-api-port", type=int, default=5000, help="port of the blocking API on the LLM server") + parser.add_argument("--llm-streaming-port", type=int, default=5005, help="port of the streaming websocket API on the LLM server") + parser.add_argument("--max-new-tokens", type=int, default=250, help="the maximum number of new tokens for the LLM to generate") + parser.add_argument("--prompt", type=str, default="") + parser.add_argument("--chat", action="store_true") + + args = parser.parse_args() + + if not args.prompt: + if args.chat: + args.prompt = "Please give me a step-by-step guide on how to plant a tree in my backyard." + else: + args.prompt = "Once upon a time," + + print(args) + + llm = LLM(**vars(args)) + llm.start() + + def on_llm_reply(response, request, end): + if not end: + if request['type'] == 'completion': + print(response, end='') + sys.stdout.flush() + elif request['type'] == 'chat': + history = response['visible'][-1][1] + print(history) + else: + print("\n") + + if args.chat: + history = {'internal': [], 'visible': []} + llm.generate_chat(args.prompt, history, max_new_tokens=args.max_new_tokens, callback=on_llm_reply) + else: + llm.generate(args.prompt, max_new_tokens=args.max_new_tokens, callback=on_llm_reply) + + \ No newline at end of file diff --git a/packages/llm/llamaspeak/requirements.txt b/packages/llm/llamaspeak/requirements.txt new file mode 100644 index 000000000..cf31e57eb --- /dev/null +++ b/packages/llm/llamaspeak/requirements.txt @@ -0,0 +1,2 @@ +nvidia-riva-client +websockets \ No newline at end of file diff --git a/packages/llm/llamaspeak/tts.py b/packages/llm/llamaspeak/tts.py new file mode 100755 index 000000000..8c7a863c9 --- /dev/null +++ b/packages/llm/llamaspeak/tts.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +import queue +import pprint +import threading + +import riva.client +import riva.client.audio_io + + +class TTS(threading.Thread): + """ + Streaming TTS service + """ + def __init__(self, auth, output_device=0, sample_rate_hz=44100, audio_channels=1, + language_code='en-US', voice='English-US.Female-1', **kwargs): + + super(TTS, self).__init__() + + self.queue = queue.Queue() + self.voice = voice + + self.language_code = language_code + self.sample_rate_hz = sample_rate_hz + self.request_count = 0 + + self.tts_service = riva.client.SpeechSynthesisService(auth) + + self.output_stream = riva.client.audio_io.SoundCallBack( + output_device, nchannels=audio_channels, sampwidth=2, framerate=sample_rate_hz + ).__enter__() + + def generate(self, text, voice=None, callback=None): + """ + Generate an asynchronous request to synthesize speech from the given text. + The voice can be changed for each request if one is provided (otherwise the default will be used) + If the callback function is provided, it will be called as the audio chunks are streamed in. + This function returns the request that was queued. + """ + request = { + 'id': self.request_count, + 'text': text, + 'voice': voice if voice else self.voice, + 'callback': callback + } + + self.request_count += 1 + self.queue.put(request) + return request + + def run(self): + print(f"-- running TTS service ({self.language_code}, {self.voice})") + + while True: + request = self.queue.get() + + print(f"-- TTS: {request['text']}") + + responses = self.tts_service.synthesize_online( + request['text'], request['voice'], self.language_code, sample_rate_hz=self.sample_rate_hz + ) + + for response in responses: + self.output_stream(response.audio) + + if request['callback'] is not None: + request['callback'](response, request) + \ No newline at end of file