Paddle Backend Example: Model Fusion by Graph Matching

This example shows how to fuse different models into a single model by pygmtools. Model fusion aims to fuse multiple models into one, such that the fused model could have higher performance. The neural networks can be regarded as graphs (channels - nodes, update functions between channels - edges; node feature - bias, edge feature - weights), and fusing the models is equivalent to solving a graph matching problem. In this example, the given models are trained on MNIST data from different distributions, and the fused model could combine the knowledge from two input models and can reach higher accuracy when testing.

# Author: Chang Liu <only-changer@sjtu.edu.cn>
#         Runzhong Wang <runzhong.wang@sjtu.edu.cn>
#         Wenzheng Pan <pwz1121@sjtu.edu.cn>
#
# License: Mulan PSL v2 License

Note

This is a simplified implementation of the ideas in Liu et al. Deep Neural Network Fusion via Graph Matching with Applications to Model Ensemble and Federated Learning. ICML 2022. For more details, please refer to the paper and the official code repository.

Note

The following solvers are included in this example:

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import time
from PIL import Image
import matplotlib.pyplot as plt
import pygmtools as pygm
import warnings
warnings.filterwarnings("ignore")

pygm.set_backend('paddle')
device = paddle.device.get_device()
paddle.device.set_device(device)
Place(cpu)

Define a simple CNN classifier network

class SimpleNet(nn.Layer):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2D(1, 32, 5, padding=1, padding_mode='replicate', bias_attr=False)
        self.max_pool = nn.MaxPool2D(2, padding=1)
        self.conv2 = nn.Conv2D(32, 64, 5, padding=1, padding_mode='replicate', bias_attr=False)
        self.fc1 = nn.Linear(3136, 32, bias_attr=False)
        self.fc2 = nn.Linear(32, 10, bias_attr=False)

    def forward(self, x):
        output = F.relu(self.conv1(x))
        output = self.max_pool(output)
        output = F.relu(self.conv2(output))
        output = self.max_pool(output)
        output = output.reshape((output.shape[0], -1))
        output = self.fc1(output)
        output = self.fc2(output)
        return output

Load the trained models to be fused

model1 = SimpleNet()
model2 = SimpleNet()
model1.set_dict(paddle.load('../data/example_model_fusion_1_paddle.dat'))
model2.set_dict(paddle.load('../data/example_model_fusion_2_paddle.dat'))
model1.to(device)
model2.to(device)
test_dataset = paddle.vision.datasets.MNIST(
    # unable to modify the directory to store the dataset.
    # default: ~/.cache/paddle/dataset/mnist
    mode='test',  # the dataset is used to test
    transform=transforms.ToTensor(),  # the dataset is in the form of tensors
    download=True)
test_loader = paddle.io.DataLoader(
    dataset=test_dataset,
    batch_size=32,
    shuffle=False)
Cache file /home/wzever/.cache/paddle/dataset/mnist/t10k-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-images-idx3-ubyte.gz
Begin to download

