From b1bf5a56c9b244a717e305f64ed21821584ccc54 Mon Sep 17 00:00:00 2001 From: Jackson Barbosa Date: Thu, 7 Dec 2023 14:28:48 -0300 Subject: [PATCH] adapt view and configure env vars --- .../api/v2/zeroshot/usecases/format_prompt.py | 57 +++++++++++++++++++ bothub/api/v2/zeroshot/views.py | 57 +++++++++++++------ bothub/settings.py | 16 +++++- 3 files changed, 111 insertions(+), 19 deletions(-) create mode 100644 bothub/api/v2/zeroshot/usecases/format_prompt.py diff --git a/bothub/api/v2/zeroshot/usecases/format_prompt.py b/bothub/api/v2/zeroshot/usecases/format_prompt.py new file mode 100644 index 00000000..d17ff288 --- /dev/null +++ b/bothub/api/v2/zeroshot/usecases/format_prompt.py @@ -0,0 +1,57 @@ + +class FormatPrompt: + const_prompt_data = { + "pt_br": { + "prompt_context": "Você é muito especialista em ", + "prompt_has_classes": ". Você possui as classes:\n\n", + "prompt_class_prefix": "Classe: ", + "prompt_context_classes": "Contexto da classe ", + "is_class_context_prefix": True, + "prompt_none_name": "Nenhuma", + "prompt_none_definition": "Classe: Nenhuma\nContexto da classe Nenhuma: A classe nenhuma é usada quando a frase não se relaciona com as outras classes definidas.\n\n", + "prompt_phrase_prefix": "Frase: ", + "prompt_analyse_text": "Pare, pense bem e analise qual é a melhor resposta de classe para a frase, responda só se você tiver muita certeza.\n\nClasse: " + }, + "en": { + "prompt_context": "You are very expert in ", + "prompt_has_classes": ". You have the following classes:\n\n", + "prompt_class_suffix": "Class: ", + "is_class_context_prefix": False, + "prompt_context_classes": " class context", + "prompt_none_name": "None", + "prompt_none_definition": "Class: None\n\nNone class context: Use when a sentence doesn't align with any of the classes above.\n", + "prompt_phrase_prefix": "Sentence: ", + "prompt_analyse_text": "Stop, think carefully and analyze what the best class answer to the sentence is, only answer if you are very sure.\n\nClass:" + }, + "es": { + "prompt_context": "Eres muy experto en ", + "prompt_has_classes": ". Usted posee las clases:\n\n", + "prompt_class_prefix": "Clase: ", + "is_class_context_prefix": True, + "prompt_context_classes": "Contexto de la clase ", + "prompt_none_name": "Ninguna", + "prompt_none_definition": "Clase: Ninguna\n\nContexto de la clase Ninguna: Aplicable si el tema no corresponde con las clases establecidas.", + "prompt_phrase_prefix": "Frase: ", + "prompt_analyse_text": "Detente, piensa detenidamente y analiza cuál es la mejor respuesta de clase a la frase, responde sólo si estás muy seguro.\n\nClase:" + } + } + + def generate_prompt(self, language: str, zeroshot_data: dict): + + translated_text = self.const_prompt_data[language] + prompt = translated_text.get("prompt_context") + zeroshot_data.get("context") + translated_text.get("prompt_has_classes") + for option in zeroshot_data.get("options"): + prompt += translated_text.get("prompt_class_prefix") + option.get("class", "").capitalize() + "\n" + if translated_text.get("is_class_context_prefix"): + prompt += translated_text.get("prompt_context_classes") + option.get("class", "").capitalize() + ": " + else: + prompt += option.get("class", "").capitalize() + translated_text.get("prompt_context_classes") + ": " + prompt += option.get("context") + "\n" + prompt += translated_text.get("prompt_none_definition") + translated_text.get("prompt_phrase_prefix") + prompt += zeroshot_data.get("text") + "\n" + translated_text.get("prompt_analyse_text") + + return prompt + + def get_none_class(self, language: str): + data = self.const_prompt_data.get(language) + return data.get("prompt_none_name", "Nenhuma") \ No newline at end of file diff --git a/bothub/api/v2/zeroshot/views.py b/bothub/api/v2/zeroshot/views.py index 4a497b52..154b06c4 100644 --- a/bothub/api/v2/zeroshot/views.py +++ b/bothub/api/v2/zeroshot/views.py @@ -16,6 +16,8 @@ ZeroshotLogs ) +from .usecases.format_prompt import FormatPrompt + from bothub.api.v2.zeroshot.permissions import ZeroshotTokenPermission logger = logging.getLogger(__name__) @@ -83,20 +85,24 @@ class ZeroShotFastPredictAPIView(APIView): def post(self, request): data = request.data - - classes = {} - - for categorie in data.get("categories"): - option = categorie.get("option") - classes[option] = [option] - for synonym in categorie.get("synonyms"): - classes[option].append(synonym) + formatter = FormatPrompt() + prompt = formatter.generate_prompt(data.get("language"), data) body = { "input": { "text": data.get("text"), "language": data.get("language"), - "classes": classes + "prompt": prompt, + "sampling_params": { + "max_tokens": settings.ZEROSHOT_MAX_TOKENS, + "n": settings.ZEROSHOT_N, + "top_p": settings.ZEROSHOT_TOP_P, + "tok_k": settings.ZEROSHOT_TOK_K, + "temperature": settings.ZEROSHOT_TEMPERATURE, + "do_sample": settings.ZEROSHOT_DO_SAMPLE, + "stop": settings.ZEROSHOT_STOP + } + } } @@ -114,17 +120,32 @@ def post(self, request): url=url, json=body ) + + response = {} + other = False + classification = None + if response_nlp.status_code == 200: classification_data = response_nlp.json().get("output") - ZeroshotLogs.objects.create( - text=data.get("text"), - classification=classification_data.get("classification"), - other=classification_data.get("other"), - categories=classes, - nlp_log=json.dumps(response_nlp.json()), - language=data.get("language"), - ) - return Response(status=response_nlp.status_code, data=response_nlp.json() if response_nlp.status_code == 200 else {"error": response_nlp.text}) + classification = classification_data.get("text")[0].strip() + other = formatter.get_none_class(language=data.get("language")) in classification + response = { + "output": { + "classification": classification, + "other": other + } + } + + ZeroshotLogs.objects.create( + text=data.get("text"), + classification=classification, + other=other, + options=data.get("options"), + nlp_log=str(response_nlp.json()), + language=data.get("language") + ) + + return Response(status=response_nlp.status_code, data=response if response_nlp.status_code == 200 else {"error": response_nlp.text}) except Exception as error: logger.error(f"[ - ] Zeroshot fast predict: {error}") return Response(status=response_nlp.status_code if response_nlp else 500, data={"error": error}) diff --git a/bothub/settings.py b/bothub/settings.py index cb5a0b9a..d046df29 100644 --- a/bothub/settings.py +++ b/bothub/settings.py @@ -123,7 +123,14 @@ ZEROSHOT_BASE_NLP_URL=(str, ""), FLOWS_TOKEN_ZEROSHOT=(str, ""), ZEROSHOT_SUFFIX=(str, ""), - ZEROSHOT_TOKEN=(str, "") + ZEROSHOT_TOKEN=(str, ""), + ZEROSHOT_MAX_TOKENS = (int, 20), + ZEROSHOT_N = (int, 1), + ZEROSHOT_TOP_P = (float, 0.95), + ZEROSHOT_TOK_K = (int, 10), + ZEROSHOT_TEMPERATURE = (float, 0.1), + ZEROSHOT_DO_SAMPLE = (bool, False), + ZEROSHOT_STOP = (str, "\n"), ) # Build paths inside the project like this: os.path.join(BASE_DIR, ...) @@ -710,3 +717,10 @@ FLOWS_TOKEN_ZEROSHOT = env.str("FLOWS_TOKEN_ZEROSHOT") ZEROSHOT_SUFFIX = env.str("ZEROSHOT_SUFFIX") ZEROSHOT_TOKEN = env.str("ZEROSHOT_TOKEN") +ZEROSHOT_MAX_TOKENS = env.int("ZEROSHOT_MAX_TOKENS") +ZEROSHOT_N = env.int("ZEROSHOT_N") +ZEROSHOT_TOP_P = env.float("ZEROSHOT_TOP_P") +ZEROSHOT_TOK_K = env.int("ZEROSHOT_TOK_K") +ZEROSHOT_TEMPERATURE = env.float("ZEROSHOT_TEMPERATURE") +ZEROSHOT_DO_SAMPLE = env.bool("ZEROSHOT_DO_SAMPLE") +ZEROSHOT_STOP = env.str("ZEROSHOT_STOP")