diff --git a/.flake8 b/.flake8 index cf360dcde2f..f0c26726e47 100644 --- a/.flake8 +++ b/.flake8 @@ -141,7 +141,6 @@ exclude = mindsdb/integrations/handlers/ingres_handler/* mindsdb/integrations/handlers/palm_handler/* mindsdb/integrations/handlers/binance_handler/* - mindsdb/integrations/handlers/openai_handler/* mindsdb/integrations/handlers/mediawiki_handler/* mindsdb/integrations/handlers/mendeley_handler/* mindsdb/integrations/handlers/databend_handler/* diff --git a/.github/workflows/test_on_push.yml b/.github/workflows/test_on_push.yml index d88aa7edde5..7d3d37e34c0 100644 --- a/.github/workflows/test_on_push.yml +++ b/.github/workflows/test_on_push.yml @@ -3,7 +3,7 @@ name: Test on Push on: pull_request: branches: [main] - + defaults: run: shell: bash @@ -14,7 +14,6 @@ concurrency: group: ${{ github.workflow_ref }} cancel-in-progress: true - jobs: # Run all of our static code checks here code_checking: @@ -27,8 +26,8 @@ jobs: with: python-version: ${{ vars.CI_PYTHON_VERSION }} cache: pip - cache-dependency-path: '**/requirements*.txt' - + cache-dependency-path: "**/requirements*.txt" + # Checks the codebase for print() statements and fails if any are found # We should be using loggers instead - name: Check for print statements @@ -46,7 +45,7 @@ jobs: uses: pre-commit/action@v3.0.0 with: extra_args: --files ${{ steps.changed-files.outputs.all_changed_files }} - + # Runs a few different checks against our many requirements files # to make sure they're in order - name: Check requirements files @@ -68,8 +67,8 @@ jobs: - id: set-matrix uses: JoshuaTheMiller/conditional-build-matrix@v2.0.1 with: - filter: '[?runOnBranch==`${{ github.ref }}` || runOnBranch==`always`]' - + filter: "[?runOnBranch==`${{ github.ref }}` || runOnBranch==`always`]" + # Check that our pip package is able to be installed in all of our supported environments check_install: name: Check pip installation @@ -78,23 +77,23 @@ jobs: matrix: ${{fromJson(needs.matrix_prep.outputs.matrix)}} runs-on: ${{ matrix.runs_on }} steps: - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5.1.0 - with: - python-version: ${{ matrix.python-version }} - cache: pip - cache-dependency-path: '**/requirements*.txt' - - name: Check requirements files are installable - run: | - # Install dev requirements and build our pip package - pip install -r requirements/requirements-dev.txt - python setup.py sdist + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5.1.0 + with: + python-version: ${{ matrix.python-version }} + cache: pip + cache-dependency-path: "**/requirements*.txt" + - name: Check requirements files are installable + run: | + # Install dev requirements and build our pip package + pip install -r requirements/requirements-dev.txt + python setup.py sdist - # Install from the pip package - # If we install from source, we don't know if the pip package is installable. - cd dist - pip install --ignore-installed *.tar.gz + # Install from the pip package + # If we install from source, we don't know if the pip package is installable. + cd dist + pip install --ignore-installed *.tar.gz unit_tests: name: Run Unit Tests @@ -104,43 +103,44 @@ jobs: runs-on: ${{ matrix.runs_on }} if: github.ref_type == 'branch' steps: - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5.1.0 - with: - python-version: ${{ matrix.python-version }} - cache: pip - cache-dependency-path: '**/requirements*.txt' - - name: Install dependencies - run: | - pip install . - pip install -r requirements/requirements-test.txt - pip install .[lightwood] # TODO: for now some tests rely on lightwood - pip install .[mssql] - pip install .[clickhouse] - pip install .[snowflake] - pip install .[web] - pip freeze - - name: Run unit tests - run: | - if [ "$RUNNER_OS" == "Linux" ]; then - env PYTHONPATH=./ pytest tests/unit/test_executor.py - env PYTHONPATH=./ pytest tests/unit/test_project_structure.py - env PYTHONPATH=./ pytest tests/unit/test_predictor_params.py - env PYTHONPATH=./ pytest tests/unit/test_mongodb_handler.py - env PYTHONPATH=./ pytest tests/unit/test_mongodb_server.py - env PYTHONPATH=./ pytest tests/unit/test_cache.py - env PYTHONPATH=./ pytest tests/unit/test_llm_utils.py - env PYTHONPATH=./ pytest tests/unit/ml_handlers/test_mindsdb_inference.py - fi - - name: Run Handlers tests and submit Coverage to coveralls - run: | - handlers=("mysql" "postgres" "mssql" "clickhouse" "snowflake" "web") - for handler in "${handlers[@]}" - do - pytest --cov=mindsdb/integrations/handlers/${handler}_handler tests/unit/handlers/test_${handler}.py - done - coveralls --service=github --basedir=mindsdb/integrations/handlers - env: - COVERALLS_REPO_TOKEN: ${{ secrets.COVERALLS_REPO_TOKEN }} - github_token: ${{ secrets.REPO_DISPATCH_PAT_TOKEN }} \ No newline at end of file + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5.1.0 + with: + python-version: ${{ matrix.python-version }} + cache: pip + cache-dependency-path: "**/requirements*.txt" + - name: Install dependencies + run: | + pip install . + pip install -r requirements/requirements-test.txt + pip install .[lightwood] # TODO: for now some tests rely on lightwood + pip install .[mssql] + pip install .[clickhouse] + pip install .[snowflake] + pip install .[web] + pip freeze + - name: Run unit tests + run: | + if [ "$RUNNER_OS" == "Linux" ]; then + env PYTHONPATH=./ pytest tests/unit/test_executor.py + env PYTHONPATH=./ pytest tests/unit/test_project_structure.py + env PYTHONPATH=./ pytest tests/unit/test_predictor_params.py + env PYTHONPATH=./ pytest tests/unit/test_mongodb_handler.py + env PYTHONPATH=./ pytest tests/unit/test_mongodb_server.py + env PYTHONPATH=./ pytest tests/unit/test_cache.py + env PYTHONPATH=./ pytest tests/unit/test_llm_utils.py + env PYTHONPATH=./ pytest tests/unit/ml_handlers/test_mindsdb_inference.py + env PYTHONPATH=./ pytest tests/unit/ml_handlers/test_openai.py + fi + - name: Run Handlers tests and submit Coverage to coveralls + run: | + handlers=("mysql" "postgres" "mssql" "clickhouse" "snowflake" "web") + for handler in "${handlers[@]}" + do + pytest --cov=mindsdb/integrations/handlers/${handler}_handler tests/unit/handlers/test_${handler}.py + done + coveralls --service=github --basedir=mindsdb/integrations/handlers + env: + COVERALLS_REPO_TOKEN: ${{ secrets.COVERALLS_REPO_TOKEN }} + github_token: ${{ secrets.REPO_DISPATCH_PAT_TOKEN }} diff --git a/mindsdb/integrations/handlers/openai_handler/helpers.py b/mindsdb/integrations/handlers/openai_handler/helpers.py index a7f19fcc5c3..23c3fc1005a 100644 --- a/mindsdb/integrations/handlers/openai_handler/helpers.py +++ b/mindsdb/integrations/handlers/openai_handler/helpers.py @@ -1,5 +1,4 @@ -import os -from typing import List +from typing import Text, List, Dict import random import time import math @@ -10,11 +9,14 @@ import tiktoken import mindsdb.utilities.profiler as profiler -from mindsdb.integrations.handlers.openai_handler.constants import OPENAI_API_BASE class PendingFT(openai.OpenAIError): + """ + Custom exception to handle pending fine-tuning status. + """ message: str + def __init__(self, message) -> None: super().__init__() self.message = message @@ -44,6 +46,18 @@ def _retry_with_exponential_backoff(func): Slight changes in the implementation, but originally from: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_handle_rate_limits.ipynb + + Args: + func: Function to be wrapped + initial_delay: Initial delay in seconds + hour_budget: Hourly budget in seconds + jitter: Adds randomness to the delay + exponential_base: Base for the exponential backoff + wait_errors: Tuple of errors to retry on + status_errors: Tuple of status errors to raise + + Returns: + Wrapper function with exponential backoff """ # noqa def wrapper(*args, **kwargs): @@ -94,10 +108,22 @@ def wrapper(*args, **kwargs): return _retry_with_exponential_backoff -def truncate_msgs_for_token_limit(messages, model_name, max_tokens, truncate='first'): +def truncate_msgs_for_token_limit(messages: List[Dict], model_name: Text, max_tokens: int, truncate: Text = 'first'): """ Truncates message list to fit within the token limit. - Note: first message for chat completion models are general directives with the system role, which will ideally be kept at all times. + The first message for chat completion models are general directives with the system role, which will ideally be kept at all times. + + Slight changes in the implementation, but originally from: + https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + + Args: + messages (List[Dict]): List of messages + model_name (Text): Model name + max_tokens (int): Maximum token limit + truncate (Text): Truncate strategy, either 'first' or 'last' + + Returns: + List[Dict]: Truncated message list """ # noqa encoder = tiktoken.encoding_for_model(model_name) sys_priming = messages[0:1] @@ -119,8 +145,15 @@ def truncate_msgs_for_token_limit(messages, model_name, max_tokens, truncate='fi return messages -def count_tokens(messages, encoder, model_name='gpt-3.5-turbo-0301'): - """Original token count implementation can be found in the OpenAI cookbook.""" +def count_tokens(messages: List[Dict], encoder: tiktoken.core.Encoding, model_name: Text = 'gpt-3.5-turbo-0301'): + """ + Counts the number of tokens in a list of messages. + + Args: + messages: List of messages + encoder: Tokenizer + model_name: Model name + """ if ( "gpt-3.5-turbo" in model_name ): # note: future models may deviate from this (only 0301 really complies) @@ -141,9 +174,16 @@ def count_tokens(messages, encoder, model_name='gpt-3.5-turbo-0301'): ) -def get_available_models(api_key: str, api_base: str) -> List[str]: +def get_available_models(api_key: Text, api_base: Text) -> List[Text]: """ Returns a list of available openai models for the given API key. + + Args: + api_key (Text): OpenAI API key + api_base (Text): OpenAI API base URL + + Returns: + List[Text]: List of available models """ res = OpenAI(api_key=api_key, base_url=api_base).models.list() diff --git a/mindsdb/integrations/handlers/openai_handler/openai_handler.py b/mindsdb/integrations/handlers/openai_handler/openai_handler.py index f08da97972c..35167caf3f4 100644 --- a/mindsdb/integrations/handlers/openai_handler/openai_handler.py +++ b/mindsdb/integrations/handlers/openai_handler/openai_handler.py @@ -7,7 +7,7 @@ import textwrap import subprocess import concurrent.futures -from typing import Optional, Dict +from typing import Text, Tuple, Dict, List, Optional, Any import openai from openai import OpenAI, NotFoundError, AuthenticationError import numpy as np @@ -35,6 +35,10 @@ class OpenAIHandler(BaseMLEngine): + """ + This handler handles connection and inference with the OpenAI API. + """ + name = 'openai' def __init__(self, *args, **kwargs): @@ -56,14 +60,24 @@ def __init__(self, *args, **kwargs): self.max_batch_size = 20 self.default_max_tokens = 100 self.chat_completion_models = CHAT_MODELS - self.supported_ft_models = FINETUNING_MODELS # base models compatible with finetuning + self.supported_ft_models = FINETUNING_MODELS # base models compatible with finetuning # For now this are only used for handlers that inherits OpenAIHandler and don't need to override base methods self.api_key_name = getattr(self, 'api_key_name', self.name) self.api_base = getattr(self, 'api_base', OPENAI_API_BASE) - def create_engine(self, connection_args): - '''check api key if provided - ''' + def create_engine(self, connection_args: Dict) -> None: + """ + Validate the OpenAI API credentials on engine creation. + + Args: + connection_args (Dict): Parameters for the engine. + + Raises: + Exception: If the handler is not configured with valid API credentials. + + Returns: + None + """ connection_args = {k.lower(): v for k, v in connection_args.items()} api_key = connection_args.get('openai_api_key') if api_key is not None: @@ -73,15 +87,19 @@ def create_engine(self, connection_args): OpenAIHandler._check_client_connection(client) @staticmethod - def _check_client_connection(client: OpenAI): - '''try to connect to api + def _check_client_connection(client: OpenAI) -> None: + """ + Check the OpenAI engine client connection by retrieving a model. Args: - client (OpenAI): + client (openai.OpenAI): OpenAI client configured with the API credentials. Raises: - Exception: if there is AuthenticationError - ''' + Exception: If the client connection (API key) is invalid or something else goes wrong. + + Returns: + None + """ try: client.models.retrieve('test') except NotFoundError: @@ -92,7 +110,21 @@ def _check_client_connection(client: OpenAI): raise Exception(f'Something went wrong: {e}') @staticmethod - def create_validation(target, args=None, **kwargs): + def create_validation(target: Text, args: Dict = None, **kwargs: Any) -> None: + """ + Validate the OpenAI API credentials on model creation. + + Args: + target (Text): Target column name. + args (Dict): Parameters for the model. + kwargs (Any): Other keyword arguments. + + Raises: + Exception: If the handler is not configured with valid API credentials. + + Returns: + None + """ if 'using' not in args: raise Exception( "OpenAI engine requires a USING clause! Refer to its documentation for more details." @@ -173,7 +205,21 @@ def create_validation(target, args=None, **kwargs): client = OpenAIHandler._get_client(api_key=api_key, base_url=api_base, org=org) OpenAIHandler._check_client_connection(client) - def create(self, target, args=None, **kwargs): + def create(self, target, args: Dict = None, **kwargs: Any) -> None: + """ + Create a model by connecting to the OpenAI API. + + Args: + target (Text): Target column name. + args (Dict): Parameters for the model. + kwargs (Any): Other keyword arguments. + + Raises: + Exception: If the model is not configured with valid parameters. + + Returns: + None + """ args = args['using'] args['target'] = target try: @@ -201,7 +247,17 @@ def create(self, target, args=None, **kwargs): def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame: """ - If there is a prompt template, we use it. Otherwise, we use the concatenation of `context_column` (optional) and `question_column` to ask for a completion. + Make predictions using a model connected to the OpenAI API. + + Args: + df (pd.DataFrame): Input data to make predictions on. + args (Dict): Parameters passed when making predictions. + + Raises: + Exception: If the model is not configured with valid parameters or if the input data is not in the expected format. + + Returns: + pd.DataFrame: Input data with the predicted values in a new column. """ # noqa # TODO: support for edits, embeddings and moderation @@ -209,11 +265,11 @@ def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame args = self.model_storage.json_get('args') connection_args = self.engine_storage.get_connection_args() - args['api_base'] = (pred_args.get('api_base') or - self.api_base or - connection_args.get('api_base') or - args.get('api_base') or - os.environ.get('OPENAI_API_BASE', OPENAI_API_BASE)) + args['api_base'] = (pred_args.get('api_base') + or self.api_base + or connection_args.get('api_base') + or args.get('api_base') + or os.environ.get('OPENAI_API_BASE', OPENAI_API_BASE)) if pred_args.get('api_organization'): args['api_organization'] = pred_args['api_organization'] df = df.reset_index(drop=True) @@ -237,7 +293,7 @@ def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame else: base_template = None - # Embedding Mode + # Embedding mode if args.get('mode', self.default_mode) == 'embedding': api_args = { 'question_column': pred_args.get('question_column', None), @@ -253,7 +309,6 @@ def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame raise Exception('Embedding mode needs a question_column') # Image mode - elif args.get('mode', self.default_mode) == 'image': api_args = { 'n': pred_args.get('n', None), @@ -295,7 +350,7 @@ def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame f"This model expects context in the '{args['context_column']}' column." ) - # api argument validation + # API argument validation model_name = args.get('model_name', self.default_model) api_args = { 'max_tokens': pred_args.get( @@ -424,10 +479,18 @@ def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame return pred_df def _completion( - self, model_name, prompts, api_key, api_args, args, df, parallel=True - ): + self, model_name: Text, prompts: List[Text], api_key: Text, api_args: Dict, args: Dict, df: pd.DataFrame, parallel: bool = True + ) -> List[Any]: """ - Handles completion for an arbitrary amount of rows. + Handles completion for an arbitrary amount of rows using a model connected to the OpenAI API. + + This method consists of several inner methods: + - _submit_completion: Submit a request to the relevant completion endpoint of the OpenAI API based on the type of task. + - _submit_normal_completion: Submit a request to the completion endpoint of the OpenAI API. + - _submit_embedding_completion: Submit a request to the embeddings endpoint of the OpenAI API. + - _submit_chat_completion: Submit a request to the chat completion endpoint of the OpenAI API. + - _submit_image_completion: Submit a request to the image completion endpoint of the OpenAI API. + - _log_api_call: Log the API call made to the OpenAI API. There are a couple checks that should be done when calling OpenAI's API: - account max batch size, to maximize batch size first @@ -435,10 +498,35 @@ def _completion( Additionally, single completion calls are done with exponential backoff to guarantee all prompts are processed, because even with previous checks the tokens-per-minute limit may apply. + + Args: + model_name (Text): OpenAI Model name. + prompts (List[Text]): List of prompts. + api_key (Text): OpenAI API key. + api_args (Dict): OpenAI API arguments. + args (Dict): Parameters for the model. + df (pd.DataFrame): Input data to run completion on. + parallel (bool): Whether to use parallel processing. + + Returns: + List[Any]: List of completions. The type of completion depends on the task type. """ @retry_with_exponential_backoff() - def _submit_completion(model_name, prompts, api_args, args, df): + def _submit_completion(model_name: Text, prompts: List[Text], api_args: Dict, args: Dict, df: pd.DataFrame) -> List[Text]: + """ + Submit a request to the relevant completion endpoint of the OpenAI API based on the type of task. + + Args: + model_name (Text): OpenAI Model name. + prompts (List[Text]): List of prompts. + api_args (Dict): OpenAI API arguments. + args (Dict): Parameters for the model. + df (pd.DataFrame): Input data to run completion on. + + Returns: + List[Text]: List of completions. + """ kwargs = { 'model': model_name, } @@ -457,7 +545,17 @@ def _submit_completion(model_name, prompts, api_args, args, df): else: return _submit_normal_completion(kwargs, prompts, api_args) - def _log_api_call(params, response): + def _log_api_call(params: Dict, response: Any) -> None: + """ + Log the API call made to the OpenAI API. + + Args: + params (Dict): Parameters for the API call. + response (Any): Response from the API. + + Returns: + None + """ after_openai_query(params, response) params2 = params.copy() @@ -465,11 +563,35 @@ def _log_api_call(params, response): params2.pop('user', None) logger.debug(f'>>>openai call: {params2}:\n{response}') - def _submit_normal_completion(kwargs, prompts, api_args): - def _tidy(comp): + def _submit_normal_completion(kwargs: Dict, prompts: List[Text], api_args: Dict) -> List[Text]: + """ + Submit a request to the completion endpoint of the OpenAI API. + + This method consists of an inner method: + - _tidy: Parse and tidy up the response from the completion endpoint of the OpenAI API. + + Args: + kwargs (Dict): OpenAI API arguments, including the model to use. + prompts (List[Text]): List of prompts. + api_args (Dict): Other OpenAI API arguments. + + Returns: + List[Text]: List of text completions. + """ + + def _tidy(comp: openai.types.completion.Completion) -> List[Text]: + """ + Parse and tidy up the response from the completion endpoint of the OpenAI API. + + Args: + comp (openai.types.completion.Completion): Completion object. + + Returns: + List[Text]: List of completions as text. + """ tidy_comps = [] for c in comp.choices: - if hasattr(c,'text'): + if hasattr(c, 'text'): tidy_comps.append(c.text.strip('\n').strip('')) return tidy_comps @@ -481,11 +603,35 @@ def _tidy(comp): _log_api_call(kwargs, resp) return resp - def _submit_embedding_completion(kwargs, prompts, api_args): - def _tidy(comp): + def _submit_embedding_completion(kwargs: Dict, prompts: List[Text], api_args: Dict) -> List[float]: + """ + Submit a request to the embeddings endpoint of the OpenAI API. + + This method consists of an inner method: + - _tidy: Parse and tidy up the response from the embeddings endpoint of the OpenAI API. + + Args: + kwargs (Dict): OpenAI API arguments, including the model to use. + prompts (List[Text]): List of prompts. + api_args (Dict): Other OpenAI API arguments. + + Returns: + List[float]: List of embeddings as numbers. + """ + + def _tidy(comp: openai.types.create_embedding_response.CreateEmbeddingResponse) -> List[float]: + """ + Parse and tidy up the response from the embeddings endpoint of the OpenAI API. + + Args: + comp (openai.types.create_embedding_response.CreateEmbeddingResponse): Embedding object. + + Returns: + List[float]: List of embeddings as numbers. + """ tidy_comps = [] for c in comp.data: - if hasattr(c,'embedding'): + if hasattr(c, 'embedding'): tidy_comps.append([c.embedding]) return tidy_comps @@ -497,13 +643,37 @@ def _tidy(comp): _log_api_call(kwargs, resp) return resp - def _submit_chat_completion( - kwargs, prompts, api_args, df, mode='conversational' - ): - def _tidy(comp): + def _submit_chat_completion(kwargs: Dict, prompts: List[Text], api_args: Dict, df: pd.DataFrame, mode: Text = 'conversational') -> List[Text]: + """ + Submit a request to the chat completion endpoint of the OpenAI API. + + This method consists of an inner method: + - _tidy: Parse and tidy up the response from the chat completion endpoint of the OpenAI API. + + Args: + kwargs (Dict): OpenAI API arguments, including the model to use. + prompts (List[Text]): List of prompts. + api_args (Dict): Other OpenAI API arguments. + df (pd.DataFrame): Input data to run chat completion on. + mode (Text): Mode of operation. + + Returns: + List[Text]: List of chat completions as text. + """ + + def _tidy(comp: openai.types.chat.chat_completion.ChatCompletion) -> List[Text]: + """ + Parse and tidy up the response from the chat completion endpoint of the OpenAI API. + + Args: + comp (openai.types.chat.chat_completion.ChatCompletion): Chat completion object. + + Returns: + List[Text]: List of chat completions as text. + """ tidy_comps = [] for c in comp.choices: - if hasattr(c,'message'): + if hasattr(c, 'message'): tidy_comps.append(c.message.content.strip('\n').strip('')) return tidy_comps @@ -575,10 +745,37 @@ def _tidy(comp): return completions - def _submit_image_completion(kwargs, prompts, api_args): - def _tidy(comp): + def _submit_image_completion(kwargs: Dict, prompts: List[Text], api_args: Dict) -> List[Text]: + """ + Submit a request to the image generation endpoint of the OpenAI API. + + This method consists of an inner method: + - _tidy: Parse and tidy up the response from the image generation endpoint of the OpenAI API. + + Args: + kwargs (Dict): OpenAI API arguments, including the model to use. + prompts (List[Text]): List of prompts. + api_args (Dict): Other OpenAI API arguments. + + Raises: + Exception: If the maximum batch size is reached. + + Returns: + List[Text]: List of image completions as URLs or base64 encoded images. + """ + + def _tidy(comp: List[openai.types.image.Image]) -> List[Text]: + """ + Parse and tidy up the response from the image generation endpoint of the OpenAI API. + + Args: + comp (List[openai.types.image.Image]): Image completion objects. + + Returns: + List[Text]: List of image completions as URLs or base64 encoded images. + """ return [ - c.url if hasattr(c,'url') else c.b64_json + c.url if hasattr(c, 'url') else c.b64_json for c in comp ] @@ -587,11 +784,13 @@ def _tidy(comp): for p in prompts ] return _tidy(completions) + client = self._get_client( api_key=api_key, base_url=args.get('api_base'), org=args.pop('api_organization') if 'api_organization' in args else None, - ) + ) + try: # check if simple completion works completion = _submit_completion( @@ -602,7 +801,7 @@ def _tidy(comp): # else, we get the max batch size if 'you can currently request up to at most a total of' in str(e): pattern = 'a total of' - max_batch_size = int(e[e.find(pattern) + len(pattern) :].split(').')[0]) + max_batch_size = int(e[e.find(pattern) + len(pattern):].split(').')[0]) else: max_batch_size = ( self.max_batch_size @@ -613,7 +812,7 @@ def _tidy(comp): for i in range(math.ceil(len(prompts) / max_batch_size)): partial = _submit_completion( model_name, - prompts[i * max_batch_size : (i + 1) * max_batch_size], + prompts[i * max_batch_size: (i + 1) * max_batch_size], api_args, args, df, @@ -634,7 +833,7 @@ def _tidy(comp): future = executor.submit( _submit_completion, model_name, - prompts[i * max_batch_size : (i + 1) * max_batch_size], + prompts[i * max_batch_size: (i + 1) * max_batch_size], api_args, args, df, @@ -649,7 +848,16 @@ def _tidy(comp): return completion - def describe(self, attribute: Optional[str] = None) -> pd.DataFrame: + def describe(self, attribute: Optional[Text] = None) -> pd.DataFrame: + """ + Get the metadata or arguments of a model. + + Args: + attribute (Optional[Text]): Attribute to describe. Can be 'args' or 'metadata'. + + Returns: + pd.DataFrame: Model metadata or model arguments. + """ # TODO: Update to use update() artifacts args = self.model_storage.json_get('args') @@ -659,7 +867,7 @@ def describe(self, attribute: Optional[str] = None) -> pd.DataFrame: elif attribute == 'metadata': model_name = args.get('model_name', self.default_model) try: - client= self._get_client( + client = self._get_client( api_key=api_key, base_url=args.get('api_base'), org=args.get('api_organization') @@ -672,11 +880,10 @@ def describe(self, attribute: Optional[str] = None) -> pd.DataFrame: tables = ['args', 'metadata'] return pd.DataFrame(tables, columns=['tables']) - def finetune( - self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None - ) -> None: + def finetune(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None: """ - Fine-tune OpenAI GPT models. Steps are roughly: + Fine-tune OpenAI GPT models via a MindsDB model connected to the OpenAI API. + Steps are roughly: - Analyze input data and modify it according to suggestions made by the OpenAI utility tool - Get a training and validation file - Determine base model to use @@ -688,19 +895,27 @@ def finetune( Caveats: - As base fine-tuning models, OpenAI only supports the original GPT ones: `ada`, `babbage`, `curie`, `davinci`. This means if you fine-tune successively more than once, any fine-tuning other than the most recent one is lost. - A bunch of helper methods exist to be overridden in other handlers that follow the OpenAI API, e.g. Anyscale - """ # noqa + Args: + df (Optional[pd.DataFrame]): Input data to fine-tune on. + args (Optional[Dict]): Parameters for the fine-tuning process. + + Raises: + Exception: If the model does not support fine-tuning. + + Returns: + None + """ # noqa args = args if args else {} api_key = get_api_key(self.api_key_name, args, self.engine_storage) using_args = args.pop('using') if 'using' in args else {} - + api_base = using_args.get('api_base', os.environ.get('OPENAI_API_BASE', OPENAI_API_BASE)) org = using_args.get('api_organization') client = self._get_client(api_key=api_key, base_url=api_base, org=org) - args = {**using_args, **args} prev_model_name = self.base_model_storage.json_get('args').get('model_name', '') @@ -727,8 +942,10 @@ def finetune( jsons = {k: None for k in file_names.keys()} for split, file_name in file_names.items(): if os.path.isfile(os.path.join(temp_storage_path, file_name)): - jsons[split] = client.files.create(file=open(f"{temp_storage_path}/{file_name}", "rb"), - purpose='fine-tune') + jsons[split] = client.files.create( + file=open(f"{temp_storage_path}/{file_name}", "rb"), + purpose='fine-tune' + ) if type(jsons['train']) is openai.types.FileObject: train_file_id = jsons['train'].id @@ -789,7 +1006,18 @@ def finetune( shutil.rmtree(temp_storage_path) @staticmethod - def _prepare_ft_jsonl(df, _, temp_filename, temp_model_path): + def _prepare_ft_jsonl(df: pd.DataFrame, _, temp_filename: Text, temp_model_path: Text) -> Dict: + """ + Prepare the input data for fine-tuning. + + Args: + df (pd.DataFrame): Input data to fine-tune on. + temp_filename (Text): Temporary filename. + temp_model_path (Text): Temporary model path. + + Returns: + Dict: File names for the fine-tuning process. + """ df.to_json(temp_model_path, orient='records', lines=True) # TODO avoid subprocess usage once OpenAI enables non-CLI access, or refactor to use our own LLM utils instead @@ -815,14 +1043,33 @@ def _prepare_ft_jsonl(df, _, temp_filename, temp_model_path): } return file_names - def _get_ft_model_type(self, model_name: str): + def _get_ft_model_type(self, model_name: Text) -> Text: + """ + Get the model to use for fine-tuning. If the model is not supported, the default model (babbage-002) is used. + + Args: + model_name (Text): Model name. + + Returns: + Text: Model to use for fine-tuning. + """ for model_type in self.supported_ft_models: if model_type in model_name.lower(): return model_type return 'babbage-002' @staticmethod - def _add_extra_ft_params(ft_params, using_args): + def _add_extra_ft_params(ft_params: Dict, using_args: Dict) -> Dict: + """ + Add extra parameters to the fine-tuning process. + + Args: + ft_params (Dict): Parameters for the fine-tuning process required by the OpenAI API. + using_args (Dict): Parameters passed when calling the fine-tuning process via a model. + + Returns: + Dict: Fine-tuning parameters with extra parameters. + """ extra_params = { 'n_epochs': using_args.get('n_epochs', None), 'batch_size': using_args.get('batch_size', None), @@ -843,11 +1090,25 @@ def _add_extra_ft_params(ft_params, using_args): } return {**ft_params, **extra_params} - def _ft_call(self, ft_params, client, hour_budget): + def _ft_call(self, ft_params: Dict, client: OpenAI, hour_budget: int) -> Tuple[openai.types.fine_tuning.FineTuningJob, Text]: """ - Separate method to account for both legacy and new endpoints. - Currently, `OpenAIHandler` uses the legacy endpoint. - Others, like `AnyscaleEndpointsHandler`, use the new endpoint. + Submit a fine-tuning job via the OpenAI API. + This method handles requests to both the legacy and new endpoints. + Currently, `OpenAIHandler` uses the legacy endpoint. Others, like `AnyscaleEndpointsHandler`, use the new endpoint. + + This method consists of an inner method: + - _check_ft_status: Check the status of a fine-tuning job via the OpenAI API. + + Args: + ft_params (Dict): Fine-tuning parameters. + client (openai.OpenAI): OpenAI client. + hour_budget (int): Hour budget for the fine-tuning process. + + Raises: + PendingFT: If the fine-tuning process is still pending. + + Returns: + Tuple[openai.types.fine_tuning.FineTuningJob, Text]: Fine-tuning stats and result file ID. """ ft_result = client.fine_tuning.jobs.create( **{k: v for k, v in ft_params.items() if v is not None} @@ -856,8 +1117,20 @@ def _ft_call(self, ft_params, client, hour_budget): @retry_with_exponential_backoff( hour_budget=hour_budget, ) - def _check_ft_status(model_id): - ft_retrieved = client.fine_tuning.jobs.retrieve(fine_tuning_job_id=model_id) + def _check_ft_status(job_id: Text) -> openai.types.fine_tuning.FineTuningJob: + """ + Check the status of a fine-tuning job via the OpenAI API. + + Args: + job_id (Text): Fine-tuning job ID. + + Raises: + PendingFT: If the fine-tuning process is still pending. + + Returns: + openai.types.fine_tuning.FineTuningJob: Fine-tuning stats. + """ + ft_retrieved = client.fine_tuning.jobs.retrieve(fine_tuning_job_id=job_id) if ft_retrieved.status in ('succeeded', 'failed', 'cancelled'): return ft_retrieved else: @@ -877,7 +1150,18 @@ def _check_ft_status(model_id): result_file_id = result_file_id.id # legacy endpoint return ft_stats, result_file_id - + @staticmethod - def _get_client(api_key, base_url, org=None): + def _get_client(api_key: Text, base_url: Text, org: Optional[Text] = None) -> OpenAI: + """ + Get an OpenAI client with the given API key, base URL, and organization. + + Args: + api_key (Text): OpenAI API key. + base_url (Text): OpenAI base URL. + org (Optional[Text]): OpenAI organization. + + Returns: + openai.OpenAI: OpenAI client. + """ return OpenAI(api_key=api_key, base_url=base_url, organization=org) diff --git a/mindsdb/integrations/handlers/openai_handler/tests/__init__.py b/mindsdb/integrations/handlers/openai_handler/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/mindsdb/integrations/handlers/openai_handler/tests/test_openai_handler.py b/mindsdb/integrations/handlers/openai_handler/tests/test_openai_handler.py new file mode 100644 index 00000000000..d4c09a68f3e --- /dev/null +++ b/mindsdb/integrations/handlers/openai_handler/tests/test_openai_handler.py @@ -0,0 +1,509 @@ +import os +import pytest +import pandas as pd +from unittest.mock import patch + +from tests.unit.ml_handlers.base_ml_test import BaseMLAPITest + + +@pytest.mark.skipif(os.environ.get('MDB_TEST_MDB_OPENAI_API_KEY') is None, reason='Missing API key!') +class TestOpenAI(BaseMLAPITest): + """ + Integration tests for the OpenAI handler. + """ + + def setup_method(self): + """ + Setup test environment by creating a project and an OpenAI engine. + """ + super().setup_method() + self.run_sql("CREATE DATABASE proj") + self.run_sql( + f""" + CREATE ML_ENGINE openai_engine + FROM openai + USING + openai_api_key = '{self.get_api_key('MDB_TEST_MDB_OPENAI_API_KEY')}'; + """ + ) + + def test_create_model_with_unsupported_model_raises_exception(self): + """ + Test if CREATE MODEL raises an exception with an unsupported model. + """ + self.run_sql( + """ + CREATE MODEL proj.test_openaai_unsupported_model_model + PREDICT answer + USING + engine='openai_engine', + model_name='this-model-does-not-exist', + prompt_template='dummy_prompt_template'; + """ + ) + with pytest.raises(Exception) as excinfo: + self.wait_predictor("proj", "test_openaai_unsupported_model_model") + + assert "Invalid model name." in str(excinfo.value) + + def test_full_flow_in_default_mode_with_question_column_for_single_prediction_runs_no_errors(self): + """ + Test the full flow in default mode with a question column for a single prediction. + """ + self.run_sql( + """ + CREATE MODEL proj.test_openai_single_full_flow_default_mode_question_column + PREDICT answer + USING + engine='openai_engine', + question_column='question'; + """ + ) + + self.wait_predictor("proj", "test_openai_single_full_flow_default_mode_question_column") + + result_df = self.run_sql( + """ + SELECT answer + FROM proj.test_openai_single_full_flow_default_mode_question_column + WHERE question='What is the capital of Sweden?' + """ + ) + assert "stockholm" in result_df["answer"].iloc[0].lower() + + @patch("mindsdb.integrations.handlers.postgres_handler.Handler") + def test_full_flow_in_default_mode_with_question_column_for_bulk_predictions_runs_no_errors(self, mock_handler): + """ + Test the full flow in default mode with a question column for bulk predictions. + """ + df = pd.DataFrame.from_dict({"question": [ + "What is the capital of Sweden?", + "What is the second planet of the solar system?" + ]}) + self.set_handler(mock_handler, name="pg", tables={"df": df}) + + self.run_sql( + """ + CREATE MODEL proj.test_openai_bulk_full_flow_default_mode_question_column + PREDICT answer + USING + engine='openai_engine', + question_column='question'; + """ + ) + + self.wait_predictor("proj", "test_openai_bulk_full_flow_default_mode_question_column") + + result_df = self.run_sql( + """ + SELECT p.answer + FROM pg.df as t + JOIN proj.test_openai_bulk_full_flow_default_mode_question_column as p; + """ + ) + assert "stockholm" in result_df["answer"].iloc[0].lower() + assert "venus" in result_df["answer"].iloc[1].lower() + + def test_full_flow_in_default_mode_with_prompt_template_for_single_prediction_runs_no_errors(self): + """ + Test the full flow in default mode with a prompt template for a single prediction. + """ + self.run_sql( + """ + CREATE MODEL proj.test_openai_single_full_flow_default_mode_prompt_template + PREDICT answer + USING + engine='openai_engine', + prompt_template='Answer this question and add "Boom!" to the end of the answer: {{{{question}}}}'; + """ + ) + + self.wait_predictor("proj", "test_openai_single_full_flow_default_mode_prompt_template") + + result_df = self.run_sql( + """ + SELECT answer + FROM proj.test_openai_single_full_flow_default_mode_prompt_template + WHERE question='What is the capital of Sweden?' + """ + ) + assert "stockholm" in result_df["answer"].iloc[0].lower() + assert "boom!" in result_df["answer"].iloc[0].lower() + + @patch("mindsdb.integrations.handlers.postgres_handler.Handler") + def test_full_flow_in_default_mode_with_prompt_template_for_bulk_predictions_runs_no_errors(self, mock_handler): + """ + Test the full flow in default mode with a prompt template for bulk predictions. + """ + df = pd.DataFrame.from_dict({"question": [ + "What is the capital of Sweden?", + "What is the second planet of the solar system?" + ]}) + self.set_handler(mock_handler, name="pg", tables={"df": df}) + + self.run_sql( + """ + CREATE MODEL proj.test_openai_bulk_full_flow_default_mode_prompt_template + PREDICT answer + USING + engine='openai_engine', + prompt_template='Answer this question and add "Boom!" to the end of the answer: {{{{question}}}}'; + """ + ) + + self.wait_predictor("proj", "test_openai_bulk_full_flow_default_mode_prompt_template") + + result_df = self.run_sql( + """ + SELECT p.answer + FROM pg.df as t + JOIN proj.test_openai_bulk_full_flow_default_mode_prompt_template as p; + """ + ) + assert "stockholm" in result_df["answer"].iloc[0].lower() + assert "boom!" in result_df["answer"].iloc[0].lower() + assert "venus" in result_df["answer"].iloc[1].lower() + assert "boom!" in result_df["answer"].iloc[1].lower() + + def test_full_flow_in_embedding_mode_for_single_prediction_runs_no_errors(self): + """ + Test the full flow in embedding mode for a single prediction. + """ + self.run_sql( + """ + CREATE MODEL proj.test_openai_single_full_flow_embedding_mode + PREDICT answer + USING + engine='openai_engine', + mode='embedding', + model_name = 'text-embedding-ada-002', + question_column = 'text'; + """ + ) + + self.wait_predictor("proj", "test_openai_single_full_flow_embedding_mode") + + result_df = self.run_sql( + """ + SELECT answer + FROM proj.test_openai_single_full_flow_embedding_mode + WHERE text='Sweden' + """ + ) + assert type(result_df["answer"].iloc[0]) == list + assert type(result_df["answer"].iloc[0][0]) == float + + @patch("mindsdb.integrations.handlers.postgres_handler.Handler") + def test_full_flow_in_embedding_mode_for_bulk_predictions_runs_no_errors(self, mock_handler): + """ + Test the full flow in embedding mode for bulk predictions. + """ + df = pd.DataFrame.from_dict({"text": [ + "Sweden", + "Venus" + ]}) + self.set_handler(mock_handler, name="pg", tables={"df": df}) + + self.run_sql( + """ + CREATE MODEL proj.test_openai_bulk_full_flow_embedding_mode + PREDICT answer + USING + engine='openai_engine', + mode='embedding', + model_name = 'text-embedding-ada-002', + question_column = 'text'; + """ + ) + + self.wait_predictor("proj", "test_openai_bulk_full_flow_embedding_mode") + + result_df = self.run_sql( + """ + SELECT p.answer + FROM pg.df as t + JOIN proj.test_openai_bulk_full_flow_embedding_mode as p; + """ + ) + assert type(result_df["answer"].iloc[0]) == list + assert type(result_df["answer"].iloc[0][0]) == float + assert type(result_df["answer"].iloc[1]) == list + assert type(result_df["answer"].iloc[1][0]) == float + + def test_full_flow_in_image_mode_for_single_prediction_runs_no_errors(self): + """ + Test the full flow in image mode for a single prediction. + """ + self.run_sql( + """ + CREATE MODEL proj.test_openai_single_full_flow_image_mode + PREDICT answer + USING + engine='openai_engine', + mode='image', + prompt_template='Generate an image for: {{{{text}}}}' + """ + ) + + self.wait_predictor("proj", "test_openai_single_full_flow_image_mode") + + result_df = self.run_sql( + """ + SELECT answer + FROM proj.test_openai_single_full_flow_image_mode + WHERE text='Leopard clubs playing in the jungle' + """ + ) + assert type(result_df["answer"].iloc[0]) == str + + @patch("mindsdb.integrations.handlers.postgres_handler.Handler") + def test_full_flow_in_image_mode_for_bulk_predictions_runs_no_errors(self, mock_handler): + """ + Test the full flow in image mode for bulk predictions. + """ + df = pd.DataFrame.from_dict({"text": [ + "Leopard clubs playing in the jungle", + "A beautiful sunset over the ocean" + ]}) + self.set_handler(mock_handler, name="pg", tables={"df": df}) + + self.run_sql( + """ + CREATE MODEL proj.test_openai_bulk_full_flow_image_mode + PREDICT answer + USING + engine='openai_engine', + mode='image', + prompt_template='Generate an image for: {{{{text}}}}' + """ + ) + + self.wait_predictor("proj", "test_openai_bulk_full_flow_image_mode") + + result_df = self.run_sql( + """ + SELECT p.answer + FROM pg.df as t + JOIN proj.test_openai_bulk_full_flow_image_mode as p; + """ + ) + assert type(result_df["answer"].iloc[0]) == str + assert type(result_df["answer"].iloc[1]) == str + + def test_full_flow_in_conversational_for_single_prediction_mode_runs_no_errors(self): + """ + Test the full flow in conversational mode for a single prediction. + """ + self.run_sql( + """ + CREATE MODEL proj.test_openai_single_full_flow_conversational_mode + PREDICT answer + USING + engine='openai_engine', + mode='conversational', + user_column='question', + prompt='you are a helpful assistant', + assistant_column='answer'; + """ + ) + + self.wait_predictor("proj", "test_openai_single_full_flow_conversational_mode") + + result_df = self.run_sql( + """ + SELECT answer + FROM proj.test_openai_single_full_flow_conversational_mode + WHERE question='What is the capital of Sweden?' + """ + ) + assert "stockholm" in result_df["answer"].iloc[0].lower() + + @patch("mindsdb.integrations.handlers.postgres_handler.Handler") + def test_full_flow_in_conversational_mode_for_bulk_predictions_runs_no_errors(self, mock_handler): + """ + Test the full flow in conversational mode for bulk predictions. + """ + df = pd.DataFrame.from_dict({"question": [ + "What is the capital of Sweden?", + "What are some cool places to visit there?" + ]}) + self.set_handler(mock_handler, name="pg", tables={"df": df}) + + self.run_sql( + """ + CREATE MODEL proj.test_openai_bulk_full_flow_conversational_mode + PREDICT answer + USING + engine='openai_engine', + mode='conversational', + user_column='question', + prompt='you are a helpful assistant', + assistant_column='answer'; + """ + ) + + self.wait_predictor("proj", "test_openai_bulk_full_flow_conversational_mode") + + result_df = self.run_sql( + """ + SELECT p.answer + FROM pg.df as t + JOIN proj.test_openai_bulk_full_flow_conversational_mode as p; + """ + ) + assert result_df["answer"].iloc[0] == "" + assert "gamla stan" in result_df["answer"].iloc[1].lower() + + def test_full_flow_in_conversational_full_mode_for_single_prediction_runs_no_errors(self): + """ + Test the full flow in conversational-full mode for a single prediction. + """ + self.run_sql( + """ + CREATE MODEL proj.test_openai_single_full_flow_conversational_full_mode + PREDICT answer + USING + engine='openai_engine', + mode='conversational-full', + user_column='question', + prompt='you are a helpful assistant', + assistant_column='answer'; + """ + ) + + self.wait_predictor("proj", "test_openai_single_full_flow_conversational_full_mode") + + result_df = self.run_sql( + """ + SELECT answer + FROM proj.test_openai_single_full_flow_conversational_full_mode + WHERE question='What is the capital of Sweden?' + """ + ) + assert "stockholm" in result_df["answer"].iloc[0].lower() + + @patch("mindsdb.integrations.handlers.postgres_handler.Handler") + def test_full_flow_in_conversational_full_mode_for_bulk_predictions_runs_no_errors(self, mock_handler): + """ + Test the full flow in conversational-full mode for bulk predictions. + """ + df = pd.DataFrame.from_dict({"question": [ + "What is the capital of Sweden?", + "What are some cool places to visit there?" + ]}) + self.set_handler(mock_handler, name="pg", tables={"df": df}) + + self.run_sql( + """ + CREATE MODEL proj.test_openai_bulk_full_flow_conversational_full_mode + PREDICT answer + USING + engine='openai_engine', + mode='conversational-full', + user_column='question', + prompt='you are a helpful assistant', + assistant_column='answer'; + """ + ) + + self.wait_predictor("proj", "test_openai_bulk_full_flow_conversational_full_mode") + + result_df = self.run_sql( + """ + SELECT p.answer + FROM pg.df as t + JOIN proj.test_openai_bulk_full_flow_conversational_full_mode as p; + """ + ) + assert "stockholm" in result_df["answer"].iloc[0].lower() + assert "gamla stan" in result_df["answer"].iloc[1].lower() + + # TODO: Fix this test for fine-tuning + # @patch("mindsdb.integrations.handlers.postgres_handler.Handler") + # def test_full_flow_finetune_runs_no_errors(self, mock_handler): + # """ + # Test the full flow for finetuning a model and making a prediction. + # """ + # df = pd.DataFrame.from_dict( + # { + # "prompt": [ + # "What is the SQL syntax to connect a database to MindsDB?", + # "What is the SQL command to connect to the demo postgres database for MindsDB learning hub examples?", + # "What is the SQL syntax to create a MindsDB machine learning model?", + # "What is the SQL syntax to join input data with predictions from a MindsDB machine learning model?" + # ], + # "completion": [ + # """ + # CREATE DATABASE datasource_name + # [WITH] [ENGINE [=] engine_name] [,] + # [PARAMETERS [=] { + # "key": "value", + # ... + # }]; + # """, + # """ + # CREATE DATABASE example_db + # WITH + # ENGINE = "postgres", + # PARAMETERS = { + # "user": "demo_user", + # "password": "demo_password", + # "host": "samples.mindsdb.com", + # "port": "5432", + # "database": "demo" + # }; + # """, + # """ + # CREATE MODEL + # mindsdb.home_rentals_model + # FROM example_db + # (SELECT * FROM demo_data.home_rentals) + # PREDICT rental_price; + # """, + # """ + # SELECT t.column_name, p.column_name, ... + # FROM integration_name.table_name [AS] t + # JOIN project_name.model_name [AS] p; + # """ + # ] + + # } + # ) + # self.set_handler(mock_handler, name="pg", tables={"df": df}) + + # self.run_sql( + # f""" + # CREATE MODEL proj.test_openai_full_flow_finetune + # PREDICT completion + # USING + # engine = 'openai_engine', + # model_name = 'davinci-002', + # prompt_template = 'Return a valid SQL string for the following question about MindsDB in-database machine learning: {{{{prompt}}}}'; + # """ + # ) + + # self.wait_predictor("proj", "test_openai_full_flow_finetune") + + # self.run_sql( + # """ + # FINETUNE proj.test_openai_full_flow_finetune + # FROM pg + # (SELECT prompt, completion FROM df); + # """ + # ) + + # self.wait_predictor("proj", "test_openai_full_flow_finetune", finetune=True) + + # result_df = self.run_sql( + # """ + # SELECT prompt, completion + # FROM proj.test_openai_full_flow_finetune + # WHERE prompt = 'What is the SQL syntax to join input data with predictions from a MindsDB machine learning model?' + # USING max_tokens=400; + # """ + # ) + # assert "SELECT t.column_name, p.column_name, ..." in result_df["completion"].iloc[0].lower() + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/unit/ml_handlers/test_openai.py b/tests/unit/ml_handlers/test_openai.py index 330369a8ade..97ac6ff185f 100644 --- a/tests/unit/ml_handlers/test_openai.py +++ b/tests/unit/ml_handlers/test_openai.py @@ -1,262 +1,674 @@ -import os -import time -import pytest -import pandas as pd -from unittest.mock import patch -from mindsdb_sql import parse_sql +import pandas +import unittest +from collections import OrderedDict +from unittest.mock import patch, MagicMock + from mindsdb.integrations.handlers.openai_handler.openai_handler import OpenAIHandler -from ..executor_test_base import BaseExecutorTest - -OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") - - -@pytest.mark.skipif(OPENAI_API_KEY is None, reason='Missing API key!') -class TestOpenAI(BaseExecutorTest): - """Test Class for OpenAI Integration Testing""" - - @staticmethod - def get_api_key(): - """Retrieve OpenAI API key from environment variables""" - return os.environ.get("OPENAI_API_KEY") - - def setup_method(self, method): - """Setup test environment, creating a project""" - super().setup_method() - self.run_sql("create database proj") - - def wait_predictor(self, project, name, timeout=100): - """ - Wait for the predictor to be created, - raising an exception if predictor creation fails or exceeds timeout - """ - for attempt in range(timeout): - ret = self.run_sql(f"select * from {project}.models where name='{name}'") - if not ret.empty: - status = ret["STATUS"][0] - if status == "complete": - return - elif status == "error": - raise RuntimeError("Predictor failed", ret["ERROR"][0]) - time.sleep(0.5) - raise RuntimeError("Predictor wasn't created") - - def run_sql(self, sql): - """Execute SQL and return a DataFrame, raising an AssertionError if an error occurs""" - ret = self.command_executor.execute_command(parse_sql(sql, dialect="mindsdb")) - assert ret.error_code is None, f"SQL execution failed with error: {ret.error_code}" - if ret.data is not None: - columns = [col.alias if col.alias else col.name for col in ret.columns] - return pd.DataFrame(ret.data, columns=columns) - - def test_missing_required_keys(self): - """Test for missing required keys""" - with pytest.raises(Exception): - self.run_sql( - f""" - create model proj.test_openai_missing_required_keys - predict answer - using - engine='openai', - api_key='{self.get_api_key()}'; - """ - ) - def test_invalid_openai_name_parameter(self): - """Test for invalid OpenAI model name parameter""" - self.run_sql( - f""" - create model proj.test_openai_nonexistant_model - predict answer - using - engine='openai', - question_column='question', - model_name='this-gpt-does-not-exist', - api_key='{self.get_api_key()}'; - """ - ) - with pytest.raises(Exception): - self.wait_predictor("proj", "test_openai_nonexistant_model") - - def test_unknown_arguments(self): - """Test for unknown arguments""" - with pytest.raises(Exception): - self.run_sql( - f""" - create model proj.test_openai_unknown_arguments - predict answer - using - engine='openai', - question_column='question', - api_key='{self.get_api_key()}', - evidently_wrong_argument='wrong value'; - """ - ) - @patch("mindsdb.integrations.handlers.postgres_handler.Handler") - def test_qa_no_context(self, mock_handler): - df = pd.DataFrame.from_dict({"question": [ - "What is the capital of Sweden?", - "What is the second planet of the solar system?" - ]}) - self.set_handler(mock_handler, name="pg", tables={"df": df}) - - self.run_sql( - f""" - create model proj.test_openai_qa_no_context - predict answer - using - engine='openai', - question_column='question', - api_key='{self.get_api_key()}'; +class TestOpenAI(unittest.TestCase): + """ + Unit tests for the OpenAI handler. + """ + + dummy_connection_data = OrderedDict( + openai_api_key='dummy_api_key', + ) + + def setUp(self): + # Mock model storage and engine storage + mock_engine_storage = MagicMock() + mock_model_storage = MagicMock() + + # Define a return value for the `get_connection_args` method of the mock engine storage + mock_engine_storage.get_connection_args.return_value = self.dummy_connection_data + + # Assign mock engine storage to instance variable for create validation tests + self.mock_engine_storage = mock_engine_storage + + self.handler = OpenAIHandler(mock_model_storage, mock_engine_storage, connection_data={'connection_data': self.dummy_connection_data}) + + def test_create_validation_without_using_clause_raises_exception(self): + """ + Test if model creation raises an exception without a USING clause. """ - ) - self.wait_predictor("proj", "test_openai_qa_no_context") - result_df = self.run_sql( - """ - SELECT p.answer - FROM proj.test_openai_qa_no_context as p - WHERE question='What is the capital of Sweden?' + with self.assertRaisesRegex(Exception, "OpenAI engine requires a USING clause! Refer to its documentation for more details."): + self.handler.create_validation('target', args={}, handler_storage=None) + + def test_create_validation_without_required_parameters_raises_exception(self): """ - ) - assert "stockholm" in result_df["answer"].iloc[0].lower() + Test if model creation raises an exception without required parameters. + """ + + with self.assertRaisesRegex(Exception, "One of `question_column`, `prompt_template` or `json_struct` is required for this engine."): + self.handler.create_validation('target', args={"using": {}}, handler_storage=self.mock_engine_storage) + + def test_create_validation_with_invalid_parameter_combinations_raises_exception(self): + """ + Test if model creation raises an exception with invalid parameter combinations. + """ + + with self.assertRaisesRegex(Exception, "^Please provide one of"): + self.handler.create_validation('target', args={"using": {'prompt_template': 'dummy_prompt_template', 'question_column': 'question'}}, handler_storage=self.mock_engine_storage) + + def test_create_validation_with_unknown_arguments_raises_exception(self): + """ + Test if model creation raises an exception with unknown arguments. + """ + + with self.assertRaisesRegex(Exception, "^Unknown arguments:"): + self.handler.create_validation('target', args={"using": {'prompt_template': 'dummy_prompt_template', 'unknown_arg': 'unknown_arg'}}, handler_storage=self.mock_engine_storage) + + def test_create_validation_with_invalid_api_key_raises_exception(self): + """ + Test if model creation raises an exception with an invalid API key. + """ + + with self.assertRaisesRegex(Exception, "Invalid api key"): + self.handler.create_validation('target', args={"using": {'prompt_template': 'dummy_prompt_template'}}, handler_storage=self.mock_engine_storage) + + @patch('mindsdb.integrations.handlers.openai_handler.openai_handler.OpenAI') + def test_create_validation_with_valid_arguments_runs_no_errors(self, mock_openai): + """ + Test if model creation is validated correctly with valid arguments. + """ + + # Mock the models.retrieve method of the OpenAI client + mock_openai_client = MagicMock() + mock_openai_client.models.retrieve.return_value = MagicMock() + + mock_openai.return_value = mock_openai_client + + self.handler.create_validation('target', args={'using': {'prompt_template': 'dummy_prompt_template'}}, handler_storage=self.mock_engine_storage) + + @patch('mindsdb.integrations.handlers.openai_handler.helpers.OpenAI') + @patch('mindsdb.integrations.handlers.openai_handler.openai_handler.OpenAI') + def test_create_with_invalid_mode_raises_exception(self, mock_openai_handler_openai_client, mock_openai_helpers_openai_client): + """ + Test if model creation raises an exception with an invalid mode. + """ + + # Mock the models.list method of the OpenAI client + mock_models_list = MagicMock() + mock_models_list.data = [ + MagicMock(id='dummy_model_name') + ] + + mock_openai_handler_openai_client.return_value.models.list.return_value = mock_models_list + mock_openai_helpers_openai_client.return_value.models.list.return_value = mock_models_list + + with self.assertRaisesRegex(Exception, "^Invalid operation mode."): + self.handler.create('dummy_target', args={'using': {'prompt_template': 'dummy_prompt_template', 'mode': 'dummy_mode'}}) + + @patch('mindsdb.integrations.handlers.openai_handler.helpers.OpenAI') + @patch('mindsdb.integrations.handlers.openai_handler.openai_handler.OpenAI') + def test_create_with_unsupported_model_raises_exception(self, mock_openai_handler_openai_client, mock_openai_helpers_openai_client): + """ + Test if model creation raises an exception with an invalid model name. + """ + + # Mock the models.list method of the OpenAI client + mock_models_list = MagicMock() + mock_models_list.data = [ + MagicMock(id='dummy_model_name') + ] + + mock_openai_handler_openai_client.return_value.models.list.return_value = mock_models_list + mock_openai_helpers_openai_client.return_value.models.list.return_value = mock_models_list + + with self.assertRaisesRegex(Exception, "^Invalid model name."): + self.handler.create('dummy_target', args={'using': {'model_name': 'dummy_unsupported_model_name', 'prompt_template': 'dummy_prompt_template'}}) + + @patch('mindsdb.integrations.handlers.openai_handler.helpers.OpenAI') + @patch('mindsdb.integrations.handlers.openai_handler.openai_handler.OpenAI') + def test_create_with_valid_arguments_runs_no_errors(self, mock_openai_handler_openai_client, mock_openai_helpers_openai_client): + """ + Test if model creation runs without errors with valid arguments. + """ + + # Mock the models.list method of the OpenAI client + mock_models_list = MagicMock() + mock_models_list.data = [ + MagicMock(id='dummy_model_name') + ] + + mock_openai_handler_openai_client.return_value.models.list.return_value = mock_models_list + mock_openai_helpers_openai_client.return_value.models.list.return_value = mock_models_list + + self.handler.create('dummy_target', args={'using': {'prompt_template': 'dummy_prompt_template'}}) + + def test_predict_with_invalid_mode_raises_exception(self): + """ + Test if model prediction raises an exception with an invalid mode. + """ + + # Create a dummy DataFrame + df = pandas.DataFrame() + + with self.assertRaisesRegex(Exception, "^Invalid operation mode."): + self.handler.predict(df=df, args={'predict_params': {'mode': 'dummy_mode'}}) + + def test_predict_in_embedding_mode_without_question_column_raises_exception(self): + """ + Test if model prediction raises an exception in embedding mode without a question column. + """ + + # Mock the json_get method of the model storage + self.handler.model_storage.json_get.return_value = { + 'mode': 'embedding', + } + + # Create a dummy DataFrame + df = pandas.DataFrame() + + with self.assertRaisesRegex(Exception, "Embedding mode needs a question_column"): + self.handler.predict(df=df, args={'predict_params': {'mode': 'embedding'}}) + + def test_predict_in_image_mode_without_question_column_or_prompt_template_raises_exception(self): + """ + Test if model prediction raises an exception in image mode without a question column or prompt template. + """ + + # Mock the json_get method of the model storage + self.handler.model_storage.json_get.return_value = { + 'mode': 'image', + } + + # Create a dummy DataFrame + df = pandas.DataFrame() - result_df = self.run_sql( - """ - SELECT p.answer - FROM pg.df as t - JOIN proj.test_openai_qa_no_context as p; + with self.assertRaisesRegex(Exception, "Image mode needs either `prompt_template` or `question_column`."): + self.handler.predict(df=df, args={'predict_params': {'mode': 'image'}}) + + def test_predict_in_default_mode_without_question_column_in_data_raises_exception(self): + """ + Test if model prediction raises an exception in default mode without a question column in the DataFrame. """ + + # Mock the json_get method of the model storage + self.handler.model_storage.json_get.return_value = { + 'mode': 'default', + 'question_column': 'question' + } + + # Create a dummy DataFrame + df = pandas.DataFrame() + + with self.assertRaisesRegex(Exception, "This model expects a question to answer in the 'question' column."): + self.handler.predict(df=df, args={'predict_params': {'mode': 'default'}}) + + def test_predict_in_default_mode_without_context_column_in_data_raises_exception(self): + """ + Test if model prediction raises an exception in default mode without a context column in the DataFrame. + """ + + # Mock the json_get method of the model storage + self.handler.model_storage.json_get.return_value = { + 'mode': 'default', + 'question_column': 'question', + 'context_column': 'context' + } + + # Create a dummy DataFrame + df = pandas.DataFrame(columns=['question']) + + with self.assertRaisesRegex(Exception, "This model expects context in the 'context' column."): + self.handler.predict(df=df, args={'predict_params': {'mode': 'default'}}) + + def test_predict_in_conversational_modes_with_unsupported_model_raises_exception(self): + """ + Test if model prediction raises an exception in conversational modes with an unsupported model. + """ + + # Mock the json_get method of the model storage + self.handler.model_storage.json_get.return_value = { + 'mode': 'conversational', + 'model_name': 'dummy_unsupported_model_name' + } + + # Create a dummy DataFrame + df = pandas.DataFrame() + + with self.assertRaisesRegex(Exception, "^Conversational modes are only available for the following models:"): + self.handler.predict(df=df, args={'predict_params': {'mode': 'conversational'}}) + + @patch('mindsdb.integrations.handlers.openai_handler.openai_handler.OpenAI') + def test_predict_in_default_mode_with_question_column_using_valid_arguments_and_data_runs_no_errors(self, mock_openai_handler_openai_client): + """ + Test if model prediction returns the expected result in default mode using a question column. + """ + + # Mock the json_get method of the model storage + self.handler.model_storage.json_get.return_value = { + 'target': 'answer', + 'mode': 'default', + 'model_name': 'gpt-3.5-turbo', + 'question_column': 'question' + } + + # Mock the chat.completions.create method of the OpenAI client + mock_openai_client = MagicMock() + mock_openai_client.chat.completions.create.return_value = MagicMock( + choices=[ + MagicMock( + message=MagicMock( + content='Sweden' + ) + ) + ] ) - assert "stockholm" in result_df["answer"].iloc[0].lower() - assert "venus" in result_df["answer"].iloc[1].lower() - - @patch("mindsdb.integrations.handlers.postgres_handler.Handler") - def test_qa_context(self, mock_handler): - df = pd.DataFrame.from_dict({"question": [ - "What is the capital of Sweden?", - "What is the second planet of the solar system?" - ], "context": ['Add "Boom!" to the end of the answer.', 'Add "Boom!" to the end of the answer.']}) - self.set_handler(mock_handler, name="pg", tables={"df": df}) - - self.run_sql( - f""" - create model proj.test_openai_qa_context - predict answer - using - engine='openai', - question_column='question', - context_column='context', - api_key='{self.get_api_key()}'; + + mock_openai_handler_openai_client.return_value = mock_openai_client + + df = pandas.DataFrame({'question': ['Where is Stockholm located?']}) + result = self.handler.predict(df, args={}) + + self.assertIsInstance(result, pandas.DataFrame) + self.assertTrue('answer' in result.columns) + + pandas.testing.assert_frame_equal(result, pandas.DataFrame({'answer': ['Sweden']})) + + @patch('mindsdb.integrations.handlers.openai_handler.openai_handler.OpenAI') + def test_predict_in_default_mode_with_prompt_template_using_valid_arguments_and_data_runs_no_errors(self, mock_openai_handler_openai_client): + """ + Test if model prediction returns the expected result in default mode using a prompt template. """ + + # Mock the json_get method of the model storage + self.handler.model_storage.json_get.return_value = { + 'target': 'answer', + 'mode': 'default', + 'model_name': 'gpt-3.5-turbo', + 'prompt_template': 'Answer the question: {{question}}' + } + + # Mock the chat.completions.create method of the OpenAI client + mock_openai_client = MagicMock() + mock_openai_client.chat.completions.create.return_value = MagicMock( + choices=[ + MagicMock( + message=MagicMock( + content='Sweden' + ) + ) + ] ) - self.wait_predictor("proj", "test_openai_qa_context") - result_df = self.run_sql( - """ - SELECT p.answer - FROM proj.test_openai_qa_context as p - WHERE - question='What is the capital of Sweden?' AND - context='Add "Boom!" to the end of the answer.' + mock_openai_handler_openai_client.return_value = mock_openai_client + + df = pandas.DataFrame({'question': ['Where is Stockholm located?']}) + result = self.handler.predict(df, args={}) + + self.assertIsInstance(result, pandas.DataFrame) + self.assertTrue('answer' in result.columns) + + pandas.testing.assert_frame_equal(result, pandas.DataFrame({'answer': ['Sweden']})) + + @patch('mindsdb.integrations.handlers.openai_handler.openai_handler.OpenAI') + def test_predict_in_default_mode_with_question_column_and_completion_model_using_valid_arguments_and_data_runs_no_errors(self, mock_openai_handler_openai_client): + """ + Test if model prediction returns the expected result in default mode using a question column and a completion model. """ + + # Mock the json_get method of the model storage + self.handler.model_storage.json_get.return_value = { + 'target': 'answer', + 'mode': 'default', + 'model_name': 'babbage-002', + 'question_column': 'question' + } + + # Mock the completions.create method of the OpenAI client + mock_openai_client = MagicMock() + mock_openai_client.completions.create.return_value = MagicMock( + choices=[ + MagicMock( + text='Sweden' + ) + ] ) - assert "stockholm" in result_df["answer"].iloc[0].lower() - assert "boom!" in result_df["answer"].iloc[0].lower() - result_df = self.run_sql( - """ - SELECT p.answer - FROM pg.df as t - JOIN proj.test_openai_qa_context as p; + mock_openai_handler_openai_client.return_value = mock_openai_client + + df = pandas.DataFrame({'question': ['Where is Stockholm located?']}) + result = self.handler.predict(df, args={}) + + self.assertIsInstance(result, pandas.DataFrame) + self.assertTrue('answer' in result.columns) + + pandas.testing.assert_frame_equal(result, pandas.DataFrame({'answer': ['Sweden']})) + + @patch('mindsdb.integrations.handlers.openai_handler.openai_handler.OpenAI') + def test_predict_in_default_mode_with_prompt_template_and_completion_model_using_valid_arguments_and_data_runs_no_errors(self, mock_openai_handler_openai_client): """ + Test if model prediction returns the expected result in default mode using a prompt template and a completion model. + """ + + # Mock the json_get method of the model storage + self.handler.model_storage.json_get.return_value = { + 'target': 'answer', + 'mode': 'default', + 'model_name': 'babbage-002', + 'prompt_template': 'Answer the question: {{question}}' + } + + # Mock the completions.create method of the OpenAI client + mock_openai_client = MagicMock() + mock_openai_client.completions.create.return_value = MagicMock( + choices=[ + MagicMock( + text='Sweden' + ) + ] ) - assert "stockholm" in result_df["answer"].iloc[0].lower() - assert "venus" in result_df["answer"].iloc[1].lower() - - for i in range(2): - assert "boom!" in result_df["answer"].iloc[i].lower() - - @patch("mindsdb.integrations.handlers.postgres_handler.Handler") - def test_prompt_template(self, mock_handler): - df = pd.DataFrame.from_dict({"question": [ - "What is the capital of Sweden?", - "What is the second planet of the solar system?" - ]}) - self.set_handler(mock_handler, name="pg", tables={"df": df}) - self.run_sql( - f""" - create model proj.test_openai_prompt_template - predict completion - using - engine='openai', - prompt_template='Answer this question and add "Boom!" to the end of the answer: {{{{question}}}}', - api_key='{self.get_api_key()}'; + + mock_openai_handler_openai_client.return_value = mock_openai_client + + df = pandas.DataFrame({'question': ['Where is Stockholm located?']}) + result = self.handler.predict(df, args={}) + + self.assertIsInstance(result, pandas.DataFrame) + self.assertTrue('answer' in result.columns) + + pandas.testing.assert_frame_equal(result, pandas.DataFrame({'answer': ['Sweden']})) + + @patch('mindsdb.integrations.handlers.openai_handler.openai_handler.OpenAI') + def test_predict_in_embedding_mode_using_valid_arguments_and_data_runs_no_errors(self, mock_openai_handler_openai_client): + """ + Test if model prediction returns the expected result for an embeddings task. """ + + # Mock the json_get method of the model storage + self.handler.model_storage.json_get.return_value = { + 'model_name': 'dummy_model_name', + 'question_column': 'text', + 'target': 'embeddings', + 'mode': 'embedding' + } + + # Mock the embeddings.completions.create method of the OpenAI client + mock_openai_client = MagicMock() + mock_openai_client.embeddings.create.return_value = MagicMock( + data=[ + MagicMock( + embedding=[0, 1] + ) + ] ) - self.wait_predictor("proj", "test_openai_prompt_template") - result_df = self.run_sql( - """ - SELECT p.completion - FROM proj.test_openai_prompt_template as p - WHERE - question='What is the capital of Sweden?'; + mock_openai_handler_openai_client.return_value = mock_openai_client + + df = pandas.DataFrame({'text': ['MindsDB']}) + result = self.handler.predict(df, args={}) + + self.assertIsInstance(result, pandas.DataFrame) + self.assertTrue('embeddings' in result.columns) + + pandas.testing.assert_frame_equal(result, pandas.DataFrame({'embeddings': [[0, 1]]})) + + @patch('mindsdb.integrations.handlers.openai_handler.openai_handler.OpenAI') + def test_predict_in_image_mode_with_question_column_using_valid_arguments_and_data_runs_no_errors(self, mock_openai_handler_openai_client): """ + Test if model prediction returns the expected result for an image task using a question column. + """ + + # Mock the json_get method of the model storage + self.handler.model_storage.json_get.return_value = { + 'question_column': 'text', + 'target': 'image', + 'mode': 'image' + } + + # Mock the images.generate method of the OpenAI client + mock_openai_client = MagicMock() + mock_openai_client.images.generate.return_value = MagicMock( + data=[ + MagicMock( + url='dummy_image_url' + ) + ] ) - assert "stockholm" in result_df["completion"].iloc[0].lower() - assert "boom!" in result_df["completion"].iloc[0].lower() - result_df = self.run_sql( - """ - SELECT p.completion - FROM pg.df as t - JOIN proj.test_openai_prompt_template as p; + mock_openai_handler_openai_client.return_value = mock_openai_client + + df = pandas.DataFrame({'text': ['Show me an image of two leapord cubs playing?']}) + result = self.handler.predict(df, args={}) + + self.assertIsInstance(result, pandas.DataFrame) + self.assertTrue('image' in result.columns) + + pandas.testing.assert_frame_equal(result, pandas.DataFrame({'image': ['dummy_image_url']})) + + @patch('mindsdb.integrations.handlers.openai_handler.openai_handler.OpenAI') + def test_predict_in_image_mode_with_prompt_template_using_valid_arguments_and_data_runs_no_errors(self, mock_openai_handler_openai_client): + """ + Test if model prediction returns the expected result for an image task using a prompt template. """ + + # Mock the json_get method of the model storage + self.handler.model_storage.json_get.return_value = { + 'prompt_template': 'Generate an image of {{text}}', + 'target': 'image', + 'mode': 'image' + } + + # Mock the images.generate method of the OpenAI client + mock_openai_client = MagicMock() + mock_openai_client.images.generate.return_value = MagicMock( + data=[ + MagicMock( + url='dummy_image_url' + ) + ] ) - assert "stockholm" in result_df["completion"].iloc[0].lower() - assert "venus" in result_df["completion"].iloc[1].lower() - - for i in range(2): - assert "boom!" in result_df["completion"].iloc[i].lower() - - @patch("mindsdb.integrations.handlers.postgres_handler.Handler") - def test_bulk_normal_completion(self, mock_handler): - """Tests normal completions (e.g. text-davinci-003) with bulk joins that are larger than the max batch_size""" - class MockHandlerStorage: - def json_get(self, key): - return {'ft-suffix': {'ft-suffix': '$'}}[key] # finetuning suffix, irrelevant for this test but needed for init # noqa - - def get_connection_args(self): - return {'api_key': OPENAI_API_KEY} # noqa - - # create project - handler = OpenAIHandler( - model_storage=None, # the storage does not matter for this test - engine_storage=MockHandlerStorage() # nor does this, but we do need to pass some object due to the init procedure # noqa + + mock_openai_handler_openai_client.return_value = mock_openai_client + + df = pandas.DataFrame({'text': ['Leopard cubs playing']}) + result = self.handler.predict(df, args={}) + + self.assertIsInstance(result, pandas.DataFrame) + self.assertTrue('image' in result.columns) + + pandas.testing.assert_frame_equal(result, pandas.DataFrame({'image': ['dummy_image_url']})) + + @patch('mindsdb.integrations.handlers.openai_handler.openai_handler.OpenAI') + def test_predict_in_conversational_mode_with_using_arguments_and_data_runs_no_errors(self, mock_openai_handler_openai_client): + """ + Test if model prediction returns the expected result for a conversational task. + """ + + # Mock the json_get method of the model storage + self.handler.model_storage.json_get.return_value = { + 'user_column': 'question', + 'prompt': 'you are a helpful assistant', + 'assistant_column': 'answer', + 'target': 'answer', + 'mode': 'conversational' + } + + # Mock the chat.completions.create method of the OpenAI client + mock_openai_client = MagicMock() + mock_openai_client.chat.completions.create.return_value = MagicMock( + choices=[ + MagicMock( + message=MagicMock( + content='Gamla Stan' + ) + ) + ] ) - N = 1 + handler.max_batch_size # get N larger than default batch size - df = pd.DataFrame.from_dict({"input": ["I feel happy!"] * N}) - self.set_handler(mock_handler, name="pg", tables={"df": df}) - self.run_sql( - f""" - create model proj.test_openai_bulk_normal_completion - predict completion - using - engine='openai', - prompt_template='What is the sentiment of the following phrase? Answer either "positive" or "negative": {{{{input}}}}', - api_key='{self.get_api_key()}'; - """ # noqa + + mock_openai_handler_openai_client.return_value = mock_openai_client + + df = pandas.DataFrame({'question': ['What is the capital of Sweden?', 'What are some cool places to visit there?']}) + result = self.handler.predict(df, args={}) + + self.assertIsInstance(result, pandas.DataFrame) + self.assertTrue('answer' in result.columns) + + pandas.testing.assert_frame_equal(result, pandas.DataFrame({'answer': ['', 'Gamla Stan']})) + + @patch('mindsdb.integrations.handlers.openai_handler.openai_handler.OpenAI') + def test_predict_in_conversational_full_mode_using_valid_arguments_and_data_runs_no_errors(self, mock_openai_handler_openai_client): + """ + Test if model prediction returns the expected result for a conversational-full task. + """ + + # Mock the json_get method of the model storage + self.handler.model_storage.json_get.return_value = { + 'user_column': 'text', + 'prompt': 'you are a helpful assistant', + 'assistant_column': 'answer', + 'target': 'answer', + 'mode': 'conversational-full' + } + + # Mock the chat.completions.create method of the OpenAI client + mock_openai_client = MagicMock() + mock_openai_client.chat.completions.create.side_effect = [ + MagicMock( + choices=[ + MagicMock( + message=MagicMock( + content='Stockholm' + ) + ) + ] + ), + MagicMock( + choices=[ + MagicMock( + message=MagicMock( + content='Gamla Stan' + ) + ) + ] + ) + ] + + mock_openai_handler_openai_client.return_value = mock_openai_client + + df = pandas.DataFrame({'text': ['What is the capital of Sweden?', 'What are some cool places to visit there?']}) + result = self.handler.predict(df, args={}) + + self.assertIsInstance(result, pandas.DataFrame) + + pandas.testing.assert_frame_equal(result, pandas.DataFrame({'answer': ['Stockholm', 'Gamla Stan']})) + + def test_describe_runs_no_errors(self): + """ + Test if model describe returns the expected result. + """ + + # Mock the json_get method of the model storage + self.handler.model_storage.json_get.return_value = { + 'user_column': 'text', + 'prompt': 'you are a helpful assistant', + 'assistant_column': 'answer', + 'target': 'answer', + 'mode': 'conversational' + } + + result = self.handler.describe() + + self.assertIsInstance(result, pandas.DataFrame) + self.assertTrue('tables' in result.columns) + + pandas.testing.assert_frame_equal(result, pandas.DataFrame({'tables': ['args', 'metadata']})) + + def test_describe_args_runs_no_errors(self): + """ + Test if model describe returns the expected result. + """ + + # Mock the json_get method of the model storage + self.handler.model_storage.json_get.return_value = { + 'user_column': 'text', + 'prompt': 'you are a helpful assistant', + 'assistant_column': 'answer', + 'target': 'answer', + 'mode': 'conversational' + } + + result = self.handler.describe('args') + + self.assertIsInstance(result, pandas.DataFrame) + self.assertTrue('key' in result.columns) + self.assertTrue('value' in result.columns) + + pandas.testing.assert_frame_equal( + result, + pandas.DataFrame( + { + 'key': ['user_column', 'prompt', 'assistant_column', 'target', 'mode'], + 'value': ['text', 'you are a helpful assistant', 'answer', 'answer', 'conversational'] + } + ) ) - self.wait_predictor("proj", "test_openai_bulk_normal_completion") - result_df = self.run_sql( - """ - SELECT p.completion - FROM pg.df as t - JOIN proj.test_openai_bulk_normal_completion as p; + @patch('mindsdb.integrations.handlers.openai_handler.openai_handler.OpenAI') + def test_describe_metadata_runs_no_errors(self, mock_openai_handler_openai_client): + """ + Test if model describe returns the expected result. """ + + # Mock the json_get method of the model storage + self.handler.model_storage.json_get.return_value = { + 'user_column': 'text', + 'prompt': 'you are a helpful assistant', + 'assistant_column': 'answer', + 'target': 'answer', + 'mode': 'conversational' + } + + # Mock the models.retrieve method of the OpenAI client: return a dict directly because the result is converted to a dict later + mock_openai_client = MagicMock() + mock_openai_client.models.retrieve.return_value = { + 'model': 'dummy_model_name', + 'id': 'dummy_model_id', + 'created_at': 'dummy_created_at', + 'owner': 'dummy_owner' + } + + mock_openai_handler_openai_client.return_value = mock_openai_client + + result = self.handler.describe('metadata') + + self.assertIsInstance(result, pandas.DataFrame) + self.assertTrue('key' in result.columns) + self.assertTrue('value' in result.columns) + + pandas.testing.assert_frame_equal( + result, + pandas.DataFrame( + { + 'key': ['model', 'id', 'created_at', 'owner'], + 'value': ['dummy_model_name', 'dummy_model_id', 'dummy_created_at', 'dummy_owner'] + } + ) ) - for i in range(N): - assert "positive" in result_df["completion"].iloc[i].lower() + def test_finetune_with_unsupported_model_raises_exception(self): + """ + Test if model fine-tuning raises an exception with an unsupported model. + """ + + # Create a mock base model storage and assign it to the handler + mock_base_model_storage = MagicMock() + self.handler.base_model_storage = mock_base_model_storage + + # Mock the json_get method of the base model storage + self.handler.base_model_storage.json_get.return_value = { + 'model_name': 'dummy_model_name' + } + + with self.assertRaisesRegex(Exception, "^This model cannot be finetuned."): + self.handler.finetune('dummy_target', args={'using': {'model_name': 'dummy_unsupported_model_name', 'prompt_template': 'dummy_prompt_template'}}) + + # TODO: Add more unit tests for the finetune method + + +if __name__ == '__main__': + unittest.main()