Source code for sgmatch.modules.scoring

from typing import Optional, List

import torch
from torch.functional import Tensor

[docs]class NeuralTensorNetwork(torch.nn.Module): r""" Neural Tensor Network layer from the `"SimGNN: A Neural Network Approach to Fast Graph Similarity Computation" <https://arxiv.org/pdf/1808.05689.pdf>`_ paper TODO: Include latex formula for NTN interaction score computation Args: input_dim: Input dimension of the graph-level embeddings slices. That is, number of slices (K) the weight tensor possesses. Often interpreted as the number of entity-pair (in this use case - pairwise node) relations the data might possess. activation: Non-linearity applied on the computed output of the layer """ def __init__(self, input_dim: int, slices: int = 16, activation: str = "tanh"): super(NeuralTensorNetwork, self).__init__() self.input_dim = input_dim self.slices = slices # K: hyperparameter self.activation = activation self.initialize_parameters() self.reset_parameters() def initialize_parameters(self): # XXX: Will arranging weight tensor as (k, d, d) cause problems in batching at dim 0? self.weight_tensor = torch.nn.Parameter(torch.Tensor(self.slices, self.input_dim, self.input_dim)) self.weight_matrix = torch.nn.Parameter(torch.Tensor(self.slices, 2 * self.input_dim)) self.bias = torch.nn.Parameter(torch.Tensor(self.slices, 1)) def reset_parameters(self): torch.nn.init.xavier_uniform_(self.weight_tensor) torch.nn.init.xavier_uniform_(self.weight_matrix) torch.nn.init.xavier_uniform_(self.bias) def forward(self, h_i: Tensor, h_j: Tensor) -> torch.Tensor: r""" Args: h_i: Graph-level Embedding of the Source/Query graph h_j: Graph-level Embedding of the Corpus/Target graph Returns: (K, 1) graph-graph interaction score vector """ scores = torch.matmul(h_i.unsqueeze(-2).unsqueeze(-2), self.weight_tensor) scores = torch.matmul(scores, h_j.unsqueeze(-2).unsqueeze(-1)).squeeze(-1) scores += torch.matmul(self.weight_matrix, torch.cat([h_i, h_j], dim=-1).unsqueeze(-1)) scores += self.bias # TODO: need to remove this from here and include it in a function in utils if self.activation == "tanh": _activation = torch.nn.functional.tanh elif self.activation == "sigmoid": _activation = torch.nn.functional.sigmoid scores = _activation(scores) return scores
class GumbelSinkhornNetwork(torch.nn.Module): r""" Implementation of the Gumbel-Sinkhorn Network from the `"Learning Latent Permutations with Gumbel-Sinkhorn Networks" <https://arxiv.org/pdf/1802.08665.pdf>`_ paper. Args: temp (float, optional): Temperature parameter for softmax distribution. Lower the value causes softmax probabilities to approach a one-hot vector. Thus, this is a differentiable way to approach a categorical distribution. (default: :obj:`0.1`) eps (float, optional): Small value for numerical stability and precision. (default: :obj:`1e-20`) noise_factor (float, optional): Parameter which controls the magnitude of the effect of sampled Gumbel Noise (default: :obj:`1`) n_iters (int, optional): Number of iterations of Sinkhorn Row and Column scaling (in practice, as little as 20 iterations are needed to achieve decent convergence for N~100). (default: :obj:`20`) """ def __init__(self, temp: float = 0.1, eps: float = 1e-20, noise_factor: float = 1, n_iters: int = 20): super().__init__() self.temp = temp self.eps = eps self.noise_factor = noise_factor self.n_iters = n_iters def sample_from_gumbel(self, shape:List[int], eps:int=1e-20): U = torch.rand(shape).float() return -torch.log(eps - torch.log(U + eps)) def forward(self, log_alpha:Tensor) -> torch.Tensor: r""" Args: log_alpha (torch.Tensor): A Tensor with elements being the logarithms of assignment probabilites Shapes: - **input:** log_alpha: 2D Tensor of shape :math:`(N, N)` or 3D tensor of shape :math:`(\mbox{batch_size}, N, N)` - **output:** Doubly Stochastic Matrix :math:`(\mbox{batch_size}, N, N)` Returns: A 3D tensor of close-to-doubly-stochastic matrices (2D tensors are converted to 3D tensors with batch_size equals to 1) :rtype: :class:`torch.Tensor` """ batch_size = log_alpha.size()[0] n = log_alpha.size()[1] log_alpha = log_alpha.view(-1, n, n) # Sampling from Gumbel Distribution shape = [batch_size, n, n] noise = self.sample_from_gumbel(shape, self.eps)*self.noise_factor log_alpha = log_alpha + noise # Iteration 0 of GS Network log_alpha = log_alpha/self.temp for i in range(self.n_iters): # Row Scaling log_alpha = log_alpha - (torch.logsumexp(log_alpha, dim=2, keepdim=True)).view(-1, n, 1) # Column Scaling log_alpha = log_alpha - (torch.logsumexp(log_alpha, dim=1, keepdim=True)).view(-1, 1, n) return torch.exp(log_alpha)
[docs]def similarity(h_i, h_j, mode:str = "cosine"): # BUG: similarity may not return the correct product # BUG: cosine returns 0-1 normalised similarity and others return unnormalized distance if mode == "cosine": return torch.nn.functional.cosine_similarity(h_i, h_j, dim=-1) if mode == "euclidean": return torch.cdist(h_i, h_j, p=2) if mode == "manhattan": return torch.cdist(h_i, h_j, p=1) if mode == "hamming": return torch.cdist(h_i, h_j, p=0) if mode == "hinge": prod = 1 - torch.matmul(h_i, h_j) return torch.max(torch.zeros(prod.shape), prod) if mode == "min_sum": return torch.sum(torch.min(h_i, h_j), dim=-1)