From 4d91a36a67501aaa6422e6b1d4d3b5e8cd8a1132 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 24 May 2024 09:15:52 +0000 Subject: [PATCH 1/8] 1 --- python/dgl/graphbolt/itemset.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index cefb68964a0f..970653bbd0f1 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -365,18 +365,23 @@ def __getitem__(self, index: Union[int, slice, Iterable[int]]): elif isinstance(index, Iterable): if not isinstance(index, torch.Tensor): index = torch.tensor(index) - assert torch.all((index >= 0) & (index < self._length)) - key_indices = ( - torch.searchsorted(self._offsets, index, right=True) - 1 - ) + sorted_index, indices = index.sort() + assert sorted_index[0] >= 0 and sorted_index[-1] < self._length + index_offsets = torch.searchsorted(sorted_index, self._offsets) + # print(f"{index = }\n{sorted_index = }\n{indices = },\n{inv_p = }\n{self._offsets = }\n{index_offsets = }\n{num_per_key = }") + # assert 0 data = {} for key_id, key in enumerate(self._keys): - mask = (key_indices == key_id).nonzero().squeeze(1) - if len(mask) == 0: + if index_offsets[key_id] == index_offsets[key_id+1]: continue - data[key] = self._itemsets[key][ - index[mask] - self._offsets[key_id] - ] + current_indices, _ = indices[index_offsets[key_id]: index_offsets[key_id+1]].sort() + data[key] = self._itemsets[key][index[current_indices] - self._offsets[key_id]] + # mask = (key_indices == key_id).nonzero().squeeze(1) + # if len(mask) == 0: + # continue + # data[key] = self._itemsets[key][ + # index[mask] - self._offsets[key_id] + # ] return data else: raise TypeError( From e5b84c300fc6f878b432fd0ab9accc7c7d7a6e4c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 24 May 2024 09:18:23 +0000 Subject: [PATCH 2/8] lint --- python/dgl/graphbolt/itemset.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index 970653bbd0f1..6a5135bc639e 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -368,20 +368,16 @@ def __getitem__(self, index: Union[int, slice, Iterable[int]]): sorted_index, indices = index.sort() assert sorted_index[0] >= 0 and sorted_index[-1] < self._length index_offsets = torch.searchsorted(sorted_index, self._offsets) - # print(f"{index = }\n{sorted_index = }\n{indices = },\n{inv_p = }\n{self._offsets = }\n{index_offsets = }\n{num_per_key = }") - # assert 0 data = {} for key_id, key in enumerate(self._keys): - if index_offsets[key_id] == index_offsets[key_id+1]: + if index_offsets[key_id] == index_offsets[key_id + 1]: continue - current_indices, _ = indices[index_offsets[key_id]: index_offsets[key_id+1]].sort() - data[key] = self._itemsets[key][index[current_indices] - self._offsets[key_id]] - # mask = (key_indices == key_id).nonzero().squeeze(1) - # if len(mask) == 0: - # continue - # data[key] = self._itemsets[key][ - # index[mask] - self._offsets[key_id] - # ] + current_indices, _ = indices[ + index_offsets[key_id] : index_offsets[key_id + 1] + ].sort() + data[key] = self._itemsets[key][ + index[current_indices] - self._offsets[key_id] + ] return data else: raise TypeError( From 5968610ed85408edfa3d84ee03be8e57af5af0e6 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 24 May 2024 09:38:16 +0000 Subject: [PATCH 3/8] modify itemsampler --- python/dgl/graphbolt/item_sampler.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/dgl/graphbolt/item_sampler.py b/python/dgl/graphbolt/item_sampler.py index 13533f30788c..f7b5a0ccf131 100644 --- a/python/dgl/graphbolt/item_sampler.py +++ b/python/dgl/graphbolt/item_sampler.py @@ -323,15 +323,17 @@ def _collate_batch(self, buffer, indices, offsets=None): elif isinstance(buffer, Mapping): # For item set that's initialized with a dict of items, # `buffer` is a dict of tensors/lists/tuples. - keys = list(buffer.keys()) - key_indices = torch.searchsorted(offsets, indices, right=True) - 1 + sorted_indices, ind = indices.sort() + indices_offsets = torch.searchsorted(sorted_indices, offsets) batch = {} - for j, key in enumerate(keys): - mask = (key_indices == j).nonzero().squeeze(1) - if len(mask) == 0: + for key_id, key in enumerate(buffer.keys()): + if indices_offsets[key_id] == indices_offsets[key_id + 1]: continue + current_indices, _ = ind[ + indices_offsets[key_id] : indices_offsets[key_id + 1] + ].sort() batch[key] = self._collate_batch( - buffer[key], indices[mask] - offsets[j] + buffer[key], indices[current_indices] - offsets[key_id] ) return batch raise TypeError(f"Unsupported buffer type {type(buffer).__name__}.") From 187c8b8b12ac8d9747bb93720568873a40411407 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 26 May 2024 17:54:10 +0000 Subject: [PATCH 4/8] cls --- python/dgl/graphbolt/item_sampler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/dgl/graphbolt/item_sampler.py b/python/dgl/graphbolt/item_sampler.py index f7b5a0ccf131..812a0e241be2 100644 --- a/python/dgl/graphbolt/item_sampler.py +++ b/python/dgl/graphbolt/item_sampler.py @@ -310,7 +310,8 @@ def __init__( # used in every epoch. self._epoch = 0 - def _collate_batch(self, buffer, indices, offsets=None): + @classmethod + def _collate_batch(cls, buffer, indices, offsets=None): """Collate a batch from the buffer. For internal use only.""" if isinstance(buffer, torch.Tensor): # For item set that's initialized with integer or single tensor, @@ -332,7 +333,7 @@ def _collate_batch(self, buffer, indices, offsets=None): current_indices, _ = ind[ indices_offsets[key_id] : indices_offsets[key_id + 1] ].sort() - batch[key] = self._collate_batch( + batch[key] = cls._collate_batch( buffer[key], indices[current_indices] - offsets[key_id] ) return batch From 6aaa74be908a63f8582d3443eb02cc8048f1eb44 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 26 May 2024 18:18:33 +0000 Subject: [PATCH 5/8] itemsampler no need to sort --- python/dgl/graphbolt/item_sampler.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/dgl/graphbolt/item_sampler.py b/python/dgl/graphbolt/item_sampler.py index 812a0e241be2..075e2caf4af1 100644 --- a/python/dgl/graphbolt/item_sampler.py +++ b/python/dgl/graphbolt/item_sampler.py @@ -324,17 +324,18 @@ def _collate_batch(cls, buffer, indices, offsets=None): elif isinstance(buffer, Mapping): # For item set that's initialized with a dict of items, # `buffer` is a dict of tensors/lists/tuples. - sorted_indices, ind = indices.sort() - indices_offsets = torch.searchsorted(sorted_indices, offsets) + # sorted_indices, ind = indices.sort() + indices_offsets = torch.searchsorted(indices, offsets) batch = {} for key_id, key in enumerate(buffer.keys()): if indices_offsets[key_id] == indices_offsets[key_id + 1]: continue - current_indices, _ = ind[ - indices_offsets[key_id] : indices_offsets[key_id + 1] - ].sort() + # current_indices, _ = ind[ + # indices_offsets[key_id] : indices_offsets[key_id + 1] + # ].sort() + current_indices = indices[indices_offsets[key_id] : indices_offsets[key_id + 1]] - offsets[key_id] batch[key] = cls._collate_batch( - buffer[key], indices[current_indices] - offsets[key_id] + buffer[key], current_indices, ) return batch raise TypeError(f"Unsupported buffer type {type(buffer).__name__}.") From 6c3a7f203cb4b94964ea2cd4e1e172190f488ad1 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 27 May 2024 02:51:37 +0000 Subject: [PATCH 6/8] revert changes in itemsampler --- python/dgl/graphbolt/item_sampler.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/python/dgl/graphbolt/item_sampler.py b/python/dgl/graphbolt/item_sampler.py index 075e2caf4af1..13533f30788c 100644 --- a/python/dgl/graphbolt/item_sampler.py +++ b/python/dgl/graphbolt/item_sampler.py @@ -310,8 +310,7 @@ def __init__( # used in every epoch. self._epoch = 0 - @classmethod - def _collate_batch(cls, buffer, indices, offsets=None): + def _collate_batch(self, buffer, indices, offsets=None): """Collate a batch from the buffer. For internal use only.""" if isinstance(buffer, torch.Tensor): # For item set that's initialized with integer or single tensor, @@ -324,18 +323,15 @@ def _collate_batch(cls, buffer, indices, offsets=None): elif isinstance(buffer, Mapping): # For item set that's initialized with a dict of items, # `buffer` is a dict of tensors/lists/tuples. - # sorted_indices, ind = indices.sort() - indices_offsets = torch.searchsorted(indices, offsets) + keys = list(buffer.keys()) + key_indices = torch.searchsorted(offsets, indices, right=True) - 1 batch = {} - for key_id, key in enumerate(buffer.keys()): - if indices_offsets[key_id] == indices_offsets[key_id + 1]: + for j, key in enumerate(keys): + mask = (key_indices == j).nonzero().squeeze(1) + if len(mask) == 0: continue - # current_indices, _ = ind[ - # indices_offsets[key_id] : indices_offsets[key_id + 1] - # ].sort() - current_indices = indices[indices_offsets[key_id] : indices_offsets[key_id + 1]] - offsets[key_id] - batch[key] = cls._collate_batch( - buffer[key], current_indices, + batch[key] = self._collate_batch( + buffer[key], indices[mask] - offsets[j] ) return batch raise TypeError(f"Unsupported buffer type {type(buffer).__name__}.") From d70401c088156b564ea13ea40b93845c5aa5e03a Mon Sep 17 00:00:00 2001 From: Skeleton003 <799284168@qq.com> Date: Wed, 12 Jun 2024 21:59:16 +0000 Subject: [PATCH 7/8] threshold --- python/dgl/graphbolt/itemset.py | 45 ++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index 6a5135bc639e..eb4a1a33ab76 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -329,6 +329,8 @@ def __init__(self, itemsets: Dict[str, ItemSet]) -> None: self._offsets = torch.tensor(offset).cumsum(0) self._length = int(self._offsets[-1]) self._keys = list(self._itemsets.keys()) + self._num_types = len(self._keys) + self._threshold = 1 << self._num_types def __len__(self) -> int: return self._length @@ -365,19 +367,38 @@ def __getitem__(self, index: Union[int, slice, Iterable[int]]): elif isinstance(index, Iterable): if not isinstance(index, torch.Tensor): index = torch.tensor(index) - sorted_index, indices = index.sort() - assert sorted_index[0] >= 0 and sorted_index[-1] < self._length - index_offsets = torch.searchsorted(sorted_index, self._offsets) data = {} - for key_id, key in enumerate(self._keys): - if index_offsets[key_id] == index_offsets[key_id + 1]: - continue - current_indices, _ = indices[ - index_offsets[key_id] : index_offsets[key_id + 1] - ].sort() - data[key] = self._itemsets[key][ - index[current_indices] - self._offsets[key_id] - ] + if len(index) < self._threshold: + # Say N = len(index), and K = num_types. + # If logN < K, we use the algo with time complexity O(N*logN). + sorted_index, indices = index.sort() + assert sorted_index[0] >= 0 and sorted_index[-1] < self._length + index_offsets = torch.searchsorted( + sorted_index, self._offsets, right=False + ) + for key_id, key in enumerate(self._keys): + if index_offsets[key_id] == index_offsets[key_id + 1]: + continue + current_indices, _ = indices[ + index_offsets[key_id] : index_offsets[key_id + 1] + ].sort() + data[key] = self._itemsets[key][ + index[current_indices] - self._offsets[key_id] + ] + else: + # Say N = len(index), and K = num_types. + # If logN >= K, we use the algo with time complexity O(N*K). + assert torch.all((index >= 0) & (index < self._length)) + key_indices = ( + torch.searchsorted(self._offsets, index, right=True) - 1 + ) + for key_id, key in enumerate(self._keys): + mask = (key_indices == key_id).nonzero().squeeze(1) + if len(mask) == 0: + continue + data[key] = self._itemsets[key][ + index[mask] - self._offsets[key_id] + ] return data else: raise TypeError( From eb94470ea3d1eea82eb85073fe865d0a342f6d7e Mon Sep 17 00:00:00 2001 From: Skeleton003 <799284168@qq.com> Date: Wed, 12 Jun 2024 22:02:21 +0000 Subject: [PATCH 8/8] property --- python/dgl/graphbolt/itemset.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index eb4a1a33ab76..4a9544ca022d 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -408,9 +408,14 @@ def __getitem__(self, index: Union[int, slice, Iterable[int]]): @property def names(self) -> Tuple[str]: - """Return the names of the items.""" + """Returns the names of the items.""" return self._names + @property + def num_types(self) -> int: + """Returns the number of types.""" + return self._num_types + def __repr__(self) -> str: ret = ( "{Classname}(\n"