pygmtools.utils.get_network
- pygmtools.utils.get_network(nn_solver_func, **params)[source]
Get the network object of a neural network solver.
- Parameters
nn_solver_func – the neural network solver function, for example
pygm.pca_gm
params – keyword parameters to define the neural network
- Returns
the network object
Pytorch Example
>>> import pygmtools as pygm >>> import torch >>> pygm.set_backend('pytorch') >>> pygm.utils.get_network(pygm.pca_gm, pretrain='willow') PCA_GM_Net( (gnn_layer_0): Siamese_Gconv( (gconv): Gconv( (a_fc): Linear(in_features=1024, out_features=2048, bias=True) (u_fc): Linear(in_features=1024, out_features=2048, bias=True) ) ) (cross_graph_0): Linear(in_features=4096, out_features=2048, bias=True) (affinity_0): WeightedInnerProdAffinity() (affinity_1): WeightedInnerProdAffinity() (gnn_layer_1): Siamese_Gconv( (gconv): Gconv( (a_fc): Linear(in_features=2048, out_features=2048, bias=True) (u_fc): Linear(in_features=2048, out_features=2048, bias=True) ) ) ) # the neural network can be integrated into a deep learning pipeline >>> net = pygm.utils.get_network(pygm.pca_gm, in_channel=1024, hidden_channel=2048, out_channel=512, num_layers=3, pretrain=False) >>> optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)