item   1/403 [..............................] - ETA: 17s - 42ms/item
item   2/403 [..............................] - ETA: 8s - 21ms/item 
item   3/403 [..............................] - ETA: 5s - 14ms/item
item   4/403 [..............................] - ETA: 4s - 11ms/item
item   5/403 [..............................] - ETA: 8s - 22ms/item
item   6/403 [..............................] - ETA: 7s - 18ms/item
item   7/403 [..............................] - ETA: 6s - 16ms/item
item   8/403 [..............................] - ETA: 5s - 14ms/item
item   9/403 [..............................] - ETA: 6s - 16ms/item
item  10/403 [..............................] - ETA: 5s - 14ms/item
item  11/403 [..............................] - ETA: 5s - 13ms/item
item  12/403 [..............................] - ETA: 4s - 12ms/item
item  13/403 [..............................] - ETA: 5s - 13ms/item
item  14/403 [>.............................] - ETA: 4s - 12ms/item
item  15/403 [>.............................] - ETA: 4s - 12ms/item
item  16/403 [>.............................] - ETA: 4s - 11ms/item
item  17/403 [>.............................] - ETA: 4s - 12ms/item
item  18/403 [>.............................] - ETA: 4s - 12ms/item
item  19/403 [>.............................] - ETA: 4s - 11ms/item
item  20/403 [>.............................] - ETA: 3s - 10ms/item
item  20/403 [>.............................] - ETA: 3s - 10ms/item
item  21/403 [>.............................] - ETA: 4s - 12ms/item
item  22/403 [>.............................] - ETA: 4s - 11ms/item
item  23/403 [>.............................] - ETA: 4s - 11ms/item
item  24/403 [>.............................] - ETA: 3s - 10ms/item
item  25/403 [>.............................] - ETA: 4s - 11ms/item
item  26/403 [>.............................] - ETA: 4s - 11ms/item
item  27/403 [=>............................] - ETA: 3s - 11ms/item
item  28/403 [=>............................] - ETA: 3s - 10ms/item
item  29/403 [=>............................] - ETA: 4s - 11ms/item
item  30/403 [=>............................] - ETA: 3s - 11ms/item
item  31/403 [=>............................] - ETA: 3s - 10ms/item
item  32/403 [=>............................] - ETA: 3s - 10ms/item
item  33/403 [=>............................] - ETA: 3s - 11ms/item
item  34/403 [=>............................] - ETA: 3s - 10ms/item
item  35/403 [=>............................] - ETA: 3s - 10ms/item
item  36/403 [=>............................] - ETA: 3s - 10ms/item
item  37/403 [=>............................] - ETA: 3s - 11ms/item
item  38/403 [=>............................] - ETA: 3s - 11ms/item
item  39/403 [=>............................] - ETA: 3s - 10ms/item
item  40/403 [=>............................] - ETA: 3s - 10ms/item
item  40/403 [=>............................] - ETA: 3s - 10ms/item
item  41/403 [==>...........................] - ETA: 3s - 10ms/item
item  42/403 [==>...........................] - ETA: 3s - 10ms/item
item  43/403 [==>...........................] - ETA: 3s - 9ms/item 
item  44/403 [==>...........................] - ETA: 3s - 9ms/item
item  45/403 [==>...........................] - ETA: 3s - 10ms/item
item  46/403 [==>...........................] - ETA: 3s - 9ms/item 
item  47/403 [==>...........................] - ETA: 3s - 9ms/item
item  48/403 [==>...........................] - ETA: 3s - 9ms/item
item  49/403 [==>...........................] - ETA: 3s - 10ms/item
item  50/403 [==>...........................] - ETA: 3s - 9ms/item 
item  51/403 [==>...........................] - ETA: 3s - 9ms/item
item  52/403 [==>...........................] - ETA: 3s - 9ms/item
item  53/403 [==>...........................] - ETA: 3s - 9ms/item
item  54/403 [===>..........................] - ETA: 3s - 9ms/item
item  55/403 [===>..........................] - ETA: 3s - 9ms/item
item  56/403 [===>..........................] - ETA: 3s - 9ms/item
item  57/403 [===>..........................] - ETA: 3s - 9ms/item
item  58/403 [===>..........................] - ETA: 2s - 9ms/item
item  59/403 [===>..........................] - ETA: 2s - 8ms/item
item  60/403 [===>..........................] - ETA: 2s - 8ms/item
item  60/403 [===>..........................] - ETA: 2s - 8ms/item
item  61/403 [===>..........................] - ETA: 2s - 9ms/item
item  62/403 [===>..........................] - ETA: 2s - 9ms/item
item  63/403 [===>..........................] - ETA: 2s - 8ms/item
item  64/403 [===>..........................] - ETA: 2s - 8ms/item
item  65/403 [===>..........................] - ETA: 2s - 9ms/item
item  66/403 [===>..........................] - ETA: 2s - 9ms/item
item  67/403 [===>..........................] - ETA: 2s - 8ms/item
item  68/403 [====>.........................] - ETA: 2s - 8ms/item
item  69/403 [====>.........................] - ETA: 2s - 8ms/item
item  70/403 [====>.........................] - ETA: 2s - 8ms/item
item  71/403 [====>.........................] - ETA: 2s - 8ms/item
item  72/403 [====>.........................] - ETA: 2s - 8ms/item
item  73/403 [====>.........................] - ETA: 2s - 8ms/item
item  74/403 [====>.........................] - ETA: 2s - 8ms/item
item  75/403 [====>.........................] - ETA: 2s - 8ms/item
item  76/403 [====>.........................] - ETA: 2s - 8ms/item
item  77/403 [====>.........................] - ETA: 2s - 8ms/item
item  78/403 [====>.........................] - ETA: 2s - 8ms/item
item  79/403 [====>.........................] - ETA: 2s - 8ms/item
item  80/403 [====>.........................] - ETA: 2s - 8ms/item
item  80/403 [====>.........................] - ETA: 2s - 8ms/item
item  81/403 [=====>........................] - ETA: 2s - 8ms/item
item  82/403 [=====>........................] - ETA: 2s - 8ms/item
item  83/403 [=====>........................] - ETA: 2s - 8ms/item
item  84/403 [=====>........................] - ETA: 2s - 8ms/item
item  85/403 [=====>........................] - ETA: 2s - 8ms/item
item  86/403 [=====>........................] - ETA: 2s - 8ms/item
item  87/403 [=====>........................] - ETA: 2s - 8ms/item
item  88/403 [=====>........................] - ETA: 2s - 8ms/item
item  89/403 [=====>........................] - ETA: 2s - 8ms/item
item  90/403 [=====>........................] - ETA: 2s - 7ms/item
item  91/403 [=====>........................] - ETA: 2s - 7ms/item
item  92/403 [=====>........................] - ETA: 2s - 7ms/item
item  93/403 [=====>........................] - ETA: 2s - 8ms/item
item  94/403 [=====>........................] - ETA: 2s - 7ms/item
item  95/403 [======>.......................] - ETA: 2s - 7ms/item
item  96/403 [======>.......................] - ETA: 2s - 7ms/item
item  97/403 [======>.......................] - ETA: 2s - 7ms/item
item  98/403 [======>.......................] - ETA: 2s - 7ms/item
item  99/403 [======>.......................] - ETA: 2s - 7ms/item
item 100/403 [======>.......................] - ETA: 2s - 7ms/item
item 100/403 [======>.......................] - ETA: 2s - 7ms/item
item 101/403 [======>.......................] - ETA: 2s - 7ms/item
item 102/403 [======>.......................] - ETA: 2s - 7ms/item
item 103/403 [======>.......................] - ETA: 2s - 7ms/item
item 104/403 [======>.......................] - ETA: 2s - 7ms/item
item 105/403 [======>.......................] - ETA: 2s - 7ms/item
item 106/403 [======>.......................] - ETA: 2s - 7ms/item
item 107/403 [======>.......................] - ETA: 2s - 7ms/item
item 108/403 [=======>......................] - ETA: 2s - 7ms/item
item 109/403 [=======>......................] - ETA: 2s - 7ms/item
item 110/403 [=======>......................] - ETA: 2s - 7ms/item
item 111/403 [=======>......................] - ETA: 2s - 7ms/item
item 112/403 [=======>......................] - ETA: 2s - 7ms/item
item 113/403 [=======>......................] - ETA: 1s - 7ms/item
item 114/403 [=======>......................] - ETA: 1s - 7ms/item
item 115/403 [=======>......................] - ETA: 1s - 7ms/item
item 116/403 [=======>......................] - ETA: 1s - 7ms/item
item 117/403 [=======>......................] - ETA: 1s - 7ms/item
item 118/403 [=======>......................] - ETA: 1s - 7ms/item
item 119/403 [=======>......................] - ETA: 1s - 7ms/item
item 120/403 [=======>......................] - ETA: 1s - 7ms/item
item 120/403 [=======>......................] - ETA: 1s - 7ms/item
item 121/403 [=======>......................] - ETA: 1s - 7ms/item
item 122/403 [========>.....................] - ETA: 1s - 7ms/item
item 123/403 [========>.....................] - ETA: 1s - 7ms/item
item 124/403 [========>.....................] - ETA: 1s - 6ms/item
item 125/403 [========>.....................] - ETA: 1s - 6ms/item
item 126/403 [========>.....................] - ETA: 1s - 6ms/item
item 127/403 [========>.....................] - ETA: 1s - 6ms/item
item 128/403 [========>.....................] - ETA: 1s - 6ms/item
item 129/403 [========>.....................] - ETA: 1s - 7ms/item
item 130/403 [========>.....................] - ETA: 1s - 6ms/item
item 131/403 [========>.....................] - ETA: 1s - 6ms/item
item 132/403 [========>.....................] - ETA: 1s - 6ms/item
item 133/403 [========>.....................] - ETA: 1s - 6ms/item
item 134/403 [========>.....................] - ETA: 1s - 6ms/item
item 135/403 [=========>....................] - ETA: 1s - 6ms/item
item 136/403 [=========>....................] - ETA: 1s - 6ms/item
item 137/403 [=========>....................] - ETA: 1s - 6ms/item
item 138/403 [=========>....................] - ETA: 1s - 6ms/item
item 139/403 [=========>....................] - ETA: 1s - 6ms/item
item 140/403 [=========>....................] - ETA: 1s - 6ms/item
item 140/403 [=========>....................] - ETA: 1s - 6ms/item
item 141/403 [=========>....................] - ETA: 1s - 6ms/item
item 142/403 [=========>....................] - ETA: 1s - 6ms/item
item 143/403 [=========>....................] - ETA: 1s - 6ms/item
item 144/403 [=========>....................] - ETA: 1s - 6ms/item
item 145/403 [=========>....................] - ETA: 1s - 6ms/item
item 146/403 [=========>....................] - ETA: 1s - 6ms/item
item 147/403 [=========>....................] - ETA: 1s - 6ms/item
item 148/403 [==========>...................] - ETA: 1s - 6ms/item
item 149/403 [==========>...................] - ETA: 1s - 6ms/item
item 150/403 [==========>...................] - ETA: 1s - 6ms/item
item 151/403 [==========>...................] - ETA: 1s - 6ms/item
item 152/403 [==========>...................] - ETA: 1s - 6ms/item
item 153/403 [==========>...................] - ETA: 1s - 6ms/item
item 154/403 [==========>...................] - ETA: 1s - 6ms/item
item 155/403 [==========>...................] - ETA: 1s - 6ms/item
item 156/403 [==========>...................] - ETA: 1s - 6ms/item
item 157/403 [==========>...................] - ETA: 1s - 6ms/item
item 158/403 [==========>...................] - ETA: 1s - 6ms/item
item 159/403 [==========>...................] - ETA: 1s - 6ms/item
item 160/403 [==========>...................] - ETA: 1s - 6ms/item
item 160/403 [==========>...................] - ETA: 1s - 6ms/item
item 161/403 [==========>...................] - ETA: 1s - 6ms/item
item 162/403 [===========>..................] - ETA: 1s - 6ms/item
item 163/403 [===========>..................] - ETA: 1s - 6ms/item
item 164/403 [===========>..................] - ETA: 1s - 6ms/item
item 165/403 [===========>..................] - ETA: 1s - 6ms/item
item 166/403 [===========>..................] - ETA: 1s - 6ms/item
item 167/403 [===========>..................] - ETA: 1s - 6ms/item
item 168/403 [===========>..................] - ETA: 1s - 6ms/item
item 169/403 [===========>..................] - ETA: 1s - 6ms/item
item 170/403 [===========>..................] - ETA: 1s - 6ms/item
item 171/403 [===========>..................] - ETA: 1s - 6ms/item
item 172/403 [===========>..................] - ETA: 1s - 6ms/item
item 173/403 [===========>..................] - ETA: 1s - 6ms/item
item 174/403 [===========>..................] - ETA: 1s - 6ms/item
item 175/403 [============>.................] - ETA: 1s - 6ms/item
item 176/403 [============>.................] - ETA: 1s - 6ms/item
item 177/403 [============>.................] - ETA: 1s - 6ms/item
item 178/403 [============>.................] - ETA: 1s - 6ms/item
item 179/403 [============>.................] - ETA: 1s - 6ms/item
item 180/403 [============>.................] - ETA: 1s - 6ms/item
item 180/403 [============>.................] - ETA: 1s - 6ms/item
item 181/403 [============>.................] - ETA: 1s - 6ms/item
item 182/403 [============>.................] - ETA: 1s - 6ms/item
item 183/403 [============>.................] - ETA: 1s - 5ms/item
item 184/403 [============>.................] - ETA: 1s - 5ms/item
item 185/403 [============>.................] - ETA: 1s - 5ms/item
item 186/403 [============>.................] - ETA: 1s - 5ms/item
item 187/403 [============>.................] - ETA: 1s - 5ms/item
item 188/403 [============>.................] - ETA: 1s - 5ms/item
item 189/403 [=============>................] - ETA: 1s - 5ms/item
item 190/403 [=============>................] - ETA: 1s - 5ms/item
item 191/403 [=============>................] - ETA: 1s - 5ms/item
item 192/403 [=============>................] - ETA: 1s - 5ms/item
item 193/403 [=============>................] - ETA: 1s - 5ms/item
item 194/403 [=============>................] - ETA: 1s - 5ms/item
item 195/403 [=============>................] - ETA: 1s - 5ms/item
item 196/403 [=============>................] - ETA: 1s - 5ms/item
item 197/403 [=============>................] - ETA: 1s - 5ms/item
item 198/403 [=============>................] - ETA: 1s - 5ms/item
item 199/403 [=============>................] - ETA: 1s - 5ms/item
item 200/403 [=============>................] - ETA: 1s - 5ms/item
item 200/403 [=============>................] - ETA: 1s - 5ms/item
item 201/403 [=============>................] - ETA: 1s - 5ms/item
item 202/403 [==============>...............] - ETA: 1s - 5ms/item
item 203/403 [==============>...............] - ETA: 1s - 5ms/item
item 204/403 [==============>...............] - ETA: 1s - 5ms/item
item 205/403 [==============>...............] - ETA: 1s - 5ms/item
item 206/403 [==============>...............] - ETA: 1s - 5ms/item
item 207/403 [==============>...............] - ETA: 1s - 5ms/item
item 208/403 [==============>...............] - ETA: 1s - 5ms/item
item 209/403 [==============>...............] - ETA: 1s - 5ms/item
item 210/403 [==============>...............] - ETA: 0s - 5ms/item
item 211/403 [==============>...............] - ETA: 0s - 5ms/item
item 212/403 [==============>...............] - ETA: 0s - 5ms/item
item 213/403 [==============>...............] - ETA: 0s - 5ms/item
item 214/403 [==============>...............] - ETA: 0s - 5ms/item
item 215/403 [==============>...............] - ETA: 0s - 5ms/item
item 216/403 [===============>..............] - ETA: 0s - 5ms/item
item 217/403 [===============>..............] - ETA: 0s - 5ms/item
item 218/403 [===============>..............] - ETA: 0s - 5ms/item
item 219/403 [===============>..............] - ETA: 0s - 5ms/item
item 220/403 [===============>..............] - ETA: 0s - 5ms/item
item 220/403 [===============>..............] - ETA: 0s - 5ms/item
item 221/403 [===============>..............] - ETA: 0s - 5ms/item
item 222/403 [===============>..............] - ETA: 0s - 5ms/item
item 223/403 [===============>..............] - ETA: 0s - 5ms/item
item 224/403 [===============>..............] - ETA: 0s - 5ms/item
item 225/403 [===============>..............] - ETA: 0s - 5ms/item
item 226/403 [===============>..............] - ETA: 0s - 5ms/item
item 227/403 [===============>..............] - ETA: 0s - 5ms/item
item 228/403 [===============>..............] - ETA: 0s - 5ms/item
item 229/403 [================>.............] - ETA: 0s - 5ms/item
item 230/403 [================>.............] - ETA: 0s - 5ms/item
item 231/403 [================>.............] - ETA: 0s - 5ms/item
item 232/403 [================>.............] - ETA: 0s - 5ms/item
item 233/403 [================>.............] - ETA: 0s - 5ms/item
item 234/403 [================>.............] - ETA: 0s - 5ms/item
item 235/403 [================>.............] - ETA: 0s - 5ms/item
item 236/403 [================>.............] - ETA: 0s - 5ms/item
item 237/403 [================>.............] - ETA: 0s - 5ms/item
item 238/403 [================>.............] - ETA: 0s - 5ms/item
item 239/403 [================>.............] - ETA: 0s - 5ms/item
item 240/403 [================>.............] - ETA: 0s - 5ms/item
item 240/403 [================>.............] - ETA: 0s - 5ms/item
item 241/403 [================>.............] - ETA: 0s - 5ms/item
item 242/403 [================>.............] - ETA: 0s - 5ms/item
item 243/403 [=================>............] - ETA: 0s - 5ms/item
item 244/403 [=================>............] - ETA: 0s - 5ms/item
item 245/403 [=================>............] - ETA: 0s - 5ms/item
item 246/403 [=================>............] - ETA: 0s - 5ms/item
item 247/403 [=================>............] - ETA: 0s - 5ms/item
item 248/403 [=================>............] - ETA: 0s - 5ms/item
item 249/403 [=================>............] - ETA: 0s - 5ms/item
item 250/403 [=================>............] - ETA: 0s - 5ms/item
item 251/403 [=================>............] - ETA: 0s - 5ms/item
item 252/403 [=================>............] - ETA: 0s - 5ms/item
item 253/403 [=================>............] - ETA: 0s - 5ms/item
item 254/403 [=================>............] - ETA: 0s - 5ms/item
item 255/403 [=================>............] - ETA: 0s - 5ms/item
item 256/403 [==================>...........] - ETA: 0s - 5ms/item
item 257/403 [==================>...........] - ETA: 0s - 5ms/item
item 258/403 [==================>...........] - ETA: 0s - 5ms/item
item 259/403 [==================>...........] - ETA: 0s - 5ms/item
item 260/403 [==================>...........] - ETA: 0s - 5ms/item
item 260/403 [==================>...........] - ETA: 0s - 5ms/item
item 261/403 [==================>...........] - ETA: 0s - 5ms/item
item 262/403 [==================>...........] - ETA: 0s - 5ms/item
item 263/403 [==================>...........] - ETA: 0s - 5ms/item
item 264/403 [==================>...........] - ETA: 0s - 5ms/item
item 265/403 [==================>...........] - ETA: 0s - 5ms/item
item 266/403 [==================>...........] - ETA: 0s - 5ms/item
item 267/403 [==================>...........] - ETA: 0s - 5ms/item
item 268/403 [==================>...........] - ETA: 0s - 5ms/item
item 269/403 [==================>...........] - ETA: 0s - 5ms/item
item 270/403 [===================>..........] - ETA: 0s - 5ms/item
item 271/403 [===================>..........] - ETA: 0s - 5ms/item
item 272/403 [===================>..........] - ETA: 0s - 4ms/item
item 273/403 [===================>..........] - ETA: 0s - 5ms/item
item 274/403 [===================>..........] - ETA: 0s - 5ms/item
item 275/403 [===================>..........] - ETA: 0s - 5ms/item
item 276/403 [===================>..........] - ETA: 0s - 5ms/item
item 277/403 [===================>..........] - ETA: 0s - 5ms/item
item 278/403 [===================>..........] - ETA: 0s - 5ms/item
item 279/403 [===================>..........] - ETA: 0s - 4ms/item
item 280/403 [===================>..........] - ETA: 0s - 4ms/item
item 280/403 [===================>..........] - ETA: 0s - 4ms/item
item 281/403 [===================>..........] - ETA: 0s - 4ms/item
item 282/403 [===================>..........] - ETA: 0s - 4ms/item
item 283/403 [====================>.........] - ETA: 0s - 4ms/item
item 284/403 [====================>.........] - ETA: 0s - 4ms/item
item 285/403 [====================>.........] - ETA: 0s - 4ms/item
item 286/403 [====================>.........] - ETA: 0s - 4ms/item
item 287/403 [====================>.........] - ETA: 0s - 4ms/item
item 288/403 [====================>.........] - ETA: 0s - 4ms/item
item 289/403 [====================>.........] - ETA: 0s - 4ms/item
item 290/403 [====================>.........] - ETA: 0s - 4ms/item
item 291/403 [====================>.........] - ETA: 0s - 4ms/item
item 292/403 [====================>.........] - ETA: 0s - 4ms/item
item 293/403 [====================>.........] - ETA: 0s - 4ms/item
item 294/403 [====================>.........] - ETA: 0s - 4ms/item
item 295/403 [====================>.........] - ETA: 0s - 4ms/item
item 296/403 [=====================>........] - ETA: 0s - 4ms/item
item 297/403 [=====================>........] - ETA: 0s - 4ms/item
item 298/403 [=====================>........] - ETA: 0s - 4ms/item
item 299/403 [=====================>........] - ETA: 0s - 4ms/item
item 300/403 [=====================>........] - ETA: 0s - 4ms/item
item 300/403 [=====================>........] - ETA: 0s - 4ms/item
item 301/403 [=====================>........] - ETA: 0s - 4ms/item
item 302/403 [=====================>........] - ETA: 0s - 4ms/item
item 303/403 [=====================>........] - ETA: 0s - 4ms/item
item 304/403 [=====================>........] - ETA: 0s - 4ms/item
item 305/403 [=====================>........] - ETA: 0s - 4ms/item
item 306/403 [=====================>........] - ETA: 0s - 4ms/item
item 307/403 [=====================>........] - ETA: 0s - 4ms/item
item 308/403 [=====================>........] - ETA: 0s - 4ms/item
item 309/403 [=====================>........] - ETA: 0s - 4ms/item
item 310/403 [======================>.......] - ETA: 0s - 4ms/item
item 311/403 [======================>.......] - ETA: 0s - 4ms/item
item 312/403 [======================>.......] - ETA: 0s - 4ms/item
item 313/403 [======================>.......] - ETA: 0s - 4ms/item
item 314/403 [======================>.......] - ETA: 0s - 4ms/item
item 315/403 [======================>.......] - ETA: 0s - 4ms/item
item 316/403 [======================>.......] - ETA: 0s - 4ms/item
item 317/403 [======================>.......] - ETA: 0s - 4ms/item
item 318/403 [======================>.......] - ETA: 0s - 4ms/item
item 319/403 [======================>.......] - ETA: 0s - 4ms/item
item 320/403 [======================>.......] - ETA: 0s - 4ms/item
item 320/403 [======================>.......] - ETA: 0s - 4ms/item
item 321/403 [======================>.......] - ETA: 0s - 4ms/item
item 322/403 [======================>.......] - ETA: 0s - 4ms/item
item 323/403 [=======================>......] - ETA: 0s - 4ms/item
item 324/403 [=======================>......] - ETA: 0s - 4ms/item
item 325/403 [=======================>......] - ETA: 0s - 4ms/item
item 326/403 [=======================>......] - ETA: 0s - 4ms/item
item 327/403 [=======================>......] - ETA: 0s - 4ms/item
item 328/403 [=======================>......] - ETA: 0s - 4ms/item
item 329/403 [=======================>......] - ETA: 0s - 4ms/item
item 330/403 [=======================>......] - ETA: 0s - 4ms/item
item 331/403 [=======================>......] - ETA: 0s - 4ms/item
item 332/403 [=======================>......] - ETA: 0s - 4ms/item
item 333/403 [=======================>......] - ETA: 0s - 4ms/item
item 334/403 [=======================>......] - ETA: 0s - 4ms/item
item 335/403 [=======================>......] - ETA: 0s - 4ms/item
item 336/403 [=======================>......] - ETA: 0s - 4ms/item
item 337/403 [========================>.....] - ETA: 0s - 4ms/item
item 338/403 [========================>.....] - ETA: 0s - 4ms/item
item 339/403 [========================>.....] - ETA: 0s - 4ms/item
item 340/403 [========================>.....] - ETA: 0s - 4ms/item
item 340/403 [========================>.....] - ETA: 0s - 4ms/item
item 341/403 [========================>.....] - ETA: 0s - 4ms/item
item 342/403 [========================>.....] - ETA: 0s - 4ms/item
item 343/403 [========================>.....] - ETA: 0s - 4ms/item
item 344/403 [========================>.....] - ETA: 0s - 4ms/item
item 345/403 [========================>.....] - ETA: 0s - 4ms/item
item 346/403 [========================>.....] - ETA: 0s - 4ms/item
item 347/403 [========================>.....] - ETA: 0s - 4ms/item
item 348/403 [========================>.....] - ETA: 0s - 4ms/item
item 349/403 [========================>.....] - ETA: 0s - 4ms/item
item 350/403 [=========================>....] - ETA: 0s - 4ms/item
item 351/403 [=========================>....] - ETA: 0s - 4ms/item
item 352/403 [=========================>....] - ETA: 0s - 4ms/item
item 353/403 [=========================>....] - ETA: 0s - 4ms/item
item 354/403 [=========================>....] - ETA: 0s - 4ms/item
item 355/403 [=========================>....] - ETA: 0s - 4ms/item
item 356/403 [=========================>....] - ETA: 0s - 4ms/item
item 357/403 [=========================>....] - ETA: 0s - 4ms/item
item 358/403 [=========================>....] - ETA: 0s - 4ms/item
item 359/403 [=========================>....] - ETA: 0s - 4ms/item
item 360/403 [=========================>....] - ETA: 0s - 4ms/item
item 360/403 [=========================>....] - ETA: 0s - 4ms/item
item 361/403 [=========================>....] - ETA: 0s - 4ms/item
item 362/403 [=========================>....] - ETA: 0s - 4ms/item
item 363/403 [=========================>....] - ETA: 0s - 4ms/item
item 364/403 [==========================>...] - ETA: 0s - 4ms/item
item 365/403 [==========================>...] - ETA: 0s - 4ms/item
item 366/403 [==========================>...] - ETA: 0s - 4ms/item
item 367/403 [==========================>...] - ETA: 0s - 4ms/item
item 368/403 [==========================>...] - ETA: 0s - 4ms/item
item 369/403 [==========================>...] - ETA: 0s - 4ms/item
item 370/403 [==========================>...] - ETA: 0s - 4ms/item
item 371/403 [==========================>...] - ETA: 0s - 4ms/item
item 372/403 [==========================>...] - ETA: 0s - 4ms/item
item 373/403 [==========================>...] - ETA: 0s - 4ms/item
item 374/403 [==========================>...] - ETA: 0s - 4ms/item
item 375/403 [==========================>...] - ETA: 0s - 4ms/item
item 376/403 [==========================>...] - ETA: 0s - 4ms/item
item 377/403 [===========================>..] - ETA: 0s - 4ms/item
item 378/403 [===========================>..] - ETA: 0s - 4ms/item
item 379/403 [===========================>..] - ETA: 0s - 4ms/item
item 380/403 [===========================>..] - ETA: 0s - 4ms/item
item 380/403 [===========================>..] - ETA: 0s - 4ms/item
item 381/403 [===========================>..] - ETA: 0s - 4ms/item
item 382/403 [===========================>..] - ETA: 0s - 4ms/item
item 383/403 [===========================>..] - ETA: 0s - 4ms/item
item 384/403 [===========================>..] - ETA: 0s - 4ms/item
item 385/403 [===========================>..] - ETA: 0s - 4ms/item
item 386/403 [===========================>..] - ETA: 0s - 4ms/item
item 387/403 [===========================>..] - ETA: 0s - 4ms/item
item 388/403 [===========================>..] - ETA: 0s - 4ms/item
item 389/403 [===========================>..] - ETA: 0s - 4ms/item
item 390/403 [===========================>..] - ETA: 0s - 4ms/item
item 391/403 [============================>.] - ETA: 0s - 4ms/item
item 392/403 [============================>.] - ETA: 0s - 4ms/item
item 393/403 [============================>.] - ETA: 0s - 4ms/item
item 394/403 [============================>.] - ETA: 0s - 4ms/item
item 395/403 [============================>.] - ETA: 0s - 4ms/item
item 396/403 [============================>.] - ETA: 0s - 4ms/item
item 397/403 [============================>.] - ETA: 0s - 4ms/item
item 398/403 [============================>.] - ETA: 0s - 4ms/item
item 399/403 [============================>.] - ETA: 0s - 4ms/item
item 400/403 [============================>.] - ETA: 0s - 4ms/item
item 400/403 [============================>.] - ETA: 0s - 4ms/item
item 401/403 [============================>.] - ETA: 0s - 4ms/item
item 402/403 [============================>.] - ETA: 0s - 4ms/item
item 403/403 [============================>.] - ETA: 0s - 4ms/item
Download finished
Cache file /home/wzever/.cache/paddle/dataset/mnist/t10k-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-labels-idx1-ubyte.gz
Begin to download

