# Paddle Backend Example: Model Fusion by Graph Matching

This example shows how to fuse different models into a single model by `pygmtools`. Model fusion aims to fuse multiple models into one, such that the fused model could have higher performance. The neural networks can be regarded as graphs (channels - nodes, update functions between channels - edges; node feature - bias, edge feature - weights), and fusing the models is equivalent to solving a graph matching problem. In this example, the given models are trained on MNIST data from different distributions, and the fused model could combine the knowledge from two input models and can reach higher accuracy when testing.

```# Author: Chang Liu <only-changer@sjtu.edu.cn>
#         Runzhong Wang <runzhong.wang@sjtu.edu.cn>
#         Wenzheng Pan <pwz1121@sjtu.edu.cn>
#
```

Note

This is a simplified implementation of the ideas in Liu et al. Deep Neural Network Fusion via Graph Matching with Applications to Model Ensemble and Federated Learning. ICML 2022. For more details, please refer to the paper and the official code repository.

Note

The following solvers are included in this example:

```import paddle
import time
from PIL import Image
import matplotlib.pyplot as plt
import pygmtools as pygm
import warnings
warnings.filterwarnings("ignore")

```
```Place(cpu)
```

## Define a simple CNN classifier network

```class SimpleNet(nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(3136, 32, bias_attr=False)
self.fc2 = nn.Linear(32, 10, bias_attr=False)

def forward(self, x):
output = F.relu(self.conv1(x))
output = self.max_pool(output)
output = F.relu(self.conv2(output))
output = self.max_pool(output)
output = output.reshape((output.shape[0], -1))
output = self.fc1(output)
output = self.fc2(output)
return output
```

## Load the trained models to be fused

