Skip to content

Commit

Permalink
feat(KDP): adding TransformerBlocks layers to categorical and text va…
Browse files Browse the repository at this point in the history
…riables
  • Loading branch information
piotrlaczkowski committed Apr 30, 2024
1 parent b8608f7 commit 0338fb4
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 11 deletions.
Binary file added docs/imgs/TransformerBlocks.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
67 changes: 67 additions & 0 deletions kdp/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,70 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
"""
output = tf.cast(inputs, tf.float32)
return output


class TransformerBlock(tf.keras.layers.Layer):
"""Class that implements a transformer block."""

def __init__(
self,
dim_model: int = 32,
num_heads: int = 3,
ff_units: int = 16,
dropout_rate: float = 0.2,
**kwargs,
):
"""Initializes the transformer block.
Args:
dim_model (int): Dimension of the model.
num_heads (int): Number of attention heads.
ff_units (int): Units in the feed-forward layer.
dropout_rate (float): Dropout rate to apply.
kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs)
self.d_model = dim_model
self.num_heads = num_heads
self.ff_units = ff_units
self.dropout_rate = dropout_rate

# Define layers
self.multihead_attention = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=dim_model)
self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
self.add1 = tf.keras.layers.Add()
self.layer_norm1 = tf.keras.layers.LayerNormalization()

self.ff1 = tf.keras.layers.Dense(ff_units, activation="relu")
self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
self.ff2 = tf.keras.layers.Dense(dim_model)
self.add2 = tf.keras.layers.Add()
self.layer_norm2 = tf.keras.layers.LayerNormalization()

def call(self, inputs: tf.Tensor) -> tf.Tensor:
"""Defines the forward pass for the transformer block.
Args:
inputs (tf.Tensor): Input tensor for the block.
Returns:
tf.Tensor: Output tensor after processing.
"""
# Reshape if needed
if len(inputs.shape) == 2:
inputs = tf.expand_dims(inputs, axis=1)

# Multi-head attention
attention = self.multihead_attention(inputs, inputs)
attention = self.dropout1(attention)
attention = self.add1([inputs, attention])
attention_norm = self.layer_norm1(attention)

# Feed-forward layers
ff = self.ff1(attention_norm)
ff = self.dropout2(ff)
ff = self.ff2(ff)
ff = self.add2([attention_norm, ff])
ff_norm = self.layer_norm2(ff)

return ff_norm
19 changes: 18 additions & 1 deletion kdp/layers_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import tensorflow as tf

from kdp.custom_layers import CastToFloat32Layer, TextPreprocessingLayer
from kdp.custom_layers import CastToFloat32Layer, TextPreprocessingLayer, TransformerBlock


class PreprocessorLayerFactory:
Expand Down Expand Up @@ -252,3 +252,20 @@ def cast_to_float32_layer(name: str = "cast_to_float32", **kwargs: dict) -> tf.k
name=name,
**kwargs,
)

@staticmethod
def transformer_block_layer(name: str = "transformer", **kwargs: dict) -> tf.keras.layers.Layer:
"""Create a TransformerBlock layer.
Args:
name: The name of the layer.
**kwargs: Additional keyword arguments to pass to the layer constructor.
Returns:
An instance of the TransformerBlock layer.
"""
return PreprocessorLayerFactory.create_layer(
layer_class=TransformerBlock,
name=name,
**kwargs,
)
63 changes: 54 additions & 9 deletions kdp/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def __init__(
overwrite_stats: bool = False,
log_to_file: bool = False,
features_specs: dict[str, FeatureType | str] = None,
transfo_nr_blocks: int = None,
transfo_nr_heads: int = 3,
transfo_ff_units: int = 16,
transfo_dropout_rate: float = 0.25,
) -> None:
"""Initialize a preprocessing model.
Expand All @@ -121,6 +125,11 @@ def __init__(
overwrite_stats (bool): A boolean indicating whether to overwrite the statistics.
log_to_file (bool): A boolean indicating whether to log to a file.
features_specs (dict[str, FeatureType | str]): A dictionary containing the features and their types.
transfo_nr_blocks (int): The number of transformer blocks for the transformer block
(default=None, transformer block is disabled).
transfo_nr_heads (int): The number of heads for the transformer block (categorical variables).
transfo_ff_units (int): The number of feed forward units for the transformer
transfo_dropout_rate (float): The dropout rate for the transformer block (default=0.25).
"""
self.path_data = path_data
self.batch_size = batch_size or 50_000
Expand All @@ -130,12 +139,18 @@ def __init__(
self.feature_crosses = feature_crosses or []
self.output_mode = output_mode
self.overwrite_stats = overwrite_stats
# transformer blocks controll
self.transfo_nr_blocks = transfo_nr_blocks
self.transfo_nr_heads = transfo_nr_heads
self.transfo_ff_units = transfo_ff_units
self.transfo_dropout_rate = transfo_dropout_rate

# PLACEHOLDERS
self.preprocessors = {}
self.inputs = {}
self.signature = {}
self.outputs = {}
self.outputs_categorical = {}

if log_to_file:
logger.info("Logging to file enabled 🗂️")
Expand Down Expand Up @@ -322,11 +337,6 @@ def _add_pipeline_numeric(self, feature_name: str, input_layer, stats: dict) ->
# defining the pipeline input layer
_output_pipeline = preprocessor.chain(input_layer=input_layer)

# adjusting output
# if _feature.feature_type == FeatureType.FLOAT_DISCRETIZED:
# Cast the crossed feature to float32
# _output_pipeline = tf.cast(_output_pipeline, tf.float32)

# defining output
self.outputs[feature_name] = _output_pipeline

Expand Down Expand Up @@ -402,8 +412,9 @@ def _add_pipeline_categorical(self, feature_name: str, input_layer, stats: dict)
layer_creator=PreprocessorLayerFactory.flatten_layer,
name=f"flatten_{feature_name}",
)

