Skip to content

Commit

Permalink
Merge branch 'main' into add-depth-estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova authored Nov 16, 2023
2 parents accbc9c + 4e4148c commit 96d24c9
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -3035,9 +3035,9 @@ export class LlamaPreTrainedModel extends PreTrainedModel {
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id

this.num_heads = this.config.num_attention_heads
this.num_heads = this.config.num_key_value_heads ?? this.config.num_attention_heads
this.num_layers = this.config.num_hidden_layers
this.dim_kv = this.config.hidden_size / this.num_heads;
this.dim_kv = this.config.hidden_size / this.config.num_attention_heads
}
}
/**
Expand Down
2 changes: 2 additions & 0 deletions src/processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ export class ImageFeatureExtractor extends FeatureExtractor {

export class DPTFeatureExtractor extends ImageFeatureExtractor { }
export class GLPNFeatureExtractor extends ImageFeatureExtractor { }
export class CLIPFeatureExtractor extends ImageFeatureExtractor { }
export class ConvNextFeatureExtractor extends ImageFeatureExtractor { }
export class ViTFeatureExtractor extends ImageFeatureExtractor { }
export class MobileViTFeatureExtractor extends ImageFeatureExtractor { }
Expand Down Expand Up @@ -1550,6 +1551,7 @@ export class AutoProcessor {
WhisperFeatureExtractor,
ViTFeatureExtractor,
MobileViTFeatureExtractor,
CLIPFeatureExtractor,
ConvNextFeatureExtractor,
DPTFeatureExtractor,
GLPNFeatureExtractor,
Expand Down
26 changes: 21 additions & 5 deletions tests/processors.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ describe('Processors', () => {
yolos: 'hustvl/yolos-small-300',
dpt: 'Intel/dpt-hybrid-midas',
glpn: 'vinvino02/glpn-kitti',
clip: 'openai/clip-vit-base-patch16',
}

const TEST_IMAGES = {
Expand Down Expand Up @@ -174,7 +175,7 @@ describe('Processors', () => {
it(MODELS.deit, async () => {
const processor = await AutoProcessor.from_pretrained(m(MODELS.deit))

{ // Tests grayscale image
{
const image = await load_image(TEST_IMAGES.tiger);
const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image);

Expand All @@ -190,7 +191,7 @@ describe('Processors', () => {
it(MODELS.beit, async () => {
const processor = await AutoProcessor.from_pretrained(m(MODELS.beit))

{ // Tests grayscale image
{
const image = await load_image(TEST_IMAGES.tiger);
const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image);

Expand All @@ -207,7 +208,7 @@ describe('Processors', () => {
it(MODELS.detr, async () => {
const processor = await AutoProcessor.from_pretrained(m(MODELS.detr))

{ // Tests grayscale image
{
const image = await load_image(TEST_IMAGES.tiger);
const { pixel_values, original_sizes, reshaped_input_sizes, pixel_mask } = await processor(image);

Expand All @@ -228,7 +229,7 @@ describe('Processors', () => {
it(MODELS.yolos, async () => {
const processor = await AutoProcessor.from_pretrained(m(MODELS.yolos))

{ // Tests grayscale image
{
const image = await load_image(TEST_IMAGES.tiger);
const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image);

Expand All @@ -241,7 +242,7 @@ describe('Processors', () => {
}, MAX_TEST_EXECUTION_TIME);


// DPTFeatureExtractor
// DPTFeatureExtractor
it(MODELS.dpt, async () => {
const processor = await AutoProcessor.from_pretrained(m(MODELS.dpt))

Expand Down Expand Up @@ -281,6 +282,21 @@ describe('Processors', () => {

compare(original_sizes, [[408, 612]]);
compare(reshaped_input_sizes, [[384, 608]]);

// CLIPFeatureExtractor
// - tests center crop (do_center_crop=true, crop_size=224)
it(MODELS.clip, async () => {
const processor = await AutoProcessor.from_pretrained(m(MODELS.clip))

{
const image = await load_image(TEST_IMAGES.tiger);
const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image);

compare(pixel_values.dims, [1, 3, 224, 224]);
compare(avg(pixel_values.data), -0.06678297738282096);

compare(original_sizes, [[408, 612]]);
compare(reshaped_input_sizes, [[224, 224]]);
}
}, MAX_TEST_EXECUTION_TIME);
});
Expand Down

0 comments on commit 96d24c9

Please sign in to comment.