diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index cefb68964a0f..4a9544ca022d 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -329,6 +329,8 @@ def __init__(self, itemsets: Dict[str, ItemSet]) -> None: self._offsets = torch.tensor(offset).cumsum(0) self._length = int(self._offsets[-1]) self._keys = list(self._itemsets.keys()) + self._num_types = len(self._keys) + self._threshold = 1 << self._num_types def __len__(self) -> int: return self._length @@ -365,18 +367,38 @@ def __getitem__(self, index: Union[int, slice, Iterable[int]]): elif isinstance(index, Iterable): if not isinstance(index, torch.Tensor): index = torch.tensor(index) - assert torch.all((index >= 0) & (index < self._length)) - key_indices = ( - torch.searchsorted(self._offsets, index, right=True) - 1 - ) data = {} - for key_id, key in enumerate(self._keys): - mask = (key_indices == key_id).nonzero().squeeze(1) - if len(mask) == 0: - continue - data[key] = self._itemsets[key][ - index[mask] - self._offsets[key_id] - ] + if len(index) < self._threshold: + # Say N = len(index), and K = num_types. + # If logN < K, we use the algo with time complexity O(N*logN). + sorted_index, indices = index.sort() + assert sorted_index[0] >= 0 and sorted_index[-1] < self._length + index_offsets = torch.searchsorted( + sorted_index, self._offsets, right=False + ) + for key_id, key in enumerate(self._keys): + if index_offsets[key_id] == index_offsets[key_id + 1]: + continue + current_indices, _ = indices[ + index_offsets[key_id] : index_offsets[key_id + 1] + ].sort() + data[key] = self._itemsets[key][ + index[current_indices] - self._offsets[key_id] + ] + else: + # Say N = len(index), and K = num_types. + # If logN >= K, we use the algo with time complexity O(N*K). + assert torch.all((index >= 0) & (index < self._length)) + key_indices = ( + torch.searchsorted(self._offsets, index, right=True) - 1 + ) + for key_id, key in enumerate(self._keys): + mask = (key_indices == key_id).nonzero().squeeze(1) + if len(mask) == 0: + continue + data[key] = self._itemsets[key][ + index[mask] - self._offsets[key_id] + ] return data else: raise TypeError( @@ -386,9 +408,14 @@ def __getitem__(self, index: Union[int, slice, Iterable[int]]): @property def names(self) -> Tuple[str]: - """Return the names of the items.""" + """Returns the names of the items.""" return self._names + @property + def num_types(self) -> int: + """Returns the number of types.""" + return self._num_types + def __repr__(self) -> str: ret = ( "{Classname}(\n"