item 1/2 [=============>................] - ETA: 0s - 105us/item
item 1/2 [=============>................] - ETA: 0s - 241us/item
item 2/2 [===========================>..] - ETA: 0s - 170us/item
item 2/2 [===========================>..] - ETA: 0s - 187us/item
Download finished

Print the layers of the simple CNN model:

print(model1)
SimpleNet(
  (conv1): Conv2D(1, 32, kernel_size=[5, 5], padding=1, padding_mode=replicate, data_format=NCHW)
  (max_pool): MaxPool2D(kernel_size=2, stride=None, padding=1)
  (conv2): Conv2D(32, 64, kernel_size=[5, 5], padding=1, padding_mode=replicate, data_format=NCHW)
  (fc1): Linear(in_features=3136, out_features=32, dtype=None)
  (fc2): Linear(in_features=32, out_features=10, dtype=None)
)

Test the input models

with paddle.no_grad():
    n_correct1 = 0
    n_correct2 = 0
    n_samples = 0
    for images, labels in test_loader:
        outputs1 = model1(images)
        outputs2 = model2(images)
        predictions1 = paddle.argmax(outputs1, 1)
        predictions2 = paddle.argmax(outputs2, 1)
        n_samples += labels.shape[0]
        n_correct1 += (predictions1 == labels.t()).sum().item()
        n_correct2 += (predictions2 == labels.t()).sum().item()
    acc1 = 100 * n_correct1 / n_samples
    acc2 = 100 * n_correct2 / n_samples

