pygmtools.utils.from_pyg

pygmtools.utils.from_pyg(G)[source]

Convert torch_geometric.data.Data object to adjacency matrix

Parameters

G – Graph object, whose type must be torch_geometric.data.Data

Returns

the adjacency matrix corresponding to the torch_geometric.data.Data

Example
>>> import torch
>>> from torch_geometric.data import Data
>>> import pygmtools as pygm
>>> pygm.set_backend('pytorch')

# Generate Graph object (edge_attr is 1D edge weights)
>>> edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3], [1, 2, 0, 2, 0, 3, 1]], dtype=torch.long)
>>> edge_attr = torch.rand((7), dtype=torch.float)
>>> G = Data(edge_index=edge_index, edge_attr=edge_attr)
>>> G
Data(edge_index=[2, 7], edge_attr=[7])

# Obtain Adjacency matrix
>>> pygm.utils.from_pyg(G)
tensor([[0.0000, 0.2872, 0.5249, 0.0000],
        [0.5386, 0.0000, 0.8801, 0.0000],
        [0.0966, 0.0000, 0.0000, 0.9825],
        [0.0000, 0.4994, 0.0000, 0.0000]])

# Generate Graph object (edge_attr is multi-dimensional edge features)
>>> edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3], [1, 2, 0, 2, 0, 3, 1]], dtype=torch.long)
>>> edge_attr = torch.rand((7, 5), dtype=torch.float)
>>> G = Data(edge_index=edge_index, edge_attr=edge_attr)
>>> G
Data(edge_index=[2, 7], edge_attr=[7, 5])

# Obtain Adjacency matrix
>>> pygm.utils.from_pyg(G)
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3776, 0.8405, 0.3963, 0.6111, 0.6220],
         [0.4824, 0.6115, 0.5169, 0.2558, 0.8300],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.4206, 0.4795, 0.0512, 0.1543, 0.0133],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1053, 0.9634, 0.1822, 0.8167, 0.4903],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.5127, 0.5046, 0.7905, 0.9613, 0.4695],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5535, 0.1592, 0.0363, 0.2447, 0.7754]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.9172, 0.6820, 0.7201, 0.4397, 0.0732],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])