Skip to content

Commit

Permalink
Merge pull request #39 from ww-tech/train-test-split-updates
Browse files Browse the repository at this point in the history
Train test split updates
  • Loading branch information
briangrahamww authored Dec 13, 2019
2 parents 9e223e4 + 43be57c commit 088c3ca
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 26 deletions.
11 changes: 3 additions & 8 deletions primrose/pipelines/sklearn_preprocessing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,9 @@ def init_pipeline(self):

for operation in self.node_config["operations"]:

args=None
if 'args' in operation:
args = operation['args']

columns=None
if 'columns' in operation:
columns = operation["columns"]

args = operation.get('args', None)
columns = operation.get('columns', None)

p = SklearnPreprocessingPipeline._instantiate_preprocessor(operation['class'], args, columns)
ts.add(p)

Expand Down
39 changes: 23 additions & 16 deletions primrose/pipelines/train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,25 +74,32 @@ def _train_test_split(self, data):
"""
logging.info("Splitting data into testing and training sets.")

if 'target_variable' in self.node_config:
data_train, data_test, target_train, target_test = train_test_split(
data[self.features(data)],
data[self.node_config['target_variable']],
test_size=(1.0 - float(self.node_config['training_fraction'])),
random_state=self.node_config['seed'])

# re-merge training and target data into a single dataframe for transforming
train_data_to_transform = pd.concat([data_train, target_train], axis=1)
test_data_to_transform = pd.concat([data_test, target_test], axis=1)
test_size = (1.0 - float(self.node_config['training_fraction']))

if test_size == 0:
train_data_to_transform = data
test_data_to_transform = pd.DataFrame()

else:
data_train, data_test = train_test_split(
data[self.features(data)],
test_size=(1.0 - float(self.node_config['training_fraction'])),
random_state=self.node_config['seed'])
if 'target_variable' in self.node_config:
data_train, data_test, target_train, target_test = train_test_split(
data[self.features(data)],
data[self.node_config['target_variable']],
test_size=(1.0 - float(self.node_config['training_fraction'])),
random_state=self.node_config['seed'])

# re-merge training and target data into a single dataframe for transforming
train_data_to_transform = pd.concat([data_train, target_train], axis=1)
test_data_to_transform = pd.concat([data_test, target_test], axis=1)

else:
data_train, data_test = train_test_split(
data[self.features(data)],
test_size=(1.0 - float(self.node_config['training_fraction'])),
random_state=self.node_config['seed'])

train_data_to_transform = data_train
test_data_to_transform = data_test
train_data_to_transform = data_train
test_data_to_transform = data_test

logging.info('Training data rows: {}, Testing data rows: {}'.format(len(train_data_to_transform),
len(test_data_to_transform)))
Expand Down
10 changes: 8 additions & 2 deletions primrose/transformers/sklearn_preprocessing_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
Carl Anderson ([email protected])
"""
from primrose.base.transformer import AbstractTransformer
import logging
import pandas as pd

from primrose.base.transformer import AbstractTransformer

class SklearnPreprocessingTransformer(AbstractTransformer):

def __init__(self, preprocessor, columns):
Expand Down Expand Up @@ -60,7 +62,11 @@ def transform(self, data):

else:
scaled_features = self.preprocessor.transform(data.values)
scaled_features_df = pd.DataFrame(scaled_features, index=data.index, columns=data.columns)
try:
scaled_features_df = pd.DataFrame(scaled_features, index=data.index, columns=data.columns)
except ValueError:
logging.info(f'{self.preprocessor.__class__.__name__} instance changed the number of columns. Returning raw values')
return pd.DataFrame(scaled_features)
return scaled_features_df

return self.preprocessor.transform(data)
Expand Down

0 comments on commit 088c3ca

Please sign in to comment.