Skip to content

Commit

Permalink
refactor(KDP): adding missing layers
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrlaczkowski committed Apr 14, 2024
1 parent 3e6dd29 commit 18e3143
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
4 changes: 2 additions & 2 deletions kdp/layers_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def crossing_layer(name: str, **kwargs) -> tf.keras.layers.Layer:
)

@staticmethod
def create_crossing_layer(name: str, **kwargs) -> tf.keras.layers.Layer:
def flatten_layer(name: str, **kwargs) -> tf.keras.layers.Layer:
"""Create a crossing layer.
Args:
Expand All @@ -180,7 +180,7 @@ def create_crossing_layer(name: str, **kwargs) -> tf.keras.layers.Layer:
An instance of the Flatten layer.
"""
return PreprocessorLayerFactory.create_layer(
layer_class=tf.keras.layers.HashedCrossing,
layer_class=tf.keras.layers.Flatten,
name=name,
**kwargs,
)
Expand Down
17 changes: 8 additions & 9 deletions kdp/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def _add_pipeline_cross(self, stats: dict) -> None:

_feature_name_crossed = f"{feature_a}_x_{feature_b}"
preprocessor.add_processing_step(
layer_creator=PreprocessorLayerFactory.create_crossing_layer,
layer_creator=PreprocessorLayerFactory.crossing_layer,
nr_bins=nr_bins,
name=f"cross_{_feature_name_crossed}",
)
Expand All @@ -466,28 +466,27 @@ def _add_pipeline_text(self, feature_name: str, input_layer) -> None:
feature_name (str): The name of the feature to be preprocessed.
input_layer: The input layer for the feature.
"""
# getting feature object
_feature = self.features_specs[feature_name]

# initializing preprocessor
preprocessor = FeaturePreprocessor(name=feature_name)

# checking if we have custom setting per feature
_feature_config = self.text_features_config.get(feature_name) or self.text_features_config
# getting stop words for text preprocessing
_stop_words = _feature_config.get("stop_words")
_stop_words = _feature.kwargs.get("stop_words", [])

if _stop_words:
preprocessor.add_processing_step(
layer_creator=PreprocessorLayerFactory.create_text_preprocessing_layer,
layer_creator=PreprocessorLayerFactory.text_preprocessing_layer,
stop_words=_stop_words,
name=f"text_preprocessor_{feature_name}",
)
preprocessor.add_processing_step(
layer_creator=PreprocessorLayerFactory.create_text_vectorization_layer,
conf=_feature_config,
layer_creator=PreprocessorLayerFactory.text_vectorization_layer,
name=f"text_vactorizer_{feature_name}",
)

self.outputs[feature_name] = preprocessor.chain(input_layer=input_layer)
# updating output vector dim
self.output_dims += _feature_config["output_sequence_length"]

def _prepare_outputs(self) -> None:
"""Preparing the outputs of the model.
Expand Down

0 comments on commit 18e3143

Please sign in to comment.