From f35a7c933319d99a1944bc858c0802d15fc0369d Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 29 Jun 2024 12:55:59 +0200 Subject: [PATCH] Add `CLIPTextModel` and `CLIPVisionModel` --- src/models.js | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/models.js b/src/models.js index a8112912e..7c0cb0991 100644 --- a/src/models.js +++ b/src/models.js @@ -3099,6 +3099,18 @@ export class CLIPPreTrainedModel extends PreTrainedModel { } */ export class CLIPModel extends CLIPPreTrainedModel { } +/** + * The text model from CLIP without any head or projection on top. + */ +export class CLIPTextModel extends CLIPPreTrainedModel { + /** @type {PreTrainedModel.from_pretrained} */ + static async from_pretrained(pretrained_model_name_or_path, options = {}) { + // Update default model file name if not provided + options.model_file_name ??= 'text_model'; + return super.from_pretrained(pretrained_model_name_or_path, options); + } +} + /** * CLIP Text Model with a projection layer on top (a linear layer on top of the pooled output) * @@ -3126,7 +3138,6 @@ export class CLIPModel extends CLIPPreTrainedModel { } * ``` */ export class CLIPTextModelWithProjection extends CLIPPreTrainedModel { - /** @type {PreTrainedModel.from_pretrained} */ static async from_pretrained(pretrained_model_name_or_path, options = {}) { // Update default model file name if not provided @@ -3135,6 +3146,18 @@ export class CLIPTextModelWithProjection extends CLIPPreTrainedModel { } } +/** + * The vision model from CLIP without any head or projection on top. + */ +export class CLIPVisionModel extends CLIPPreTrainedModel { + /** @type {PreTrainedModel.from_pretrained} */ + static async from_pretrained(pretrained_model_name_or_path, options = {}) { + // Update default model file name if not provided + options.model_file_name ??= 'vision_model'; + return super.from_pretrained(pretrained_model_name_or_path, options); + } +} + /** * CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output) *