-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GraphBolt] modify logic for HeteroItemSet
indexing
#7428
base: master
Are you sure you want to change the base?
Changes from all commits
4d91a36
e5b84c3
5968610
187c8b8
6aaa74b
6c3a7f2
e7ae261
d70401c
eb94470
50da718
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -329,6 +329,8 @@ def __init__(self, itemsets: Dict[str, ItemSet]) -> None: | |
self._offsets = torch.tensor(offset).cumsum(0) | ||
self._length = int(self._offsets[-1]) | ||
self._keys = list(self._itemsets.keys()) | ||
self._num_types = len(self._keys) | ||
self._threshold = 1 << self._num_types | ||
|
||
def __len__(self) -> int: | ||
return self._length | ||
|
@@ -365,18 +367,38 @@ def __getitem__(self, index: Union[int, slice, Iterable[int]]): | |
elif isinstance(index, Iterable): | ||
if not isinstance(index, torch.Tensor): | ||
index = torch.tensor(index) | ||
assert torch.all((index >= 0) & (index < self._length)) | ||
key_indices = ( | ||
torch.searchsorted(self._offsets, index, right=True) - 1 | ||
) | ||
data = {} | ||
for key_id, key in enumerate(self._keys): | ||
mask = (key_indices == key_id).nonzero().squeeze(1) | ||
if len(mask) == 0: | ||
continue | ||
data[key] = self._itemsets[key][ | ||
index[mask] - self._offsets[key_id] | ||
] | ||
if len(index) < self._threshold: | ||
# Say N = len(index), and K = num_types. | ||
# If logN < K, we use the algo with time complexity O(N*logN). | ||
sorted_index, indices = index.sort() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we try numpy.argsort here to get indices and index[indices] to get sorted_index? It looks like numpy might have a more efficient sorting implementation. When benchmarking, we should ensure that we have a recent version of numpy installed. It looks like numpy uses this efficient sorting implementation by intel: https://github.com/intel/x86-simd-sort There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is assuming that the sort is the bottleneck for this code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your info!
How recent is the very version? Because it seems that we have just diabled numpy>=2.0.0 in #7479 . There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. numpy/numpy#22315 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's use the latest 1.x version and see how the performance is. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Though the code changes were commited in numpy/numpy#22315 where the version is 1.25, this improvement was not officially announced until NumPy 2.0.0 Release Notes. Therefore, it is likely that they did not integrate the changes until version 2.0.0. I plan to move that we offer full support for numpy>=2.0.0 at the Monday meeting, and perform the benchmark after we do so. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See https://docs.google.com/document/d/1Bbmp8gMekiGIYYxEMVbmXSANRZlZ_nTNbhpWul4RaKA/edit?usp=sharing . The results seem to have changed little. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for the updated numbers. I will profile the benchmark code and see if there is a potential improvement we can do. |
||
assert sorted_index[0] >= 0 and sorted_index[-1] < self._length | ||
index_offsets = torch.searchsorted( | ||
sorted_index, self._offsets, right=False | ||
) | ||
for key_id, key in enumerate(self._keys): | ||
if index_offsets[key_id] == index_offsets[key_id + 1]: | ||
continue | ||
current_indices, _ = indices[ | ||
index_offsets[key_id] : index_offsets[key_id + 1] | ||
].sort() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could use np.sort here as well. |
||
data[key] = self._itemsets[key][ | ||
index[current_indices] - self._offsets[key_id] | ||
] | ||
else: | ||
# Say N = len(index), and K = num_types. | ||
# If logN >= K, we use the algo with time complexity O(N*K). | ||
assert torch.all((index >= 0) & (index < self._length)) | ||
key_indices = ( | ||
torch.searchsorted(self._offsets, index, right=True) - 1 | ||
) | ||
for key_id, key in enumerate(self._keys): | ||
mask = (key_indices == key_id).nonzero().squeeze(1) | ||
if len(mask) == 0: | ||
continue | ||
data[key] = self._itemsets[key][ | ||
index[mask] - self._offsets[key_id] | ||
] | ||
return data | ||
else: | ||
raise TypeError( | ||
|
@@ -386,9 +408,14 @@ def __getitem__(self, index: Union[int, slice, Iterable[int]]): | |
|
||
@property | ||
def names(self) -> Tuple[str]: | ||
"""Return the names of the items.""" | ||
"""Returns the names of the items.""" | ||
return self._names | ||
|
||
@property | ||
def num_types(self) -> int: | ||
"""Returns the number of types.""" | ||
return self._num_types | ||
|
||
def __repr__(self) -> str: | ||
ret = ( | ||
"{Classname}(\n" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need a benchmark before settings such a threshold only by looking at runtime complexity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right. I'll do it right away.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use K values 1 2 4 8 16 32 etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a little bit complex, I think the optimal complexity should be O(N*logK), with additional 2 helper array: offsets = [0, 10, 30, 60], etypes = ["A", "B", "C", "D"]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NVM, I see the points, to do the way I suggested, we need to implement our own C++ kernel with parallel optimization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is any of the ops used in the new implementation single threaded?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems no. Maybe sorting op itself is a bit slow even it's multi-threaded?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How did you verify?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure, I'm just guessing. Did you find out anything from the benchmarking code?