```model1 = SimpleNet()
model2 = SimpleNet()
model1.to(device)
model2.to(device)
# unable to modify the directory to store the dataset.
mode='test',  # the dataset is used to test
transform=transforms.ToTensor(),  # the dataset is in the form of tensors
dataset=test_dataset,
batch_size=32,
shuffle=False)
```
```Cache file /home/wzever/.cache/paddle/dataset/mnist/t10k-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-images-idx3-ubyte.gz

item   1/403 [..............................] - ETA: 17s - 42ms/item
item   2/403 [..............................] - ETA: 8s - 21ms/item 
item   3/403 [..............................] - ETA: 5s - 14ms/item
item   4/403 [..............................] - ETA: 4s - 11ms/item
item   5/403 [..............................] - ETA: 8s - 22ms/item
item   6/403 [..............................] - ETA: 7s - 18ms/item
item   7/403 [..............................] - ETA: 6s - 16ms/item
item   8/403 [..............................] - ETA: 5s - 14ms/item
item   9/403 [..............................] - ETA: 6s - 16ms/item
item  10/403 [..............................] - ETA: 5s - 14ms/item
item  11/403 [..............................] - ETA: 5s - 13ms/item
item  12/403 [..............................] - ETA: 4s - 12ms/item
item  13/403 [..............................] - ETA: 5s - 13ms/item
item  14/403 [>.............................] - ETA: 4s - 12ms/item
item  15/403 [>.............................] - ETA: 4s - 12ms/item
item  16/403 [>.............................] - ETA: 4s - 11ms/item
item  17/403 [>.............................] - ETA: 4s - 12ms/item
item  18/403 [>.............................] - ETA: 4s - 12ms/item
item  19/403 [>.............................] - ETA: 4s - 11ms/item
item  20/403 [>.............................] - ETA: 3s - 10ms/item
item  20/403 [>.............................] - ETA: 3s - 10ms/item
item  21/403 [>.............................] - ETA: 4s - 12ms/item
item  22/403 [>.............................] - ETA: 4s - 11ms/item
item  23/403 [>.............................] - ETA: 4s - 11ms/item
item  24/403 [>.............................] - ETA: 3s - 10ms/item
item  25/403 [>.............................] - ETA: 4s - 11ms/item
item  26/403 [>.............................] - ETA: 4s - 11ms/item
item  27/403 [=>............................] - ETA: 3s - 11ms/item
item  28/403 [=>............................] - ETA: 3s - 10ms/item
item  29/403 [=>............................] - ETA: 4s - 11ms/item
item  30/403 [=>............................] - ETA: 3s - 11ms/item
item  31/403 [=>............................] - ETA: 3s - 10ms/item
item  32/403 [=>............................] - ETA: 3s - 10ms/item
item  33/403 [=>............................] - ETA: 3s - 11ms/item
item  34/403 [=>............................] - ETA: 3s - 10ms/item
item  35/403 [=>............................] - ETA: 3s - 10ms/item
item  36/403 [=>............................] - ETA: 3s - 10ms/item
item  37/403 [=>............................] - ETA: 3s - 11ms/item
item  38/403 [=>............................] - ETA: 3s - 11ms/item
item  39/403 [=>............................] - ETA: 3s - 10ms/item
item  40/403 [=>............................] - ETA: 3s - 10ms/item
item  40/403 [=>............................] - ETA: 3s - 10ms/item
item  41/403 [==>...........................] - ETA: 3s - 10ms/item
item  42/403 [==>...........................] - ETA: 3s - 10ms/item
item  43/403 [==>...........................] - ETA: 3s - 9ms/item 
item  44/403 [==>...........................] - ETA: 3s - 9ms/item
item  45/403 [==>...........................] - ETA: 3s - 10ms/item
item  46/403 [==>...........................] - ETA: 3s - 9ms/item 
item  47/403 [==>...........................] - ETA: 3s - 9ms/item
item  48/403 [==>...........................] - ETA: 3s - 9ms/item
item  49/403 [==>...........................] - ETA: 3s - 10ms/item
item  50/403 [==>...........................] - ETA: 3s - 9ms/item 
item  51/403 [==>...........................] - ETA: 3s - 9ms/item
item  52/403 [==>...........................] - ETA: 3s - 9ms/item
item  53/403 [==>...........................] - ETA: 3s - 9ms/item
item  54/403 [===>..........................] - ETA: 3s - 9ms/item
item  55/403 [===>..........................] - ETA: 3s - 9ms/item
item  56/403 [===>..........................] - ETA: 3s - 9ms/item
item  57/403 [===>..........................] - ETA: 3s - 9ms/item
item  58/403 [===>..........................] - ETA: 2s - 9ms/item
item  59/403 [===>..........................] - ETA: 2s - 8ms/item
item  60/403 [===>..........................] - ETA: 2s - 8ms/item
item  60/403 [===>..........................] - ETA: 2s - 8ms/item
item  61/403 [===>..........................] - ETA: 2s - 9ms/item
item  62/403 [===>..........................] - ETA: 2s - 9ms/item
item  63/403 [===>..........................] - ETA: 2s - 8ms/item
item  64/403 [===>..........................] - ETA: 2s - 8ms/item
item  65/403 [===>..........................] - ETA: 2s - 9ms/item
item  66/403 [===>..........................] - ETA: 2s - 9ms/item
item  67/403 [===>..........................] - ETA: 2s - 8ms/item
item  68/403 [====>.........................] - ETA: 2s - 8ms/item
item  69/403 [====>.........................] - ETA: 2s - 8ms/item
item  70/403 [====>.........................] - ETA: 2s - 8ms/item
item  71/403 [====>.........................] - ETA: 2s - 8ms/item
item  72/403 [====>.........................] - ETA: 2s - 8ms/item
item  73/403 [====>.........................] - ETA: 2s - 8ms/item
item  74/403 [====>.........................] - ETA: 2s - 8ms/item
item  75/403 [====>.........................] - ETA: 2s - 8ms/item
item  76/403 [====>.........................] - ETA: 2s - 8ms/item
item  77/403 [====>.........................] - ETA: 2s - 8ms/item
item  78/403 [====>.........................] - ETA: 2s - 8ms/item
item  79/403 [====>.........................] - ETA: 2s - 8ms/item
item  80/403 [====>.........................] - ETA: 2s - 8ms/item
item  80/403 [====>.........................] - ETA: 2s - 8ms/item
item  81/403 [=====>........................] - ETA: 2s - 8ms/item
item  82/403 [=====>........................] - ETA: 2s - 8ms/item
item  83/403 [=====>........................] - ETA: 2s - 8ms/item
item  84/403 [=====>........................] - ETA: 2s - 8ms/item
item  85/403 [=====>........................] - ETA: 2s - 8ms/item
item  86/403 [=====>........................] - ETA: 2s - 8ms/item
item  87/403 [=====>........................] - ETA: 2s - 8ms/item
item  88/403 [=====>........................] - ETA: 2s - 8ms/item
item  89/403 [=====>........................] - ETA: 2s - 8ms/item
item  90/403 [=====>........................] - ETA: 2s - 7ms/item
item  91/403 [=====>........................] - ETA: 2s - 7ms/item
item  92/403 [=====>........................] - ETA: 2s - 7ms/item
item  93/403 [=====>........................] - ETA: 2s - 8ms/item
item  94/403 [=====>........................] - ETA: 2s - 7ms/item
item  95/403 [======>.......................] - ETA: 2s - 7ms/item
item  96/403 [======>.......................] - ETA: 2s - 7ms/item
item  97/403 [======>.......................] - ETA: 2s - 7ms/item
item  98/403 [======>.......................] - ETA: 2s - 7ms/item
item  99/403 [======>.......................] - ETA: 2s - 7ms/item
item 100/403 [======>.......................] - ETA: 2s - 7ms/item
item 100/403 [======>.......................] - ETA: 2s - 7ms/item
item 101/403 [======>.......................] - ETA: 2s - 7ms/item
item 102/403 [======>.......................] - ETA: 2s - 7ms/item
item 103/403 [======>.......................] - ETA: 2s - 7ms/item
item 104/403 [======>.......................] - ETA: 2s - 7ms/item
item 105/403 [======>.......................] - ETA: 2s - 7ms/item
item 106/403 [======>.......................] - ETA: 2s - 7ms/item
item 107/403 [======>.......................] - ETA: 2s - 7ms/item
item 108/403 [=======>......................] - ETA: 2s - 7ms/item
item 109/403 [=======>......................] - ETA: 2s - 7ms/item
item 110/403 [=======>......................] - ETA: 2s - 7ms/item
item 111/403 [=======>......................] - ETA: 2s - 7ms/item
item 112/403 [=======>......................] - ETA: 2s - 7ms/item
item 113/403 [=======>......................] - ETA: 1s - 7ms/item
item 114/403 [=======>......................] - ETA: 1s - 7ms/item
item 115/403 [=======>......................] - ETA: 1s - 7ms/item
item 116/403 [=======>......................] - ETA: 1s - 7ms/item
item 117/403 [=======>......................] - ETA: 1s - 7ms/item
item 118/403 [=======>......................] - ETA: 1s - 7ms/item
item 119/403 [=======>......................] - ETA: 1s - 7ms/item
item 120/403 [=======>......................] - ETA: 1s - 7ms/item
item 120/403 [=======>......................] - ETA: 1s - 7ms/item
item 121/403 [=======>......................] - ETA: 1s - 7ms/item
item 122/403 [========>.....................] - ETA: 1s - 7ms/item
item 123/403 [========>.....................] - ETA: 1s - 7ms/item
item 124/403 [========>.....................] - ETA: 1s - 6ms/item
item 125/403 [========>.....................] - ETA: 1s - 6ms/item
item 126/403 [========>.....................] - ETA: 1s - 6ms/item
item 127/403 [========>.....................] - ETA: 1s - 6ms/item
item 128/403 [========>.....................] - ETA: 1s - 6ms/item
item 129/403 [========>.....................] - ETA: 1s - 7ms/item
item 130/403 [========>.....................] - ETA: 1s - 6ms/item
item 131/403 [========>.....................] - ETA: 1s - 6ms/item
item 132/403 [========>.....................] - ETA: 1s - 6ms/item
item 133/403 [========>.....................] - ETA: 1s - 6ms/item
item 134/403 [========>.....................] - ETA: 1s - 6ms/item
item 135/403 [=========>....................] - ETA: 1s - 6ms/item
item 136/403 [=========>....................] - ETA: 1s - 6ms/item
item 137/403 [=========>....................] - ETA: 1s - 6ms/item
item 138/403 [=========>....................] - ETA: 1s - 6ms/item
item 139/403 [=========>....................] - ETA: 1s - 6ms/item
item 140/403 [=========>....................] - ETA: 1s - 6ms/item
item 140/403 [=========>....................] - ETA: 1s - 6ms/item
item 141/403 [=========>....................] - ETA: 1s - 6ms/item
item 142/403 [=========>....................] - ETA: 1s - 6ms/item
item 143/403 [=========>....................] - ETA: 1s - 6ms/item
item 144/403 [=========>....................] - ETA: 1s - 6ms/item
item 145/403 [=========>....................] - ETA: 1s - 6ms/item
item 146/403 [=========>....................] - ETA: 1s - 6ms/item
item 147/403 [=========>....................] - ETA: 1s - 6ms/item
item 148/403 [==========>...................] - ETA: 1s - 6ms/item
item 149/403 [==========>...................] - ETA: 1s - 6ms/item
item 150/403 [==========>...................] - ETA: 1s - 6ms/item
item 151/403 [==========>...................] - ETA: 1s - 6ms/item
item 152/403 [==========>...................] - ETA: 1s - 6ms/item
item 153/403 [==========>...................] - ETA: 1s - 6ms/item
item 154/403 [==========>...................] - ETA: 1s - 6ms/item
item 155/403 [==========>...................] - ETA: 1s - 6ms/item
item 156/403 [==========>...................] - ETA: 1s - 6ms/item
item 157/403 [==========>...................] - ETA: 1s - 6ms/item
item 158/403 [==========>...................] - ETA: 1s - 6ms/item
item 159/403 [==========>...................] - ETA: 1s - 6ms/item
item 160/403 [==========>...................] - ETA: 1s - 6ms/item
item 160/403 [==========>...................] - ETA: 1s - 6ms/item
item 161/403 [==========>...................] - ETA: 1s - 6ms/item
item 162/403 [===========>..................] - ETA: 1s - 6ms/item
item 163/403 [===========>..................] - ETA: 1s - 6ms/item
item 164/403 [===========>..................] - ETA: 1s - 6ms/item
item 165/403 [===========>..................] - ETA: 1s - 6ms/item
item 166/403 [===========>..................] - ETA: 1s - 6ms/item
item 167/403 [===========>..................] - ETA: 1s - 6ms/item
item 168/403 [===========>..................] - ETA: 1s - 6ms/item
item 169/403 [===========>..................] - ETA: 1s - 6ms/item
item 170/403 [===========>..................] - ETA: 1s - 6ms/item
item 171/403 [===========>..................] - ETA: 1s - 6ms/item
item 172/403 [===========>..................] - ETA: 1s - 6ms/item
item 173/403 [===========>..................] - ETA: 1s - 6ms/item
item 174/403 [===========>..................] - ETA: 1s - 6ms/item
item 175/403 [============>.................] - ETA: 1s - 6ms/item
item 176/403 [============>.................] - ETA: 1s - 6ms/item
item 177/403 [============>.................] - ETA: 1s - 6ms/item
item 178/403 [============>.................] - ETA: 1s - 6ms/item
item 179/403 [============>.................] - ETA: 1s - 6ms/item
item 180/403 [============>.................] - ETA: 1s - 6ms/item
item 180/403 [============>.................] - ETA: 1s - 6ms/item
item 181/403 [============>.................] - ETA: 1s - 6ms/item
item 182/403 [============>.................] - ETA: 1s - 6ms/item
item 183/403 [============>.................] - ETA: 1s - 5ms/item
item 184/403 [============>.................] - ETA: 1s - 5ms/item
item 185/403 [============>.................] - ETA: 1s - 5ms/item
item 186/403 [============>.................] - ETA: 1s - 5ms/item
item 187/403 [============>.................] - ETA: 1s - 5ms/item
item 188/403 [============>.................] - ETA: 1s - 5ms/item
item 189/403 [=============>................] - ETA: 1s - 5ms/item
item 190/403 [=============>................] - ETA: 1s - 5ms/item
item 191/403 [=============>................] - ETA: 1s - 5ms/item
item 192/403 [=============>................] - ETA: 1s - 5ms/item
item 193/403 [=============>................] - ETA: 1s - 5ms/item
item 194/403 [=============>................] - ETA: 1s - 5ms/item
item 195/403 [=============>................] - ETA: 1s - 5ms/item
item 196/403 [=============>................] - ETA: 1s - 5ms/item
item 197/403 [=============>................] - ETA: 1s - 5ms/item
item 198/403 [=============>................] - ETA: 1s - 5ms/item
item 199/403 [=============>................] - ETA: 1s - 5ms/item
item 200/403 [=============>................] - ETA: 1s - 5ms/item
item 200/403 [=============>................] - ETA: 1s - 5ms/item
item 201/403 [=============>................] - ETA: 1s - 5ms/item
item 202/403 [==============>...............] - ETA: 1s - 5ms/item
item 203/403 [==============>...............] - ETA: 1s - 5ms/item
item 204/403 [==============>...............] - ETA: 1s - 5ms/item
item 205/403 [==============>...............] - ETA: 1s - 5ms/item
item 206/403 [==============>...............] - ETA: 1s - 5ms/item
item 207/403 [==============>...............] - ETA: 1s - 5ms/item
item 208/403 [==============>...............] - ETA: 1s - 5ms/item
item 209/403 [==============>...............] - ETA: 1s - 5ms/item
item 210/403 [==============>...............] - ETA: 0s - 5ms/item
item 211/403 [==============>...............] - ETA: 0s - 5ms/item
item 212/403 [==============>...............] - ETA: 0s - 5ms/item
item 213/403 [==============>...............] - ETA: 0s - 5ms/item
item 214/403 [==============>...............] - ETA: 0s - 5ms/item
item 215/403 [==============>...............] - ETA: 0s - 5ms/item
item 216/403 [===============>..............] - ETA: 0s - 5ms/item
item 217/403 [===============>..............] - ETA: 0s - 5ms/item
item 218/403 [===============>..............] - ETA: 0s - 5ms/item
item 219/403 [===============>..............] - ETA: 0s - 5ms/item
item 220/403 [===============>..............] - ETA: 0s - 5ms/item
item 220/403 [===============>..............] - ETA: 0s - 5ms/item
item 221/403 [===============>..............] - ETA: 0s - 5ms/item
item 222/403 [===============>..............] - ETA: 0s - 5ms/item
item 223/403 [===============>..............] - ETA: 0s - 5ms/item
item 224/403 [===============>..............] - ETA: 0s - 5ms/item
item 225/403 [===============>..............] - ETA: 0s - 5ms/item
item 226/403 [===============>..............] - ETA: 0s - 5ms/item
item 227/403 [===============>..............] - ETA: 0s - 5ms/item
item 228/403 [===============>..............] - ETA: 0s - 5ms/item
item 229/403 [================>.............] - ETA: 0s - 5ms/item
item 230/403 [================>.............] - ETA: 0s - 5ms/item
item 231/403 [================>.............] - ETA: 0s - 5ms/item
item 232/403 [================>.............] - ETA: 0s - 5ms/item
item 233/403 [================>.............] - ETA: 0s - 5ms/item
item 234/403 [================>.............] - ETA: 0s - 5ms/item
item 235/403 [================>.............] - ETA: 0s - 5ms/item
item 236/403 [================>.............] - ETA: 0s - 5ms/item
item 237/403 [================>.............] - ETA: 0s - 5ms/item
item 238/403 [================>.............] - ETA: 0s - 5ms/item
item 239/403 [================>.............] - ETA: 0s - 5ms/item
item 240/403 [================>.............] - ETA: 0s - 5ms/item
item 240/403 [================>.............] - ETA: 0s - 5ms/item
item 241/403 [================>.............] - ETA: 0s - 5ms/item
item 242/403 [================>.............] - ETA: 0s - 5ms/item
item 243/403 [=================>............] - ETA: 0s - 5ms/item
item 244/403 [=================>............] - ETA: 0s - 5ms/item
item 245/403 [=================>............] - ETA: 0s - 5ms/item
item 246/403 [=================>............] - ETA: 0s - 5ms/item
item 247/403 [=================>............] - ETA: 0s - 5ms/item
item 248/403 [=================>............] - ETA: 0s - 5ms/item
item 249/403 [=================>............] - ETA: 0s - 5ms/item
item 250/403 [=================>............] - ETA: 0s - 5ms/item
item 251/403 [=================>............] - ETA: 0s - 5ms/item
item 252/403 [=================>............] - ETA: 0s - 5ms/item
item 253/403 [=================>............] - ETA: 0s - 5ms/item
item 254/403 [=================>............] - ETA: 0s - 5ms/item
item 255/403 [=================>............] - ETA: 0s - 5ms/item
item 256/403 [==================>...........] - ETA: 0s - 5ms/item
item 257/403 [==================>...........] - ETA: 0s - 5ms/item
item 258/403 [==================>...........] - ETA: 0s - 5ms/item
item 259/403 [==================>...........] - ETA: 0s - 5ms/item
item 260/403 [==================>...........] - ETA: 0s - 5ms/item
item 260/403 [==================>...........] - ETA: 0s - 5ms/item
item 261/403 [==================>...........] - ETA: 0s - 5ms/item
item 262/403 [==================>...........] - ETA: 0s - 5ms/item
item 263/403 [==================>...........] - ETA: 0s - 5ms/item
item 264/403 [==================>...........] - ETA: 0s - 5ms/item
item 265/403 [==================>...........] - ETA: 0s - 5ms/item
item 266/403 [==================>...........] - ETA: 0s - 5ms/item
item 267/403 [==================>...........] - ETA: 0s - 5ms/item
item 268/403 [==================>...........] - ETA: 0s - 5ms/item
item 269/403 [==================>...........] - ETA: 0s - 5ms/item
item 270/403 [===================>..........] - ETA: 0s - 5ms/item
item 271/403 [===================>..........] - ETA: 0s - 5ms/item
item 272/403 [===================>..........] - ETA: 0s - 4ms/item
item 273/403 [===================>..........] - ETA: 0s - 5ms/item
item 274/403 [===================>..........] - ETA: 0s - 5ms/item
item 275/403 [===================>..........] - ETA: 0s - 5ms/item
item 276/403 [===================>..........] - ETA: 0s - 5ms/item
item 277/403 [===================>..........] - ETA: 0s - 5ms/item
item 278/403 [===================>..........] - ETA: 0s - 5ms/item
item 279/403 [===================>..........] - ETA: 0s - 4ms/item
item 280/403 [===================>..........] - ETA: 0s - 4ms/item
item 280/403 [===================>..........] - ETA: 0s - 4ms/item
item 281/403 [===================>..........] - ETA: 0s - 4ms/item
item 282/403 [===================>..........] - ETA: 0s - 4ms/item
item 283/403 [====================>.........] - ETA: 0s - 4ms/item
item 284/403 [====================>.........] - ETA: 0s - 4ms/item
item 285/403 [====================>.........] - ETA: 0s - 4ms/item
item 286/403 [====================>.........] - ETA: 0s - 4ms/item
item 287/403 [====================>.........] - ETA: 0s - 4ms/item
item 288/403 [====================>.........] - ETA: 0s - 4ms/item
item 289/403 [====================>.........] - ETA: 0s - 4ms/item
item 290/403 [====================>.........] - ETA: 0s - 4ms/item
item 291/403 [====================>.........] - ETA: 0s - 4ms/item
item 292/403 [====================>.........] - ETA: 0s - 4ms/item
item 293/403 [====================>.........] - ETA: 0s - 4ms/item
item 294/403 [====================>.........] - ETA: 0s - 4ms/item
item 295/403 [====================>.........] - ETA: 0s - 4ms/item
item 296/403 [=====================>........] - ETA: 0s - 4ms/item
item 297/403 [=====================>........] - ETA: 0s - 4ms/item
item 298/403 [=====================>........] - ETA: 0s - 4ms/item
item 299/403 [=====================>........] - ETA: 0s - 4ms/item
item 300/403 [=====================>........] - ETA: 0s - 4ms/item
item 300/403 [=====================>........] - ETA: 0s - 4ms/item
item 301/403 [=====================>........] - ETA: 0s - 4ms/item
item 302/403 [=====================>........] - ETA: 0s - 4ms/item
item 303/403 [=====================>........] - ETA: 0s - 4ms/item
item 304/403 [=====================>........] - ETA: 0s - 4ms/item
item 305/403 [=====================>........] - ETA: 0s - 4ms/item
item 306/403 [=====================>........] - ETA: 0s - 4ms/item
item 307/403 [=====================>........] - ETA: 0s - 4ms/item
item 308/403 [=====================>........] - ETA: 0s - 4ms/item
item 309/403 [=====================>........] - ETA: 0s - 4ms/item
item 310/403 [======================>.......] - ETA: 0s - 4ms/item
item 311/403 [======================>.......] - ETA: 0s - 4ms/item
item 312/403 [======================>.......] - ETA: 0s - 4ms/item
item 313/403 [======================>.......] - ETA: 0s - 4ms/item
item 314/403 [======================>.......] - ETA: 0s - 4ms/item
item 315/403 [======================>.......] - ETA: 0s - 4ms/item
item 316/403 [======================>.......] - ETA: 0s - 4ms/item
item 317/403 [======================>.......] - ETA: 0s - 4ms/item
item 318/403 [======================>.......] - ETA: 0s - 4ms/item
item 319/403 [======================>.......] - ETA: 0s - 4ms/item
item 320/403 [======================>.......] - ETA: 0s - 4ms/item
item 320/403 [======================>.......] - ETA: 0s - 4ms/item
item 321/403 [======================>.......] - ETA: 0s - 4ms/item
item 322/403 [======================>.......] - ETA: 0s - 4ms/item
item 323/403 [=======================>......] - ETA: 0s - 4ms/item
item 324/403 [=======================>......] - ETA: 0s - 4ms/item
item 325/403 [=======================>......] - ETA: 0s - 4ms/item
item 326/403 [=======================>......] - ETA: 0s - 4ms/item
item 327/403 [=======================>......] - ETA: 0s - 4ms/item
item 328/403 [=======================>......] - ETA: 0s - 4ms/item
item 329/403 [=======================>......] - ETA: 0s - 4ms/item
item 330/403 [=======================>......] - ETA: 0s - 4ms/item
item 331/403 [=======================>......] - ETA: 0s - 4ms/item
item 332/403 [=======================>......] - ETA: 0s - 4ms/item
item 333/403 [=======================>......] - ETA: 0s - 4ms/item
item 334/403 [=======================>......] - ETA: 0s - 4ms/item
item 335/403 [=======================>......] - ETA: 0s - 4ms/item
item 336/403 [=======================>......] - ETA: 0s - 4ms/item
item 337/403 [========================>.....] - ETA: 0s - 4ms/item
item 338/403 [========================>.....] - ETA: 0s - 4ms/item
item 339/403 [========================>.....] - ETA: 0s - 4ms/item
item 340/403 [========================>.....] - ETA: 0s - 4ms/item
item 340/403 [========================>.....] - ETA: 0s - 4ms/item
item 341/403 [========================>.....] - ETA: 0s - 4ms/item
item 342/403 [========================>.....] - ETA: 0s - 4ms/item
item 343/403 [========================>.....] - ETA: 0s - 4ms/item
item 344/403 [========================>.....] - ETA: 0s - 4ms/item
item 345/403 [========================>.....] - ETA: 0s - 4ms/item
item 346/403 [========================>.....] - ETA: 0s - 4ms/item
item 347/403 [========================>.....] - ETA: 0s - 4ms/item
item 348/403 [========================>.....] - ETA: 0s - 4ms/item
item 349/403 [========================>.....] - ETA: 0s - 4ms/item
item 350/403 [=========================>....] - ETA: 0s - 4ms/item
item 351/403 [=========================>....] - ETA: 0s - 4ms/item
item 352/403 [=========================>....] - ETA: 0s - 4ms/item
item 353/403 [=========================>....] - ETA: 0s - 4ms/item
item 354/403 [=========================>....] - ETA: 0s - 4ms/item
item 355/403 [=========================>....] - ETA: 0s - 4ms/item
item 356/403 [=========================>....] - ETA: 0s - 4ms/item
item 357/403 [=========================>....] - ETA: 0s - 4ms/item
item 358/403 [=========================>....] - ETA: 0s - 4ms/item
item 359/403 [=========================>....] - ETA: 0s - 4ms/item
item 360/403 [=========================>....] - ETA: 0s - 4ms/item
item 360/403 [=========================>....] - ETA: 0s - 4ms/item
item 361/403 [=========================>....] - ETA: 0s - 4ms/item
item 362/403 [=========================>....] - ETA: 0s - 4ms/item
item 363/403 [=========================>....] - ETA: 0s - 4ms/item
item 364/403 [==========================>...] - ETA: 0s - 4ms/item
item 365/403 [==========================>...] - ETA: 0s - 4ms/item
item 366/403 [==========================>...] - ETA: 0s - 4ms/item
item 367/403 [==========================>...] - ETA: 0s - 4ms/item
item 368/403 [==========================>...] - ETA: 0s - 4ms/item
item 369/403 [==========================>...] - ETA: 0s - 4ms/item
item 370/403 [==========================>...] - ETA: 0s - 4ms/item
item 371/403 [==========================>...] - ETA: 0s - 4ms/item
item 372/403 [==========================>...] - ETA: 0s - 4ms/item
item 373/403 [==========================>...] - ETA: 0s - 4ms/item
item 374/403 [==========================>...] - ETA: 0s - 4ms/item
item 375/403 [==========================>...] - ETA: 0s - 4ms/item
item 376/403 [==========================>...] - ETA: 0s - 4ms/item
item 377/403 [===========================>..] - ETA: 0s - 4ms/item
item 378/403 [===========================>..] - ETA: 0s - 4ms/item
item 379/403 [===========================>..] - ETA: 0s - 4ms/item
item 380/403 [===========================>..] - ETA: 0s - 4ms/item
item 380/403 [===========================>..] - ETA: 0s - 4ms/item
item 381/403 [===========================>..] - ETA: 0s - 4ms/item
item 382/403 [===========================>..] - ETA: 0s - 4ms/item
item 383/403 [===========================>..] - ETA: 0s - 4ms/item
item 384/403 [===========================>..] - ETA: 0s - 4ms/item
item 385/403 [===========================>..] - ETA: 0s - 4ms/item
item 386/403 [===========================>..] - ETA: 0s - 4ms/item
item 387/403 [===========================>..] - ETA: 0s - 4ms/item
item 388/403 [===========================>..] - ETA: 0s - 4ms/item
item 389/403 [===========================>..] - ETA: 0s - 4ms/item
item 390/403 [===========================>..] - ETA: 0s - 4ms/item
item 391/403 [============================>.] - ETA: 0s - 4ms/item
item 392/403 [============================>.] - ETA: 0s - 4ms/item
item 393/403 [============================>.] - ETA: 0s - 4ms/item
item 394/403 [============================>.] - ETA: 0s - 4ms/item
item 395/403 [============================>.] - ETA: 0s - 4ms/item
item 396/403 [============================>.] - ETA: 0s - 4ms/item
item 397/403 [============================>.] - ETA: 0s - 4ms/item
item 398/403 [============================>.] - ETA: 0s - 4ms/item
item 399/403 [============================>.] - ETA: 0s - 4ms/item
item 400/403 [============================>.] - ETA: 0s - 4ms/item
item 400/403 [============================>.] - ETA: 0s - 4ms/item
item 401/403 [============================>.] - ETA: 0s - 4ms/item
item 402/403 [============================>.] - ETA: 0s - 4ms/item
item 403/403 [============================>.] - ETA: 0s - 4ms/item

item 1/2 [=============>................] - ETA: 0s - 105us/item
item 1/2 [=============>................] - ETA: 0s - 241us/item
item 2/2 [===========================>..] - ETA: 0s - 170us/item
item 2/2 [===========================>..] - ETA: 0s - 187us/item
```

