pygmtools.neural_solvers.pca_gm

pygmtools.neural_solvers.pca_gm(feat1, feat2, A1, A2, n1=None, n2=None, in_channel=1024, hidden_channel=2048, out_channel=2048, num_layers=2, sk_max_iter=20, sk_tau=0.05, network=None, return_network=False, pretrain='voc', backend=None)[source]

The PCA-GM (Permutation loss and Cross-graph Affinity Graph Matching) neural network model for processing two individual graphs (KB-QAP). The graph matching module is composed of several intra-graph embedding layers, a cross-graph embedding layer, and a Sinkhorn matching layer. Only the second last layer has a cross-graph update layer.

See the following pipeline for an example, with application to visual graph matching (layers in the gray box + Affinity Metric + Sinkhorn are implemented by pygmtools):

../../_images/pca_gm.png

See the following paper for more technical details: “Wang et al. Combinatorial Learning of Robust Deep Graph Matching: an Embedding based Approach. TPAMI 2020.”

You may be also interested in the extended version IPCA-GM (see ipca_gm()).

Parameters
  • feat1\((b\times n_1 \times d)\) input feature of graph1

  • feat2\((b\times n_2 \times d)\) input feature of graph2

  • A1\((b\times n_1 \times n_1)\) input adjacency matrix of graph1

  • A2\((b\times n_2 \times n_2)\) input adjacency matrix of graph2

  • n1\((b)\) number of nodes in graph1. Optional if all equal to \(n_1\)

  • n2\((b)\) number of nodes in graph2. Optional if all equal to \(n_2\)

  • in_channel – (default: 1024) Channel size of the input layer. It must match the feature dimension \((d)\) of feat1, feat2. Ignored if the network object is given (ignored if network!=None)

  • hidden_channel – (default: 2048) Channel size of hidden layers. Ignored if the network object is given (ignored if network!=None)

  • out_channel – (default: 2048) Channel size of the output layer. Ignored if the network object is given (ignored if network!=None)

  • num_layers – (default: 2) Number of graph embedding layers. Must be >=2. Ignored if the network object is given (ignored if network!=None)

  • sk_max_iter – (default: 20) Max number of iterations of Sinkhorn. See sinkhorn() for more details about this argument.

  • sk_tau – (default: 0.05) The temperature parameter of Sinkhorn. See sinkhorn() for more details about this argument.

  • network – (default: None) The network object. If None, a new network object will be created, and load the model weights specified in pretrain argument.

  • return_network – (default: False) Return the network object (saving model construction time if calling the model multiple times).

  • pretrain – (default: ‘voc’) If network==None, the pretrained model weights to be loaded. Available pretrained weights: voc (on Pascal VOC Keypoint dataset), willow (on Willow Object Class dataset), voc-all (on Pascal VOC Keypoint dataset, without filtering), or False (no pretraining).

  • backend – (default: pygmtools.BACKEND variable) the backend for computation.

Returns

if return_network==False, \((b\times n_1 \times n_2)\) the doubly-stochastic matching matrix

if return_network==True, \((b\times n_1 \times n_2)\) the doubly-stochastic matching matrix, the network object

Note

You may need a proxy to load the pretrained weights if Google drive is not accessible in your contry/region. You may also download the pretrained models manually and put them at ~/.cache/pygmtools (for Linux).

[google drive] [baidu drive]

Note

This function also supports non-batched input, by ignoring all batch dimensions in the input tensors.

Numpy Example
>>> import numpy as np
>>> import pygmtools as pygm
>>> pygm.set_backend('numpy')
>>> np.random.seed(1)

# Generate a batch of isomorphic graphs
>>> batch_size = 10
>>> X_gt = np.zeros((batch_size, 4, 4))
>>> X_gt[:, np.arange(0, 4, dtype='i4'), np.random.permutation(4)] = 1
>>> A1 = 1. * (np.random.rand(batch_size, 4, 4) > 0.5)
>>> for i in np.arange(4): # discard self-loop edges
...    for j in np.arange(batch_size):
...        A1[j][i][i] = 0
>>> A2 = np.matmul(np.matmul(X_gt.swapaxes(1, 2), A1), X_gt)
>>> feat1 = np.random.rand(batch_size, 4, 1024) - 0.5
>>> feat2 = np.matmul(X_gt.swapaxes(1, 2), feat1)
>>> n1 = n2 = np.array([4] * batch_size)

