Skip to content

A pure-functional implementation of a machine learning transformer model in Python/JAX

License

Notifications You must be signed in to change notification settings

awf/functional-transformer

Repository files navigation

A fully functional (pun intended) implementation of a machine learning transformer model in Python/JAX. I do realize that 'pure functional' and 'Python' are not necessarily mots quit vont très bien ensemble, but I'm sure you'll agree on reading the code that it has una anima di pura programmazione funzionale. And a little macaronica appeals to the peasant soul. In other words, don't worry about the language...

Given only a few simple BLAS-like functions:

def linear(params, x: jnp.ndarray):
    return x @ params.weight + params.bias[None,:]

def elementwise_linear(params, x: jnp.ndarray):
    return params.gain[None,:] * x + params.bias[None,:]

def standardize(x, eps = 1e-5):
    return (x - x.mean())/(x.std() + eps)

then the entire transformer forward computation is 25 lines of code (excerpt from transformer.py):

def transformer(cfg, params, x: jnp.ndarray):
    """
    cfg: Config, from transformer_init, holds hyperparameters
    params: Current transformer parameters, initialized in init
    x: 1D array of L integers, representing the input sequence
    output: L x n_vocab logits
    """

    L, = x.shape # x is just 1D. Vmap/pmap will handle batching

    # Create mask: 0 to attend, -Inf to ignore
    mask = jnp.log(jnp.tril(jnp.ones((L, L))))

    # Start with token embeddings
    embeddings = cfg.lambda_e * params.embeddings[x, :]     # L x Dm

    # Add (learned) positional encodings
    embeddings += cfg.lambda_pe * params.positional_encodings[:L, :]

    # Apply the transformer layers
    for layer in params.layers:

        # Layer-normalize embeddings
        t1 = vmap(standardize)(embeddings)
        t1 = elementwise_linear(layer.norm_self_attn, t1)   # L x Dm

        # Multi-head self-attention
        for head in layer.heads:

            # Project into this head's query/key space
            query = linear(head.query, t1)                  # L x Dk
            key = linear(head.key, t1)                      # L x Dk

            # Compute L x L attention matrix
            score = query @ key.T + mask                    # L x L
            attn = jax.nn.softmax(cfg.tau * score, axis=1)  # L x L

            value = linear(head.value, t1)                  # L x Dm
            self_attn = attn @ value                        # L x Dm

            # Add this head's contribution into embeddings
            embeddings += self_attn                         # L x Dm

        # Layer-normalize embeddings
        t2 = vmap(standardize)(embeddings)
        t2 = elementwise_linear(layer.norm_ff, t2)          # L x Dm

        # Feedforward fully connected
        t2 = linear(layer.ffn1, t2)                         # L x Dff
        t2 = jax.nn.relu(t2)
        t2 = linear(layer.ffn2, t2)                         # L x Dm

        # Add this layer's contribution into embeddings
        embeddings += t2

    # Layer-normalize embeddings
    embeddings = vmap(standardize)(embeddings)
    embeddings = elementwise_linear(params.pre_output_norm, embeddings)

    # And linearly project to output dimension
    return linear(params.output, embeddings)                # L x n_vocab 

The loss and its gradient needs a few more lines:

def crossentropy(output: jnp.ndarray, target: int):
    return -jax.nn.log_softmax(output)[target]

def seq_crossentropy(output: jnp.ndarray, targets: jnp.ndarray):
    return vmap(crossentropy)(output, targets).mean()

def transformer_loss(cfg, params, x):
    output = transformer(cfg, params, x)

    return seq_crossentropy(output[:-1], x[1:])

# Gradient wrt 'params'
grad_loss = jax.grad(transformer_loss, argnums=1)

The random initialization is also short:

params = ParamsDict()

# Create embedding layer
rng,params.embeddings = rand(rng, jax.random.normal, (n_vocab, d_model))

# Positional encodings initialized to zeros
params.positional_encodings = jnp.zeros((max_len, d_model))

# For transformer layers
params.layers = []
for _ in range(n_layers):
    layer = ParamsDict()
    layer.norm_self_attn = layernorm_init_identity(d_model)

    layer.heads = []
    for _ in range(n_heads):
        head = ParamsDict()
        rng,head.query = linear_init_uniform(rng, d_model, d_k)
        rng,head.key = linear_init_uniform(rng, d_model, d_k)
        rng,head.value = linear_init_uniform(rng, d_model, d_model)
        
        layer.heads.append(head)

    layer.norm_ff = layernorm_init_identity(d_model)

    rng,layer.ffn1 = linear_init_uniform(rng, d_model, d_ff)
    rng,layer.ffn2 = linear_init_uniform(rng, d_ff, d_model)

    params.layers.append(layer)

# Final normalization and output layer
params.pre_output_norm = layernorm_init_identity(d_model)
rng,params.output = linear_init_uniform(rng, d_model, n_vocab)

Add an optimizer, and we are pronto a romblare.

Running

$ export JAX_PLATFORM_NAME=gpu # or cpu
$ export JAX_LOG_COMPILES=1 # or 0
$ export XLA_FLAGS=--xla_dump_to=./xla-dumps/  # Also dumps jaxprs to this folder
$ python main.py -help
$ python main.py -layers 3 -dmodel 512 -heads 8 -dk 64 -dff 2048 

Results at https://wandb.ai/awfidius/pure-transformer

Acknowledgements

The model is based on https://github.com/vpj/jax_transformer/blob/master/transformer.py, and the Adam and Dataset classes in jaxutils are almost direct copies from https://github.com/vpj/jax_transformer

About

A pure-functional implementation of a machine learning transformer model in Python/JAX

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages