Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt] modify logic for HeteroItemSet indexing #7428

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions python/dgl/graphbolt/itemset.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ def __init__(self, itemsets: Dict[str, ItemSet]) -> None:
self._offsets = torch.tensor(offset).cumsum(0)
self._length = int(self._offsets[-1])
self._keys = list(self._itemsets.keys())
self._num_types = len(self._keys)
self._threshold = 1 << self._num_types

def __len__(self) -> int:
return self._length
Expand Down Expand Up @@ -365,18 +367,38 @@ def __getitem__(self, index: Union[int, slice, Iterable[int]]):
elif isinstance(index, Iterable):
if not isinstance(index, torch.Tensor):
index = torch.tensor(index)
assert torch.all((index >= 0) & (index < self._length))
key_indices = (
torch.searchsorted(self._offsets, index, right=True) - 1
)
data = {}
for key_id, key in enumerate(self._keys):
mask = (key_indices == key_id).nonzero().squeeze(1)
if len(mask) == 0:
continue
data[key] = self._itemsets[key][
index[mask] - self._offsets[key_id]
]
if len(index) < self._threshold:
Copy link
Collaborator

@mfbalin mfbalin Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need a benchmark before settings such a threshold only by looking at runtime complexity.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. I'll do it right away.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use K values 1 2 4 8 16 32 etc.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a little bit complex, I think the optimal complexity should be O(N*logK), with additional 2 helper array: offsets = [0, 10, 30, 60], etypes = ["A", "B", "C", "D"]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVM, I see the points, to do the way I suggested, we need to implement our own C++ kernel with parallel optimization.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is any of the ops used in the new implementation single threaded?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems no. Maybe sorting op itself is a bit slow even it's multi-threaded?

Copy link
Collaborator

@mfbalin mfbalin Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did you verify?

Copy link
Collaborator Author

@Skeleton003 Skeleton003 Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, I'm just guessing. Did you find out anything from the benchmarking code?

# Say N = len(index), and K = num_types.
# If logN < K, we use the algo with time complexity O(N*logN).
sorted_index, indices = index.sort()
Copy link
Collaborator

@mfbalin mfbalin Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we try numpy.argsort here to get indices and index[indices] to get sorted_index? It looks like numpy might have a more efficient sorting implementation. When benchmarking, we should ensure that we have a recent version of numpy installed. It looks like numpy uses this efficient sorting implementation by intel: https://github.com/intel/x86-simd-sort

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is assuming that the sort is the bottleneck for this code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your info!

When benchmarking, we should ensure that we have a recent version of numpy installed.

How recent is the very version? Because it seems that we have just diabled numpy>=2.0.0 in #7479 .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy/numpy#22315
They added faster sort in this PR. Looks like the version number is 1.25 or later.
https://github.com/search?q=repo%3Anumpy%2Fnumpy%20%2322315&type=code

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use the latest 1.x version and see how the performance is.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though the code changes were commited in numpy/numpy#22315 where the version is 1.25, this improvement was not officially announced until NumPy 2.0.0 Release Notes. Therefore, it is likely that they did not integrate the changes until version 2.0.0.

I plan to move that we offer full support for numpy>=2.0.0 at the Monday meeting, and perform the benchmark after we do so.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy>=2 is compatible with DGL. I think we can perform this benchmark. I just ran the graphbolt tests with numpy>=2 installed and all passed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the updated numbers. I will profile the benchmark code and see if there is a potential improvement we can do.

assert sorted_index[0] >= 0 and sorted_index[-1] < self._length
index_offsets = torch.searchsorted(
sorted_index, self._offsets, right=False
)
for key_id, key in enumerate(self._keys):
if index_offsets[key_id] == index_offsets[key_id + 1]:
continue
current_indices, _ = indices[
index_offsets[key_id] : index_offsets[key_id + 1]
].sort()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use np.sort here as well.

data[key] = self._itemsets[key][
index[current_indices] - self._offsets[key_id]
]
else:
# Say N = len(index), and K = num_types.
# If logN >= K, we use the algo with time complexity O(N*K).
assert torch.all((index >= 0) & (index < self._length))
key_indices = (
torch.searchsorted(self._offsets, index, right=True) - 1
)
for key_id, key in enumerate(self._keys):
mask = (key_indices == key_id).nonzero().squeeze(1)
if len(mask) == 0:
continue
data[key] = self._itemsets[key][
index[mask] - self._offsets[key_id]
]
return data
else:
raise TypeError(
Expand All @@ -386,9 +408,14 @@ def __getitem__(self, index: Union[int, slice, Iterable[int]]):

@property
def names(self) -> Tuple[str]:
"""Return the names of the items."""
"""Returns the names of the items."""
return self._names

@property
def num_types(self) -> int:
"""Returns the number of types."""
return self._num_types

def __repr__(self) -> str:
ret = (
"{Classname}(\n"
Expand Down
Loading