Note
Go to the end to download the full example code
PyTorch Backend Example: Multi-Graph Matching
This example shows how to match multiple graphs. Multi-graph matching means that more than two graphs are jointly matched.
# Author: Zetian Jiang <maple_jzt@sjtu.edu.cn>
# Ziao Guo <ziao.guo@sjtu.edu.cn>
#
# License: Mulan PSL v2 License
Note
The following solvers are included in this example:
rrwm()
(classic solver)cao()
(classic solver)mgm_floyd()
(classic solver)
import os
import time
import math
import copy
import torch # pytorch backend
import itertools
import numpy as np
import pygmtools as pygm
import matplotlib.pyplot as plt # for plotting
import scipy.io as sio # for loading .mat file
import scipy.spatial as spa # for Delaunay triangulation
from PIL import Image
from matplotlib.patches import ConnectionPatch # for plotting matching result
pygm.set_backend('pytorch') # set default backend for pygmtools
Load the images
Images are from the Willow Object Class dataset (this dataset also available with the Benchmark of pygmtools
,
see WillowObject
).
The images are resized to 256x256.
obj_resize = (256, 256)
n_images = 30
n_outlier = 0
img_list = []
kpts_list = []
n_kpts_list = []
perm_list = []
bm = pygm.benchmark.Benchmark(name='WillowObject',
sets='train',
obj_resize=obj_resize)
while len(img_list) < n_images:
data_list, gt_dict, _ = bm.rand_get_data(cls='Car')
for data in data_list:
img = Image.fromarray(data['img'])
coords = sorted(data['kpts'], key=lambda x: x['labels'])
kpts = torch.tensor([[kpt['x'] for kpt in coords],
[kpt['y'] for kpt in coords]])
perm = np.eye(kpts.shape[1])
img_list.append(img)
kpts_list.append(kpts)
n_kpts_list.append(kpts.shape[1])
perm_list.append(perm)
Visualize the images and keypoints
def plot_image_with_graph(img, kpt, A=None):
plt.imshow(img)
plt.scatter(kpt[0], kpt[1], c='w', edgecolors='k')
if A is not None:
for idx in torch.nonzero(A, as_tuple=False):
plt.plot((kpt[0, idx[0]], kpt[0, idx[1]]), (kpt[1, idx[0]], kpt[1, idx[1]]), 'k-')
plt.figure(figsize=(20, 18))
for i in range(n_images):
plt.subplot(5, n_images // 5, i + 1)
plt.title('Image {}'.format(i + 1))
plot_image_with_graph(img_list[i], kpts_list[i])
# plt.savefig('image')
# plt.close()

Build the graphs
Graph structures are built based on the geometric structure of the keypoint set. In this example, we refer to Delaunay triangulation.
def delaunay_triangulation(kpt):
d = spa.Delaunay(kpt.numpy().transpose())
A = torch.zeros(len(kpt[0]), len(kpt[0]))
for simplex in d.simplices:
for pair in itertools.permutations(simplex, 2):
A[pair] = 1
return A
adj_list = []
for i in range(n_images):
A = delaunay_triangulation(kpts_list[i])
adj_list.append(A)
Build affinity matrix
We follow the formulation of Quadratic Assignment Problem (QAP):
where the first step is to build the affinity matrix (\(\mathbf{K}\)) for each pair of graphs
def get_feature(n, points, adj):
"""
:param n: points # of graph
:param points: torch tensor, (n, 2)
:param adj: torch tensor, (n, n)
:return: edge feat, angle feat
"""
points_1 = points.reshape(n, 1, 2).repeat(1, n, 1)
points_2 = points.reshape(1, n, 2).repeat(n, 1, 1)
edge_feat = torch.sqrt(torch.sum((points_1 - points_2) ** 2, dim=2))
edge_feat = edge_feat / torch.max(edge_feat)
angle_feat = torch.atan((points_1[:, :, 1] - points_2[:, :, 1]) / (points_1[:, :, 0] - points_2[:, :, 0] + 1e-8))
angle_feat = 2 * angle_feat / math.pi
return edge_feat, angle_feat
def get_pair_affinity(edge_feat_1, angle_feat_1, edge_feat_2, angle_feat_2, adj1, adj2):
n1, n2 = edge_feat_1.shape[0], edge_feat_2.shape[0]
assert n1 == angle_feat_1.shape[0] and n2 == angle_feat_2.shape[0]
left_adj = adj1.reshape(n1, n1, 1, 1).repeat(1, 1, n2, n2)
right_adj = adj2.reshape(1, 1, n2, n2).repeat(n1, n1, 1, 1)
adj = left_adj * right_adj
left_edge_feat = edge_feat_1.reshape(n1, n1, 1, 1, -1).repeat(1, 1, n2, n2, 1)
right_edge_feat = edge_feat_2.reshape(1, 1, n2, n2, -1).repeat(n1, n1, 1, 1, 1)
edge_weight = torch.sqrt(torch.sum((left_edge_feat - right_edge_feat) ** 2, dim=-1))
left_angle_feat = angle_feat_1.reshape(n1, n1, 1, 1, -1).repeat(1, 1, n2, n2, 1)
right_angle_feat = angle_feat_2.reshape(1, 1, n2, n2, -1).repeat(n1, n1, 1, 1, 1)
angle_weight = torch.sqrt(torch.sum((left_angle_feat - right_angle_feat) ** 2, dim=-1))
affinity = edge_weight * 0.9 + angle_weight * 0.1
affinity = torch.exp(-affinity / 0.1) * adj
affinity = affinity.transpose(1, 2)
return affinity
def generate_affinity_matrix(n_points, points_list, adj_list):
m = len(n_points)
n_max = max(n_points)
affinity = torch.zeros(m, m, n_max, n_max, n_max, n_max)
edge_feat_list = []
angle_feat_list = []
for n, points, adj in zip(n_points, points_list, adj_list):
edge_feat, angle_feat = get_feature(n, points, adj)
edge_feat_list.append(edge_feat)
angle_feat_list.append(angle_feat)
for i, j in itertools.product(range(m), range(m)):
pair_affinity = get_pair_affinity(edge_feat_list[i],
angle_feat_list[i],
edge_feat_list[j],
angle_feat_list[j],
adj_list[i],
adj_list[j])
affinity[i, j] = pair_affinity
affinity = affinity.permute(0, 1, 3, 2, 5, 4).reshape(m, m, n_max * n_max, n_max * n_max)
return affinity
affinity_mat = generate_affinity_matrix(n_kpts_list, kpts_list, adj_list)
m = len(kpts_list)
n = int(torch.max(torch.tensor(n_kpts_list)))
ns_src = torch.ones(m * m).int() * n
ns_tgt = torch.ones(m * m).int() * n
Calculate accuracy, consistency, and affinity
def cal_accuracy(mat, gt_mat, n):
m = mat.shape[0]
acc = 0
for i in range(m):
for j in range(m):
_mat, _gt_mat = mat[i, j], gt_mat[i, j]
row_sum = torch.sum(_gt_mat, dim=0)
col_sum = torch.sum(_gt_mat, dim=1)
row_idx = [k for k in range(n) if row_sum[k] != 0]
col_idx = [k for k in range(n) if col_sum[k] != 0]
_mat = _mat[row_idx, :]
_mat = _mat[:, col_idx]
_gt_mat = _gt_mat[row_idx, :]
_gt_mat = _gt_mat[:, col_idx]
acc += 1 - torch.sum(torch.abs(_mat - _gt_mat)) / 2 / (n - n_outlier)
return acc / (m * m)
def cal_consistency(mat, gt_mat, m, n):
return torch.mean(get_batch_pc_opt(mat))
def cal_affinity(X, X_gt, K, m, n):
X_batch = X.reshape(-1, n, n)
X_gt_batch = X_gt.reshape(-1, n, n)
K_batch = K.reshape(-1, n * n, n * n)
affinity = get_batch_affinity(X_batch, K_batch)
affinity_gt = get_batch_affinity(X_gt_batch, K_batch)
return torch.mean(affinity / (affinity_gt + 1e-8))
def get_batch_affinity(X, K, norm=1):
"""
calculate affinity score
:param X: (b, n, n)
:param K: (b, n*n, n*n)
:param norm: normalization term
:return: affinity_score (b, 1, 1)
"""
b, n, _ = X.size()
vx = X.transpose(1, 2).reshape(b, -1, 1) # (b, n*n, 1)
vxt = vx.transpose(1, 2) # (b, 1, n*n)
affinity = torch.bmm(torch.bmm(vxt, K), vx) / norm
return affinity
def get_single_affinity(X, K, norm=1):
"""
calculate affinity score
:param X: (n, n)
:param K: (n*n, n*n)
:param norm: normalization term
:return: affinity_score scale
"""
n, _ = X.size()
vx = X.transpose(0, 1).reshape(-1, 1)
vxt = vx.transpose(0, 1)
affinity = torch.matmul(torch.matmul(vxt, K), vx) / norm
return affinity
def get_single_pc(X, i, j, Xij=None):
"""
:param X: (m, m, n, n) all the matching results
:param i: index
:param j: index
:param Xij: (n, n) matching
:return: the consistency of X_ij
"""
m, _, n, _ = X.size()
if Xij is None:
Xij = X[i, j]
pair_con = 0
for k in range(m):
X_combo = torch.matmul(X[i, k], X[k, j])
pair_con += torch.sum(torch.abs(Xij - X_combo)) / (2 * n)
return 1 - pair_con / m
def get_single_pc_opt(X, i, j, Xij=None):
"""
:param X: (m, m, n, n) all the matching results
:param i: index
:param j: index
:return: the consistency of X_ij
"""
m, _, n, _ = X.size()
if Xij is None:
Xij = X[i, j]
X1 = X[i, :].reshape(-1, n, n)
X2 = X[:, j].reshape(-1, n, n)
X_combo = torch.bmm(X1, X2)
pair_con = 1 - torch.sum(torch.abs(Xij - X_combo)) / (2 * n * m)
return pair_con
def get_batch_pc(X):
"""
:param X: (m, m, n, n) all the matching results
:return: (m, m) the consistency of X
"""
pair_con = torch.zeros(m, m).cuda()
for i in range(m):
for j in range(m):
pair_con[i, j] = get_single_pc_opt(X, i, j)
return pair_con
def get_batch_pc_opt(X):
"""
:param X: (m, m, n, n) all the matching results
:return: (m, m) the consistency of X
"""
m, _, n, _ = X.size()
X1 = X.reshape(m, 1, m, n, n).repeat(1, m, 1, 1, 1).reshape(-1, n, n) # X1[i, j, k] = X[i, k]
X2 = X.reshape(1, m, m, n, n).repeat(m, 1, 1, 1, 1).transpose(1, 2).reshape(-1, n, n) # X2[i, j, k] = X[k, j]
X_combo = torch.bmm(X1, X2).reshape(m, m, m, n, n)
X_ori = X.reshape(m, m, 1, n, n).repeat(1, 1, m, 1, 1)
pair_con = 1 - torch.sum(torch.abs(X_combo - X_ori), dim=(2, 3, 4)) / (2 * n * m)
return pair_con
def eval(mat, gt_mat, affinity, m, n):
acc = cal_accuracy(mat, gt_mat, n)
src = cal_affinity(mat, gt_mat, affinity, m, n)
con = cal_consistency(mat, gt_mat, m, n)
return acc, src, con
Generate gt mat
gt_mat = torch.zeros(m, m, n, n)
for i in range(m):
for j in range(m):
gt_mat[i, j] = torch.tensor(np.matmul(perm_list[i].transpose(0, 1), perm_list[j]))
# print(perm_list[0])
# print(perm_list[1])
# print(gt_mat[1, 2])
# print(gt_mat[0, 1] - gt_mat[1, 0].transpose(0, 1))
Pairwise graph matching by RRWM
See rrwm()
for the API reference.
a = 0
b = 12
tic = time.time()
rrwm_mat = pygm.classic_solvers.rrwm(affinity_mat.reshape(-1, n * n, n * n), ns_src, ns_tgt)
rrwm_mat = pygm.linear_solvers.hungarian(rrwm_mat)
toc = time.time()
rrwm_mat = rrwm_mat.reshape(m, m, n, n)
rrwm_acc, rrwm_src, rrwm_con = eval(rrwm_mat, gt_mat, affinity_mat, m, n)
rrwm_tim = toc - tic
plt.figure(figsize=(8, 4))
plt.suptitle('Multi-Graph Matching Result by RRWM')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img_list[a], kpts_list[a], adj_list[a])
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img_list[b], kpts_list[b], adj_list[b])
X = rrwm_mat[a, b]
for i in range(X.shape[0]):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts_list[a][:, i], xyB=kpts_list[b][:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)
# plt.savefig("RRWM.png")
# plt.close()