Print the layers of the simple CNN model:

```print(model1)
```
```SimpleNet(
(fc1): Linear(in_features=3136, out_features=32, dtype=None)
(fc2): Linear(in_features=32, out_features=10, dtype=None)
)
```

## Test the input models

```with paddle.no_grad():
n_correct1 = 0
n_correct2 = 0
n_samples = 0
outputs1 = model1(images)
outputs2 = model2(images)
n_samples += labels.shape[0]
n_correct1 += (predictions1 == labels.t()).sum().item()
n_correct2 += (predictions2 == labels.t()).sum().item()
acc1 = 100 * n_correct1 / n_samples
acc2 = 100 * n_correct2 / n_samples
```

Testing results (two separate models):

```print(f'model1 accuracy = {acc1}%, model2 accuracy = {acc2}%')
```
```model1 accuracy = 84.18%, model2 accuracy = 83.81%
```

## Build the affinity matrix for graph matching

As shown in the following plot, the neural networks can be regarded as graphs. The weights correspond to the edge features, and the bias corresponds to the node features. In this example, the neural network does not have bias so that there are only edge features.

```plt.figure(figsize=(8, 4))
img = Image.open('../data/model_fusion.png')
plt.imshow(img)
plt.axis('off')
st_time = time.perf_counter()
```

