diff --git a/examples/train_gpt2/README.md b/examples/train_gpt2/README.md index c50591ee4..f1025343e 100644 --- a/examples/train_gpt2/README.md +++ b/examples/train_gpt2/README.md @@ -3,3 +3,9 @@ This Tango example showcases how you could fine-tune GPT2 from [transformers](https://github.com/huggingface/transformers) on WikiText2 or a similar dataset. It's best that you run this experiment on a machine with a GPU and PyTorch [properly installed](https://pytorch.org/get-started/locally/#start-locally), otherwise Tango will fall back to CPU-only and it will be extremely slow. + +You can kick off a training job with the following command: + +``` +$ tango run config.jsonnet +``` diff --git a/examples/train_gpt2/config.jsonnet b/examples/train_gpt2/config.jsonnet index cbf4743d3..9f3824cff 100644 --- a/examples/train_gpt2/config.jsonnet +++ b/examples/train_gpt2/config.jsonnet @@ -1,76 +1,76 @@ -local pretrained_model = "gpt2"; +local pretrained_model = 'gpt2'; local training_steps = 200; local warmup_steps = 20; local batch_size = 8; local validate_every = 20; -local distributed = false; # Set to `true` to train on 2 (or more) GPUs. +local distributed = false; // Set to `true` to train on 2 (or more) GPUs. local devices = if distributed then 2 else 1; local grad_accum = if distributed then 2 else 4; local distributed_dataloader = { - "batch_size": batch_size, - "collate_fn": {"type": "transformers_default"}, - "sampler": { - "type": "torch::DistributedSampler", - "shuffle": true, - "drop_last": true, - } + batch_size: batch_size, + collate_fn: { type: 'transformers_default' }, + sampler: { + type: 'torch::DistributedSampler', + shuffle: true, + drop_last: true, + }, }; local single_device_dataloader = { - "shuffle": true, - "batch_size": batch_size, - "collate_fn": {"type": "transformers_default"}, + shuffle: true, + batch_size: batch_size, + collate_fn: { type: 'transformers_default' }, }; local dataloader = if distributed then distributed_dataloader else single_device_dataloader; { - "steps": { - "raw_data": { - "type": "datasets::load", - "path": "wikitext", - "name": "wikitext-2-raw-v1", - }, - "tokenized_data": { - "type": "tokenize_data", - "dataset": {"type": "ref", "ref": "raw_data"}, - "pretrained_model_name": pretrained_model, - }, - "trained_model": { - "type": "torch::train", - "model": { - "type": "gpt2", - "pretrained_model_name_or_path": pretrained_model, - }, - "dataset_dict": {"type": "ref", "ref": "tokenized_data"}, - "train_dataloader": dataloader, - "validation_split": "validation", - "optimizer": { - "type": "transformers_adamw", - "lr": 0.0007, - "betas": [0.9, 0.95], - "eps": 1e-6, - "correct_bias": false, - }, - "lr_scheduler": { - "type": "linear_with_warmup", - "num_warmup_steps": warmup_steps, - "num_training_steps": training_steps, - }, - "grad_accum": grad_accum, - "train_steps": training_steps, - "validate_every": validate_every, - "checkpoint_every": validate_every, - "log_every": 1, - "device_count": devices, - } - "final_metrics": { - "type": "torch::eval", - "model": {"type": "ref", "ref": "trained_model"}, - "dataset_dict": {"type": "ref", "ref": "tokenized_data"}, - "dataloader": single_device_dataloader, - "test_split": "test", - }, - } + steps: { + raw_data: { + type: 'datasets::load', + path: 'wikitext', + name: 'wikitext-2-raw-v1', + }, + tokenized_data: { + type: 'tokenize_data', + dataset: { type: 'ref', ref: 'raw_data' }, + pretrained_model_name: pretrained_model, + }, + trained_model: { + type: 'torch::train', + model: { + type: 'gpt2', + pretrained_model_name_or_path: pretrained_model, + }, + dataset_dict: { type: 'ref', ref: 'tokenized_data' }, + train_dataloader: dataloader, + validation_split: 'validation', + optimizer: { + type: 'transformers_adamw', + lr: 0.0007, + betas: [0.9, 0.95], + eps: 1e-6, + correct_bias: false, + }, + lr_scheduler: { + type: 'linear_with_warmup', + num_warmup_steps: warmup_steps, + num_training_steps: training_steps, + }, + grad_accum: grad_accum, + train_steps: training_steps, + validate_every: validate_every, + checkpoint_every: validate_every, + log_every: 1, + device_count: devices, + }, + final_metrics: { + type: 'torch::eval', + model: { type: 'ref', ref: 'trained_model' }, + dataset_dict: { type: 'ref', ref: 'tokenized_data' }, + dataloader: single_device_dataloader, + test_split: 'test', + }, + }, }