Skip to content

Commit

Permalink
Add an example and type enhancement for TextStreamer (#1066)
Browse files Browse the repository at this point in the history
* typing: GenerationConfig option for TextStreamer

* docs: streaming example with following the style

* docs: streaming description from @xenova's suggestion

Co-authored-by: Joshua Lochner <[email protected]>

* fix: streaming example from @xenova's suggestion

Co-authored-by: Joshua Lochner <[email protected]>

* fix: <pre> tag by wrapping it in a <detail> tag

* fix: remove newlines for proper rendering

---------

Co-authored-by: Joshua Lochner <[email protected]>
  • Loading branch information
seonglae and xenova authored Dec 3, 2024
1 parent 31ce759 commit 0dc1d8b
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 2 deletions.
64 changes: 64 additions & 0 deletions docs/source/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,70 @@ Cheddar is my go-to for any occasion or mood;
It adds depth and richness without being overpowering its taste buds alone
```

### Streaming

Some pipelines such as `text-generation` or `automatic-speech-recognition` support streaming output. This is achieved using the `TextStreamer` class. For example, when using a chat model like `Qwen2.5-Coder-0.5B-Instruct`, you can specify a callback function that will be called with each generated token text (if unset, new tokens will be printed to the console).

```js
import { pipeline, TextStreamer } from "@huggingface/transformers";

// Create a text generation pipeline
const generator = await pipeline(
"text-generation",
"onnx-community/Qwen2.5-Coder-0.5B-Instruct",
{ dtype: "q4" },
);

// Define the list of messages
const messages = [
{ role: "system", content: "You are a helpful assistant." },
{ role: "user", content: "Write a quick sort algorithm." },
];

// Create text streamer
const streamer = new TextStreamer(generator.tokenizer, {
skip_prompt: true,
// Optionally, do something with the text (e.g., write to a textbox)
// callback_function: (text) => { /* Do something with text */ },
})

// Generate a response
const result = await generator(messages, { max_new_tokens: 512, do_sample: false, streamer });
```

Logging `result[0].generated_text` to the console gives:


<details>
<summary>Click to view the console output</summary>
<pre>
Here's a simple implementation of the quick sort algorithm in Python:
```python
def quick_sort(arr):
if len(arr) <= 1:
return arr
pivot = arr[len(arr) // 2]
left = [x for x in arr if x < pivot]
middle = [x for x in arr if x == pivot]
right = [x for x in arr if x > pivot]
return quick_sort(left) + middle + quick_sort(right)
# Example usage:
arr = [3, 6, 8, 10, 1, 2]
sorted_arr = quick_sort(arr)
print(sorted_arr)
```
### Explanation:
- **Base Case**: If the array has less than or equal to one element (i.e., `len(arr)` is less than or equal to `1`), it is already sorted and can be returned as is.
- **Pivot Selection**: The pivot is chosen as the middle element of the array.
- **Partitioning**: The array is partitioned into three parts: elements less than the pivot (`left`), elements equal to the pivot (`middle`), and elements greater than the pivot (`right`). These partitions are then recursively sorted.
- **Recursive Sorting**: The subarrays are sorted recursively using `quick_sort`.
This approach ensures that each recursive call reduces the problem size by half until it reaches a base case.
</pre>
</details>

This streaming feature allows you to process the output as it is generated, rather than waiting for the entire output to be generated before processing it.


For more information on the available options for each pipeline, refer to the [API Reference](./api/pipelines).
If you would like more control over the inference process, you can use the [`AutoModel`](./api/models), [`AutoTokenizer`](./api/tokenizers), or [`AutoProcessor`](./api/processors) classes instead.

Expand Down
7 changes: 7 additions & 0 deletions src/generation/configuration_utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,13 @@ export class GenerationConfig {
*/
suppress_tokens = null;

/**
* A streamer that will be used to stream the generation.
* @type {import('./streamers.js').TextStreamer}
* @default null
*/
streamer = null;

/**
* A list of tokens that will be suppressed at the beginning of the generation.
* The `SuppressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled.
Expand Down
9 changes: 7 additions & 2 deletions src/generation/streamers.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ const stdout_write = apis.IS_PROCESS_AVAILABLE
export class TextStreamer extends BaseStreamer {
/**
*
* @param {import('../tokenizers.js').PreTrainedTokenizer} tokenizer
* @param {import('../tokenizers.js').PreTrainedTokenizer} tokenizer
* @param {Object} options
* @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens
* @param {function(string): void} [options.callback_function=null] Function to call when a piece of text is ready to display
* @param {function(bigint[]): void} [options.token_callback_function=null] Function to call when a new token is generated
* @param {Object} [options.decode_kwargs={}] Additional keyword arguments to pass to the tokenizer's decode method
*/
constructor(tokenizer, {
skip_prompt = false,
Expand Down Expand Up @@ -143,7 +148,7 @@ export class WhisperTextStreamer extends TextStreamer {
* @param {Object} options
* @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens
* @param {function(string): void} [options.callback_function=null] Function to call when a piece of text is ready to display
* @param {function(string): void} [options.token_callback_function=null] Function to call when a new token is generated
* @param {function(bigint[]): void} [options.token_callback_function=null] Function to call when a new token is generated
* @param {function(number): void} [options.on_chunk_start=null] Function to call when a new chunk starts
* @param {function(number): void} [options.on_chunk_end=null] Function to call when a chunk ends
* @param {function(): void} [options.on_finalize=null] Function to call when the stream is finalized
Expand Down

0 comments on commit 0dc1d8b

Please sign in to comment.