Define the graph matching affinity metric function

```class Ground_Metric_GM:
def __init__(self,
conv_param: bool = False,
bias_param: bool = False,
pre_conv_param: bool = False,
pre_conv_image_size_squared: int = None):
self.model_1_param = model_1_param
self.model_2_param = model_2_param
self.conv_param = conv_param
self.bias_param = bias_param
# bias, or fully-connected from linear
if bias_param is True or (conv_param is False and pre_conv_param is False):
self.model_1_param = self.model_1_param.reshape((1, -1, 1))
self.model_2_param = self.model_2_param.reshape((1, -1, 1))
# fully-connected from conv
elif conv_param is False and pre_conv_param is True:
self.model_1_param = self.model_1_param.reshape((1, -1, pre_conv_image_size_squared))
self.model_2_param = self.model_2_param.reshape((1, -1, pre_conv_image_size_squared))
# conv
else:
self.model_1_param = self.model_1_param.reshape((1, -1, model_1_param.shape[-1]))
self.model_2_param = self.model_2_param.reshape((1, -1, model_2_param.shape[-1]))

def process_distance(self, p: int = 2):
dist = []
param_1 = self.model_1_param.cast('float32')[0]
param_2 = self.model_2_param.cast('float32')[0]
for i in param_1:

def process_soft_affinity(self, p: int = 2):
```

