Note
Go to the end to download the full example code
Paddle Backend Example: Model Fusion by Graph Matching
This example shows how to fuse different models into a single model by pygmtools
.
Model fusion aims to fuse multiple models into one, such that the fused model could have higher performance.
The neural networks can be regarded as graphs (channels - nodes, update functions between channels - edges;
node feature - bias, edge feature - weights), and fusing the models is equivalent to solving a graph matching
problem. In this example, the given models are trained on MNIST data from different distributions, and the
fused model could combine the knowledge from two input models and can reach higher accuracy when testing.
# Author: Chang Liu <only-changer@sjtu.edu.cn>
# Runzhong Wang <runzhong.wang@sjtu.edu.cn>
# Wenzheng Pan <pwz1121@sjtu.edu.cn>
#
# License: Mulan PSL v2 License
Note
This is a simplified implementation of the ideas in Liu et al. Deep Neural Network Fusion via Graph Matching with Applications to Model Ensemble and Federated Learning. ICML 2022. For more details, please refer to the paper and the official code repository.
Note
The following solvers are included in this example:
sm()
(classic solver)hungarian()
(linear solver)
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import time
from PIL import Image
import matplotlib.pyplot as plt
import pygmtools as pygm
import warnings
warnings.filterwarnings("ignore")
pygm.set_backend('paddle')
device = paddle.device.get_device()
paddle.device.set_device(device)
Place(cpu)
Define a simple CNN classifier network
class SimpleNet(nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.conv1 = nn.Conv2D(1, 32, 5, padding=1, padding_mode='replicate', bias_attr=False)
self.max_pool = nn.MaxPool2D(2, padding=1)
self.conv2 = nn.Conv2D(32, 64, 5, padding=1, padding_mode='replicate', bias_attr=False)
self.fc1 = nn.Linear(3136, 32, bias_attr=False)
self.fc2 = nn.Linear(32, 10, bias_attr=False)
def forward(self, x):
output = F.relu(self.conv1(x))
output = self.max_pool(output)
output = F.relu(self.conv2(output))
output = self.max_pool(output)
output = output.reshape((output.shape[0], -1))
output = self.fc1(output)
output = self.fc2(output)
return output
Load the trained models to be fused
model1 = SimpleNet()
model2 = SimpleNet()
model1.set_dict(paddle.load('../data/example_model_fusion_1_paddle.dat'))
model2.set_dict(paddle.load('../data/example_model_fusion_2_paddle.dat'))
model1.to(device)
model2.to(device)
test_dataset = paddle.vision.datasets.MNIST(
# unable to modify the directory to store the dataset.
# default: ~/.cache/paddle/dataset/mnist
mode='test', # the dataset is used to test
transform=transforms.ToTensor(), # the dataset is in the form of tensors
download=True)
test_loader = paddle.io.DataLoader(
dataset=test_dataset,
batch_size=32,
shuffle=False)
Cache file /home/wzever/.cache/paddle/dataset/mnist/t10k-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-images-idx3-ubyte.gz
Begin to download
item 1/403 [..............................] - ETA: 17s - 42ms/item
item 2/403 [..............................] - ETA: 8s - 21ms/item
item 3/403 [..............................] - ETA: 5s - 14ms/item
item 4/403 [..............................] - ETA: 4s - 11ms/item
item 5/403 [..............................] - ETA: 8s - 22ms/item
item 6/403 [..............................] - ETA: 7s - 18ms/item
item 7/403 [..............................] - ETA: 6s - 16ms/item
item 8/403 [..............................] - ETA: 5s - 14ms/item
item 9/403 [..............................] - ETA: 6s - 16ms/item
item 10/403 [..............................] - ETA: 5s - 14ms/item
item 11/403 [..............................] - ETA: 5s - 13ms/item
item 12/403 [..............................] - ETA: 4s - 12ms/item
item 13/403 [..............................] - ETA: 5s - 13ms/item
item 14/403 [>.............................] - ETA: 4s - 12ms/item
item 15/403 [>.............................] - ETA: 4s - 12ms/item
item 16/403 [>.............................] - ETA: 4s - 11ms/item
item 17/403 [>.............................] - ETA: 4s - 12ms/item
item 18/403 [>.............................] - ETA: 4s - 12ms/item
item 19/403 [>.............................] - ETA: 4s - 11ms/item
item 20/403 [>.............................] - ETA: 3s - 10ms/item
item 20/403 [>.............................] - ETA: 3s - 10ms/item
item 21/403 [>.............................] - ETA: 4s - 12ms/item
item 22/403 [>.............................] - ETA: 4s - 11ms/item
item 23/403 [>.............................] - ETA: 4s - 11ms/item
item 24/403 [>.............................] - ETA: 3s - 10ms/item
item 25/403 [>.............................] - ETA: 4s - 11ms/item
item 26/403 [>.............................] - ETA: 4s - 11ms/item
item 27/403 [=>............................] - ETA: 3s - 11ms/item
item 28/403 [=>............................] - ETA: 3s - 10ms/item
item 29/403 [=>............................] - ETA: 4s - 11ms/item
item 30/403 [=>............................] - ETA: 3s - 11ms/item
item 31/403 [=>............................] - ETA: 3s - 10ms/item
item 32/403 [=>............................] - ETA: 3s - 10ms/item
item 33/403 [=>............................] - ETA: 3s - 11ms/item
item 34/403 [=>............................] - ETA: 3s - 10ms/item
item 35/403 [=>............................] - ETA: 3s - 10ms/item
item 36/403 [=>............................] - ETA: 3s - 10ms/item
item 37/403 [=>............................] - ETA: 3s - 11ms/item
item 38/403 [=>............................] - ETA: 3s - 11ms/item
item 39/403 [=>............................] - ETA: 3s - 10ms/item
item 40/403 [=>............................] - ETA: 3s - 10ms/item
item 40/403 [=>............................] - ETA: 3s - 10ms/item
item 41/403 [==>...........................] - ETA: 3s - 10ms/item
item 42/403 [==>...........................] - ETA: 3s - 10ms/item
item 43/403 [==>...........................] - ETA: 3s - 9ms/item
item 44/403 [==>...........................] - ETA: 3s - 9ms/item
item 45/403 [==>...........................] - ETA: 3s - 10ms/item
item 46/403 [==>...........................] - ETA: 3s - 9ms/item
item 47/403 [==>...........................] - ETA: 3s - 9ms/item
item 48/403 [==>...........................] - ETA: 3s - 9ms/item
item 49/403 [==>...........................] - ETA: 3s - 10ms/item
item 50/403 [==>...........................] - ETA: 3s - 9ms/item
item 51/403 [==>...........................] - ETA: 3s - 9ms/item
item 52/403 [==>...........................] - ETA: 3s - 9ms/item
item 53/403 [==>...........................] - ETA: 3s - 9ms/item
item 54/403 [===>..........................] - ETA: 3s - 9ms/item
item 55/403 [===>..........................] - ETA: 3s - 9ms/item
item 56/403 [===>..........................] - ETA: 3s - 9ms/item
item 57/403 [===>..........................] - ETA: 3s - 9ms/item
item 58/403 [===>..........................] - ETA: 2s - 9ms/item
item 59/403 [===>..........................] - ETA: 2s - 8ms/item
item 60/403 [===>..........................] - ETA: 2s - 8ms/item
item 60/403 [===>..........................] - ETA: 2s - 8ms/item
item 61/403 [===>..........................] - ETA: 2s - 9ms/item
item 62/403 [===>..........................] - ETA: 2s - 9ms/item
item 63/403 [===>..........................] - ETA: 2s - 8ms/item
item 64/403 [===>..........................] - ETA: 2s - 8ms/item
item 65/403 [===>..........................] - ETA: 2s - 9ms/item
item 66/403 [===>..........................] - ETA: 2s - 9ms/item
item 67/403 [===>..........................] - ETA: 2s - 8ms/item
item 68/403 [====>.........................] - ETA: 2s - 8ms/item
item 69/403 [====>.........................] - ETA: 2s - 8ms/item
item 70/403 [====>.........................] - ETA: 2s - 8ms/item
item 71/403 [====>.........................] - ETA: 2s - 8ms/item
item 72/403 [====>.........................] - ETA: 2s - 8ms/item
item 73/403 [====>.........................] - ETA: 2s - 8ms/item
item 74/403 [====>.........................] - ETA: 2s - 8ms/item
item 75/403 [====>.........................] - ETA: 2s - 8ms/item
item 76/403 [====>.........................] - ETA: 2s - 8ms/item
item 77/403 [====>.........................] - ETA: 2s - 8ms/item
item 78/403 [====>.........................] - ETA: 2s - 8ms/item
item 79/403 [====>.........................] - ETA: 2s - 8ms/item
item 80/403 [====>.........................] - ETA: 2s - 8ms/item
item 80/403 [====>.........................] - ETA: 2s - 8ms/item
item 81/403 [=====>........................] - ETA: 2s - 8ms/item
item 82/403 [=====>........................] - ETA: 2s - 8ms/item
item 83/403 [=====>........................] - ETA: 2s - 8ms/item
item 84/403 [=====>........................] - ETA: 2s - 8ms/item
item 85/403 [=====>........................] - ETA: 2s - 8ms/item
item 86/403 [=====>........................] - ETA: 2s - 8ms/item
item 87/403 [=====>........................] - ETA: 2s - 8ms/item
item 88/403 [=====>........................] - ETA: 2s - 8ms/item
item 89/403 [=====>........................] - ETA: 2s - 8ms/item
item 90/403 [=====>........................] - ETA: 2s - 7ms/item
item 91/403 [=====>........................] - ETA: 2s - 7ms/item
item 92/403 [=====>........................] - ETA: 2s - 7ms/item
item 93/403 [=====>........................] - ETA: 2s - 8ms/item
item 94/403 [=====>........................] - ETA: 2s - 7ms/item
item 95/403 [======>.......................] - ETA: 2s - 7ms/item
item 96/403 [======>.......................] - ETA: 2s - 7ms/item
item 97/403 [======>.......................] - ETA: 2s - 7ms/item
item 98/403 [======>.......................] - ETA: 2s - 7ms/item
item 99/403 [======>.......................] - ETA: 2s - 7ms/item
item 100/403 [======>.......................] - ETA: 2s - 7ms/item
item 100/403 [======>.......................] - ETA: 2s - 7ms/item
item 101/403 [======>.......................] - ETA: 2s - 7ms/item
item 102/403 [======>.......................] - ETA: 2s - 7ms/item
item 103/403 [======>.......................] - ETA: 2s - 7ms/item
item 104/403 [======>.......................] - ETA: 2s - 7ms/item
item 105/403 [======>.......................] - ETA: 2s - 7ms/item
item 106/403 [======>.......................] - ETA: 2s - 7ms/item
item 107/403 [======>.......................] - ETA: 2s - 7ms/item
item 108/403 [=======>......................] - ETA: 2s - 7ms/item
item 109/403 [=======>......................] - ETA: 2s - 7ms/item
item 110/403 [=======>......................] - ETA: 2s - 7ms/item
item 111/403 [=======>......................] - ETA: 2s - 7ms/item
item 112/403 [=======>......................] - ETA: 2s - 7ms/item
item 113/403 [=======>......................] - ETA: 1s - 7ms/item
item 114/403 [=======>......................] - ETA: 1s - 7ms/item
item 115/403 [=======>......................] - ETA: 1s - 7ms/item
item 116/403 [=======>......................] - ETA: 1s - 7ms/item
item 117/403 [=======>......................] - ETA: 1s - 7ms/item
item 118/403 [=======>......................] - ETA: 1s - 7ms/item
item 119/403 [=======>......................] - ETA: 1s - 7ms/item
item 120/403 [=======>......................] - ETA: 1s - 7ms/item
item 120/403 [=======>......................] - ETA: 1s - 7ms/item
item 121/403 [=======>......................] - ETA: 1s - 7ms/item
item 122/403 [========>.....................] - ETA: 1s - 7ms/item
item 123/403 [========>.....................] - ETA: 1s - 7ms/item
item 124/403 [========>.....................] - ETA: 1s - 6ms/item
item 125/403 [========>.....................] - ETA: 1s - 6ms/item
item 126/403 [========>.....................] - ETA: 1s - 6ms/item
item 127/403 [========>.....................] - ETA: 1s - 6ms/item
item 128/403 [========>.....................] - ETA: 1s - 6ms/item
item 129/403 [========>.....................] - ETA: 1s - 7ms/item
item 130/403 [========>.....................] - ETA: 1s - 6ms/item
item 131/403 [========>.....................] - ETA: 1s - 6ms/item
item 132/403 [========>.....................] - ETA: 1s - 6ms/item
item 133/403 [========>.....................] - ETA: 1s - 6ms/item
item 134/403 [========>.....................] - ETA: 1s - 6ms/item
item 135/403 [=========>....................] - ETA: 1s - 6ms/item
item 136/403 [=========>....................] - ETA: 1s - 6ms/item
item 137/403 [=========>....................] - ETA: 1s - 6ms/item
item 138/403 [=========>....................] - ETA: 1s - 6ms/item
item 139/403 [=========>....................] - ETA: 1s - 6ms/item
item 140/403 [=========>....................] - ETA: 1s - 6ms/item
item 140/403 [=========>....................] - ETA: 1s - 6ms/item
item 141/403 [=========>....................] - ETA: 1s - 6ms/item
item 142/403 [=========>....................] - ETA: 1s - 6ms/item
item 143/403 [=========>....................] - ETA: 1s - 6ms/item
item 144/403 [=========>....................] - ETA: 1s - 6ms/item
item 145/403 [=========>....................] - ETA: 1s - 6ms/item
item 146/403 [=========>....................] - ETA: 1s - 6ms/item
item 147/403 [=========>....................] - ETA: 1s - 6ms/item
item 148/403 [==========>...................] - ETA: 1s - 6ms/item
item 149/403 [==========>...................] - ETA: 1s - 6ms/item
item 150/403 [==========>...................] - ETA: 1s - 6ms/item
item 151/403 [==========>...................] - ETA: 1s - 6ms/item
item 152/403 [==========>...................] - ETA: 1s - 6ms/item
item 153/403 [==========>...................] - ETA: 1s - 6ms/item
item 154/403 [==========>...................] - ETA: 1s - 6ms/item
item 155/403 [==========>...................] - ETA: 1s - 6ms/item
item 156/403 [==========>...................] - ETA: 1s - 6ms/item
item 157/403 [==========>...................] - ETA: 1s - 6ms/item
item 158/403 [==========>...................] - ETA: 1s - 6ms/item
item 159/403 [==========>...................] - ETA: 1s - 6ms/item
item 160/403 [==========>...................] - ETA: 1s - 6ms/item
item 160/403 [==========>...................] - ETA: 1s - 6ms/item
item 161/403 [==========>...................] - ETA: 1s - 6ms/item
item 162/403 [===========>..................] - ETA: 1s - 6ms/item
item 163/403 [===========>..................] - ETA: 1s - 6ms/item
item 164/403 [===========>..................] - ETA: 1s - 6ms/item
item 165/403 [===========>..................] - ETA: 1s - 6ms/item
item 166/403 [===========>..................] - ETA: 1s - 6ms/item
item 167/403 [===========>..................] - ETA: 1s - 6ms/item
item 168/403 [===========>..................] - ETA: 1s - 6ms/item
item 169/403 [===========>..................] - ETA: 1s - 6ms/item
item 170/403 [===========>..................] - ETA: 1s - 6ms/item
item 171/403 [===========>..................] - ETA: 1s - 6ms/item
item 172/403 [===========>..................] - ETA: 1s - 6ms/item
item 173/403 [===========>..................] - ETA: 1s - 6ms/item
item 174/403 [===========>..................] - ETA: 1s - 6ms/item
item 175/403 [============>.................] - ETA: 1s - 6ms/item
item 176/403 [============>.................] - ETA: 1s - 6ms/item
item 177/403 [============>.................] - ETA: 1s - 6ms/item
item 178/403 [============>.................] - ETA: 1s - 6ms/item
item 179/403 [============>.................] - ETA: 1s - 6ms/item
item 180/403 [============>.................] - ETA: 1s - 6ms/item
item 180/403 [============>.................] - ETA: 1s - 6ms/item
item 181/403 [============>.................] - ETA: 1s - 6ms/item
item 182/403 [============>.................] - ETA: 1s - 6ms/item
item 183/403 [============>.................] - ETA: 1s - 5ms/item
item 184/403 [============>.................] - ETA: 1s - 5ms/item
item 185/403 [============>.................] - ETA: 1s - 5ms/item
item 186/403 [============>.................] - ETA: 1s - 5ms/item
item 187/403 [============>.................] - ETA: 1s - 5ms/item
item 188/403 [============>.................] - ETA: 1s - 5ms/item
item 189/403 [=============>................] - ETA: 1s - 5ms/item
item 190/403 [=============>................] - ETA: 1s - 5ms/item
item 191/403 [=============>................] - ETA: 1s - 5ms/item
item 192/403 [=============>................] - ETA: 1s - 5ms/item
item 193/403 [=============>................] - ETA: 1s - 5ms/item
item 194/403 [=============>................] - ETA: 1s - 5ms/item
item 195/403 [=============>................] - ETA: 1s - 5ms/item
item 196/403 [=============>................] - ETA: 1s - 5ms/item
item 197/403 [=============>................] - ETA: 1s - 5ms/item
item 198/403 [=============>................] - ETA: 1s - 5ms/item
item 199/403 [=============>................] - ETA: 1s - 5ms/item
item 200/403 [=============>................] - ETA: 1s - 5ms/item
item 200/403 [=============>................] - ETA: 1s - 5ms/item
item 201/403 [=============>................] - ETA: 1s - 5ms/item
item 202/403 [==============>...............] - ETA: 1s - 5ms/item
item 203/403 [==============>...............] - ETA: 1s - 5ms/item
item 204/403 [==============>...............] - ETA: 1s - 5ms/item
item 205/403 [==============>...............] - ETA: 1s - 5ms/item
item 206/403 [==============>...............] - ETA: 1s - 5ms/item
item 207/403 [==============>...............] - ETA: 1s - 5ms/item
item 208/403 [==============>...............] - ETA: 1s - 5ms/item
item 209/403 [==============>...............] - ETA: 1s - 5ms/item
item 210/403 [==============>...............] - ETA: 0s - 5ms/item
item 211/403 [==============>...............] - ETA: 0s - 5ms/item
item 212/403 [==============>...............] - ETA: 0s - 5ms/item
item 213/403 [==============>...............] - ETA: 0s - 5ms/item
item 214/403 [==============>...............] - ETA: 0s - 5ms/item
item 215/403 [==============>...............] - ETA: 0s - 5ms/item
item 216/403 [===============>..............] - ETA: 0s - 5ms/item
item 217/403 [===============>..............] - ETA: 0s - 5ms/item
item 218/403 [===============>..............] - ETA: 0s - 5ms/item
item 219/403 [===============>..............] - ETA: 0s - 5ms/item
item 220/403 [===============>..............] - ETA: 0s - 5ms/item
item 220/403 [===============>..............] - ETA: 0s - 5ms/item
item 221/403 [===============>..............] - ETA: 0s - 5ms/item
item 222/403 [===============>..............] - ETA: 0s - 5ms/item
item 223/403 [===============>..............] - ETA: 0s - 5ms/item
item 224/403 [===============>..............] - ETA: 0s - 5ms/item
item 225/403 [===============>..............] - ETA: 0s - 5ms/item
item 226/403 [===============>..............] - ETA: 0s - 5ms/item
item 227/403 [===============>..............] - ETA: 0s - 5ms/item
item 228/403 [===============>..............] - ETA: 0s - 5ms/item
item 229/403 [================>.............] - ETA: 0s - 5ms/item
item 230/403 [================>.............] - ETA: 0s - 5ms/item
item 231/403 [================>.............] - ETA: 0s - 5ms/item
item 232/403 [================>.............] - ETA: 0s - 5ms/item
item 233/403 [================>.............] - ETA: 0s - 5ms/item
item 234/403 [================>.............] - ETA: 0s - 5ms/item
item 235/403 [================>.............] - ETA: 0s - 5ms/item
item 236/403 [================>.............] - ETA: 0s - 5ms/item
item 237/403 [================>.............] - ETA: 0s - 5ms/item
item 238/403 [================>.............] - ETA: 0s - 5ms/item
item 239/403 [================>.............] - ETA: 0s - 5ms/item
item 240/403 [================>.............] - ETA: 0s - 5ms/item
item 240/403 [================>.............] - ETA: 0s - 5ms/item
item 241/403 [================>.............] - ETA: 0s - 5ms/item
item 242/403 [================>.............] - ETA: 0s - 5ms/item
item 243/403 [=================>............] - ETA: 0s - 5ms/item
item 244/403 [=================>............] - ETA: 0s - 5ms/item
item 245/403 [=================>............] - ETA: 0s - 5ms/item
item 246/403 [=================>............] - ETA: 0s - 5ms/item
item 247/403 [=================>............] - ETA: 0s - 5ms/item
item 248/403 [=================>............] - ETA: 0s - 5ms/item
item 249/403 [=================>............] - ETA: 0s - 5ms/item
item 250/403 [=================>............] - ETA: 0s - 5ms/item
item 251/403 [=================>............] - ETA: 0s - 5ms/item
item 252/403 [=================>............] - ETA: 0s - 5ms/item
item 253/403 [=================>............] - ETA: 0s - 5ms/item
item 254/403 [=================>............] - ETA: 0s - 5ms/item
item 255/403 [=================>............] - ETA: 0s - 5ms/item
item 256/403 [==================>...........] - ETA: 0s - 5ms/item
item 257/403 [==================>...........] - ETA: 0s - 5ms/item
item 258/403 [==================>...........] - ETA: 0s - 5ms/item
item 259/403 [==================>...........] - ETA: 0s - 5ms/item
item 260/403 [==================>...........] - ETA: 0s - 5ms/item
item 260/403 [==================>...........] - ETA: 0s - 5ms/item
item 261/403 [==================>...........] - ETA: 0s - 5ms/item
item 262/403 [==================>...........] - ETA: 0s - 5ms/item
item 263/403 [==================>...........] - ETA: 0s - 5ms/item
item 264/403 [==================>...........] - ETA: 0s - 5ms/item
item 265/403 [==================>...........] - ETA: 0s - 5ms/item
item 266/403 [==================>...........] - ETA: 0s - 5ms/item
item 267/403 [==================>...........] - ETA: 0s - 5ms/item
item 268/403 [==================>...........] - ETA: 0s - 5ms/item
item 269/403 [==================>...........] - ETA: 0s - 5ms/item
item 270/403 [===================>..........] - ETA: 0s - 5ms/item
item 271/403 [===================>..........] - ETA: 0s - 5ms/item
item 272/403 [===================>..........] - ETA: 0s - 4ms/item
item 273/403 [===================>..........] - ETA: 0s - 5ms/item
item 274/403 [===================>..........] - ETA: 0s - 5ms/item
item 275/403 [===================>..........] - ETA: 0s - 5ms/item
item 276/403 [===================>..........] - ETA: 0s - 5ms/item
item 277/403 [===================>..........] - ETA: 0s - 5ms/item
item 278/403 [===================>..........] - ETA: 0s - 5ms/item
item 279/403 [===================>..........] - ETA: 0s - 4ms/item
item 280/403 [===================>..........] - ETA: 0s - 4ms/item
item 280/403 [===================>..........] - ETA: 0s - 4ms/item
item 281/403 [===================>..........] - ETA: 0s - 4ms/item
item 282/403 [===================>..........] - ETA: 0s - 4ms/item
item 283/403 [====================>.........] - ETA: 0s - 4ms/item
item 284/403 [====================>.........] - ETA: 0s - 4ms/item
item 285/403 [====================>.........] - ETA: 0s - 4ms/item
item 286/403 [====================>.........] - ETA: 0s - 4ms/item
item 287/403 [====================>.........] - ETA: 0s - 4ms/item
item 288/403 [====================>.........] - ETA: 0s - 4ms/item
item 289/403 [====================>.........] - ETA: 0s - 4ms/item
item 290/403 [====================>.........] - ETA: 0s - 4ms/item
item 291/403 [====================>.........] - ETA: 0s - 4ms/item
item 292/403 [====================>.........] - ETA: 0s - 4ms/item
item 293/403 [====================>.........] - ETA: 0s - 4ms/item
item 294/403 [====================>.........] - ETA: 0s - 4ms/item
item 295/403 [====================>.........] - ETA: 0s - 4ms/item
item 296/403 [=====================>........] - ETA: 0s - 4ms/item
item 297/403 [=====================>........] - ETA: 0s - 4ms/item
item 298/403 [=====================>........] - ETA: 0s - 4ms/item
item 299/403 [=====================>........] - ETA: 0s - 4ms/item
item 300/403 [=====================>........] - ETA: 0s - 4ms/item
item 300/403 [=====================>........] - ETA: 0s - 4ms/item
item 301/403 [=====================>........] - ETA: 0s - 4ms/item
item 302/403 [=====================>........] - ETA: 0s - 4ms/item
item 303/403 [=====================>........] - ETA: 0s - 4ms/item
item 304/403 [=====================>........] - ETA: 0s - 4ms/item
item 305/403 [=====================>........] - ETA: 0s - 4ms/item
item 306/403 [=====================>........] - ETA: 0s - 4ms/item
item 307/403 [=====================>........] - ETA: 0s - 4ms/item
item 308/403 [=====================>........] - ETA: 0s - 4ms/item
item 309/403 [=====================>........] - ETA: 0s - 4ms/item
item 310/403 [======================>.......] - ETA: 0s - 4ms/item
item 311/403 [======================>.......] - ETA: 0s - 4ms/item
item 312/403 [======================>.......] - ETA: 0s - 4ms/item
item 313/403 [======================>.......] - ETA: 0s - 4ms/item
item 314/403 [======================>.......] - ETA: 0s - 4ms/item
item 315/403 [======================>.......] - ETA: 0s - 4ms/item
item 316/403 [======================>.......] - ETA: 0s - 4ms/item
item 317/403 [======================>.......] - ETA: 0s - 4ms/item
item 318/403 [======================>.......] - ETA: 0s - 4ms/item
item 319/403 [======================>.......] - ETA: 0s - 4ms/item
item 320/403 [======================>.......] - ETA: 0s - 4ms/item
item 320/403 [======================>.......] - ETA: 0s - 4ms/item
item 321/403 [======================>.......] - ETA: 0s - 4ms/item
item 322/403 [======================>.......] - ETA: 0s - 4ms/item
item 323/403 [=======================>......] - ETA: 0s - 4ms/item
item 324/403 [=======================>......] - ETA: 0s - 4ms/item
item 325/403 [=======================>......] - ETA: 0s - 4ms/item
item 326/403 [=======================>......] - ETA: 0s - 4ms/item
item 327/403 [=======================>......] - ETA: 0s - 4ms/item
item 328/403 [=======================>......] - ETA: 0s - 4ms/item
item 329/403 [=======================>......] - ETA: 0s - 4ms/item
item 330/403 [=======================>......] - ETA: 0s - 4ms/item
item 331/403 [=======================>......] - ETA: 0s - 4ms/item
item 332/403 [=======================>......] - ETA: 0s - 4ms/item
item 333/403 [=======================>......] - ETA: 0s - 4ms/item
item 334/403 [=======================>......] - ETA: 0s - 4ms/item
item 335/403 [=======================>......] - ETA: 0s - 4ms/item
item 336/403 [=======================>......] - ETA: 0s - 4ms/item
item 337/403 [========================>.....] - ETA: 0s - 4ms/item
item 338/403 [========================>.....] - ETA: 0s - 4ms/item
item 339/403 [========================>.....] - ETA: 0s - 4ms/item
item 340/403 [========================>.....] - ETA: 0s - 4ms/item
item 340/403 [========================>.....] - ETA: 0s - 4ms/item
item 341/403 [========================>.....] - ETA: 0s - 4ms/item
item 342/403 [========================>.....] - ETA: 0s - 4ms/item
item 343/403 [========================>.....] - ETA: 0s - 4ms/item
item 344/403 [========================>.....] - ETA: 0s - 4ms/item
item 345/403 [========================>.....] - ETA: 0s - 4ms/item
item 346/403 [========================>.....] - ETA: 0s - 4ms/item
item 347/403 [========================>.....] - ETA: 0s - 4ms/item
item 348/403 [========================>.....] - ETA: 0s - 4ms/item
item 349/403 [========================>.....] - ETA: 0s - 4ms/item
item 350/403 [=========================>....] - ETA: 0s - 4ms/item
item 351/403 [=========================>....] - ETA: 0s - 4ms/item
item 352/403 [=========================>....] - ETA: 0s - 4ms/item
item 353/403 [=========================>....] - ETA: 0s - 4ms/item
item 354/403 [=========================>....] - ETA: 0s - 4ms/item
item 355/403 [=========================>....] - ETA: 0s - 4ms/item
item 356/403 [=========================>....] - ETA: 0s - 4ms/item
item 357/403 [=========================>....] - ETA: 0s - 4ms/item
item 358/403 [=========================>....] - ETA: 0s - 4ms/item
item 359/403 [=========================>....] - ETA: 0s - 4ms/item
item 360/403 [=========================>....] - ETA: 0s - 4ms/item
item 360/403 [=========================>....] - ETA: 0s - 4ms/item
item 361/403 [=========================>....] - ETA: 0s - 4ms/item
item 362/403 [=========================>....] - ETA: 0s - 4ms/item
item 363/403 [=========================>....] - ETA: 0s - 4ms/item
item 364/403 [==========================>...] - ETA: 0s - 4ms/item
item 365/403 [==========================>...] - ETA: 0s - 4ms/item
item 366/403 [==========================>...] - ETA: 0s - 4ms/item
item 367/403 [==========================>...] - ETA: 0s - 4ms/item
item 368/403 [==========================>...] - ETA: 0s - 4ms/item
item 369/403 [==========================>...] - ETA: 0s - 4ms/item
item 370/403 [==========================>...] - ETA: 0s - 4ms/item
item 371/403 [==========================>...] - ETA: 0s - 4ms/item
item 372/403 [==========================>...] - ETA: 0s - 4ms/item
item 373/403 [==========================>...] - ETA: 0s - 4ms/item
item 374/403 [==========================>...] - ETA: 0s - 4ms/item
item 375/403 [==========================>...] - ETA: 0s - 4ms/item
item 376/403 [==========================>...] - ETA: 0s - 4ms/item
item 377/403 [===========================>..] - ETA: 0s - 4ms/item
item 378/403 [===========================>..] - ETA: 0s - 4ms/item
item 379/403 [===========================>..] - ETA: 0s - 4ms/item
item 380/403 [===========================>..] - ETA: 0s - 4ms/item
item 380/403 [===========================>..] - ETA: 0s - 4ms/item
item 381/403 [===========================>..] - ETA: 0s - 4ms/item
item 382/403 [===========================>..] - ETA: 0s - 4ms/item
item 383/403 [===========================>..] - ETA: 0s - 4ms/item
item 384/403 [===========================>..] - ETA: 0s - 4ms/item
item 385/403 [===========================>..] - ETA: 0s - 4ms/item
item 386/403 [===========================>..] - ETA: 0s - 4ms/item
item 387/403 [===========================>..] - ETA: 0s - 4ms/item
item 388/403 [===========================>..] - ETA: 0s - 4ms/item
item 389/403 [===========================>..] - ETA: 0s - 4ms/item
item 390/403 [===========================>..] - ETA: 0s - 4ms/item
item 391/403 [============================>.] - ETA: 0s - 4ms/item
item 392/403 [============================>.] - ETA: 0s - 4ms/item
item 393/403 [============================>.] - ETA: 0s - 4ms/item
item 394/403 [============================>.] - ETA: 0s - 4ms/item
item 395/403 [============================>.] - ETA: 0s - 4ms/item
item 396/403 [============================>.] - ETA: 0s - 4ms/item
item 397/403 [============================>.] - ETA: 0s - 4ms/item
item 398/403 [============================>.] - ETA: 0s - 4ms/item
item 399/403 [============================>.] - ETA: 0s - 4ms/item
item 400/403 [============================>.] - ETA: 0s - 4ms/item
item 400/403 [============================>.] - ETA: 0s - 4ms/item
item 401/403 [============================>.] - ETA: 0s - 4ms/item
item 402/403 [============================>.] - ETA: 0s - 4ms/item
item 403/403 [============================>.] - ETA: 0s - 4ms/item
Download finished
Cache file /home/wzever/.cache/paddle/dataset/mnist/t10k-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-labels-idx1-ubyte.gz
Begin to download
item 1/2 [=============>................] - ETA: 0s - 105us/item
item 1/2 [=============>................] - ETA: 0s - 241us/item
item 2/2 [===========================>..] - ETA: 0s - 170us/item
item 2/2 [===========================>..] - ETA: 0s - 187us/item
Download finished
Print the layers of the simple CNN model:
print(model1)
SimpleNet(
(conv1): Conv2D(1, 32, kernel_size=[5, 5], padding=1, padding_mode=replicate, data_format=NCHW)
(max_pool): MaxPool2D(kernel_size=2, stride=None, padding=1)
(conv2): Conv2D(32, 64, kernel_size=[5, 5], padding=1, padding_mode=replicate, data_format=NCHW)
(fc1): Linear(in_features=3136, out_features=32, dtype=None)
(fc2): Linear(in_features=32, out_features=10, dtype=None)
)
Test the input models
with paddle.no_grad():
n_correct1 = 0
n_correct2 = 0
n_samples = 0
for images, labels in test_loader:
outputs1 = model1(images)
outputs2 = model2(images)
predictions1 = paddle.argmax(outputs1, 1)
predictions2 = paddle.argmax(outputs2, 1)
n_samples += labels.shape[0]
n_correct1 += (predictions1 == labels.t()).sum().item()
n_correct2 += (predictions2 == labels.t()).sum().item()
acc1 = 100 * n_correct1 / n_samples
acc2 = 100 * n_correct2 / n_samples
Testing results (two separate models):
print(f'model1 accuracy = {acc1}%, model2 accuracy = {acc2}%')
model1 accuracy = 84.18%, model2 accuracy = 83.81%
Build the affinity matrix for graph matching
As shown in the following plot, the neural networks can be regarded as graphs. The weights correspond to the edge features, and the bias corresponds to the node features. In this example, the neural network does not have bias so that there are only edge features.
plt.figure(figsize=(8, 4))
img = Image.open('../data/model_fusion.png')
plt.imshow(img)
plt.axis('off')
st_time = time.perf_counter()

Define the graph matching affinity metric function
class Ground_Metric_GM:
def __init__(self,
model_1_param: paddle.Tensor = None,
model_2_param: paddle.Tensor = None,
conv_param: bool = False,
bias_param: bool = False,
pre_conv_param: bool = False,
pre_conv_image_size_squared: int = None):
self.model_1_param = model_1_param
self.model_2_param = model_2_param
self.conv_param = conv_param
self.bias_param = bias_param
# bias, or fully-connected from linear
if bias_param is True or (conv_param is False and pre_conv_param is False):
self.model_1_param = self.model_1_param.reshape((1, -1, 1))
self.model_2_param = self.model_2_param.reshape((1, -1, 1))
# fully-connected from conv
elif conv_param is False and pre_conv_param is True:
self.model_1_param = self.model_1_param.reshape((1, -1, pre_conv_image_size_squared))
self.model_2_param = self.model_2_param.reshape((1, -1, pre_conv_image_size_squared))
# conv
else:
self.model_1_param = self.model_1_param.reshape((1, -1, model_1_param.shape[-1]))
self.model_2_param = self.model_2_param.reshape((1, -1, model_2_param.shape[-1]))
def process_distance(self, p: int = 2):
dist = []
cdist = paddle.nn.PairwiseDistance(p)
param_1 = self.model_1_param.cast('float32')[0]
param_2 = self.model_2_param.cast('float32')[0]
for i in param_1:
dist.append(cdist(i.broadcast_to(param_2.shape), param_2))
return paddle.to_tensor(dist)
def process_soft_affinity(self, p: int = 2):
return paddle.exp(0 - self.process_distance(p=p))
Define the affinity function between two neural networks. This function takes multiple neural network modules, and construct the corresponding affinity matrix which is further processed by the graph matching solver.
def graph_matching_fusion(networks: list):
def total_node_num(network: paddle.nn.Layer):
# count the total number of nodes in the network [network]
num_nodes = 0
for idx, (name, parameters) in enumerate(network.named_parameters()):
if 'bias' in name:
continue
if idx == 0:
num_nodes += parameters.shape[1]
# transpose linear layers in paddle to conventional shape,
num_nodes += parameters.shape[0] if 'fc' not in name else parameters.shape[1]
return num_nodes
n1 = total_node_num(network=networks[0])
n2 = total_node_num(network=networks[1])
assert (n1 == n2)
affinity = paddle.zeros([n1 * n2, n1 * n2])
num_layers = len(list(zip(networks[0].parameters(), networks[1].parameters())))
num_nodes_before = 0
num_nodes_incremental = []
num_nodes_layers = []
pre_conv_list = []
cur_conv_list = []
conv_kernel_size_list = []
num_nodes_pre = 0
is_conv = False
pre_conv = False
pre_conv_out_channel = 1
is_final_bias = False
perm_is_complete = True
named_weight_list_0 = [named_parameter for named_parameter in networks[0].named_parameters()]
for idx, ((name_0, fc_layer0_weight), (name_1, fc_layer1_weight)) in \
enumerate(zip(networks[0].named_parameters(), networks[1].named_parameters())):
assert fc_layer0_weight.shape == fc_layer1_weight.shape
if 'fc' in name_0:
fc_layer0_weight = fc_layer0_weight.t()
fc_layer1_weight = fc_layer1_weight.t()
layer_shape = fc_layer0_weight.shape
num_nodes_cur = fc_layer0_weight.shape[0]
if len(layer_shape) > 1:
if is_conv is True and len(layer_shape) == 2:
num_nodes_pre = pre_conv_out_channel
else:
num_nodes_pre = fc_layer0_weight.shape[1]
if idx >= 1 and len(named_weight_list_0[idx - 1][1].shape) == 1:
pre_bias = True
else:
pre_bias = False
if len(layer_shape) > 2:
is_bias = False
if not pre_bias:
pre_conv = is_conv
pre_conv_list.append(pre_conv)
is_conv = True
cur_conv_list.append(is_conv)
fc_layer0_weight_data = fc_layer0_weight.detach().reshape(
(fc_layer0_weight.shape[0], fc_layer0_weight.shape[1], -1))
fc_layer1_weight_data = fc_layer1_weight.detach().reshape(
(fc_layer1_weight.shape[0], fc_layer1_weight.shape[1], -1))
elif len(layer_shape) == 2:
is_bias = False
if not pre_bias:
pre_conv = is_conv
pre_conv_list.append(pre_conv)
is_conv = False
cur_conv_list.append(is_conv)
fc_layer0_weight_data = fc_layer0_weight.detach()
fc_layer1_weight_data = fc_layer1_weight.detach()
else:
is_bias = True
if not pre_bias:
pre_conv = is_conv
pre_conv_list.append(pre_conv)
is_conv = False
cur_conv_list.append(is_conv)
fc_layer0_weight_data = fc_layer0_weight.detach()
fc_layer1_weight_data = fc_layer1_weight.detach()
if is_conv:
pre_conv_out_channel = num_nodes_cur
if is_bias is True and idx == num_layers - 1:
is_final_bias = True
if idx == 0:
for a in range(num_nodes_pre):
affinity[(num_nodes_before + a) * n2 + num_nodes_before + a, \
(num_nodes_before + a) * n2 + num_nodes_before + a] \
= 1
if idx == num_layers - 2 and 'bias' in named_weight_list_0[idx + 1][0] or \
idx == num_layers - 1 and 'bias' not in named_weight_list_0[idx][0]:
for a in range(num_nodes_cur):
affinity[(num_nodes_before + num_nodes_pre + a) * n2 + num_nodes_before + num_nodes_pre + a, \
(num_nodes_before + num_nodes_pre + a) * n2 + num_nodes_before + num_nodes_pre + a] \
= 1
if is_bias is False:
ground_metric = Ground_Metric_GM(
fc_layer0_weight_data, fc_layer1_weight_data, is_conv, is_bias,
pre_conv, int(fc_layer0_weight_data.shape[1] / pre_conv_out_channel))
else:
ground_metric = Ground_Metric_GM(
fc_layer0_weight_data, fc_layer1_weight_data, is_conv, is_bias,
pre_conv, 1)
layer_affinity = ground_metric.process_soft_affinity(p=2)
if is_bias is False:
pre_conv_kernel_size = fc_layer0_weight.shape[3] if is_conv else None
conv_kernel_size_list.append(pre_conv_kernel_size)
if is_bias is True and is_final_bias is False:
for a in range(num_nodes_cur):
for c in range(num_nodes_cur):
affinity[(num_nodes_before + a) * n2 + num_nodes_before + c, \
(num_nodes_before + a) * n2 + num_nodes_before + c] \
= layer_affinity[a][c]
elif is_final_bias is False:
for a in range(num_nodes_pre):
for b in range(num_nodes_cur):
affinity[
(num_nodes_before + a) * n2 + num_nodes_before:
(num_nodes_before + a) * n2 + num_nodes_before + num_nodes_pre,
(num_nodes_before + num_nodes_pre + b) * n2 + num_nodes_before + num_nodes_pre:
(num_nodes_before + num_nodes_pre + b) * n2 + num_nodes_before + num_nodes_pre + num_nodes_cur] \
= layer_affinity[a + b * num_nodes_pre].reshape((num_nodes_cur, num_nodes_pre)).t()
if is_bias is False:
num_nodes_before += num_nodes_pre
num_nodes_incremental.append(num_nodes_before)
num_nodes_layers.append(num_nodes_cur)
# affinity = (affinity + affinity.t()) / 2
return affinity, [n1, n2, num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list]
Get the affinity (similarity) matrix between model1 and model2.
K, params = graph_matching_fusion([model1, model2])
Align the models by graph matching
Align the channels of model1 & model2 by maximize the affinity (similarity) via graph matching algorithms.
n1 = params[0]
n2 = params[1]
X = pygm.sm(K, n1, n2)
Project X
to neural network matching result. The neural network matching matrix is built by applying
Hungarian to small blocks of X
, because only the channels from the same neural network layer can be
matched.
Note
In this example, we assume the last FC layer is aligned and need not be matched.
new_X = paddle.zeros_like(X)
new_X[:params[2][0], :params[2][0]] = paddle.eye(params[2][0])
for start_idx, length in zip(params[2][:-1], params[3][:-1]): # params[2] and params[3] are the indices of layers
slicing = slice(start_idx, start_idx + length)
new_X[slicing, slicing] = pygm.hungarian(X[slicing, slicing])
# assume the last FC layer is aligned
slicing = slice(params[2][-1], params[2][-1] + params[3][-1])
new_X[slicing, slicing] = paddle.eye(params[3][-1])
X = new_X
Visualization of the matching result. The black lines splits the channels of different layers.
plt.figure(figsize=(4, 4))
plt.imshow(X.cpu().numpy(), cmap='Blues')
for idx in params[2]:
plt.axvline(x=idx, color='k')
plt.axhline(y=idx, color='k')

Define the alignment function: fuse the models based on matching result
def align(solution, fusion_proportion, networks: list, params: list):
[_, _, num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list] = params
named_weight_list_0 = [named_parameter for named_parameter in networks[0].named_parameters()]
aligned_wt_0 = [parameter.detach() if 'fc' not in name else parameter.detach().t() for name, parameter in named_weight_list_0]
idx = 0
num_layers = len(aligned_wt_0)
for num_before, num_cur, cur_conv, cur_kernel_size in \
zip(num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list):
perm = solution[num_before:num_before + num_cur, num_before:num_before + num_cur]
assert 'bias' not in named_weight_list_0[idx][0]
if len(named_weight_list_0[idx][1].shape) == 4:
aligned_wt_0[idx] = (perm.t().cast(paddle.float64) @
aligned_wt_0[idx].cast(paddle.float64).transpose((2, 3, 0, 1))) \
.transpose((2, 3, 0, 1))
else:
aligned_wt_0[idx] = perm.t().cast(paddle.float64) @ aligned_wt_0[idx].cast(paddle.float64)
idx += 1
if idx >= num_layers:
continue
if 'bias' in named_weight_list_0[idx][0]:
aligned_wt_0[idx] = aligned_wt_0[idx].cast(paddle.float64) @ perm.cast(paddle.float64)
idx += 1
if idx >= num_layers:
continue
if cur_conv and len(named_weight_list_0[idx][1].shape) == 2:
aligned_wt_0[idx] = (aligned_wt_0[idx].cast(paddle.float64)
.reshape((aligned_wt_0[idx].shape[0], 64, -1))
.transpose((0, 2, 1))
@ perm.cast(paddle.float64)) \
.transpose((0, 2, 1)) \
.reshape((aligned_wt_0[idx].shape[0], -1))
elif len(named_weight_list_0[idx][1].shape) == 4:
aligned_wt_0[idx] = (aligned_wt_0[idx].cast(paddle.float64)
.transpose((2, 3, 0, 1))
@ perm.cast(paddle.float64)) \
.transpose((2, 3, 0, 1))
else:
aligned_wt_0[idx] = aligned_wt_0[idx].cast(paddle.float64) @ perm.cast(paddle.float64)
assert idx == num_layers
averaged_weights = []
for idx, (named, parameter) in enumerate(networks[1].named_parameters()):
parameter = parameter.t() if 'fc' in named else parameter
averaged_weights.append((1 - fusion_proportion) * aligned_wt_0[idx].cast('float32') + fusion_proportion * parameter)
return averaged_weights
Test the fused model
The fusion_proportion
variable denotes the contribution to the new model. For example, if fusion_proportion=0.2
,
the fused model = 80% model1 + 20% model2.
def align_model_and_test(X):
acc_list = []
for fusion_proportion in paddle.arange(0, 11, 1) / 10: # paddle arange accepts int step only
fused_weights = align(X, fusion_proportion, [model1, model2], params)
fused_model = SimpleNet()
state_dict = fused_model.state_dict()
for idx, (key, _) in enumerate(state_dict.items()):
state_dict[key] = fused_weights[idx].t() if 'fc' in key else fused_weights[idx]
fused_model.set_dict(state_dict)
fused_model.to(device)
test_loss = 0
correct = 0
for data, target in test_loader:
output = fused_model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.detach().argmax(1, keepdim=True)
correct += pred.equal(target.detach().reshape(pred.shape)).sum()
test_loss /= len(test_loader.dataset)
acc = 100. * correct / len(test_loader.dataset)
print(
f"{1 - fusion_proportion.item():.2f} model1 + {fusion_proportion.item():.2f} model2 -> fused model accuracy: {acc.item():.2f}%")
acc_list.append(acc)
return paddle.to_tensor(acc_list)
print('Graph Matching Fusion')
gm_acc_list = align_model_and_test(X)
Graph Matching Fusion
1.00 model1 + 0.00 model2 -> fused model accuracy: 84.18%
0.90 model1 + 0.10 model2 -> fused model accuracy: 85.12%
0.80 model1 + 0.20 model2 -> fused model accuracy: 85.21%
0.70 model1 + 0.30 model2 -> fused model accuracy: 82.52%
0.60 model1 + 0.40 model2 -> fused model accuracy: 71.11%
0.50 model1 + 0.50 model2 -> fused model accuracy: 53.74%
0.40 model1 + 0.60 model2 -> fused model accuracy: 63.26%
0.30 model1 + 0.70 model2 -> fused model accuracy: 78.51%
0.20 model1 + 0.80 model2 -> fused model accuracy: 82.81%
0.10 model1 + 0.90 model2 -> fused model accuracy: 83.97%
0.00 model1 + 1.00 model2 -> fused model accuracy: 83.81%
Compare with vanilla model fusion (no matching), graph matching method stabilizes the fusion step:
print('No Matching Fusion')
vanilla_acc_list = align_model_and_test(paddle.eye(n1))
plt.figure(figsize=(4, 4))
plt.title('Fused Model Accuracy')
plt.plot((paddle.arange(0, 11, 1) / 10).numpy(), gm_acc_list.cpu().numpy(), 'r*-', label='Graph Matching Fusion')
plt.plot((paddle.arange(0, 11, 1) / 10).numpy(), vanilla_acc_list.cpu().numpy(), 'b*-', label='No Matching Fusion')
plt.plot((paddle.arange(0, 11, 1) / 10).numpy(), [acc1] * 11, '--', color="gray", label='Model1 Accuracy')
plt.plot((paddle.arange(0, 11, 1) / 10).numpy(), [acc2] * 11, '--', color="brown", label='Model2 Accuracy')
plt.gca().set_xlabel('Fusion Proportion')
plt.gca().set_ylabel('Accuracy (%)')
plt.ylim((70, 87))
plt.legend(loc=3)

No Matching Fusion
1.00 model1 + 0.00 model2 -> fused model accuracy: 84.18%
0.90 model1 + 0.10 model2 -> fused model accuracy: 84.01%
0.80 model1 + 0.20 model2 -> fused model accuracy: 81.91%
0.70 model1 + 0.30 model2 -> fused model accuracy: 74.67%
0.60 model1 + 0.40 model2 -> fused model accuracy: 60.39%
0.50 model1 + 0.50 model2 -> fused model accuracy: 47.16%
0.40 model1 + 0.60 model2 -> fused model accuracy: 55.34%
0.30 model1 + 0.70 model2 -> fused model accuracy: 72.86%
0.20 model1 + 0.80 model2 -> fused model accuracy: 79.64%
0.10 model1 + 0.90 model2 -> fused model accuracy: 82.56%
0.00 model1 + 1.00 model2 -> fused model accuracy: 83.81%
<matplotlib.legend.Legend object at 0x7fdef9dfde10>
Print the result summary
end_time = time.perf_counter()
print(f'time consumed for model fusion: {end_time - st_time:.2f} seconds')
print(f'model1 accuracy = {acc1}%, model2 accuracy = {acc2}%')
print(f"best fused model accuracy: {(paddle.max(gm_acc_list)).item():.2f}%")
time consumed for model fusion: 1836.18 seconds
model1 accuracy = 84.18%, model2 accuracy = 83.81%
best fused model accuracy: 85.21%
Note
This example supports both GPU and CPU, and the online documentation is built by a CPU-only machine. The efficiency will be significantly improved if you run this code on GPU.
Total running time of the script: (30 minutes 44.030 seconds)