similarity.py (6630B)
1 from typing import Dict, List, Optional, Tuple, Union 2 3 import geomloss 4 import torch 5 6 import gbure.utils 7 8 9 class LinguisticSimilarity(torch.nn.Module): 10 """ 11 Compute the similarity between two bertcoder representations. 12 13 The bertcoder embeddings are all in the same direction of space since they all correspond to the tags <e1> and <e2> 14 Thus after a dot product, the activations are not standardized anymore, even after scaling by √d 15 16 Config: 17 linguistic_similarity: the function used to compute the similarity between embeddings 18 linguistic_similarity_delta: an additive constant to add to all similarity (useful to have a strictly positive cosine) 19 latent_metric_scale: how to scale the similarity once computed 20 latent_dot_mean: when latent_metric_scale=="standard", the value to substract 21 latent_dot_std: when latent_metric_scale=="standard", the value to divide by 22 """ 23 24 def __init__(self, config: gbure.utils.dotdict) -> None: 25 super().__init__() 26 self.config = config 27 28 def forward(self, lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor: 29 """ Compute the similarity between lhs and rhs. """ 30 if self.config.linguistic_similarity == "dot": 31 similarity: torch.Tensor = torch.einsum('b...d,b...d->b...', lhs, rhs) 32 elif self.config.linguistic_similarity == "cosine": 33 similarity: torch.Tensor = torch.nn.functional.cosine_similarity(lhs, rhs, dim=-1) 34 elif self.config.linguistic_similarity == "euclidean": 35 similarity: torch.Tensor = - torch.sum((lhs - rhs)**2, dim=-1) 36 else: 37 raise RuntimeError(f"Unknown config value for linguistic_similarity: {self.config.linguistic_similarity}") 38 similarity += self.config.get("linguistic_similarity_delta", 0) 39 40 encoder_output_size: int = lhs.shape[-1] 41 if self.config.get("latent_metric_scale") == "sqrt": 42 # If X, Y ~ N(0, I_d): 43 # E[X·Y] = 0 44 # Var[X·Y] = √d 45 similarity = similarity / (encoder_output_size ** 0.5) 46 elif self.config.get("latent_metric_scale") == "full": 47 similarity = similarity / encoder_output_size 48 elif self.config.get("latent_metric_scale") == "match": 49 # If X ~ N(0, 1): 50 # E[X²] = 1 51 # Var[X²] = √2 52 similarity = (similarity - encoder_output_size) / ((2**0.5) * encoder_output_size) 53 elif self.config.get("latent_metric_scale") == "standard": 54 similarity = (similarity - self.config.latent_dot_mean) / self.config.latent_dot_std 55 elif self.config.get("latent_metric_scale") is not None: 56 raise RuntimeError("Unsuported config value for latent_metric_scale") 57 58 return similarity 59 60 61 class TopologicalSimilarity(torch.nn.Module): 62 """ 63 Compute the similarity between two neighborhood representations. 64 65 This similarity is either an inner product in the case of fixed-size representations or the negative of the 1-Wasserstein distance. 66 """ 67 68 def __init__(self, config: gbure.utils.dotdict) -> None: 69 super().__init__() 70 self.config = config 71 if self.config.get("gcn_aggregator", "none") == "none": 72 self.sinkhorn = geomloss.SamplesLoss(loss="sinkhorn", p=self.config.get("wasserstein_underlying_distance", 2), blur=self.config.get("sinkhorn_blur", 0.05)) 73 74 @staticmethod 75 def merge_shapes(lhs: List[int], rhs: List[int]) -> List[int]: 76 """ Find the shape of an elementwise operation between two tensors of the given shapes. """ 77 lhs = [1] * (len(rhs) - len(lhs)) + lhs 78 rhs = [1] * (len(lhs) - len(rhs)) + rhs 79 res = [] 80 for left, right in zip(lhs, rhs): 81 if left == right: 82 res.append(left) 83 elif left == 1 or right == 1: 84 res.append(left + right - 1) 85 else: 86 raise RuntimeError(f"Incompatible shapes {lhs} and {rhs}") 87 return res 88 89 def forward(self, lhs: Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], rhs: Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: 90 """ 91 Compute the distance between two topological embeddings. 92 93 The embeddings are either a single vector (per sample) which was pooled using a GCN, or four matrices corresponding to: 94 - the embeddings of the head neighborhood 95 - the embeddings of the tail neighborhood 96 - the mask of the head neighborhood 97 - the mask of the tail neighborhood 98 """ 99 if isinstance(lhs, tuple): 100 shape: List[int] = self.merge_shapes(list(lhs[0].shape)[:-2], list(rhs[0].shape)[:-2]) 101 102 def expand_tuple(t: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 103 return (t[0].expand(tuple(shape)+t[0].shape[-2:]), 104 t[1].expand(tuple(shape)+t[1].shape[-2:]), 105 t[2].expand(tuple(shape)+t[2].shape[-1:]), 106 t[3].expand(tuple(shape)+t[3].shape[-1:])) 107 108 lhs = expand_tuple(lhs) 109 rhs = expand_tuple(rhs) 110 111 cumulative_shape: List[int] = [] 112 accumulator: int = 1 113 for x in shape: 114 cumulative_shape.append(accumulator) 115 accumulator *= x 116 result: torch.Tensor = lhs[0].new_zeros(shape+[2]) 117 result_mask: torch.Tensor = lhs[0].new_zeros(shape+[2], dtype=torch.bool) 118 for i in range(accumulator): 119 indices = [] 120 for denominator, modulo in zip(cumulative_shape, shape): 121 indices.append(i // denominator % modulo) 122 indices = tuple(indices) 123 for slot in range(2): 124 sample_lhs: torch.Tensor = lhs[slot][indices][lhs[2+slot][indices]] 125 sample_rhs: torch.Tensor = rhs[slot][indices][rhs[2+slot][indices]] 126 if sample_lhs.numel() > 0 and sample_rhs.numel() > 0: 127 result[indices+(slot,)] = - self.sinkhorn(sample_lhs, sample_rhs) 128 result_mask[indices+(slot,)] = True 129 else: 130 result_mask[indices+(slot,)] = False 131 return result, result_mask 132 else: 133 # TODO implement other topological similarities 134 return torch.einsum('b...d,b...d->b...', lhs, rhs), None