Testing results (two separate models):

print(f'model1 accuracy = {acc1}%, model2 accuracy = {acc2}%')
model1 accuracy = 84.18%, model2 accuracy = 83.81%

Build the affinity matrix for graph matching

As shown in the following plot, the neural networks can be regarded as graphs. The weights correspond to the edge features, and the bias corresponds to the node features. In this example, the neural network does not have bias so that there are only edge features.

plt.figure(figsize=(8, 4))
img = Image.open('../data/model_fusion.png')
plt.imshow(img)
plt.axis('off')
st_time = time.perf_counter()
plot model fusion paddle

Define the graph matching affinity metric function

class Ground_Metric_GM:
    def __init__(self,
                 model_1_param: paddle.Tensor = None,
                 model_2_param: paddle.Tensor = None,
                 conv_param: bool = False,
                 bias_param: bool = False,
                 pre_conv_param: bool = False,
                 pre_conv_image_size_squared: int = None):
        self.model_1_param = model_1_param
        self.model_2_param = model_2_param
        self.conv_param = conv_param
        self.bias_param = bias_param
        # bias, or fully-connected from linear
        if bias_param is True or (conv_param is False and pre_conv_param is False):
            self.model_1_param = self.model_1_param.reshape((1, -1, 1))
            self.model_2_param = self.model_2_param.reshape((1, -1, 1))
        # fully-connected from conv
        elif conv_param is False and pre_conv_param is True:
            self.model_1_param = self.model_1_param.reshape((1, -1, pre_conv_image_size_squared))
            self.model_2_param = self.model_2_param.reshape((1, -1, pre_conv_image_size_squared))
        # conv
        else:
            self.model_1_param = self.model_1_param.reshape((1, -1, model_1_param.shape[-1]))
            self.model_2_param = self.model_2_param.reshape((1, -1, model_2_param.shape[-1]))

    def process_distance(self, p: int = 2):
        dist = []
        cdist = paddle.nn.PairwiseDistance(p)
        param_1 = self.model_1_param.cast('float32')[0]
        param_2 = self.model_2_param.cast('float32')[0]
        for i in param_1:
            dist.append(cdist(i.broadcast_to(param_2.shape), param_2))
        return paddle.to_tensor(dist)

    def process_soft_affinity(self, p: int = 2):
        return paddle.exp(0 - self.process_distance(p=p))

