linguistic_encoder.py (4018B)
1 from typing import Callable, Optional, Tuple 2 3 import torch 4 import transformers 5 6 import gbure.utils 7 8 9 class LinguisticEncoder(torch.nn.Module): 10 """ 11 Transformer encoder from Soares et al. 12 13 Correspond to the left part of each subfigure of Figure 2 (Deep Transformer Encoder and the green layer above). 14 We only implement the "entity markers, entity start" variant (which is the one with the best performance). 15 16 Config: 17 transformer_model: Which transformer to use (e.g. bert-large-uncased). 18 post_transformer_layer: The transformation applied after the transformer (must be "linear" or "layer_norm") 19 """ 20 21 def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, transformer: Optional[transformers.PreTrainedModel] = None) -> None: 22 """ 23 Instantiate a Soares et al. encoder. 24 25 Args: 26 config: global config object 27 tokenizer: tokenizer used to create the vocabulary 28 transformer: the transformer to use instead of loading a pre-trained one 29 """ 30 super().__init__() 31 32 self.config: gbure.utils.dotdict = config 33 self.tokenizer: transformers.PreTrainedTokenizer = tokenizer 34 35 self.transformer: transformers.PreTrainedModel 36 if transformer is not None: 37 self.transformer = transformer 38 elif self.config.get("load") or self.config.get("pretrained"): 39 # TODO introduce a config parameter to change the initialization of <tags> embeddings 40 transformer_config = transformers.AutoConfig.from_pretrained(self.config.transformer_model) 41 transformer_config.vocab_size = len(tokenizer) 42 self.transformer = transformers.AutoModel.from_config(transformer_config) 43 else: 44 self.transformer = transformers.AutoModel.from_pretrained(self.config.transformer_model) 45 self.transformer.resize_token_embeddings(len(tokenizer)) 46 47 self.post_transformer: Callable[[torch.Tensor], torch.Tensor] 48 if self.config.post_transformer_layer == "linear": 49 self.post_transformer_linear = torch.nn.Linear(in_features=self.output_size, out_features=self.output_size) 50 self.post_transformer = lambda x: self.post_transformer_linear(x) 51 elif self.config.post_transformer_layer == "layer_norm": 52 # It is not clear whether a Linear should be added before the layer_norm, see Soares et al. section 3.3 53 # Setting elementwise_affine to True (the default) makes little sense when computing similarity scores. 54 self.post_transformer_linear = torch.nn.Linear(in_features=self.output_size, out_features=self.output_size) 55 self.post_transformer_activation = torch.nn.LayerNorm(self.output_size, elementwise_affine=False) 56 self.post_transformer = lambda x: self.post_transformer_activation(self.post_transformer_linear(x)) 57 elif self.config.post_transformer_layer == "none": 58 self.post_transformer = lambda x: x 59 else: 60 raise RuntimeError("Unsuported config value for post_transformer_layer") 61 62 @property 63 def output_size(self) -> int: 64 """ Dimension of the representation returned by the model. """ 65 return 2 * self.transformer.config.hidden_size 66 67 def forward(self, text: torch.Tensor, mask: torch.Tensor, entity_positions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 68 """ Encode the given text into a fixed size representation. """ 69 batch_size: int = text.shape[0] 70 batch_ids: torch.Tensor = torch.arange(batch_size, device=text.device, dtype=torch.int64).unsqueeze(1) 71 72 # The first element of the tuple is the Batch×Sentence×Hidden output matrix. 73 transformer_out: torch.Tensor = self.transformer(text, attention_mask=mask)[0] 74 sentence: torch.Tensor = transformer_out[batch_ids, entity_positions].view(batch_size, self.output_size) 75 return self.post_transformer(sentence), transformer_out