Skip to content

Commit

Permalink
Improve typing of pipeline helper function
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Dec 29, 2023
1 parent 61459e3 commit b2d3dd9
Showing 1 changed file with 25 additions and 13 deletions.
38 changes: 25 additions & 13 deletions src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -2335,7 +2335,7 @@ export class DepthEstimationPipeline extends Pipeline {
}
}

const SUPPORTED_TASKS = {
const SUPPORTED_TASKS = Object.freeze({
"text-classification": {
"tokenizer": AutoTokenizer,
"pipeline": TextClassificationPipeline,
Expand Down Expand Up @@ -2609,24 +2609,35 @@ const SUPPORTED_TASKS = {
},
"type": "text",
},
}
})


const TASK_ALIASES = {
// TODO: Add types for TASK_ALIASES
const TASK_ALIASES = Object.freeze({
"sentiment-analysis": "text-classification",
"ner": "token-classification",
"vqa": "visual-question-answering",
// "vqa": "visual-question-answering", // TODO: Add
"asr": "automatic-speech-recognition",
"text-to-speech": "text-to-audio",

// Add for backwards compatibility
"embeddings": "feature-extraction",
}
});

/**
* Utility factory method to build a [`Pipeline`] object.
*
* @param {string} task The task defining which pipeline will be returned. Currently accepted tasks are:
* @typedef {keyof typeof SUPPORTED_TASKS} TaskType
* @typedef {keyof typeof TASK_ALIASES} AliasType
* @typedef {TaskType | AliasType} PipelineType All possible pipeline types.
* @typedef {{[K in TaskType]: InstanceType<typeof SUPPORTED_TASKS[K]["pipeline"]>}} SupportedTasks A mapping of pipeline names to their corresponding pipeline classes.
* @typedef {{[K in AliasType]: InstanceType<typeof SUPPORTED_TASKS[TASK_ALIASES[K]]["pipeline"]>}} AliasTasks A mapping from pipeline aliases to their corresponding pipeline classes.
* @typedef {SupportedTasks & AliasTasks} AllTasks A mapping from all pipeline names and aliases to their corresponding pipeline classes.
*/

/**
* Utility factory method to build a `Pipeline` object.
*
* @template {PipelineType} T The type of pipeline to return.
* @param {T} task The task defining which pipeline will be returned. Currently accepted tasks are:
* - `"audio-classification"`: will return a `AudioClassificationPipeline`.
* - `"automatic-speech-recognition"`: will return a `AutomaticSpeechRecognitionPipeline`.
* - `"depth-estimation"`: will return a `DepthEstimationPipeline`.
Expand All @@ -2651,7 +2662,7 @@ const TASK_ALIASES = {
* - `"zero-shot-object-detection"`: will return a `ZeroShotObjectDetectionPipeline`.
* @param {string} [model=null] The name of the pre-trained model to use. If not specified, the default model for the task will be used.
* @param {import('./utils/hub.js').PretrainedOptions} [options] Optional parameters for the pipeline.
* @returns {Promise<Pipeline>} A Pipeline object for the specified task.
* @returns {Promise<AllTasks[T]>} A Pipeline object for the specified task.
* @throws {Error} If an unsupported pipeline is requested.
*/
export async function pipeline(
Expand All @@ -2669,10 +2680,11 @@ export async function pipeline(
// Helper method to construct pipeline

// Apply aliases
// @ts-ignore
task = TASK_ALIASES[task] ?? task;

// Get pipeline info
let pipelineInfo = SUPPORTED_TASKS[task.split('_', 1)[0]];
const pipelineInfo = SUPPORTED_TASKS[task.split('_', 1)[0]];
if (!pipelineInfo) {
throw Error(`Unsupported pipeline: ${task}. Must be one of [${Object.keys(SUPPORTED_TASKS)}]`)
}
Expand All @@ -2683,7 +2695,7 @@ export async function pipeline(
console.log(`No model specified. Using default model: "${model}".`);
}

let pretrainedOptions = {
const pretrainedOptions = {
quantized,
progress_callback,
config,
Expand All @@ -2699,7 +2711,7 @@ export async function pipeline(
]);

// Load model, tokenizer, and processor (if they exist)
let results = await loadItems(classes, model, pretrainedOptions);
const results = await loadItems(classes, model, pretrainedOptions);
results.task = task;

dispatchCallback(progress_callback, {
Expand All @@ -2708,7 +2720,7 @@ export async function pipeline(
'model': model,
});

let pipelineClass = pipelineInfo.pipeline;
const pipelineClass = pipelineInfo.pipeline;
return new pipelineClass(results);
}

Expand Down

0 comments on commit b2d3dd9

Please sign in to comment.