Define the affinity function between two neural networks. This function takes multiple neural network modules, and construct the corresponding affinity matrix which is further processed by the graph matching solver.

```def graph_matching_fusion(networks: list):
# count the total number of nodes in the network [network]
num_nodes = 0
for idx, (name, parameters) in enumerate(network.named_parameters()):
if 'bias' in name:
continue
if idx == 0:
num_nodes += parameters.shape[1]
# transpose linear layers in paddle to conventional shape,
num_nodes += parameters.shape[0] if 'fc' not in name else parameters.shape[1]
return num_nodes

n1 = total_node_num(network=networks[0])
n2 = total_node_num(network=networks[1])
assert (n1 == n2)
affinity = paddle.zeros([n1 * n2, n1 * n2])
num_layers = len(list(zip(networks[0].parameters(), networks[1].parameters())))
num_nodes_before = 0
num_nodes_incremental = []
num_nodes_layers = []
pre_conv_list = []
cur_conv_list = []
conv_kernel_size_list = []
num_nodes_pre = 0
is_conv = False
pre_conv = False
pre_conv_out_channel = 1
is_final_bias = False
perm_is_complete = True
named_weight_list_0 = [named_parameter for named_parameter in networks[0].named_parameters()]
for idx, ((name_0, fc_layer0_weight), (name_1, fc_layer1_weight)) in \
enumerate(zip(networks[0].named_parameters(), networks[1].named_parameters())):
assert fc_layer0_weight.shape == fc_layer1_weight.shape
if 'fc' in name_0:
fc_layer0_weight = fc_layer0_weight.t()
fc_layer1_weight = fc_layer1_weight.t()
layer_shape = fc_layer0_weight.shape
num_nodes_cur = fc_layer0_weight.shape[0]
if len(layer_shape) > 1:
if is_conv is True and len(layer_shape) == 2:
num_nodes_pre = pre_conv_out_channel
else:
num_nodes_pre = fc_layer0_weight.shape[1]
if idx >= 1 and len(named_weight_list_0[idx - 1][1].shape) == 1:
pre_bias = True
else:
pre_bias = False
if len(layer_shape) > 2:
is_bias = False
if not pre_bias:
pre_conv = is_conv
pre_conv_list.append(pre_conv)
is_conv = True
cur_conv_list.append(is_conv)
fc_layer0_weight_data = fc_layer0_weight.detach().reshape(
(fc_layer0_weight.shape[0], fc_layer0_weight.shape[1], -1))
fc_layer1_weight_data = fc_layer1_weight.detach().reshape(
(fc_layer1_weight.shape[0], fc_layer1_weight.shape[1], -1))
elif len(layer_shape) == 2:
is_bias = False
if not pre_bias:
pre_conv = is_conv
pre_conv_list.append(pre_conv)
is_conv = False
cur_conv_list.append(is_conv)
fc_layer0_weight_data = fc_layer0_weight.detach()
fc_layer1_weight_data = fc_layer1_weight.detach()
else:
is_bias = True
if not pre_bias:
pre_conv = is_conv
pre_conv_list.append(pre_conv)
is_conv = False
cur_conv_list.append(is_conv)
fc_layer0_weight_data = fc_layer0_weight.detach()
fc_layer1_weight_data = fc_layer1_weight.detach()
if is_conv:
pre_conv_out_channel = num_nodes_cur
if is_bias is True and idx == num_layers - 1:
is_final_bias = True
if idx == 0:
for a in range(num_nodes_pre):
affinity[(num_nodes_before + a) * n2 + num_nodes_before + a, \
(num_nodes_before + a) * n2 + num_nodes_before + a] \
= 1
if idx == num_layers - 2 and 'bias' in named_weight_list_0[idx + 1][0] or \
idx == num_layers - 1 and 'bias' not in named_weight_list_0[idx][0]:
for a in range(num_nodes_cur):
affinity[(num_nodes_before + num_nodes_pre + a) * n2 + num_nodes_before + num_nodes_pre + a, \
(num_nodes_before + num_nodes_pre + a) * n2 + num_nodes_before + num_nodes_pre + a] \
= 1
if is_bias is False:
ground_metric = Ground_Metric_GM(
fc_layer0_weight_data, fc_layer1_weight_data, is_conv, is_bias,
pre_conv, int(fc_layer0_weight_data.shape[1] / pre_conv_out_channel))
else:
ground_metric = Ground_Metric_GM(
fc_layer0_weight_data, fc_layer1_weight_data, is_conv, is_bias,
pre_conv, 1)

layer_affinity = ground_metric.process_soft_affinity(p=2)

if is_bias is False:
pre_conv_kernel_size = fc_layer0_weight.shape[3] if is_conv else None
conv_kernel_size_list.append(pre_conv_kernel_size)
if is_bias is True and is_final_bias is False:
for a in range(num_nodes_cur):
for c in range(num_nodes_cur):
affinity[(num_nodes_before + a) * n2 + num_nodes_before + c, \
(num_nodes_before + a) * n2 + num_nodes_before + c] \
= layer_affinity[a][c]
elif is_final_bias is False:
for a in range(num_nodes_pre):
for b in range(num_nodes_cur):
affinity[
(num_nodes_before + a) * n2 + num_nodes_before:
(num_nodes_before + a) * n2 + num_nodes_before + num_nodes_pre,
(num_nodes_before + num_nodes_pre + b) * n2 + num_nodes_before + num_nodes_pre:
(num_nodes_before + num_nodes_pre + b) * n2 + num_nodes_before + num_nodes_pre + num_nodes_cur] \
= layer_affinity[a + b * num_nodes_pre].reshape((num_nodes_cur, num_nodes_pre)).t()
if is_bias is False:
num_nodes_before += num_nodes_pre
num_nodes_incremental.append(num_nodes_before)
num_nodes_layers.append(num_nodes_cur)
# affinity = (affinity + affinity.t()) / 2
return affinity, [n1, n2, num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list]
```

