Skip to content

Commit

Permalink
Merge pull request mindsdb#9225 from mindsdb/eri/openai
Browse files Browse the repository at this point in the history
OpenAI Integration Improvements
  • Loading branch information
ZoranPandovski authored May 30, 2024
2 parents 20cc8cc + 8f31520 commit a1362ed
Show file tree
Hide file tree
Showing 7 changed files with 1,613 additions and 369 deletions.
1 change: 0 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ exclude =
mindsdb/integrations/handlers/ingres_handler/*
mindsdb/integrations/handlers/palm_handler/*
mindsdb/integrations/handlers/binance_handler/*
mindsdb/integrations/handlers/openai_handler/*
mindsdb/integrations/handlers/mediawiki_handler/*
mindsdb/integrations/handlers/mendeley_handler/*
mindsdb/integrations/handlers/databend_handler/*
Expand Down
126 changes: 63 additions & 63 deletions .github/workflows/test_on_push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Test on Push
on:
pull_request:
branches: [main]

defaults:
run:
shell: bash
Expand All @@ -14,7 +14,6 @@ concurrency:
group: ${{ github.workflow_ref }}
cancel-in-progress: true


jobs:
# Run all of our static code checks here
code_checking:
Expand All @@ -27,8 +26,8 @@ jobs:
with:
python-version: ${{ vars.CI_PYTHON_VERSION }}
cache: pip
cache-dependency-path: '**/requirements*.txt'
cache-dependency-path: "**/requirements*.txt"

# Checks the codebase for print() statements and fails if any are found
# We should be using loggers instead
- name: Check for print statements
Expand All @@ -46,7 +45,7 @@ jobs:
uses: pre-commit/[email protected]
with:
extra_args: --files ${{ steps.changed-files.outputs.all_changed_files }}

# Runs a few different checks against our many requirements files
# to make sure they're in order
- name: Check requirements files
Expand All @@ -68,8 +67,8 @@ jobs:
- id: set-matrix
uses: JoshuaTheMiller/[email protected]
with:
filter: '[?runOnBranch==`${{ github.ref }}` || runOnBranch==`always`]'
filter: "[?runOnBranch==`${{ github.ref }}` || runOnBranch==`always`]"

# Check that our pip package is able to be installed in all of our supported environments
check_install:
name: Check pip installation
Expand All @@ -78,23 +77,23 @@ jobs:
matrix: ${{fromJson(needs.matrix_prep.outputs.matrix)}}
runs-on: ${{ matrix.runs_on }}
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/[email protected]
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: '**/requirements*.txt'
- name: Check requirements files are installable
run: |
# Install dev requirements and build our pip package
pip install -r requirements/requirements-dev.txt
python setup.py sdist
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/[email protected]
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: "**/requirements*.txt"
- name: Check requirements files are installable
run: |
# Install dev requirements and build our pip package
pip install -r requirements/requirements-dev.txt
python setup.py sdist
# Install from the pip package
# If we install from source, we don't know if the pip package is installable.
cd dist
pip install --ignore-installed *.tar.gz
# Install from the pip package
# If we install from source, we don't know if the pip package is installable.
cd dist
pip install --ignore-installed *.tar.gz
unit_tests:
name: Run Unit Tests
Expand All @@ -104,43 +103,44 @@ jobs:
runs-on: ${{ matrix.runs_on }}
if: github.ref_type == 'branch'
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/[email protected]
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: '**/requirements*.txt'
- name: Install dependencies
run: |
pip install .
pip install -r requirements/requirements-test.txt
pip install .[lightwood] # TODO: for now some tests rely on lightwood
pip install .[mssql]
pip install .[clickhouse]
pip install .[snowflake]
pip install .[web]
pip freeze
- name: Run unit tests
run: |
if [ "$RUNNER_OS" == "Linux" ]; then
env PYTHONPATH=./ pytest tests/unit/test_executor.py
env PYTHONPATH=./ pytest tests/unit/test_project_structure.py
env PYTHONPATH=./ pytest tests/unit/test_predictor_params.py
env PYTHONPATH=./ pytest tests/unit/test_mongodb_handler.py
env PYTHONPATH=./ pytest tests/unit/test_mongodb_server.py
env PYTHONPATH=./ pytest tests/unit/test_cache.py
env PYTHONPATH=./ pytest tests/unit/test_llm_utils.py
env PYTHONPATH=./ pytest tests/unit/ml_handlers/test_mindsdb_inference.py
fi
- name: Run Handlers tests and submit Coverage to coveralls
run: |
handlers=("mysql" "postgres" "mssql" "clickhouse" "snowflake" "web")
for handler in "${handlers[@]}"
do
pytest --cov=mindsdb/integrations/handlers/${handler}_handler tests/unit/handlers/test_${handler}.py
done
coveralls --service=github --basedir=mindsdb/integrations/handlers
env:
COVERALLS_REPO_TOKEN: ${{ secrets.COVERALLS_REPO_TOKEN }}
github_token: ${{ secrets.REPO_DISPATCH_PAT_TOKEN }}
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/[email protected]
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: "**/requirements*.txt"
- name: Install dependencies
run: |
pip install .
pip install -r requirements/requirements-test.txt
pip install .[lightwood] # TODO: for now some tests rely on lightwood
pip install .[mssql]
pip install .[clickhouse]
pip install .[snowflake]
pip install .[web]
pip freeze
- name: Run unit tests
run: |
if [ "$RUNNER_OS" == "Linux" ]; then
env PYTHONPATH=./ pytest tests/unit/test_executor.py
env PYTHONPATH=./ pytest tests/unit/test_project_structure.py
env PYTHONPATH=./ pytest tests/unit/test_predictor_params.py
env PYTHONPATH=./ pytest tests/unit/test_mongodb_handler.py
env PYTHONPATH=./ pytest tests/unit/test_mongodb_server.py
env PYTHONPATH=./ pytest tests/unit/test_cache.py
env PYTHONPATH=./ pytest tests/unit/test_llm_utils.py
env PYTHONPATH=./ pytest tests/unit/ml_handlers/test_mindsdb_inference.py
env PYTHONPATH=./ pytest tests/unit/ml_handlers/test_openai.py
fi
- name: Run Handlers tests and submit Coverage to coveralls
run: |
handlers=("mysql" "postgres" "mssql" "clickhouse" "snowflake" "web")
for handler in "${handlers[@]}"
do
pytest --cov=mindsdb/integrations/handlers/${handler}_handler tests/unit/handlers/test_${handler}.py
done
coveralls --service=github --basedir=mindsdb/integrations/handlers
env:
COVERALLS_REPO_TOKEN: ${{ secrets.COVERALLS_REPO_TOKEN }}
github_token: ${{ secrets.REPO_DISPATCH_PAT_TOKEN }}
56 changes: 48 additions & 8 deletions mindsdb/integrations/handlers/openai_handler/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from typing import List
from typing import Text, List, Dict
import random
import time
import math
Expand All @@ -10,11 +9,14 @@
import tiktoken

