.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/5.model_fusion/plot_model_fusion_pytorch.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_5.model_fusion_plot_model_fusion_pytorch.py: ======================================================= PyTorch Backend Example: Model Fusion by Graph Matching ======================================================= This example shows how to fuse different models into a single model by ``pygmtools``. Model fusion aims to fuse multiple models into one, such that the fused model could have higher performance. The neural networks can be regarded as graphs (channels - nodes, update functions between channels - edges; node feature - bias, edge feature - weights), and fusing the models is equivalent to solving a graph matching problem. In this example, the given models are trained on MNIST data from different distributions, and the fused model could combine the knowledge from two input models and can reach higher accuracy when testing. .. GENERATED FROM PYTHON SOURCE LINES 14-20 .. code-block:: default # Author: Chang Liu # Runzhong Wang # # License: Mulan PSL v2 License .. GENERATED FROM PYTHON SOURCE LINES 22-33 .. note:: This is a simplified implementation of the ideas in `Liu et al. Deep Neural Network Fusion via Graph Matching with Applications to Model Ensemble and Federated Learning. ICML 2022. `_ For more details, please refer to the paper and the `official code repository `_. .. note:: The following solvers are included in this example: * :func:`~pygmtools.classic_solvers.sm` (classic solver) * :func:`~pygmtools.linear_solvers.hungarian` (linear solver) .. GENERATED FROM PYTHON SOURCE LINES 33-48 .. code-block:: default import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as transforms import time from PIL import Image import matplotlib.pyplot as plt import pygmtools as pygm pygm.set_backend('pytorch') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') .. GENERATED FROM PYTHON SOURCE LINES 49-52 Define a simple CNN classifier network --------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 52-72 .. code-block:: default class SimpleNet(nn.Module): def __init__(self): super(SimpleNet, self).__init__() self.conv1 = nn.Conv2d(1, 32, 5, padding=1, padding_mode='replicate', bias=False) self.max_pool = nn.MaxPool2d(2, padding=1) self.conv2 = nn.Conv2d(32, 64, 5, padding=1, padding_mode='replicate', bias=False) self.fc1 = nn.Linear(3136, 32, bias=False) self.fc2 = nn.Linear(32, 10, bias=False) def forward(self, x): output = F.relu(self.conv1(x)) output = self.max_pool(output) output = F.relu(self.conv2(output)) output = self.max_pool(output) output = output.view(output.shape[0], -1) output = self.fc1(output) output = self.fc2(output) return output .. GENERATED FROM PYTHON SOURCE LINES 73-76 Load the trained models to be fused ------------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 76-92 .. code-block:: default model1 = SimpleNet() model2 = SimpleNet() model1.load_state_dict(torch.load('../data/example_model_fusion_1.dat', map_location=device)) model2.load_state_dict(torch.load('../data/example_model_fusion_2.dat', map_location=device)) model1.to(device) model2.to(device) test_dataset = torchvision.datasets.MNIST( root='../data/mnist_data', # the directory to store the dataset train=False, # the dataset is used to test transform=transforms.ToTensor(), # the dataset is in the form of tensors download=True) test_loader = torch.utils.data.DataLoader( dataset=test_dataset, batch_size=32, shuffle=False) .. GENERATED FROM PYTHON SOURCE LINES 93-95 Print the layers of the simple CNN model: .. GENERATED FROM PYTHON SOURCE LINES 95-97 .. code-block:: default print(model1) .. rst-class:: sphx-glr-script-out .. code-block:: none SimpleNet( (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=replicate) (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False) (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=replicate) (fc1): Linear(in_features=3136, out_features=32, bias=False) (fc2): Linear(in_features=32, out_features=10, bias=False) ) .. GENERATED FROM PYTHON SOURCE LINES 98-101 Test the input models ------------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 101-118 .. code-block:: default with torch.no_grad(): n_correct1 = 0 n_correct2 = 0 n_samples = 0 for images, labels in test_loader: images = images.to(device) labels = labels.to(device) outputs1 = model1(images) outputs2 = model2(images) _, predictions1 = torch.max(outputs1, 1) _, predictions2 = torch.max(outputs2, 1) n_samples += labels.shape[0] n_correct1 += (predictions1 == labels).sum().item() n_correct2 += (predictions2 == labels).sum().item() acc1 = 100 * n_correct1 / n_samples acc2 = 100 * n_correct2 / n_samples .. GENERATED FROM PYTHON SOURCE LINES 119-121 Testing results (two separate models): .. GENERATED FROM PYTHON SOURCE LINES 121-123 .. code-block:: default print(f'model1 accuracy = {acc1}%, model2 accuracy = {acc2}%') .. rst-class:: sphx-glr-script-out .. code-block:: none model1 accuracy = 84.18%, model2 accuracy = 83.8% .. GENERATED FROM PYTHON SOURCE LINES 124-130 Build the affinity matrix for graph matching --------------------------------------------- As shown in the following plot, the neural networks can be regarded as graphs. The weights correspond to the edge features, and the bias corresponds to the node features. In this example, the neural network does not have bias so that there are only edge features. .. GENERATED FROM PYTHON SOURCE LINES 130-137 .. code-block:: default plt.figure(figsize=(8, 4)) img = Image.open('../data/model_fusion.png') plt.imshow(img) plt.axis('off') st_time = time.perf_counter() .. image-sg:: /auto_examples/5.model_fusion/images/sphx_glr_plot_model_fusion_pytorch_001.png :alt: plot model fusion pytorch :srcset: /auto_examples/5.model_fusion/images/sphx_glr_plot_model_fusion_pytorch_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 138-140 Define the graph matching affinity metric function .. GENERATED FROM PYTHON SOURCE LINES 140-175 .. code-block:: default class Ground_Metric_GM: def __init__(self, model_1_param: torch.tensor = None, model_2_param: torch.tensor = None, conv_param: bool = False, bias_param: bool = False, pre_conv_param: bool = False, pre_conv_image_size_squared: int = None): self.model_1_param = model_1_param self.model_2_param = model_2_param self.conv_param = conv_param self.bias_param = bias_param # bias, or fully-connected from linear if bias_param is True or (conv_param is False and pre_conv_param is False): self.model_1_param = self.model_1_param.reshape(1, -1, 1) self.model_2_param = self.model_2_param.reshape(1, -1, 1) # fully-connected from conv elif conv_param is False and pre_conv_param is True: self.model_1_param = self.model_1_param.reshape(1, -1, pre_conv_image_size_squared) self.model_2_param = self.model_2_param.reshape(1, -1, pre_conv_image_size_squared) # conv else: self.model_1_param = self.model_1_param.reshape(1, -1, model_1_param.shape[-1]) self.model_2_param = self.model_2_param.reshape(1, -1, model_2_param.shape[-1]) def process_distance(self, p: int = 2): return torch.cdist( self.model_1_param.to(torch.float), self.model_2_param.to(torch.float), p=p)[0] def process_soft_affinity(self, p: int = 2): return torch.exp(0 - self.process_distance(p=p)) .. GENERATED FROM PYTHON SOURCE LINES 176-179 Define the affinity function between two neural networks. This function takes multiple neural network modules, and construct the corresponding affinity matrix which is further processed by the graph matching solver. .. GENERATED FROM PYTHON SOURCE LINES 179-304 .. code-block:: default def graph_matching_fusion(networks: list): def total_node_num(network: torch.nn.Module): # count the total number of nodes in the network [network] num_nodes = 0 for idx, (name, parameters) in enumerate(network.named_parameters()): if 'bias' in name: continue if idx == 0: num_nodes += parameters.shape[1] num_nodes += parameters.shape[0] return num_nodes n1 = total_node_num(network=networks[0]) n2 = total_node_num(network=networks[1]) assert (n1 == n2) affinity = torch.zeros([n1 * n2, n1 * n2], device=device) num_layers = len(list(zip(networks[0].parameters(), networks[1].parameters()))) num_nodes_before = 0 num_nodes_incremental = [] num_nodes_layers = [] pre_conv_list = [] cur_conv_list = [] conv_kernel_size_list = [] num_nodes_pre = 0 is_conv = False pre_conv = False pre_conv_out_channel = 1 is_final_bias = False perm_is_complete = True named_weight_list_0 = [named_parameter for named_parameter in networks[0].named_parameters()] for idx, ((_, fc_layer0_weight), (_, fc_layer1_weight)) in \ enumerate(zip(networks[0].named_parameters(), networks[1].named_parameters())): assert fc_layer0_weight.shape == fc_layer1_weight.shape layer_shape = fc_layer0_weight.shape num_nodes_cur = fc_layer0_weight.shape[0] if len(layer_shape) > 1: if is_conv is True and len(layer_shape) == 2: num_nodes_pre = pre_conv_out_channel else: num_nodes_pre = fc_layer0_weight.shape[1] if idx >= 1 and len(named_weight_list_0[idx - 1][1].shape) == 1: pre_bias = True else: pre_bias = False if len(layer_shape) > 2: is_bias = False if not pre_bias: pre_conv = is_conv pre_conv_list.append(pre_conv) is_conv = True cur_conv_list.append(is_conv) fc_layer0_weight_data = fc_layer0_weight.data.view( fc_layer0_weight.shape[0], fc_layer0_weight.shape[1], -1) fc_layer1_weight_data = fc_layer1_weight.data.view( fc_layer1_weight.shape[0], fc_layer1_weight.shape[1], -1) elif len(layer_shape) == 2: is_bias = False if not pre_bias: pre_conv = is_conv pre_conv_list.append(pre_conv) is_conv = False cur_conv_list.append(is_conv) fc_layer0_weight_data = fc_layer0_weight.data fc_layer1_weight_data = fc_layer1_weight.data else: is_bias = True if not pre_bias: pre_conv = is_conv pre_conv_list.append(pre_conv) is_conv = False cur_conv_list.append(is_conv) fc_layer0_weight_data = fc_layer0_weight.data fc_layer1_weight_data = fc_layer1_weight.data if is_conv: pre_conv_out_channel = num_nodes_cur if is_bias is True and idx == num_layers - 1: is_final_bias = True if idx == 0: for a in range(num_nodes_pre): affinity[(num_nodes_before + a) * n2 + num_nodes_before + a] \ [(num_nodes_before + a) * n2 + num_nodes_before + a] \ = 1 if idx == num_layers - 2 and 'bias' in named_weight_list_0[idx + 1][0] or \ idx == num_layers - 1 and 'bias' not in named_weight_list_0[idx][0]: for a in range(num_nodes_cur): affinity[(num_nodes_before + num_nodes_pre + a) * n2 + num_nodes_before + num_nodes_pre + a] \ [(num_nodes_before + num_nodes_pre + a) * n2 + num_nodes_before + num_nodes_pre + a] \ = 1 if is_bias is False: ground_metric = Ground_Metric_GM( fc_layer0_weight_data, fc_layer1_weight_data, is_conv, is_bias, pre_conv, int(fc_layer0_weight_data.shape[1] / pre_conv_out_channel)) else: ground_metric = Ground_Metric_GM( fc_layer0_weight_data, fc_layer1_weight_data, is_conv, is_bias, pre_conv, 1) layer_affinity = ground_metric.process_soft_affinity(p=2) if is_bias is False: pre_conv_kernel_size = fc_layer0_weight.shape[3] if is_conv else None conv_kernel_size_list.append(pre_conv_kernel_size) if is_bias is True and is_final_bias is False: for a in range(num_nodes_cur): for c in range(num_nodes_cur): affinity[(num_nodes_before + a) * n2 + num_nodes_before + c] \ [(num_nodes_before + a) * n2 + num_nodes_before + c] \ = layer_affinity[a][c] elif is_final_bias is False: for a in range(num_nodes_pre): for b in range(num_nodes_cur): affinity[ (num_nodes_before + a) * n2 + num_nodes_before: (num_nodes_before + a) * n2 + num_nodes_before + num_nodes_pre, (num_nodes_before + num_nodes_pre + b) * n2 + num_nodes_before + num_nodes_pre: (num_nodes_before + num_nodes_pre + b) * n2 + num_nodes_before + num_nodes_pre + num_nodes_cur] \ = layer_affinity[a + b * num_nodes_pre].view(num_nodes_cur, num_nodes_pre).transpose(0, 1) if is_bias is False: num_nodes_before += num_nodes_pre num_nodes_incremental.append(num_nodes_before) num_nodes_layers.append(num_nodes_cur) # affinity = (affinity + affinity.t()) / 2 return affinity, [n1, n2, num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list] .. GENERATED FROM PYTHON SOURCE LINES 305-307 Get the affinity (similarity) matrix between model1 and model2. .. GENERATED FROM PYTHON SOURCE LINES 307-309 .. code-block:: default K, params = graph_matching_fusion([model1, model2]) .. GENERATED FROM PYTHON SOURCE LINES 310-314 Align the models by graph matching ----------------------------------- Align the channels of model1 & model2 by maximize the affinity (similarity) via graph matching algorithms. .. GENERATED FROM PYTHON SOURCE LINES 314-318 .. code-block:: default n1 = params[0] n2 = params[1] X = pygm.sm(K, n1, n2) .. GENERATED FROM PYTHON SOURCE LINES 319-326 Project ``X`` to neural network matching result. The neural network matching matrix is built by applying Hungarian to small blocks of ``X``, because only the channels from the same neural network layer can be matched. .. note:: In this example, we assume the last FC layer is aligned and need not be matched. .. GENERATED FROM PYTHON SOURCE LINES 326-336 .. code-block:: default new_X = torch.zeros_like(X) new_X[:params[2][0], :params[2][0]] = torch.eye(params[2][0], device=device) for start_idx, length in zip(params[2][:-1], params[3][:-1]): # params[2] and params[3] are the indices of layers slicing = slice(start_idx, start_idx + length) new_X[slicing, slicing] = pygm.hungarian(X[slicing, slicing]) # assume the last FC layer is aligned slicing = slice(params[2][-1], params[2][-1] + params[3][-1]) new_X[slicing, slicing] = torch.eye(params[3][-1], device=device) X = new_X .. GENERATED FROM PYTHON SOURCE LINES 337-339 Visualization of the matching result. The black lines splits the channels of different layers. .. GENERATED FROM PYTHON SOURCE LINES 339-346 .. code-block:: default plt.figure(figsize=(4, 4)) plt.imshow(X.cpu().numpy(), cmap='Blues') for idx in params[2]: plt.axvline(x=idx, color='k') plt.axhline(y=idx, color='k') .. image-sg:: /auto_examples/5.model_fusion/images/sphx_glr_plot_model_fusion_pytorch_002.png :alt: plot model fusion pytorch :srcset: /auto_examples/5.model_fusion/images/sphx_glr_plot_model_fusion_pytorch_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 347-349 Define the alignment function: fuse the models based on matching result .. GENERATED FROM PYTHON SOURCE LINES 349-395 .. code-block:: default def align(solution, fusion_proportion, networks: list, params: list): [_, _, num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list] = params named_weight_list_0 = [named_parameter for named_parameter in networks[0].named_parameters()] aligned_wt_0 = [parameter.data for name, parameter in named_weight_list_0] idx = 0 num_layers = len(aligned_wt_0) for num_before, num_cur, cur_conv, cur_kernel_size in \ zip(num_nodes_incremental, num_nodes_layers, cur_conv_list, conv_kernel_size_list): perm = solution[num_before:num_before + num_cur, num_before:num_before + num_cur] assert 'bias' not in named_weight_list_0[idx][0] if len(named_weight_list_0[idx][1].shape) == 4: aligned_wt_0[idx] = (perm.transpose(0, 1).to(torch.float64) @ aligned_wt_0[idx].to(torch.float64).permute(2, 3, 0, 1)) \ .permute(2, 3, 0, 1) else: aligned_wt_0[idx] = perm.transpose(0, 1).to(torch.float64) @ aligned_wt_0[idx].to(torch.float64) idx += 1 if idx >= num_layers: continue if 'bias' in named_weight_list_0[idx][0]: aligned_wt_0[idx] = aligned_wt_0[idx].to(torch.float64) @ perm.to(torch.float64) idx += 1 if idx >= num_layers: continue if cur_conv and len(named_weight_list_0[idx][1].shape) == 2: aligned_wt_0[idx] = (aligned_wt_0[idx].to(torch.float64) .reshape(aligned_wt_0[idx].shape[0], 64, -1) .permute(0, 2, 1) @ perm.to(torch.float64)) \ .permute(0, 2, 1) \ .reshape(aligned_wt_0[idx].shape[0], -1) elif len(named_weight_list_0[idx][1].shape) == 4: aligned_wt_0[idx] = (aligned_wt_0[idx].to(torch.float64) .permute(2, 3, 0, 1) @ perm.to(torch.float64)) \ .permute(2, 3, 0, 1) else: aligned_wt_0[idx] = aligned_wt_0[idx].to(torch.float64) @ perm.to(torch.float64) assert idx == num_layers averaged_weights = [] for idx, parameter in enumerate(networks[1].parameters()): averaged_weights.append((1 - fusion_proportion) * aligned_wt_0[idx] + fusion_proportion * parameter) return averaged_weights .. GENERATED FROM PYTHON SOURCE LINES 396-401 Test the fused model --------------------- The ``fusion_proportion`` variable denotes the contribution to the new model. For example, if ``fusion_proportion=0.2``, the fused model = 80% model1 + 20% model2. .. GENERATED FROM PYTHON SOURCE LINES 401-432 .. code-block:: default def align_model_and_test(X): acc_list = [] for fusion_proportion in torch.arange(0, 1.1, 0.1): fused_weights = align(X, fusion_proportion, [model1, model2], params) fused_model = SimpleNet() state_dict = fused_model.state_dict() for idx, (key, _) in enumerate(state_dict.items()): state_dict[key] = fused_weights[idx] fused_model.load_state_dict(state_dict) fused_model.to(device) test_loss = 0 correct = 0 for data, target in test_loader: data = data.to(device) target = target.to(device) output = fused_model(data) test_loss += F.nll_loss(output, target, reduction='sum').item() pred = output.data.max(1, keepdim=True)[1] correct += pred.eq(target.data.view_as(pred)).sum() test_loss /= len(test_loader.dataset) acc = 100. * correct / len(test_loader.dataset) print( f"{1 - fusion_proportion:.2f} model1 + {fusion_proportion:.2f} model2 -> fused model accuracy: {acc:.2f}%") acc_list.append(acc) return torch.tensor(acc_list) print('Graph Matching Fusion') gm_acc_list = align_model_and_test(X) .. rst-class:: sphx-glr-script-out .. code-block:: none Graph Matching Fusion 1.00 model1 + 0.00 model2 -> fused model accuracy: 84.18% 0.90 model1 + 0.10 model2 -> fused model accuracy: 85.46% 0.80 model1 + 0.20 model2 -> fused model accuracy: 86.92% 0.70 model1 + 0.30 model2 -> fused model accuracy: 88.38% 0.60 model1 + 0.40 model2 -> fused model accuracy: 86.43% 0.50 model1 + 0.50 model2 -> fused model accuracy: 74.11% 0.40 model1 + 0.60 model2 -> fused model accuracy: 72.45% 0.30 model1 + 0.70 model2 -> fused model accuracy: 78.12% 0.20 model1 + 0.80 model2 -> fused model accuracy: 81.65% 0.10 model1 + 0.90 model2 -> fused model accuracy: 83.29% 0.00 model1 + 1.00 model2 -> fused model accuracy: 83.80% .. GENERATED FROM PYTHON SOURCE LINES 433-435 Compare with vanilla model fusion (no matching), graph matching method stabilizes the fusion step: .. GENERATED FROM PYTHON SOURCE LINES 435-449 .. code-block:: default print('No Matching Fusion') vanilla_acc_list = align_model_and_test(torch.eye(n1, device=device)) plt.figure(figsize=(4, 4)) plt.title('Fused Model Accuracy') plt.plot(torch.arange(0, 1.1, 0.1).numpy(), gm_acc_list.cpu().numpy(), 'r*-', label='Graph Matching Fusion') plt.plot(torch.arange(0, 1.1, 0.1).numpy(), vanilla_acc_list.cpu().numpy(), 'b*-', label='No Matching Fusion') plt.plot(torch.arange(0, 1.1, 0.1).numpy(), [acc1] * 11, '--', color="gray", label='Model1 Accuracy') plt.plot(torch.arange(0, 1.1, 0.1).numpy(), [acc2] * 11, '--', color="brown", label='Model2 Accuracy') plt.gca().set_xlabel('Fusion Proportion') plt.gca().set_ylabel('Accuracy (%)') plt.ylim((70, 87)) plt.legend(loc=3) plt.show() .. image-sg:: /auto_examples/5.model_fusion/images/sphx_glr_plot_model_fusion_pytorch_003.png :alt: Fused Model Accuracy :srcset: /auto_examples/5.model_fusion/images/sphx_glr_plot_model_fusion_pytorch_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none No Matching Fusion 1.00 model1 + 0.00 model2 -> fused model accuracy: 84.18% 0.90 model1 + 0.10 model2 -> fused model accuracy: 84.01% 0.80 model1 + 0.20 model2 -> fused model accuracy: 81.91% 0.70 model1 + 0.30 model2 -> fused model accuracy: 74.67% 0.60 model1 + 0.40 model2 -> fused model accuracy: 60.39% 0.50 model1 + 0.50 model2 -> fused model accuracy: 47.15% 0.40 model1 + 0.60 model2 -> fused model accuracy: 55.36% 0.30 model1 + 0.70 model2 -> fused model accuracy: 72.87% 0.20 model1 + 0.80 model2 -> fused model accuracy: 79.64% 0.10 model1 + 0.90 model2 -> fused model accuracy: 82.55% 0.00 model1 + 1.00 model2 -> fused model accuracy: 83.80% .. GENERATED FROM PYTHON SOURCE LINES 450-453 Print the result summary ------------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 453-458 .. code-block:: default end_time = time.perf_counter() print(f'time consumed for model fusion: {end_time - st_time:.2f} seconds') print(f'model1 accuracy = {acc1}%, model2 accuracy = {acc2}%') print(f"best fused model accuracy: {torch.max(gm_acc_list):.2f}%") .. rst-class:: sphx-glr-script-out .. code-block:: none time consumed for model fusion: 16.75 seconds model1 accuracy = 84.18%, model2 accuracy = 83.8% best fused model accuracy: 88.38% .. GENERATED FROM PYTHON SOURCE LINES 459-463 .. note:: This example supports both GPU and CPU, and the online documentation is built by a CPU-only machine. The efficiency will be significantly improved if you run this code on GPU. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 18.178 seconds) .. _sphx_glr_download_auto_examples_5.model_fusion_plot_model_fusion_pytorch.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_model_fusion_pytorch.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_model_fusion_pytorch.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_