Define the affinity function between two neural networks. This function takes multiple neural network modules, and construct the corresponding affinity matrix which is further processed by the graph matching solver.

def graph_matching_fusion(networks: list):
    def total_node_num(network: paddle.nn.Layer):
        # count the total number of nodes in the network [network]
        num_nodes = 0
        for idx, (name, parameters) in enumerate(network.named_parameters()):
            if 'bias' in name:
                continue
            if idx == 0:
                num_nodes += parameters.shape[1]
            # transpose linear layers in paddle to conventional shape,
            num_nodes += parameters.shape[0] if 'fc' not in name else parameters.shape[1]
        return num_nodes

    n1 = total_node_num(network=networks[0])
    n2 = total_node_num(network=networks[1])
    assert (n1 == n2)
    affinity = paddle.zeros([n1 * n2, n1 * n2])
    num_layers = len(list(zip(networks[0].parameters(), networks[1].parameters())))
    num_nodes_before = 0
    num_nodes_incremental = []
    num_nodes_layers = []
    pre_conv_list = []
    cur_conv_list = []
    conv_kernel_size_list = []
    num_nodes_pre = 0
    is_conv = False
    pre_conv = False
    pre_conv_out_channel = 1
    is_final_bias = False
    perm_is_complete = True
    named_weight_list_0 = [named_parameter for named_parameter in networks[0].named_parameters()]
    for idx, ((name_0, fc_layer0_weight), (name_1, fc_layer1_weight)) in \
            enumerate(zip(networks[0].named_parameters(), networks[1].named_parameters())):
        assert fc_layer0_weight.shape == fc_layer1_weight.shape
        if 'fc' in name_0:
            fc_layer0_weight = fc_layer0_weight.t()
            fc_layer1_weight = fc_layer1_weight.t()
        layer_shape = fc_layer0_weight.shape
        num_nodes_cur = fc_layer0_weight.shape[0]
        if len(layer_shape) > 1:
            if is_conv is True and len(layer_shape) == 2:
                num_nodes_pre = pre_conv_out_channel
            else:
                num_nodes_pre = fc_layer0_weight.shape[1]
        if idx >= 1 and len(named_weight_list_0[idx - 1][1].shape) == 1:
            pre_bias = True
        else:
            pre_bias = False
        if len(layer_shape) > 2:
            is_bias = False
            if not pre_bias:
                pre_conv = is_conv
                pre_conv_list.append(pre_conv)
            is_conv = True
            cur_conv_list.append(is_conv)
            fc_layer0_weight_data = fc_layer0_weight.detach().reshape(
                (fc_layer0_weight.shape[0], fc_layer0_weight.shape[1], -1))
            fc_layer1_weight_data = fc_layer1_weight.detach().reshape(
                (fc_layer1_weight.shape[0], fc_layer1_weight.shape[1], -1))
        elif len(layer_shape) == 2:
            is_bias = False
            if not pre_bias:
                pre_conv = is_conv
                pre_conv_list.append(pre_conv)
            is_conv = False
            cur_conv_list.append(is_conv)
            fc_layer0_weight_data = fc_layer0_weight.detach()
            fc_layer1_weight_data = fc_layer1_weight.detach()
        else:
            is_bias = True
            if not pre_bias:
                pre_conv = is_conv
                pre_conv_list.append(pre_conv)
            is_conv = False
            cur_conv_list.append(is_conv)
            fc_layer0_weight_data = fc_layer0_weight.detach()
            fc_layer1_weight_data = fc_layer1_weight.detach()
        if is_conv:
            pre_conv_out_channel = num_nodes_cur
        if is_bias is True and idx == num_layers - 1:
            is_final_bias = True
        if idx == 0:
            for a in range(num_nodes_pre):
                affinity[(num_nodes_before + a) * n2 + num_nodes_before + a, \
                         (num_nodes_before + a) * n2 + num_nodes_before + a] \
                        = 1
        if idx == num_layers - 2 and 'bias' in named_weight_list_0[idx + 1][0] or \
                idx == num_layers - 1 and 'bias' not in named_weight_list_0[idx][0]:
            for a in range(num_nodes_cur):
                affinity[(num_nodes_before + num_nodes_pre + a) * n2 + num_nodes_before + num_nodes_pre + a, \
                         (num_nodes_before + num_nodes_pre + a) * n2 + num_nodes_before + num_nodes_pre + a] \
                        = 1
        if is_bias is False:
            ground_metric = Ground_Metric_GM(
                fc_layer0_weight_data, fc_layer1_weight_data, is_conv, is_bias,
                pre_conv, int(fc_layer0_weight_data.shape[1] / pre_conv_out_channel))
        else:
            ground_metric = Ground_Metric_GM(
                fc_layer0_weight_data, fc_layer1_weight_data, is_conv, is_bias,
                pre_conv, 1)

        layer_affinity = ground_metric.process_soft_affinity(p=2)

        if is_bias is False:
            pre_conv_kernel_size = fc_layer0_weight.shape[3] if is_conv else None
            conv_kernel_size_list.append(pre_conv_kernel_size)
        if is_bias is True and is_final_bias is False:
            for a in range(num_nodes_cur):
                for c in range(num_nodes_cur):
                    affinity[(num_nodes_before + a) * n2 + num_nodes_before + c, \
                             (num_nodes_before + a) * n2 + num_nodes_before + c] \
                            = layer_affinity[a][c]
        elif is_final_bias is False:
            for a in range(num_nodes_pre):
                for b in range(num_nodes_cur):
                    affinity[
                    (num_nodes_before + a) * n2 + num_nodes_before:
                    (num_nodes_before + a) * n2 + num_nodes_before + num_nodes_pre,
                    (num_nodes_before + num_nodes_pre + b) * n2 + num_nodes_before + num_nodes_pre:
                    (num_nodes_before + num_nodes_pre + b) * n2 + num_nodes_before + num_nodes_pre + num_nodes_cur] \
                        = layer_affinity[a + b * num_nodes_pre].reshape((num_nodes_cur, num_nodes_pre)).t()
        if is_bias is False:
            num_nodes_before += num_nodes_pre
            num_nodes_incremental.append(num_nodes_before)
            num_nodes_layers.append(num_nodes_cur)
    # affinity = (affinity + affinity.t()) / 2
    return affinity, [n1, n2, num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list]

