diff --git a/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py b/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py index fc687d5b..adaf4b7b 100644 --- a/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py +++ b/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py @@ -38,12 +38,7 @@ def create_cehrbert_pretraining_dataset( required_columns = TRANSFORMER_COLUMNS + CEHRBERT_COLUMNS # Remove patients without any records - dataset = dataset.filter( - lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_concepts"]], - num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None, - batched=True, - batch_size=data_args.preprocessing_batch_size, - ) + dataset = filter_dataset(dataset, data_args) # If the data is already in meds, we don't need to sort the sequence anymore if data_args.is_data_in_meds: @@ -82,12 +77,7 @@ def create_cehrbert_finetuning_dataset( required_columns = TRANSFORMER_COLUMNS + CEHRBERT_COLUMNS + FINETUNING_COLUMNS # Remove patients without any records - dataset = dataset.filter( - lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_concepts"]], - num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None, - batched=True, - batch_size=data_args.preprocessing_batch_size, - ) + dataset = filter_dataset(dataset, data_args) if data_args.is_data_in_meds: mapping_functions = [ @@ -120,6 +110,26 @@ def create_cehrbert_finetuning_dataset( return dataset +def filter_dataset(dataset: Union[Dataset, DatasetDict], data_args: DataTrainingArguments): + # Remove patients without any records + # check if DatatsetDict or IterableDatasetDict, if so, filter each dataset + if isinstance(dataset, DatasetDict) and data_args.streaming: + for key in dataset.keys(): + dataset[key] = dataset[key].filter( + lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_concepts"]], + batched=True, + batch_size=data_args.preprocessing_batch_size, + ) + else: + dataset = dataset.filter( + lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_concepts"]], + num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None, + batched=True, + batch_size=data_args.preprocessing_batch_size, + ) + return dataset + + def apply_cehrbert_dataset_mapping( dataset: Union[DatasetDict, Dataset, IterableDataset, IterableDatasetDict], mapping_function: DatasetMapping, diff --git a/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py b/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py index eb1e6ea5..f7dd3e47 100644 --- a/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py +++ b/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py @@ -206,6 +206,7 @@ def main(): test_size=data_args.validation_split_percentage, seed=training_args.seed, ) + dataset = DatasetDict({"train": dataset["train"], "validation": dataset["test"]}) else: raise RuntimeError( f"Can not split the data. If streaming is enabled, validation_split_num needs " @@ -261,20 +262,11 @@ def filter_func(examples): if not data_args.streaming: processed_dataset.set_format("pt") - eval_dataset = None - if isinstance(processed_dataset, DatasetDict) or isinstance(processed_dataset, IterableDatasetDict): - train_dataset = processed_dataset["train"] - if "validation" in processed_dataset: - eval_dataset = processed_dataset["validation"] - else: - train_dataset = processed_dataset - trainer = Trainer( model=model, data_collator=collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - # compute_metrics=compute_metrics, + train_dataset=processed_dataset["train"], + eval_dataset=processed_dataset["validation"], args=training_args, ) diff --git a/tests/integration_tests/runners/hf_cehrbert_pretrain_runner_streaming_test.py b/tests/integration_tests/runners/hf_cehrbert_pretrain_runner_streaming_test.py new file mode 100644 index 00000000..234c44ce --- /dev/null +++ b/tests/integration_tests/runners/hf_cehrbert_pretrain_runner_streaming_test.py @@ -0,0 +1,55 @@ +import os +import shutil +import sys +import tempfile +import unittest +from pathlib import Path + +from datasets import disable_caching + +from cehrbert.runners.hf_cehrbert_pretrain_runner import main + +disable_caching() +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +os.environ["WANDB_MODE"] = "disabled" +os.environ["TRANSFORMERS_VERBOSITY"] = "info" + + +class HfCehrBertRunnerIntegrationTest(unittest.TestCase): + def setUp(self): + # Get the root folder of the project + root_folder = Path(os.path.abspath(__file__)).parent.parent.parent.parent + data_folder = os.path.join(root_folder, "sample_data", "pretrain") + # Create a temporary directory to store model and tokenizer + self.temp_dir = tempfile.mkdtemp() + self.model_folder_path = os.path.join(self.temp_dir, "model") + Path(self.model_folder_path).mkdir(parents=True, exist_ok=True) + self.dataset_prepared_path = os.path.join(self.temp_dir, "dataset_prepared_path") + Path(self.dataset_prepared_path).mkdir(parents=True, exist_ok=True) + sys.argv = [ + "hf_cehrbert_pretraining_runner.py", + "--model_name_or_path", + self.model_folder_path, + "--tokenizer_name_or_path", + self.model_folder_path, + "--output_dir", + self.model_folder_path, + "--data_folder", + data_folder, + "--dataset_prepared_path", + self.dataset_prepared_path, + "--max_steps", + "10", + "--streaming", + ] + + def tearDown(self): + # Remove the temporary directory + shutil.rmtree(self.temp_dir) + + def test_train_model(self): + main() + + +if __name__ == "__main__": + unittest.main()