gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

supervised.py (1884B)


      1 from typing import Dict, Tuple
      2 
      3 import torch
      4 import transformers
      5 
      6 from gbure.data.dictionary import RelationDictionary
      7 from gbure.model.linguistic_encoder import LinguisticEncoder
      8 import gbure.utils
      9 
     10 
     11 class Model(torch.nn.Module):
     12     """
     13     Supervised model from Soares et al.
     14 
     15     Correspond to the left subfigure of Figure 2.
     16     """
     17 
     18     def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: RelationDictionary) -> None:
     19         """
     20         Instantiate a Soares et al. supervised model.
     21 
     22         Args:
     23             config: global config object
     24             tokenizer: tokenizer used to create the vocabulary
     25             relation_dictionary: dictionary of all relations
     26         """
     27         super().__init__()
     28 
     29         self.config: gbure.utils.dotdict = config
     30         self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
     31         self.relation_dictionary: RelationDictionary = relation_dictionary
     32 
     33         self.encoder: torch.nn.Module = LinguisticEncoder(config, tokenizer)
     34         self.relation_encoder = torch.nn.Linear(
     35                 in_features=self.encoder.output_size,
     36                 out_features=len(relation_dictionary),
     37                 bias=False)
     38         self.loss_module = torch.nn.CrossEntropyLoss(reduction="mean")
     39 
     40     def forward(self, text: torch.Tensor, mask: torch.Tensor, entity_positions: torch.Tensor, relation: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
     41         """ Compute the supervised loss on the given text and target relation. """
     42         latent: torch.Tensor = self.encoder(text, mask, entity_positions)[0]
     43         logits: torch.Tensor = self.relation_encoder(latent)
     44         loss: torch.Tensor = self.loss_module(logits, relation)
     45         return loss, {}, {"prediction_logits": logits, "predicted_relation": logits.argmax(1)}