Source code for sgmatch.models.ISONET

from typing import List, Optional, Dict
import torch
from torch.functional import Tensor
from torch_geometric.utils import to_dense_adj

from ..modules.encoder import MLPEncoder
from ..modules.scoring import GumbelSinkhornNetwork
from ..modules.propagation import GraphProp
from ..utils.utility import setup_LRL_nn

[docs]class ISONET(torch.nn.Module): r""" End-to-End implementation of the ISONET model from the `"Interpretable Neural Subgraph Matching for Graph Retrieval" <https://ojs.aaai.org/index.php/AAAI/article/view/20784>`_ paper. Args: node_feature_dim (int): Input dimension of node feature embedding vectors. enc_node_hidden_sizes ([int]): Number of hidden neurons in each linear layer for transforming the node features. prop_node_hidden_sizes ([int]): Number of hidden neurons in each linear layer of node update MLP :obj:`f_node`. :obj:`node_feature_dim` is appended as the size of the final linear layer to maintain node embedding dimensionality prop_message_hidden_sizes ([int]): Number of hidden neurons in each linear layer of message computation MLP :obj:`f_node`. Note that the message vector dimensionality (:obj:`prop_message_hidden_sizes[-1]`) may not be equal to :obj:`node_feature_dim`. edge_feature_dim (int, optional): Input dimension of node feature embedding vectors. (default: :obj:`None`) enc_edge_hidden_sizes ([int], optional): Number of hidden neurons in each linear layer for transforming the edge features. (default: :obj:`None`) message_net_init_scale (float, optional): Initialisation scale for the message net output vectors. (default: :obj:`0.1`) node_update_type (str, optional): Type of update applied to node feature vectors (:obj:`"GRU"` or :obj:`"MLP"` or :obj:`"residual"`). (default: :obj:`"GRU"`) use_reverse_direction (bool, optional): Flag for whether or not to use the reverse message aggregation for propagation step. (default: :obj:`True`) reverse_dir_param_different (bool, optional): Flag for whether or not message computation model parameters should be shared by forward and reverse messages in propagation step. (default: :obj:`True`) layer_norm (bool, optional): Flag for applying layer normalization in propagation step. (default: :obj:`False`) lrl_hidden_sizes ([int], optional): List containing the sizes for LRL network to pass edge features of input graphs. (default: :obj:`[16,16]`) temp (float, optional): Temperature parameter in the Gumbel-Sinkhorn Network. (default: :obj:`0.1`) eps (float, optional): Small value for numerical stability and precision in the Gumbel-Sinkhorn Network. (default: :obj:`1e-20`) noise_factor (float, optional): Parameter which controls the magnitude of the effect of sampled Gumbel Noise. (default: :obj:`1`) gs_num_iters (int, optional): Number of iterations of Sinkhorn Row and Column scaling (in practice, as little as 20 iterations are needed to achieve decent convergence for N~100). (default: :obj:`20`) """ def __init__(self, node_feature_dim: int, enc_node_hidden_sizes: List[int], prop_node_hidden_sizes: List[int], prop_message_hidden_sizes: List[int], edge_feature_dim: Optional[int] = None, enc_edge_hidden_sizes: Optional[List[int]] = None, message_net_init_scale: float = 0.1, node_update_type: str = 'GRU', use_reverse_direction: bool = True, reverse_dir_param_different: bool = True, layer_norm: bool = False, lrl_hidden_sizes: List[int] = [16, 16], temp: float = 0.1, eps: float = 1e-20, noise_factor: float = 1, gs_num_iters: int = 20): super(ISONET, self).__init__() self.node_feature_dim = node_feature_dim self.edge_feature_dim = edge_feature_dim # Encoder Module self.enc_node_layers = enc_node_hidden_sizes self.enc_edge_layers = enc_edge_hidden_sizes # Propagation Module self.prop_node_layers = prop_node_hidden_sizes self.prop_message_layers = prop_message_hidden_sizes self.message_net_init_scale = message_net_init_scale # Unused self.node_update_type = node_update_type self.use_reverse_direction = use_reverse_direction self.reverse_dir_param_different = reverse_dir_param_different self.layer_norm = layer_norm self.prop_type = "embedding" # Gumbel-Sinkhorn Network Parameters self.lrl_hidden_sizes = lrl_hidden_sizes self.temp = temp self.eps = eps self.noise_factor = noise_factor self.gs_num_iters = gs_num_iters def setup_layers(self): # Used the same nomenclature present in the ISONET paper self._init = MLPEncoder(self.node_feature_dim, self.enc_node_layers, edge_feature_dim=self.edge_feature_dim, edge_hidden_sizes=self.enc_edge_layers) self._message_agg_comb = GraphProp(self.enc_node_layers[-1], self.prop_node_layers, self.prop_message_layers, edge_feature_dim=self.edge_feature_dim, message_net_init_scale=self.message_net_init_scale, node_update_type=self.node_update_type, use_reverse_direction=self.use_reverse_direction, reverse_dir_param_different=self.reverse_dir_param_different, layer_norm=self.layer_norm, prop_type=self.prop_type) self.LRL = setup_LRL_nn(input_dim=self.prop_message_layers[-1], hidden_sizes=self.lrl_hidden_sizes) self.gumbel_sinkhorn = GumbelSinkhornNetwork(self.temp, self.eps, self.noise_factor, self.gs_num_iters) def reset_parameters(self): self._init.reset_parameters() self._message_agg_comb.reset_parameters() for layer in self.LRL[::2]: layer.reset_parameters() def embed_edges(self, node_features: Tensor, edge_index: Tensor, edge_features: Optional[Tensor] = None, num_prop: int = 5): from_idx = edge_index[:,0] if len(edge_index.shape) == 3 else edge_index[0] to_idx = edge_index[:,1] if len(edge_index.shape) == 3 else edge_index[1] if edge_features is not None: node_features, edge_features = self._init(node_features, edge_features) else: node_features = self._init(node_features) # This calculates h_u(K) for _ in range(num_prop): # TODO: Can include a list keeping track of propagation layer outputs node_features = self._message_agg_comb(node_features, from_idx, to_idx, edge_features) # Computes r_(u,v)_(K) edge_message = self._message_agg_comb._compute_messages(node_features, from_idx, to_idx, self._message_agg_comb.message_net, edge_features=edge_features) #reverse_edge_message = self._message_agg_comb._compute_messages(node_features, to_idx, from_idx, #self._message_agg_comb.message_net, #edge_features=edge_features) return edge_message def forward(self, node_features_q: Tensor, node_features_c: Tensor, edge_index_q: Tensor, edge_index_c: Tensor, edge_features_q: Optional[Tensor] = None, edge_features_c: Optional[Tensor] = None, batch_q: Optional[Tensor] = None, batch_c: Optional[Tensor] = None, num_prop: int = 5): # Computes r_(u,v)_(K) edge_features_q = self.embed_edges(node_features_q, edge_index_q, edge_features_q, num_prop) edge_features_c = self.embed_edges(node_features_c, edge_index_c, edge_features_c, num_prop) # Once we have the Node and Edge Embeddings, we create R_q and R_c Matrices if len(edge_index_q.shape)==3 and len(edge_index_c.shape)==3: # Finding out the maximum num of edges in any graph - query / corpus in the batch max_num_edges = max([edge_index.shape[1].item() for edge_index in edge_index_q]) max_num_edges = max(max_num_edges, max([edge_index.shape[1].item() for edge_index in edge_index_c])) edge_features_q_batched = torch.stack([torch.functional.pad(x, pad=(0,0,0,max_num_edges-x.shape[0]))\ for x in edge_features_q]) edge_features_c_batched = torch.stack([torch.functional.pad(x, pad=(0,0,0,max_num_edges-x.shape[0]))\ for x in edge_features_c]) else: edge_features_q_batched = to_dense_adj(edge_index=edge_index_q, batch=batch_q, edge_attr=edge_features_q) edge_features_c_batched = to_dense_adj(edge_index=edge_index_c, batch=batch_c, edge_attr=edge_features_c) # Passing R_q and R_c through the LRL and the Gumbel-Sinkhorn Network edge_features_q_batched_lrl = self.LRL(edge_features_q_batched) edge_features_c_batched_lrl = self.LRL(edge_features_c_batched) soft_permutation_matrix = self.gumbel_sinkhorn(torch.matmul(edge_features_q_batched_lrl, edge_features_c_batched_lrl.permute(0,2,1))) # Calculating the Distance between corpus and query graph using the Soft Permutation Matrix d_cq = torch.nn.ReLU(edge_features_q_batched - torch.matmul(soft_permutation_matrix, edge_features_c_batched)) d_cq = torch.sum(d_cq, dim=(1,2)) return d_cq def __repr__(self) -> str: # TODO: Update __repr__ with information return super().__repr__()