.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/3.discovering_subgraphs/plot_subgraphs_pytorch.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_3.discovering_subgraphs_plot_subgraphs_pytorch.py: ============================================== PyTorch Backend Example: Discovering Subgraphs ============================================== This example shows how to match a smaller graph to a subset of a larger graph. .. GENERATED FROM PYTHON SOURCE LINES 9-14 .. code-block:: default # Author: Runzhong Wang # # License: Mulan PSL v2 License .. GENERATED FROM PYTHON SOURCE LINES 16-27 .. note:: The following solvers are included in this example: * :func:`~pygmtools.classic_solvers.rrwm` (classic solver) * :func:`~pygmtools.classic_solvers.ipfp` (classic solver) * :func:`~pygmtools.classic_solvers.sm` (classic solver) * :func:`~pygmtools.neural_solvers.ngm` (neural network solver) .. GENERATED FROM PYTHON SOURCE LINES 27-35 .. code-block:: default import torch # pytorch backend import pygmtools as pygm import matplotlib.pyplot as plt # for plotting from matplotlib.patches import ConnectionPatch # for plotting matching result import networkx as nx # for plotting graphs pygm.set_backend('pytorch') # set default backend for pygmtools _ = torch.manual_seed(1) # fix random seed .. GENERATED FROM PYTHON SOURCE LINES 36-39 Generate the larger graph -------------------------- .. GENERATED FROM PYTHON SOURCE LINES 39-45 .. code-block:: default num_nodes2 = 10 A2 = torch.rand(num_nodes2, num_nodes2) A2 = (A2 + A2.t() > 1.) * (A2 + A2.t()) / 2 torch.diagonal(A2)[:] = 0 n2 = torch.tensor([num_nodes2]) .. GENERATED FROM PYTHON SOURCE LINES 46-49 Generate the smaller graph --------------------------- .. GENERATED FROM PYTHON SOURCE LINES 49-65 .. code-block:: default num_nodes1 = 5 G2 = nx.from_numpy_array(A2.numpy()) pos2 = nx.spring_layout(G2) pos2_t = torch.tensor([pos2[_] for _ in range(num_nodes2)]) selected = [0] # build G1 as a cluster in visualization unselected = list(range(1, num_nodes2)) while len(selected) < num_nodes1: dist = torch.sum(torch.sum(torch.abs(pos2_t[selected].unsqueeze(1) - pos2_t[unselected].unsqueeze(0)), dim=-1), dim=0) select_id = unselected[torch.argmin(dist).item()] # find the closest node from unselected selected.append(select_id) unselected.remove(select_id) selected.sort() A1 = A2[selected, :][:, selected] X_gt = torch.eye(num_nodes2)[selected, :] n1 = torch.tensor([num_nodes1]) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/wzever/pygmtools/examples/3.discovering_subgraphs/plot_subgraphs_pytorch.py:52: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:210.) pos2_t = torch.tensor([pos2[_] for _ in range(num_nodes2)]) .. GENERATED FROM PYTHON SOURCE LINES 66-69 Visualize the graphs --------------------- .. GENERATED FROM PYTHON SOURCE LINES 69-82 .. code-block:: default G1 = nx.from_numpy_array(A1.numpy()) pos1 = {_: pos2[selected[_]] for _ in range(num_nodes1)} color1 = ['#FF5733' for _ in range(num_nodes1)] color2 = ['#FF5733' if _ in selected else '#1f78b4' for _ in range(num_nodes2)] plt.figure(figsize=(8, 4)) plt.subplot(1, 2, 1) plt.title('Subgraph 1') plt.gca().margins(0.4) nx.draw_networkx(G1, pos=pos1, node_color=color1) plt.subplot(1, 2, 2) plt.title('Graph 2') nx.draw_networkx(G2, pos=pos2, node_color=color2) .. image-sg:: /auto_examples/3.discovering_subgraphs/images/sphx_glr_plot_subgraphs_pytorch_001.png :alt: Subgraph 1, Graph 2 :srcset: /auto_examples/3.discovering_subgraphs/images/sphx_glr_plot_subgraphs_pytorch_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 83-96 We then show how to automatically discover the matching by graph matching. Build affinity matrix ---------------------- To match the larger graph and the smaller graph, we follow the formulation of Quadratic Assignment Problem (QAP): .. math:: &\max_{\mathbf{X}} \ \texttt{vec}(\mathbf{X})^\top \mathbf{K} \texttt{vec}(\mathbf{X})\\ s.t. \quad &\mathbf{X} \in \{0, 1\}^{n_1\times n_2}, \ \mathbf{X}\mathbf{1} = \mathbf{1}, \ \mathbf{X}^\top\mathbf{1} \leq \mathbf{1} where the first step is to build the affinity matrix (:math:`\mathbf{K}`) .. GENERATED FROM PYTHON SOURCE LINES 96-102 .. code-block:: default conn1, edge1 = pygm.utils.dense_to_sparse(A1) conn2, edge2 = pygm.utils.dense_to_sparse(A2) import functools gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=.001) # set affinity function K = pygm.utils.build_aff_mat(None, edge1, conn1, None, edge2, conn2, n1, None, n2, None, edge_aff_fn=gaussian_aff) .. GENERATED FROM PYTHON SOURCE LINES 103-110 Visualization of the affinity matrix. For graph matching problem with :math:`N_1` and :math:`N_2` nodes, the affinity matrix has :math:`N_1N_2\times N_1N_2` elements because there are :math:`N_1^2` and :math:`N_2^2` edges in each graph, respectively. .. note:: The diagonal elements of the affinity matrix is empty because there is no node features in this example. .. GENERATED FROM PYTHON SOURCE LINES 110-114 .. code-block:: default plt.figure(figsize=(4, 4)) plt.title(f'Affinity Matrix (size: {K.shape[0]}$\\times${K.shape[1]})') plt.imshow(K.numpy(), cmap='Blues') .. image-sg:: /auto_examples/3.discovering_subgraphs/images/sphx_glr_plot_subgraphs_pytorch_002.png :alt: Affinity Matrix (size: 50$\times$50) :srcset: /auto_examples/3.discovering_subgraphs/images/sphx_glr_plot_subgraphs_pytorch_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 115-119 Solve graph matching problem by RRWM solver ------------------------------------------- See :func:`~pygmtools.classic_solvers.rrwm` for the API reference. .. GENERATED FROM PYTHON SOURCE LINES 119-121 .. code-block:: default X = pygm.rrwm(K, n1, n2) .. GENERATED FROM PYTHON SOURCE LINES 122-124 The output of RRWM is a soft matching matrix. Visualization: .. GENERATED FROM PYTHON SOURCE LINES 124-132 .. code-block:: default plt.figure(figsize=(8, 4)) plt.subplot(1, 2, 1) plt.title('RRWM Soft Matching Matrix') plt.imshow(X.numpy(), cmap='Blues') plt.subplot(1, 2, 2) plt.title('Ground Truth Matching Matrix') plt.imshow(X_gt.numpy(), cmap='Blues') .. image-sg:: /auto_examples/3.discovering_subgraphs/images/sphx_glr_plot_subgraphs_pytorch_003.png :alt: RRWM Soft Matching Matrix, Ground Truth Matching Matrix :srcset: /auto_examples/3.discovering_subgraphs/images/sphx_glr_plot_subgraphs_pytorch_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 133-137 Get the discrete matching matrix --------------------------------- Hungarian algorithm is then adopted to reach a discrete matching matrix .. GENERATED FROM PYTHON SOURCE LINES 137-139 .. code-block:: default X = pygm.hungarian(X) .. GENERATED FROM PYTHON SOURCE LINES 140-142 Visualization of the discrete matching matrix: .. GENERATED FROM PYTHON SOURCE LINES 142-150 .. code-block:: default plt.figure(figsize=(8, 4)) plt.subplot(1, 2, 1) plt.title(f'RRWM Matching Matrix (acc={(X * X_gt).sum()/ X_gt.sum():.2f})') plt.imshow(X.numpy(), cmap='Blues') plt.subplot(1, 2, 2) plt.title('Ground Truth Matching Matrix') plt.imshow(X_gt.numpy(), cmap='Blues') .. image-sg:: /auto_examples/3.discovering_subgraphs/images/sphx_glr_plot_subgraphs_pytorch_004.png :alt: RRWM Matching Matrix (acc=1.00), Ground Truth Matching Matrix :srcset: /auto_examples/3.discovering_subgraphs/images/sphx_glr_plot_subgraphs_pytorch_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 151-155 Match the subgraph ------------------- Draw the matching: .. GENERATED FROM PYTHON SOURCE LINES 155-170 .. code-block:: default plt.figure(figsize=(8, 4)) plt.suptitle(f'RRWM Matching Result (acc={(X * X_gt).sum()/ X_gt.sum():.2f})') ax1 = plt.subplot(1, 2, 1) plt.title('Subgraph 1') plt.gca().margins(0.4) nx.draw_networkx(G1, pos=pos1, node_color=color1) ax2 = plt.subplot(1, 2, 2) plt.title('Graph 2') nx.draw_networkx(G2, pos=pos2, node_color=color2) for i in range(num_nodes1): j = torch.argmax(X[i]).item() con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data", axesA=ax1, axesB=ax2, color="green" if X_gt[i,j] == 1 else "red") plt.gca().add_artist(con) .. image-sg:: /auto_examples/3.discovering_subgraphs/images/sphx_glr_plot_subgraphs_pytorch_005.png :alt: RRWM Matching Result (acc=1.00), Subgraph 1, Graph 2 :srcset: /auto_examples/3.discovering_subgraphs/images/sphx_glr_plot_subgraphs_pytorch_005.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 171-178 Other solvers are also available --------------------------------- Classic IPFP solver ^^^^^^^^^^^^^^^^^^^^^ See :func:`~pygmtools.classic_solvers.ipfp` for the API reference. .. GENERATED FROM PYTHON SOURCE LINES 178-180 .. code-block:: default X = pygm.ipfp(K, n1, n2) .. GENERATED FROM PYTHON SOURCE LINES 181-183 Visualization of IPFP matching result: .. GENERATED FROM PYTHON SOURCE LINES 183-198 .. code-block:: default plt.figure(figsize=(8, 4)) plt.suptitle(f'IPFP Matching Result (acc={(X * X_gt).sum()/ X_gt.sum():.2f})') ax1 = plt.subplot(1, 2, 1) plt.title('Subgraph 1') plt.gca().margins(0.4) nx.draw_networkx(G1, pos=pos1, node_color=color1) ax2 = plt.subplot(1, 2, 2) plt.title('Graph 2') nx.draw_networkx(G2, pos=pos2, node_color=color2) for i in range(num_nodes1): j = torch.argmax(X[i]).item() con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data", axesA=ax1, axesB=ax2, color="green" if X_gt[i,j] == 1 else "red") plt.gca().add_artist(con) .. image-sg:: /auto_examples/3.discovering_subgraphs/images/sphx_glr_plot_subgraphs_pytorch_006.png :alt: IPFP Matching Result (acc=1.00), Subgraph 1, Graph 2 :srcset: /auto_examples/3.discovering_subgraphs/images/sphx_glr_plot_subgraphs_pytorch_006.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 199-203 Classic SM solver ^^^^^^^^^^^^^^^^^^^^^ See :func:`~pygmtools.classic_solvers.sm` for the API reference. .. GENERATED FROM PYTHON SOURCE LINES 203-206 .. code-block:: default X = pygm.sm(K, n1, n2) X = pygm.hungarian(X) .. GENERATED FROM PYTHON SOURCE LINES 207-209 Visualization of SM matching result: .. GENERATED FROM PYTHON SOURCE LINES 209-224 .. code-block:: default plt.figure(figsize=(8, 4)) plt.suptitle(f'SM Matching Result (acc={(X * X_gt).sum()/ X_gt.sum():.2f})') ax1 = plt.subplot(1, 2, 1) plt.title('Subgraph 1') plt.gca().margins(0.4) nx.draw_networkx(G1, pos=pos1, node_color=color1) ax2 = plt.subplot(1, 2, 2) plt.title('Graph 2') nx.draw_networkx(G2, pos=pos2, node_color=color2) for i in range(num_nodes1): j = torch.argmax(X[i]).item() con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data", axesA=ax1, axesB=ax2, color="green" if X_gt[i,j] == 1 else "red") plt.gca().add_artist(con) .. image-sg:: /auto_examples/3.discovering_subgraphs/images/sphx_glr_plot_subgraphs_pytorch_007.png :alt: SM Matching Result (acc=1.00), Subgraph 1, Graph 2 :srcset: /auto_examples/3.discovering_subgraphs/images/sphx_glr_plot_subgraphs_pytorch_007.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 225-234 NGM neural network solver ^^^^^^^^^^^^^^^^^^^^^^^^^ See :func:`~pygmtools.neural_solvers.ngm` for the API reference. .. note:: The NGM solvers are pretrained on a different problem setting, so their performance may seem inferior. To improve their performance, you may change the way of building affinity matrices, or try finetuning NGM on the new problem. .. GENERATED FROM PYTHON SOURCE LINES 234-238 .. code-block:: default with torch.set_grad_enabled(False): X = pygm.ngm(K, n1, n2, pretrain='voc') X = pygm.hungarian(X) .. GENERATED FROM PYTHON SOURCE LINES 239-241 Visualization of NGM matching result: .. GENERATED FROM PYTHON SOURCE LINES 241-255 .. code-block:: default plt.figure(figsize=(8, 4)) plt.suptitle(f'NGM Matching Result (acc={(X * X_gt).sum()/ X_gt.sum():.2f})') ax1 = plt.subplot(1, 2, 1) plt.title('Subgraph 1') plt.gca().margins(0.4) nx.draw_networkx(G1, pos=pos1, node_color=color1) ax2 = plt.subplot(1, 2, 2) plt.title('Graph 2') nx.draw_networkx(G2, pos=pos2, node_color=color2) for i in range(num_nodes1): j = torch.argmax(X[i]).item() con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data", axesA=ax1, axesB=ax2, color="green" if X_gt[i,j] == 1 else "red") plt.gca().add_artist(con) .. image-sg:: /auto_examples/3.discovering_subgraphs/images/sphx_glr_plot_subgraphs_pytorch_008.png :alt: NGM Matching Result (acc=0.60), Subgraph 1, Graph 2 :srcset: /auto_examples/3.discovering_subgraphs/images/sphx_glr_plot_subgraphs_pytorch_008.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.545 seconds) .. _sphx_glr_download_auto_examples_3.discovering_subgraphs_plot_subgraphs_pytorch.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_subgraphs_pytorch.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_subgraphs_pytorch.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_