diff --git a/src/generation/streamers.js b/src/generation/streamers.js index 099b2d3d4..5dcbd84df 100644 --- a/src/generation/streamers.js +++ b/src/generation/streamers.js @@ -38,11 +38,13 @@ export class TextStreamer extends BaseStreamer { */ constructor(tokenizer, { skip_prompt = false, + callback_function = null, ...decode_kwargs } = {}) { super(); this.tokenizer = tokenizer; this.skip_prompt = skip_prompt; + this.callback_function = callback_function ?? stdout_write; this.decode_kwargs = decode_kwargs; // variables used in the streaming process @@ -115,10 +117,10 @@ export class TextStreamer extends BaseStreamer { */ on_finalized_text(text, stream_end) { if (text.length > 0) { - stdout_write(text); + this.callback_function(text); } if (stream_end) { - stdout_write('\n'); + this.callback_function('\n'); } } }