Skip to content

Commit

Permalink
feat(KDP): adding options enums
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrlaczkowski committed Mar 12, 2024
1 parent 7e7e927 commit 0ef86a7
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions kdp/processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import OrderedDict
from collections.abc import Callable, Generator
from enum import auto
from typing import Any

import tensorflow as tf
Expand Down Expand Up @@ -238,6 +239,16 @@ def create_concat_layer(name="concat") -> tf.keras.layers.Layer:
)


class CategoryEncodingOptions(auto):
ONE_HOT_ENCODING = "ONE_HOT_ENCODING"
EMBEDDING = "EMBEDDING"


class OutputModeOptions(auto):
CONCAT = "concat"
DICT = "dict"


class PreprocessingModel:
def __init__(
self,
Expand All @@ -249,10 +260,10 @@ def __init__(
feature_crosses: list[tuple[str, str, int]] = None,
numeric_feature_buckets: dict[str, list[float]] = None,
features_stats_path: str = None,
category_encoding_option: str = "EMBEDDING",
output_mode: str = "concat",
category_encoding_option: str = CategoryEncodingOptions.EMBEDDING,
output_mode: str = OutputModeOptions.CONCAT,
overwrite_stats: bool = False,
embedding_custom_size: dict = None,
embedding_custom_size: dict[str, int] = None,
log_to_file: bool = False,
features_specs: dict[str, FeatureType | str] = None,
) -> None:
Expand Down Expand Up @@ -411,7 +422,7 @@ def _add_pipeline_categorical(self, feature_name: str, input_layer, stats: dict)
name=f"lookup_{feature_name}",
)

if self.category_encoding_option.upper() == "EMBEDDING":
if self.category_encoding_option.upper() == CategoryEncodingOptions.EMBEDDING:
preprocessor.add_processing_step(
layer_creator=PreprocessorLayerFactory.create_embedding_layer,
input_dim=len(vocab) + 1,
Expand Down Expand Up @@ -505,7 +516,7 @@ def _prepare_outputs(self) -> None:
Two outputs are possible based on output_model variable.
"""
logger.info("Building preprocessor Model")
if self.output_mode == "concat":
if self.output_mode == OutputModeOptions.CONCAT:
self.features_to_concat = list(self.outputs.values())
self.concat = tf.keras.layers.Concatenate(axis=-1)
self.outputs = self.concat(self.features_to_concat)
Expand Down

0 comments on commit 0ef86a7

Please sign in to comment.