# adding outputs
self.outputs[feature_name] = preprocessor.chain(input_layer=input_layer)
self.outputs_categorical[feature_name] = preprocessor.chain(input_layer=input_layer)

def _add_pipeline_text(self, feature_name: str, input_layer, stats: dict) -> None:
"""Add a text preprocessing step to the pipeline.
Expand Down Expand Up @@ -455,7 +466,9 @@ def _add_pipeline_text(self, feature_name: str, input_layer, stats: dict) -> Non
layer_creator=PreprocessorLayerFactory.cast_to_float32_layer,
name=f"cast_to_float_{feature_name}",
)
self.outputs[feature_name] = preprocessor.chain(input_layer=input_layer)
# adding outputs
self.outputs_categorical[feature_name] = preprocessor.chain(input_layer=input_layer)
# self.outputs[feature_name] = preprocessor.chain(input_layer=input_layer)

def _add_pipeline_cross(self) -> None:
"""Add a crossing preprocessing step to the pipeline.
Expand Down Expand Up @@ -500,9 +513,41 @@ def _prepare_outputs(self) -> None:
"""
logger.info("Building preprocessor Model")
if self.output_mode == OutputModeOptions.CONCAT:
# getting all features to concatenate
self.features_to_concat = list(self.outputs.values())
self.concat = tf.keras.layers.Concatenate(axis=-1)
self.outputs = self.concat(self.features_to_concat)
self.features_cat_to_concat = list(self.outputs_categorical.values())

# Concatenate numerical features
concat_num = tf.keras.layers.Concatenate(
name="ConcatenateNumeric",
axis=-1,
)(self.features_to_concat)

# Concatenate categorical features
concat_cat = tf.keras.layers.Concatenate(
name="ConcatenateCategorical",
axis=-1,
)(self.features_cat_to_concat)

# adding transformer layers
if self.transfo_nr_blocks:
logger.info(f"Adding transformer blocks: #{self.transfo_nr_blocks}")
for block_idx in range(self.transfo_nr_blocks):
concat_cat = PreprocessorLayerFactory.transformer_block_layer(
dim_model=concat_cat.shape[1],
num_heads=self.transfo_nr_heads,
ff_units=self.transfo_ff_units,
dropout_rate=self.transfo_dropout_rate,
name=f"transformer_block_{block_idx}_{self.transfo_nr_heads}heads",
)(concat_cat)

# Combine concatenated numerical and categorical features
self.outputs = tf.keras.layers.Concatenate(
name="ConcatenateAllFeatures",
axis=-1,
)([concat_num, concat_cat])

# self.outputs = self.concat(self.features_to_concat + [self.concat])
logger.info("Concatenating outputs mode enabled")
else:
outputs = OrderedDict([(k, None) for k in self.inputs if k in self.outputs])
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "kdp"
version = "1.4.0"
version = "1.5.1"
documentation = "http://piotrlaczkowski.github.io/keras-data-processor/"
repository = "https://github.com/piotrlaczkowski/keras-data-processor"
description = "Data Preprocessing model based on Keras preprocessing layers"
Expand Down

0 comments on commit 0338fb4

Please sign in to comment.