Skip to content

Commit

Permalink
refactor(data_loader): cast "Target" column to float32 in get_dataset…
Browse files Browse the repository at this point in the history
… function
  • Loading branch information
rileydrizzy committed Jan 24, 2024
1 parent 2f3b748 commit df66f0b
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,15 @@ def get_dataset(file_path, batch_size=2, shuffle_size=100, shuffle=False):
f"Required columns {required_columns} not present in the DataFrame."
)

dataframe = dataframe.with_columns(dataframe["Target"].cast(pl.Float32))
features_df = dataframe["Log"].to_numpy()
target_df = dataframe["Target"].to_numpy()

dataset = tf.data.Dataset.from_tensor_slices((features_df, target_df))

if shuffle:
dataset = dataset.shuffle(shuffle_size)
# dataset = dataset.map(convert_label_to_float)
dataset = dataset.batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)
return dataset

Expand Down

0 comments on commit df66f0b

Please sign in to comment.