pygmtools: Python Graph Matching Tools
pygmtools
provides graph matching solvers in Python and is easily accessible via:
$ pip install pygmtools
Official documentation: https://pygmtools.readthedocs.io
Source code: https://github.com/Thinklab-SJTU/pygmtools
Graph matching is a fundamental yet challenging problem in pattern recognition, data mining, and others. Graph matching aims to find node-to-node correspondence among multiple graphs, by solving an NP-hard combinatorial optimization problem.
Doing graph matching in Python used to be difficult, and this library wants to make researchers’ lives easier.
To highlight, pygmtools
has the following features:
Support various solvers, including traditional combinatorial solvers (including linear, quadratic, and multi-graph) and novel deep learning-based solvers;
Support various backends, including
numpy
which is universally accessible, and some state-of-the-art deep learning architectures with GPU support:pytorch
,paddle
,jittor
.Deep learning friendly, the operations are designed to best preserve the gradient during computation and batched operations support for the best performance.
Installation
You can install the stable release on PyPI:
$ pip install pygmtools
or get the latest version by running:
$ pip install -U https://github.com/Thinklab-SJTU/pygmtools/archive/master.zip # with --user for user install (no root)
Now the pygmtools is available with the numpy
backend.
The following packages are required, and shall be automatically installed by pip
:
Python >= 3.5
requests >= 2.25.1
scipy >= 1.4.1
Pillow >= 7.2.0
numpy >= 1.18.5
easydict >= 1.7
appdirs >= 1.4.4
tqdm >= 4.64.1
Available Graph Matching Solvers
This library offers user-friendly API for the following solvers:
-
Linear assignment solvers including the differentiable soft Sinkhorn algorithm [1], and the exact solver Hungarian [2].
Soft and differentiable quadratic assignment solvers, including spectral graph matching [3] and random-walk-based graph matching [4].
Discrete (non-differentiable) quadratic assignment solver integer projected fixed point method [5].
-
Composition based Affinity Optimization (CAO) solver [6] by optimizing the affinity score, meanwhile gradually infusing the consistency.
Multi-Graph Matching based on Floyd shortest path algorithm [7].
Graduated-assignment based multi-graph matching solver [8][9] by graduated annealing of Sinkhorn’s temperature.
-
Intra-graph and cross-graph embedding based neural graph matching solvers PCA-GM and IPCA-GM [10] for matching individual graphs.
Channel independent embedding (CIE) [11] based neural graph matching solver for matching individual graphs.
Neural graph matching solver (NGM) [12] for the general quadratic assignment formulation.
Available Backends
This library is designed to support multiple backends with the same set of API. Please follow the official instructions to install your backend.
The following backends are available:
Numpy (default backend, CPU only)
PyTorch (recommended backend, GPU friendly, deep learning friendly)
PaddlePaddle (GPU friendly, deep learning friendly)
Jittor (GPU friendly, deep learning friendly)
For more details, please read the documentation.
The Deep Graph Matching Benchmark
pygmtools
is also featured with a standard data interface of several graph matching benchmarks. We also maintain a
repository containing non-trivial implementation of deep graph matching models, please check out
ThinkMatch if you are interested!
Chat with the Community
If you have any questions, or if you are experiencing any issues, feel free to raise an issue on GitHub.
We also offer the following chat rooms if you are more comfortable with them:
Discord (for English users):
:raw-html-m2r:`<img src=”https://discordapp.com/api/guilds/1028701206526304317/widget.png?style=banner2”> <https://discord.gg/8m6n7rRz9T>`_
QQ Group (for Chinese users)/QQ群(中文用户): 696401889
:raw-html-m2r:`<img src=”http://pub.idqqimg.com/wpa/images/group.png” alt=”ThinkMatch/pygmtools交流群” title=”ThinkMatch/pygmtools交流群”> <https://qm.qq.com/cgi-bin/qm/qr?k=NlPuwwvaFaHzEWD8w7jSOTzoqSLIM80V&jump_from=webapi&authKey=chI2htrWDujQed6VtVid3V1NXEoJvwz3MVwruax6x5lQIvLsC8BmpmzBJOCzhtQd>`_
Contributing
Any contributions/ideas/suggestions from the community is welcomed! Before starting your contribution, please read the Contributing Guide.
Developers and Maintainers
pygmtools
is currently developed and maintained by members from ThinkLab at
Shanghai Jiao Tong University.
References
[1] Sinkhorn, Richard, and Paul Knopp. “Concerning nonnegative matrices and doubly stochastic matrices.” Pacific Journal of Mathematics 21.2 (1967): 343-348.
[2] Munkres, James. “Algorithms for the assignment and transportation problems.” Journal of the society for industrial and applied mathematics 5.1 (1957): 32-38.
[3] Leordeanu, Marius, and Martial Hebert. “A spectral technique for correspondence problems using pairwise constraints.” International Conference on Computer Vision (2005).
[4] Cho, Minsu, Jungmin Lee, and Kyoung Mu Lee. “Reweighted random walks for graph matching.” European conference on Computer vision. Springer, Berlin, Heidelberg, 2010.
[5] Leordeanu, Marius, Martial Hebert, and Rahul Sukthankar. “An integer projected fixed point method for graph matching and map inference.” Advances in neural information processing systems 22 (2009).
[6] Yan, Junchi, et al. “Multi-graph matching via affinity optimization with graduated consistency regularization.” IEEE transactions on pattern analysis and machine intelligence 38.6 (2015): 1228-1242.
[7] Jiang, Zetian, Tianzhe Wang, and Junchi Yan. “Unifying offline and online multi-graph matching via finding shortest paths on supergraph.” IEEE transactions on pattern analysis and machine intelligence 43.10 (2020): 3648-3663.
[8] Solé-Ribalta, Albert, and Francesc Serratosa. “Graduated assignment algorithm for multiple graph matching based on a common labeling.” International Journal of Pattern Recognition and Artificial Intelligence 27.01 (2013): 1350001.
[9] Wang, Runzhong, Junchi Yan, and Xiaokang Yang. “Graduated assignment for joint multi-graph matching and clustering with application to unsupervised graph matching network learning.” Advances in Neural Information Processing Systems 33 (2020): 19908-19919.
[10] Wang, Runzhong, Junchi Yan, and Xiaokang Yang. “Combinatorial learning of robust deep graph matching: an embedding based approach.” IEEE Transactions on Pattern Analysis and Machine Intelligence (2020).
[11] Yu, Tianshu, et al. “Learning deep graph matching with channel-independent embedding and hungarian attention.” International conference on learning representations. 2019.
[12] Wang, Runzhong, Junchi Yan, and Xiaokang Yang. “Neural graph matching network: Learning lawler’s quadratic assignment problem with extension to hypergraph and multiple-graph matching.” IEEE Transactions on Pattern Analysis and Machine Intelligence (2021).
Contents of Official Documentation
Introduction and Guidelines
This page provides a brief introduction to graph matching and some guidelines for using pygmtools
.
If you are seeking some background information, this is the right place!
Note
For more technical details, we recommend the following two surveys.
About learning-based deep graph matching: Junchi Yan, Shuang Yang, Edwin Hancock. “Learning Graph Matching and Related Combinatorial Optimization Problems.” IJCAI 2020.
About non-learning two-graph matching and multi-graph matching: Junchi Yan, Xu-Cheng Yin, Weiyao Lin, Cheng Deng, Hongyuan Zha, Xiaokang Yang. “A Short Survey of Recent Advances in Graph Matching.” ICMR 2016.
Why Graph Matching?
Graph Matching (GM) is a fundamental yet challenging problem in pattern recognition, data mining, and others. GM aims to find node-to-node correspondence among multiple graphs, by solving an NP-hard combinatorial problem. Recently, there is growing interest in developing deep learning-based graph matching methods.
Compared to other straight-forward matching methods e.g. greedy matching, graph matching methods are more reliable because it is based on an optimization form. Besides, graph matching methods exploit both node affinity and edge affinity, thus graph matching methods are usually more robust to noises and outliers. The recent line of deep graph matching methods also enables many graph matching solvers to be integrated into a deep learning pipeline.
Graph matching techniques have been applied to the following applications:
-
-
Model ensemble and federated learning
-
and more…
If your task involves matching two or more graphs, you should try the solvers in pygmtools
!
What is Graph Matching?
The Graph Matching Pipeline
Solving a real-world graph-matching problem may involve the following steps:
Extract node/edge features from the graphs you want to match.
Build an affinity matrix from node/edge features.
Solve the graph matching problem with GM solvers.
And Step 1 may be done by methods depending on your application, Step 2&3 can be handled by pygmtools
.
The following plot illustrates a standard deep graph matching pipeline.

The Math Form
Let’s involve a little bit of math to better understand the graph matching pipeline. In general, graph matching is of the following form, known as Quadratic Assignment Problem (QAP):
The notations are explained as follows:
\(\mathbf{X}\) is known as the permutation matrix which encodes the matching result. It is also the decision variable in graph matching problem. \(\mathbf{X}_{i,a}=1\) means node \(i\) in graph 1 is matched to node \(a\) in graph 2, and \(\mathbf{X}_{i,a}=0\) means non-matched. Without loss of generality, it is assumed that \(n_1\leq n_2.\) \(\mathbf{X}\) has the following constraints:
The sum of each row must be equal to 1: \(\mathbf{X}\mathbf{1} = \mathbf{1}\);
The sum of each column must be equal to, or smaller than 1: \(\mathbf{X}\mathbf{1} \leq \mathbf{1}\).
\(\mathtt{vec}(\mathbf{X})\) means the column-wise vectorization form of \(\mathbf{X}\).
\(\mathbf{1}\) means a column vector whose elements are all 1s.
\(\mathbf{K}\) is known as the affinity matrix which encodes the information of the input graphs. Both node-wise and edge-wise affinities are encoded in \(\mathbf{K}\):
The diagonal element \(\mathbf{K}_{i + a\times n_1, i + a\times n_1}\) means the node-wise affinity of node \(i\) in graph 1 and node \(a\) in graph 2;
The off-diagonal element \(\mathbf{K}_{i + a\times n_1, j + b\times n_1}\) means the edge-wise affinity of edge \(ij\) in graph 1 and edge \(ab\) in graph 2.
Graph Matching Best Practice
We need to understand the advantages and limitations of graph matching solvers. As discussed above, the major advantage of graph matching solvers is that they are more robust to noises and outliers. Graph matching also utilizes edge information, which is usually ignored in linear matching methods. The major drawback of graph matching solvers is their efficiency and scalability since the optimization problem is NP-hard. Therefore, to decide which matching method is most suitable, one needs to balance between the required matching accuracy and the affordable time and memory cost according to his/her application.
Note
Anyway, it does no harm to try graph matching first!
When to use pygmtools
pygmtools
is recommended for the following cases, and you could benefit from the friendly API:
If you want to integrate graph matching as a step of your pipeline (either learning or non-learning).
If you want a quick benchmarking and profiling of the graph matching solvers available in
pygmtools
.If you do not want to dive too deep into the algorithm details and do not need to modify the algorithm.
We offer the following guidelines for your reference:
If you want to integrate graph matching solvers into your end-to-end supervised deep learning pipeline, try
neural_solvers
.If no ground truth label is available for the matching step, try
classic_solvers
.If there are multiple graphs to be jointly matched, try
multi_graph_solvers
.If time and memory cost of the above methods are unacceptable for your task, try
linear_solvers
.
When not to use pygmtools
As a highly packed toolkit, pygmtools
lacks some flexibilities in the implementation details, especially for
experts in graph matching. If you are researching new graph matching algorithms or developing next-generation deep
graph matching neural networks, pygmtools
may not be suitable. We recommend
ThinkMatch as the protocol for academic research.
What’s Next
Please read the Get Started guide.
Get Started
Basic Install by pip
You can install the stable release on PyPI:
$ pip install pygmtools
or get the latest version by running:
$ pip install -U https://github.com/Thinklab-SJTU/pygmtools/archive/master.zip # with --user for user install (no root)
Now the pygmtools is available with the numpy
backend:

You may jump to Example: Matching Isomorphic Graphs if you do not need other backends.
The following packages are required, and shall be automatically installed by pip
:
Python >= 3.5
requests >= 2.25.1
scipy >= 1.4.1
Pillow >= 7.2.0
numpy >= 1.18.5
easydict >= 1.7
appdirs >= 1.4.4
tqdm >= 4.64.1
Install Other Backends
Currently, we also support deep learning architectures pytorch
, paddle
, jittor
which are GPU-friendly and deep learning-friendly.
Once the backend is ready, you may switch to the backend globally by the following command:
>>> import pygmtools as pygm
>>> pygm.BACKEND = 'pytorch' # replace 'pytorch' by other backend names
PyTorch Backend

PyTorch is an open-source machine learning framework developed and maintained by Meta Inc./Linux Foundation.
PyTorch is popular, especially among the deep learning research community.
The PyTorch backend of pygmtools
is designed to support GPU devices and facilitate deep learning research.
Please follow the official PyTorch installation guide.
This package is developed with torch==1.6.0
and shall work with any PyTorch versions >=1.6.0
.
How to enable PyTorch backend:
>>> import pygmtools as pygm
>>> import torch
>>> pygm.BACKEND = 'pytorch'
Paddle Backend

PaddlePaddle is an open-source deep learning platform originated from industrial practice, which is developed and
maintained by Baidu Inc.
The Paddle backend of pygmtools
is designed to support GPU devices and deep learning applications.
Please follow the official PaddlePaddle installation guide.
This package is developed with paddlepaddle==2.3.1
and shall work with any PaddlePaddle versions >=2.3.1
.
How to enable Paddle backend:
>>> import pygmtools as pygm
>>> import paddle
>>> pygm.BACKEND = 'paddle'
Jittor Backend