Get the affinity (similarity) matrix between model1 and model2.

```K, params = graph_matching_fusion([model1, model2])
```

## Align the models by graph matching

Align the channels of model1 & model2 by maximize the affinity (similarity) via graph matching algorithms.

```n1 = params[0]
n2 = params[1]
X = pygm.sm(K, n1, n2)
```

Project `X` to neural network matching result. The neural network matching matrix is built by applying Hungarian to small blocks of `X`, because only the channels from the same neural network layer can be matched.

Note

In this example, we assume the last FC layer is aligned and need not be matched.

```new_X = paddle.zeros_like(X)
for start_idx, length in zip(params[2][:-1], params[3][:-1]):  # params[2] and params[3] are the indices of layers
slicing = slice(start_idx, start_idx + length)
new_X[slicing, slicing] = pygm.hungarian(X[slicing, slicing])
# assume the last FC layer is aligned
slicing = slice(params[2][-1], params[2][-1] + params[3][-1])
X = new_X
```

Visualization of the matching result. The black lines splits the channels of different layers.

```plt.figure(figsize=(4, 4))
plt.imshow(X.cpu().numpy(), cmap='Blues')
for idx in params[2]:
plt.axvline(x=idx, color='k')
plt.axhline(y=idx, color='k')
```

Define the alignment function: fuse the models based on matching result

