Skip to content

Commit

Permalink
Use shared memory and copy out with threads (#2476)
Browse files Browse the repository at this point in the history
  • Loading branch information
fiedorowicz1 authored Oct 15, 2024
1 parent cb07681 commit 1e5114c
Showing 1 changed file with 87 additions and 29 deletions.
116 changes: 87 additions & 29 deletions python/lbann/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import pickle
import lbann
from multiprocessing import Pool
from multiprocessing.shared_memory import SharedMemory
import numpy as np
from typing import Dict, List, Optional, Union
from numpy.typing import ArrayLike
import concurrent.futures as cf


class Sample:
Expand Down Expand Up @@ -161,8 +163,9 @@ class DataReader:
Helper class used by LBANN to control worker processes and handle sample/batch loading.
"""

def __init__(self, dataset: Dataset, num_procs: int, prefetch_factor: int,
dtype: str) -> None:
def __init__(
self, dataset: Dataset, num_procs: int, prefetch_factor: int, dtype: str
) -> None:
"""
DataReader Constructor
Expand All @@ -184,13 +187,16 @@ def __init__(self, dataset: Dataset, num_procs: int, prefetch_factor: int,
self.sample_dims = dataset.get_sample_dims()
self.num_io_partitions = 1
self.loaded_samples = []
self.thread_pool = cf.ThreadPoolExecutor(max_workers=num_procs)

if isinstance(self.dataset, DistConvDataset):
self.num_io_partitions = self.dataset.num_io_partitions

self.pool = Pool(processes=num_procs,
initializer=DataReader.init_worker,
initargs=(self.dataset, ))
self.pool = Pool(
processes=num_procs,
initializer=DataReader.init_worker,
initargs=(self.dataset,),
)

@staticmethod
def init_worker(dataset):
Expand Down Expand Up @@ -232,14 +238,47 @@ def load_sample(ind) -> Sample:
:return: Sample
:rtype: Sample
"""
return g_dataset[ind]
samp = g_dataset[ind]

shm_size = 0
dtype = None
if hasattr(samp, "sample"):
dtype = samp.sample.dtype
shm_size += samp.sample.size
if hasattr(samp, "label"):
dtype = samp.label.dtype
shm_size += samp.label.size
if hasattr(samp, "response"):
dtype = samp.response.dtype
shm_size += samp.response.size

shm = SharedMemory(create=True, size=shm_size * dtype.itemsize)
shm_arr = np.ndarray(shm_size, dtype=dtype, buffer=shm.buf)

offset = 0
if hasattr(samp, "sample"):
new_offset = offset + samp.sample.size
shm_arr[offset:new_offset] = samp.sample.ravel()
offset = new_offset
if hasattr(samp, "label"):
new_offset = offset + samp.label.size
shm_arr[offset:new_offset] = samp.label.ravel()
offset = new_offset
if hasattr(samp, "response"):
new_offset = offset + samp.response.size
shm_arr[offset:new_offset] = samp.response.ravel()
offset = new_offset

shm.close()
return shm.name, shm_size

def load_next_sample_async(self, ind: int):
"""
Submit the next sample index to be loaded to the worker pool.
"""
self.loaded_samples.append(
self.pool.apply_async(DataReader.load_sample, (ind, )))
self.pool.apply_async(DataReader.load_sample, (ind,))
)

def queue_samples(self, inds: List[int]) -> None:
"""
Expand All @@ -261,34 +300,53 @@ def get_batch(self, batch_size: int) -> Dict[str, Union[np.ndarray, int]]:
:return: Batch of samples and pointers for each input field
:rtype: Dict[str, Union[np.ndarray, int]]
"""
samples = []
for _ in range(batch_size):
samples.append(self.loaded_samples.pop(0).get())

batch = {}

# Note: we return the arrays with the pointers so that they aren't
# deallocated by the garbage collector.
batch["sample"] = np.ascontiguousarray([s.sample for s in samples],
dtype=self.dtype)
batch["sample_ptr"] = batch["sample"].ctypes.data
assert (batch["sample"].size == np.prod(self.sample_dims.sample) *
batch_size / self.num_io_partitions)

if hasattr(self.sample_dims, "sample"):
sample_size = np.prod(self.sample_dims.sample) // self.num_io_partitions
batch["sample"] = np.empty([batch_size, sample_size], dtype=self.dtype)
batch["sample_ptr"] = batch["sample"].ctypes.data
if hasattr(self.sample_dims, "label"):
batch["label"] = np.ascontiguousarray([s.label for s in samples],
dtype=self.dtype)
label_size = np.prod(self.sample_dims.sample)
batch["label"] = np.empty([batch_size, label_size], dtype=self.dtype)
batch["label_ptr"] = batch["label"].ctypes.data
assert batch["label"].size == np.prod(
self.sample_dims.label) * batch_size

if hasattr(self.sample_dims, "response"):
batch["response"] = np.ascontiguousarray(
[s.response for s in samples], dtype=self.dtype)
response_size = self.sample_dims.response
batch["response"] = np.empty([batch_size, response_size], dtype=self.dtype)
batch["response_ptr"] = batch["response"].ctypes.data
assert (
batch["response"].size == np.prod(self.sample_dims.response) *
batch_size)

def copy_to_array(i, sample):
shm_name, shm_size = sample.get()

shm = SharedMemory(name=shm_name)
shm_arr = np.ndarray(shm_size, dtype=self.dtype, buffer=shm.buf)

offset = 0
if hasattr(self.sample_dims, "sample"):
new_offset = offset + sample_size
batch["sample"][i, :] = shm_arr[offset:new_offset]
offset = new_offset
if hasattr(self.sample_dims, "label"):
new_offset = offset + label_size
batch["label"][i, :] = shm_arr[offset:new_offset]
offset = new_offset
if hasattr(self.sample_dims, "response"):
new_offset = offset + response_size
batch["response"][i, :] = shm_arr[offset:new_offset]
offset = new_offset

del shm_arr

shm.close()
shm.unlink()

futures = []
for i in range(batch_size):
futures.append(
self.thread_pool.submit(copy_to_array, i, self.loaded_samples.pop(0))
)

cf.wait(futures)

return batch

Expand Down

0 comments on commit 1e5114c

Please sign in to comment.