Skip to content

Commit

Permalink
opensource more modules including vit so the tiny_vit_imagenet.py exa…
Browse files Browse the repository at this point in the history
…mple works externally.

PiperOrigin-RevId: 691479221
  • Loading branch information
Qwlouse authored and The kauldron Authors committed Oct 30, 2024
1 parent 53d078f commit bc8767c
Show file tree
Hide file tree
Showing 8 changed files with 1,171 additions and 1 deletion.
2 changes: 1 addition & 1 deletion examples/tiny_vit_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_config():


def _make_ds(training: bool):
return kd.data.Tfds(
return kd.data.py.Tfds(
name="imagenet_resized/64x64",
split="train" if training else "validation",
shuffle=True if training else False,
Expand Down
15 changes: 15 additions & 0 deletions kauldron/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,18 @@
from kauldron.modules.pos_embeddings import ZeroEmbedding
from kauldron.modules.models import FlatAutoencoder
from kauldron.modules.models import Sequential

# Models should be open-sourced on a individual basis
from kauldron.modules.attention import ImprovedMultiHeadDotProductAttention
from kauldron.modules.attention import MultiHeadDotProductAttention
from kauldron.modules.input_embeddings import Patchify
from kauldron.modules.input_embeddings import PatchifyEmbed

# transformer
from kauldron.modules.transformers import PreNormBlock
from kauldron.modules.transformers import PostNormBlock
from kauldron.modules.transformers import ParallelAttentionBlock
from kauldron.modules.transformers import TransformerMLP
# vit
from kauldron.modules.vit import Vit
from kauldron.modules.vit import VitEncoder
50 changes: 50 additions & 0 deletions kauldron/modules/adapter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2024 The kauldron Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test."""

from flax import linen as nn
import jax
import jax.numpy as jnp
from kauldron import kd
import numpy as np


def test_external():
model = kd.nn.ExternalModule(
model=nn.Dropout(0.5),
keys={
'inputs': 'a',
},
train_kwarg_name='~deterministic',
)

inputs = jnp.ones((5, 5))
input_kwargs = kd.kontext.get_from_keys_obj(
{'a': inputs, 'b': jnp.zeros(())}, model
)
out_train = model.apply(
{},
**input_kwargs,
is_training_property=True,
rngs={'dropout': jax.random.PRNGKey(0)},
)
out_eval = model.apply(
{},
**input_kwargs,
is_training_property=False,
)

assert not np.array_equal(out_train, inputs)
np.testing.assert_array_equal(out_eval, inputs)
241 changes: 241 additions & 0 deletions kauldron/modules/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# Copyright 2024 The kauldron Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Attention layers."""

from __future__ import annotations

from typing import Callable, Optional

import flax.linen as nn
import jax
import jax.numpy as jnp
import kauldron.modules as knn
from kauldron.typing import Axes, Bool, Float, Initializer, typechecked # pylint: disable=g-multiple-import,g-importing-member


def softmax(
x: Float['*a'], axis: Axes = -1, dtype: jnp.dtype | None = jnp.float32
) -> Float['*a']:
if dtype is None:
dtype = x.dtype
return jax.nn.softmax(x.astype(dtype), axis=axis).astype(x.dtype)


class MultiHeadDotProductAttention(nn.MultiHeadDotProductAttention):
"""Wrapper around `nn.MultiHeadDotProductAttention` using `knn.train_property`."""

is_training = knn.train_property()

def __post_init__(self):
super().__post_init__()
if self.deterministic is not None:
raise ValueError(
'`kd.nn.Dropout` should not use `deterministic`. Instead the '
'training mode is set through `is_training_property`. See '
'`kd.nn.train_property`.'
)

@typechecked
@nn.compact
def __call__(
self,
inputs_q: Float['*b q dq'],
inputs_k: Optional[Float['*b k dk']] = None,
inputs_v: Optional[Float['*b k dv']] = None,
*,
mask: Optional[Bool['*b #heads #q #k']] = None,
) -> Float['*b q do']:
return super().__call__(
inputs_q=inputs_q,
inputs_k=inputs_k,
inputs_v=inputs_v,
deterministic=not self.is_training,
mask=mask,
)


@typechecked
def dot_product_attention_weights(
query: Float['*b q h d'],
key: Float['*b k h d'],
softmax_axis: Axes = -1,
bias: Optional[Float['*b #h #q #k']] = None,
mask: Optional[Bool['*b #h #q #k']] = None,
) -> Float['*b h q k']:
"""Computes dot-product attention weights given query and key.
q: number of queries, k: number of keys, h: number of heads
d: dimension of keys/queries
Args:
query: Queries for calculating attention
key: Keys for calculating attention.
softmax_axis: The axes over which the softmax is taken. defaults to -1 which
is the keys axis. For Slot-Attention set to -2 (queries).
bias: Bias for the attention weights. This should be broadcastable to the
shape `[*b h q k]`. This can be used for incorporating causal masks,
padding masks, proximity bias, etc.
mask: Mask for the attention weights. This should be broadcastable to the
shape `[*b h q k]`. This can be used for incorporating causal masks.
Attention weights are masked out if their corresponding mask value is
`False`.
Returns:
Attention weights of shape `[*b h q k]`.
"""
query = query / jnp.sqrt(query.shape[-1])
attn_weights = jnp.einsum('...qhd,...khd->...hqk', query, key)

if bias is not None:
attn_weights = attn_weights + bias

if mask is not None:
big_neg = jnp.finfo(query.dtype).min
attn_weights = jnp.where(mask, attn_weights, big_neg)

attn_weights = softmax(attn_weights, axis=softmax_axis)

return attn_weights


class ImprovedMultiHeadDotProductAttention(nn.Module):
"""Multi-head dot-product attention.
Simplified nn.MultiheadDotProductAttention with a few modifications:
- include a softmax axis
- accept an (additive) bias for the attention weights (in addition to mask)
- dropped support for dropout
- add attention weights to interms as "interms.PATH.TO.LAYER.attn_weights"
Attributes:
num_heads: Number of attention heads.
qk_size: Total dimension of the keys and queries.
v_size: Total dimension of the values. Defaults to qk_size.
softmax_axis: The axis over which the softmax is taken. defaults to -1 which
is the keys axis. For Slot-Attention set to -2 (queries).
"""

num_heads: int
qk_features: Optional[int] = None
v_features: Optional[int] = None
out_features: Optional[int] = None
softmax_axis: Axes = -1

normalize_qk: bool = False

kernel_init: Initializer = nn.initializers.lecun_normal()
bias_init: Initializer = nn.initializers.zeros_init()
use_bias: bool = True
attn_weights_fn: Callable[..., Float['...']] = dot_product_attention_weights
decode: bool = False

interms = knn.interms_property()

@typechecked
@nn.compact
def __call__(
self,
inputs_q: Float['*b q dq'],
inputs_k: Optional[Float['*b kv dk']] = None, # defaults to inputs_q
inputs_v: Optional[Float['*b kv dv']] = None, # defaults to inputs_k
*,
bias: Optional[Float['*b #num_heads #q #kv']] = None,
mask: Optional[Bool['*b #num_heads #q #kv']] = None,
) -> Float['*b q do']:
"""Applies multi-head dot product attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.
Args:
inputs_q: Input tokens from which queries are computed.
inputs_k: Input tokens from which the keys are computed (defaults to
inputs_q).
inputs_v: Input tokens from which the values are computed (defaults to
inputs_k).
bias: Bias for the attention weights. This can be used for incorporating
causal masks, padding masks, proximity bias, etc.
mask: Attention mask, where attention weights are masked out if their mask
value is `False`.
Returns:
output tokens (linear projection of an attention weighted average of value
tokens per query).
"""
qk_features = self.qk_features or inputs_q.shape[-1]
v_features = self.v_features or qk_features

if qk_features % self.num_heads:
raise ValueError(f'{self.num_heads=} must divide {qk_features=}.')
if v_features % self.num_heads:
raise ValueError(f'{self.num_heads=} must divide {v_features=}.')

if inputs_k is None:
if inputs_v is not None:
raise ValueError('inputs_k cannot be None if inputs_v is given.')
inputs_k = inputs_q
if inputs_v is None:
inputs_v = inputs_k

# Project inputs_q to multi-headed queries and keys.
# dimensions are then [*b q h qk_size]
def dense(name, x, features):
return nn.DenseGeneral(
features=(self.num_heads, features // self.num_heads),
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=self.use_bias,
dtype=x.dtype,
name=name,
)(x)

query = dense('query', inputs_q, qk_features)
key = dense('key', inputs_k, qk_features)
value = dense('value', inputs_v, v_features)

if self.normalize_qk:
# Normalizing query and key projections stabilizes training with higher
# LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis.
query = nn.LayerNorm(
name='query_norm', use_bias=False, dtype=query.dtype)(query)
key = nn.LayerNorm(
name='key_norm', use_bias=False, dtype=key.dtype)(key)

# Compute attention weights.
attn_weights = self.attn_weights_fn( # pylint: disable=redundant-keyword-arg
query=query,
key=key,
softmax_axis=self.softmax_axis,
bias=bias,
mask=mask,
)

# accessible as `interms.[path.to.this.module].attn_weights[0]`
self.interms['attn_weights'] = attn_weights

# Return weighted sum over values for each query position.
x = jnp.einsum('...hqk,...khd->...qhd', attn_weights, value)

# Back to the original input dimensions.
return nn.DenseGeneral(
features=self.out_features or inputs_q.shape[-1],
axis=(-2, -1),
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=self.use_bias,
dtype=x.dtype,
name='out',
)(x)
Loading

0 comments on commit bc8767c

Please sign in to comment.