gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

topological_encoder.py (3592B)


      1 from typing import List, Optional, Tuple, Union
      2 import functools
      3 import math
      4 import operator
      5 
      6 import torch
      7 
      8 import gbure.utils
      9 
     10 
     11 class TopologicalEncoder(torch.nn.Module):
     12     """
     13     Encoder for neighborhoods.
     14 
     15     Config:
     16         gcn_aggregator: aggregator used to pool the representations of several neighbors into a fixed-size one.
     17     """
     18 
     19     def __init__(self, config: gbure.utils.dotdict, linguistic_encoder: torch.nn.Module) -> None:
     20         """
     21         Instantiate a Soares et al. encoder.
     22 
     23         Args:
     24             config: global config object
     25             linguistic_encoder: the model used to get a fixed-size representation of text
     26         """
     27         super().__init__()
     28         self.config: gbure.utils.dotdict = config
     29         self.linguistic_encoder: torch.nn.Module = linguistic_encoder
     30 
     31         if self.config.get("gcn_aggregator", "") in ["mean", "chebyshev"]:
     32             self.gcn_layer: torch.nn.Module = torch.nn.Linear(in_features=self.linguistic_encoder.output_size, out_features=self.linguistic_encoder.output_size)
     33 
     34     def forward(self, prefix: str, loop: torch.Tensor, degree_delta: int = 0, **batch) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
     35         """
     36         Encode the neighborhood of the given prefix.
     37 
     38         When a gcn_aggregator is defined, this result in a fixed-size representation, otherwise it returns clouds of points to be compared using optimal transport.
     39         The degree_delta parameters changes the degrees used to compute various GCN weighting. This can be useful when the sample comes from outside the graph so 1 should be added to the degrees.
     40         """
     41         linguistic_embeddings: List[torch.Tensor] = []
     42         masks: List[torch.Tensor] = []
     43         for slot in [1, 2]:
     44             fake_batch_size: int = functools.reduce(operator.mul, batch[f"{prefix}e{slot}_neighborhood_text"].shape[:-1])
     45             mask: torch.Tensor = batch[f"{prefix}e{slot}_neighborhood_mask"].view(fake_batch_size, -1)[:, 0].unsqueeze(1)
     46             linguistic_embeddings.append((self.linguistic_encoder(
     47                     batch[f"{prefix}e{slot}_neighborhood_text"].view(fake_batch_size, -1),
     48                     batch[f"{prefix}e{slot}_neighborhood_mask"].view(fake_batch_size, -1),
     49                     batch[f"{prefix}e{slot}_neighborhood_entity_positions"].view(fake_batch_size, 2)
     50                     )[0] * mask).view(*batch[f"{prefix}e{slot}_neighborhood_text"].shape[:-1], self.linguistic_encoder.output_size))
     51             masks.append(mask.view(*batch[f"{prefix}e{slot}_neighborhood_text"].shape[:-1]))
     52 
     53         if self.config.get("gcn_aggregator", "") == "mean":
     54             head: torch.Tensor = linguistic_embeddings[0].sum(-2)
     55             tail: torch.Tensor = linguistic_embeddings[1].sum(-2)
     56             neighborhood_size: torch.Tensor = sum(mask.sum(-1, keepdim=True) for mask in masks)
     57             return self.gcn_layer((loop + head + tail) / (neighborhood_size + 1))
     58         elif self.config.get("gcn_aggregator", "") == "chebyshev":
     59             pre_embedding: torch.Tensor = loop / torch.sqrt(2 * (batch[f"{prefix}entity_degrees"].sum(-1, keepdim=True) - 1 + degree_delta))
     60             for slot in [1, 2]:
     61                 weights = 1 / torch.sqrt(batch[f"{prefix}e{slot}_neighborhood_entity_degrees"].sum(-1, keepdim=True) - 1 + degree_delta)
     62                 pre_embedding += torch.sum(weights * linguistic_embeddings[slot-1], dim=-2)
     63             return self.gcn_layer(pre_embedding)
     64         else:
     65             return (linguistic_embeddings[0], linguistic_embeddings[1], masks[0], masks[1])