pygmtools.utils.permutation_loss
- pygmtools.utils.permutation_loss(pred_dsmat, gt_perm, n1=None, n2=None, backend=None)[source]
Binary cross entropy loss between two permutations, also known as “permutation loss”. Proposed by “Wang et al. Learning Combinatorial Embedding Networks for Deep Graph Matching. ICCV 2019.”
\[L_{perm} =- \sum_{i \in \mathcal{V}_1, j \in \mathcal{V}_2} \left(\mathbf{X}^{gt}_{i,j} \log \mathbf{S}_{i,j} + (1-\mathbf{X}^{gt}_{i,j}) \log (1-\mathbf{S}_{i,j}) \right)\]where \(\mathcal{V}_1, \mathcal{V}_2\) are vertex sets for two graphs.
- Parameters
pred_dsmat – \((b\times n_1 \times n_2)\) predicted doubly-stochastic matrix \((\mathbf{S})\)
gt_perm – \((b\times n_1 \times n_2)\) ground truth permutation matrix \((\mathbf{X}^{gt})\)
n1 – (optional) \((b)\) number of exact pairs in the first graph.
n2 – (optional) \((b)\) number of exact pairs in the second graph.
backend – (default:
pygmtools.BACKEND
variable) the backend for computation.
- Returns
\((1)\) averaged permutation loss
Note
We support batched instances with different number of nodes, therefore
n1
andn2
are required if you want to specify the exact number of nodes of each instance in the batch.Note
For batched input, this loss function computes the averaged loss among all instances in the batch. This function also supports non-batched input if the batch dimension (\(b\)) is ignored.