Get the affinity (similarity) matrix between model1 and model2.

K, params = graph_matching_fusion([model1, model2])

Align the models by graph matching

Align the channels of model1 & model2 by maximize the affinity (similarity) via graph matching algorithms.

n1 = params[0]
n2 = params[1]
X = pygm.sm(K, n1, n2)

Project X to neural network matching result. The neural network matching matrix is built by applying Hungarian to small blocks of X, because only the channels from the same neural network layer can be matched.

Note

In this example, we assume the last FC layer is aligned and need not be matched.

new_X = paddle.zeros_like(X)
new_X[:params[2][0], :params[2][0]] = paddle.eye(params[2][0])
for start_idx, length in zip(params[2][:-1], params[3][:-1]):  # params[2] and params[3] are the indices of layers
    slicing = slice(start_idx, start_idx + length)
    new_X[slicing, slicing] = pygm.hungarian(X[slicing, slicing])
# assume the last FC layer is aligned
slicing = slice(params[2][-1], params[2][-1] + params[3][-1])
new_X[slicing, slicing] = paddle.eye(params[3][-1])
X = new_X

Visualization of the matching result. The black lines splits the channels of different layers.

plt.figure(figsize=(4, 4))
plt.imshow(X.cpu().numpy(), cmap='Blues')
for idx in params[2]:
    plt.axvline(x=idx, color='k')
    plt.axhline(y=idx, color='k')