```def align(solution, fusion_proportion, networks: list, params: list):
[_, _, num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list] = params
named_weight_list_0 = [named_parameter for named_parameter in networks[0].named_parameters()]
aligned_wt_0 = [parameter.detach() if 'fc' not in name else parameter.detach().t() for name, parameter in named_weight_list_0]
idx = 0
num_layers = len(aligned_wt_0)
for num_before, num_cur, cur_conv, cur_kernel_size in \
zip(num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list):
perm = solution[num_before:num_before + num_cur, num_before:num_before + num_cur]
assert 'bias' not in named_weight_list_0[idx][0]
if len(named_weight_list_0[idx][1].shape) == 4:
.transpose((2, 3, 0, 1))
else:
idx += 1
if idx >= num_layers:
continue
if 'bias' in named_weight_list_0[idx][0]:
idx += 1
if idx >= num_layers:
continue
if cur_conv and len(named_weight_list_0[idx][1].shape) == 2:
.reshape((aligned_wt_0[idx].shape[0], 64, -1))
.transpose((0, 2, 1))
.transpose((0, 2, 1)) \
.reshape((aligned_wt_0[idx].shape[0], -1))
elif len(named_weight_list_0[idx][1].shape) == 4:
.transpose((2, 3, 0, 1))
.transpose((2, 3, 0, 1))
else:
assert idx == num_layers

averaged_weights = []
for idx, (named, parameter) in enumerate(networks[1].named_parameters()):
parameter = parameter.t() if 'fc' in named else parameter
averaged_weights.append((1 - fusion_proportion) * aligned_wt_0[idx].cast('float32') + fusion_proportion * parameter)
return averaged_weights
```

