outputs.py (2509B)
1 from __future__ import annotations 2 from typing import Dict, Optional, Type 3 import pathlib 4 import types 5 6 import torch 7 import transformers 8 9 from gbure.data.dictionary import RelationDictionary 10 11 12 class Outputs: 13 """ 14 Class for outputing data about the learned model. 15 """ 16 # TODO parametrize this class 17 18 def __init__(self, logdir: pathlib.Path, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: RelationDictionary) -> None: 19 """ Initialize the outputs variables, but do not acquire necessary resources yet. """ 20 self.logdir: pathlib.Path = logdir 21 self.tokenizer: transformers.PreTrainedTokenizer = tokenizer 22 self.relation_dictionary: RelationDictionary = relation_dictionary 23 24 def __enter__(self) -> Outputs: 25 """ Acquire resources needed for outputing the data. """ 26 # TODO parametrize this file name with split and epoch 27 self.target_prediction_file = (self.logdir / "target_prediction").open("w") 28 return self 29 30 def __exit__(self, 31 exc_type: Optional[Type[BaseException]], 32 exc_inst: Optional[BaseException], 33 exc_tb: Optional[types.TracebackType]) -> None: 34 """ Free used resources. """ 35 self.target_prediction_file.close() 36 37 def update(self, batch: Dict[str, torch.Tensor], loss: torch.Tensor, losses: Dict[str, torch.Tensor], variables: Dict[str, torch.Tensor]) -> None: 38 """ 39 Update model output data files with the given batch and the outputs of the model on this batch. 40 41 The variables dictionary returned by the model must contain a predicted_relation tensor. 42 43 Args: 44 batch: the input values used for evaluation 45 loss: the loss optimized by the model 46 losses: intermediary (unweighted) losses 47 variables: internal variables used by the model to compute the loss 48 """ 49 targets: torch.Tensor = batch.get("relation", batch.get("query_relation")) 50 predictions: torch.Tensor = variables.get("predicted_relation") 51 if targets is not None and predictions is not None: 52 for prediction, target in zip(predictions, targets): 53 print(f"{self.relation_dictionary.decode(target)}\t{self.relation_dictionary.decode(prediction)}", file=self.target_prediction_file) 54 elif "prediction_relative" in variables: 55 print("\n".join(map(str, variables["prediction_relative"].tolist())), file=self.target_prediction_file)