plot model fusion paddle

Define the alignment function: fuse the models based on matching result

def align(solution, fusion_proportion, networks: list, params: list):
    [_, _, num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list] = params
    named_weight_list_0 = [named_parameter for named_parameter in networks[0].named_parameters()]
    aligned_wt_0 = [parameter.detach() if 'fc' not in name else parameter.detach().t() for name, parameter in named_weight_list_0]
    idx = 0
    num_layers = len(aligned_wt_0)
    for num_before, num_cur, cur_conv, cur_kernel_size in \
            zip(num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list):
        perm = solution[num_before:num_before + num_cur, num_before:num_before + num_cur]
        assert 'bias' not in named_weight_list_0[idx][0]
        if len(named_weight_list_0[idx][1].shape) == 4:
            aligned_wt_0[idx] = (perm.t().cast(paddle.float64) @
                                 aligned_wt_0[idx].cast(paddle.float64).transpose((2, 3, 0, 1))) \
                .transpose((2, 3, 0, 1))
        else:
            aligned_wt_0[idx] = perm.t().cast(paddle.float64) @ aligned_wt_0[idx].cast(paddle.float64)
        idx += 1
        if idx >= num_layers:
            continue
        if 'bias' in named_weight_list_0[idx][0]:
            aligned_wt_0[idx] = aligned_wt_0[idx].cast(paddle.float64) @ perm.cast(paddle.float64)
            idx += 1
        if idx >= num_layers:
            continue
        if cur_conv and len(named_weight_list_0[idx][1].shape) == 2:
            aligned_wt_0[idx] = (aligned_wt_0[idx].cast(paddle.float64)
                                 .reshape((aligned_wt_0[idx].shape[0], 64, -1))
                                 .transpose((0, 2, 1))
                                 @ perm.cast(paddle.float64)) \
                .transpose((0, 2, 1)) \
                .reshape((aligned_wt_0[idx].shape[0], -1))
        elif len(named_weight_list_0[idx][1].shape) == 4:
            aligned_wt_0[idx] = (aligned_wt_0[idx].cast(paddle.float64)
                                 .transpose((2, 3, 0, 1))
                                 @ perm.cast(paddle.float64)) \
                .transpose((2, 3, 0, 1))
        else:
            aligned_wt_0[idx] = aligned_wt_0[idx].cast(paddle.float64) @ perm.cast(paddle.float64)
    assert idx == num_layers

    averaged_weights = []
    for idx, (named, parameter) in enumerate(networks[1].named_parameters()):
        parameter = parameter.t() if 'fc' in named else parameter
        averaged_weights.append((1 - fusion_proportion) * aligned_wt_0[idx].cast('float32') + fusion_proportion * parameter)
    return averaged_weights

