gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

similarity.py (6630B)


      1 from typing import Dict, List, Optional, Tuple, Union
      2 
      3 import geomloss
      4 import torch
      5 
      6 import gbure.utils
      7 
      8 
      9 class LinguisticSimilarity(torch.nn.Module):
     10     """
     11     Compute the similarity between two bertcoder representations.
     12 
     13     The bertcoder embeddings are all in the same direction of space since they all correspond to the tags <e1> and <e2>
     14     Thus after a dot product, the activations are not standardized anymore, even after scaling by √d
     15 
     16     Config:
     17         linguistic_similarity: the function used to compute the similarity between embeddings
     18         linguistic_similarity_delta: an additive constant to add to all similarity (useful to have a strictly positive cosine)
     19         latent_metric_scale: how to scale the similarity once computed
     20         latent_dot_mean: when latent_metric_scale=="standard", the value to substract
     21         latent_dot_std: when latent_metric_scale=="standard", the value to divide by
     22     """
     23 
     24     def __init__(self, config: gbure.utils.dotdict) -> None:
     25         super().__init__()
     26         self.config = config
     27 
     28     def forward(self, lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor:
     29         """ Compute the similarity between lhs and rhs. """
     30         if self.config.linguistic_similarity == "dot":
     31             similarity: torch.Tensor = torch.einsum('b...d,b...d->b...', lhs, rhs)
     32         elif self.config.linguistic_similarity == "cosine":
     33             similarity: torch.Tensor = torch.nn.functional.cosine_similarity(lhs, rhs, dim=-1)
     34         elif self.config.linguistic_similarity == "euclidean":
     35             similarity: torch.Tensor = - torch.sum((lhs - rhs)**2, dim=-1)
     36         else:
     37             raise RuntimeError(f"Unknown config value for linguistic_similarity: {self.config.linguistic_similarity}")
     38         similarity += self.config.get("linguistic_similarity_delta", 0)
     39 
     40         encoder_output_size: int = lhs.shape[-1]
     41         if self.config.get("latent_metric_scale") == "sqrt":
     42             # If X, Y ~ N(0, I_d):
     43             #     E[X·Y] = 0
     44             #     Var[X·Y] = √d
     45             similarity = similarity / (encoder_output_size ** 0.5)
     46         elif self.config.get("latent_metric_scale") == "full":
     47             similarity = similarity / encoder_output_size
     48         elif self.config.get("latent_metric_scale") == "match":
     49             # If X ~ N(0, 1):
     50             #     E[X²] = 1
     51             #     Var[X²] = √2
     52             similarity = (similarity - encoder_output_size) / ((2**0.5) * encoder_output_size)
     53         elif self.config.get("latent_metric_scale") == "standard":
     54             similarity = (similarity - self.config.latent_dot_mean) / self.config.latent_dot_std
     55         elif self.config.get("latent_metric_scale") is not None:
     56             raise RuntimeError("Unsuported config value for latent_metric_scale")
     57 
     58         return similarity
     59 
     60 
     61 class TopologicalSimilarity(torch.nn.Module):
     62     """
     63     Compute the similarity between two neighborhood representations.
     64 
     65     This similarity is either an inner product in the case of fixed-size representations or the negative of the 1-Wasserstein distance.
     66     """
     67 
     68     def __init__(self, config: gbure.utils.dotdict) -> None:
     69         super().__init__()
     70         self.config = config
     71         if self.config.get("gcn_aggregator", "none") == "none":
     72             self.sinkhorn = geomloss.SamplesLoss(loss="sinkhorn", p=self.config.get("wasserstein_underlying_distance", 2), blur=self.config.get("sinkhorn_blur", 0.05))
     73 
     74     @staticmethod
     75     def merge_shapes(lhs: List[int], rhs: List[int]) -> List[int]:
     76         """ Find the shape of an elementwise operation between two tensors of the given shapes. """
     77         lhs = [1] * (len(rhs) - len(lhs)) + lhs
     78         rhs = [1] * (len(lhs) - len(rhs)) + rhs
     79         res = []
     80         for left, right in zip(lhs, rhs):
     81             if left == right:
     82                 res.append(left)
     83             elif left == 1 or right == 1:
     84                 res.append(left + right - 1)
     85             else:
     86                 raise RuntimeError(f"Incompatible shapes {lhs} and {rhs}")
     87         return res
     88 
     89     def forward(self, lhs: Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], rhs: Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
     90         """
     91         Compute the distance between two topological embeddings.
     92 
     93         The embeddings are either a single vector (per sample) which was pooled using a GCN, or four matrices corresponding to:
     94             - the embeddings of the head neighborhood
     95             - the embeddings of the tail neighborhood
     96             - the mask of the head neighborhood
     97             - the mask of the tail neighborhood
     98         """
     99         if isinstance(lhs, tuple):
    100             shape: List[int] = self.merge_shapes(list(lhs[0].shape)[:-2], list(rhs[0].shape)[:-2])
    101 
    102             def expand_tuple(t: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    103                 return (t[0].expand(tuple(shape)+t[0].shape[-2:]),
    104                         t[1].expand(tuple(shape)+t[1].shape[-2:]),
    105                         t[2].expand(tuple(shape)+t[2].shape[-1:]),
    106                         t[3].expand(tuple(shape)+t[3].shape[-1:]))
    107 
    108             lhs = expand_tuple(lhs)
    109             rhs = expand_tuple(rhs)
    110 
    111             cumulative_shape: List[int] = []
    112             accumulator: int = 1
    113             for x in shape:
    114                 cumulative_shape.append(accumulator)
    115                 accumulator *= x
    116             result: torch.Tensor = lhs[0].new_zeros(shape+[2])
    117             result_mask: torch.Tensor = lhs[0].new_zeros(shape+[2], dtype=torch.bool)
    118             for i in range(accumulator):
    119                 indices = []
    120                 for denominator, modulo in zip(cumulative_shape, shape):
    121                     indices.append(i // denominator % modulo)
    122                 indices = tuple(indices)
    123                 for slot in range(2):
    124                     sample_lhs: torch.Tensor = lhs[slot][indices][lhs[2+slot][indices]]
    125                     sample_rhs: torch.Tensor = rhs[slot][indices][rhs[2+slot][indices]]
    126                     if sample_lhs.numel() > 0 and sample_rhs.numel() > 0:
    127                         result[indices+(slot,)] = - self.sinkhorn(sample_lhs, sample_rhs)
    128                         result_mask[indices+(slot,)] = True
    129                     else:
    130                         result_mask[indices+(slot,)] = False
    131             return result, result_mask
    132         else:
    133             # TODO implement other topological similarities
    134             return torch.einsum('b...d,b...d->b...', lhs, rhs), None