Multi graph matching by multi-graph solvers
Multi graph matching: CAO-M See
cao()
for the API reference.
base_mat = copy.deepcopy(rrwm_mat)
tic = time.time()
cao_m_mat = pygm.multi_graph_solvers.cao(affinity_mat, base_mat, mode='memory')
cao_m_mat = pygm.linear_solvers.hungarian(cao_m_mat.reshape(-1, n, n)).reshape(m, m, n, n)
toc = time.time()
cao_m_acc, cao_m_src, cao_m_con = eval(cao_m_mat, gt_mat, affinity_mat, m, n)
cao_m_tim = toc - tic + rrwm_tim
plt.figure(figsize=(8, 4))
plt.suptitle('Multi-Graph Matching Result by CAO-M')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img_list[a], kpts_list[a], adj_list[a])
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img_list[b], kpts_list[b], adj_list[b])
X = cao_m_mat[a, b]
for i in range(X.shape[0]):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts_list[a][:, i], xyB=kpts_list[b][:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)
# plt.savefig("CAO-M.png")
# plt.close()

Multi graph matching: CAO-T
See cao()
for the API reference.
base_mat = copy.deepcopy(rrwm_mat)
tic = time.time()
cao_t_mat = pygm.multi_graph_solvers.cao(affinity_mat, base_mat, mode='time')
cao_t_mat = pygm.linear_solvers.hungarian(cao_t_mat.reshape(-1, n, n)).reshape(m, m, n, n)
toc = time.time()
cao_t_acc, cao_t_src, cao_t_con = eval(cao_t_mat, gt_mat, affinity_mat, m, n)
cao_t_tim = toc - tic + rrwm_tim
plt.figure(figsize=(8, 4))
plt.suptitle('Multi-Graph Matching Result by CAO-T')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img_list[a], kpts_list[a], adj_list[a])
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img_list[b], kpts_list[b], adj_list[b])
X = cao_t_mat[a, b]
for i in range(X.shape[0]):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts_list[a][:, i], xyB=kpts_list[b][:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)
# plt.savefig("CAO-T.png")
# plt.close()

