In [None]:
%matplotlib inline


# Discovering Subgraphs

This example shows how to match a smaller graph to a subset of a larger graph.


In [None]:
# Author: Runzhong Wang <runzhong.wang@sjtu.edu.cn>
#
# License: Mulan PSL v2 License

<div class="alert alert-info"><h4>Note</h4><p>The following solvers are included in this example:

    * :func:`~pygmtools.classic_solvers.rrwm` (classic solver)

    * :func:`~pygmtools.classic_solvers.ipfp` (classic solver)

    * :func:`~pygmtools.classic_solvers.sm` (classic solver)

    * :func:`~pygmtools.neural_solvers.ngm` (neural network solver)</p></div>




In [None]:
import torch # pytorch backend
import pygmtools as pygm
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import networkx as nx # for plotting graphs
pygm.BACKEND = 'pytorch' # set default backend for pygmtools
_ = torch.manual_seed(1) # fix random seed

## Generate the larger graph




In [None]:
num_nodes2 = 10
A2 = torch.rand(num_nodes2, num_nodes2)
A2 = (A2 + A2.t() > 1.) * (A2 + A2.t()) / 2
torch.diagonal(A2)[:] = 0
n2 = torch.tensor([num_nodes2])

## Generate the smaller graph




In [None]:
num_nodes1 = 5
G2 = nx.from_numpy_array(A2.numpy())
pos2 = nx.spring_layout(G2)
pos2_t = torch.tensor([pos2[_] for _ in range(num_nodes2)])
selected = [0] # build G1 as a cluster in visualization
unselected = list(range(1, num_nodes2))
while len(selected) < num_nodes1:
    dist = torch.sum(torch.sum(torch.abs(pos2_t[selected].unsqueeze(1) - pos2_t[unselected].unsqueeze(0)), dim=-1), dim=0)
    select_id = unselected[torch.argmin(dist).item()] # find the closest node from unselected
    selected.append(select_id)
    unselected.remove(select_id)
selected.sort()
A1 = A2[selected, :][:, selected]
X_gt = torch.eye(num_nodes2)[selected, :]
n1 = torch.tensor([num_nodes1])

## Visualize the graphs




In [None]:
G1 = nx.from_numpy_array(A1.numpy())
pos1 = {_: pos2[selected[_]] for _ in range(num_nodes1)}
color1 = ['#FF5733' for _ in range(num_nodes1)]
color2 = ['#FF5733' if _ in selected else '#1f78b4' for _ in range(num_nodes2)]
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Subgraph 1')
plt.gca().margins(0.4)
nx.draw_networkx(G1, pos=pos1, node_color=color1)
plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2, node_color=color2)

We then show how to automatically discover the matching by graph matching.

## Build affinity matrix
To match the larger graph and the smaller graph, we follow the formulation of Quadratic Assignment Problem (QAP):

\begin{align}&\max_{\mathbf{X}} \ \texttt{vec}(\mathbf{X})^\top \mathbf{K} \texttt{vec}(\mathbf{X})\\
    s.t. \quad &\mathbf{X} \in \{0, 1\}^{n_1\times n_2}, \ \mathbf{X}\mathbf{1} = \mathbf{1}, \ \mathbf{X}^\top\mathbf{1} \leq \mathbf{1}\end{align}

where the first step is to build the affinity matrix ($\mathbf{K}$)




In [None]:
conn1, edge1 = pygm.utils.dense_to_sparse(A1)
conn2, edge2 = pygm.utils.dense_to_sparse(A2)
import functools
gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=.001) # set affinity function
K = pygm.utils.build_aff_mat(None, edge1, conn1, None, edge2, conn2, n1, None, n2, None, edge_aff_fn=gaussian_aff)

Visualization of the affinity matrix. For graph matching problem with $N_1$ and $N_2$ nodes,
the affinity matrix has $N_1N_2\times N_1N_2$ elements because there are $N_1^2$ and
$N_2^2$ edges in each graph, respectively.

