Skip to content

Commit

Permalink
Fix non-deterministic output of gnn sampler (#1677)
Browse files Browse the repository at this point in the history
Fix the issues by updating the sampler during inference to full sampling from subsampling.



Closes #1676 

## By Submitting this PR I confirm:
- I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md).
- When the PR is ready for review, new or existing tests cover these changes.
- When the PR is ready for review, the documentation is up to date with these changes.

Authors:
  - Tad ZeMicheal (https://github.com/tzemicheal)

Approvers:
  - David Gardner (https://github.com/dagardner-nv)
  - https://github.com/raykallen

URL: #1677
  • Loading branch information
tzemicheal authored May 21, 2024
1 parent 08e40dc commit 2fe4dd3
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examples/gnn_fraud_detection_pipeline/stages/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def inference(self,
"""

# create sampler and test dataloaders
full_sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts=[4, 3])
full_sampler = dgl.dataloading.MultiLayerFullNeighborSampler(num_layers=3)
test_dataloader = dgl.dataloading.DataLoader(input_graph, {target_node: test_idx},
full_sampler,
batch_size=batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from morpheus.messages import MessageMeta
from morpheus.messages import MultiMessage

# pylint: disable=no-name-in-module


# pylint: disable=no-name-in-module
@pytest.mark.usefixtures("manual_seed")
@pytest.mark.use_python
class TestGraphSageStage:

Expand Down Expand Up @@ -68,4 +68,4 @@ def test_process_message(self,
cols = results.inductive_embedding_column_names + ['index']
assert sorted(cols) == sorted(expected_df.columns)
ind_emb_df = results.get_meta(cols)
dataset_pandas.assert_compare_df(ind_emb_df.to_pandas(), expected_df)
dataset_pandas.assert_compare_df(ind_emb_df.to_pandas(), expected_df, abs_tol=1, rel_tol=1)
Git LFS file not shown

0 comments on commit 2fe4dd3

Please sign in to comment.