diff --git a/.gitignore b/.gitignore index 7e97adb..b19469a 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ tmp/* build.log venv *.egg-info +settings.py diff --git a/README.md b/README.md index 38051c8..6776b14 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,6 @@ # RunAI -Run AI allows you to run a threaded Stable Diffusion low level python socket -server. +Run AI allows you to run a LLMs using a socket server. --- @@ -11,8 +10,6 @@ server. - **Sockets**: handles byte packets of an arbitrary size - **Threaded**: asynchronously handle requests and responses - **Queue**: requests and responses are handed off to a queue -- **Auto-shutdown**: server automatically shuts down after client disconnects -- Does not save images or logs to disc --- @@ -27,291 +24,28 @@ Stable Diffusion (and other AI models) locally. It was designed for use with the Krita Stable Diffusion plugin, but can work with any interface provided someone writes a client for it. -### Only uses float16 (half floats) +### Only works with Mistral -If someone wants to build in functionality for float32 I will merge the code but -currently this is not a priority feature. +This library was designed to work with the Mistral model, but it can be expanded +to work with any LLM. --- ## Installation -First run `sh bin/install.sh` to install the required models. These will be -placed in `~/stablediffusion`. See [Stable Diffusion directory structure](#stable-diffusion-directory-structure) for more information. - -## Docker - -Easiest method - -1. [Install docker](https://docs.docker.com/engine/install/) -2. [Install nvdia-container-runtime](https://nvidia.github.io/nvidia-container-runtime/) -3. `sudo apt install nvidia-container-toolkit` -4. Copy `daemon.json` to `/etc/docker/daemon.json` (if you already have a daemon.js file in that directory, just copy the contents) -5. `docker-compose up` - ----- - -### Docker commands - -All of the following commands are contained in `/bin/dc`, you can add it to your path or run it directly. - -- **Run** `./bin/dc start` run the server -- **Shell** `./bin/dc bash` enter shell -- **Update** `./bin/dc updatereqs` update pip -- **Build** `./bin/dc build` build the server -- **Clean** `./bin/dc /app/bin/clean.sh` - ---- - -### More commands - -- Build and start the services `docker-compose up` -- Stop and remove all services `docker-compose down` -- Rebuild all services `docker-compose build` -- List all running containers `docker-compose ps` -- View the output from containers `docker-compose logs` -- Execute a command in a running container `docker-compose exec ` -- Replace with the name of the service defined in the docker-compose.yml file, and with the command you want to run. - ---- - -## Bare metal - -1. [Install CUDA Toolkit 11.7](https://developer.nvidia.com/cuda-11-7-0-download-archive?target_os=Linux&target_arch=x86_64) -2. [Install miniconda](https://docs.conda.io/en/latest/miniconda.html) -3. Activate environment `` -4. conda activate runai -5. Install requirements `pip install -r requirements.txt` - -Create a lib folder in the root of the project: - -`mkdir -r lib/torch` - -Copy the following into `lib/torch/`: - -- `lib/torch/bin/torch_shm_manager` -- `lib/torch/lib/libtorch_global_deps.so` - -Your directory structure may differ, but it will likely look something like this: - -``` -/home//miniconda3/envs/ksd-build/lib/python3.10/site-packages/torch/bin/torch_shm_manager -/home//miniconda3/envs/ksd-build/lib/python3.10/site-packages/torch/lib/libtorch_global_deps.so -``` - -![img.png](img.png) - -- git -- conda -- a cuda capable GPU - -Build the server -``` -./bin/buildlinux.sh -``` -The standalone server will be in the `dist` directory - -Run the server -``` -conda activate runai -python server.py -``` - ---- - -## Request structure - -Clients establish a connection with the server over a socket and send a JSON -object encoded as a byte string split into packets. An EOM (end of message) -signal is sent to indicate the end of the message. - - -![img_1.png](img_1.png) - -The server assembles the packets, decodes the JSON object and processes the -request. Once processing is complete the server will send a response -back to the client. - -It is up to the client to reassemble the packets, decode the byte string to JSON -and handle the message. - ---- - -## Client - -For an example client, take a look at the [connect.py file](https://github.com/w4ffl35/krita_stable_diffusion/blob/master/krita_stable_diffusion/connect.py) in the Krita Stable Diffusion [Plugin](https://github.com/w4ffl35/krita_stable_diffusion) which uses this server. - ---- - -## Stable Diffusion directory structure - - -This is the recommended and default setup for runai - -### Linux - -Default directory structure for runai Stable Diffusion - -#### Base models - -These models are required to run Stable Diffusion - -- **CLIP** files for CLIP -- **CompVis** safety checker model (used for NSWF filtering) -- **openai** clip-vit-large-patch14 model - -``` - ├── ~/stablediffusuion -    ├── CLIP -    ├── CompVis -    │   ├── stable-diffusion-safety-checker -    ├── openai -       ├── clip-vit-large-patch14 -``` - -#### Diffusers models - -These are the base models to run a particular version of Stable Diffusion. - -- **runwayml**: Base models for Stable Diffusion v1 -- **stabilityai**: Base models for Stable Diffusion v2 - -``` -├── ~/stablediffusuion -   ├── runwayml -      ├── stable-diffusion-inpainting -      ├── stable-diffusion-v1-5 -   ├── stabilityai -      ├── stable-diffusion-2-1-base -      ├── stable-diffusion-2-inpainting -``` - -#### Custom models - -- **v1** should be a directory containing models using stable diffusion v1 -- **v2** should be a directory containing models using stable diffusion v2 - -You may place diffusers folders, ckpt and safetensor files in these directories. - -``` -├── ~/stablediffusuion -   ├── v1 -   │   ├── (diffusers directory) -   │   ├── .ckpt -   │   ├── .safetensor -   ├── v2 -      ├── (diffusers directory) -      ├── .ckpt -      ├── .safetensor -``` - -### Automatic1111 existing files - -If you are using **Automatic1111** you can place your checkpoints in the -webui models folder as you typically would, however the directory structure -which includes v1 models separated from v2 models is required for now. - -This allows you to use the same checkpoints for both **Automatic1111 webui** -and this server. - -For example, if your `webui` directory looks like this - -``` -├── /home/USER/stable-diffusion-webui/models/Stable-diffusion -    ├── .ckpt -    ├── .ckpt -    ├── .ckpt -``` - -You would reorganize it like this: - -``` -├── /home/USER/stable-diffusion-webui/models/Stable-diffusion -    ├── v1 -       ├── .ckpt -       ├── .ckpt -    ├── v2 -       ├── .ckpt -``` - -You would then set BASE_DIR to `/home/USER/stable-diffusion-webui/models/Stable-diffusion` - ---- - -### Build - -First install `pyinstaller` - -`pip install pyinstaller` - -Then build the executable - +```bash +pip install runai +cp src/runai/default.settings.py src/runai/settings.py ``` -./bin/buildlinux.sh -``` - -Test - -``` -cd ./dist/runai -./runai -``` - -This should start a server. - -[Connect a client to see if it is working properly](https://github.com/w4ffl35/krita_stable_diffusion) - ---- - -## Running the server - -`python server.py` - -The following flags and options are available - -- `--port` (int) - port to run server on -- `--host` (str) - host to run server on -- `--timeout` - whether to timeout after failing to receive a client connection, pass this flag for true, otherwise the server will not timeout. -- `--packet-size` (int) - size of byte packets to transmit to and from the client -- `--model-base-path` (str) - base directory for checkpoints -- `--max-client-connections` (int) - maximum number of client connections to accept - -Example - -``` -python server.py --port 8080 --host https://0.0.0.0 --timeout -``` - -This will start a server listening on https://0.0.0.0:8080 and will timeout -after a set number of attempts when no failing to receive a client connection. - ---- - -### Request structure - -Requets are sent to the server as a JSON encoded byte string. The JSON object -should look as follows - -``` -{ - TODO -} -``` - ---- - -### Model loading -The server does not automatically load a model. It waits for the client to send -a request which contains a model path and name. The server will determine which -version of stable diffusion is in use and which model has been selected -to generate images. It will also determine the best model to load based on -the list of available types in the directory provided. +Modify `settings.py` as you see fit. --- -### Development notes +## Run server and client +See `src/runai/server.py` for an example of how to run the server and `src/runai/client.py` for an example of how to run +the client. Both of these files can be run directly from the command line. -- `StableDiffusionRequestQueueWorker.callback` Handle routes and dispatch to functions -- `socket_server.message_client` Send a message to the client \ No newline at end of file +The socket client will continuously attempt to connect to the server until it is successful. The server will accept +connections from any client on the given port. diff --git a/install.sh b/install.sh index f59a121..1fdca63 100644 --- a/install.sh +++ b/install.sh @@ -1,3 +1,3 @@ #!/usr/bin/bash -pip install -r requirements.txt \ No newline at end of file +pip install -r requirements.txt diff --git a/run.sh b/run.sh deleted file mode 100755 index 3865609..0000000 --- a/run.sh +++ /dev/null @@ -1,5 +0,0 @@ -# python3 /app/server.py -# run a loop taking input -while true; do - read -p "Enter a number: " num -done \ No newline at end of file diff --git a/setup.py b/setup.py index 6da54f0..d9d99e0 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ "numpy==1.26.4", # Core application dependencies "accelerate==0.29.2", - "huggingface-hub==0.22.2", + "huggingface-hub==0.23.0", "torch==2.2.2", "optimum==1.19.1", @@ -42,6 +42,7 @@ "pycountry==23.12.11", "sounddevice==0.4.6", # Required for tts and stt "pyttsx3==2.90", # Required for tts + "peft==0.12.0", # Pyinstaller Dependencies "ninja==1.11.1.1", diff --git a/src/runai/agent.py b/src/runai/agent.py index 5e8a896..93f38c3 100644 --- a/src/runai/agent.py +++ b/src/runai/agent.py @@ -2,6 +2,23 @@ class Agent: def __init__(self, *args, **kwargs): self.conversation = kwargs.get("conversation", []) self.name = kwargs.get("name", "") + self.mood_stats = { + "happy": 0, + "sad": 0, + "neutral": 0, + "angry": 0, + "paranoid": 0, + "anxious": 0, + "excited": 0, + "bored": 0, + "confused": 0, + "relaxed": 0, + "curious": 0, + "frustrated": 0, + "hopeful": 0, + "disappointed": 0, + } + self.user = None def conversation_so_far(self, use_name=False): if use_name: @@ -15,3 +32,7 @@ def to_dict(self): "conversation": self.conversation, "name": self.name } + + def update_mood_stat(self, stat: str, amount: float): + if stat in self.mood_stats: + self.mood_stats[stat] += amount diff --git a/src/runai/client.py b/src/runai/client.py new file mode 100644 index 0000000..66e4ca2 --- /dev/null +++ b/src/runai/client.py @@ -0,0 +1,177 @@ +import json +import re +from datetime import datetime + +from runai.agent import Agent +from runai.llm_request import LLMRequest +from runai.settings import DEFAULT_HOST, DEFAULT_PORT, PACKET_SIZE, USER_NAME, BOT_NAME, LLM_INSTRUCTIONS +from runai.socket_client import SocketClient + + +class Client: + def __init__( + self, + host=DEFAULT_HOST, + port=DEFAULT_PORT, + packet_size=PACKET_SIZE, + user_name=USER_NAME, + bot_name=BOT_NAME + ): + self.host = host + self.port = port + self.packet_size = packet_size + self.socket_client = self.connect_socket() + self.bot_agent = Agent(name=bot_name) + self.user_agent = Agent(name=user_name) + + self.history = [] + + def connect_socket(self): + socket_client = SocketClient(host=self.host, port=self.port, packet_size=self.packet_size) + socket_client.connect() + return socket_client + + @property + def dialogue_instructions(self): + return LLM_INSTRUCTIONS["dialogue_instructions"].format( + dialogue_rules=self.dialogue_rules, + mood_stats=self.mood_stats, + contextual_information=self.contextual_information + ) + + @property + def contextual_information(self): + return LLM_INSTRUCTIONS["contextual_information"].format( + date_time=datetime.now().strftime("%m/%d/%Y, %I:%M:%S %p"), + weather="sunny" + ) + + @property + def update_mood_instructions(self): + return LLM_INSTRUCTIONS["update_mood_instructions"].format( + speaker_name=self.bot_agent.name, + python_rules=self.python_rules + ) + + @property + def mood_stats(self): + stats = ", ".join([f"{k}: {v}" for k, v in self.bot_agent.mood_stats.items()]) + return ( + f"{self.bot_agent.name}'s mood stats:\n" + f"{stats}\n" + ) + + @property + def dialogue_rules(self): + return LLM_INSTRUCTIONS["dialogue_rules"].format( + speaker_name=self.bot_agent.name, + listener_name=self.user_agent.name + ) + + @property + def json_rules(self): + return LLM_INSTRUCTIONS["json_rules"] + + @property + def python_rules(self): + return LLM_INSTRUCTIONS["python_rules"] + + def do_greeting(self): + return self.do_query( + LLM_INSTRUCTIONS["greeting_prompt"].format(speaker_name=self.bot_agent.name), + self.dialogue_instructions + ) + + def do_response(self): + return self.do_query( + LLM_INSTRUCTIONS["response_prompt"].format(speaker_name=self.bot_agent.name), + self.dialogue_instructions + ) + + def update_mood(self, agent: Agent): + stats = ", ".join([f'"{k}": {v}' for k, v in agent.mood_stats.items()]) + res = self.do_query( + LLM_INSTRUCTIONS["update_mood_prompt"].format( + agent_name=agent.name, + stats=stats + ), + self.update_mood_instructions + ) + python_code_match = self.find_python(res) + if python_code_match: + python_code = python_code_match.group(1) + exec(python_code, {}, {"agent": agent}) + return agent + + @staticmethod + def find_python(res: str): + return Client.find_code_block("python", res) + + @staticmethod + def find_json(res: str): + return Client.find_code_block("json", res) + + @staticmethod + def find_code_block(language: str, res: str) -> re.Match: + return re.search(r'```' + language + 'r\n(.*?)\n```', res, re.DOTALL) + + def do_prompt(self, user_prompt, update_speaker_mood=False): + self.update_history(self.user_agent.name, user_prompt) + + if update_speaker_mood: + self.bot_agent = self.update_mood(self.bot_agent) + + for res in self.do_response(): + yield res + + return "" + + def update_history(self, name: str, message: str): + self.history.append({ + "name": name, + "message": message + }) + + def do_query(self, user_prompt, instructions): + llm_request = LLMRequest( + history=self.history, + speaker=self.bot_agent, + listener=self.user_agent, + use_usernames=True, + prompt_prefix="", + instructions=instructions, + prompt=user_prompt + ) + self.socket_client.send_message(json.dumps(llm_request.to_dict())) + + server_response = "" + for res in self.socket_client.receive_message(): + res.replace(f"{self.bot_agent.name}: ", "") + server_response += res.replace('\x00', '') + yield server_response + return server_response.replace('\x00', '') + + def _handle_prompt(self, prompt: str): + response = "" + for txt in self.do_prompt(prompt): + response = str(txt) + yield response + + self.history.append({ + "name": self.bot_agent.name, + "message": response + }) + + def run(self): + while True: + prompt = input("Enter a prompt: ") + for res in self._handle_prompt(prompt): + self.handle_res(res) + + def handle_res(self, res): + print(res) + + +if __name__ == "__main__": + client = Client() + client.run() diff --git a/src/runai/default.settings.py b/src/runai/default.settings.py new file mode 100644 index 0000000..b7da43e --- /dev/null +++ b/src/runai/default.settings.py @@ -0,0 +1,65 @@ +PACKET_SIZE = 1024 +DEFAULT_PORT = 50006 +DEFAULT_HOST = "0.0.0.0" +USER_NAME = "User" +BOT_NAME = "AI Bot" +MAX_CLIENTS = 1 +DEBUG = True +DEFAULT_SERVER_TYPE = "LLM" +MODEL_BASE_PATH = "~/.airunner/text/models/causallm" +MODELS = { + "mistral_instruct": { + "path": "mistralai/Mistral-7B-Instruct-v0.3", + "chat_template": ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '[INST] <>' + message['content'] + ' <>[/INST]' }}" + "{% elif message['role'] == 'user' %}" + "{{ '[INST]' + message['content'] + ' [/INST]' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ message['content'] + eosets_token + ' ' }}" + "{% endif %}" + "{% endfor %}" + ) + } +} +DEFAULT_MODEL_NAME = "mistral_instruct" +LLM_INSTRUCTIONS = { + "dialogue_instructions": "You are a chatbot. You will follow all of the rules in order to generate compelling and intriguing dialogue.\nThe Rules:\n{dialogue_rules}------\n{mood_stats}------\n{contextual_information}------\n", + "contextual_information": "Contextual Information:\n{date_time}\nThe weather is {weather}\n", + "update_mood_instructions": "Analyze the conversation and update {speaker_name}'s and determine what {speaker_name}'s mood stats should change to.\nThe Rules:\n{python_rules}------\n", + "dialogue_rules": ( + "You will ONLY return dialogue, nothing more.\n" + "Limit responses to a single sentence.\n" + "Only generate responses in pure dialogue form without including any actions, descriptions or stage directions in parentheses. " + "Only return spoken words.\n" + "Do not generate redundant dialogue. Examine the conversation and context close and keep responses interesting and creative.\n" + "Do not format the response with the character's name or any other text. Only return the dialogue.\n" + "{speaker_name} and {listener_name} are having a conversation. \n" + "Respond with dialogue for {speaker_name}.\n" + "Avoid repeating {speaker_name}'s previous dialogue or {listener_name}'s previous dialogue.\n" + ), + "json_rules": ( + "You will ONLY return JSON.\n" + "No other data types are allowed.\n" + "Never return instructions, information or dialogue.\n" + ), + "python_rules": ( + "You will ONLY return Python code.\n" + "No other data types are allowed.\n" + "Never return instructions, information or dialogue.\n" + "Never return comments.\n" + ), + "greeting_prompt": "Generate a greeting for {speaker_name}", + "response_prompt": "Generate a response for {speaker_name}", + "update_mood_prompt": ( + "Update the appropriate mood stats, incrementing or decrementing them by floating points.\n" + "Current mood stats for {agent_name}\n" + "{stats}" + "Return a block of python code updating whichever mood stats you think are appropriate based on the conversation.\n" + "call `agent.update_mood_stat` to update each appropriate mood stat.\n" + "The function takes two arguments: stat: str, and amount: float.\n" + "You may call the method on `agent` multiple times passing various stats that should be updated.\n" + "```python\n" + ) +} diff --git a/src/runai/llm_handler.py b/src/runai/llm_handler.py index 8e49834..03feb96 100644 --- a/src/runai/llm_handler.py +++ b/src/runai/llm_handler.py @@ -11,13 +11,15 @@ from runai.settings import MODEL_BASE_PATH, MODELS -class LLMHandler(RagMixin): +class LLMHandler():#RagMixin): def __init__(self, model_name: str = ""): - self._model_path = os.path.expanduser(MODEL_BASE_PATH) - self.model_name = MODELS[model_name]["path"] + self.model_name = model_name + self.model_path = os.path.join( + os.path.expanduser(MODEL_BASE_PATH), + MODELS[self.model_name]["path"] + ) - # RagMixin.__init__(self) - self.rendered_template = None + #RagMixin.__init__(self) self.model = self.load_model() self.tokenizer = self.load_tokenizer() self.streamer = self.load_streamer() @@ -32,8 +34,8 @@ def interrupt(self): self._do_interrupt_process = True @property - def model_path(self): - return os.path.join(self._model_path, self.model_name) + def quantized_model_path(self): + return self.model_path + "_quantized" @property def device(self): @@ -43,8 +45,12 @@ def do_interrupt_process(self): return self._do_interrupt_process def load_model(self): - return AutoModelForCausalLM.from_pretrained( - self.model_path, + model_path = self.quantized_model_path + if not os.path.exists(model_path): + model_path = self.model_path + + model = AutoModelForCausalLM.from_pretrained( + model_path, local_files_only=True, use_cache=True, trust_remote_code=False, @@ -61,41 +67,34 @@ def load_model(self): device_map=self.device, ) + if model_path != self.quantized_model_path: + model.save_pretrained(self.quantized_model_path) + return model + def load_tokenizer(self): return AutoTokenizer.from_pretrained(self.model_path) def load_streamer(self): return TextIteratorStreamer(self.tokenizer) + def rendered_template(self, conversation: list) -> str: + chat_template = MODELS[self.model_name]["chat_template"] + rendered_template = self.tokenizer.apply_chat_template( + chat_template=chat_template, + conversation=conversation, + tokenize=False + ) + return rendered_template + def query_model( self, llm_request: LLMRequest ): - chat_template = ( - "{% for message in messages %}" - "{% if message['role'] == 'system' %}" - "{{ '[INST] <>' + message['content'] + ' <>[/INST]' }}" - "{% elif message['role'] == 'user' %}" - "{{ '[INST]' + message['content'] + ' [/INST]' }}" - "{% elif message['role'] == 'assistant' %}" - "{{ message['content'] + eosets_token + ' ' }}" - "{% endif %}" - "{% endfor %}" - ) - rendered_template = self.tokenizer.apply_chat_template( - chat_template=chat_template, - conversation=llm_request.conversation, - tokenize=False - ) - self.rendered_template = rendered_template - model_inputs = self.tokenizer( - rendered_template, - return_tensors="pt" - ).to(self.device) - stopping_criteria = ExternalConditionStoppingCriteria( - self.do_interrupt_process - ) + rendered_template = self.rendered_template(llm_request.conversation) + model_inputs = self.tokenizer(rendered_template, return_tensors="pt").to(self.device) + stopping_criteria = ExternalConditionStoppingCriteria(self.do_interrupt_process) + print(rendered_template) self.generate_data = dict( model_inputs, max_new_tokens=llm_request.max_new_tokens, @@ -126,28 +125,44 @@ def query_model( ) self.generate_thread.start() - rendered_template = rendered_template.replace("", "") - strip_template = "" + rendered_template - # strip_template = strip_template.replace(" [INST]", " [INST]") - # strip_template = strip_template.replace(" [INST] <>", "[INST] <>") - - - strip_template = strip_template.replace("[INST] <>", "[INST] <>") - strip_template = strip_template.replace("<>[/INST][INST]", "<>[/INST][INST] ") + rendered_template = self.update_rendered_template(rendered_template) streamed_template = "" replaced = False for new_text in self.streamer: - streamed_template += new_text - streamed_template = streamed_template.replace("", "") - if streamed_template.find(strip_template) != -1: - replaced = True - streamed_template = streamed_template.replace(strip_template, "") + if not replaced: + replaced, streamed_template = self.update_streamed_template( + rendered_template, + streamed_template, + new_text + ) + if replaced: - parsed = new_text.replace("[/INST]", "") - parsed = parsed.replace("", "") - parsed = parsed.replace("<>", "") + parsed = self.strip_tags(new_text) yield parsed + @staticmethod + def update_streamed_template(rendered_template, streamed_template, new_text): + streamed_template += new_text + streamed_template = streamed_template.replace("", "") + replaced = streamed_template.find(rendered_template) != -1 + streamed_template = streamed_template.replace(rendered_template, "") + return replaced, streamed_template + + @staticmethod + def update_rendered_template(rendered_template) -> str: + rendered_template = rendered_template.replace("", "") + rendered_template = "" + rendered_template + rendered_template = rendered_template.replace("[INST] <>", "[INST] <>") + rendered_template = rendered_template.replace("<>[/INST][INST]", "<>[/INST][INST] ") + return rendered_template + + @staticmethod + def strip_tags(template: str) -> str: + template = template.replace("[/INST]", "") + template = template.replace("", "") + template = template.replace("<>", "") + return template + def generate(self, data): self.model.generate(**data) diff --git a/src/runai/llm_request.py b/src/runai/llm_request.py index afecd02..b7a4731 100644 --- a/src/runai/llm_request.py +++ b/src/runai/llm_request.py @@ -7,7 +7,7 @@ class LLMRequest: def __init__( self, - conversation: Optional[List[dict]] = None, + history: List[dict] = None, listener: Agent = None, speaker: Agent = None, use_usernames: bool = False, @@ -29,12 +29,12 @@ def __init__( length_penalty: float = 1.0, llm_query_type: Optional[LLMQueryType] = None ): - self.conversation = conversation + self.history = history if history else [] self.listener = listener self.speaker = speaker self.use_usernames = use_usernames self.prompt_prefix = prompt_prefix - self.instructions = instructions + self._instructions = instructions self.prompt = prompt self.max_new_tokens = max_new_tokens self.temperature = temperature @@ -51,9 +51,25 @@ def __init__( self.use_cache = use_cache self.length_penalty = length_penalty + @property + def instructions(self): + instructions = self._instructions + if len(self.history): + instructions += "\nThe conversation so far:\n" + for turn in self.history: + instructions += f"{turn['name']}: {turn['message']}\n" + return instructions + + @property + def conversation(self): + return [ + {"role": "system", "content": self.instructions}, + {"role": "user", "content": self.prompt} + ] + def to_dict(self): return { - "conversation": self.conversation, + "history": self.history, "listener": self.listener.to_dict() if self.listener else None, "speaker": self.speaker.to_dict() if self.speaker else None, "use_usernames": self.use_usernames, diff --git a/src/runai/llm_request_queue_worker.py b/src/runai/llm_request_queue_worker.py index d46b0f5..5beb41c 100644 --- a/src/runai/llm_request_queue_worker.py +++ b/src/runai/llm_request_queue_worker.py @@ -34,16 +34,8 @@ def callback(self, data): data = json.loads(data) except json.decoder.JSONDecodeError: logger.error(f"Improperly formatted request from client") - print(data) return - print(data) - - data["conversation"] = [ - {"role": "system", "content": data["instructions"]}, - {"role": "user", "content": data["prompt"]} - ] - data["listener"] = Agent(**data["listener"]) data["speaker"] = Agent(**data["speaker"]) diff --git a/src/runai/server.py b/src/runai/server.py index b9e4eae..0727665 100644 --- a/src/runai/server.py +++ b/src/runai/server.py @@ -2,44 +2,49 @@ from runai import settings -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--model-name', type=str, default=settings.DEFAULT_MODEL_NAME) - parser.add_argument('--server', type=str, default=settings.DEFAULT_SERVER_TYPE) - parser.add_argument('--port', type=int, default=settings.DEFAULT_PORT) - parser.add_argument('--host', type=str, default=settings.DEFAULT_HOST) - parser.add_argument('--timeout', action='store_true', default=False) - parser.add_argument('--packet-size', type=int, default=settings.PACKET_SIZE) - parser.add_argument('--max-client-connections', type=int, default=1) - parser.add_argument('--model-base-path', type=str, default='.') - - return parser.parse_args() +class Server: + def __init__(self): + # get command line arguments + args = self.parse_args() + + model_name = args.model_name + if model_name is None or model_name == "" or model_name not in settings.MODELS: + model_name = settings.DEFAULT_MODEL_NAME + + if args.server == "LLM": + from llm_request_queue_worker import LLMRequestQueueWorker + server_class_ = LLMRequestQueueWorker + elif args.server == "SD": + from stable_diffusion_request_queue_worker import StableDiffusionRequestQueueWorker + server_class_ = StableDiffusionRequestQueueWorker + else: + raise ValueError(f"Unknown server type: {args.server}") + + self.app = server_class_( + model_name=model_name, + port=args.port, + host=args.host, + do_timeout=args.timeout, + packet_size=args.packet_size, + + # future: + max_client_connections=args.max_client_connections, + model_base_path=args.model_base_path + ) + + def parse_args(self): + parser = argparse.ArgumentParser() + parser.add_argument('--model-name', type=str, default=settings.DEFAULT_MODEL_NAME) + parser.add_argument('--server', type=str, default=settings.DEFAULT_SERVER_TYPE) + parser.add_argument('--port', type=int, default=settings.DEFAULT_PORT) + parser.add_argument('--host', type=str, default=settings.DEFAULT_HOST) + parser.add_argument('--timeout', action='store_true', default=False) + parser.add_argument('--packet-size', type=int, default=settings.PACKET_SIZE) + parser.add_argument('--max-client-connections', type=int, default=1) + parser.add_argument('--model-base-path', type=str, default='.') + + return parser.parse_args() if __name__ == '__main__': - # get command line arguments - args = parse_args() - - model_name = args.model_name - if model_name is None or model_name == "" or model_name not in settings.MODELS: - model_name = settings.DEFAULT_MODEL_NAME - - if args.server == "LLM": - from llm_request_queue_worker import LLMRequestQueueWorker - server_class_ = LLMRequestQueueWorker - elif args.server == "SD": - from stable_diffusion_request_queue_worker import StableDiffusionRequestQueueWorker - server_class_ = StableDiffusionRequestQueueWorker - else: - raise ValueError(f"Unknown server type: {args.server}") - app = server_class_( - model_name=model_name, - port=args.port, - host=args.host, - do_timeout=args.timeout, - packet_size=args.packet_size, - - # future: - max_client_connections=args.max_client_connections, - model_base_path=args.model_base_path - ) + server = Server() diff --git a/src/runai/settings.py b/src/runai/settings.py deleted file mode 100644 index ee4594c..0000000 --- a/src/runai/settings.py +++ /dev/null @@ -1,14 +0,0 @@ -PACKET_SIZE = 1024 -DEFAULT_PORT = 50006 -DEFAULT_HOST = "0.0.0.0" -MAX_CLIENTS = 1 -DEBUG = True -DEFAULT_SERVER_TYPE = "LLM" -MODEL_BASE_PATH = "~/.airunner/text/models/causallm" -MODELS = { - "mistral_instruct": { - "path": "mistralai/Mistral-7B-Instruct-v0.3" - } -} -DEFAULT_MODEL_NAME = "mistral_instruct" -GAME_GENRE = "fantasy" diff --git a/src/runai/socket_client.py b/src/runai/socket_client.py index 622b9eb..0b21544 100644 --- a/src/runai/socket_client.py +++ b/src/runai/socket_client.py @@ -1,19 +1,25 @@ import json import socket +import time class SocketClient: - def __init__(self, host, port, packet_size): + def __init__(self, host, port, packet_size, retry_delay=2): self.packet_size = packet_size self.host = host self.port = port self.client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.retry_delay = retry_delay def connect(self): - try: - self.client_socket.connect((self.host, self.port)) - except ConnectionRefusedError: - print("Connection refused. Make sure the server is running.") + while True: + try: + self.client_socket.connect((self.host, self.port)) + print("Connected to server.") + return + except ConnectionRefusedError: + print(f"Connection refused. Retrying in {self.retry_delay} seconds...") + time.sleep(self.retry_delay) def send_message(self, message): message = message.encode('utf-8') # encode the message as UTF-8 diff --git a/src/runai/test.py b/src/runai/test.py index 0db48db..a133ba9 100644 --- a/src/runai/test.py +++ b/src/runai/test.py @@ -1,24 +1,25 @@ -import asyncio -import websockets - -# Define a callback function to handle incoming WebSocket messages -async def handle_websocket(websocket, path): - try: - while True: - message = await websocket.recv() - print(f"Received message: >{message}<") - - # XXX: Do some stuff here with the message. - - response = "Pong!" - await websocket.send(response) - except websockets.ConnectionClosed: - pass +from runai.client import Client if __name__ == "__main__": - # Start the WebSocket server - start_server = websockets.serve(handle_websocket, "localhost", 12345) - - asyncio.get_event_loop().run_until_complete(start_server) - asyncio.get_event_loop().run_forever() - + client = Client() + # client.do_greeting() + # #client.swap_speaker() + # print(client.history) + # client.do_prompt("What is the meaning of life?") + # print(client.history) + # client.do_prompt("How can I find true happiness?") + # print(client.history) + # client.do_prompt("What were my first two questions?") + # print(client.history) + # + # # client.do_response() + # # client.swap_speaker() + # # print(client.history) + # # + # # client.do_response() + # # client.swap_speaker() + # # print(client.history) + while True: + prompt = input("Enter a prompt: ") + client.do_prompt(prompt) + print(client.history[-1]["message"])