Test the fused model

The fusion_proportion variable denotes the contribution to the new model. For example, if fusion_proportion=0.2, the fused model = 80% model1 + 20% model2.

def align_model_and_test(X):
    acc_list = []
    for fusion_proportion in paddle.arange(0, 11, 1) / 10: # paddle arange accepts int step only
        fused_weights = align(X, fusion_proportion, [model1, model2], params)

        fused_model = SimpleNet()
        state_dict = fused_model.state_dict()
        for idx, (key, _) in enumerate(state_dict.items()):
            state_dict[key] = fused_weights[idx].t() if 'fc' in key else fused_weights[idx]
        fused_model.set_dict(state_dict)
        fused_model.to(device)
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            output = fused_model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.detach().argmax(1, keepdim=True)
            correct += pred.equal(target.detach().reshape(pred.shape)).sum()
        test_loss /= len(test_loader.dataset)
        acc = 100. * correct / len(test_loader.dataset)
        print(
            f"{1 - fusion_proportion.item():.2f} model1 + {fusion_proportion.item():.2f} model2 -> fused model accuracy: {acc.item():.2f}%")
        acc_list.append(acc)
    return paddle.to_tensor(acc_list)


print('Graph Matching Fusion')
gm_acc_list = align_model_and_test(X)
Graph Matching Fusion
1.00 model1 + 0.00 model2 -> fused model accuracy: 84.18%
0.90 model1 + 0.10 model2 -> fused model accuracy: 85.12%
0.80 model1 + 0.20 model2 -> fused model accuracy: 85.21%
0.70 model1 + 0.30 model2 -> fused model accuracy: 82.52%
0.60 model1 + 0.40 model2 -> fused model accuracy: 71.11%
0.50 model1 + 0.50 model2 -> fused model accuracy: 53.74%
0.40 model1 + 0.60 model2 -> fused model accuracy: 63.26%
0.30 model1 + 0.70 model2 -> fused model accuracy: 78.51%
0.20 model1 + 0.80 model2 -> fused model accuracy: 82.81%
0.10 model1 + 0.90 model2 -> fused model accuracy: 83.97%
0.00 model1 + 1.00 model2 -> fused model accuracy: 83.81%

Compare with vanilla model fusion (no matching), graph matching method stabilizes the fusion step:

print('No Matching Fusion')
vanilla_acc_list = align_model_and_test(paddle.eye(n1))

plt.figure(figsize=(4, 4))
plt.title('Fused Model Accuracy')
plt.plot((paddle.arange(0, 11, 1) / 10).numpy(), gm_acc_list.cpu().numpy(), 'r*-', label='Graph Matching Fusion')
plt.plot((paddle.arange(0, 11, 1) / 10).numpy(), vanilla_acc_list.cpu().numpy(), 'b*-', label='No Matching Fusion')
plt.plot((paddle.arange(0, 11, 1) / 10).numpy(), [acc1] * 11, '--', color="gray", label='Model1 Accuracy')
plt.plot((paddle.arange(0, 11, 1) / 10).numpy(), [acc2] * 11, '--', color="brown", label='Model2 Accuracy')
plt.gca().set_xlabel('Fusion Proportion')
plt.gca().set_ylabel('Accuracy (%)')
plt.ylim((70, 87))
plt.legend(loc=3)
Fused Model Accuracy
No Matching Fusion
1.00 model1 + 0.00 model2 -> fused model accuracy: 84.18%
0.90 model1 + 0.10 model2 -> fused model accuracy: 84.01%
0.80 model1 + 0.20 model2 -> fused model accuracy: 81.91%
0.70 model1 + 0.30 model2 -> fused model accuracy: 74.67%
0.60 model1 + 0.40 model2 -> fused model accuracy: 60.39%
0.50 model1 + 0.50 model2 -> fused model accuracy: 47.16%
0.40 model1 + 0.60 model2 -> fused model accuracy: 55.34%
0.30 model1 + 0.70 model2 -> fused model accuracy: 72.86%
0.20 model1 + 0.80 model2 -> fused model accuracy: 79.64%
0.10 model1 + 0.90 model2 -> fused model accuracy: 82.56%
0.00 model1 + 1.00 model2 -> fused model accuracy: 83.81%

<matplotlib.legend.Legend object at 0x7fdef9dfde10>