## Test the fused model

The `fusion_proportion` variable denotes the contribution to the new model. For example, if `fusion_proportion=0.2`, the fused model = 80% model1 + 20% model2.

```def align_model_and_test(X):
acc_list = []
for fusion_proportion in paddle.arange(0, 11, 1) / 10: # paddle arange accepts int step only
fused_weights = align(X, fusion_proportion, [model1, model2], params)

fused_model = SimpleNet()
state_dict = fused_model.state_dict()
for idx, (key, _) in enumerate(state_dict.items()):
state_dict[key] = fused_weights[idx].t() if 'fc' in key else fused_weights[idx]
fused_model.set_dict(state_dict)
fused_model.to(device)
test_loss = 0
correct = 0
output = fused_model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.detach().argmax(1, keepdim=True)
correct += pred.equal(target.detach().reshape(pred.shape)).sum()
acc = 100. * correct / len(test_loader.dataset)
print(
f"{1 - fusion_proportion.item():.2f} model1 + {fusion_proportion.item():.2f} model2 -> fused model accuracy: {acc.item():.2f}%")
acc_list.append(acc)

print('Graph Matching Fusion')
gm_acc_list = align_model_and_test(X)
```
```Graph Matching Fusion
1.00 model1 + 0.00 model2 -> fused model accuracy: 84.18%
0.90 model1 + 0.10 model2 -> fused model accuracy: 85.12%
0.80 model1 + 0.20 model2 -> fused model accuracy: 85.21%
0.70 model1 + 0.30 model2 -> fused model accuracy: 82.52%
0.60 model1 + 0.40 model2 -> fused model accuracy: 71.11%
0.50 model1 + 0.50 model2 -> fused model accuracy: 53.74%
0.40 model1 + 0.60 model2 -> fused model accuracy: 63.26%
0.30 model1 + 0.70 model2 -> fused model accuracy: 78.51%
0.20 model1 + 0.80 model2 -> fused model accuracy: 82.81%
0.10 model1 + 0.90 model2 -> fused model accuracy: 83.97%
0.00 model1 + 1.00 model2 -> fused model accuracy: 83.81%
```

Compare with vanilla model fusion (no matching), graph matching method stabilizes the fusion step:

```print('No Matching Fusion')

plt.figure(figsize=(4, 4))
plt.title('Fused Model Accuracy')
plt.plot((paddle.arange(0, 11, 1) / 10).numpy(), gm_acc_list.cpu().numpy(), 'r*-', label='Graph Matching Fusion')
plt.plot((paddle.arange(0, 11, 1) / 10).numpy(), vanilla_acc_list.cpu().numpy(), 'b*-', label='No Matching Fusion')
plt.plot((paddle.arange(0, 11, 1) / 10).numpy(), [acc1] * 11, '--', color="gray", label='Model1 Accuracy')
plt.plot((paddle.arange(0, 11, 1) / 10).numpy(), [acc2] * 11, '--', color="brown", label='Model2 Accuracy')
plt.gca().set_xlabel('Fusion Proportion')
plt.gca().set_ylabel('Accuracy (%)')
plt.ylim((70, 87))
plt.legend(loc=3)
```
```No Matching Fusion
1.00 model1 + 0.00 model2 -> fused model accuracy: 84.18%
0.90 model1 + 0.10 model2 -> fused model accuracy: 84.01%
0.80 model1 + 0.20 model2 -> fused model accuracy: 81.91%
0.70 model1 + 0.30 model2 -> fused model accuracy: 74.67%
0.60 model1 + 0.40 model2 -> fused model accuracy: 60.39%
0.50 model1 + 0.50 model2 -> fused model accuracy: 47.16%
0.40 model1 + 0.60 model2 -> fused model accuracy: 55.34%
0.30 model1 + 0.70 model2 -> fused model accuracy: 72.86%
0.20 model1 + 0.80 model2 -> fused model accuracy: 79.64%
0.10 model1 + 0.90 model2 -> fused model accuracy: 82.56%
0.00 model1 + 1.00 model2 -> fused model accuracy: 83.81%

<matplotlib.legend.Legend object at 0x7fdef9dfde10>
```