Skip to content

Commit

Permalink
[WIP] Add support for Glm
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Nov 29, 2024
1 parent 2c92943 commit 2d89de5
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/configs.js
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ function getNormalizedConfig(config) {
break;
case 'gemma':
case 'gemma2':
case 'glm':
mapping['num_heads'] = 'num_key_value_heads';
mapping['num_layers'] = 'num_hidden_layers';
mapping['dim_kv'] = 'head_dim';
Expand Down
19 changes: 19 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -4037,6 +4037,23 @@ export class Gemma2Model extends Gemma2PreTrainedModel { }
export class Gemma2ForCausalLM extends Gemma2PreTrainedModel { }
//////////////////////////////////////////////////


//////////////////////////////////////////////////
// Glm models

/**
* The bare Glm Model outputting raw hidden-states without any specific head on top.
*/
export class GlmPreTrainedModel extends PreTrainedModel { }
/**
* The bare Glm Model outputting raw hidden-states without any specific head on top.
*/
export class GlmModel extends GlmPreTrainedModel { }

export class GlmForCausalLM extends GlmPreTrainedModel { }
//////////////////////////////////////////////////


//////////////////////////////////////////////////
export class OpenELMPreTrainedModel extends PreTrainedModel { }
export class OpenELMModel extends OpenELMPreTrainedModel { }
Expand Down Expand Up @@ -6765,6 +6782,7 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
['cohere', ['CohereModel', CohereModel]],
['gemma', ['GemmaModel', GemmaModel]],
['gemma2', ['Gemma2Model', Gemma2Model]],
['glm', ['GlmModel', GlmModel]],
['openelm', ['OpenELMModel', OpenELMModel]],
['qwen2', ['Qwen2Model', Qwen2Model]],
['phi', ['PhiModel', PhiModel]],
Expand Down Expand Up @@ -6856,6 +6874,7 @@ const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
['cohere', ['CohereForCausalLM', CohereForCausalLM]],
['gemma', ['GemmaForCausalLM', GemmaForCausalLM]],
['gemma2', ['Gemma2ForCausalLM', Gemma2ForCausalLM]],
['glm', ['GlmForCausalLM', GlmForCausalLM]],
['openelm', ['OpenELMForCausalLM', OpenELMForCausalLM]],
['qwen2', ['Qwen2ForCausalLM', Qwen2ForCausalLM]],
['phi', ['PhiForCausalLM', PhiForCausalLM]],
Expand Down
55 changes: 53 additions & 2 deletions tests/tiny_random.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import {
BertTokenizer,
T5Tokenizer,
WhisperTokenizer,
BartTokenizer,
MarianTokenizer,
PreTrainedTokenizer,
AutoTokenizer,
Expand All @@ -29,6 +28,7 @@ import {
CohereForCausalLM,
GemmaForCausalLM,
Gemma2ForCausalLM,
GlmForCausalLM,
OPTForCausalLM,
GPTNeoXForCausalLM,
GPTJForCausalLM,
Expand Down Expand Up @@ -1366,7 +1366,7 @@ describe("Tiny random models", () => {
});
});

describe("gemma", () => {
describe("gemma2", () => {
describe("Gemma2ForCausalLM", () => {
const model_id = "hf-internal-testing/tiny-random-Gemma2ForCausalLM";
/** @type {Gemma2ForCausalLM} */
Expand Down Expand Up @@ -1417,6 +1417,57 @@ describe("Tiny random models", () => {
});
});

describe("glm", () => {
describe("GlmForCausalLM", () => {
const model_id = "hf-internal-testing/tiny-random-GlmForCausalLM";
/** @type {GlmForCausalLM} */
let model;
/** @type {PreTrainedTokenizer} */
let tokenizer;
beforeAll(async () => {
model = await GlmForCausalLM.from_pretrained(model_id, {
// TODO move to config
...DEFAULT_MODEL_OPTIONS,
});
tokenizer = await PreTrainedTokenizer.from_pretrained(model_id);
// tokenizer.padding_side = "left";
}, MAX_MODEL_LOAD_TIME);

it(
"batch_size=1",
async () => {
const inputs = tokenizer("hello");
const outputs = await model.generate({
...inputs,
max_length: 10,
});
expect(outputs.tolist()).toEqual([[23582n, 5797n, 38238n, 24486n, 36539n, 34489n, 6948n, 34489n, 6948n, 34489n]]);
},
MAX_TEST_EXECUTION_TIME,
);

it(
"batch_size>1",
async () => {
const inputs = tokenizer(["hello", "hello world"], { padding: true });
const outputs = await model.generate({
...inputs,
max_length: 10,
});
expect(outputs.tolist()).toEqual([
[59246n, 23582n, 5797n, 38238n, 24486n, 36539n, 34489n, 6948n, 34489n, 6948n],
[23582n, 2901n, 39936n, 25036n, 55411n, 10337n, 3424n, 39183n, 30430n, 37285n]
]);
},
MAX_TEST_EXECUTION_TIME,
);

afterAll(async () => {
await model?.dispose();
}, MAX_MODEL_DISPOSE_TIME);
});
});

describe("gpt_neo", () => {
describe("GPTNeoForCausalLM", () => {
const model_id = "hf-internal-testing/tiny-random-GPTNeoForCausalLM";
Expand Down

0 comments on commit 2d89de5

Please sign in to comment.