Skip to content

Commit

Permalink
Add CLIPFeatureExtractor (and tests) (#387)
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova authored Nov 15, 2023
1 parent c980730 commit 35d61f5
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ export class ImageFeatureExtractor extends FeatureExtractor {

}

export class CLIPFeatureExtractor extends ImageFeatureExtractor { }
export class ConvNextFeatureExtractor extends ImageFeatureExtractor { }
export class ViTFeatureExtractor extends ImageFeatureExtractor { }
export class MobileViTFeatureExtractor extends ImageFeatureExtractor { }
Expand Down Expand Up @@ -1538,6 +1539,7 @@ export class AutoProcessor {
WhisperFeatureExtractor,
ViTFeatureExtractor,
MobileViTFeatureExtractor,
CLIPFeatureExtractor,
ConvNextFeatureExtractor,
BeitFeatureExtractor,
DeiTFeatureExtractor,
Expand Down
26 changes: 22 additions & 4 deletions tests/processors.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ describe('Processors', () => {
beit: 'microsoft/beit-base-patch16-224-pt22k-ft22k',
detr: 'facebook/detr-resnet-50',
yolos: 'hustvl/yolos-small-300',
clip: 'openai/clip-vit-base-patch16',
}

const TEST_IMAGES = {
Expand Down Expand Up @@ -171,7 +172,7 @@ describe('Processors', () => {
it(MODELS.deit, async () => {
const processor = await AutoProcessor.from_pretrained(m(MODELS.deit))

{ // Tests grayscale image
{
const image = await load_image(TEST_IMAGES.tiger);
const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image);

Expand All @@ -187,7 +188,7 @@ describe('Processors', () => {
it(MODELS.beit, async () => {
const processor = await AutoProcessor.from_pretrained(m(MODELS.beit))

{ // Tests grayscale image
{
const image = await load_image(TEST_IMAGES.tiger);
const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image);

Expand All @@ -204,7 +205,7 @@ describe('Processors', () => {
it(MODELS.detr, async () => {
const processor = await AutoProcessor.from_pretrained(m(MODELS.detr))

{ // Tests grayscale image
{
const image = await load_image(TEST_IMAGES.tiger);
const { pixel_values, original_sizes, reshaped_input_sizes, pixel_mask } = await processor(image);

Expand All @@ -225,7 +226,7 @@ describe('Processors', () => {
it(MODELS.yolos, async () => {
const processor = await AutoProcessor.from_pretrained(m(MODELS.yolos))

{ // Tests grayscale image
{
const image = await load_image(TEST_IMAGES.tiger);
const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image);

Expand All @@ -236,5 +237,22 @@ describe('Processors', () => {
compare(reshaped_input_sizes, [[888, 1333]]);
}
}, MAX_TEST_EXECUTION_TIME);

// CLIPFeatureExtractor
// - tests center crop (do_center_crop=true, crop_size=224)
it(MODELS.clip, async () => {
const processor = await AutoProcessor.from_pretrained(m(MODELS.clip))

{
const image = await load_image(TEST_IMAGES.tiger);
const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image);

compare(pixel_values.dims, [1, 3, 224, 224]);
compare(avg(pixel_values.data), -0.06678297738282096);

compare(original_sizes, [[408, 612]]);
compare(reshaped_input_sizes, [[224, 224]]);
}
}, MAX_TEST_EXECUTION_TIME);
});
});

0 comments on commit 35d61f5

Please sign in to comment.