# Match by PCA-GM (load pretrained model)
>>> X, net = pygm.pca_gm(feat1, feat2, A1, A2, n1, n2, return_network=True)
Downloading to ~/.cache/pygmtools/pca_gm_voc_numpy.npy...
>>> (pygm.hungarian(X) * X_gt).sum() / X_gt.sum() # accuracy
1.0

# Pass the net object to avoid rebuilding the model agian
>>> X = pygm.pca_gm(feat1, feat2, A1, A2, n1, n2, network=net)
>>> (pygm.hungarian(X) * X_gt).sum() / X_gt.sum() # accuracy
1.0

# You may also load other pretrained weights
>>> X, net = pygm.pca_gm(feat1, feat2, A1, A2, n1, n2, return_network=True, pretrain='willow')
Downloading to ~/.cache/pygmtools/pca_gm_willow_numpy.npy...
>>> (pygm.hungarian(X) * X_gt).sum() / X_gt.sum() # accuracy
1.0

# You may configure your own model and integrate the model into a deep learning pipeline. For example:
>>> net = pygm.utils.get_network(pygm.pca_gm, in_channel=1024, hidden_channel=2048, out_channel=512, num_layers=3, pretrain=False)
# feat1/feat2 may be outputs by other neural networks
>>> X = pygm.pca_gm(feat1, feat2, A1, A2, n1, n2, network=net)
>>> (pygm.hungarian(X) * X_gt).sum() / X_gt.sum() # accuracy
1.0
PyTorch Example
>>> import torch
>>> import pygmtools as pygm
>>> pygm.set_backend('pytorch')
>>> _ = torch.manual_seed(1)

# Generate a batch of isomorphic graphs
>>> batch_size = 10
>>> X_gt = torch.zeros(batch_size, 4, 4)
>>> X_gt[:, torch.arange(0, 4, dtype=torch.int64), torch.randperm(4)] = 1
>>> A1 = 1. * (torch.rand(batch_size, 4, 4) > 0.5)
>>> torch.diagonal(A1, dim1=1, dim2=2)[:] = 0 # discard self-loop edges
>>> A2 = torch.bmm(torch.bmm(X_gt.transpose(1, 2), A1), X_gt)
>>> feat1 = torch.rand(batch_size, 4, 1024) - 0.5
>>> feat2 = torch.bmm(X_gt.transpose(1, 2), feat1)
>>> n1 = n2 = torch.tensor([4] * batch_size)

# Match by PCA-GM (load pretrained model)
>>> X, net = pygm.pca_gm(feat1, feat2, A1, A2, n1, n2, return_network=True)
Downloading to ~/.cache/pygmtools/pca_gm_voc_pytorch.pt...
>>> (pygm.hungarian(X) * X_gt).sum() / X_gt.sum() # accuracy
tensor(1.)

# Pass the net object to avoid rebuilding the model agian
>>> X = pygm.pca_gm(feat1, feat2, A1, A2, n1, n2, network=net)

# You may also load other pretrained weights
>>> X, net = pygm.pca_gm(feat1, feat2, A1, A2, n1, n2, return_network=True, pretrain='willow')
Downloading to ~/.cache/pygmtools/pca_gm_willow_pytorch.pt...

# You may configure your own model and integrate the model into a deep learning pipeline. For example:
>>> net = pygm.utils.get_network(pygm.pca_gm, in_channel=1024, hidden_channel=2048, out_channel=512, num_layers=3, pretrain=False)
>>> optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# feat1/feat2 may be outputs by other neural networks
>>> X = pygm.pca_gm(feat1, feat2, A1, A2, n1, n2, network=net)
>>> loss = pygm.utils.permutation_loss(X, X_gt)
>>> loss.backward()
>>> optimizer.step()
Jittor Example
>>> import jittor as jt
>>> import pygmtools as pygm
>>> pygm.set_backend('jittor')
>>> _ = jt.seed(1)

# Generate a batch of isomorphic graphs
>>> batch_size = 10
>>> X_gt = jt.zeros((batch_size, 4, 4))
>>> X_gt[:, jt.arange(0, 4, dtype=jt.int64), jt.randperm(4)] = 1
>>> A1 = 1. * (jt.rand(batch_size, 4, 4) > 0.5)
>>> for i in range(batch_size):
>>>     for j in range(4):
>>>         A1.data[i][j][j] = 0  # discard self-loop edges
>>> A2 = jt.bmm(jt.bmm(X_gt.transpose(1, 2), A1), X_gt)
>>> feat1 = jt.rand(batch_size, 4, 1024) - 0.5
>>> feat2 = jt.bmm(X_gt.transpose(1, 2), feat1)
>>> n1 = n2 = jt.Var([4] * batch_size)

