pygmtools.utils.get_network

pygmtools.utils.get_network(nn_solver_func, **params)[source]

Get the network object of a neural network solver.

Parameters
  • nn_solver_func – the neural network solver function, for example pygm.pca_gm

  • params – keyword parameters to define the neural network

Returns

the network object

Pytorch Example
>>> import pygmtools as pygm
>>> import torch
>>> pygm.set_backend('pytorch')
>>> pygm.utils.get_network(pygm.pca_gm, pretrain='willow')
PCA_GM_Net(
  (gnn_layer_0): Siamese_Gconv(
    (gconv): Gconv(
      (a_fc): Linear(in_features=1024, out_features=2048, bias=True)
      (u_fc): Linear(in_features=1024, out_features=2048, bias=True)
    )
  )
  (cross_graph_0): Linear(in_features=4096, out_features=2048, bias=True)
  (affinity_0): WeightedInnerProdAffinity()
  (affinity_1): WeightedInnerProdAffinity()
  (gnn_layer_1): Siamese_Gconv(
    (gconv): Gconv(
      (a_fc): Linear(in_features=2048, out_features=2048, bias=True)
      (u_fc): Linear(in_features=2048, out_features=2048, bias=True)
    )
  )
)

# the neural network can be integrated into a deep learning pipeline
>>> net = pygm.utils.get_network(pygm.pca_gm, in_channel=1024, hidden_channel=2048, out_channel=512, num_layers=3, pretrain=False)
>>> optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)