supervised.py (1884B)
1 from typing import Dict, Tuple 2 3 import torch 4 import transformers 5 6 from gbure.data.dictionary import RelationDictionary 7 from gbure.model.linguistic_encoder import LinguisticEncoder 8 import gbure.utils 9 10 11 class Model(torch.nn.Module): 12 """ 13 Supervised model from Soares et al. 14 15 Correspond to the left subfigure of Figure 2. 16 """ 17 18 def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: RelationDictionary) -> None: 19 """ 20 Instantiate a Soares et al. supervised model. 21 22 Args: 23 config: global config object 24 tokenizer: tokenizer used to create the vocabulary 25 relation_dictionary: dictionary of all relations 26 """ 27 super().__init__() 28 29 self.config: gbure.utils.dotdict = config 30 self.tokenizer: transformers.PreTrainedTokenizer = tokenizer 31 self.relation_dictionary: RelationDictionary = relation_dictionary 32 33 self.encoder: torch.nn.Module = LinguisticEncoder(config, tokenizer) 34 self.relation_encoder = torch.nn.Linear( 35 in_features=self.encoder.output_size, 36 out_features=len(relation_dictionary), 37 bias=False) 38 self.loss_module = torch.nn.CrossEntropyLoss(reduction="mean") 39 40 def forward(self, text: torch.Tensor, mask: torch.Tensor, entity_positions: torch.Tensor, relation: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: 41 """ Compute the supervised loss on the given text and target relation. """ 42 latent: torch.Tensor = self.encoder(text, mask, entity_positions)[0] 43 logits: torch.Tensor = self.relation_encoder(latent) 44 loss: torch.Tensor = self.loss_module(logits, relation) 45 return loss, {}, {"prediction_logits": logits, "predicted_relation": logits.argmax(1)}