Multi graph matching: MGM-Floyd-M
See mgm_floyd()
for the API reference.
base_mat = copy.deepcopy(rrwm_mat)
tic = time.time()
floyd_m_mat = pygm.multi_graph_solvers.mgm_floyd(affinity_mat, base_mat, param_lambda=0.4, mode='memory')
floyd_m_mat = pygm.linear_solvers.hungarian(floyd_m_mat.reshape(-1, n, n)).reshape(m, m, n, n)
toc = time.time()
floyd_m_acc, floyd_m_src, floyd_m_con = eval(floyd_m_mat, gt_mat, affinity_mat, m, n)
floyd_m_tim = toc - tic + rrwm_tim
plt.figure(figsize=(8, 4))
plt.suptitle('Multi-Graph Matching Result by Floyd-M')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img_list[a], kpts_list[a], adj_list[a])
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img_list[b], kpts_list[b], adj_list[b])
X = floyd_m_mat[a, b]
for i in range(X.shape[0]):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts_list[a][:, i], xyB=kpts_list[b][:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)
# plt.savefig("Floyd-M.png")
# plt.close()

Multi graph matching: MGM-Floyd-T
See mgm_floyd()
for the API reference.
base_mat = copy.deepcopy(rrwm_mat)
tic = time.time()
floyd_t_mat = pygm.multi_graph_solvers.mgm_floyd(affinity_mat, base_mat, param_lambda=0.6, mode='time')
floyd_t_mat = pygm.linear_solvers.hungarian(floyd_t_mat.reshape(-1, n, n)).reshape(m, m, n, n)
toc = time.time()
floyd_t_acc, floyd_t_src, floyd_t_con = eval(floyd_t_mat, gt_mat, affinity_mat, m, n)
floyd_t_tim = toc - tic + rrwm_tim
plt.figure(figsize=(8, 4))
plt.suptitle('Multi-Graph Matching Result by Floyd-T')
ax1 = plt.subplot(1, 2, 1)
plot_image_with_graph(img_list[a], kpts_list[a], adj_list[a])
ax2 = plt.subplot(1, 2, 2)
plot_image_with_graph(img_list[b], kpts_list[b], adj_list[b])
X = floyd_t_mat[a, b]
for i in range(X.shape[0]):
j = torch.argmax(X[i]).item()
con = ConnectionPatch(xyA=kpts_list[a][:, i], xyB=kpts_list[b][:, j], coordsA="data", coordsB="data",
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
plt.gca().add_artist(con)
# plt.savefig("Floyd-T.png")
# plt.close()

Total running time of the script: (0 minutes 22.959 seconds)