From 7d23c252b889a0402affb06af7c55ae2af2f8b4e Mon Sep 17 00:00:00 2001 From: Jonathan Lounsbury Date: Tue, 7 Mar 2023 15:43:21 -0500 Subject: [PATCH 1/5] Add OpenAI Embeddings Primitive Adds a primitive for natural language logical types that uses the OpenAI Embeddings API to calculate embeddings features. The model to use is configurable, but text-embedding-ada-002 is used by default. --- nlp_primitives/__init__.py | 4 + nlp_primitives/openai/__init__.py | 3 + nlp_primitives/openai/api_requester.py | 261 +++++++++++++++++++++++++ nlp_primitives/openai/embeddings.py | 190 ++++++++++++++++++ nlp_primitives/openai/model.py | 17 ++ nlp_primitives/openai/request.py | 54 +++++ nlp_primitives/openai/response.py | 13 ++ nlp_primitives/utils/__init__.py | 1 + nlp_primitives/utils/event_loop.py | 22 +++ pyproject.toml | 12 +- 10 files changed, 576 insertions(+), 1 deletion(-) create mode 100644 nlp_primitives/openai/__init__.py create mode 100644 nlp_primitives/openai/api_requester.py create mode 100644 nlp_primitives/openai/embeddings.py create mode 100644 nlp_primitives/openai/model.py create mode 100644 nlp_primitives/openai/request.py create mode 100644 nlp_primitives/openai/response.py create mode 100644 nlp_primitives/utils/__init__.py create mode 100644 nlp_primitives/utils/event_loop.py diff --git a/nlp_primitives/__init__.py b/nlp_primitives/__init__.py index 6357b029..3adb7961 100644 --- a/nlp_primitives/__init__.py +++ b/nlp_primitives/__init__.py @@ -3,6 +3,7 @@ import inspect from importlib.util import find_spec +import os import nltk.data import pkg_resources @@ -19,6 +20,9 @@ if find_spec("tensorflow") and find_spec("tensorflow_hub"): from nlp_primitives.tensorflow import Elmo, UniversalSentenceEncoder +if find_spec("openai") and "OPENAI_API_KEY" in os.environ: + from nlp_primitives.openai import OpenAIEmbeddings + NLP_PRIMITIVES = [ obj for obj in globals().values() diff --git a/nlp_primitives/openai/__init__.py b/nlp_primitives/openai/__init__.py new file mode 100644 index 00000000..4b6e2a4b --- /dev/null +++ b/nlp_primitives/openai/__init__.py @@ -0,0 +1,3 @@ +from nlp_primitives.openai.embeddings import ( + OpenAIEmbeddings, +) diff --git a/nlp_primitives/openai/api_requester.py b/nlp_primitives/openai/api_requester.py new file mode 100644 index 00000000..412be4bf --- /dev/null +++ b/nlp_primitives/openai/api_requester.py @@ -0,0 +1,261 @@ +# Adapted from +# https://github.com/openai/openai-cookbook/blob/66b988407d8d13cad5060a881dc8c892141f2d5c/examples/api_request_parallel_processor.py + +""" +API REQUESTER + +Using the OpenAI API to process lots of text quickly takes some care. +If you trickle in a million API requests one by one, they'll take days to complete. +If you flood a million API requests in parallel, they'll exceed the rate limits and fail with errors. +To maximize throughput, parallel requests need to be throttled to stay under rate limits. + +This script parallelizes requests to the OpenAI API while throttling to stay under rate limits. + +Features: +- Makes requests concurrently, to maximize throughput +- Throttles request and token usage, to stay under rate limits +- Retries failed requests up to {max_attempts} times, to avoid missing data +- Logs errors, to diagnose problems with requests + +Inputs: +- requests_list : List[OpenAIRequest[T]] + - a list of requests to process +- max_requests_per_minute : float, optional + - target number of requests to make per minute (will make less if limited by tokens) + - leave headroom by setting this to 50% or 75% of your limit + - if requests are limiting you, try batching multiple embeddings or completions into one request + - if omitted, will default to 1,500 +- max_tokens_per_minute : float, optional + - target number of tokens to use per minute (will use less if limited by requests) + - leave headroom by setting this to 50% or 75% of your limit + - if omitted, will default to 125,000 +- max_attempts : int, optional + - number of times to retry a failed request before giving up + - if omitted, will default to 5 + +The script is structured as follows: + - Imports + - Define process_api_requests() + - Initialize things + - In main loop: + - Get next request if one is not already waiting for capacity + - Update available token & request capacity + - If enough capacity available, call API + - The loop pauses if a rate limit error is hit + - The loop breaks when no tasks remain + - Define dataclasses + - StatusTracker (stores script metadata counters; only one instance is created) + - RetryableRequest (stores API inputs, outputs, metadata; one method to call API) + - Define functions + - task_id_generator_function (yields 1, 2, 3, ...) +""" + +# imports +import asyncio # for running API calls concurrently +import logging # for logging rate limit warnings and other messages +import time # for sleeping after rate limit is hit +from dataclasses import dataclass +from typing import Generic, List + +from nlp_primitives.openai.request import RESPONSE, OpenAIRequest +from openai.error import OpenAIError, RateLimitError + + +async def process_api_requests( + request_list: List[OpenAIRequest[RESPONSE]], + max_requests_per_minute: float = 3_000 * 0.5, + max_tokens_per_minute: float = 250_000 * 0.5, + max_attempts: int = 5, + seconds_to_pause_after_rate_limit_error: int = 15, + seconds_to_sleep_each_loop: float = 0.001, # 1 ms limits max throughput to 1,000 requests per second +) -> List[RESPONSE]: + """Processes API requests in parallel, throttling to stay under rate limits.""" + logging.debug("Initializing requester.") + + # initialize trackers + queue_of_requests_to_retry = asyncio.Queue() + request_id_generator = ( + request_id_generator_function() + ) # generates integer IDs of 1, 2, 3, ... + status_tracker = ( + StatusTracker() + ) # single instance to track a collection of variables + next_request = None # variable to hold the next request to run + + # initialize available capacity counts + available_request_capacity = max_requests_per_minute + available_token_capacity = max_tokens_per_minute + last_update_time = time.time() + + # initialize flags + requests_remaining = True + logging.debug("Initialization complete.") + + # `requests` will provide requests one at a time + requests = request_list.__iter__() + logging.debug("Iteration started. Entering main loop") + + sent_requests = [] + while True: + # get next task (if one is not already waiting for capacity) + if next_request is None: + if queue_of_requests_to_retry.empty() is False: + next_request = queue_of_requests_to_retry.get_nowait() + logging.debug( + f"Retrying request {next_request.request_id}: {next_request}" + ) + elif requests_remaining: + try: + # get new request + request = next(requests) + next_request = RetryableRequest( + request_id=next(request_id_generator), + request=request, + attempts_left=max_attempts, + ) + sent_requests.append(next_request) + status_tracker.num_tasks_started += 1 + status_tracker.num_tasks_in_progress += 1 + logging.debug( + f"Created request {next_request.request_id}: {next_request}" + ) + except StopIteration: + # if requests list runs out, set flag to stop iterating + logging.debug("Requests list exhausted") + requests_remaining = False + + # update available capacity + current_time = time.time() + seconds_since_update = current_time - last_update_time + available_request_capacity = min( + available_request_capacity + + max_requests_per_minute * seconds_since_update / 60.0, + max_requests_per_minute, + ) + available_token_capacity = min( + available_token_capacity + + max_tokens_per_minute * seconds_since_update / 60.0, + max_tokens_per_minute, + ) + last_update_time = current_time + + # if enough capacity available, call API + if next_request: + next_request_tokens = next_request.request.token_consumption + if ( + available_request_capacity >= 1 + and available_token_capacity >= next_request_tokens + ): + # update counters + available_request_capacity -= 1 + available_token_capacity -= next_request_tokens + next_request.attempts_left -= 1 + + # call API + asyncio.create_task( + next_request.execute( + retry_queue=queue_of_requests_to_retry, + status_tracker=status_tracker, + ) + ) + next_request = None # reset next_request to empty + + # if all tasks are finished, break + if status_tracker.num_tasks_in_progress == 0: + break + + # main loop sleeps briefly so concurrent tasks can run + await asyncio.sleep(seconds_to_sleep_each_loop) + + # if a rate limit error was hit recently, pause to cool down + seconds_since_rate_limit_error = ( + time.time() - status_tracker.time_of_last_rate_limit_error + ) + if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error: + remaining_seconds_to_pause = ( + seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error + ) + await asyncio.sleep(remaining_seconds_to_pause) + # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago + until = time.ctime( + status_tracker.time_of_last_rate_limit_error + + seconds_to_pause_after_rate_limit_error + ) + logging.warning(f"Pausing to cool down until {until}") + + # after finishing, log final status + logging.info("""Parallel processing complete.""") + if status_tracker.num_rate_limit_errors > 0: + logging.warning( + f"{status_tracker.num_rate_limit_errors} rate limit errors received." + " Consider running at a lower rate." + ) + + return [req.response for req in sent_requests] + + +# dataclasses + + +@dataclass +class StatusTracker: + """Stores metadata about the script's progress. Only one instance is created.""" + + num_tasks_started: int = 0 + num_tasks_in_progress: int = 0 # script ends when this reaches 0 + num_tasks_succeeded: int = 0 + num_rate_limit_errors: int = 0 + time_of_last_rate_limit_error: int = 0 # used to cool off after hitting rate limits + + +@dataclass +class RetryableRequest(Generic[RESPONSE]): + """Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API call. + """ + + request_id: int + request: OpenAIRequest + attempts_left: int + errors = [] + response: RESPONSE | None = None + + async def execute( + self, + retry_queue: asyncio.Queue, + status_tracker: StatusTracker, + ): + """Calls the OpenAI API and saves results.""" + logging.info(f"Starting request #{self.request_id}") + error = None + try: + self.response = await self.request.execute() + except RateLimitError: + status_tracker.time_of_last_rate_limit_error = time.time() + status_tracker.num_rate_limit_errors += 1 + except OpenAIError as e: + logging.warning(f"Request {self.request_id} failed with Exception {e}") + error = e + + if error: + self.errors.append(error) + if self.attempts_left: + retry_queue.put_nowait(self) + else: + logging.error( + f"Request {self.request_id} failed after all attempts:" + f" {self.errors}" + ) + status_tracker.num_tasks_in_progress -= 1 + raise error + else: + logging.debug(f"Request {self.request_id} succeeded") + status_tracker.num_tasks_in_progress -= 1 + status_tracker.num_tasks_succeeded += 1 + + +def request_id_generator_function(): + """Generate integers 0, 1, 2, and so on.""" + request_id = 0 + while True: + yield request_id + request_id += 1 diff --git a/nlp_primitives/openai/embeddings.py b/nlp_primitives/openai/embeddings.py new file mode 100644 index 00000000..ce184a70 --- /dev/null +++ b/nlp_primitives/openai/embeddings.py @@ -0,0 +1,190 @@ +import itertools +from typing import List, Optional + +import numpy as np +import pandas as pd +import tiktoken +from featuretools.primitives.base import TransformPrimitive +from woodwork.column_schema import ColumnSchema +from woodwork.logical_types import Double, NaturalLanguage + +from nlp_primitives.openai.api_requester import process_api_requests +from nlp_primitives.openai.model import ( + OpenAIEmbeddingModel, +) +from nlp_primitives.openai.request import ( + OpenAIEmbeddingRequest, + OpenAIRequest, + StaticOpenAIEmbeddingRequest, +) +from nlp_primitives.openai.response import OpenAIEmbeddingResponse +from nlp_primitives.utils import CurrentEventLoop + +DEFAULT_MODEL = OpenAIEmbeddingModel( + name="text-embedding-ada-002", + encoding="cl100k_base", + max_tokens=8191, + output_dimensions=1536, +) + + +class OpenAIEmbeddings(TransformPrimitive): + """Generates embeddings using OpenAI. + + Description: + Given list of strings, determine the embeddings for each string, using + the OpenAI model. + + Args: + model (OpenAIEmbeddingModel, optional): The model to use to produce embeddings. + Defaults to "text-embedding-ada-002" if not specified. + max_tokens_per_batch (int, optional): The maximum number of tokens to send in a batched request to OpenAI. + Defaults to 10 * model.max_tokens if not specified. + + Examples: + >>> x = ['This is a test file', 'This is second line', None] + >>> openai_embeddings = OpenAIEmbeddings() + >>> openai_embeddings(x).tolist() + [ + [ + -0.007940744049847126, + 0.007481361739337444, + ... + 0.009351702407002449, + -0.016065239906311035 + ], + [ + -0.001055666827596724, + 0.01066350657492876, + ... + -0.024650879204273224, + 0.009666346944868565 + ], + [ + nan, + nan, + ... + nan, + nan + ], + ] + """ + + name = "openai_embeddings" + input_types = [ColumnSchema(logical_type=NaturalLanguage)] + return_type = ColumnSchema(logical_type=Double, semantic_tags={"numeric"}) + + def __init__( + self, + model: OpenAIEmbeddingModel = DEFAULT_MODEL, + max_tokens_per_batch: Optional[int] = None, + ): + self.model = model + self.number_output_features = model.output_dimensions + if max_tokens_per_batch is None: + self.max_tokens_per_batch = model.max_tokens * 10 + else: + self.max_tokens_per_batch = max_tokens_per_batch + + def _is_too_many_tokens(self, element, encoding) -> bool: + """Return whether a data element has too many tokens and should be ignored""" + return len(encoding.encode(element)) > self.model.max_tokens + + def _create_request_batches( + self, series + ) -> List[OpenAIRequest[OpenAIEmbeddingResponse]]: + """Group elements of a series into batches of requests to send to OpenAI""" + # the encoding used by the model + encoding = tiktoken.get_encoding(self.model.encoding) + + # a static embeddings response for an invalid element + invalid_request = StaticOpenAIEmbeddingRequest( + result=[np.nan] * self.number_output_features + ) + + # mutable variables to track the request batching process + # a running list of the requests to make + requests: List[OpenAIRequest[OpenAIEmbeddingResponse]] = [] + # a list of elements that should be batched into the next request + elements_in_batch = [] + # a running total of tokens that will be sent in the next request + tokens_in_batch = 0 + + def add_batched_request() -> int: + """Create a batched request from the currently staged elements and add it to the request list + """ + if elements_in_batch: + elements_copy = list(elements_in_batch) + elements_in_batch.clear() + requests.append( + OpenAIEmbeddingRequest( + list_of_text=elements_copy, + model=self.model, + token_consumption=tokens_in_batch, + ) + ) + return 0 + else: + return tokens_in_batch + + def can_fit_in_batch(tokens) -> bool: + return ( + len(elements_in_batch) < 2048 + and tokens_in_batch + tokens <= self.max_tokens_per_batch + ) + + # loop through the input data series to create request batches + for element in series: + if pd.isnull(element) or self._is_too_many_tokens(element, encoding): + # invalid element + # create a request from any pending elements + tokens_in_batch = add_batched_request() + # add a static request that returns the invalid results + requests.append(invalid_request) + else: + # valid element + # check how many tokens are in it + next_tokens = len(encoding.encode(element)) + + # can this element fit in the batch? + if can_fit_in_batch(next_tokens): + # can't fit -- construct a request with existing elements + tokens_in_batch = add_batched_request() + + # add to next batch + elements_in_batch.append(element) + tokens_in_batch += next_tokens + + # collect any remaining elements into one last request + add_batched_request() + return requests + + async def async_get_embeddings(self, series): + """Get the embeddings for an input data series""" + # batch the requests + requests = self._create_request_batches(series) + + # process the batched requests + responses = await process_api_requests(requests) + + # get the embeddings from the responses + embeddings = [response.embeddings for response in responses] + + # flatten them + result = list(itertools.chain.from_iterable(embeddings)) + + # convert to series + result = np.array(result).T.tolist() + return pd.Series(result) + + def get_function(self): + def get_embeddings(series): + current_loop = CurrentEventLoop() + try: + return current_loop.loop.run_until_complete( + self.async_get_embeddings(series) + ) + finally: + current_loop.close() + + return get_embeddings diff --git a/nlp_primitives/openai/model.py b/nlp_primitives/openai/model.py new file mode 100644 index 00000000..71c26949 --- /dev/null +++ b/nlp_primitives/openai/model.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass + + +@dataclass +class OpenAIModel(object): + """A model accessible via the OpenAI API.""" + + name: str + encoding: str + max_tokens: int + + +@dataclass +class OpenAIEmbeddingModel(OpenAIModel): + """A model accessible via the OpenAI API that can produce embeddings.""" + + output_dimensions: int diff --git a/nlp_primitives/openai/request.py b/nlp_primitives/openai/request.py new file mode 100644 index 00000000..bd44fa51 --- /dev/null +++ b/nlp_primitives/openai/request.py @@ -0,0 +1,54 @@ +from dataclasses import dataclass +from typing import Generic, List, TypeVar + +import openai +from nlp_primitives.openai.model import OpenAIEmbeddingModel +from nlp_primitives.openai.response import OpenAIEmbeddingResponse, OpenAIResponse + +RESPONSE = TypeVar("RESPONSE", bound=OpenAIResponse) + + +class OpenAIRequest(Generic[RESPONSE]): + """A request to the OpenAI API.""" + + token_consumption: int + + async def execute(self) -> RESPONSE: + raise NotImplementedError("Subclass must implement") + + +class OpenAIEmbeddingRequest(OpenAIRequest[OpenAIEmbeddingResponse]): + """A request to the OpenAI Embeddings API.""" + + def __init__( + self, + list_of_text: List[str], + model: OpenAIEmbeddingModel, + token_consumption: int, + ): + self.list_of_text = [text.replace("\n", " ") for text in list_of_text] + self.model = model + self.token_consumption = token_consumption + + async def execute(self) -> OpenAIEmbeddingResponse: + data = ( + await openai.Embedding.acreate( + input=self.list_of_text, engine=self.model.name + ) + ).data + data = sorted( + data, key=lambda x: x["index"] + ) # maintain the same order as input. + return OpenAIEmbeddingResponse(embeddings=[d["embedding"] for d in data]) + + +@dataclass +class StaticOpenAIEmbeddingRequest(OpenAIRequest[OpenAIEmbeddingResponse]): + """A request to the OpenAI Embeddings API that immediately returns a static value. + """ + + result: List[List[float]] + token_consumption = 0 + + async def execute(self) -> OpenAIEmbeddingResponse: + return OpenAIEmbeddingResponse(embeddings=self.result) diff --git a/nlp_primitives/openai/response.py b/nlp_primitives/openai/response.py new file mode 100644 index 00000000..02113393 --- /dev/null +++ b/nlp_primitives/openai/response.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass +from typing import List + + +class OpenAIResponse(object): + """A response from the OpenAI API.""" + + +@dataclass +class OpenAIEmbeddingResponse(OpenAIResponse): + """A response from the OpenAI Embeddings API.""" + + embeddings: List[List[float]] diff --git a/nlp_primitives/utils/__init__.py b/nlp_primitives/utils/__init__.py new file mode 100644 index 00000000..44e75274 --- /dev/null +++ b/nlp_primitives/utils/__init__.py @@ -0,0 +1 @@ +from nlp_primitives.utils.event_loop import CurrentEventLoop diff --git a/nlp_primitives/utils/event_loop.py b/nlp_primitives/utils/event_loop.py new file mode 100644 index 00000000..4d352270 --- /dev/null +++ b/nlp_primitives/utils/event_loop.py @@ -0,0 +1,22 @@ +import asyncio + + +class CurrentEventLoop(object): + """Represents the current event loop for a thread. + + Description: + Gets the running event loop for a thread, or creates a new one. + If a new event loop is created, the close method will close it, otherwise close is a noop. + """ + + def __init__(self): + try: + self.loop = asyncio.get_running_loop() + self.should_close = False + except RuntimeError: + self.loop = asyncio.new_event_loop() + self.should_close = True + + def close(self): + if self.should_close: + self.loop.close() diff --git a/pyproject.toml b/pyproject.toml index 7177fa5d..9cfd4d2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,13 +64,23 @@ dev = [ "pre-commit == 2.20.0", "nlp_primitives[test]" ] -complete = [ + +tensorflow = [ "tensorflow >= 1.14.0; sys_platform!='darwin' or platform_machine!='arm64'", "tensorflow-metal >= 0.4.0; sys_platform=='darwin' and platform_machine=='arm64'", "tensorflow-macos >= 2.8.0; sys_platform=='darwin' and platform_machine=='arm64'", "tensorflow_hub >= 0.4.0", ] +openai = [ + "openai[embeddings] >= 0.26.5", + "tiktoken >= 0.3.0", +] + +complete = [ + "nlp_primitives[tensorflow,openai]", +] + [project.entry-points."featuretools_plugin"] nlp_primitives = "nlp_primitives" From 04236495a7b317d756103c44a0afad72bbec4b59 Mon Sep 17 00:00:00 2001 From: Jonathan Lounsbury Date: Tue, 7 Mar 2023 23:46:16 -0500 Subject: [PATCH 2/5] attempt to fix issue with event loop nesting with nest_asyncio --- nlp_primitives/openai/embeddings.py | 12 ++++-------- pyproject.toml | 3 ++- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/nlp_primitives/openai/embeddings.py b/nlp_primitives/openai/embeddings.py index ce184a70..c88432d9 100644 --- a/nlp_primitives/openai/embeddings.py +++ b/nlp_primitives/openai/embeddings.py @@ -1,6 +1,8 @@ +import asyncio import itertools from typing import List, Optional +import nest_asyncio import numpy as np import pandas as pd import tiktoken @@ -18,7 +20,6 @@ StaticOpenAIEmbeddingRequest, ) from nlp_primitives.openai.response import OpenAIEmbeddingResponse -from nlp_primitives.utils import CurrentEventLoop DEFAULT_MODEL = OpenAIEmbeddingModel( name="text-embedding-ada-002", @@ -179,12 +180,7 @@ async def async_get_embeddings(self, series): def get_function(self): def get_embeddings(series): - current_loop = CurrentEventLoop() - try: - return current_loop.loop.run_until_complete( - self.async_get_embeddings(series) - ) - finally: - current_loop.close() + nest_asyncio.apply() + return asyncio.run(self.async_get_embeddings(series)) return get_embeddings diff --git a/pyproject.toml b/pyproject.toml index 9cfd4d2f..717133e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,8 +73,9 @@ tensorflow = [ ] openai = [ - "openai[embeddings] >= 0.26.5", + "openai[embeddings] >= 0.27.0", "tiktoken >= 0.3.0", + "nest_asyncio >= 1.5.6" ] complete = [ From dbd890f9627015a713d4e915823aca68c0ec50b1 Mon Sep 17 00:00:00 2001 From: Gaurav Sheni Date: Thu, 9 Mar 2023 15:11:35 -0500 Subject: [PATCH 3/5] Update release_notes.rst --- release_notes.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/release_notes.rst b/release_notes.rst index 075b803b..92369d7c 100644 --- a/release_notes.rst +++ b/release_notes.rst @@ -5,6 +5,7 @@ Changelog Future Release ============== * Enhancements + * Add OpenAI Embeddings Primitive (:pr:`251`) * Fixes * Fix Makefile ``package`` command (:pr:`241`) * Changes @@ -15,7 +16,7 @@ Future Release * Add pull request check for linked issues to CI workflow (:pr:`245`) Thanks to the following people for contributing to this release: - :user:`gsheni`, :user:`sbadithe` + :user:`gsheni`, :user:`jlouns`, :user:`sbadithe` v2.10.0 Jan 10, 2023 ==================== From 3651ab9f555105144f2e5fa6b1a2b89c4e1e03bd Mon Sep 17 00:00:00 2001 From: Jonathan Lounsbury Date: Wed, 15 Mar 2023 15:58:26 -0400 Subject: [PATCH 4/5] Apply suggestions from code review Co-authored-by: Shripad Badithe <60528327+sbadithe@users.noreply.github.com> --- nlp_primitives/openai/embeddings.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/nlp_primitives/openai/embeddings.py b/nlp_primitives/openai/embeddings.py index c88432d9..fe3728b0 100644 --- a/nlp_primitives/openai/embeddings.py +++ b/nlp_primitives/openai/embeddings.py @@ -112,18 +112,17 @@ def _create_request_batches( tokens_in_batch = 0 def add_batched_request() -> int: - """Create a batched request from the currently staged elements and add it to the request list + """Create a batched request from the currently staged elements and add it to the request list. Return the resulting total of tokens that will be sent in the next request. """ if elements_in_batch: - elements_copy = list(elements_in_batch) - elements_in_batch.clear() requests.append( OpenAIEmbeddingRequest( - list_of_text=elements_copy, + list_of_text=elements_in_batch.copy(), model=self.model, token_consumption=tokens_in_batch, ) ) + elements_in_batch.clear() return 0 else: return tokens_in_batch From 9f73db6fc778f2695c79660475ffd6bcc3c7c038 Mon Sep 17 00:00:00 2001 From: Jonathan Lounsbury Date: Wed, 15 Mar 2023 16:05:07 -0400 Subject: [PATCH 5/5] fix can_fit_in_batch branch --- nlp_primitives/openai/embeddings.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nlp_primitives/openai/embeddings.py b/nlp_primitives/openai/embeddings.py index fe3728b0..55bcf928 100644 --- a/nlp_primitives/openai/embeddings.py +++ b/nlp_primitives/openai/embeddings.py @@ -112,7 +112,8 @@ def _create_request_batches( tokens_in_batch = 0 def add_batched_request() -> int: - """Create a batched request from the currently staged elements and add it to the request list. Return the resulting total of tokens that will be sent in the next request. + """Create a batched request from the currently staged elements and add it to the request list. + Return the resulting total of tokens that will be sent in the next request. """ if elements_in_batch: requests.append( @@ -147,7 +148,7 @@ def can_fit_in_batch(tokens) -> bool: next_tokens = len(encoding.encode(element)) # can this element fit in the batch? - if can_fit_in_batch(next_tokens): + if not can_fit_in_batch(next_tokens): # can't fit -- construct a request with existing elements tokens_in_batch = add_batched_request()