From 3769ae710c7dccb158b30174afbffc8c49b4c9e1 Mon Sep 17 00:00:00 2001 From: whatcloud Date: Wed, 2 Oct 2024 20:48:00 +1000 Subject: [PATCH] fix token exceed llm issue --- agents/application/executor.py | 78 +++++++++++++++++++++++++++++++--- 1 file changed, 73 insertions(+), 5 deletions(-) diff --git a/agents/application/executor.py b/agents/application/executor.py index e6d1f75..ff0d996 100644 --- a/agents/application/executor.py +++ b/agents/application/executor.py @@ -2,6 +2,9 @@ import json import ast import re +from typing import List, Dict, Any + +import math from dotenv import load_dotenv from langchain_core.messages import HumanMessage, SystemMessage @@ -13,14 +16,27 @@ from agents.application.prompts import Prompter from agents.polymarket.polymarket import Polymarket +def retain_keys(data, keys_to_retain): + if isinstance(data, dict): + return { + key: retain_keys(value, keys_to_retain) + for key, value in data.items() + if key in keys_to_retain + } + elif isinstance(data, list): + return [retain_keys(item, keys_to_retain) for item in data] + else: + return data class Executor: - def __init__(self) -> None: + def __init__(self, default_model='gpt-3.5-turbo-16k') -> None: load_dotenv() + max_token_model = {'gpt-3.5-turbo-16k':15000, 'gpt-4-1106-preview':95000} + self.token_limit = max_token_model.get(default_model) self.prompter = Prompter() self.openai_api_key = os.getenv("OPENAI_API_KEY") self.llm = ChatOpenAI( - model="gpt-3.5-turbo", + model=default_model, #gpt-3.5-turbo" temperature=0, ) self.gamma = Gamma() @@ -43,9 +59,12 @@ def get_superforecast( result = self.llm.invoke(messages) return result.content - def get_polymarket_llm(self, user_input: str) -> str: - data1 = self.gamma.get_current_events() - data2 = self.gamma.get_current_markets() + + def estimate_tokens(self, text: str) -> int: + # This is a rough estimate. For more accurate results, consider using a tokenizer. + return len(text) // 4 # Assuming average of 4 characters per token + + def process_data_chunk(self, data1: List[Dict[Any, Any]], data2: List[Dict[Any, Any]], user_input: str) -> str: system_message = SystemMessage( content=str(self.prompter.prompts_polymarket(data1=data1, data2=data2)) ) @@ -54,6 +73,55 @@ def get_polymarket_llm(self, user_input: str) -> str: result = self.llm.invoke(messages) return result.content + + def divide_list(self, original_list, i): + # Calculate the size of each sublist + sublist_size = math.ceil(len(original_list) / i) + + # Use list comprehension to create sublists + return [original_list[j:j+sublist_size] for j in range(0, len(original_list), sublist_size)] + + def get_polymarket_llm(self, user_input: str) -> str: + data1 = self.gamma.get_current_events() + data2 = self.gamma.get_current_markets() + + combined_data = str(self.prompter.prompts_polymarket(data1=data1, data2=data2)) + + # Estimate total tokens + total_tokens = self.estimate_tokens(combined_data) + + # Set a token limit (adjust as needed, leaving room for system and user messages) + token_limit = self.token_limit + if total_tokens <= token_limit: + # If within limit, process normally + return self.process_data_chunk(data1, data2, user_input) + else: + # If exceeding limit, process in chunks + chunk_size = len(combined_data) // ((total_tokens // token_limit) + 1) + print(f'total tokens {total_tokens} exceeding llm capacity, now will split and answer') + group_size = (total_tokens // token_limit) + 1 # 3 is safe factor + keys_no_meaning = ['image','pagerDutyNotificationEnabled','resolvedBy','endDate','clobTokenIds','negRiskMarketID','conditionId','updatedAt','startDate'] + useful_keys = ['id','questionID','description','liquidity','clobTokenIds','outcomes','outcomePrices','volume','startDate','endDate','question','questionID','events'] + data1 = retain_keys(data1, useful_keys) + cut_1 = self.divide_list(data1, group_size) + cut_2 = self.divide_list(data2, group_size) + cut_data_12 = zip(cut_1, cut_2) + + results = [] + + for cut_data in cut_data_12: + sub_data1 = cut_data[0] + sub_data2 = cut_data[1] + sub_tokens = self.estimate_tokens(str(self.prompter.prompts_polymarket(data1=sub_data1, data2=sub_data2))) + + result = self.process_data_chunk(sub_data1, sub_data2, user_input) + results.append(result) + + combined_result = " ".join(results) + + + + return combined_result def filter_events(self, events: "list[SimpleEvent]") -> str: prompt = self.prompter.filter_events(events) result = self.llm.invoke(prompt)