diff --git a/src/models.js b/src/models.js index 3b8aed83c..0bab83253 100644 --- a/src/models.js +++ b/src/models.js @@ -3035,9 +3035,9 @@ export class LlamaPreTrainedModel extends PreTrainedModel { // config doesn't contain pad_token_id, so we assume it is the eos_token_id this.config.pad_token_id = this.config.eos_token_id - this.num_heads = this.config.num_attention_heads + this.num_heads = this.config.num_key_value_heads ?? this.config.num_attention_heads this.num_layers = this.config.num_hidden_layers - this.dim_kv = this.config.hidden_size / this.num_heads; + this.dim_kv = this.config.hidden_size / this.config.num_attention_heads } } /** diff --git a/src/processors.js b/src/processors.js index 31d93195c..74b740ec0 100644 --- a/src/processors.js +++ b/src/processors.js @@ -521,6 +521,7 @@ export class ImageFeatureExtractor extends FeatureExtractor { export class DPTFeatureExtractor extends ImageFeatureExtractor { } export class GLPNFeatureExtractor extends ImageFeatureExtractor { } +export class CLIPFeatureExtractor extends ImageFeatureExtractor { } export class ConvNextFeatureExtractor extends ImageFeatureExtractor { } export class ViTFeatureExtractor extends ImageFeatureExtractor { } export class MobileViTFeatureExtractor extends ImageFeatureExtractor { } @@ -1550,6 +1551,7 @@ export class AutoProcessor { WhisperFeatureExtractor, ViTFeatureExtractor, MobileViTFeatureExtractor, + CLIPFeatureExtractor, ConvNextFeatureExtractor, DPTFeatureExtractor, GLPNFeatureExtractor, diff --git a/tests/processors.test.js b/tests/processors.test.js index d3ae92583..face9e268 100644 --- a/tests/processors.test.js +++ b/tests/processors.test.js @@ -40,6 +40,7 @@ describe('Processors', () => { yolos: 'hustvl/yolos-small-300', dpt: 'Intel/dpt-hybrid-midas', glpn: 'vinvino02/glpn-kitti', + clip: 'openai/clip-vit-base-patch16', } const TEST_IMAGES = { @@ -174,7 +175,7 @@ describe('Processors', () => { it(MODELS.deit, async () => { const processor = await AutoProcessor.from_pretrained(m(MODELS.deit)) - { // Tests grayscale image + { const image = await load_image(TEST_IMAGES.tiger); const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); @@ -190,7 +191,7 @@ describe('Processors', () => { it(MODELS.beit, async () => { const processor = await AutoProcessor.from_pretrained(m(MODELS.beit)) - { // Tests grayscale image + { const image = await load_image(TEST_IMAGES.tiger); const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); @@ -207,7 +208,7 @@ describe('Processors', () => { it(MODELS.detr, async () => { const processor = await AutoProcessor.from_pretrained(m(MODELS.detr)) - { // Tests grayscale image + { const image = await load_image(TEST_IMAGES.tiger); const { pixel_values, original_sizes, reshaped_input_sizes, pixel_mask } = await processor(image); @@ -228,7 +229,7 @@ describe('Processors', () => { it(MODELS.yolos, async () => { const processor = await AutoProcessor.from_pretrained(m(MODELS.yolos)) - { // Tests grayscale image + { const image = await load_image(TEST_IMAGES.tiger); const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); @@ -241,7 +242,7 @@ describe('Processors', () => { }, MAX_TEST_EXECUTION_TIME); - // DPTFeatureExtractor + // DPTFeatureExtractor it(MODELS.dpt, async () => { const processor = await AutoProcessor.from_pretrained(m(MODELS.dpt)) @@ -281,6 +282,21 @@ describe('Processors', () => { compare(original_sizes, [[408, 612]]); compare(reshaped_input_sizes, [[384, 608]]); + + // CLIPFeatureExtractor + // - tests center crop (do_center_crop=true, crop_size=224) + it(MODELS.clip, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.clip)) + + { + const image = await load_image(TEST_IMAGES.tiger); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 224, 224]); + compare(avg(pixel_values.data), -0.06678297738282096); + + compare(original_sizes, [[408, 612]]); + compare(reshaped_input_sizes, [[224, 224]]); } }, MAX_TEST_EXECUTION_TIME); });