gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

outputs.py (2509B)


      1 from __future__ import annotations
      2 from typing import Dict, Optional, Type
      3 import pathlib
      4 import types
      5 
      6 import torch
      7 import transformers
      8 
      9 from gbure.data.dictionary import RelationDictionary
     10 
     11 
     12 class Outputs:
     13     """
     14     Class for outputing data about the learned model.
     15     """
     16     # TODO parametrize this class
     17 
     18     def __init__(self, logdir: pathlib.Path, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: RelationDictionary) -> None:
     19         """ Initialize the outputs variables, but do not acquire necessary resources yet. """
     20         self.logdir: pathlib.Path = logdir
     21         self.tokenizer: transformers.PreTrainedTokenizer = tokenizer
     22         self.relation_dictionary: RelationDictionary = relation_dictionary
     23 
     24     def __enter__(self) -> Outputs:
     25         """ Acquire resources needed for outputing the data. """
     26         # TODO parametrize this file name with split and epoch
     27         self.target_prediction_file = (self.logdir / "target_prediction").open("w")
     28         return self
     29 
     30     def __exit__(self,
     31                  exc_type: Optional[Type[BaseException]],
     32                  exc_inst: Optional[BaseException],
     33                  exc_tb: Optional[types.TracebackType]) -> None:
     34         """ Free used resources. """
     35         self.target_prediction_file.close()
     36 
     37     def update(self, batch: Dict[str, torch.Tensor], loss: torch.Tensor, losses: Dict[str, torch.Tensor], variables: Dict[str, torch.Tensor]) -> None:
     38         """
     39         Update model output data files with the given batch and the outputs of the model on this batch.
     40 
     41         The variables dictionary returned by the model must contain a predicted_relation tensor.
     42 
     43         Args:
     44             batch: the input values used for evaluation
     45             loss: the loss optimized by the model
     46             losses: intermediary (unweighted) losses
     47             variables: internal variables used by the model to compute the loss
     48         """
     49         targets: torch.Tensor = batch.get("relation", batch.get("query_relation"))
     50         predictions: torch.Tensor = variables.get("predicted_relation")
     51         if targets is not None and predictions is not None:
     52             for prediction, target in zip(predictions, targets):
     53                 print(f"{self.relation_dictionary.decode(target)}\t{self.relation_dictionary.decode(prediction)}", file=self.target_prediction_file)
     54         elif "prediction_relative" in variables:
     55             print("\n".join(map(str, variables["prediction_relative"].tolist())), file=self.target_prediction_file)