Skip to content

Latest commit

 

History

History
89 lines (73 loc) · 2.84 KB

README.md

File metadata and controls

89 lines (73 loc) · 2.84 KB

Fused Network Gromov-Wasserstein Distance

Implementation of the TMLR paper Exploiting Edge Features in Graph-based Learning with Fused Network Gromov-Wasserstein Distance.

Attention: In the paper, we use E to denote edge feature tensor, in our code, we use C instead for some cases.

Environment Preparation

Before running the code, please create the following conda environment:

conda create -n fngw_env python=3.8.13
conda activate fngw_env
pip install POT==0.8.2 \
            numpy==1.24.4 \
            rdkit==2022.3.5 \
            scipy==1.8.1 \
            scikit-learn==1.1.1 \
            joblib==1.1.0 \
            networkx==3.1 \
            matplotlib==3.5.2 \
            torch-geometric==2.1.0.post1 \
            tqdm==4.64.0 \
            pandas==2.0.3 \
            chainer-chemistry==0.7.1

Distance Computation

To compute the fngw distance between two node-featured and edge-featured graphs, please call the following function:

from fngw import fused_network_gromov_wasserstein2

Here is an simple example of how to compute the FNGW distance between the two above graphs:

from ot.utils import unif, dist
import numpy as np

# Create two graphs
ones_F = np.eye(2, dtype=float)
ones_E = np.eye(3, dtype=float)
F1 = np.array([0, 0, 1])
F1 = ones_F[F1]
A1 = np.array([[0, 1, 1], [1, 0, 0], [1, 0, 0]], dtype=float)
E1 = np.array([[0, 1, 2], [1, 0, 0], [2, 0, 0]])
E1 = ones_E[E1]

F2 = np.array([0, 1, 0])
F2 = ones_F[F2]
A2 = np.array([[0, 1, 1], [1, 0, 0], [1, 0, 0]], dtype=float)
E2 = np.array([[0, 1, 2], [1, 0, 0], [2, 0, 0]])
E2= ones_E[E2]

# Compute FNGW distance
M = dist(F1, F2)
dist = fused_network_gromov_wasserstein2(M, E1, E2, A1, A2,
                                         p=unif(3), q=unif(3),
                                         alpha=1/3, beta=1/3)
print(dist)
# Out: 0.2962962962962963 

Barycenter Computation

Given a set of graphs, to compute the FNGW barycenter between them, please call the following function:

from fngw import fngw_barycenters

It should be noted that the number of nodes in the barycenter graph need to be specified. A circle barycenter example can be found in notebook barycenter_example.ipynb.

Graph Classification

Supervised Graph Prediction: Fingerprint to Molecule

1. Create dataset

In order to create the Fin2Mol dataset of the first Supervised Graph Prediction task presented in our paper, please run the following script:

cd sgp_fin2mol
python create_data.py

Normally, you can found the created input fingerprint and output graph files in the directory sgp_fin2mol.

Supervised Graph Prediction: Metabolite Identification