Jittor is an open-source deep learning platform based on just-in-time (JIT) for high performance, which is developed
and maintained by the CSCG group from Tsinghua University.
The Jittor backend of pygmtools
is designed to support GPU devices and deep learning applications.
Please follow the official Jittor installation guide.
This package is developed with jittor==1.3.4.16
and shall work with any Jittor versions >=1.3.4.16
.
How to enable Jittor backend:
>>> import pygmtools as pygm
>>> import jittor
>>> pygm.BACKEND = 'jittor'
Example: Matching Isomorphic Graphs
Here we provide a basic example of matching two isomorphic graphs (i.e. two graphs have the same nodes and edges, but the node permutations are unknown).
Step 0: Import packages and set backend
>>> import numpy as np
>>> import pygmtools as pygm
>>> pygm.BACKEND = 'numpy'
>>> np.random.seed(1)
Step 1: Generate a batch of isomorphic graphs
>>> batch_size = 3
>>> X_gt = np.zeros((batch_size, 4, 4))
>>> X_gt[:, np.arange(0, 4, dtype=np.int64), np.random.permutation(4)] = 1
>>> A1 = np.random.rand(batch_size, 4, 4)
>>> A2 = np.matmul(np.matmul(X_gt.transpose((0, 2, 1)), A1), X_gt)
>>> n1 = n2 = np.repeat([4], batch_size)
Step 2: Build an affinity matrix and select an affinity function
>>> conn1, edge1, ne1 = pygm.utils.dense_to_sparse(A1)
>>> conn2, edge2, ne2 = pygm.utils.dense_to_sparse(A2)
>>> import functools
>>> gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=1.) # set affinity function
>>> K = pygm.utils.build_aff_mat(None, edge1, conn1, None, edge2, conn2, n1, ne1, n2, ne2, edge_aff_fn=gaussian_aff)
Step 3: Solve graph matching by RRWM
>>> X = pygm.rrwm(K, n1, n2, beta=100)
>>> X = pygm.hungarian(X)
>>> X # X is the permutation matrix
[[[0. 0. 0. 1.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]]
[[0. 0. 0. 1.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]]
[[0. 0. 0. 1.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]]]
Final Step: Evaluate the accuracy
>>> (X * X_gt).sum() / X_gt.sum()
1.0
What’s Next
Please checkout Examples Gallery to see how to apply pygmtools
to tackle real-world problems.
You may see API and Modules for the API documentation.
Graph Matching Benchmark
pygmtools also provides a protocol to fairly compare existing deep graph matching algorithms under different datasets & experiment settings.
The Benchmark
module provides a unified data interface and an evaluating platform for different datasets.
If you are interested in the performance and the full deep learning pipeline, please refer to our ThinkMatch project.
Evaluation Metrics and Results
Our evaluation metrics include matching_precision (p), matching_recall (r) and f1_score (f1). Also, to measure the reliability of the evaluation result, we define coverage (cvg) for each class in the dataset as the number of evaluated pairs in the class/number of all possible pairs in the class. Therefore, larger coverage refers to higher reliability.
An example of evaluation result (p==r==f1
because this evaluation does not involve partial matching/outliers):
Matching accuracy
Car: p = 0.8395±0.2280, r = 0.8395±0.2280, f1 = 0.8395±0.2280, cvg = 1.0000
Duck: p = 0.7713±0.2255, r = 0.7713±0.2255, f1 = 0.7713±0.2255, cvg = 1.0000
Face: p = 0.9656±0.0913, r = 0.9656±0.0913, f1 = 0.9656±0.0913, cvg = 0.2612
Motorbike: p = 0.8821±0.1821, r = 0.8821±0.1821, f1 = 0.8821±0.1821, cvg = 1.0000
Winebottle: p = 0.8929±0.1569, r = 0.8929±0.1569, f1 = 0.8929±0.1569, cvg = 0.9662
average accuracy: p = 0.8703±0.1767, r = 0.8703±0.1767, f1 = 0.8703±0.1767
Evaluation complete in 1m 55s
Available Datasets
Dataset can be automatically downloaded and unzipped, but you can also download the dataset yourself, and make sure it in the right path.
PascalVOC-Keypoint Dataset
Download VOC2011 dataset and make sure it looks like
data/PascalVOC/TrainVal/VOCdevkit/VOC2011
Download keypoint annotation for VOC2011 from Berkeley server or google drive and make sure it looks like
data/PascalVOC/annotations
Download the train/test split file and make sure it looks like
data/PascalVOC/voc2011_pairs.npz
Please cite the following papers if you use PascalVOC-Keypoint dataset:
@article{EveringhamIJCV10,
title={The pascal visual object classes (voc) challenge},
author={Everingham, Mark and Van Gool, Luc and Williams, Christopher KI and Winn, John and Zisserman, Andrew},
journal={International Journal of Computer Vision},
volume={88},
pages={303–338},
year={2010}
}
@inproceedings{BourdevICCV09,
title={Poselets: Body part detectors trained using 3d human pose annotations},
author={Bourdev, L. and Malik, J.},
booktitle={International Conference on Computer Vision},
pages={1365--1372},
year={2009},
organization={IEEE}
}
Willow-Object-Class Dataset
Download Willow-ObjectClass dataset
Unzip the dataset and make sure it looks like
data/WillowObject/WILLOW-ObjectClass
Please cite the following paper if you use Willow-Object-Class dataset:
@inproceedings{ChoICCV13,
author={Cho, Minsu and Alahari, Karteek and Ponce, Jean},
title = {Learning Graphs to Match},
booktitle = {International Conference on Computer Vision},
pages={25--32},
year={2013}
}
CUB2011 Dataset
Download CUB-200-2011 dataset.
Unzip the dataset and make sure it looks like
data/CUB_200_2011/CUB_200_2011
Please cite the following report if you use CUB2011 dataset:
@techreport{CUB2011,
Title = {{The Caltech-UCSD Birds-200-2011 Dataset}},
Author = {Wah, C. and Branson, S. and Welinder, P. and Perona, P. and Belongie, S.},
Year = {2011},
Institution = {California Institute of Technology},
Number = {CNS-TR-2011-001}
}
IMC-PT-SparseGM Dataset
Download the IMC-PT-SparseGM dataset from google drive or baidu drive (code: 0576)
Unzip the dataset and make sure it looks like
data/IMC_PT_SparseGM/annotations
Please cite the following papers if you use IMC-PT-SparseGM dataset:
@article{JinIJCV21,
title={Image Matching across Wide Baselines: From Paper to Practice},
author={Jin, Yuhe and Mishkin, Dmytro and Mishchuk, Anastasiia and Matas, Jiri and Fua, Pascal and Yi, Kwang Moo and Trulls, Eduard},
journal={International Journal of Computer Vision},
pages={517--547},
year={2021}
}
API Reference
See the API doc of Benchmark module and the API doc of datasets for details.
File Organization
dataset.py
: The file includes 5 dataset classes, used to automatically download the dataset and process the dataset into a json file, and also save the training set and the testing set.benchmark.py
: The file includes Benchmark class that can be used to fetch data from the json file and evaluate prediction results.dataset_config.py
: The default dataset settings, mostly dataset path and classes.
Example
import pygmtools as pygm
from pygm.benchmark import Benchmark
# Define Benchmark on PascalVOC.
bm = Benchmark(name='PascalVOC', sets='train',
obj_resize=(256, 256), problem='2GM',
filter='intersection')
# Random fetch data and ground truth.
data_list, gt_dict, _ = bm.rand_get_data(cls=None, num=2)
API and Modules
Classic (learning-free) linear assignment problem solvers. |
|
Classic (learning-free) two-graph matching solvers. |
|
Classic (learning-free) multi-graph matching solvers. |
|
Neural network-based graph matching solvers. |
|
Utility functions: problem formulating, data processing, and beyond. |
|
The Benchmark module with a unified data interface to evaluate graph matching methods. |
|
The implementations of data loading and data processing. |
Warning
By default the API functions and modules run on numpy
backend. You could set the default backend by setting
pygm.BACKEND
. If you enable other backends than numpy
, the corresponding package should be installed. See
the installation guide for details.
Examples Gallery
Below is the gallery of pygmtools
examples (categorized by the backend).
Warning
The examples are under construction. Will be updated very soon.
PyTorch Backend Examples

Matching Image Keypoints by Graph Matching Neural Networks
Note
Click here to download the full example code
Introduction: Matching Isomorphic Graphs
This example is an introduction to pygmtools
which shows how to match isomorphic graphs.
Isomorphic graphs means graphs whose structures are identical, but the node correspondence is unknown.
# Author: Runzhong Wang <runzhong.wang@sjtu.edu.cn>
#
# License: Mulan PSL v2 License
Note
The following solvers support QAP formulation, and are included in this example:
import torch # pytorch backend
import pygmtools as pygm
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import networkx as nx # for plotting graphs
pygm.BACKEND = 'pytorch' # set default backend for pygmtools
_ = torch.manual_seed(1) # fix random seed
Generate two isomorphic graphs
num_nodes = 10
X_gt = torch.zeros(num_nodes, num_nodes)
X_gt[torch.arange(0, num_nodes, dtype=torch.int64), torch.randperm(num_nodes)] = 1
A1 = torch.rand(num_nodes, num_nodes)
A1 = (A1 + A1.t() > 1.) * (A1 + A1.t()) / 2
torch.diagonal(A1)[:] = 0
A2 = torch.mm(torch.mm(X_gt.t(), A1), X_gt)
n1 = torch.tensor([num_nodes])
n2 = torch.tensor([num_nodes])
Visualize the graphs
plt.figure(figsize=(8, 4))
G1 = nx.from_numpy_array(A1.numpy())
G2 = nx.from_numpy_array(A2.numpy())
pos1 = nx.spring_layout(G1)
pos2 = nx.spring_layout(G2)
plt.subplot(1, 2, 1)
plt.title('Graph 1')
nx.draw_networkx(G1, pos=pos1)
plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2)

