diff --git a/docs/source/pipelines.md b/docs/source/pipelines.md index 3e1ad6b15..1e2e1d439 100644 --- a/docs/source/pipelines.md +++ b/docs/source/pipelines.md @@ -148,6 +148,70 @@ Cheddar is my go-to for any occasion or mood; It adds depth and richness without being overpowering its taste buds alone ``` +### Streaming + +Some pipelines such as `text-generation` or `automatic-speech-recognition` support streaming output. This is achieved using the `TextStreamer` class. For example, when using a chat model like `Qwen2.5-Coder-0.5B-Instruct`, you can specify a callback function that will be called with each generated token text (if unset, new tokens will be printed to the console). + +```js +import { pipeline, TextStreamer } from "@huggingface/transformers"; + +// Create a text generation pipeline +const generator = await pipeline( + "text-generation", + "onnx-community/Qwen2.5-Coder-0.5B-Instruct", + { dtype: "q4" }, +); + +// Define the list of messages +const messages = [ + { role: "system", content: "You are a helpful assistant." }, + { role: "user", content: "Write a quick sort algorithm." }, +]; + +// Create text streamer +const streamer = new TextStreamer(generator.tokenizer, { + skip_prompt: true, + // Optionally, do something with the text (e.g., write to a textbox) + // callback_function: (text) => { /* Do something with text */ }, +}) + +// Generate a response +const result = await generator(messages, { max_new_tokens: 512, do_sample: false, streamer }); +``` + +Logging `result[0].generated_text` to the console gives: + + +
+Click to view the console output +
+Here's a simple implementation of the quick sort algorithm in Python:
+```python
+def quick_sort(arr):
+    if len(arr) <= 1:
+        return arr
+    pivot = arr[len(arr) // 2]
+    left = [x for x in arr if x < pivot]
+    middle = [x for x in arr if x == pivot]
+    right = [x for x in arr if x > pivot]
+    return quick_sort(left) + middle + quick_sort(right)
+# Example usage:
+arr = [3, 6, 8, 10, 1, 2]
+sorted_arr = quick_sort(arr)
+print(sorted_arr)
+```
+### Explanation:
+- **Base Case**: If the array has less than or equal to one element (i.e., `len(arr)` is less than or equal to `1`), it is already sorted and can be returned as is.
+- **Pivot Selection**: The pivot is chosen as the middle element of the array.
+- **Partitioning**: The array is partitioned into three parts: elements less than the pivot (`left`), elements equal to the pivot (`middle`), and elements greater than the pivot (`right`). These partitions are then recursively sorted.
+- **Recursive Sorting**: The subarrays are sorted recursively using `quick_sort`.
+This approach ensures that each recursive call reduces the problem size by half until it reaches a base case.
+
+
+ +This streaming feature allows you to process the output as it is generated, rather than waiting for the entire output to be generated before processing it. + + For more information on the available options for each pipeline, refer to the [API Reference](./api/pipelines). If you would like more control over the inference process, you can use the [`AutoModel`](./api/models), [`AutoTokenizer`](./api/tokenizers), or [`AutoProcessor`](./api/processors) classes instead. diff --git a/src/generation/configuration_utils.js b/src/generation/configuration_utils.js index 33a6fbe81..8474057da 100644 --- a/src/generation/configuration_utils.js +++ b/src/generation/configuration_utils.js @@ -259,6 +259,13 @@ export class GenerationConfig { */ suppress_tokens = null; + /** + * A streamer that will be used to stream the generation. + * @type {import('./streamers.js').TextStreamer} + * @default null + */ + streamer = null; + /** * A list of tokens that will be suppressed at the beginning of the generation. * The `SuppressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled. diff --git a/src/generation/streamers.js b/src/generation/streamers.js index 64afc71c7..effe2e01f 100644 --- a/src/generation/streamers.js +++ b/src/generation/streamers.js @@ -34,7 +34,12 @@ const stdout_write = apis.IS_PROCESS_AVAILABLE export class TextStreamer extends BaseStreamer { /** * - * @param {import('../tokenizers.js').PreTrainedTokenizer} tokenizer + * @param {import('../tokenizers.js').PreTrainedTokenizer} tokenizer + * @param {Object} options + * @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens + * @param {function(string): void} [options.callback_function=null] Function to call when a piece of text is ready to display + * @param {function(bigint[]): void} [options.token_callback_function=null] Function to call when a new token is generated + * @param {Object} [options.decode_kwargs={}] Additional keyword arguments to pass to the tokenizer's decode method */ constructor(tokenizer, { skip_prompt = false, @@ -143,7 +148,7 @@ export class WhisperTextStreamer extends TextStreamer { * @param {Object} options * @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens * @param {function(string): void} [options.callback_function=null] Function to call when a piece of text is ready to display - * @param {function(string): void} [options.token_callback_function=null] Function to call when a new token is generated + * @param {function(bigint[]): void} [options.token_callback_function=null] Function to call when a new token is generated * @param {function(number): void} [options.on_chunk_start=null] Function to call when a new chunk starts * @param {function(number): void} [options.on_chunk_end=null] Function to call when a chunk ends * @param {function(): void} [options.on_finalize=null] Function to call when the stream is finalized