Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add binary embedding quantization support to FeatureExtraction pipeline #691

Merged
merged 6 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ import {
Tensor,
mean_pooling,
interpolate,
quantize_embeddings,
} from './utils/tensor.js';
import { RawImage } from './utils/image.js';

Expand Down Expand Up @@ -1112,6 +1113,8 @@ export class ZeroShotClassificationPipeline extends (/** @type {new (options: Te
* @typedef {Object} FeatureExtractionPipelineOptions Parameters specific to feature extraction pipelines.
* @property {'none'|'mean'|'cls'} [pooling="none"] The pooling method to use.
* @property {boolean} [normalize=false] Whether or not to normalize the embeddings in the last dimension.
* @property {boolean} [quantize=false] Whether or not to quantize the embeddings.
* @property {'binary'|'ubinary'} [precision='binary'] The precision to use for quantization.
*
* @callback FeatureExtractionPipelineCallback Extract the features of the input(s).
* @param {string|string[]} texts One or several texts (or one list of texts) to get the features of.
Expand Down Expand Up @@ -1157,6 +1160,16 @@ export class ZeroShotClassificationPipeline extends (/** @type {new (options: Te
* // dims: [1, 384]
* // }
* ```
* **Example:** Calculating binary embeddings with `sentence-transformers` models.
* ```javascript
* const extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2');
* const output = await extractor('This is a simple test.', { pooling: 'mean', normalize: true, quantize: true, precision: 'binary' });
xenova marked this conversation as resolved.
Show resolved Hide resolved
* // Tensor {
* // type: 'int8',
* // data: Int8Array [-13, -78, 21, ...],
xenova marked this conversation as resolved.
Show resolved Hide resolved
* // dims: [1, 48]
* // }
* ```
*/
export class FeatureExtractionPipeline extends (/** @type {new (options: TextPipelineConstructorArgs) => FeatureExtractionPipelineType} */ (Pipeline)) {
/**
Expand All @@ -1171,6 +1184,8 @@ export class FeatureExtractionPipeline extends (/** @type {new (options: TextPip
async _call(texts, {
pooling = /** @type {'none'} */('none'),
normalize = false,
quantize = false,
precision = /** @type {'binary'|'ubinary'} */('binary'),
xenova marked this conversation as resolved.
Show resolved Hide resolved
} = {}) {

// Run tokenization
Expand Down Expand Up @@ -1203,6 +1218,10 @@ export class FeatureExtractionPipeline extends (/** @type {new (options: TextPip
result = result.normalize(2, -1);
}

if (quantize) {
result = quantize_embeddings(result, precision);
}

return result;
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -2653,8 +2653,8 @@ export class PreTrainedTokenizer extends Callable {
}

} else {
if (text === null) {
throw Error('text may not be null')
if (text === null || text === undefined) {
throw Error('text may not be null or undefined')
}

if (Array.isArray(text_pair)) {
Expand Down
44 changes: 44 additions & 0 deletions src/utils/tensor.js
Original file line number Diff line number Diff line change
Expand Up @@ -1193,3 +1193,47 @@ export function ones(size) {
export function ones_like(tensor) {
return ones(tensor.dims);
}

/**
* Quantizes the embeddings tensor to binary or unsigned binary precision.
* @param {Tensor} tensor The tensor to quantize.
* @param {'binary'|'ubinary'} precision The precision to use for quantization.
* @returns {Tensor} The quantized tensor.
*/
export function quantize_embeddings(tensor, precision) {
if (tensor.dims.length !== 2) {
throw new Error("The tensor must have 2 dimensions");
}
if (tensor.dims.at(-1) % 8 !== 0) {
throw new Error("The last dimension of the tensor must be a multiple of 8");
}
if (!['binary', 'ubinary'].includes(precision)) {
throw new Error("The precision must be either 'binary' or 'ubinary'");
}

const signed = precision === 'binary';
const dtype = signed ? 'int8' : 'uint8';

// Create a typed array to store the packed bits
const cls = signed ? Int8Array : Uint8Array;
const inputData = tensor.data;
const outputData = new cls(inputData.length / 8);
xenova marked this conversation as resolved.
Show resolved Hide resolved

// Iterate over each number in the array
for (let i = 0; i < inputData.length; ++i) {
// Determine if the number is greater than 0
const bit = inputData[i] > 0 ? 1 : 0;

// Calculate the index in the typed array and the position within the byte
const arrayIndex = Math.floor(i / 8);
const bitPosition = i % 8;

// Pack the bit into the typed array
outputData[arrayIndex] |= bit << (7 - bitPosition);
if (signed && bitPosition === 0) {
outputData[arrayIndex] -= 128;
}
};

return new Tensor(dtype, outputData, [tensor.dims[0], tensor.dims[1] / 8]);
}
Loading