Skip to content

Commit

Permalink
Fix GPT2 example. (#158)
Browse files Browse the repository at this point in the history
  • Loading branch information
schmmd authored Jan 26, 2022
1 parent 4011482 commit 5ff51d6
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 59 deletions.
6 changes: 6 additions & 0 deletions examples/train_gpt2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,9 @@
This Tango example showcases how you could fine-tune GPT2 from [transformers](https://github.com/huggingface/transformers) on WikiText2 or a similar dataset.
It's best that you run this experiment on a machine with a GPU and PyTorch [properly installed](https://pytorch.org/get-started/locally/#start-locally),
otherwise Tango will fall back to CPU-only and it will be extremely slow.

You can kick off a training job with the following command:

```
$ tango run config.jsonnet
```
118 changes: 59 additions & 59 deletions examples/train_gpt2/config.jsonnet
Original file line number Diff line number Diff line change
@@ -1,76 +1,76 @@
local pretrained_model = "gpt2";
local pretrained_model = 'gpt2';
local training_steps = 200;
local warmup_steps = 20;
local batch_size = 8;
local validate_every = 20;
local distributed = false; # Set to `true` to train on 2 (or more) GPUs.
local distributed = false; // Set to `true` to train on 2 (or more) GPUs.
local devices = if distributed then 2 else 1;
local grad_accum = if distributed then 2 else 4;

local distributed_dataloader = {
"batch_size": batch_size,
"collate_fn": {"type": "transformers_default"},
"sampler": {
"type": "torch::DistributedSampler",
"shuffle": true,
"drop_last": true,
}
batch_size: batch_size,
collate_fn: { type: 'transformers_default' },
sampler: {
type: 'torch::DistributedSampler',
shuffle: true,
drop_last: true,
},
};

local single_device_dataloader = {
"shuffle": true,
"batch_size": batch_size,
"collate_fn": {"type": "transformers_default"},
shuffle: true,
batch_size: batch_size,
collate_fn: { type: 'transformers_default' },
};

local dataloader = if distributed then distributed_dataloader else single_device_dataloader;

{
"steps": {
"raw_data": {
"type": "datasets::load",
"path": "wikitext",
"name": "wikitext-2-raw-v1",
},
"tokenized_data": {
"type": "tokenize_data",
"dataset": {"type": "ref", "ref": "raw_data"},
"pretrained_model_name": pretrained_model,
},
"trained_model": {
"type": "torch::train",
"model": {
"type": "gpt2",
"pretrained_model_name_or_path": pretrained_model,
},
"dataset_dict": {"type": "ref", "ref": "tokenized_data"},
"train_dataloader": dataloader,
"validation_split": "validation",
"optimizer": {
"type": "transformers_adamw",
"lr": 0.0007,
"betas": [0.9, 0.95],
"eps": 1e-6,
"correct_bias": false,
},
"lr_scheduler": {
"type": "linear_with_warmup",
"num_warmup_steps": warmup_steps,
"num_training_steps": training_steps,
},
"grad_accum": grad_accum,
"train_steps": training_steps,
"validate_every": validate_every,
"checkpoint_every": validate_every,
"log_every": 1,
"device_count": devices,
}
"final_metrics": {
"type": "torch::eval",
"model": {"type": "ref", "ref": "trained_model"},
"dataset_dict": {"type": "ref", "ref": "tokenized_data"},
"dataloader": single_device_dataloader,
"test_split": "test",
},
}
steps: {
raw_data: {
type: 'datasets::load',
path: 'wikitext',
name: 'wikitext-2-raw-v1',
},
tokenized_data: {
type: 'tokenize_data',
dataset: { type: 'ref', ref: 'raw_data' },
pretrained_model_name: pretrained_model,
},
trained_model: {
type: 'torch::train',
model: {
type: 'gpt2',
pretrained_model_name_or_path: pretrained_model,
},
dataset_dict: { type: 'ref', ref: 'tokenized_data' },
train_dataloader: dataloader,
validation_split: 'validation',
optimizer: {
type: 'transformers_adamw',
lr: 0.0007,
betas: [0.9, 0.95],
eps: 1e-6,
correct_bias: false,
},
lr_scheduler: {
type: 'linear_with_warmup',
num_warmup_steps: warmup_steps,
num_training_steps: training_steps,
},
grad_accum: grad_accum,
train_steps: training_steps,
validate_every: validate_every,
checkpoint_every: validate_every,
log_every: 1,
device_count: devices,
},
final_metrics: {
type: 'torch::eval',
model: { type: 'ref', ref: 'trained_model' },
dataset_dict: { type: 'ref', ref: 'tokenized_data' },
dataloader: single_device_dataloader,
test_split: 'test',
},
},
}

0 comments on commit 5ff51d6

Please sign in to comment.