From 7cc7b2e00043170ce081928aa68367d6350f0b95 Mon Sep 17 00:00:00 2001 From: duemoo Date: Thu, 18 Apr 2024 16:58:10 +0900 Subject: [PATCH] add data extraction code --- analysis/create_inject_indices_map.py | 2 +- analysis/extract_data.py | 56 +++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 analysis/extract_data.py diff --git a/analysis/create_inject_indices_map.py b/analysis/create_inject_indices_map.py index ca346fb..60ad470 100644 --- a/analysis/create_inject_indices_map.py +++ b/analysis/create_inject_indices_map.py @@ -21,7 +21,7 @@ data_order_file_path = cached_path("https://olmo-checkpoints.org/ai2-llm/olmo-small/46zc5fly/train_data/global_indices.npy") train_config_path = "/mnt/nas/hoyeon/OLMo/configs/official/OLMo-1B.yaml" -dataset_path = 'fictional_knowledge.json' +dataset_path = '/home/hoyeon/OLMo/fictional_knowledge/fictional_knowledge.json' with open(dataset_path, 'r') as f: data = json.load(f) diff --git a/analysis/extract_data.py b/analysis/extract_data.py new file mode 100644 index 0000000..e511d59 --- /dev/null +++ b/analysis/extract_data.py @@ -0,0 +1,56 @@ +import numpy as np +from tqdm import tqdm +from cached_path import cached_path +import os + +from olmo.config import TrainConfig +from olmo.data import build_memmap_dataset + +# Update these paths to what you want: +data_order_file_path = cached_path("https://olmo-checkpoints.org/ai2-llm/olmo-medium/wvc30anm/train_data/global_indices.npy") +train_config_path = "/home/hoyeon/OLMo/configs/official/OLMo-7B.yaml" + + +cfg = TrainConfig.load(train_config_path) +dataset = build_memmap_dataset(cfg, cfg.data) +batch_size = cfg.global_train_batch_size +global_indices = np.memmap(data_order_file_path, mode="r+", dtype=np.uint32) + +def get_batch_instances(batch_idx: int) -> list[list[int]]: + batch_start = batch_idx * batch_size + batch_end = (batch_idx + 1) * batch_size + batch_indices = global_indices[batch_start:batch_end] + batch_instances = [] + for index in tqdm(batch_indices): + # print(dataset[index].keys()) + # break + data = dataset[index] + # print(f"data: {data}") + batch_instances.append(data) + return batch_instances + + +def split_array(data, chunk_size): + """Yield successive chunk-sized arrays from data.""" + for i in range(0, len(data), chunk_size*2048): + yield data[i:i + chunk_size*2048] + +def save_chunks(data, chunk_size, directory='dolma_extracted'): + + # if not os.path.exists(directory): + # os.makedirs(directory) + + for i, chunk in enumerate(split_array(data, chunk_size)): + filename = f"{directory}/part-{i:05d}.npy" + np.save(filename, chunk) + print(f"Saved {filename}") + +batch_indices = range(360000,363000) + +extracted_dataset = [] +print(batch_indices) +for i, idx in enumerate(tqdm(batch_indices)): + extracted_dataset.extend(get_batch_instances(idx)) + +print(f"len extracted data: {len(extracted_dataset)}") +save_chunks(extracted_dataset, 30) \ No newline at end of file