Skip to content

Commit

Permalink
feat(KDP): connecting custom embedding size dict to the code
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrlaczkowski committed Mar 12, 2024
1 parent 4b4ce7d commit 7e7e927
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions kdp/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def __init__(
category_encoding_option: str = "EMBEDDING",
output_mode: str = "concat",
overwrite_stats: bool = False,
embedding_custom_size: int = None,
embedding_custom_size: dict = None,
log_to_file: bool = False,
features_specs: dict[str, FeatureType | str] = None,
) -> None:
Expand All @@ -269,7 +269,7 @@ def __init__(
self.numeric_feature_buckets = numeric_feature_buckets or {}
self.output_mode = output_mode
self.overwrite_stats = overwrite_stats
self.embedding_custom_size = embedding_custom_size
self.embedding_custom_size = embedding_custom_size or {}

# PLACEHOLDERS
self.preprocessors = {}
Expand All @@ -286,10 +286,13 @@ def __init__(
self._init_stats()

def _init_stats(self) -> None:
"""Initialize the statistics for the model."""
# Initializing Data Stats object
# we only need numeric and cat features stats for layers
# crosses and numeric do not need layers init
"""Initialize the statistics for the model.
Note:
Initializing Data Stats object
we only need numeric and cat features stats for layers
crosses and numeric do not need layers init
"""
_data_stats_kwrgs = {"path_data": self.path_data}
if self.numeric_features:
_data_stats_kwrgs["numeric_cols"] = self.numeric_features
Expand Down Expand Up @@ -390,7 +393,7 @@ def _add_pipeline_categorical(self, feature_name: str, input_layer, stats: dict)
"""
vocab = stats["vocab"]
dtype = stats["dtype"]
emb_size = self._embedding_size_rule(nr_categories=len(vocab))
emb_size = self.embedding_custom_size.get(feature_name) or self._embedding_size_rule(nr_categories=len(vocab))
preprocessor = FeaturePreprocessor(name=feature_name)
# setting up lookup layer based on dtype
if dtype == tf.string:
Expand Down Expand Up @@ -496,7 +499,11 @@ def _add_pipeline_cross(self, stats: dict) -> None:
self.output_dims += nr_bins

def _prepare_outputs(self) -> None:
"""Preparing the outputs of the model."""
"""Preparing the outputs of the model.
Note:
Two outputs are possible based on output_model variable.
"""
logger.info("Building preprocessor Model")
if self.output_mode == "concat":
self.features_to_concat = list(self.outputs.values())
Expand Down

0 comments on commit 7e7e927

Please sign in to comment.