import mindsdb.utilities.profiler as profiler
from mindsdb.integrations.handlers.openai_handler.constants import OPENAI_API_BASE


class PendingFT(openai.OpenAIError):
"""
Custom exception to handle pending fine-tuning status.
"""
message: str

def __init__(self, message) -> None:
super().__init__()
self.message = message
Expand Down Expand Up @@ -44,6 +46,18 @@ def _retry_with_exponential_backoff(func):
Slight changes in the implementation, but originally from:
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_handle_rate_limits.ipynb
Args:
func: Function to be wrapped
initial_delay: Initial delay in seconds
hour_budget: Hourly budget in seconds
jitter: Adds randomness to the delay
exponential_base: Base for the exponential backoff
wait_errors: Tuple of errors to retry on
status_errors: Tuple of status errors to raise
Returns:
Wrapper function with exponential backoff
""" # noqa

def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -94,10 +108,22 @@ def wrapper(*args, **kwargs):
return _retry_with_exponential_backoff


def truncate_msgs_for_token_limit(messages, model_name, max_tokens, truncate='first'):
def truncate_msgs_for_token_limit(messages: List[Dict], model_name: Text, max_tokens: int, truncate: Text = 'first'):
"""
Truncates message list to fit within the token limit.
Note: first message for chat completion models are general directives with the system role, which will ideally be kept at all times.
The first message for chat completion models are general directives with the system role, which will ideally be kept at all times.
Slight changes in the implementation, but originally from:
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
Args:
messages (List[Dict]): List of messages
model_name (Text): Model name
max_tokens (int): Maximum token limit
truncate (Text): Truncate strategy, either 'first' or 'last'
Returns:
List[Dict]: Truncated message list
""" # noqa
encoder = tiktoken.encoding_for_model(model_name)
sys_priming = messages[0:1]
Expand All @@ -119,8 +145,15 @@ def truncate_msgs_for_token_limit(messages, model_name, max_tokens, truncate='fi
return messages


def count_tokens(messages, encoder, model_name='gpt-3.5-turbo-0301'):
"""Original token count implementation can be found in the OpenAI cookbook."""
def count_tokens(messages: List[Dict], encoder: tiktoken.core.Encoding, model_name: Text = 'gpt-3.5-turbo-0301'):
"""
Counts the number of tokens in a list of messages.
Args:
messages: List of messages
encoder: Tokenizer
model_name: Model name
"""
if (
"gpt-3.5-turbo" in model_name
): # note: future models may deviate from this (only 0301 really complies)
Expand All @@ -141,9 +174,16 @@ def count_tokens(messages, encoder, model_name='gpt-3.5-turbo-0301'):
)


def get_available_models(api_key: str, api_base: str) -> List[str]:
def get_available_models(api_key: Text, api_base: Text) -> List[Text]:
"""
Returns a list of available openai models for the given API key.
Args:
api_key (Text): OpenAI API key
api_base (Text): OpenAI API base URL
Returns:
List[Text]: List of available models
"""
res = OpenAI(api_key=api_key, base_url=api_base).models.list()

Expand Down
Loading

0 comments on commit a1362ed

Please sign in to comment.