# Match by PCA-GM (load pretrained model)
>>> X, net = pygm.pca_gm(feat1, feat2, A1, A2, n1, n2, return_network=True)
Downloading to ~/.cache/pygmtools/pca_gm_voc_jittor.pt...

# Pass the net object to avoid rebuilding the model agian
>>> X = pygm.pca_gm(feat1, feat2, A1, A2, n1, n2, network=net)

# You may also load other pretrained weights
>>> X, net = pygm.pca_gm(feat1, feat2, A1, A2, n1, n2, return_network=True, pretrain='willow')
Downloading to ~/.cache/pygmtools/pca_gm_willow_jittor.pt...
>>> (pygm.hungarian(X) * X_gt).sum() / X_gt.sum() # accuracy
jt.Var([1.], dtype=float32)

# You may configure your own model and integrate the model into a deep learning pipeline. For example:
>>> net = pygm.utils.get_network(pygm.pca_gm, in_channel=1024, hidden_channel=2048, out_channel=512, num_layers=3, pretrain=False)
>>> optimizer = jt.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# feat1/feat2 may be outputs by other neural networks
>>> X = pygm.pca_gm(feat1, feat2, A1, A2, n1, n2, network=net)
>>> loss = pygm.utils.permutation_loss(X, X_gt)
>>> optimizer.backward(loss)
>>> optimizer.step()
Paddle Example
>>> import paddle
>>> import pygmtools as pygm
>>> pygm.set_backend('paddle')
>>> _ = paddle.seed(4)

# Generate a batch of isomorphic graphs
>>> batch_size = 10
>>> X_gt = paddle.zeros((batch_size, 4, 4))
>>> X_gt[:, paddle.arange(0, 4, dtype=paddle.int64), paddle.randperm(4)] = 1
>>> A1 = 1. * (paddle.rand((batch_size, 4, 4)) > 0.5)
>>> paddle.diagonal(A1, axis1=1, axis2=2)[:] = 0 # discard self-loop edges
>>> A2 = paddle.bmm(paddle.bmm(X_gt.transpose((0, 2, 1)), A1), X_gt)
>>> feat1 = paddle.rand((batch_size, 4, 1024)) - 0.5
>>> feat2 = paddle.bmm(X_gt.transpose((0, 2, 1)), feat1)
>>> n1 = n2 = paddle.to_tensor([4] * batch_size)

# Match by PCA-GM (load pretrained model)
>>> X, net = pygm.pca_gm(feat1, feat2, A1, A2, n1, n2, return_network=True)
Downloading to ~/.cache/pygmtools/pca_gm_voc_paddle.pdparams...
>>> (pygm.hungarian(X) * X_gt).sum() / X_gt.sum() # accuracy
Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
    [1.])

# Pass the net object to avoid rebuilding the model agian
>>> X = pygm.pca_gm(feat1, feat2, A1, A2, n1, n2, network=net)

# You may also load other pretrained weights
>>> X, net = pygm.pca_gm(feat1, feat2, A1, A2, n1, n2, return_network=True, pretrain='willow')
Downloading to ~/.cache/pygmtools/pca_gm_willow_paddle.pdparams...

# You may configure your own model and integrate the model into a deep learning pipeline. For example:
>>> net = pygm.utils.get_network(pygm.pca_gm, in_channel=1024, hidden_channel=2048, out_channel=512, num_layers=3, pretrain=False)
>>> optimizer = paddle.optimizer.SGD(parameters=net.parameters(), learning_rate=0.001)
# feat1/feat2 may be outputs by other neural networks
>>> X = pygm.pca_gm(feat1, feat2, A1, A2, n1, n2, network=net)
>>> loss = pygm.utils.permutation_loss(X, X_gt)
>>> loss.backward()
>>> optimizer.step()

Note

If you find this model useful in your research, please cite:

@article{WangPAMI20,
  author = {Wang, Runzhong and Yan, Junchi and Yang, Xiaokang},
  title = {Combinatorial Learning of Robust Deep Graph Matching: an Embedding based Approach},
  journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence},
  year = {2020}
}