.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/5.model_fusion/plot_model_fusion_paddle.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_5.model_fusion_plot_model_fusion_paddle.py: ====================================================== Paddle Backend Example: Model Fusion by Graph Matching ====================================================== This example shows how to fuse different models into a single model by ``pygmtools``. Model fusion aims to fuse multiple models into one, such that the fused model could have higher performance. The neural networks can be regarded as graphs (channels - nodes, update functions between channels - edges; node feature - bias, edge feature - weights), and fusing the models is equivalent to solving a graph matching problem. In this example, the given models are trained on MNIST data from different distributions, and the fused model could combine the knowledge from two input models and can reach higher accuracy when testing. .. GENERATED FROM PYTHON SOURCE LINES 14-21 .. code-block:: default # Author: Chang Liu # Runzhong Wang # Wenzheng Pan # # License: Mulan PSL v2 License .. GENERATED FROM PYTHON SOURCE LINES 23-34 .. note:: This is a simplified implementation of the ideas in `Liu et al. Deep Neural Network Fusion via Graph Matching with Applications to Model Ensemble and Federated Learning. ICML 2022. `_ For more details, please refer to the paper and the `official code repository `_. .. note:: The following solvers are included in this example: * :func:`~pygmtools.classic_solvers.sm` (classic solver) * :func:`~pygmtools.linear_solvers.hungarian` (linear solver) .. GENERATED FROM PYTHON SOURCE LINES 34-50 .. code-block:: default import paddle import paddle.nn as nn import paddle.nn.functional as F import paddle.vision.transforms as transforms import time from PIL import Image import matplotlib.pyplot as plt import pygmtools as pygm import warnings warnings.filterwarnings("ignore") pygm.set_backend('paddle') device = paddle.device.get_device() paddle.device.set_device(device) .. rst-class:: sphx-glr-script-out .. code-block:: none Place(cpu) .. GENERATED FROM PYTHON SOURCE LINES 51-54 Define a simple CNN classifier network --------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 54-74 .. code-block:: default class SimpleNet(nn.Layer): def __init__(self): super(SimpleNet, self).__init__() self.conv1 = nn.Conv2D(1, 32, 5, padding=1, padding_mode='replicate', bias_attr=False) self.max_pool = nn.MaxPool2D(2, padding=1) self.conv2 = nn.Conv2D(32, 64, 5, padding=1, padding_mode='replicate', bias_attr=False) self.fc1 = nn.Linear(3136, 32, bias_attr=False) self.fc2 = nn.Linear(32, 10, bias_attr=False) def forward(self, x): output = F.relu(self.conv1(x)) output = self.max_pool(output) output = F.relu(self.conv2(output)) output = self.max_pool(output) output = output.reshape((output.shape[0], -1)) output = self.fc1(output) output = self.fc2(output) return output .. GENERATED FROM PYTHON SOURCE LINES 75-78 Load the trained models to be fused ------------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 78-95 .. code-block:: default model1 = SimpleNet() model2 = SimpleNet() model1.set_dict(paddle.load('../data/example_model_fusion_1_paddle.dat')) model2.set_dict(paddle.load('../data/example_model_fusion_2_paddle.dat')) model1.to(device) model2.to(device) test_dataset = paddle.vision.datasets.MNIST( # unable to modify the directory to store the dataset. # default: ~/.cache/paddle/dataset/mnist mode='test', # the dataset is used to test transform=transforms.ToTensor(), # the dataset is in the form of tensors download=True) test_loader = paddle.io.DataLoader( dataset=test_dataset, batch_size=32, shuffle=False) .. rst-class:: sphx-glr-script-out .. code-block:: none Cache file /home/wzever/.cache/paddle/dataset/mnist/t10k-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-images-idx3-ubyte.gz Begin to download item 1/403 [..............................] - ETA: 17s - 42ms/item item 2/403 [..............................] - ETA: 8s - 21ms/item  item 3/403 [..............................] - ETA: 5s - 14ms/item item 4/403 [..............................] - ETA: 4s - 11ms/item item 5/403 [..............................] - ETA: 8s - 22ms/item item 6/403 [..............................] - ETA: 7s - 18ms/item item 7/403 [..............................] - ETA: 6s - 16ms/item item 8/403 [..............................] - ETA: 5s - 14ms/item item 9/403 [..............................] - ETA: 6s - 16ms/item item 10/403 [..............................] - ETA: 5s - 14ms/item item 11/403 [..............................] - ETA: 5s - 13ms/item item 12/403 [..............................] - ETA: 4s - 12ms/item item 13/403 [..............................] - ETA: 5s - 13ms/item item 14/403 [>.............................] - ETA: 4s - 12ms/item item 15/403 [>.............................] - ETA: 4s - 12ms/item item 16/403 [>.............................] - ETA: 4s - 11ms/item item 17/403 [>.............................] - ETA: 4s - 12ms/item item 18/403 [>.............................] - ETA: 4s - 12ms/item item 19/403 [>.............................] - ETA: 4s - 11ms/item item 20/403 [>.............................] - ETA: 3s - 10ms/item item 20/403 [>.............................] - ETA: 3s - 10ms/item item 21/403 [>.............................] - ETA: 4s - 12ms/item item 22/403 [>.............................] - ETA: 4s - 11ms/item item 23/403 [>.............................] - ETA: 4s - 11ms/item item 24/403 [>.............................] - ETA: 3s - 10ms/item item 25/403 [>.............................] - ETA: 4s - 11ms/item item 26/403 [>.............................] - ETA: 4s - 11ms/item item 27/403 [=>............................] - ETA: 3s - 11ms/item item 28/403 [=>............................] - ETA: 3s - 10ms/item item 29/403 [=>............................] - ETA: 4s - 11ms/item item 30/403 [=>............................] - ETA: 3s - 11ms/item item 31/403 [=>............................] - ETA: 3s - 10ms/item item 32/403 [=>............................] - ETA: 3s - 10ms/item item 33/403 [=>............................] - ETA: 3s - 11ms/item item 34/403 [=>............................] - ETA: 3s - 10ms/item item 35/403 [=>............................] - ETA: 3s - 10ms/item item 36/403 [=>............................] - ETA: 3s - 10ms/item item 37/403 [=>............................] - ETA: 3s - 11ms/item item 38/403 [=>............................] - ETA: 3s - 11ms/item item 39/403 [=>............................] - ETA: 3s - 10ms/item item 40/403 [=>............................] - ETA: 3s - 10ms/item item 40/403 [=>............................] - ETA: 3s - 10ms/item item 41/403 [==>...........................] - ETA: 3s - 10ms/item item 42/403 [==>...........................] - ETA: 3s - 10ms/item item 43/403 [==>...........................] - ETA: 3s - 9ms/item  item 44/403 [==>...........................] - ETA: 3s - 9ms/item item 45/403 [==>...........................] - ETA: 3s - 10ms/item item 46/403 [==>...........................] - ETA: 3s - 9ms/item  item 47/403 [==>...........................] - ETA: 3s - 9ms/item item 48/403 [==>...........................] - ETA: 3s - 9ms/item item 49/403 [==>...........................] - ETA: 3s - 10ms/item item 50/403 [==>...........................] - ETA: 3s - 9ms/item  item 51/403 [==>...........................] - ETA: 3s - 9ms/item item 52/403 [==>...........................] - ETA: 3s - 9ms/item item 53/403 [==>...........................] - ETA: 3s - 9ms/item item 54/403 [===>..........................] - ETA: 3s - 9ms/item item 55/403 [===>..........................] - ETA: 3s - 9ms/item item 56/403 [===>..........................] - ETA: 3s - 9ms/item item 57/403 [===>..........................] - ETA: 3s - 9ms/item item 58/403 [===>..........................] - ETA: 2s - 9ms/item item 59/403 [===>..........................] - ETA: 2s - 8ms/item item 60/403 [===>..........................] - ETA: 2s - 8ms/item item 60/403 [===>..........................] - ETA: 2s - 8ms/item item 61/403 [===>..........................] - ETA: 2s - 9ms/item item 62/403 [===>..........................] - ETA: 2s - 9ms/item item 63/403 [===>..........................] - ETA: 2s - 8ms/item item 64/403 [===>..........................] - ETA: 2s - 8ms/item item 65/403 [===>..........................] - ETA: 2s - 9ms/item item 66/403 [===>..........................] - ETA: 2s - 9ms/item item 67/403 [===>..........................] - ETA: 2s - 8ms/item item 68/403 [====>.........................] - ETA: 2s - 8ms/item item 69/403 [====>.........................] - ETA: 2s - 8ms/item item 70/403 [====>.........................] - ETA: 2s - 8ms/item item 71/403 [====>.........................] - ETA: 2s - 8ms/item item 72/403 [====>.........................] - ETA: 2s - 8ms/item item 73/403 [====>.........................] - ETA: 2s - 8ms/item item 74/403 [====>.........................] - ETA: 2s - 8ms/item item 75/403 [====>.........................] - ETA: 2s - 8ms/item item 76/403 [====>.........................] - ETA: 2s - 8ms/item item 77/403 [====>.........................] - ETA: 2s - 8ms/item item 78/403 [====>.........................] - ETA: 2s - 8ms/item item 79/403 [====>.........................] - ETA: 2s - 8ms/item item 80/403 [====>.........................] - ETA: 2s - 8ms/item item 80/403 [====>.........................] - ETA: 2s - 8ms/item item 81/403 [=====>........................] - ETA: 2s - 8ms/item item 82/403 [=====>........................] - ETA: 2s - 8ms/item item 83/403 [=====>........................] - ETA: 2s - 8ms/item item 84/403 [=====>........................] - ETA: 2s - 8ms/item item 85/403 [=====>........................] - ETA: 2s - 8ms/item item 86/403 [=====>........................] - ETA: 2s - 8ms/item item 87/403 [=====>........................] - ETA: 2s - 8ms/item item 88/403 [=====>........................] - ETA: 2s - 8ms/item item 89/403 [=====>........................] - ETA: 2s - 8ms/item item 90/403 [=====>........................] - ETA: 2s - 7ms/item item 91/403 [=====>........................] - ETA: 2s - 7ms/item item 92/403 [=====>........................] - ETA: 2s - 7ms/item item 93/403 [=====>........................] - ETA: 2s - 8ms/item item 94/403 [=====>........................] - ETA: 2s - 7ms/item item 95/403 [======>.......................] - ETA: 2s - 7ms/item item 96/403 [======>.......................] - ETA: 2s - 7ms/item item 97/403 [======>.......................] - ETA: 2s - 7ms/item item 98/403 [======>.......................] - ETA: 2s - 7ms/item item 99/403 [======>.......................] - ETA: 2s - 7ms/item item 100/403 [======>.......................] - ETA: 2s - 7ms/item item 100/403 [======>.......................] - ETA: 2s - 7ms/item item 101/403 [======>.......................] - ETA: 2s - 7ms/item item 102/403 [======>.......................] - ETA: 2s - 7ms/item item 103/403 [======>.......................] - ETA: 2s - 7ms/item item 104/403 [======>.......................] - ETA: 2s - 7ms/item item 105/403 [======>.......................] - ETA: 2s - 7ms/item item 106/403 [======>.......................] - ETA: 2s - 7ms/item item 107/403 [======>.......................] - ETA: 2s - 7ms/item item 108/403 [=======>......................] - ETA: 2s - 7ms/item item 109/403 [=======>......................] - ETA: 2s - 7ms/item item 110/403 [=======>......................] - ETA: 2s - 7ms/item item 111/403 [=======>......................] - ETA: 2s - 7ms/item item 112/403 [=======>......................] - ETA: 2s - 7ms/item item 113/403 [=======>......................] - ETA: 1s - 7ms/item item 114/403 [=======>......................] - ETA: 1s - 7ms/item item 115/403 [=======>......................] - ETA: 1s - 7ms/item item 116/403 [=======>......................] - ETA: 1s - 7ms/item item 117/403 [=======>......................] - ETA: 1s - 7ms/item item 118/403 [=======>......................] - ETA: 1s - 7ms/item item 119/403 [=======>......................] - ETA: 1s - 7ms/item item 120/403 [=======>......................] - ETA: 1s - 7ms/item item 120/403 [=======>......................] - ETA: 1s - 7ms/item item 121/403 [=======>......................] - ETA: 1s - 7ms/item item 122/403 [========>.....................] - ETA: 1s - 7ms/item item 123/403 [========>.....................] - ETA: 1s - 7ms/item item 124/403 [========>.....................] - ETA: 1s - 6ms/item item 125/403 [========>.....................] - ETA: 1s - 6ms/item item 126/403 [========>.....................] - ETA: 1s - 6ms/item item 127/403 [========>.....................] - ETA: 1s - 6ms/item item 128/403 [========>.....................] - ETA: 1s - 6ms/item item 129/403 [========>.....................] - ETA: 1s - 7ms/item item 130/403 [========>.....................] - ETA: 1s - 6ms/item item 131/403 [========>.....................] - ETA: 1s - 6ms/item item 132/403 [========>.....................] - ETA: 1s - 6ms/item item 133/403 [========>.....................] - ETA: 1s - 6ms/item item 134/403 [========>.....................] - ETA: 1s - 6ms/item item 135/403 [=========>....................] - ETA: 1s - 6ms/item item 136/403 [=========>....................] - ETA: 1s - 6ms/item item 137/403 [=========>....................] - ETA: 1s - 6ms/item item 138/403 [=========>....................] - ETA: 1s - 6ms/item item 139/403 [=========>....................] - ETA: 1s - 6ms/item item 140/403 [=========>....................] - ETA: 1s - 6ms/item item 140/403 [=========>....................] - ETA: 1s - 6ms/item item 141/403 [=========>....................] - ETA: 1s - 6ms/item item 142/403 [=========>....................] - ETA: 1s - 6ms/item item 143/403 [=========>....................] - ETA: 1s - 6ms/item item 144/403 [=========>....................] - ETA: 1s - 6ms/item item 145/403 [=========>....................] - ETA: 1s - 6ms/item item 146/403 [=========>....................] - ETA: 1s - 6ms/item item 147/403 [=========>....................] - ETA: 1s - 6ms/item item 148/403 [==========>...................] - ETA: 1s - 6ms/item item 149/403 [==========>...................] - ETA: 1s - 6ms/item item 150/403 [==========>...................] - ETA: 1s - 6ms/item item 151/403 [==========>...................] - ETA: 1s - 6ms/item item 152/403 [==========>...................] - ETA: 1s - 6ms/item item 153/403 [==========>...................] - ETA: 1s - 6ms/item item 154/403 [==========>...................] - ETA: 1s - 6ms/item item 155/403 [==========>...................] - ETA: 1s - 6ms/item item 156/403 [==========>...................] - ETA: 1s - 6ms/item item 157/403 [==========>...................] - ETA: 1s - 6ms/item item 158/403 [==========>...................] - ETA: 1s - 6ms/item item 159/403 [==========>...................] - ETA: 1s - 6ms/item item 160/403 [==========>...................] - ETA: 1s - 6ms/item item 160/403 [==========>...................] - ETA: 1s - 6ms/item item 161/403 [==========>...................] - ETA: 1s - 6ms/item item 162/403 [===========>..................] - ETA: 1s - 6ms/item item 163/403 [===========>..................] - ETA: 1s - 6ms/item item 164/403 [===========>..................] - ETA: 1s - 6ms/item item 165/403 [===========>..................] - ETA: 1s - 6ms/item item 166/403 [===========>..................] - ETA: 1s - 6ms/item item 167/403 [===========>..................] - ETA: 1s - 6ms/item item 168/403 [===========>..................] - ETA: 1s - 6ms/item item 169/403 [===========>..................] - ETA: 1s - 6ms/item item 170/403 [===========>..................] - ETA: 1s - 6ms/item item 171/403 [===========>..................] - ETA: 1s - 6ms/item item 172/403 [===========>..................] - ETA: 1s - 6ms/item item 173/403 [===========>..................] - ETA: 1s - 6ms/item item 174/403 [===========>..................] - ETA: 1s - 6ms/item item 175/403 [============>.................] - ETA: 1s - 6ms/item item 176/403 [============>.................] - ETA: 1s - 6ms/item item 177/403 [============>.................] - ETA: 1s - 6ms/item item 178/403 [============>.................] - ETA: 1s - 6ms/item item 179/403 [============>.................] - ETA: 1s - 6ms/item item 180/403 [============>.................] - ETA: 1s - 6ms/item item 180/403 [============>.................] - ETA: 1s - 6ms/item item 181/403 [============>.................] - ETA: 1s - 6ms/item item 182/403 [============>.................] - ETA: 1s - 6ms/item item 183/403 [============>.................] - ETA: 1s - 5ms/item item 184/403 [============>.................] - ETA: 1s - 5ms/item item 185/403 [============>.................] - ETA: 1s - 5ms/item item 186/403 [============>.................] - ETA: 1s - 5ms/item item 187/403 [============>.................] - ETA: 1s - 5ms/item item 188/403 [============>.................] - ETA: 1s - 5ms/item item 189/403 [=============>................] - ETA: 1s - 5ms/item item 190/403 [=============>................] - ETA: 1s - 5ms/item item 191/403 [=============>................] - ETA: 1s - 5ms/item item 192/403 [=============>................] - ETA: 1s - 5ms/item item 193/403 [=============>................] - ETA: 1s - 5ms/item item 194/403 [=============>................] - ETA: 1s - 5ms/item item 195/403 [=============>................] - ETA: 1s - 5ms/item item 196/403 [=============>................] - ETA: 1s - 5ms/item item 197/403 [=============>................] - ETA: 1s - 5ms/item item 198/403 [=============>................] - ETA: 1s - 5ms/item item 199/403 [=============>................] - ETA: 1s - 5ms/item item 200/403 [=============>................] - ETA: 1s - 5ms/item item 200/403 [=============>................] - ETA: 1s - 5ms/item item 201/403 [=============>................] - ETA: 1s - 5ms/item item 202/403 [==============>...............] - ETA: 1s - 5ms/item item 203/403 [==============>...............] - ETA: 1s - 5ms/item item 204/403 [==============>...............] - ETA: 1s - 5ms/item item 205/403 [==============>...............] - ETA: 1s - 5ms/item item 206/403 [==============>...............] - ETA: 1s - 5ms/item item 207/403 [==============>...............] - ETA: 1s - 5ms/item item 208/403 [==============>...............] - ETA: 1s - 5ms/item item 209/403 [==============>...............] - ETA: 1s - 5ms/item item 210/403 [==============>...............] - ETA: 0s - 5ms/item item 211/403 [==============>...............] - ETA: 0s - 5ms/item item 212/403 [==============>...............] - ETA: 0s - 5ms/item item 213/403 [==============>...............] - ETA: 0s - 5ms/item item 214/403 [==============>...............] - ETA: 0s - 5ms/item item 215/403 [==============>...............] - ETA: 0s - 5ms/item item 216/403 [===============>..............] - ETA: 0s - 5ms/item item 217/403 [===============>..............] - ETA: 0s - 5ms/item item 218/403 [===============>..............] - ETA: 0s - 5ms/item item 219/403 [===============>..............] - ETA: 0s - 5ms/item item 220/403 [===============>..............] - ETA: 0s - 5ms/item item 220/403 [===============>..............] - ETA: 0s - 5ms/item item 221/403 [===============>..............] - ETA: 0s - 5ms/item item 222/403 [===============>..............] - ETA: 0s - 5ms/item item 223/403 [===============>..............] - ETA: 0s - 5ms/item item 224/403 [===============>..............] - ETA: 0s - 5ms/item item 225/403 [===============>..............] - ETA: 0s - 5ms/item item 226/403 [===============>..............] - ETA: 0s - 5ms/item item 227/403 [===============>..............] - ETA: 0s - 5ms/item item 228/403 [===============>..............] - ETA: 0s - 5ms/item item 229/403 [================>.............] - ETA: 0s - 5ms/item item 230/403 [================>.............] - ETA: 0s - 5ms/item item 231/403 [================>.............] - ETA: 0s - 5ms/item item 232/403 [================>.............] - ETA: 0s - 5ms/item item 233/403 [================>.............] - ETA: 0s - 5ms/item item 234/403 [================>.............] - ETA: 0s - 5ms/item item 235/403 [================>.............] - ETA: 0s - 5ms/item item 236/403 [================>.............] - ETA: 0s - 5ms/item item 237/403 [================>.............] - ETA: 0s - 5ms/item item 238/403 [================>.............] - ETA: 0s - 5ms/item item 239/403 [================>.............] - ETA: 0s - 5ms/item item 240/403 [================>.............] - ETA: 0s - 5ms/item item 240/403 [================>.............] - ETA: 0s - 5ms/item item 241/403 [================>.............] - ETA: 0s - 5ms/item item 242/403 [================>.............] - ETA: 0s - 5ms/item item 243/403 [=================>............] - ETA: 0s - 5ms/item item 244/403 [=================>............] - ETA: 0s - 5ms/item item 245/403 [=================>............] - ETA: 0s - 5ms/item item 246/403 [=================>............] - ETA: 0s - 5ms/item item 247/403 [=================>............] - ETA: 0s - 5ms/item item 248/403 [=================>............] - ETA: 0s - 5ms/item item 249/403 [=================>............] - ETA: 0s - 5ms/item item 250/403 [=================>............] - ETA: 0s - 5ms/item item 251/403 [=================>............] - ETA: 0s - 5ms/item item 252/403 [=================>............] - ETA: 0s - 5ms/item item 253/403 [=================>............] - ETA: 0s - 5ms/item item 254/403 [=================>............] - ETA: 0s - 5ms/item item 255/403 [=================>............] - ETA: 0s - 5ms/item item 256/403 [==================>...........] - ETA: 0s - 5ms/item item 257/403 [==================>...........] - ETA: 0s - 5ms/item item 258/403 [==================>...........] - ETA: 0s - 5ms/item item 259/403 [==================>...........] - ETA: 0s - 5ms/item item 260/403 [==================>...........] - ETA: 0s - 5ms/item item 260/403 [==================>...........] - ETA: 0s - 5ms/item item 261/403 [==================>...........] - ETA: 0s - 5ms/item item 262/403 [==================>...........] - ETA: 0s - 5ms/item item 263/403 [==================>...........] - ETA: 0s - 5ms/item item 264/403 [==================>...........] - ETA: 0s - 5ms/item item 265/403 [==================>...........] - ETA: 0s - 5ms/item item 266/403 [==================>...........] - ETA: 0s - 5ms/item item 267/403 [==================>...........] - ETA: 0s - 5ms/item item 268/403 [==================>...........] - ETA: 0s - 5ms/item item 269/403 [==================>...........] - ETA: 0s - 5ms/item item 270/403 [===================>..........] - ETA: 0s - 5ms/item item 271/403 [===================>..........] - ETA: 0s - 5ms/item item 272/403 [===================>..........] - ETA: 0s - 4ms/item item 273/403 [===================>..........] - ETA: 0s - 5ms/item item 274/403 [===================>..........] - ETA: 0s - 5ms/item item 275/403 [===================>..........] - ETA: 0s - 5ms/item item 276/403 [===================>..........] - ETA: 0s - 5ms/item item 277/403 [===================>..........] - ETA: 0s - 5ms/item item 278/403 [===================>..........] - ETA: 0s - 5ms/item item 279/403 [===================>..........] - ETA: 0s - 4ms/item item 280/403 [===================>..........] - ETA: 0s - 4ms/item item 280/403 [===================>..........] - ETA: 0s - 4ms/item item 281/403 [===================>..........] - ETA: 0s - 4ms/item item 282/403 [===================>..........] - ETA: 0s - 4ms/item item 283/403 [====================>.........] - ETA: 0s - 4ms/item item 284/403 [====================>.........] - ETA: 0s - 4ms/item item 285/403 [====================>.........] - ETA: 0s - 4ms/item item 286/403 [====================>.........] - ETA: 0s - 4ms/item item 287/403 [====================>.........] - ETA: 0s - 4ms/item item 288/403 [====================>.........] - ETA: 0s - 4ms/item item 289/403 [====================>.........] - ETA: 0s - 4ms/item item 290/403 [====================>.........] - ETA: 0s - 4ms/item item 291/403 [====================>.........] - ETA: 0s - 4ms/item item 292/403 [====================>.........] - ETA: 0s - 4ms/item item 293/403 [====================>.........] - ETA: 0s - 4ms/item item 294/403 [====================>.........] - ETA: 0s - 4ms/item item 295/403 [====================>.........] - ETA: 0s - 4ms/item item 296/403 [=====================>........] - ETA: 0s - 4ms/item item 297/403 [=====================>........] - ETA: 0s - 4ms/item item 298/403 [=====================>........] - ETA: 0s - 4ms/item item 299/403 [=====================>........] - ETA: 0s - 4ms/item item 300/403 [=====================>........] - ETA: 0s - 4ms/item item 300/403 [=====================>........] - ETA: 0s - 4ms/item item 301/403 [=====================>........] - ETA: 0s - 4ms/item item 302/403 [=====================>........] - ETA: 0s - 4ms/item item 303/403 [=====================>........] - ETA: 0s - 4ms/item item 304/403 [=====================>........] - ETA: 0s - 4ms/item item 305/403 [=====================>........] - ETA: 0s - 4ms/item item 306/403 [=====================>........] - ETA: 0s - 4ms/item item 307/403 [=====================>........] - ETA: 0s - 4ms/item item 308/403 [=====================>........] - ETA: 0s - 4ms/item item 309/403 [=====================>........] - ETA: 0s - 4ms/item item 310/403 [======================>.......] - ETA: 0s - 4ms/item item 311/403 [======================>.......] - ETA: 0s - 4ms/item item 312/403 [======================>.......] - ETA: 0s - 4ms/item item 313/403 [======================>.......] - ETA: 0s - 4ms/item item 314/403 [======================>.......] - ETA: 0s - 4ms/item item 315/403 [======================>.......] - ETA: 0s - 4ms/item item 316/403 [======================>.......] - ETA: 0s - 4ms/item item 317/403 [======================>.......] - ETA: 0s - 4ms/item item 318/403 [======================>.......] - ETA: 0s - 4ms/item item 319/403 [======================>.......] - ETA: 0s - 4ms/item item 320/403 [======================>.......] - ETA: 0s - 4ms/item item 320/403 [======================>.......] - ETA: 0s - 4ms/item item 321/403 [======================>.......] - ETA: 0s - 4ms/item item 322/403 [======================>.......] - ETA: 0s - 4ms/item item 323/403 [=======================>......] - ETA: 0s - 4ms/item item 324/403 [=======================>......] - ETA: 0s - 4ms/item item 325/403 [=======================>......] - ETA: 0s - 4ms/item item 326/403 [=======================>......] - ETA: 0s - 4ms/item item 327/403 [=======================>......] - ETA: 0s - 4ms/item item 328/403 [=======================>......] - ETA: 0s - 4ms/item item 329/403 [=======================>......] - ETA: 0s - 4ms/item item 330/403 [=======================>......] - ETA: 0s - 4ms/item item 331/403 [=======================>......] - ETA: 0s - 4ms/item item 332/403 [=======================>......] - ETA: 0s - 4ms/item item 333/403 [=======================>......] - ETA: 0s - 4ms/item item 334/403 [=======================>......] - ETA: 0s - 4ms/item item 335/403 [=======================>......] - ETA: 0s - 4ms/item item 336/403 [=======================>......] - ETA: 0s - 4ms/item item 337/403 [========================>.....] - ETA: 0s - 4ms/item item 338/403 [========================>.....] - ETA: 0s - 4ms/item item 339/403 [========================>.....] - ETA: 0s - 4ms/item item 340/403 [========================>.....] - ETA: 0s - 4ms/item item 340/403 [========================>.....] - ETA: 0s - 4ms/item item 341/403 [========================>.....] - ETA: 0s - 4ms/item item 342/403 [========================>.....] - ETA: 0s - 4ms/item item 343/403 [========================>.....] - ETA: 0s - 4ms/item item 344/403 [========================>.....] - ETA: 0s - 4ms/item item 345/403 [========================>.....] - ETA: 0s - 4ms/item item 346/403 [========================>.....] - ETA: 0s - 4ms/item item 347/403 [========================>.....] - ETA: 0s - 4ms/item item 348/403 [========================>.....] - ETA: 0s - 4ms/item item 349/403 [========================>.....] - ETA: 0s - 4ms/item item 350/403 [=========================>....] - ETA: 0s - 4ms/item item 351/403 [=========================>....] - ETA: 0s - 4ms/item item 352/403 [=========================>....] - ETA: 0s - 4ms/item item 353/403 [=========================>....] - ETA: 0s - 4ms/item item 354/403 [=========================>....] - ETA: 0s - 4ms/item item 355/403 [=========================>....] - ETA: 0s - 4ms/item item 356/403 [=========================>....] - ETA: 0s - 4ms/item item 357/403 [=========================>....] - ETA: 0s - 4ms/item item 358/403 [=========================>....] - ETA: 0s - 4ms/item item 359/403 [=========================>....] - ETA: 0s - 4ms/item item 360/403 [=========================>....] - ETA: 0s - 4ms/item item 360/403 [=========================>....] - ETA: 0s - 4ms/item item 361/403 [=========================>....] - ETA: 0s - 4ms/item item 362/403 [=========================>....] - ETA: 0s - 4ms/item item 363/403 [=========================>....] - ETA: 0s - 4ms/item item 364/403 [==========================>...] - ETA: 0s - 4ms/item item 365/403 [==========================>...] - ETA: 0s - 4ms/item item 366/403 [==========================>...] - ETA: 0s - 4ms/item item 367/403 [==========================>...] - ETA: 0s - 4ms/item item 368/403 [==========================>...] - ETA: 0s - 4ms/item item 369/403 [==========================>...] - ETA: 0s - 4ms/item item 370/403 [==========================>...] - ETA: 0s - 4ms/item item 371/403 [==========================>...] - ETA: 0s - 4ms/item item 372/403 [==========================>...] - ETA: 0s - 4ms/item item 373/403 [==========================>...] - ETA: 0s - 4ms/item item 374/403 [==========================>...] - ETA: 0s - 4ms/item item 375/403 [==========================>...] - ETA: 0s - 4ms/item item 376/403 [==========================>...] - ETA: 0s - 4ms/item item 377/403 [===========================>..] - ETA: 0s - 4ms/item item 378/403 [===========================>..] - ETA: 0s - 4ms/item item 379/403 [===========================>..] - ETA: 0s - 4ms/item item 380/403 [===========================>..] - ETA: 0s - 4ms/item item 380/403 [===========================>..] - ETA: 0s - 4ms/item item 381/403 [===========================>..] - ETA: 0s - 4ms/item item 382/403 [===========================>..] - ETA: 0s - 4ms/item item 383/403 [===========================>..] - ETA: 0s - 4ms/item item 384/403 [===========================>..] - ETA: 0s - 4ms/item item 385/403 [===========================>..] - ETA: 0s - 4ms/item item 386/403 [===========================>..] - ETA: 0s - 4ms/item item 387/403 [===========================>..] - ETA: 0s - 4ms/item item 388/403 [===========================>..] - ETA: 0s - 4ms/item item 389/403 [===========================>..] - ETA: 0s - 4ms/item item 390/403 [===========================>..] - ETA: 0s - 4ms/item item 391/403 [============================>.] - ETA: 0s - 4ms/item item 392/403 [============================>.] - ETA: 0s - 4ms/item item 393/403 [============================>.] - ETA: 0s - 4ms/item item 394/403 [============================>.] - ETA: 0s - 4ms/item item 395/403 [============================>.] - ETA: 0s - 4ms/item item 396/403 [============================>.] - ETA: 0s - 4ms/item item 397/403 [============================>.] - ETA: 0s - 4ms/item item 398/403 [============================>.] - ETA: 0s - 4ms/item item 399/403 [============================>.] - ETA: 0s - 4ms/item item 400/403 [============================>.] - ETA: 0s - 4ms/item item 400/403 [============================>.] - ETA: 0s - 4ms/item item 401/403 [============================>.] - ETA: 0s - 4ms/item item 402/403 [============================>.] - ETA: 0s - 4ms/item item 403/403 [============================>.] - ETA: 0s - 4ms/item Download finished Cache file /home/wzever/.cache/paddle/dataset/mnist/t10k-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-labels-idx1-ubyte.gz Begin to download item 1/2 [=============>................] - ETA: 0s - 105us/item item 1/2 [=============>................] - ETA: 0s - 241us/item item 2/2 [===========================>..] - ETA: 0s - 170us/item item 2/2 [===========================>..] - ETA: 0s - 187us/item Download finished .. GENERATED FROM PYTHON SOURCE LINES 96-98 Print the layers of the simple CNN model: .. GENERATED FROM PYTHON SOURCE LINES 98-100 .. code-block:: default print(model1) .. rst-class:: sphx-glr-script-out .. code-block:: none SimpleNet( (conv1): Conv2D(1, 32, kernel_size=[5, 5], padding=1, padding_mode=replicate, data_format=NCHW) (max_pool): MaxPool2D(kernel_size=2, stride=None, padding=1) (conv2): Conv2D(32, 64, kernel_size=[5, 5], padding=1, padding_mode=replicate, data_format=NCHW) (fc1): Linear(in_features=3136, out_features=32, dtype=None) (fc2): Linear(in_features=32, out_features=10, dtype=None) ) .. GENERATED FROM PYTHON SOURCE LINES 101-104 Test the input models ------------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 104-119 .. code-block:: default with paddle.no_grad(): n_correct1 = 0 n_correct2 = 0 n_samples = 0 for images, labels in test_loader: outputs1 = model1(images) outputs2 = model2(images) predictions1 = paddle.argmax(outputs1, 1) predictions2 = paddle.argmax(outputs2, 1) n_samples += labels.shape[0] n_correct1 += (predictions1 == labels.t()).sum().item() n_correct2 += (predictions2 == labels.t()).sum().item() acc1 = 100 * n_correct1 / n_samples acc2 = 100 * n_correct2 / n_samples .. GENERATED FROM PYTHON SOURCE LINES 120-122 Testing results (two separate models): .. GENERATED FROM PYTHON SOURCE LINES 122-124 .. code-block:: default print(f'model1 accuracy = {acc1}%, model2 accuracy = {acc2}%') .. rst-class:: sphx-glr-script-out .. code-block:: none model1 accuracy = 84.18%, model2 accuracy = 83.81% .. GENERATED FROM PYTHON SOURCE LINES 125-131 Build the affinity matrix for graph matching --------------------------------------------- As shown in the following plot, the neural networks can be regarded as graphs. The weights correspond to the edge features, and the bias corresponds to the node features. In this example, the neural network does not have bias so that there are only edge features. .. GENERATED FROM PYTHON SOURCE LINES 131-138 .. code-block:: default plt.figure(figsize=(8, 4)) img = Image.open('../data/model_fusion.png') plt.imshow(img) plt.axis('off') st_time = time.perf_counter() .. image-sg:: /auto_examples/5.model_fusion/images/sphx_glr_plot_model_fusion_paddle_001.png :alt: plot model fusion paddle :srcset: /auto_examples/5.model_fusion/images/sphx_glr_plot_model_fusion_paddle_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 139-141 Define the graph matching affinity metric function .. GENERATED FROM PYTHON SOURCE LINES 141-179 .. code-block:: default class Ground_Metric_GM: def __init__(self, model_1_param: paddle.Tensor = None, model_2_param: paddle.Tensor = None, conv_param: bool = False, bias_param: bool = False, pre_conv_param: bool = False, pre_conv_image_size_squared: int = None): self.model_1_param = model_1_param self.model_2_param = model_2_param self.conv_param = conv_param self.bias_param = bias_param # bias, or fully-connected from linear if bias_param is True or (conv_param is False and pre_conv_param is False): self.model_1_param = self.model_1_param.reshape((1, -1, 1)) self.model_2_param = self.model_2_param.reshape((1, -1, 1)) # fully-connected from conv elif conv_param is False and pre_conv_param is True: self.model_1_param = self.model_1_param.reshape((1, -1, pre_conv_image_size_squared)) self.model_2_param = self.model_2_param.reshape((1, -1, pre_conv_image_size_squared)) # conv else: self.model_1_param = self.model_1_param.reshape((1, -1, model_1_param.shape[-1])) self.model_2_param = self.model_2_param.reshape((1, -1, model_2_param.shape[-1])) def process_distance(self, p: int = 2): dist = [] cdist = paddle.nn.PairwiseDistance(p) param_1 = self.model_1_param.cast('float32')[0] param_2 = self.model_2_param.cast('float32')[0] for i in param_1: dist.append(cdist(i.broadcast_to(param_2.shape), param_2)) return paddle.to_tensor(dist) def process_soft_affinity(self, p: int = 2): return paddle.exp(0 - self.process_distance(p=p)) .. GENERATED FROM PYTHON SOURCE LINES 180-183 Define the affinity function between two neural networks. This function takes multiple neural network modules, and construct the corresponding affinity matrix which is further processed by the graph matching solver. .. GENERATED FROM PYTHON SOURCE LINES 183-312 .. code-block:: default def graph_matching_fusion(networks: list): def total_node_num(network: paddle.nn.Layer): # count the total number of nodes in the network [network] num_nodes = 0 for idx, (name, parameters) in enumerate(network.named_parameters()): if 'bias' in name: continue if idx == 0: num_nodes += parameters.shape[1] # transpose linear layers in paddle to conventional shape, num_nodes += parameters.shape[0] if 'fc' not in name else parameters.shape[1] return num_nodes n1 = total_node_num(network=networks[0]) n2 = total_node_num(network=networks[1]) assert (n1 == n2) affinity = paddle.zeros([n1 * n2, n1 * n2]) num_layers = len(list(zip(networks[0].parameters(), networks[1].parameters()))) num_nodes_before = 0 num_nodes_incremental = [] num_nodes_layers = [] pre_conv_list = [] cur_conv_list = [] conv_kernel_size_list = [] num_nodes_pre = 0 is_conv = False pre_conv = False pre_conv_out_channel = 1 is_final_bias = False perm_is_complete = True named_weight_list_0 = [named_parameter for named_parameter in networks[0].named_parameters()] for idx, ((name_0, fc_layer0_weight), (name_1, fc_layer1_weight)) in \ enumerate(zip(networks[0].named_parameters(), networks[1].named_parameters())): assert fc_layer0_weight.shape == fc_layer1_weight.shape if 'fc' in name_0: fc_layer0_weight = fc_layer0_weight.t() fc_layer1_weight = fc_layer1_weight.t() layer_shape = fc_layer0_weight.shape num_nodes_cur = fc_layer0_weight.shape[0] if len(layer_shape) > 1: if is_conv is True and len(layer_shape) == 2: num_nodes_pre = pre_conv_out_channel else: num_nodes_pre = fc_layer0_weight.shape[1] if idx >= 1 and len(named_weight_list_0[idx - 1][1].shape) == 1: pre_bias = True else: pre_bias = False if len(layer_shape) > 2: is_bias = False if not pre_bias: pre_conv = is_conv pre_conv_list.append(pre_conv) is_conv = True cur_conv_list.append(is_conv) fc_layer0_weight_data = fc_layer0_weight.detach().reshape( (fc_layer0_weight.shape[0], fc_layer0_weight.shape[1], -1)) fc_layer1_weight_data = fc_layer1_weight.detach().reshape( (fc_layer1_weight.shape[0], fc_layer1_weight.shape[1], -1)) elif len(layer_shape) == 2: is_bias = False if not pre_bias: pre_conv = is_conv pre_conv_list.append(pre_conv) is_conv = False cur_conv_list.append(is_conv) fc_layer0_weight_data = fc_layer0_weight.detach() fc_layer1_weight_data = fc_layer1_weight.detach() else: is_bias = True if not pre_bias: pre_conv = is_conv pre_conv_list.append(pre_conv) is_conv = False cur_conv_list.append(is_conv) fc_layer0_weight_data = fc_layer0_weight.detach() fc_layer1_weight_data = fc_layer1_weight.detach() if is_conv: pre_conv_out_channel = num_nodes_cur if is_bias is True and idx == num_layers - 1: is_final_bias = True if idx == 0: for a in range(num_nodes_pre): affinity[(num_nodes_before + a) * n2 + num_nodes_before + a, \ (num_nodes_before + a) * n2 + num_nodes_before + a] \ = 1 if idx == num_layers - 2 and 'bias' in named_weight_list_0[idx + 1][0] or \ idx == num_layers - 1 and 'bias' not in named_weight_list_0[idx][0]: for a in range(num_nodes_cur): affinity[(num_nodes_before + num_nodes_pre + a) * n2 + num_nodes_before + num_nodes_pre + a, \ (num_nodes_before + num_nodes_pre + a) * n2 + num_nodes_before + num_nodes_pre + a] \ = 1 if is_bias is False: ground_metric = Ground_Metric_GM( fc_layer0_weight_data, fc_layer1_weight_data, is_conv, is_bias, pre_conv, int(fc_layer0_weight_data.shape[1] / pre_conv_out_channel)) else: ground_metric = Ground_Metric_GM( fc_layer0_weight_data, fc_layer1_weight_data, is_conv, is_bias, pre_conv, 1) layer_affinity = ground_metric.process_soft_affinity(p=2) if is_bias is False: pre_conv_kernel_size = fc_layer0_weight.shape[3] if is_conv else None conv_kernel_size_list.append(pre_conv_kernel_size) if is_bias is True and is_final_bias is False: for a in range(num_nodes_cur): for c in range(num_nodes_cur): affinity[(num_nodes_before + a) * n2 + num_nodes_before + c, \ (num_nodes_before + a) * n2 + num_nodes_before + c] \ = layer_affinity[a][c] elif is_final_bias is False: for a in range(num_nodes_pre): for b in range(num_nodes_cur): affinity[ (num_nodes_before + a) * n2 + num_nodes_before: (num_nodes_before + a) * n2 + num_nodes_before + num_nodes_pre, (num_nodes_before + num_nodes_pre + b) * n2 + num_nodes_before + num_nodes_pre: (num_nodes_before + num_nodes_pre + b) * n2 + num_nodes_before + num_nodes_pre + num_nodes_cur] \ = layer_affinity[a + b * num_nodes_pre].reshape((num_nodes_cur, num_nodes_pre)).t() if is_bias is False: num_nodes_before += num_nodes_pre num_nodes_incremental.append(num_nodes_before) num_nodes_layers.append(num_nodes_cur) # affinity = (affinity + affinity.t()) / 2 return affinity, [n1, n2, num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list] .. GENERATED FROM PYTHON SOURCE LINES 313-315 Get the affinity (similarity) matrix between model1 and model2. .. GENERATED FROM PYTHON SOURCE LINES 315-317 .. code-block:: default K, params = graph_matching_fusion([model1, model2]) .. GENERATED FROM PYTHON SOURCE LINES 318-322 Align the models by graph matching ----------------------------------- Align the channels of model1 & model2 by maximize the affinity (similarity) via graph matching algorithms. .. GENERATED FROM PYTHON SOURCE LINES 322-326 .. code-block:: default n1 = params[0] n2 = params[1] X = pygm.sm(K, n1, n2) .. GENERATED FROM PYTHON SOURCE LINES 327-334 Project ``X`` to neural network matching result. The neural network matching matrix is built by applying Hungarian to small blocks of ``X``, because only the channels from the same neural network layer can be matched. .. note:: In this example, we assume the last FC layer is aligned and need not be matched. .. GENERATED FROM PYTHON SOURCE LINES 334-344 .. code-block:: default new_X = paddle.zeros_like(X) new_X[:params[2][0], :params[2][0]] = paddle.eye(params[2][0]) for start_idx, length in zip(params[2][:-1], params[3][:-1]): # params[2] and params[3] are the indices of layers slicing = slice(start_idx, start_idx + length) new_X[slicing, slicing] = pygm.hungarian(X[slicing, slicing]) # assume the last FC layer is aligned slicing = slice(params[2][-1], params[2][-1] + params[3][-1]) new_X[slicing, slicing] = paddle.eye(params[3][-1]) X = new_X .. GENERATED FROM PYTHON SOURCE LINES 345-347 Visualization of the matching result. The black lines splits the channels of different layers. .. GENERATED FROM PYTHON SOURCE LINES 347-354 .. code-block:: default plt.figure(figsize=(4, 4)) plt.imshow(X.cpu().numpy(), cmap='Blues') for idx in params[2]: plt.axvline(x=idx, color='k') plt.axhline(y=idx, color='k') .. image-sg:: /auto_examples/5.model_fusion/images/sphx_glr_plot_model_fusion_paddle_002.png :alt: plot model fusion paddle :srcset: /auto_examples/5.model_fusion/images/sphx_glr_plot_model_fusion_paddle_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 355-357 Define the alignment function: fuse the models based on matching result .. GENERATED FROM PYTHON SOURCE LINES 357-404 .. code-block:: default def align(solution, fusion_proportion, networks: list, params: list): [_, _, num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list] = params named_weight_list_0 = [named_parameter for named_parameter in networks[0].named_parameters()] aligned_wt_0 = [parameter.detach() if 'fc' not in name else parameter.detach().t() for name, parameter in named_weight_list_0] idx = 0 num_layers = len(aligned_wt_0) for num_before, num_cur, cur_conv, cur_kernel_size in \ zip(num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list): perm = solution[num_before:num_before + num_cur, num_before:num_before + num_cur] assert 'bias' not in named_weight_list_0[idx][0] if len(named_weight_list_0[idx][1].shape) == 4: aligned_wt_0[idx] = (perm.t().cast(paddle.float64) @ aligned_wt_0[idx].cast(paddle.float64).transpose((2, 3, 0, 1))) \ .transpose((2, 3, 0, 1)) else: aligned_wt_0[idx] = perm.t().cast(paddle.float64) @ aligned_wt_0[idx].cast(paddle.float64) idx += 1 if idx >= num_layers: continue if 'bias' in named_weight_list_0[idx][0]: aligned_wt_0[idx] = aligned_wt_0[idx].cast(paddle.float64) @ perm.cast(paddle.float64) idx += 1 if idx >= num_layers: continue if cur_conv and len(named_weight_list_0[idx][1].shape) == 2: aligned_wt_0[idx] = (aligned_wt_0[idx].cast(paddle.float64) .reshape((aligned_wt_0[idx].shape[0], 64, -1)) .transpose((0, 2, 1)) @ perm.cast(paddle.float64)) \ .transpose((0, 2, 1)) \ .reshape((aligned_wt_0[idx].shape[0], -1)) elif len(named_weight_list_0[idx][1].shape) == 4: aligned_wt_0[idx] = (aligned_wt_0[idx].cast(paddle.float64) .transpose((2, 3, 0, 1)) @ perm.cast(paddle.float64)) \ .transpose((2, 3, 0, 1)) else: aligned_wt_0[idx] = aligned_wt_0[idx].cast(paddle.float64) @ perm.cast(paddle.float64) assert idx == num_layers averaged_weights = [] for idx, (named, parameter) in enumerate(networks[1].named_parameters()): parameter = parameter.t() if 'fc' in named else parameter averaged_weights.append((1 - fusion_proportion) * aligned_wt_0[idx].cast('float32') + fusion_proportion * parameter) return averaged_weights .. GENERATED FROM PYTHON SOURCE LINES 405-410 Test the fused model --------------------- The ``fusion_proportion`` variable denotes the contribution to the new model. For example, if ``fusion_proportion=0.2``, the fused model = 80% model1 + 20% model2. .. GENERATED FROM PYTHON SOURCE LINES 410-439 .. code-block:: default def align_model_and_test(X): acc_list = [] for fusion_proportion in paddle.arange(0, 11, 1) / 10: # paddle arange accepts int step only fused_weights = align(X, fusion_proportion, [model1, model2], params) fused_model = SimpleNet() state_dict = fused_model.state_dict() for idx, (key, _) in enumerate(state_dict.items()): state_dict[key] = fused_weights[idx].t() if 'fc' in key else fused_weights[idx] fused_model.set_dict(state_dict) fused_model.to(device) test_loss = 0 correct = 0 for data, target in test_loader: output = fused_model(data) test_loss += F.nll_loss(output, target, reduction='sum').item() pred = output.detach().argmax(1, keepdim=True) correct += pred.equal(target.detach().reshape(pred.shape)).sum() test_loss /= len(test_loader.dataset) acc = 100. * correct / len(test_loader.dataset) print( f"{1 - fusion_proportion.item():.2f} model1 + {fusion_proportion.item():.2f} model2 -> fused model accuracy: {acc.item():.2f}%") acc_list.append(acc) return paddle.to_tensor(acc_list) print('Graph Matching Fusion') gm_acc_list = align_model_and_test(X) .. rst-class:: sphx-glr-script-out .. code-block:: none Graph Matching Fusion 1.00 model1 + 0.00 model2 -> fused model accuracy: 84.18% 0.90 model1 + 0.10 model2 -> fused model accuracy: 85.12% 0.80 model1 + 0.20 model2 -> fused model accuracy: 85.21% 0.70 model1 + 0.30 model2 -> fused model accuracy: 82.52% 0.60 model1 + 0.40 model2 -> fused model accuracy: 71.11% 0.50 model1 + 0.50 model2 -> fused model accuracy: 53.74% 0.40 model1 + 0.60 model2 -> fused model accuracy: 63.26% 0.30 model1 + 0.70 model2 -> fused model accuracy: 78.51% 0.20 model1 + 0.80 model2 -> fused model accuracy: 82.81% 0.10 model1 + 0.90 model2 -> fused model accuracy: 83.97% 0.00 model1 + 1.00 model2 -> fused model accuracy: 83.81% .. GENERATED FROM PYTHON SOURCE LINES 440-442 Compare with vanilla model fusion (no matching), graph matching method stabilizes the fusion step: .. GENERATED FROM PYTHON SOURCE LINES 442-456 .. code-block:: default print('No Matching Fusion') vanilla_acc_list = align_model_and_test(paddle.eye(n1)) plt.figure(figsize=(4, 4)) plt.title('Fused Model Accuracy') plt.plot((paddle.arange(0, 11, 1) / 10).numpy(), gm_acc_list.cpu().numpy(), 'r*-', label='Graph Matching Fusion') plt.plot((paddle.arange(0, 11, 1) / 10).numpy(), vanilla_acc_list.cpu().numpy(), 'b*-', label='No Matching Fusion') plt.plot((paddle.arange(0, 11, 1) / 10).numpy(), [acc1] * 11, '--', color="gray", label='Model1 Accuracy') plt.plot((paddle.arange(0, 11, 1) / 10).numpy(), [acc2] * 11, '--', color="brown", label='Model2 Accuracy') plt.gca().set_xlabel('Fusion Proportion') plt.gca().set_ylabel('Accuracy (%)') plt.ylim((70, 87)) plt.legend(loc=3) .. image-sg:: /auto_examples/5.model_fusion/images/sphx_glr_plot_model_fusion_paddle_003.png :alt: Fused Model Accuracy :srcset: /auto_examples/5.model_fusion/images/sphx_glr_plot_model_fusion_paddle_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none No Matching Fusion 1.00 model1 + 0.00 model2 -> fused model accuracy: 84.18% 0.90 model1 + 0.10 model2 -> fused model accuracy: 84.01% 0.80 model1 + 0.20 model2 -> fused model accuracy: 81.91% 0.70 model1 + 0.30 model2 -> fused model accuracy: 74.67% 0.60 model1 + 0.40 model2 -> fused model accuracy: 60.39% 0.50 model1 + 0.50 model2 -> fused model accuracy: 47.16% 0.40 model1 + 0.60 model2 -> fused model accuracy: 55.34% 0.30 model1 + 0.70 model2 -> fused model accuracy: 72.86% 0.20 model1 + 0.80 model2 -> fused model accuracy: 79.64% 0.10 model1 + 0.90 model2 -> fused model accuracy: 82.56% 0.00 model1 + 1.00 model2 -> fused model accuracy: 83.81% .. GENERATED FROM PYTHON SOURCE LINES 457-460 Print the result summary ------------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 460-465 .. code-block:: default end_time = time.perf_counter() print(f'time consumed for model fusion: {end_time - st_time:.2f} seconds') print(f'model1 accuracy = {acc1}%, model2 accuracy = {acc2}%') print(f"best fused model accuracy: {(paddle.max(gm_acc_list)).item():.2f}%") .. rst-class:: sphx-glr-script-out .. code-block:: none time consumed for model fusion: 1836.18 seconds model1 accuracy = 84.18%, model2 accuracy = 83.81% best fused model accuracy: 85.21% .. GENERATED FROM PYTHON SOURCE LINES 466-470 .. note:: This example supports both GPU and CPU, and the online documentation is built by a CPU-only machine. The efficiency will be significantly improved if you run this code on GPU. .. rst-class:: sphx-glr-timing **Total running time of the script:** (30 minutes 44.030 seconds) .. _sphx_glr_download_auto_examples_5.model_fusion_plot_model_fusion_paddle.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_model_fusion_paddle.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_model_fusion_paddle.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_