<div class="alert alert-info"><h4>Note</h4><p>The diagonal elements of the affinity matrix is empty because there is no node features in this example.</p></div>




In [None]:
plt.figure(figsize=(4, 4))
plt.title(f'Affinity Matrix (size: {K.shape[0]}$\\times${K.shape[1]})')
plt.imshow(K.numpy(), cmap='Blues')

## Solve graph matching problem by RRWM solver
See :func:`~pygmtools.classic_solvers.rrwm` for the API reference.




In [None]:
X = pygm.rrwm(K, n1, n2)

The output of RRWM is a soft matching matrix. Visualization:




In [None]:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('RRWM Soft Matching Matrix')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')

## Get the discrete matching matrix
Hungarian algorithm is then adopted to reach a discrete matching matrix




In [None]:
X = pygm.hungarian(X)

Visualization of the discrete matching matrix:




In [None]:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title(f'RRWM Matching Matrix (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')

## Match the subgraph
Draw the matching:




In [None]:
plt.figure(figsize=(8, 4))
plt.suptitle(f'RRWM Matching Result (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
ax1 = plt.subplot(1, 2, 1)
plt.title('Subgraph 1')
plt.gca().margins(0.4)
nx.draw_networkx(G1, pos=pos1, node_color=color1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2, node_color=color2)
for i in range(num_nodes1):
    j = torch.argmax(X[i]).item()
    con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="green" if X_gt[i,j] == 1 else "red")
    plt.gca().add_artist(con)

## Other solvers are also available

### Classic IPFP solver
See :func:`~pygmtools.classic_solvers.ipfp` for the API reference.




In [None]:
X = pygm.ipfp(K, n1, n2)

Visualization of IPFP matching result:




In [None]:
plt.figure(figsize=(8, 4))
plt.suptitle(f'IPFP Matching Result (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
ax1 = plt.subplot(1, 2, 1)
plt.title('Subgraph 1')
plt.gca().margins(0.4)
nx.draw_networkx(G1, pos=pos1, node_color=color1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2, node_color=color2)
for i in range(num_nodes1):
    j = torch.argmax(X[i]).item()
    con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="green" if X_gt[i,j] == 1 else "red")
    plt.gca().add_artist(con)

### Classic SM solver
See :func:`~pygmtools.classic_solvers.sm` for the API reference.




In [None]:
X = pygm.sm(K, n1, n2)
X = pygm.hungarian(X)

Visualization of SM matching result:




In [None]:
plt.figure(figsize=(8, 4))
plt.suptitle(f'SM Matching Result (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
ax1 = plt.subplot(1, 2, 1)
plt.title('Subgraph 1')
plt.gca().margins(0.4)
nx.draw_networkx(G1, pos=pos1, node_color=color1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2, node_color=color2)
for i in range(num_nodes1):
    j = torch.argmax(X[i]).item()
    con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="green" if X_gt[i,j] == 1 else "red")
    plt.gca().add_artist(con)

### NGM neural network solver
See :func:`~pygmtools.neural_solvers.ngm` for the API reference.

<div class="alert alert-info"><h4>Note</h4><p>The NGM solvers are pretrained on a different problem setting, so their performance may seem inferior.
    To improve their performance, you may change the way of building affinity matrices, or try finetuning
    NGM on the new problem.</p></div>




In [None]:
with torch.set_grad_enabled(False):
    X = pygm.ngm(K, n1, n2, pretrain='voc')
    X = pygm.hungarian(X)

Visualization of NGM matching result:




In [None]:
plt.figure(figsize=(8, 4))
plt.suptitle(f'NGM Matching Result (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
ax1 = plt.subplot(1, 2, 1)
plt.title('Subgraph 1')
plt.gca().margins(0.4)
nx.draw_networkx(G1, pos=pos1, node_color=color1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2, node_color=color2)
for i in range(num_nodes1):
    j = torch.argmax(X[i]).item()
    con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data",
                          axesA=ax1, axesB=ax2, color="green" if X_gt[i,j] == 1 else "red")
    plt.gca().add_artist(con)