pygmtools.utils.permutation_loss

pygmtools.utils.permutation_loss(pred_dsmat, gt_perm, n1=None, n2=None, backend=None)[source]

Binary cross entropy loss between two permutations, also known as “permutation loss”. Proposed by “Wang et al. Learning Combinatorial Embedding Networks for Deep Graph Matching. ICCV 2019.”

\[L_{perm} =- \sum_{i \in \mathcal{V}_1, j \in \mathcal{V}_2} \left(\mathbf{X}^{gt}_{i,j} \log \mathbf{S}_{i,j} + (1-\mathbf{X}^{gt}_{i,j}) \log (1-\mathbf{S}_{i,j}) \right)\]

where \(\mathcal{V}_1, \mathcal{V}_2\) are vertex sets for two graphs.

Parameters
  • pred_dsmat\((b\times n_1 \times n_2)\) predicted doubly-stochastic matrix \((\mathbf{S})\)

  • gt_perm\((b\times n_1 \times n_2)\) ground truth permutation matrix \((\mathbf{X}^{gt})\)

  • n1 – (optional) \((b)\) number of exact pairs in the first graph.

  • n2 – (optional) \((b)\) number of exact pairs in the second graph.

  • backend – (default: pygmtools.BACKEND variable) the backend for computation.

Returns

\((1)\) averaged permutation loss

Note

We support batched instances with different number of nodes, therefore n1 and n2 are required if you want to specify the exact number of nodes of each instance in the batch.

Note

For batched input, this loss function computes the averaged loss among all instances in the batch. This function also supports non-batched input if the batch dimension (\(b\)) is ignored.