Skip to content

Commit

Permalink
fixed the streaming bug (#75)
Browse files Browse the repository at this point in the history
* fixed the streaming bug

* rewrite the key from test to validation to be consistent with dataset names
  • Loading branch information
ChaoPang authored Nov 15, 2024
1 parent a58255d commit 3fee9ab
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 23 deletions.
34 changes: 22 additions & 12 deletions src/cehrbert/data_generators/hf_data_generator/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,7 @@ def create_cehrbert_pretraining_dataset(
required_columns = TRANSFORMER_COLUMNS + CEHRBERT_COLUMNS

# Remove patients without any records
dataset = dataset.filter(
lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_concepts"]],
num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None,
batched=True,
batch_size=data_args.preprocessing_batch_size,
)
dataset = filter_dataset(dataset, data_args)

# If the data is already in meds, we don't need to sort the sequence anymore
if data_args.is_data_in_meds:
Expand Down Expand Up @@ -82,12 +77,7 @@ def create_cehrbert_finetuning_dataset(
required_columns = TRANSFORMER_COLUMNS + CEHRBERT_COLUMNS + FINETUNING_COLUMNS

# Remove patients without any records
dataset = dataset.filter(
lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_concepts"]],
num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None,
batched=True,
batch_size=data_args.preprocessing_batch_size,
)
dataset = filter_dataset(dataset, data_args)

if data_args.is_data_in_meds:
mapping_functions = [
Expand Down Expand Up @@ -120,6 +110,26 @@ def create_cehrbert_finetuning_dataset(
return dataset


def filter_dataset(dataset: Union[Dataset, DatasetDict], data_args: DataTrainingArguments):
# Remove patients without any records
# check if DatatsetDict or IterableDatasetDict, if so, filter each dataset
if isinstance(dataset, DatasetDict) and data_args.streaming:
for key in dataset.keys():
dataset[key] = dataset[key].filter(
lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_concepts"]],
batched=True,
batch_size=data_args.preprocessing_batch_size,
)
else:
dataset = dataset.filter(
lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_concepts"]],
num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None,
batched=True,
batch_size=data_args.preprocessing_batch_size,
)
return dataset


def apply_cehrbert_dataset_mapping(
dataset: Union[DatasetDict, Dataset, IterableDataset, IterableDatasetDict],
mapping_function: DatasetMapping,
Expand Down
14 changes: 3 additions & 11 deletions src/cehrbert/runners/hf_cehrbert_pretrain_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def main():
test_size=data_args.validation_split_percentage,
seed=training_args.seed,
)
dataset = DatasetDict({"train": dataset["train"], "validation": dataset["test"]})
else:
raise RuntimeError(
f"Can not split the data. If streaming is enabled, validation_split_num needs "
Expand Down Expand Up @@ -261,20 +262,11 @@ def filter_func(examples):
if not data_args.streaming:
processed_dataset.set_format("pt")

eval_dataset = None
if isinstance(processed_dataset, DatasetDict) or isinstance(processed_dataset, IterableDatasetDict):
train_dataset = processed_dataset["train"]
if "validation" in processed_dataset:
eval_dataset = processed_dataset["validation"]
else:
train_dataset = processed_dataset

trainer = Trainer(
model=model,
data_collator=collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
# compute_metrics=compute_metrics,
train_dataset=processed_dataset["train"],
eval_dataset=processed_dataset["validation"],
args=training_args,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import shutil
import sys
import tempfile
import unittest
from pathlib import Path

from datasets import disable_caching

from cehrbert.runners.hf_cehrbert_pretrain_runner import main

disable_caching()
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["WANDB_MODE"] = "disabled"
os.environ["TRANSFORMERS_VERBOSITY"] = "info"


class HfCehrBertRunnerIntegrationTest(unittest.TestCase):
def setUp(self):
# Get the root folder of the project
root_folder = Path(os.path.abspath(__file__)).parent.parent.parent.parent
data_folder = os.path.join(root_folder, "sample_data", "pretrain")
# Create a temporary directory to store model and tokenizer
self.temp_dir = tempfile.mkdtemp()
self.model_folder_path = os.path.join(self.temp_dir, "model")
Path(self.model_folder_path).mkdir(parents=True, exist_ok=True)
self.dataset_prepared_path = os.path.join(self.temp_dir, "dataset_prepared_path")
Path(self.dataset_prepared_path).mkdir(parents=True, exist_ok=True)
sys.argv = [
"hf_cehrbert_pretraining_runner.py",
"--model_name_or_path",
self.model_folder_path,
"--tokenizer_name_or_path",
self.model_folder_path,
"--output_dir",
self.model_folder_path,
"--data_folder",
data_folder,
"--dataset_prepared_path",
self.dataset_prepared_path,
"--max_steps",
"10",
"--streaming",
]

def tearDown(self):
# Remove the temporary directory
shutil.rmtree(self.temp_dir)

def test_train_model(self):
main()


if __name__ == "__main__":
unittest.main()

0 comments on commit 3fee9ab

Please sign in to comment.