Source code for sgmatch.models.NeuroMatch

from typing import Optional, List

import torch
import torch_geometric.nn as pyg_nn
from torch.functional import Tensor
from ..utils.utility import setup_LRL_nn
from ..utils.constants import CONVS

[docs]class SkipLastGNN(torch.nn.Module): r""" End to end implementation of NeuroMatch from the `"Neural Subgraph Matching" <https://arxiv.org/abs/2007.03092>`_ paper TODO: Provide argument description Args: input_dim (int): Input dimension of node feature vectors. hidden_dim (int): Dimension of output_dim (int): Input dimension of node feature vectors. num_layers (int): conv_type (str, optional): Type of Graph Neural Network to encode input features (:obj:`"Neuro-PNA"` or :obj:`"PNA"` or :obj:`"GCN"` or :obj:`"GAT"`or :obj:`"SAGE"` or :obj:`"GIN"` or :obj:`"graph"` or :obj:`"gated"`). (default: :obj:`"Neuro-PNA"`) dropout (float, optional): Dropout probability to prevent overfitting (default: :obj:`0.0`) skip (str, optional): Type of skip (default: :obj:`"learnable"`) """ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, conv_type: str = "Neuro-PNA", dropout: float = 0.0, skip: str = "learnable"): super(SkipLastGNN, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.output_dim = output_dim self.num_layers = num_layers self.conv_type = conv_type self.dropout = dropout self.skip = skip # XXX: Does feature preprocessing need to be moved elsewhere and # XXX: even included in the library # if len(feature_preprocess.FEATURE_AUGMENT) > 0: # self.feat_preprocess = feature_preprocess.Preprocess(input_dim) # input_dim = self.feat_preprocess.dim_out # else: # self.feat_preprocess = None # Using setup_LRL_nn over setup_linear_nn for single linear layer to get torch.nn.Sequential behaviour self.pre_mlp = setup_LRL_nn(self.input_dim, hidden_sizes=[3*self.hidden_dim] if self.conv_type=="Neuro-PNA" else [self.hidden_dim]) # TODO: Include error checking for invalid convolution type self.conv_model = CONVS[self.conv_type] if self.conv_type == "Neuro-PNA": self.convs_sum = torch.nn.ModuleList() self.convs_mean = torch.nn.ModuleList() self.convs_max = torch.nn.ModuleList() else: self.convs = torch.nn.ModuleList() if self.skip == 'learnable': self.learnable_skip = torch.nn.Parameter(torch.ones(self.num_layers, self.num_layers)) for layer in range(self.num_layers): if self.skip == 'all' or self.skip == 'learnable': hidden_input_dim = self.hidden_dim * (layer + 1) else: hidden_input_dim = self.hidden_dim if self.conv_type == "Neuro-PNA": self.convs_sum.append(self.conv_model(3*hidden_input_dim, self.hidden_dim)) self.convs_mean.append(self.conv_model(3*hidden_input_dim, self.hidden_dim)) self.convs_max.append(self.conv_model(3*hidden_input_dim, self.hidden_dim)) else: self.convs.append(self.conv_model(hidden_input_dim, self.hidden_dim)) post_input_dim = self.hidden_dim * (self.num_layers + 1) if self.conv_type == "PNA": post_input_dim *= 3 self.post_mlp = torch.nn.Sequential( torch.nn.Linear(post_input_dim, self.hidden_dim), torch.nn.Dropout(self.dropout), torch.nn.LeakyReLU(0.1), torch.nn.Linear(self.hidden_dim, self.output_dim), torch.nn.ReLU(), torch.nn.Linear(self.hidden_dim, 256), # Should this be output_dim torch.nn.ReLU(), torch.nn.Linear(256, self.hidden_dim)) #self.batch_norm = torch.nn.BatchNorm1d(output_dim, eps=1e-5, momentum=0.1) def forward(self, node_features: Tensor, edge_index: Tensor): # if self.feat_preprocess is not None: # if not hasattr(data, "preprocessed"): # data = self.feat_preprocess(data) # data.preprocessed = True # x, edge_index, batch = data.node_feature, data.edge_index, data.batch node_features = self.pre_mlp(node_features) all_emb = node_features.unsqueeze(1) emb = node_features for i in range(len(self.convs_sum) if self.conv_type=="Neuro-PNA" else len(self.convs)): if self.skip == 'learnable': skip_vals = self.learnable_skip[i,:i+1].unsqueeze(0).unsqueeze(-1) curr_emb = all_emb * torch.sigmoid(skip_vals) curr_emb = curr_emb.view(node_features.size(0), -1) if self.conv_type == "Neuro-PNA": node_features = torch.cat((self.convs_sum[i](curr_emb, edge_index), self.convs_mean[i](curr_emb, edge_index), self.convs_max[i](curr_emb, edge_index)), dim=-1) else: node_features = self.convs[i](curr_emb, edge_index) elif self.skip == 'all': if self.conv_type == "Neuro-PNA": node_features = torch.cat((self.convs_sum[i](emb, edge_index), self.convs_mean[i](emb, edge_index), self.convs_max[i](emb, edge_index)), dim=-1) else: node_features = self.convs[i](emb, edge_index) else: node_features = self.convs[i](node_features, edge_index) node_features = torch.nn.functional.relu(node_features) node_features = torch.nn.functional.dropout(node_features, p=self.dropout, training=self.training) emb = torch.cat((emb, node_features), 1) if self.skip == 'learnable': all_emb = torch.cat((all_emb, node_features.unsqueeze(1)), 1) # node_features = pyg_nn.global_mean_pool(node_features, batch) # emb = torch_geometric.nn.global_add_pool(emb, batch) emb = torch.sum(emb, dim=-2) emb = self.post_mlp(emb) #emb = self.batch_norm(emb) # TODO: test #out = F.log_softmax(emb, dim=1) return emb def loss(self, pred, label): return torch.nn.functional.nll_loss(pred, label)