These two graphs look dissimilar because they are not aligned. We then align these two graphs by graph matching.
Build affinity matrix
To match isomorphic graphs by graph matching, we follow the formulation of Quadratic Assignment Problem (QAP):
where the first step is to build the affinity matrix (\(\mathbf{K}\))
conn1, edge1 = pygm.utils.dense_to_sparse(A1)
conn2, edge2 = pygm.utils.dense_to_sparse(A2)
import functools
gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=.1) # set affinity function
K = pygm.utils.build_aff_mat(None, edge1, conn1, None, edge2, conn2, n1, None, n2, None, edge_aff_fn=gaussian_aff)
Visualization of the affinity matrix. For graph matching problem with \(N\) nodes, the affinity matrix has \(N^2\times N^2\) elements because there are \(N^2\) edges in each graph.
Note
The diagonal elements of the affinity matrix is empty because there is no node features in this example.
plt.figure(figsize=(4, 4))
plt.title(f'Affinity Matrix (size: {K.shape[0]}$\\times${K.shape[1]})')
plt.imshow(K.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7efe935af710>
Solve graph matching problem by RRWM solver
See rrwm()
for the API reference.
X = pygm.rrwm(K, n1, n2)
The output of RRWM is a soft matching matrix. Visualization:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('RRWM Soft Matching Matrix')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7efe934020d0>
Get the discrete matching matrix
Hungarian algorithm is then adopted to reach a discrete matching matrix
X = pygm.hungarian(X)
Visualization of the discrete matching matrix:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title(f'RRWM Matching Matrix (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7efe932e0650>
Align the original graphs
Draw the matching (green lines for correct matching, red lines for wrong matching):
plt.figure(figsize=(8, 4))
ax1 = plt.subplot(1, 2, 1)
plt.title('Graph 1')
nx.draw_networkx(G1, pos=pos1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2)
for i in range(num_nodes):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="green" if X_gt[i, j] else "red")
plt.gca().add_artist(con)

Align the nodes:
align_A2 = torch.mm(torch.mm(X, A2), X.t())
plt.figure(figsize=(8, 4))
ax1 = plt.subplot(1, 2, 1)
plt.title('Graph 1')
nx.draw_networkx(G1, pos=pos1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Aligned Graph 2')
align_pos2 = {}
for i in range(num_nodes):
j = torch.argmax(X[i]).item()
align_pos2[j] = pos1[i]
con = ConnectionPatch(xyA=pos1[i], xyB=align_pos2[j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="green" if X_gt[i, j] else "red")
plt.gca().add_artist(con)
nx.draw_networkx(G2, pos=align_pos2)

Other solvers are also available
See ipfp()
for the API reference.
X = pygm.ipfp(K, n1, n2)
Visualization of IPFP matching result:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title(f'IPFP Matching Matrix (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7efe931fa1d0>
See sm()
for the API reference.
X = pygm.sm(K, n1, n2)
X = pygm.hungarian(X)
Visualization of SM matching result:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title(f'SM Matching Matrix (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7efe93504f10>
See ngm()
for the API reference.
with torch.set_grad_enabled(False):
X = pygm.ngm(K, n1, n2, pretrain='voc')
X = pygm.hungarian(X)
Downloading to /home/docs/.cache/pygmtools/ngm_voc_pytorch.pt...
0%| | 0/23119 [00:00<?, ?it/s]
100%|##########| 22.6k/22.6k [00:00<00:00, 1.62MB/s]
Visualization of NGM matching result:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title(f'NGM Matching Matrix (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7efe9d4de250>
Total running time of the script: ( 0 minutes 3.455 seconds)
Note
Click here to download the full example code
Seeded Graph Matching
Seeded graph matching means some partial of the matching result is already known, and the known matching
results are called “seeds”. In this example, we show how to exploit such prior with pygmtools
.
# Author: Runzhong Wang <runzhong.wang@sjtu.edu.cn>
#
# License: Mulan PSL v2 License
Note
How to perform seeded graph matching is still an open research problem. In this example, we show a
simple yet effective approach that works with pygmtools
.
Note
The following solvers are included in this example:
import torch # pytorch backend
import pygmtools as pygm
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import networkx as nx # for plotting graphs
pygm.BACKEND = 'pytorch' # set default backend for pygmtools
_ = torch.manual_seed(1) # fix random seed
Generate two isomorphic graphs (with seeds)
In this example, we assume the first three nodes are already aligned. Firstly, we generate the seed matching matrix:
num_nodes = 10
num_seeds = 3
seed_mat = torch.zeros(num_nodes, num_nodes)
seed_mat[:num_seeds, :num_seeds] = torch.eye(num_seeds)
Then we generate the isomorphic graphs:
X_gt = seed_mat.clone()
X_gt[num_seeds:, num_seeds:][torch.arange(0, num_nodes-num_seeds, dtype=torch.int64), torch.randperm(num_nodes-num_seeds)] = 1
A1 = torch.rand(num_nodes, num_nodes)
A1 = (A1 + A1.t() > 1.) * (A1 + A1.t()) / 2
torch.diagonal(A1)[:] = 0
A2 = torch.mm(torch.mm(X_gt.t(), A1), X_gt)
n1 = torch.tensor([num_nodes])
n2 = torch.tensor([num_nodes])
Visualize the graphs and seeds
The seed matching matrix:
plt.figure(figsize=(4, 4))
plt.title('Seed Matching Matrix')
plt.imshow(seed_mat.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7efe935ad890>
The blue lines denote the matching seeds.
plt.figure(figsize=(8, 4))
G1 = nx.from_numpy_array(A1.numpy())
G2 = nx.from_numpy_array(A2.numpy())
pos1 = nx.spring_layout(G1)
pos2 = nx.spring_layout(G2)
ax1 = plt.subplot(1, 2, 1)
plt.title('Graph 1')
nx.draw_networkx(G1, pos=pos1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2)
for i in range(num_seeds):
j = torch.argmax(seed_mat[i]).item()
con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="blue")
plt.gca().add_artist(con)

Now these two graphs look dissimilar because they are not aligned. We then align these two graphs by graph matching.
Build affinity matrix with seed prior
We follow the formulation of Quadratic Assignment Problem (QAP):
where the first step is to build the affinity matrix (\(\mathbf{K}\)). We firstly build a “standard” affinity matrix:
conn1, edge1 = pygm.utils.dense_to_sparse(A1)
conn2, edge2 = pygm.utils.dense_to_sparse(A2)
import functools
gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=.1) # set affinity function
K = pygm.utils.build_aff_mat(None, edge1, conn1, None, edge2, conn2, n1, None, n2, None, edge_aff_fn=gaussian_aff)
The next step is to add the seed matching information as priors to the affinity matrix. The matching priors are treated as node affinities and the corresponding node affinity is added by 10 if there is an matching prior.
Note
The node affinity matrix is transposed because in the graph matching formulation followed by pygmtools
,
\(\texttt{vec}(\mathbf{X})\) means column vectorization. The node affinity should also be column-
vectorized.
torch.diagonal(K)[:] += seed_mat.t().reshape(-1) * 10
Visualization of the affinity matrix.
Note
In this example, the diagonal elements reflect the matching prior.
plt.figure(figsize=(4, 4))
plt.title(f'Affinity Matrix (size: {K.shape[0]}$\\times${K.shape[1]})')
plt.imshow(K.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7efe93115d10>
Solve graph matching problem by RRWM solver
See rrwm()
for the API reference.
X = pygm.rrwm(K, n1, n2)
The output of RRWM is a soft matching matrix. The matching prior is well-preserved:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('RRWM Soft Matching Matrix')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7efe9341d4d0>
Get the discrete matching matrix
Hungarian algorithm is then adopted to reach a discrete matching matrix
X = pygm.hungarian(X)
Visualization of the discrete matching matrix:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title(f'RRWM Matching Matrix (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7efe93449710>
Align the original graphs
Draw the matching (green lines for correct matching, red lines for wrong matching, blue lines for seed matching):
plt.figure(figsize=(8, 4))
ax1 = plt.subplot(1, 2, 1)
plt.title('Graph 1')
nx.draw_networkx(G1, pos=pos1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2)
for i in range(num_nodes):
j = torch.argmax(X[i]).item()
if seed_mat[i, j]:
line_color = "blue"
elif X_gt[i, j]:
line_color = "green"
else:
line_color = "red"
con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color=line_color)
plt.gca().add_artist(con)

Align the nodes:
align_A2 = torch.mm(torch.mm(X, A2), X.t())
plt.figure(figsize=(8, 4))
ax1 = plt.subplot(1, 2, 1)
plt.title('Graph 1')
nx.draw_networkx(G1, pos=pos1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Aligned Graph 2')
align_pos2 = {}
for i in range(num_nodes):
j = torch.argmax(X[i]).item()
align_pos2[j] = pos1[i]
if seed_mat[i, j]:
line_color = "blue"
elif X_gt[i, j]:
line_color = "green"
else:
line_color = "red"
con = ConnectionPatch(xyA=pos1[i], xyB=align_pos2[j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color=line_color)
plt.gca().add_artist(con)
nx.draw_networkx(G2, pos=align_pos2)

Other solvers are also available
Only the affinity matrix is modified to encode matching priors, thus other graph matching solvers are also available to handle this seeded graph matching setting.
See ipfp()
for the API reference.
X = pygm.ipfp(K, n1, n2)
Visualization of IPFP matching result:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title(f'IPFP Matching Matrix (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7efe91ca1350>
See sm()
for the API reference.
X = pygm.sm(K, n1, n2)
X = pygm.hungarian(X)
Visualization of SM matching result:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title(f'SM Matching Matrix (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7efe93241ed0>
See ngm()
for the API reference.
with torch.set_grad_enabled(False):
X = pygm.ngm(K, n1, n2, pretrain='voc')
X = pygm.hungarian(X)
Visualization of NGM matching result:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title(f'NGM Matching Matrix (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7efe931bdfd0>
Total running time of the script: ( 0 minutes 1.651 seconds)
Note
Click here to download the full example code
Discovering Subgraphs
This example shows how to match a smaller graph to a subset of a larger graph.
# Author: Runzhong Wang <runzhong.wang@sjtu.edu.cn>
#
# License: Mulan PSL v2 License
Note
The following solvers are included in this example:
import torch # pytorch backend
import pygmtools as pygm
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import networkx as nx # for plotting graphs
pygm.BACKEND = 'pytorch' # set default backend for pygmtools
_ = torch.manual_seed(1) # fix random seed
Generate the larger graph
num_nodes2 = 10
A2 = torch.rand(num_nodes2, num_nodes2)
A2 = (A2 + A2.t() > 1.) * (A2 + A2.t()) / 2
torch.diagonal(A2)[:] = 0
n2 = torch.tensor([num_nodes2])
Generate the smaller graph
num_nodes1 = 5
G2 = nx.from_numpy_array(A2.numpy())
pos2 = nx.spring_layout(G2)
pos2_t = torch.tensor([pos2[_] for _ in range(num_nodes2)])
selected = [0] # build G1 as a cluster in visualization
unselected = list(range(1, num_nodes2))
while len(selected) < num_nodes1:
dist = torch.sum(torch.sum(torch.abs(pos2_t[selected].unsqueeze(1) - pos2_t[unselected].unsqueeze(0)), dim=-1), dim=0)
select_id = unselected[torch.argmin(dist).item()] # find the closest node from unselected
selected.append(select_id)
unselected.remove(select_id)
selected.sort()
A1 = A2[selected, :][:, selected]
X_gt = torch.eye(num_nodes2)[selected, :]
n1 = torch.tensor([num_nodes1])
Visualize the graphs
G1 = nx.from_numpy_array(A1.numpy())
pos1 = {_: pos2[selected[_]] for _ in range(num_nodes1)}
color1 = ['#FF5733' for _ in range(num_nodes1)]
color2 = ['#FF5733' if _ in selected else '#1f78b4' for _ in range(num_nodes2)]
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Subgraph 1')
plt.gca().margins(0.4)
nx.draw_networkx(G1, pos=pos1, node_color=color1)
plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2, node_color=color2)

We then show how to automatically discover the matching by graph matching.
Build affinity matrix
To match the larger graph and the smaller graph, we follow the formulation of Quadratic Assignment Problem (QAP):
where the first step is to build the affinity matrix (\(\mathbf{K}\))
conn1, edge1 = pygm.utils.dense_to_sparse(A1)
conn2, edge2 = pygm.utils.dense_to_sparse(A2)
import functools
gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=.001) # set affinity function
K = pygm.utils.build_aff_mat(None, edge1, conn1, None, edge2, conn2, n1, None, n2, None, edge_aff_fn=gaussian_aff)
Visualization of the affinity matrix. For graph matching problem with \(N_1\) and \(N_2\) nodes, the affinity matrix has \(N_1N_2\times N_1N_2\) elements because there are \(N_1^2\) and \(N_2^2\) edges in each graph, respectively.
Note
The diagonal elements of the affinity matrix is empty because there is no node features in this example.
plt.figure(figsize=(4, 4))
plt.title(f'Affinity Matrix (size: {K.shape[0]}$\\times${K.shape[1]})')
plt.imshow(K.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7efe91d88610>
Solve graph matching problem by RRWM solver
See rrwm()
for the API reference.
X = pygm.rrwm(K, n1, n2)
The output of RRWM is a soft matching matrix. Visualization:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('RRWM Soft Matching Matrix')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7efeebf28f90>
Get the discrete matching matrix
Hungarian algorithm is then adopted to reach a discrete matching matrix
X = pygm.hungarian(X)
Visualization of the discrete matching matrix:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title(f'RRWM Matching Matrix (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
plt.imshow(X.numpy(), cmap='Blues')
plt.subplot(1, 2, 2)
plt.title('Ground Truth Matching Matrix')
plt.imshow(X_gt.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7efe93431090>
Match the subgraph
Draw the matching:
plt.figure(figsize=(8, 4))
plt.suptitle(f'RRWM Matching Result (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
ax1 = plt.subplot(1, 2, 1)
plt.title('Subgraph 1')
plt.gca().margins(0.4)
nx.draw_networkx(G1, pos=pos1, node_color=color1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2, node_color=color2)
for i in range(num_nodes1):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="green" if X_gt[i,j] == 1 else "red")
plt.gca().add_artist(con)

Other solvers are also available
See ipfp()
for the API reference.
X = pygm.ipfp(K, n1, n2)
Visualization of IPFP matching result:
plt.figure(figsize=(8, 4))
plt.suptitle(f'IPFP Matching Result (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
ax1 = plt.subplot(1, 2, 1)
plt.title('Subgraph 1')
plt.gca().margins(0.4)
nx.draw_networkx(G1, pos=pos1, node_color=color1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2, node_color=color2)
for i in range(num_nodes1):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="green" if X_gt[i,j] == 1 else "red")
plt.gca().add_artist(con)

See sm()
for the API reference.
X = pygm.sm(K, n1, n2)
X = pygm.hungarian(X)
Visualization of SM matching result:
plt.figure(figsize=(8, 4))
plt.suptitle(f'SM Matching Result (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
ax1 = plt.subplot(1, 2, 1)
plt.title('Subgraph 1')
plt.gca().margins(0.4)
nx.draw_networkx(G1, pos=pos1, node_color=color1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2, node_color=color2)
for i in range(num_nodes1):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="green" if X_gt[i,j] == 1 else "red")
plt.gca().add_artist(con)

See ngm()
for the API reference.
Note
The NGM solvers are pretrained on a different problem setting, so their performance may seem inferior. To improve their performance, you may change the way of building affinity matrices, or try finetuning NGM on the new problem.
with torch.set_grad_enabled(False):
X = pygm.ngm(K, n1, n2, pretrain='voc')
X = pygm.hungarian(X)
Visualization of NGM matching result:
plt.figure(figsize=(8, 4))
plt.suptitle(f'NGM Matching Result (acc={(X * X_gt).sum()/ X_gt.sum():.2f})')
ax1 = plt.subplot(1, 2, 1)
plt.title('Subgraph 1')
plt.gca().margins(0.4)
nx.draw_networkx(G1, pos=pos1, node_color=color1)
ax2 = plt.subplot(1, 2, 2)
plt.title('Graph 2')
nx.draw_networkx(G2, pos=pos2, node_color=color2)
for i in range(num_nodes1):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=pos1[i], xyB=pos2[j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="green" if X_gt[i,j] == 1 else "red")
plt.gca().add_artist(con)

Total running time of the script: ( 0 minutes 1.396 seconds)
Note
Click here to download the full example code
Matching Image Keypoints by QAP Solvers
This example shows how to match image keypoints by graph matching solvers provided by pygmtools
.
These solvers follow the Quadratic Assignment Problem formulation and can generally work out-of-box.
The matched images can be further processed for other downstream tasks.
# Author: Runzhong Wang <runzhong.wang@sjtu.edu.cn>
#
# License: Mulan PSL v2 License
Note
The following solvers support QAP formulation, and are included in this example:
import torch # pytorch backend
import torchvision # CV models
import pygmtools as pygm
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import scipy.io as sio # for loading .mat file
import scipy.spatial as spa # for Delaunay triangulation
from sklearn.decomposition import PCA as PCAdimReduc
import itertools
import numpy as np
from PIL import Image
pygm.BACKEND = 'pytorch' # set default backend for pygmtools
Load the images
Images are from the Willow Object Class dataset (this dataset also available with the Benchmark of pygmtools
,
see WillowObject
).
The images are resized to 256x256.
obj_resize = (256, 256)
img1 = Image.open('../data/willow_duck_0001.png')
img2 = Image.open('../data/willow_duck_0002.png')
kpts1 = torch.tensor(sio.loadmat('../data/willow_duck_0001.mat')['pts_coord'])
kpts2 = torch.tensor(sio.loadmat('../data/willow_duck_0002.mat')['pts_coord'])
kpts1[0] = kpts1[0] * obj_resize[0] / img1.size[0]
kpts1[1] = kpts1[1] * obj_resize[1] / img1.size[1]
kpts2[0] = kpts2[0] * obj_resize[0] / img2.size[0]
kpts2[1] = kpts2[1] * obj_resize[1] / img2.size[1]
img1 = img1.resize(obj_resize, resample=Image.BILINEAR)
img2 = img2.resize(obj_resize, resample=Image.BILINEAR)
/home/docs/checkouts/readthedocs.org/user_builds/pygmtools/checkouts/0.3.0/examples/pytorch/plot_image_matching.py:59: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
img1 = img1.resize(obj_resize, resample=Image.BILINEAR)
/home/docs/checkouts/readthedocs.org/user_builds/pygmtools/checkouts/0.3.0/examples/pytorch/plot_image_matching.py:60: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
img2 = img2.resize(obj_resize, resample=Image.BILINEAR)
Visualize the images and keypoints
def plot_image_with_graph(img, kpt, A=None):
plt.imshow(img)
plt.scatter(kpt[0], kpt[1], c='w', edgecolors='k')
if A is not None:
for idx in torch.nonzero(A, as_tuple=False):
plt.plot((kpt[0, idx[0]], kpt[0, idx[1]]), (kpt[1, idx[0]], kpt[1, idx[1]]), 'k-')
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Image 1')
plot_image_with_graph(img1, kpts1)
plt.subplot(1, 2, 2)
plt.title('Image 2')
plot_image_with_graph(img2, kpts2)

Build the graphs
Graph structures are built based on the geometric structure of the keypoint set. In this example, we refer to Delaunay triangulation.
def delaunay_triangulation(kpt):
d = spa.Delaunay(kpt.numpy().transpose())
A = torch.zeros(len(kpt[0]), len(kpt[0]))
for simplex in d.simplices:
for pair in itertools.permutations(simplex, 2):
A[pair] = 1
return A
A1 = delaunay_triangulation(kpts1)
A2 = delaunay_triangulation(kpts2)
We encode the length of edges as edge features
A1 = ((kpts1.unsqueeze(1) - kpts1.unsqueeze(2)) ** 2).sum(dim=0) * A1
A1 = (A1 / A1.max()).to(dtype=torch.float32)
A2 = ((kpts2.unsqueeze(1) - kpts2.unsqueeze(2)) ** 2).sum(dim=0) * A2
A2 = (A2 / A2.max()).to(dtype=torch.float32)
Visualize the graphs
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Image 1 with Graphs')
plot_image_with_graph(img1, kpts1, A1)
plt.subplot(1, 2, 2)
plt.title('Image 2 with Graphs')
plot_image_with_graph(img2, kpts2, A2)

Extract node features
Let’s adopt the VGG16 CNN model to extract node features.
vgg16_cnn = torchvision.models.vgg16_bn(True)
torch_img1 = torch.from_numpy(np.array(img1, dtype=np.float32) / 256).permute(2, 0, 1).unsqueeze(0) # shape: BxCxHxW
torch_img2 = torch.from_numpy(np.array(img2, dtype=np.float32) / 256).permute(2, 0, 1).unsqueeze(0) # shape: BxCxHxW
with torch.set_grad_enabled(False):
feat1 = vgg16_cnn.features(torch_img1)
feat2 = vgg16_cnn.features(torch_img2)
Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /home/docs/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth
0%| | 0.00/528M [00:00<?, ?B/s]
0%| | 840k/528M [00:00<01:04, 8.53MB/s]
1%|1 | 7.30M/528M [00:00<00:12, 43.3MB/s]
3%|2 | 14.3M/528M [00:00<00:09, 57.2MB/s]
4%|4 | 21.4M/528M [00:00<00:08, 63.9MB/s]
5%|5 | 28.6M/528M [00:00<00:07, 67.8MB/s]
7%|6 | 35.7M/528M [00:00<00:07, 70.1MB/s]
8%|8 | 42.9M/528M [00:00<00:07, 71.7MB/s]
9%|9 | 50.1M/528M [00:00<00:06, 72.6MB/s]
11%|# | 57.2M/528M [00:00<00:06, 73.2MB/s]
12%|#2 | 64.4M/528M [00:01<00:06, 73.9MB/s]
14%|#3 | 71.8M/528M [00:01<00:06, 74.9MB/s]
15%|#5 | 79.2M/528M [00:01<00:06, 75.9MB/s]
16%|#6 | 86.6M/528M [00:01<00:06, 76.2MB/s]
18%|#7 | 94.0M/528M [00:01<00:05, 76.5MB/s]
19%|#9 | 101M/528M [00:01<00:05, 76.9MB/s]
21%|## | 109M/528M [00:01<00:05, 77.2MB/s]
22%|##2 | 116M/528M [00:01<00:05, 77.1MB/s]
23%|##3 | 124M/528M [00:01<00:05, 77.4MB/s]
25%|##4 | 131M/528M [00:01<00:05, 77.6MB/s]
26%|##6 | 139M/528M [00:02<00:05, 77.5MB/s]
28%|##7 | 146M/528M [00:02<00:05, 77.7MB/s]
29%|##9 | 154M/528M [00:02<00:04, 78.5MB/s]
31%|### | 161M/528M [00:02<00:04, 79.0MB/s]
32%|###2 | 169M/528M [00:02<00:04, 79.7MB/s]
33%|###3 | 177M/528M [00:02<00:04, 79.8MB/s]
35%|###4 | 184M/528M [00:02<00:04, 80.0MB/s]
36%|###6 | 192M/528M [00:02<00:04, 80.0MB/s]
38%|###7 | 200M/528M [00:02<00:04, 80.1MB/s]
39%|###9 | 208M/528M [00:02<00:04, 80.8MB/s]
41%|#### | 216M/528M [00:03<00:04, 81.0MB/s]
42%|####2 | 223M/528M [00:03<00:03, 80.7MB/s]
44%|####3 | 231M/528M [00:03<00:03, 79.3MB/s]
45%|####5 | 239M/528M [00:03<00:03, 79.7MB/s]
47%|####6 | 246M/528M [00:03<00:03, 80.1MB/s]
48%|####8 | 254M/528M [00:03<00:03, 80.7MB/s]
50%|####9 | 262M/528M [00:03<00:03, 81.5MB/s]
51%|#####1 | 270M/528M [00:03<00:03, 82.6MB/s]
53%|#####2 | 278M/528M [00:03<00:03, 82.9MB/s]
54%|#####4 | 286M/528M [00:03<00:03, 83.0MB/s]
56%|#####5 | 294M/528M [00:04<00:02, 83.0MB/s]
57%|#####7 | 302M/528M [00:04<00:02, 83.1MB/s]
59%|#####8 | 310M/528M [00:04<00:02, 83.4MB/s]
60%|###### | 318M/528M [00:04<00:02, 84.0MB/s]
62%|######1 | 327M/528M [00:04<00:02, 84.2MB/s]
63%|######3 | 335M/528M [00:04<00:02, 84.1MB/s]
65%|######4 | 343M/528M [00:04<00:02, 84.5MB/s]
66%|######6 | 351M/528M [00:04<00:02, 84.9MB/s]
68%|######8 | 359M/528M [00:04<00:02, 85.4MB/s]
70%|######9 | 368M/528M [00:04<00:01, 85.9MB/s]
71%|#######1 | 376M/528M [00:05<00:01, 86.5MB/s]
73%|#######2 | 384M/528M [00:05<00:01, 86.6MB/s]
74%|#######4 | 393M/528M [00:05<00:01, 86.6MB/s]
76%|#######5 | 401M/528M [00:05<00:01, 86.5MB/s]
78%|#######7 | 409M/528M [00:05<00:01, 86.8MB/s]
79%|#######9 | 417M/528M [00:05<00:01, 86.8MB/s]
81%|######## | 426M/528M [00:05<00:01, 86.9MB/s]
82%|########2 | 434M/528M [00:05<00:01, 87.4MB/s]
84%|########3 | 443M/528M [00:05<00:01, 88.1MB/s]
86%|########5 | 451M/528M [00:05<00:00, 88.4MB/s]
87%|########7 | 460M/528M [00:06<00:00, 88.8MB/s]
89%|########8 | 469M/528M [00:06<00:00, 89.3MB/s]
90%|######### | 477M/528M [00:06<00:00, 89.4MB/s]
92%|#########2| 486M/528M [00:06<00:00, 89.8MB/s]
94%|#########3| 495M/528M [00:06<00:00, 90.1MB/s]
95%|#########5| 503M/528M [00:06<00:00, 89.9MB/s]
97%|#########6| 512M/528M [00:06<00:00, 89.9MB/s]
99%|#########8| 520M/528M [00:06<00:00, 89.8MB/s]
100%|##########| 528M/528M [00:06<00:00, 81.2MB/s]
Normalize the features
num_features = feat1.shape[1]
def l2norm(node_feat):
return torch.nn.functional.local_response_norm(
node_feat, node_feat.shape[1] * 2, alpha=node_feat.shape[1] * 2, beta=0.5, k=0)
feat1 = l2norm(feat1)
feat2 = l2norm(feat2)
Up-sample the features to the original image size
feat1_upsample = torch.nn.functional.interpolate(feat1, obj_resize, mode='bilinear')
feat2_upsample = torch.nn.functional.interpolate(feat2, obj_resize, mode='bilinear')
/home/docs/checkouts/readthedocs.org/user_builds/pygmtools/envs/0.3.0/lib/python3.7/site-packages/torch/nn/functional.py:3121: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
"See the documentation of nn.Upsample for details.".format(mode))
Visualize the extracted CNN feature (dimensionality reduction via principle component analysis)
pca_dim_reduc = PCAdimReduc(n_components=3, whiten=True)
feat_dim_reduc = pca_dim_reduc.fit_transform(
np.concatenate((
feat1_upsample.permute(0, 2, 3, 1).reshape(-1, num_features).numpy(),
feat2_upsample.permute(0, 2, 3, 1).reshape(-1, num_features).numpy()
), axis=0)
)
feat_dim_reduc = feat_dim_reduc / np.max(np.abs(feat_dim_reduc), axis=0, keepdims=True) / 2 + 0.5
feat1_dim_reduc = feat_dim_reduc[:obj_resize[0] * obj_resize[1], :]
feat2_dim_reduc = feat_dim_reduc[obj_resize[0] * obj_resize[1]:, :]
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Image 1 with CNN features')
plot_image_with_graph(img1, kpts1, A1)
plt.imshow(feat1_dim_reduc.reshape(obj_resize[0], obj_resize[1], 3), alpha=0.5)
plt.subplot(1, 2, 2)
plt.title('Image 2 with CNN features')
plot_image_with_graph(img2, kpts2, A2)
plt.imshow(feat2_dim_reduc.reshape(obj_resize[0], obj_resize[1], 3), alpha=0.5)

<matplotlib.image.AxesImage object at 0x7efe8ca25e50>
Extract node features by nearest interpolation
rounded_kpts1 = torch.round(kpts1).to(dtype=torch.long)
rounded_kpts2 = torch.round(kpts2).to(dtype=torch.long)
node1 = feat1_upsample[0, :, rounded_kpts1[0], rounded_kpts1[1]].t() # shape: NxC
node2 = feat2_upsample[0, :, rounded_kpts2[0], rounded_kpts2[1]].t() # shape: NxC
Build affinity matrix
We follow the formulation of Quadratic Assignment Problem (QAP):
where the first step is to build the affinity matrix (\(\mathbf{K}\))
conn1, edge1 = pygm.utils.dense_to_sparse(A1)
conn2, edge2 = pygm.utils.dense_to_sparse(A2)
import functools
gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=1) # set affinity function
K = pygm.utils.build_aff_mat(node1, edge1, conn1, node2, edge2, conn2, edge_aff_fn=gaussian_aff)
Visualization of the affinity matrix. For graph matching problem with \(N\) nodes, the affinity matrix has \(N^2\times N^2\) elements because there are \(N^2\) edges in each graph.
Note
The diagonal elements are node affinities, the off-diagonal elements are edge features.
plt.figure(figsize=(4, 4))
plt.title(f'Affinity Matrix (size: {K.shape[0]}$\\times${K.shape[1]})')
plt.imshow(K.numpy(), cmap='Blues')

<matplotlib.image.AxesImage object at 0x7efe6b8bf5d0>
Solve graph matching problem by RRWM solver
See rrwm()
for the API reference.
X = pygm.rrwm(K, kpts1.shape[1], kpts2.shape[1])
The output of RRWM is a soft matching matrix. Hungarian algorithm is then adopted to reach a discrete matching matrix
X = pygm.hungarian(X)
Plot the matching
The correct matchings are marked by green, and wrong matchings are marked by red. In this example, the nodes are ordered by their ground truth classes (i.e. the ground truth matching matrix is a diagonal matrix).
plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by RRWM')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)

Solve by other solvers
We could also do a quick benchmarking of other solvers on this specific problem.
See ipfp()
for the API reference.
X = pygm.ipfp(K, kpts1.shape[1], kpts2.shape[1])
plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by IPFP')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)

See sm()
for the API reference.
X = pygm.sm(K, kpts1.shape[1], kpts2.shape[1])
X = pygm.hungarian(X)
plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by SM')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)

See ngm()
for the API reference.
Note
The NGM solvers are pretrained on a different problem setting, so their performance may seem inferior. To improve their performance, you may change the way of building affinity matrices, or try finetuning NGM on the new problem.
The NGM solver pretrained on Willow dataset:
X = pygm.ngm(K, kpts1.shape[1], kpts2.shape[1], pretrain='willow')
X = pygm.hungarian(X)
plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by NGM (willow pretrain)')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)

Downloading to /home/docs/.cache/pygmtools/ngm_willow_pytorch.pt...
0%| | 0/23119 [00:00<?, ?it/s]
100%|##########| 22.6k/22.6k [00:00<00:00, 1.43MB/s]
The NGM solver pretrained on VOC dataset:
X = pygm.ngm(K, kpts1.shape[1], kpts2.shape[1], pretrain='voc')
X = pygm.hungarian(X)
plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by NGM (voc pretrain)')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)

Total running time of the script: ( 0 minutes 17.204 seconds)
Note
Click here to download the full example code
Matching Image Keypoints by Graph Matching Neural Networks
This example shows how to match image keypoints by neural network-based graph matching solvers. These graph matching solvers are designed to match two individual graphs. The matched images can be further passed to tackle downstream tasks.
# Author: Runzhong Wang <runzhong.wang@sjtu.edu.cn>
#
# License: Mulan PSL v2 License
Note
The following solvers are based on matching two individual graphs, and are included in this example:
import torch # pytorch backend
import torchvision # CV models
import pygmtools as pygm
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import scipy.io as sio # for loading .mat file
import scipy.spatial as spa # for Delaunay triangulation
from sklearn.decomposition import PCA as PCAdimReduc
import itertools
import numpy as np
from PIL import Image
pygm.BACKEND = 'pytorch' # set default backend for pygmtools
Predicting Matching by Graph Matching Neural Networks
In this section we show how to do predictions (inference) by graph matching neural networks.
Let’s take PCA-GM (pca_gm()
) as an example.
Images are from the Willow Object Class dataset (this dataset also available with the Benchmark of pygmtools
,
see WillowObject
).
The images are resized to 256x256.
obj_resize = (256, 256)
img1 = Image.open('../data/willow_duck_0001.png')
img2 = Image.open('../data/willow_duck_0002.png')
kpts1 = torch.tensor(sio.loadmat('../data/willow_duck_0001.mat')['pts_coord'])
kpts2 = torch.tensor(sio.loadmat('../data/willow_duck_0002.mat')['pts_coord'])
kpts1[0] = kpts1[0] * obj_resize[0] / img1.size[0]
kpts1[1] = kpts1[1] * obj_resize[1] / img1.size[1]
kpts2[0] = kpts2[0] * obj_resize[0] / img2.size[0]
kpts2[1] = kpts2[1] * obj_resize[1] / img2.size[1]
img1 = img1.resize(obj_resize, resample=Image.BILINEAR)
img2 = img2.resize(obj_resize, resample=Image.BILINEAR)
torch_img1 = torch.from_numpy(np.array(img1, dtype=np.float32) / 256).permute(2, 0, 1).unsqueeze(0) # shape: BxCxHxW
torch_img2 = torch.from_numpy(np.array(img2, dtype=np.float32) / 256).permute(2, 0, 1).unsqueeze(0) # shape: BxCxHxW
/home/docs/checkouts/readthedocs.org/user_builds/pygmtools/checkouts/0.3.0/examples/pytorch/plot_deep_image_matching.py:62: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
img1 = img1.resize(obj_resize, resample=Image.BILINEAR)
/home/docs/checkouts/readthedocs.org/user_builds/pygmtools/checkouts/0.3.0/examples/pytorch/plot_deep_image_matching.py:63: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
img2 = img2.resize(obj_resize, resample=Image.BILINEAR)
Visualize the images and keypoints
def plot_image_with_graph(img, kpt, A=None):
plt.imshow(img)
plt.scatter(kpt[0], kpt[1], c='w', edgecolors='k')
if A is not None:
for idx in torch.nonzero(A, as_tuple=False):
plt.plot((kpt[0, idx[0]], kpt[0, idx[1]]), (kpt[1, idx[0]], kpt[1, idx[1]]), 'k-')
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Image 1')
plot_image_with_graph(img1, kpts1)
plt.subplot(1, 2, 2)
plt.title('Image 2')
plot_image_with_graph(img2, kpts2)

Graph structures are built based on the geometric structure of the keypoint set. In this example, we refer to Delaunay triangulation.
def delaunay_triangulation(kpt):
d = spa.Delaunay(kpt.numpy().transpose())
A = torch.zeros(len(kpt[0]), len(kpt[0]))
for simplex in d.simplices:
for pair in itertools.permutations(simplex, 2):
A[pair] = 1
return A
A1 = delaunay_triangulation(kpts1)
A2 = delaunay_triangulation(kpts2)
Visualize the graphs
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Image 1 with Graphs')
plot_image_with_graph(img1, kpts1, A1)
plt.subplot(1, 2, 2)
plt.title('Image 2 with Graphs')
plot_image_with_graph(img2, kpts2, A2)

Deep graph matching solvers can be fused with CNN feature extractors, to build an end-to-end learning pipeline.
In this example, let’s adopt the deep graph solvers based on matching two individual graphs.
The image features are based on two intermediate layers from the VGG16 CNN model, following
existing deep graph matching papers (such as pca_gm()
)
Let’s firstly fetch and download the VGG16 model:
vgg16_cnn = torchvision.models.vgg16_bn(True)
List of layers of VGG16:
print(vgg16_cnn.features)
Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(9): ReLU(inplace=True)
(10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(12): ReLU(inplace=True)
(13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(16): ReLU(inplace=True)
(17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(19): ReLU(inplace=True)
(20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(26): ReLU(inplace=True)
(27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(29): ReLU(inplace=True)
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(32): ReLU(inplace=True)
(33): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(36): ReLU(inplace=True)
(37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(39): ReLU(inplace=True)
(40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(42): ReLU(inplace=True)
(43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
Let’s define the CNN feature extractor, which outputs the features of layer (30)
and
layer (37)
class CNNNet(torch.nn.Module):
def __init__(self, vgg16_module):
super(CNNNet, self).__init__()
# The naming of the layers follow ThinkMatch convention to load pretrained models.
self.node_layers = torch.nn.Sequential(*[_ for _ in vgg16_module.features[:31]])
self.edge_layers = torch.nn.Sequential(*[_ for _ in vgg16_module.features[31:38]])
def forward(self, inp_img):
feat_local = self.node_layers(inp_img)
feat_global = self.edge_layers(feat_local)
return feat_local, feat_global
Download pretrained CNN weights (from ThinkMatch), load the weights and then extract the CNN features
cnn = CNNNet(vgg16_cnn)
path = pygm.utils.download('vgg16_pca_voc_pytorch.pt', 'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1JnX3cSPvRYBSrDKVwByzp7CADgVCJCO_')
if torch.cuda.is_available():
map_location = torch.device('cuda:0')
else:
map_location = torch.device('cpu')
cnn.load_state_dict(torch.load(path, map_location=map_location), strict=False)
with torch.set_grad_enabled(False):
feat1_local, feat1_global = cnn(torch_img1)
feat2_local, feat2_global = cnn(torch_img2)
Downloading to /home/docs/.cache/pygmtools/vgg16_pca_voc_pytorch.pt...
0%| | 0/166984482 [00:00<?, ?it/s]
1%| | 1.11M/159M [00:00<00:14, 11.6MB/s]
6%|5 | 9.53M/159M [00:00<00:02, 56.5MB/s]
11%|#1 | 18.3M/159M [00:00<00:02, 72.6MB/s]
17%|#6 | 27.1M/159M [00:00<00:01, 80.0MB/s]
23%|##2 | 36.0M/159M [00:00<00:01, 84.6MB/s]
28%|##8 | 45.0M/159M [00:00<00:01, 88.1MB/s]
34%|###3 | 54.1M/159M [00:00<00:01, 90.3MB/s]
40%|###9 | 63.2M/159M [00:00<00:01, 92.0MB/s]
45%|####5 | 72.3M/159M [00:00<00:00, 92.9MB/s]
51%|#####1 | 81.4M/159M [00:01<00:00, 93.5MB/s]
57%|#####6 | 90.5M/159M [00:01<00:00, 94.2MB/s]
63%|######2 | 99.7M/159M [00:01<00:00, 94.8MB/s]
68%|######8 | 109M/159M [00:01<00:00, 95.5MB/s]
74%|#######4 | 118M/159M [00:01<00:00, 96.3MB/s]
80%|######## | 128M/159M [00:01<00:00, 97.1MB/s]
86%|########6 | 137M/159M [00:01<00:00, 97.8MB/s]
92%|#########2| 147M/159M [00:01<00:00, 98.2MB/s]
98%|#########8| 156M/159M [00:01<00:00, 98.4MB/s]
100%|##########| 159M/159M [00:01<00:00, 90.9MB/s]
Normalize the features
def l2norm(node_feat):
return torch.nn.functional.local_response_norm(
node_feat, node_feat.shape[1] * 2, alpha=node_feat.shape[1] * 2, beta=0.5, k=0)
feat1_local = l2norm(feat1_local)
feat1_global = l2norm(feat1_global)
feat2_local = l2norm(feat2_local)
feat2_global = l2norm(feat2_global)
Up-sample the features to the original image size and concatenate
feat1_local_upsample = torch.nn.functional.interpolate(feat1_local, obj_resize, mode='bilinear')
feat1_global_upsample = torch.nn.functional.interpolate(feat1_global, obj_resize, mode='bilinear')
feat2_local_upsample = torch.nn.functional.interpolate(feat2_local, obj_resize, mode='bilinear')
feat2_global_upsample = torch.nn.functional.interpolate(feat2_global, obj_resize, mode='bilinear')
feat1_upsample = torch.cat((feat1_local_upsample, feat1_global_upsample), dim=1)
feat2_upsample = torch.cat((feat2_local_upsample, feat2_global_upsample), dim=1)
num_features = feat1_upsample.shape[1]
/home/docs/checkouts/readthedocs.org/user_builds/pygmtools/envs/0.3.0/lib/python3.7/site-packages/torch/nn/functional.py:3121: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
"See the documentation of nn.Upsample for details.".format(mode))
Visualize the extracted CNN feature (dimensionality reduction via principle component analysis)
pca_dim_reduc = PCAdimReduc(n_components=3, whiten=True)
feat_dim_reduc = pca_dim_reduc.fit_transform(
np.concatenate((
feat1_upsample.permute(0, 2, 3, 1).reshape(-1, num_features).numpy(),
feat2_upsample.permute(0, 2, 3, 1).reshape(-1, num_features).numpy()
), axis=0)
)
feat_dim_reduc = feat_dim_reduc / np.max(np.abs(feat_dim_reduc), axis=0, keepdims=True) / 2 + 0.5
feat1_dim_reduc = feat_dim_reduc[:obj_resize[0] * obj_resize[1], :]
feat2_dim_reduc = feat_dim_reduc[obj_resize[0] * obj_resize[1]:, :]
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Image 1 with CNN features')
plot_image_with_graph(img1, kpts1, A1)
plt.imshow(feat1_dim_reduc.reshape(obj_resize[0], obj_resize[1], 3), alpha=0.5)
plt.subplot(1, 2, 2)
plt.title('Image 2 with CNN features')
plot_image_with_graph(img2, kpts2, A2)
plt.imshow(feat2_dim_reduc.reshape(obj_resize[0], obj_resize[1], 3), alpha=0.5)

<matplotlib.image.AxesImage object at 0x7efe8caba510>
Extract node features by nearest interpolation
rounded_kpts1 = torch.round(kpts1).to(dtype=torch.long)
rounded_kpts2 = torch.round(kpts2).to(dtype=torch.long)
node1 = feat1_upsample[0, :, rounded_kpts1[0], rounded_kpts1[1]].t() # shape: NxC
node2 = feat2_upsample[0, :, rounded_kpts2[0], rounded_kpts2[1]].t() # shape: NxC
See pca_gm()
for the API reference.
X = pygm.pca_gm(node1, node2, A1, A2, pretrain='voc')
X = pygm.hungarian(X)
plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by PCA-GM')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)

Downloading to /home/docs/.cache/pygmtools/pca_gm_voc_pytorch.pt...
0%| | 0/117485843 [00:00<?, ?it/s]
1%| | 800k/112M [00:00<00:14, 8.17MB/s]
7%|7 | 8.28M/112M [00:00<00:02, 49.6MB/s]
15%|#4 | 16.7M/112M [00:00<00:01, 67.1MB/s]
22%|##2 | 25.1M/112M [00:00<00:01, 75.2MB/s]
30%|##9 | 33.4M/112M [00:00<00:01, 79.6MB/s]
37%|###7 | 41.8M/112M [00:00<00:00, 82.3MB/s]
45%|####4 | 50.2M/112M [00:00<00:00, 84.2MB/s]
52%|#####2 | 58.5M/112M [00:00<00:00, 85.3MB/s]
60%|#####9 | 66.9M/112M [00:00<00:00, 86.0MB/s]
67%|######7 | 75.4M/112M [00:01<00:00, 86.6MB/s]
75%|#######4 | 83.8M/112M [00:01<00:00, 87.0MB/s]
82%|########2 | 92.2M/112M [00:01<00:00, 87.5MB/s]
90%|######### | 101M/112M [00:01<00:00, 88.8MB/s]
98%|#########7| 110M/112M [00:01<00:00, 89.3MB/s]
100%|##########| 112M/112M [00:01<00:00, 82.1MB/s]
Matching images with other neural networks
The above pipeline also works for other deep graph matching networks. Here we give examples of
ipca_gm()
and cie()
.
See ipca_gm()
for the API reference.
path = pygm.utils.download('vgg16_ipca_voc_pytorch.pt', 'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1TGrbSQRmUkClH3Alz2OCwqjl8r8gf5yI')
cnn.load_state_dict(torch.load(path, map_location=map_location), strict=False)
with torch.set_grad_enabled(False):
feat1_local, feat1_global = cnn(torch_img1)
feat2_local, feat2_global = cnn(torch_img2)
Downloading to /home/docs/.cache/pygmtools/vgg16_ipca_voc_pytorch.pt...
0%| | 0/166984418 [00:00<?, ?it/s]
1%| | 1.09M/159M [00:00<00:14, 11.4MB/s]
5%|5 | 8.02M/159M [00:00<00:04, 37.1MB/s]
10%|# | 16.0M/159M [00:00<00:03, 46.1MB/s]
15%|#5 | 24.0M/159M [00:00<00:02, 51.4MB/s]
20%|## | 32.0M/159M [00:00<00:02, 54.9MB/s]
25%|##5 | 40.0M/159M [00:00<00:02, 59.9MB/s]
30%|### | 48.0M/159M [00:00<00:01, 59.7MB/s]
35%|###5 | 56.0M/159M [00:01<00:01, 61.3MB/s]
40%|#### | 64.0M/159M [00:01<00:02, 41.0MB/s]
45%|####5 | 72.0M/159M [00:01<00:02, 34.9MB/s]
51%|##### | 81.2M/159M [00:01<00:01, 44.4MB/s]
55%|#####5 | 88.2M/159M [00:01<00:01, 49.9MB/s]
61%|######1 | 97.4M/159M [00:02<00:01, 59.4MB/s]
66%|######6 | 106M/159M [00:02<00:00, 65.2MB/s]
72%|#######1 | 115M/159M [00:02<00:00, 72.7MB/s]
77%|#######6 | 122M/159M [00:02<00:00, 72.3MB/s]
83%|########2 | 132M/159M [00:02<00:00, 79.0MB/s]
88%|########7 | 140M/159M [00:02<00:00, 61.8MB/s]
94%|#########3| 149M/159M [00:02<00:00, 70.5MB/s]
100%|#########9| 159M/159M [00:02<00:00, 77.6MB/s]
100%|##########| 159M/159M [00:02<00:00, 58.2MB/s]
Normalize the features
def l2norm(node_feat):
return torch.nn.functional.local_response_norm(
node_feat, node_feat.shape[1] * 2, alpha=node_feat.shape[1] * 2, beta=0.5, k=0)
feat1_local = l2norm(feat1_local)
feat1_global = l2norm(feat1_global)
feat2_local = l2norm(feat2_local)
feat2_global = l2norm(feat2_global)
Up-sample the features to the original image size and concatenate
feat1_local_upsample = torch.nn.functional.interpolate(feat1_local, obj_resize, mode='bilinear')
feat1_global_upsample = torch.nn.functional.interpolate(feat1_global, obj_resize, mode='bilinear')
feat2_local_upsample = torch.nn.functional.interpolate(feat2_local, obj_resize, mode='bilinear')
feat2_global_upsample = torch.nn.functional.interpolate(feat2_global, obj_resize, mode='bilinear')
feat1_upsample = torch.cat((feat1_local_upsample, feat1_global_upsample), dim=1)
feat2_upsample = torch.cat((feat2_local_upsample, feat2_global_upsample), dim=1)
num_features = feat1_upsample.shape[1]
/home/docs/checkouts/readthedocs.org/user_builds/pygmtools/envs/0.3.0/lib/python3.7/site-packages/torch/nn/functional.py:3121: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
"See the documentation of nn.Upsample for details.".format(mode))
Extract node features by nearest interpolation
rounded_kpts1 = torch.round(kpts1).to(dtype=torch.long)
rounded_kpts2 = torch.round(kpts2).to(dtype=torch.long)
node1 = feat1_upsample[0, :, rounded_kpts1[0], rounded_kpts1[1]].t() # shape: NxC
node2 = feat2_upsample[0, :, rounded_kpts2[0], rounded_kpts2[1]].t() # shape: NxC
Build edge features as edge lengths
kpts1_dis = (kpts1.unsqueeze(0) - kpts1.unsqueeze(1))
kpts1_dis = torch.norm(kpts1_dis, p=2, dim=2).detach()
kpts2_dis = (kpts2.unsqueeze(0) - kpts2.unsqueeze(1))
kpts2_dis = torch.norm(kpts2_dis, p=2, dim=2).detach()
Q1 = torch.exp(-kpts1_dis / obj_resize[0])
Q2 = torch.exp(-kpts2_dis / obj_resize[0])
Matching by IPCA-GM model
X = pygm.ipca_gm(node1, node2, A1, A2, pretrain='voc')
X = pygm.hungarian(X)
plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by IPCA-GM')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)

Downloading to /home/docs/.cache/pygmtools/ipca_gm_voc_pytorch.pt...
0%| | 0/100708304 [00:00<?, ?it/s]
1%|1 | 1.06M/96.0M [00:00<00:08, 11.1MB/s]
10%|9 | 9.36M/96.0M [00:00<00:01, 55.4MB/s]
19%|#8 | 18.0M/96.0M [00:00<00:01, 71.4MB/s]
28%|##7 | 26.7M/96.0M [00:00<00:00, 79.2MB/s]
37%|###6 | 35.4M/96.0M [00:00<00:00, 83.5MB/s]
46%|####5 | 44.1M/96.0M [00:00<00:00, 85.7MB/s]
55%|#####5 | 52.8M/96.0M [00:00<00:00, 87.7MB/s]
64%|######4 | 61.8M/96.0M [00:00<00:00, 89.7MB/s]
74%|#######3 | 70.9M/96.0M [00:00<00:00, 91.1MB/s]
83%|########3 | 80.0M/96.0M [00:01<00:00, 92.3MB/s]
93%|#########2| 88.9M/96.0M [00:01<00:00, 92.7MB/s]
100%|##########| 96.0M/96.0M [00:01<00:00, 85.0MB/s]
See cie()
for the API reference.
path = pygm.utils.download('vgg16_cie_voc_pytorch.pt', 'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1oRwcnw06t1rCbrIN_7p8TJZY-XkBOFEp')
cnn.load_state_dict(torch.load(path, map_location=map_location), strict=False)
with torch.set_grad_enabled(False):
feat1_local, feat1_global = cnn(torch_img1)
feat2_local, feat2_global = cnn(torch_img2)
Downloading to /home/docs/.cache/pygmtools/vgg16_cie_voc_pytorch.pt...
0%| | 0/217351064 [00:00<?, ?it/s]
0%| | 848k/207M [00:00<00:25, 8.64MB/s]
4%|4 | 8.47M/207M [00:00<00:04, 50.6MB/s]
8%|8 | 16.8M/207M [00:00<00:02, 67.3MB/s]
12%|#2 | 25.2M/207M [00:00<00:02, 75.3MB/s]
16%|#6 | 33.6M/207M [00:00<00:02, 80.0MB/s]
20%|## | 42.0M/207M [00:00<00:02, 82.7MB/s]
24%|##4 | 50.5M/207M [00:00<00:01, 84.6MB/s]
28%|##8 | 58.8M/207M [00:00<00:01, 85.6MB/s]
32%|###2 | 67.3M/207M [00:00<00:01, 86.6MB/s]
37%|###6 | 75.7M/207M [00:01<00:01, 87.0MB/s]
41%|#### | 84.2M/207M [00:01<00:01, 87.5MB/s]
45%|####4 | 92.6M/207M [00:01<00:01, 87.6MB/s]
49%|####8 | 101M/207M [00:01<00:01, 87.9MB/s]
53%|#####2 | 109M/207M [00:01<00:01, 87.9MB/s]
57%|#####6 | 118M/207M [00:01<00:01, 88.2MB/s]
61%|###### | 126M/207M [00:01<00:00, 88.1MB/s]
65%|######5 | 135M/207M [00:01<00:00, 88.1MB/s]
69%|######9 | 143M/207M [00:01<00:00, 88.1MB/s]
73%|#######3 | 152M/207M [00:01<00:00, 88.2MB/s]
77%|#######7 | 160M/207M [00:02<00:00, 88.1MB/s]
81%|########1 | 169M/207M [00:02<00:00, 88.5MB/s]
85%|########5 | 177M/207M [00:02<00:00, 88.2MB/s]
90%|########9 | 186M/207M [00:02<00:00, 88.5MB/s]
94%|#########3| 194M/207M [00:02<00:00, 88.5MB/s]
98%|#########7| 203M/207M [00:02<00:00, 88.5MB/s]
100%|##########| 207M/207M [00:02<00:00, 84.8MB/s]
Normalize the features
def l2norm(node_feat):
return torch.nn.functional.local_response_norm(
node_feat, node_feat.shape[1] * 2, alpha=node_feat.shape[1] * 2, beta=0.5, k=0)
feat1_local = l2norm(feat1_local)
feat1_global = l2norm(feat1_global)
feat2_local = l2norm(feat2_local)
feat2_global = l2norm(feat2_global)
Up-sample the features to the original image size and concatenate
feat1_local_upsample = torch.nn.functional.interpolate(feat1_local, obj_resize, mode='bilinear')
feat1_global_upsample = torch.nn.functional.interpolate(feat1_global, obj_resize, mode='bilinear')
feat2_local_upsample = torch.nn.functional.interpolate(feat2_local, obj_resize, mode='bilinear')
feat2_global_upsample = torch.nn.functional.interpolate(feat2_global, obj_resize, mode='bilinear')
feat1_upsample = torch.cat((feat1_local_upsample, feat1_global_upsample), dim=1)
feat2_upsample = torch.cat((feat2_local_upsample, feat2_global_upsample), dim=1)
num_features = feat1_upsample.shape[1]
/home/docs/checkouts/readthedocs.org/user_builds/pygmtools/envs/0.3.0/lib/python3.7/site-packages/torch/nn/functional.py:3121: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
"See the documentation of nn.Upsample for details.".format(mode))
Extract node features by nearest interpolation
rounded_kpts1 = torch.round(kpts1).to(dtype=torch.long)
rounded_kpts2 = torch.round(kpts2).to(dtype=torch.long)
node1 = feat1_upsample[0, :, rounded_kpts1[0], rounded_kpts1[1]].t() # shape: NxC
node2 = feat2_upsample[0, :, rounded_kpts2[0], rounded_kpts2[1]].t() # shape: NxC
Build edge features as edge lengths
kpts1_dis = (kpts1.unsqueeze(1) - kpts1.unsqueeze(2))
kpts1_dis = torch.norm(kpts1_dis, p=2, dim=0).detach()
kpts2_dis = (kpts2.unsqueeze(1) - kpts2.unsqueeze(2))
kpts2_dis = torch.norm(kpts2_dis, p=2, dim=0).detach()
Q1 = torch.exp(-kpts1_dis / obj_resize[0]).unsqueeze(-1).to(dtype=torch.float32)
Q2 = torch.exp(-kpts2_dis / obj_resize[0]).unsqueeze(-1).to(dtype=torch.float32)
Call CIE matching model
X = pygm.cie(node1, node2, A1, A2, Q1, Q2, pretrain='voc')
X = pygm.hungarian(X)
plt.figure(figsize=(8, 4))
plt.suptitle('Image Matching Result by CIE')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img1, kpts1, A1)
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img2, kpts2, A2)
for i in range(X.shape[0]):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)

Downloading to /home/docs/.cache/pygmtools/cie_voc_pytorch.pt...
0%| | 0/134288927 [00:00<?, ?it/s]
1%| | 896k/128M [00:00<00:14, 9.12MB/s]
6%|6 | 8.02M/128M [00:00<00:02, 43.8MB/s]
13%|#2 | 16.5M/128M [00:00<00:01, 63.5MB/s]
20%|#9 | 25.1M/128M [00:00<00:01, 73.6MB/s]
26%|##6 | 33.7M/128M [00:00<00:01, 79.1MB/s]
33%|###2 | 42.2M/128M [00:00<00:01, 82.7MB/s]
40%|###9 | 50.8M/128M [00:00<00:00, 85.0MB/s]
46%|####6 | 59.3M/128M [00:00<00:00, 86.3MB/s]
53%|#####3 | 67.9M/128M [00:00<00:00, 87.4MB/s]
60%|#####9 | 76.5M/128M [00:01<00:00, 88.2MB/s]
66%|######6 | 85.1M/128M [00:01<00:00, 88.7MB/s]
73%|#######3 | 93.7M/128M [00:01<00:00, 89.1MB/s]
80%|#######9 | 102M/128M [00:01<00:00, 89.2MB/s]
87%|########6 | 111M/128M [00:01<00:00, 89.8MB/s]
94%|#########3| 120M/128M [00:01<00:00, 90.6MB/s]
100%|##########| 128M/128M [00:01<00:00, 83.3MB/s]
Training a deep graph matching model
In this section, we show how to build a deep graph matching model which supports end-to-end training. For the image matching problem considered here, the model is composed of a CNN feature extractor and a learnable matching module. Take the PCA-GM model as an example.
Note
This simple example is intended to show you how to do the basic forward and backward pass when training an end-to-end deep graph matching neural network. A ‘more formal’ deep learning pipeline should involve asynchronized data loader, batched operations, CUDA support and so on, which are all omitted in consideration of simplicity. You may refer to ThinkMatch which is a research protocol with all these advanced features.
Let’s firstly define the neural network model. By passing None
to pca_gm()
,
it will simply return the network object.
class GMNet(torch.nn.Module):
def __init__(self):
super(GMNet, self).__init__()
self.gm_net = pygm.utils.get_network(pygm.pca_gm, pretrain=False) # fetch the network object
self.cnn = CNNNet(vgg16_cnn)
def forward(self, img1, img2, kpts1, kpts2, A1, A2):
# CNN feature extractor layers
feat1_local, feat1_global = self.cnn(img1)
feat2_local, feat2_global = self.cnn(img2)
feat1_local = l2norm(feat1_local)
feat1_global = l2norm(feat1_global)
feat2_local = l2norm(feat2_local)
feat2_global = l2norm(feat2_global)
# upsample feature map
feat1_local_upsample = torch.nn.functional.interpolate(feat1_local, obj_resize, mode='bilinear')
feat1_global_upsample = torch.nn.functional.interpolate(feat1_global, obj_resize, mode='bilinear')
feat2_local_upsample = torch.nn.functional.interpolate(feat2_local, obj_resize, mode='bilinear')
feat2_global_upsample = torch.nn.functional.interpolate(feat2_global, obj_resize, mode='bilinear')
feat1_upsample = torch.cat((feat1_local_upsample, feat1_global_upsample), dim=1)
feat2_upsample = torch.cat((feat2_local_upsample, feat2_global_upsample), dim=1)
# assign node features
rounded_kpts1 = torch.round(kpts1).to(dtype=torch.long)
rounded_kpts2 = torch.round(kpts2).to(dtype=torch.long)
node1 = feat1_upsample[0, :, rounded_kpts1[0], rounded_kpts1[1]].t() # shape: NxC
node2 = feat2_upsample[0, :, rounded_kpts2[0], rounded_kpts2[1]].t() # shape: NxC
# PCA-GM matching layers
X = pygm.pca_gm(node1, node2, A1, A2, network=self.gm_net) # the network object is reused
return X
model = GMNet()
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
X = model(torch_img1, torch_img2, kpts1, kpts2, A1, A2)
/home/docs/checkouts/readthedocs.org/user_builds/pygmtools/envs/0.3.0/lib/python3.7/site-packages/torch/nn/functional.py:3121: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
"See the documentation of nn.Upsample for details.".format(mode))
In this example, the ground truth matching matrix is a diagonal matrix. We calculate the loss function via
permutation_loss()
X_gt = torch.eye(X.shape[0])
loss = pygm.utils.permutation_loss(X, X_gt)
print(f'loss={loss:.4f}')
loss=3.0041
loss.backward()
Visualize the gradients
plt.figure(figsize=(4, 4))
plt.title('Gradient Sizes of PCA-GM and VGG16 layers')
plt.gca().set_xlabel('Layer Index')
plt.gca().set_ylabel('Average Gradient Size')
grad_size = []
for param in model.parameters():
grad_size.append(torch.abs(param.grad).mean().item())
print(grad_size)
plt.stem(grad_size)

[0.00010975256009260193, 0.0025820708833634853, 0.0001791776594473049, 0.003091256134212017, 0.00019184016855433583, 0.004451952874660492, 9.98683572106529e-06, 4.676436947192997e-05, 9.653159213485196e-05, 0.0037215095944702625, 0.00012330709432717413, 0.0025161956436932087, 0.0004079950740560889, 9.495846775564587e-09, 0.0008160851430147886, 0.0006923370528966188, 0.00016938714543357491, 8.247080351964087e-09, 0.0018779300153255463, 0.0011980452109128237, 0.00021563969494309276, 2.068021531798081e-09, 0.001510141883045435, 0.0009055724949575961, 0.0001880597264971584, 3.660260761151335e-09, 0.0017704766942188144, 0.0009336725925095379, 0.0001807412481866777, 1.0506436831647648e-09, 0.0014650675002485514, 0.001176448306068778, 0.00015496666310355067, 1.5817932519368583e-09, 0.001682713278569281, 0.0012110973475500941, 0.00016743903688620776, 1.7945213093284451e-09, 0.0018167807720601559, 0.0009177665924653411, 0.0001519177749287337, 4.470540881928997e-10, 0.0016938832122832537, 0.001109698903746903, 0.00011667542275972664, 6.756070702884642e-10, 0.0018534527625888586, 0.001221839222125709, 0.00011477413499960676, 0.0004964273539371789, 0.0017710058018565178, 0.0009307056316174567, 9.207324183080345e-05, 2.3413559979701404e-10, 0.001390282646752894, 0.0009812040952965617, 8.179757423931733e-05, 0.0007362457108683884]
<StemContainer object of 3 artists>
Update the model parameters. A deep learning pipeline should iterate the forward pass and backward pass steps until convergence.
optim.step()
optim.zero_grad()
Note
This example supports both GPU and CPU, and the online documentation is built by a CPU-only machine. The efficiency will be significantly improved if you run this code on GPU.
Total running time of the script: ( 1 minutes 2.536 seconds)
Note
Click here to download the full example code
Model Fusion by Graph Matching
This example shows how to fuse different models into a single model by pygmtools
.
Model fusion aims to fuse multiple models into one, such that the fused model could have higher performance.
The neural networks can be regarded as graphs (channels - nodes, update functions between channels - edges;
node feature - bias, edge feature - weights), and fusing the models is equivalent to solving a graph matching
problem. In this example, the given models are trained on MNIST data from different distributions, and the
fused model could combine the knowledge from two input models and can reach higher accuracy when testing.
# Author: Chang Liu <only-changer@sjtu.edu.cn>
# Runzhong Wang <runzhong.wang@sjtu.edu.cn>
#
# License: Mulan PSL v2 License
Note
This is a simplified implementation of the ideas in Liu et al. Deep Neural Network Fusion via Graph Matching with Applications to Model Ensemble and Federated Learning. ICML 2022. For more details, please refer to the paper and the official code repository.
Note
The following solvers are included in this example:
sm()
(classic solver)hungarian()
(linear solver)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import time
from PIL import Image
import matplotlib.pyplot as plt
import pygmtools as pygm
pygm.BACKEND = 'pytorch'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
Define a simple CNN classifier network
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 5, padding=1, padding_mode='replicate', bias=False)
self.max_pool = nn.MaxPool2d(2, padding=1)
self.conv2 = nn.Conv2d(32, 64, 5, padding=1, padding_mode='replicate', bias=False)
self.fc1 = nn.Linear(3136, 32, bias=False)
self.fc2 = nn.Linear(32, 10, bias=False)
def forward(self, x):
output = F.relu(self.conv1(x))
output = self.max_pool(output)
output = F.relu(self.conv2(output))
output = self.max_pool(output)
output = output.view(output.shape[0], -1)
output = self.fc1(output)
output = self.fc2(output)
return output
Load the trained models to be fused
model1 = SimpleNet()
model2 = SimpleNet()
model1.load_state_dict(torch.load('../data/example_model_fusion_1.dat', map_location=device))
model2.load_state_dict(torch.load('../data/example_model_fusion_2.dat', map_location=device))
model1.to(device)
model2.to(device)
test_dataset = torchvision.datasets.MNIST(
root='../data/mnist_data', # the directory to store the dataset
train=False, # the dataset is used to test
transform=transforms.ToTensor(), # the dataset is in the form of tensors
download=True)
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset,
batch_size=32,
shuffle=False)
Using downloaded and verified file: ../data/mnist_data/MNIST/raw/train-images-idx3-ubyte.gz
Extracting ../data/mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/mnist_data/MNIST/raw
Using downloaded and verified file: ../data/mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ../data/mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/mnist_data/MNIST/raw
Using downloaded and verified file: ../data/mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting ../data/mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/mnist_data/MNIST/raw
Using downloaded and verified file: ../data/mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ../data/mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/mnist_data/MNIST/raw
Processing...
/home/docs/checkouts/readthedocs.org/user_builds/pygmtools/envs/0.3.0/lib/python3.7/site-packages/torchvision/datasets/mnist.py:469: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.)
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
Done!
Print the layers of the simple CNN model:
print(model1)
SimpleNet(
(conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=replicate)
(max_pool): MaxPool2d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False)
(conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=replicate)
(fc1): Linear(in_features=3136, out_features=32, bias=False)
(fc2): Linear(in_features=32, out_features=10, bias=False)
)
Test the input models
with torch.no_grad():
n_correct1 = 0
n_correct2 = 0
n_samples = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs1 = model1(images)
outputs2 = model2(images)
_, predictions1 = torch.max(outputs1, 1)
_, predictions2 = torch.max(outputs2, 1)
n_samples += labels.shape[0]
n_correct1 += (predictions1 == labels).sum().item()
n_correct2 += (predictions2 == labels).sum().item()
acc1 = 100 * n_correct1 / n_samples
acc2 = 100 * n_correct2 / n_samples
Testing results (two separate models):
print(f'model1 accuracy = {acc1}%, model2 accuracy = {acc2}%')
model1 accuracy = 84.18%, model2 accuracy = 83.81%
Build the affinity matrix for graph matching
As shown in the following plot, the neural networks can be regarded as graphs. The weights corresponds to the edge features, and the bias corresponds to the node features. In this example, the neural network does not have bias so that there are only edge features.
plt.figure(figsize=(8, 4))
img = Image.open('../data/model_fusion.png')
plt.imshow(img)
plt.axis('off')
st_time = time.perf_counter()

Define the graph matching affinity metric function
class Ground_Metric_GM:
def __init__(self,
model_1_param: torch.tensor = None,
model_2_param: torch.tensor = None,
conv_param: bool = False,
bias_param: bool = False,
pre_conv_param: bool = False,
pre_conv_image_size_squared: int = None):
self.model_1_param = model_1_param
self.model_2_param = model_2_param
self.conv_param = conv_param
self.bias_param = bias_param
# bias, or fully-connected from linear
if bias_param is True or (conv_param is False and pre_conv_param is False):
self.model_1_param = self.model_1_param.reshape(1, -1, 1)
self.model_2_param = self.model_2_param.reshape(1, -1, 1)
# fully-connected from conv
elif conv_param is False and pre_conv_param is True:
self.model_1_param = self.model_1_param.reshape(1, -1, pre_conv_image_size_squared)
self.model_2_param = self.model_2_param.reshape(1, -1, pre_conv_image_size_squared)
# conv
else:
self.model_1_param = self.model_1_param.reshape(1, -1, model_1_param.shape[-1])
self.model_2_param = self.model_2_param.reshape(1, -1, model_2_param.shape[-1])
def process_distance(self, p: int = 2):
return torch.cdist(
self.model_1_param.to(torch.float),
self.model_2_param.to(torch.float),
p=p)[0]
def process_soft_affinity(self, p: int = 2):
return torch.exp(0 - self.process_distance(p=p))
Define the affinity function between two neural networks. This function takes multiple neural network modules, and construct the corresponding affinity matrix which is further processed by the graph matching solver.
def graph_matching_fusion(networks: list):
def total_node_num(network: torch.nn.Module):
# count the total number of nodes in the network [network]
num_nodes = 0
for idx, (name, parameters) in enumerate(network.named_parameters()):
if 'bias' in name:
continue
if idx == 0:
num_nodes += parameters.shape[1]
num_nodes += parameters.shape[0]
return num_nodes
n1 = total_node_num(network=networks[0])
n2 = total_node_num(network=networks[1])
assert (n1 == n2)
affinity = torch.zeros([n1 * n2, n1 * n2], device=device)
num_layers = len(list(zip(networks[0].parameters(), networks[1].parameters())))
num_nodes_before = 0
num_nodes_incremental = []
num_nodes_layers = []
pre_conv_list = []
cur_conv_list = []
conv_kernel_size_list = []
num_nodes_pre = 0
is_conv = False
pre_conv = False
pre_conv_out_channel = 1
is_final_bias = False
perm_is_complete = True
named_weight_list_0 = [named_parameter for named_parameter in networks[0].named_parameters()]
for idx, ((_, fc_layer0_weight), (_, fc_layer1_weight)) in \
enumerate(zip(networks[0].named_parameters(), networks[1].named_parameters())):
assert fc_layer0_weight.shape == fc_layer1_weight.shape
layer_shape = fc_layer0_weight.shape
num_nodes_cur = fc_layer0_weight.shape[0]
if len(layer_shape) > 1:
if is_conv is True and len(layer_shape) == 2:
num_nodes_pre = pre_conv_out_channel
else:
num_nodes_pre = fc_layer0_weight.shape[1]
if idx >= 1 and len(named_weight_list_0[idx - 1][1].shape) == 1:
pre_bias = True
else:
pre_bias = False
if len(layer_shape) > 2:
is_bias = False
if not pre_bias:
pre_conv = is_conv
pre_conv_list.append(pre_conv)
is_conv = True
cur_conv_list.append(is_conv)
fc_layer0_weight_data = fc_layer0_weight.data.view(
fc_layer0_weight.shape[0], fc_layer0_weight.shape[1], -1)
fc_layer1_weight_data = fc_layer1_weight.data.view(
fc_layer1_weight.shape[0], fc_layer1_weight.shape[1], -1)
elif len(layer_shape) == 2:
is_bias = False
if not pre_bias:
pre_conv = is_conv
pre_conv_list.append(pre_conv)
is_conv = False
cur_conv_list.append(is_conv)
fc_layer0_weight_data = fc_layer0_weight.data
fc_layer1_weight_data = fc_layer1_weight.data
else:
is_bias = True
if not pre_bias:
pre_conv = is_conv
pre_conv_list.append(pre_conv)
is_conv = False
cur_conv_list.append(is_conv)
fc_layer0_weight_data = fc_layer0_weight.data
fc_layer1_weight_data = fc_layer1_weight.data
if is_conv:
pre_conv_out_channel = num_nodes_cur
if is_bias is True and idx == num_layers - 1:
is_final_bias = True
if idx == 0:
for a in range(num_nodes_pre):
affinity[(num_nodes_before + a) * n2 + num_nodes_before + a] \
[(num_nodes_before + a) * n2 + num_nodes_before + a] \
= 1
if idx == num_layers - 2 and 'bias' in named_weight_list_0[idx + 1][0] or \
idx == num_layers - 1 and 'bias' not in named_weight_list_0[idx][0]:
for a in range(num_nodes_cur):
affinity[(num_nodes_before + num_nodes_pre + a) * n2 + num_nodes_before + num_nodes_pre + a] \
[(num_nodes_before + num_nodes_pre + a) * n2 + num_nodes_before + num_nodes_pre + a] \
= 1
if is_bias is False:
ground_metric = Ground_Metric_GM(
fc_layer0_weight_data, fc_layer1_weight_data, is_conv, is_bias,
pre_conv, int(fc_layer0_weight_data.shape[1] / pre_conv_out_channel))
else:
ground_metric = Ground_Metric_GM(
fc_layer0_weight_data, fc_layer1_weight_data, is_conv, is_bias,
pre_conv, 1)
layer_affinity = ground_metric.process_soft_affinity(p=2)
if is_bias is False:
pre_conv_kernel_size = fc_layer0_weight.shape[3] if is_conv else None
conv_kernel_size_list.append(pre_conv_kernel_size)
if is_bias is True and is_final_bias is False:
for a in range(num_nodes_cur):
for c in range(num_nodes_cur):
affinity[(num_nodes_before + a) * n2 + num_nodes_before + c] \
[(num_nodes_before + a) * n2 + num_nodes_before + c] \
= layer_affinity[a][c]
elif is_final_bias is False:
for a in range(num_nodes_pre):
for b in range(num_nodes_cur):
affinity[
(num_nodes_before + a) * n2 + num_nodes_before:
(num_nodes_before + a) * n2 + num_nodes_before + num_nodes_pre,
(num_nodes_before + num_nodes_pre + b) * n2 + num_nodes_before + num_nodes_pre:
(num_nodes_before + num_nodes_pre + b) * n2 + num_nodes_before + num_nodes_pre + num_nodes_cur] \
= layer_affinity[a + b * num_nodes_pre].view(num_nodes_cur, num_nodes_pre).transpose(0, 1)
if is_bias is False:
num_nodes_before += num_nodes_pre
num_nodes_incremental.append(num_nodes_before)
num_nodes_layers.append(num_nodes_cur)
# affinity = (affinity + affinity.t()) / 2
return affinity, [n1, n2, num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list]
Get the affinity (similarity) matrix between model1 and model2.
K, params = graph_matching_fusion([model1, model2])
Align the models by graph matching
Align the channels of model1 & model2 by maximize the affinity (similarity) via graph matching algorithms.
n1 = params[0]
n2 = params[1]
X = pygm.sm(K, n1, n2)
Project X
to neural network matching result. The neural network matching matrix is built by applying
Hungarian to small blocks of X
, because only the channels from the same neural network layer can be
matched.
Note
In this example, we assume the last FC layer is aligned and need not to be matched.
new_X = torch.zeros_like(X)
new_X[:params[2][0], :params[2][0]] = torch.eye(params[2][0], device=device)
for start_idx, length in zip(params[2][:-1], params[3][:-1]): # params[2] and params[3] are the indices of layers
slicing = slice(start_idx, start_idx + length)
new_X[slicing, slicing] = pygm.hungarian(X[slicing, slicing])
# assume the last FC layer is aligned
slicing = slice(params[2][-1], params[2][-1] + params[3][-1])
new_X[slicing, slicing] = torch.eye(params[3][-1], device=device)
X = new_X
Visualization of the matching result. The black lines splits the channels of different layers.
plt.figure(figsize=(4, 4))
plt.imshow(X.cpu().numpy(), cmap='Blues')
for idx in params[2]:
plt.axvline(x=idx, color='k')
plt.axhline(y=idx, color='k')

Define the alignment function: fuse the models based on matching result
def align(solution, fusion_proportion, networks: list, params: list):
[_, _, num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list] = params
named_weight_list_0 = [named_parameter for named_parameter in networks[0].named_parameters()]
aligned_wt_0 = [parameter.data for name, parameter in named_weight_list_0]
idx = 0
num_layers = len(aligned_wt_0)
for num_before, num_cur, cur_conv, cur_kernel_size in \
zip(num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list):
perm = solution[num_before:num_before + num_cur, num_before:num_before + num_cur]
assert 'bias' not in named_weight_list_0[idx][0]
if len(named_weight_list_0[idx][1].shape) == 4:
aligned_wt_0[idx] = (perm.transpose(0, 1).to(torch.float64) @
aligned_wt_0[idx].to(torch.float64).permute(2, 3, 0, 1)) \
.permute(2, 3, 0, 1)
else:
aligned_wt_0[idx] = perm.transpose(0, 1).to(torch.float64) @ aligned_wt_0[idx].to(torch.float64)
idx += 1
if idx >= num_layers:
continue
if 'bias' in named_weight_list_0[idx][0]:
aligned_wt_0[idx] = aligned_wt_0[idx].to(torch.float64) @ perm.to(torch.float64)
idx += 1
if idx >= num_layers:
continue
if cur_conv and len(named_weight_list_0[idx][1].shape) == 2:
aligned_wt_0[idx] = (aligned_wt_0[idx].to(torch.float64)
.reshape(aligned_wt_0[idx].shape[0], 64, -1)
.permute(0, 2, 1)
@ perm.to(torch.float64)) \
.permute(0, 2, 1) \
.reshape(aligned_wt_0[idx].shape[0], -1)
elif len(named_weight_list_0[idx][1].shape) == 4:
aligned_wt_0[idx] = (aligned_wt_0[idx].to(torch.float64)
.permute(2, 3, 0, 1)
@ perm.to(torch.float64)) \
.permute(2, 3, 0, 1)
else:
aligned_wt_0[idx] = aligned_wt_0[idx].to(torch.float64) @ perm.to(torch.float64)
assert idx == num_layers
averaged_weights = []
for idx, parameter in enumerate(networks[1].parameters()):
averaged_weights.append((1 - fusion_proportion) * aligned_wt_0[idx] + fusion_proportion * parameter)
return averaged_weights
Test the fused model
The fusion_proportion
variable denotes the contribution to the new model. For example, if fusion_proportion=0.2
,
the fused model = 80% model1 + 20% model2.
def align_model_and_test(X):
acc_list = []
for fusion_proportion in torch.arange(0, 1.1, 0.1):
fused_weights = align(X, fusion_proportion, [model1, model2], params)
fused_model = SimpleNet()
state_dict = fused_model.state_dict()
for idx, (key, _) in enumerate(state_dict.items()):
state_dict[key] = fused_weights[idx]
fused_model.load_state_dict(state_dict)
fused_model.to(device)
test_loss = 0
correct = 0
for data, target in test_loader:
data = data.to(device)
target = target.to(device)
output = fused_model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
test_loss /= len(test_loader.dataset)
acc = 100. * correct / len(test_loader.dataset)
print(
f"{1 - fusion_proportion:.2f} model1 + {fusion_proportion:.2f} model2 -> fused model accuracy: {acc:.2f}%")
acc_list.append(acc)
return torch.tensor(acc_list)
print('Graph Matching Fusion')
gm_acc_list = align_model_and_test(X)
Graph Matching Fusion
1.00 model1 + 0.00 model2 -> fused model accuracy: 84.18%
0.90 model1 + 0.10 model2 -> fused model accuracy: 85.12%
0.80 model1 + 0.20 model2 -> fused model accuracy: 85.21%
0.70 model1 + 0.30 model2 -> fused model accuracy: 82.52%
0.60 model1 + 0.40 model2 -> fused model accuracy: 71.11%
0.50 model1 + 0.50 model2 -> fused model accuracy: 53.74%
0.40 model1 + 0.60 model2 -> fused model accuracy: 63.26%
0.30 model1 + 0.70 model2 -> fused model accuracy: 78.51%
0.20 model1 + 0.80 model2 -> fused model accuracy: 82.81%
0.10 model1 + 0.90 model2 -> fused model accuracy: 83.97%
0.00 model1 + 1.00 model2 -> fused model accuracy: 83.81%
Compare with vanilla model fusion (no matching), graph matching method stabilizes the fusion step:
print('No Matching Fusion')
vanilla_acc_list = align_model_and_test(torch.eye(n1, device=device))
plt.figure(figsize=(4, 4))
plt.title('Fused Model Accuracy')
plt.plot(torch.arange(0, 1.1, 0.1).numpy(), gm_acc_list.cpu().numpy(), 'r*-', label='Graph Matching Fusion')
plt.plot(torch.arange(0, 1.1, 0.1).numpy(), vanilla_acc_list.cpu().numpy(), 'b*-', label='No Matching Fusion')
plt.plot(torch.arange(0, 1.1, 0.1).numpy(), [acc1] * 11, '--', color="gray", label='Model1 Accuracy')
plt.plot(torch.arange(0, 1.1, 0.1).numpy(), [acc2] * 11, '--', color="brown", label='Model2 Accuracy')
plt.gca().set_xlabel('Fusion Proportion')
plt.gca().set_ylabel('Accuracy (%)')
plt.ylim((70, 87))
plt.legend(loc=3)
plt.show()

No Matching Fusion
1.00 model1 + 0.00 model2 -> fused model accuracy: 84.18%
0.90 model1 + 0.10 model2 -> fused model accuracy: 84.01%
0.80 model1 + 0.20 model2 -> fused model accuracy: 81.91%
0.70 model1 + 0.30 model2 -> fused model accuracy: 74.67%
0.60 model1 + 0.40 model2 -> fused model accuracy: 60.39%
0.50 model1 + 0.50 model2 -> fused model accuracy: 47.16%
0.40 model1 + 0.60 model2 -> fused model accuracy: 55.34%
0.30 model1 + 0.70 model2 -> fused model accuracy: 72.86%
0.20 model1 + 0.80 model2 -> fused model accuracy: 79.64%
0.10 model1 + 0.90 model2 -> fused model accuracy: 82.56%
0.00 model1 + 1.00 model2 -> fused model accuracy: 83.81%
Print the result summary
end_time = time.perf_counter()
print(f'time consumed for model fusion: {end_time - st_time:.2f} seconds')
print(f'model1 accuracy = {acc1}%, model2 accuracy = {acc2}%')
print(f"best fused model accuracy: {torch.max(gm_acc_list):.2f}%")
time consumed for model fusion: 105.46 seconds
model1 accuracy = 84.18%, model2 accuracy = 83.81%
best fused model accuracy: 85.21%
Note
This example supports both GPU and CPU, and the online documentation is built by a CPU-only machine. The efficiency will be significantly improved if you run this code on GPU.
Total running time of the script: ( 1 minutes 54.211 seconds)
Contributing to pygmtools
First, thank you for contributing to pygmtools
!
How to contribute
The preferred workflow for contributing to pygmtools
is to fork the
main repository on
GitHub, clone, and develop on a branch. Steps:
Fork the project repository by clicking on the ‘Fork’ button near the top right of the page. This creates a copy of the code under your GitHub user account. For more details on how to fork a repository see this guide.
Clone your fork of the repo from your GitHub account to your local disk:
$ git clone git@github.com:YourUserName/pygmtools.git $ cd pygmtools
Create a
feature
branch to hold your development changes:$ git checkout -b my-feature
Always use a
feature
branch. It is good practice to never work on themaster
branch!Develop the feature on your feature branch. Add changed files using
git add
and thengit commit
files:$ git add modified_files $ git commit
to record your changes in Git, then push the changes to your GitHub account with:
$ git push -u origin my-feature
Follow these instructions to create a pull request from your fork. This will email the committers and an automatic check will run.
(If any of the above seems like magic to you, please look up the Git documentation on the web, or ask a friend or another contributor for help.)
Pull Request Checklist
We recommended that your contribution complies with the following rules before you submit a pull request:
Follow the PEP8 Guidelines.
If your pull request addresses an issue, please use the pull request title to describe the issue and mention the issue number in the pull request description. This will make sure a link back to the original issue is created.
All public methods should have informative docstrings with sample usage presented as doctests when appropriate.
When adding additional functionality, provide at least one example script in the
examples/
folder. Have a look at other examples for reference. Examples should demonstrate why the new functionality is useful in practice and, if possible, compare it to other methods available inpygmtools
.Documentation and high-coverage tests are necessary for enhancements to be accepted. Bug-fixes or new features should be provided with non-regression tests. These tests verify the correct behavior of the fix or feature. In this manner, further modifications on the code base are granted to be consistent with the desired behavior. For the Bug-fixes case, at the time of the PR, these tests should fail for the code base in master and pass for the PR code.
At least one paragraph of narrative documentation with links to references in the literature and the example.
You can also check for common programming errors with the following tools:
No pyflakes warnings, check with:
$ pip install pyflakes $ pyflakes path/to/module.py
No PEP8 warnings, check with:
$ pip install pep8 $ pep8 path/to/module.py
AutoPEP8 can help you fix some of the easy redundant errors:
$ pip install autopep8 $ autopep8 path/to/pep8.py
Filing bugs
We use Github issues to track all bugs and feature requests; feel free to open an issue if you have found a bug or wish to see a feature implemented.
It is recommended to check that your issue complies with the following rules before submitting:
Verify that your issue is not being currently addressed by other issues or pull requests.
Please ensure all code snippets and error messages are formatted in appropriate code blocks. See Creating and highlighting code blocks.
Please include your operating system type and version number, as well as your Python, pygmtools, numpy, and scipy versions. Please also provide the name of your running backend, and the GPU/CUDA versions if you are using GPU. This information can be found by running the following environment report (
pygmtools>=0.2.9
):$ python3 -c 'import pygmtools; pygmtools.env_report()'
If you are using GPU, make sure to install
pynvml
before running the above script:pip install pynvml
.Please be specific about what estimators and/or functions are involved and the shape of the data, as appropriate; please include a reproducible code snippet or link to a gist. If an exception is raised, please provide the traceback.
Documentation
We are glad to accept any sort of documentation: function docstrings,
reStructuredText documents, tutorials, examples, etc.
reStructuredText documents live in the source code repository under the
doc/
directory.
You can edit the documentation using any text editor and then generate
the HTML output by typing make html
from the docs/
directory.
The resulting HTML files are in docs/_build/
and are viewable in
any web browser. The example files in examples/
are also built.
If you want to skip building the examples, please use the command
make html-noplot
.
For building the documentation, you will need the packages listed in
docs/requirements.txt
. Please use python>=3.7
because the packages
for earlier Python versions are outdated.
When you are writing documentation, it is important to keep a good compromise between mathematical and algorithmic details, and give intuition to the reader on what the algorithm does. It is best to always start with a small paragraph with a hand-waving explanation of what the method does to the data.
This Contribution guide is strongly inpired by the one of the scikit-learn team.