gbure

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

commit bafbf5ff3a8f8f39573719da56083cacac5345f9
parent a2553cee69c02fca47496c1cb5c23dc99b5ac5c8
Author: Étienne Simon <esimon@esimon.eu>
Date:   Thu, 21 Nov 2019 12:38:01 +0100

Fix metrics

Diffstat:
MREADME | 30+++++++++++++++++++++---------
Mfsre/__init__.py | 1+
Mfsre/config/soares_supervised_kbp37.py | 2+-
Mfsre/config/soares_supervised_semeval.py | 2+-
Dfsre/data/prepare_fewrel.py | 61-------------------------------------------------------------
Mfsre/data/prepare_kbp37.py | 3++-
Mfsre/data/prepare_semeval.py | 14++++++++++++--
Mfsre/data/relation_dictionary.py | 85++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----------
Afsre/eval.py | 30++++++++++++++++++++++++++++++
Mfsre/metrics.py | 228++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----------
Mfsre/train.py | 104+++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------------
11 files changed, 409 insertions(+), 151 deletions(-)

diff --git a/README b/README @@ -1,23 +1,35 @@ Reproduction of Matching the Blanks: Distributional Similarity for Relation Learning by Livio Baldini Soares, Nicholas FitzGerald, Jeffrey Ling, and Tom Kwiatkowski. -This repository currently contains the supervised model for the Semeval and KBP37 datasets, the unsupervised MTB model and the FewRel dataset will be added latter on. -In order to reproduce the results, you first need to manually download and extract the datasets (e.g. in /tmp), then do the following: +This repository currently contains the supervised "entity markers-entity start" model for the Semeval and KBP37 datasets, the unsupervised MTB model and the FewRel dataset will be added latter on. +In order to reproduce the results, you first need to manually download and extract the datasets (e.g. in /tmp), then execute the following: $ export DATA_PATH="/tmp" # the directory containing the extracted datasets $ export LOG_PATH="/tmp" # the directory where the models will be saved $ git clone https://gitlab.lip6.fr/esimon/few-shot-relation-extraction $ cd few-shot-relation-extraction -$ python -m fsre.data.prepare_semeval bert-large-cased +$ python -m fsre.data.prepare_semeval $ python -m fsre.train fsre/config/soares_supervised_semeval.py -$ python -m fsre.data.prepare_kbp37 bert-large-cased +$ python -m fsre.data.prepare_kbp37 $ python -m fsre.train fsre/config/soares_supervised_kbp37.py -I must be missing something since I'm still far from the results reported in the paper: +I must be missing something since I'm a bit away from the results reported in the paper. +Here are the score reached by this repository (official SemEval macro F1 "taking directionality into account"): On Semeval: - valid macro F1: paper 82.1 vs us 81.5 (accuracy 84.7) - test macro F1: paper 89.2 vs us 81.8 (accuracy 84.8) + paper valid: 82.1 + BERT cased valid: 88.55 (std: 0.70) + BERT uncased valid: 87.49 (std: 0.87) + paper test: 89.2 + BERT cased test: 88.24 (std: 0.37) + BERT uncased test: 88.02 (std: 0.65) On KBP37: - valid macro F1: paper 70.0 vs us 65.7 (accuracy 65.5) - test macro F1: paper 68.3 vs us 63.2 (accuracy 64.2) + paper valid: 70.0 + BERT cased valid: 66.92 (std: 0.61) + BERT uncased valid: 66.48 (std: 0.46) + paper test: 68.3 + BERT cased test: 67.18 (std: 0.51) + BERT uncased test: 66.21 (std: 0.62) + +The mean and std are computed over 5 runs, all reported results are for "large" BERT models. +Detailed results can be found at http://www-ia.lip6.fr/~esimon/results.xhtml (temporary link). diff --git a/fsre/__init__.py b/fsre/__init__.py @@ -1,3 +1,4 @@ import fsre.data import fsre.utils import fsre.metrics +import fsre.train diff --git a/fsre/config/soares_supervised_kbp37.py b/fsre/config/soares_supervised_kbp37.py @@ -11,7 +11,7 @@ learning_rate = 3e-5 true_batch_size = 64 # Guessed -validation_metric = "f1" +validation_metric = "half_directed_macro_f1" early_stopping_patience = 2 # Implementation details diff --git a/fsre/config/soares_supervised_semeval.py b/fsre/config/soares_supervised_semeval.py @@ -11,7 +11,7 @@ learning_rate = 3e-5 true_batch_size = 64 # Guessed -validation_metric = "f1" +validation_metric = "half_directed_macro_f1" early_stopping_patience = 2 # Implementation details diff --git a/fsre/data/prepare_fewrel.py b/fsre/data/prepare_fewrel.py @@ -1,61 +0,0 @@ -import argparse -import numpy -import transformers -import tqdm - -from fsre.utils import DATA_PATH -from fsre.data.relation_dictionary import RelationDictionary - - -def load_fewrel_dataset(path, tokenizer, relation_dictionary): - pass - - -def prepare_fewrel(args): - rng = numpy.random.RandomState(args.seed) - fewrel_path = DATA_PATH / "FewRel" - output_path = fewrel_path / args.tokenizer - - if not output_path.is_dir(): - output_path.mkdir() - - relation_dictionary = RelationDictionary() - - tokenizer = transformers.BertTokenizer.from_pretrained(args.tokenizer) - tokenizer.add_special_tokens({"additional_special_tokens": ["<e1>", "</e1>", "<e2>", "</e2>"]}) - tokenizer_path = output_path / "tokenizer" - if not tokenizer_path.is_dir(): - tokenizer_path.mkdir() - tokenizer.save_pretrained(tokenizer_path) - - train = load_fewrel_dataset( - fewrel_path / "train.json", - tokenizer, - relation_dictionary) - rng.shuffle(train) - - valid = load_fewrel_dataset( - fewrel_path / "val.json", - tokenizer, - relation_dictionary) - - numpy.save(output_path / "train.npy", numpy.array(train)) - numpy.save(output_path / "valid.npy", numpy.array(valid)) - - relation_dictionary.save(output_path / "relations") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Prepare the FewRel dataset.") - parser.add_argument("tokenizer", - type=str, - nargs='?', - default="bert-large-uncased", - help="Name of the transformers tokenizer") - parser.add_argument("-s", "--seed", - type=int, - default=0, - help="Seed of the RNG for shuffling the dataset") - - prepare_fewrel(parser.parse_args()) diff --git a/fsre/data/prepare_kbp37.py b/fsre/data/prepare_kbp37.py @@ -10,6 +10,7 @@ from fsre.data.prepare_semeval import load_semeval_dataset TRAIN_SIZE = 15917 VALID_SIZE = 1724 TEST_SIZE = 3405 +UNKNOWN_RELATION = "no_relation" def prepare_kbp37(args): @@ -20,7 +21,7 @@ def prepare_kbp37(args): if not output_path.is_dir(): output_path.mkdir() - relation_dictionary = RelationDictionary() + relation_dictionary = RelationDictionary(unknown=UNKNOWN_RELATION) tokenizer = transformers.BertTokenizer.from_pretrained(args.tokenizer) tokenizer.add_special_tokens({"additional_special_tokens": ["<e1>", "</e1>", "<e2>", "</e2>"]}) diff --git a/fsre/data/prepare_semeval.py b/fsre/data/prepare_semeval.py @@ -8,6 +8,7 @@ from fsre.data.relation_dictionary import RelationDictionary TRAIN_SIZE = 8000 TEST_SIZE = 2717 +UNKNOWN_RELATION = "Other" def load_semeval_dataset(path, tokenizer, relation_dictionary, size): @@ -24,12 +25,21 @@ def load_semeval_dataset(path, tokenizer, relation_dictionary, size): id, raw_text = idtext_line.rstrip().split('\t') id = int(id) + raw_text = raw_text[1:-1] # remove quotes around text text = tokenizer.encode(raw_text, add_special_tokens=True) e1_pos = text.index(be1_id) e2_pos = text.index(be2_id) + if len(text) > tokenizer.max_len: + text = text[:tokenizer.max_len] + e1_pos = min(tokenizer.max_len-1, e1_pos) + e2_pos = min(tokenizer.max_len-1, e2_pos) text = numpy.array(text, dtype=numpy.int32) - relation = relation_dictionary.encode(relation_line.rstrip()) + + relation_line = relation_line.rstrip() + dir_start = relation_line.find('(') + relation_base = relation_line[:dir_start] if dir_start >= 0 else relation_line + relation = relation_dictionary.encode(relation_line, relation_base) dataset.append([id, text, e1_pos, e2_pos, relation]) infile.readline() # Ignore Comment line @@ -45,7 +55,7 @@ def prepare_semeval(args): if not output_path.is_dir(): output_path.mkdir() - relation_dictionary = RelationDictionary() + relation_dictionary = RelationDictionary(unknown=UNKNOWN_RELATION) tokenizer = transformers.BertTokenizer.from_pretrained(args.tokenizer) tokenizer.add_special_tokens({"additional_special_tokens": ["<e1>", "</e1>", "<e2>", "</e2>"]}) diff --git a/fsre/data/relation_dictionary.py b/fsre/data/relation_dictionary.py @@ -1,31 +1,94 @@ +import pickle + + class RelationDictionary: - """ A very simple dictionary to be used for relations. """ - def __init__(self, path=None): + """ + A dictionary to be used for relations. + + The tokens held by this class are divided between: + - *relation* such as "Entity-Destination(e1,e2)" + - *base* such as "Entity-Destination" + """ + + def __init__(self, *, unknown=None, path=None): self.encoder = {} self.decoder = [] + + self.base_encoder = {} + self.base_decoder = [] + self.id_to_bid = [] + + self.unknown = unknown + if unknown is not None: + self.encoder[unknown] = 0 + self.decoder.append(unknown) + self.base_encoder[unknown] = 0 + self.base_decoder.append(unknown) + self.id_to_bid.append(0) + if path is not None: self.load(path) def __len__(self): + """ Number of relations in the dictionary. """ return len(self.decoder) - def encode(self, token): - id = self.encoder.get(token) + def base_size(self): + """ Number of bases in the dictionary. """ + return len(self.base_decoder) + + def encode(self, relation, base=None): + """ + Returns the id corresponding to a relation string. + + Args: + relation: the string of the relation (e.g. "Entity-Destination(e1,e2)") + base: the string of the base relation (e.g. "Entity-Destination") + """ + + if relation is None: + return None + + if base is None: + return self.encoder[relation] + + id = self.encoder.get(relation) if id is not None: return id + + bid = self.base_encoder.get(base) + if bid is None: + bid = len(self.base_decoder) + self.base_encoder[base] = bid + self.base_decoder.append(base) + id = len(self.decoder) - self.encoder[token] = id - self.decoder.append(token) + self.encoder[relation] = id + self.decoder.append(relation) + self.id_to_bid.append(bid) return id def decode(self, id): + """ Returns the string corresponding to a relation id. """ return self.decoder[id] + def base_id(self, id): + """ Returns the base id corresponding to a relation id. """ + return self.id_to_bid[id] + def save(self, path): - with open(path, 'w') as file: - file.writelines(map(lambda x: f"{x}\n", self.decoder)) + with open(path, "wb") as file: + pickle.dump({ + "unknown": self.unknown, + "decoder": self.decoder, + "encoder": self.encoder, + "base_encoder": self.base_encoder, + "base_decoder": self.base_decoder, + "id_to_bid": self.id_to_bid, + }, file) def load(self, path): - with open(path, 'r') as file: - self.decoder = list(map(str.rstrip, file.readlines())) - self.encoder = dict(zip(self.decoder, range(len(self.decoder)))) + with open(path, "rb") as file: + data = pickle.load(file) + for key, value in data.items(): + setattr(self, key, value) diff --git a/fsre/eval.py b/fsre/eval.py @@ -0,0 +1,30 @@ +import torch + +import fsre + + +class Evaluator(fsre.train.Trainer): + """ + Evaluate a model. + """ + + def __init__(self, config, state_dicts): + self.eval_config = config + super().__init__(state_dicts["config"], None, state_dicts) + + def run(self): + self.epoch = self.state_dicts["epoch"] + self.info() + self.initialize_rng() + self.prepare_dataset() + self.build_model() + self.count_parameters() + self.evaluate("test") + + +if __name__ == "__main__": + fsre.utils.fix_transformers_logging_handler() + config = fsre.utils.parse_args() + + state_dicts = torch.load(config.load) + Evaluator(config, state_dicts).run() diff --git a/fsre/metrics.py b/fsre/metrics.py @@ -1,3 +1,4 @@ +import math import numpy import torch @@ -6,28 +7,49 @@ class Metrics: """ Class for computing metrics. - Five metrics are computed: + Twenty metrics are computed: - Accuracy - - Macro F1 - - Macro Precision - - Macro Recall - - Negative Log Likelihood + - Negative Log Likelihood (nll) + - {directed, undirected, half_directed} {micro, macro} {f1, precision, recall} + Note that the Accuracy is the true accuracy, taking directionality into account and scoring the unknown relation as any other relation. + The last 18 metrics follow the SemEval scorer: + - The unknown ("Other") relation is only scored indirectly + - Directed is equivalent to the metrics "USING DIRECTIONALITY" + - Undirected is equivalent to the metrics "IGNORING DIRECTIONALITY" + - Half-directed is equivalent to the metrics "TAKING DIRECTIONALITY INTO ACCOUNT -- OFFICIAL" + Note that the directed and half_directed micro metrics are equivalents. """ - def __init__(self, nclass): + def __init__(self, relation_dictionary): """ Initialize all metrics. Args: - nclass: number of relations + relation_dictionary: see class RelationDictionary """ - self.n = nclass + self.relation_dictionary = relation_dictionary + self.n = len(relation_dictionary) + self.m = relation_dictionary.base_size() + self.build_mask() + self.build_base_transition() self.crossentropy = torch.nn.CrossEntropyLoss(reduction="sum") self.size = 0 + self.correct = 0 self.ce_sum = 0 - self.confusion = numpy.zeros((nclass, nclass), numpy.int64) + self.confusion = numpy.zeros((self.n, self.n), numpy.int64) + + def build_mask(self): + self.mask = numpy.ones(self.n) + if self.relation_dictionary.unknown is not None: + assert(self.relation_dictionary.decode(0) == self.relation_dictionary.unknown) + self.mask[0] = 0 + + def build_base_transition(self): + self.base_transition = numpy.zeros((self.n, self.m)) + for id, bid in enumerate(self.relation_dictionary.id_to_bid): + self.base_transition[id, bid] = 1 def update(self, predictions, target): """ @@ -39,46 +61,190 @@ class Metrics: """ self.size += predictions.shape[0] - self.ce_sum += self.crossentropy(predictions, target).cpu().item() + self.ce_sum += self.crossentropy(predictions, target).item() prediction = predictions.argmax(1) - for p, t in zip(prediction.cpu(), target.cpu()): - self.confusion[p.item(), t.item()] += 1 + for p, t in zip(prediction.tolist(), target.tolist()): + self.confusion[p, t] += 1 + self.correct += (p == t) + + @property + def summary(self): + return {"accuracy": f"{self.accuracy*100:.2f}", + "nll": f"{self.nll:.2f}"} + + @property + def all(self): + keys = ["accuracy", "nll"] + [ + f"{direction}_{level}_{metric}" + for direction in ["directed", "undirected", "half_directed"] + for level in ["macro", "micro"] + for metric in ["f1", "precision", "recall"]] + return {key: getattr(self, key) for key in keys} + + @property + def base_mask(self): + return self.mask.dot(self.base_transition).clip(0, 1) + + @property + def base_confusion(self): + return self.base_transition.T.dot(self.confusion).dot(self.base_transition) @property def accuracy(self): - return self.confusion.diagonal().sum() / (self.confusion.sum() + 1e-12) + return math.nan if self.size == 0 else self.correct / self.size @property - def class_precision(self): - return self.confusion.diagonal() / (self.confusion.sum(1) + 1e-12) + def nll(self): + return math.nan if self.size == 0 else self.ce_sum / self.size + + ########################## + # Directed macro metrics # + ########################## @property - def class_recall(self): - return self.confusion.diagonal() / (self.confusion.sum(0) + 1e-12) + def directed_class_precision(self): + norm = self.confusion.sum(1) + norm[norm == 0] = 1 + return self.confusion.diagonal() / norm @property - def class_f1(self): - return 2 * self.class_precision * self.class_recall / (self.class_precision + self.class_recall + 1e-12) + def directed_class_recall(self): + norm = self.confusion.sum(0) + norm[norm == 0] = 1 + return self.confusion.diagonal() / norm @property - def precision(self): - return self.class_precision.mean() + def directed_class_f1(self): + norm = self.directed_class_precision + self.directed_class_recall + norm[norm == 0] = 1 + return 2 * self.directed_class_precision * self.directed_class_recall / norm @property - def recall(self): - return self.class_recall.mean() + def directed_macro_precision(self): + return numpy.sum(self.directed_class_precision * self.mask) / self.mask.sum() @property - def f1(self): - return self.class_f1.mean() + def directed_macro_recall(self): + return numpy.sum(self.directed_class_recall * self.mask) / self.mask.sum() @property - def nll(self): - return self.ce_sum / self.size + def directed_macro_f1(self): + return numpy.sum(self.directed_class_f1 * self.mask) / self.mask.sum() + + ############################ + # Undirected macro metrics # + ############################ @property - def summary(self): - return {"accuracy": f"{self.accuracy*100:.2f}", - "f1": f"{self.f1*100:.2f}", - "nll": f"{self.nll:.2f}"} + def undirected_class_precision(self): + norm = self.base_confusion.sum(1) + norm[norm == 0] = 1 + return self.base_confusion.diagonal() / norm + + @property + def undirected_class_recall(self): + norm = self.base_confusion.sum(0) + norm[norm == 0] = 1 + return self.base_confusion.diagonal() / norm + + @property + def undirected_class_f1(self): + norm = self.undirected_class_precision + self.undirected_class_recall + norm[norm == 0] = 1 + return 2 * self.undirected_class_precision * self.undirected_class_recall / norm + + @property + def undirected_macro_precision(self): + return numpy.sum(self.undirected_class_precision * self.base_mask) / self.base_mask.sum() + + @property + def undirected_macro_recall(self): + return numpy.sum(self.undirected_class_recall * self.base_mask) / self.base_mask.sum() + + @property + def undirected_macro_f1(self): + return numpy.sum(self.undirected_class_f1 * self.base_mask) / self.base_mask.sum() + + ############################### + # Half-directed macro metrics # + ############################### + + @property + def half_directed_class_precision(self): + norm = self.base_confusion.sum(1) + norm[norm == 0] = 1 + return self.confusion.diagonal().dot(self.base_transition) / norm + + @property + def half_directed_class_recall(self): + norm = self.base_confusion.sum(0) + norm[norm == 0] = 1 + return self.confusion.diagonal().dot(self.base_transition) / norm + + @property + def half_directed_class_f1(self): + norm = self.half_directed_class_precision + self.half_directed_class_recall + norm[norm == 0] = 1 + return 2 * self.half_directed_class_precision * self.half_directed_class_recall / norm + + @property + def half_directed_macro_precision(self): + return numpy.sum(self.half_directed_class_precision * self.base_mask) / self.base_mask.sum() + + @property + def half_directed_macro_recall(self): + return numpy.sum(self.half_directed_class_recall * self.base_mask) / self.base_mask.sum() + + @property + def half_directed_macro_f1(self): + return numpy.sum(self.half_directed_class_f1 * self.base_mask) / self.base_mask.sum() + + ################# + # Micro metrics # + ################# + + @property + def directed_micro_precision(self): + norm = numpy.sum(self.confusion.sum(1) * self.mask) + return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm + + @property + def directed_micro_recall(self): + norm = numpy.sum(self.confusion.sum(0) * self.mask) + return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm + + @property + def directed_micro_f1(self): + norm = self.directed_micro_precision + self.directed_micro_recall + return 0 if norm == 0 else 2 * (self.directed_micro_precision * self.directed_micro_recall) / norm + + @property + def half_directed_micro_precision(self): + norm = numpy.sum(self.confusion.sum(1) * self.mask) + return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm + + @property + def half_directed_micro_recall(self): + norm = numpy.sum(self.confusion.sum(0) * self.mask) + return 0 if norm == 0 else numpy.sum(self.confusion.diagonal() * self.mask) / norm + + @property + def half_directed_micro_f1(self): + norm = self.half_directed_micro_precision + self.half_directed_micro_recall + return 0 if norm == 0 else 2 * (self.half_directed_micro_precision * self.half_directed_micro_recall) / norm + + @property + def undirected_micro_precision(self): + norm = numpy.sum(self.base_confusion.sum(1) * self.base_mask) + return 0 if norm == 0 else numpy.sum(self.base_confusion.diagonal() * self.base_mask) / norm + + @property + def undirected_micro_recall(self): + norm = numpy.sum(self.base_confusion.sum(0) * self.base_mask) + return 0 if norm == 0 else numpy.sum(self.base_confusion.diagonal() * self.base_mask) / norm + + @property + def undirected_micro_f1(self): + norm = self.undirected_micro_precision + self.undirected_micro_recall + return 0 if norm == 0 else 2 * (self.undirected_micro_precision * self.undirected_micro_recall) / norm diff --git a/fsre/train.py b/fsre/train.py @@ -1,6 +1,8 @@ import sys import os import math +import time +import contextlib import multiprocessing import signal import logging @@ -19,15 +21,20 @@ class Trainer: Train a model. Config: + Model: the model class to use for training batch_size: the number of samples in the batch of data loaded - true_batch_size: the actual number of sample in a batch, the number of - sample seen before a backward (must be a multiple of batch_size) + bert_model: the model of transformer to use dataset_name: name of the dataset to load deterministic: run in deterministic mode + early_stopping_patience: how many epoch to train after best validation score has been reached learning_rate: learning rate max_epoch: maximum number of epoch + no_initial_validation: do not run evaluation on the valid dataset before first epoch seed: the seed for the random number generator sort_per_shuffle_bucket: the number of sort buckets in a shuffle bucket + test_output: path to a file where the test predictions will be written + true_batch_size: the actual number of sample in a batch, the number of sample seen before a backward (must be a multiple of batch_size) + validation_metric: metric used for early stopping """ def __init__(self, config, logdir, state_dicts=None): @@ -83,7 +90,11 @@ class Trainer: logger.info(f"GPU{i} {gp.name} {gp.total_memory} {gp.major}.{gp.minor}") def info(self): - print(f"logdir is \033[1m\033[33m{self.logdir}\033[0m") + if self.logdir is None: + print(f"logdir is \033[1m\033[31mnot set\033[0m, log messages will be discarded") + else: + print(f"logdir is \033[1m\033[33m{self.logdir}\033[0m") + self.environment_check() self.detect_gpus() print("") @@ -103,7 +114,7 @@ class Trainer: suffix = "" if self.state_dicts: suffix = time.strftime("%FT%H:%M:%S") - fsre.utils.save_patch(self.logdir / "patch{suffix}") + fsre.utils.save_patch(self.logdir / f"patch{suffix}") def initialize_rng(self): if self.state_dicts: @@ -120,7 +131,7 @@ class Trainer: def prepare_dataset(self): data_dir = fsre.utils.DATA_PATH / self.config.dataset_name / self.config.bert_model - self.relation_dictionary = fsre.data.RelationDictionary(data_dir / "relations") + self.relation_dictionary = fsre.data.RelationDictionary(path=data_dir / "relations") self.tokenizer = transformers.BertTokenizer.from_pretrained(data_dir / "tokenizer") self.dataset = {} @@ -178,6 +189,22 @@ class Trainer: signal.signal(signal.SIGINT, handler) + def eval_context(self, dataset): + self.test_output_file = None + if dataset == "test" and self.config.get("test_output"): + self.test_output_file = open(self.config.test_output, 'w') + return self.test_output_file + return contextlib.nullcontext() + + def eval_handle_predictions(self, ids, predictions): + if self.test_output_file is None: + return + ids = ids.tolist() + predictions = predictions.argmax(1).tolist() + for id, prediction in zip(ids, predictions): + prediction = self.relation_dictionary.decode(prediction) + self.test_output_file.write(f"{id}\t{prediction}\n") + def evaluate(self, dataset): loop = tqdm.tqdm( iterable=self.iterator[dataset](), @@ -188,36 +215,45 @@ class Trainer: leave=False) self.model.eval() - with torch.no_grad(): - scorer = fsre.metrics.Metrics(len(self.relation_dictionary)) + output_prediction_to_file = True + with torch.no_grad(), self.eval_context(dataset): + has_target = False + scorer = fsre.metrics.Metrics(self.relation_dictionary) correct_prediction = 0 for batch in loop: batch = {key: value.to(self.device) for key, value in batch.items()} # Pop the target to ensure it's not used by the model - target = batch.pop("relation") + if "relation" in batch: + target = batch.pop("relation") + has_target = True predictions = self.model(batch) - scorer.update(predictions, target) - loop.set_postfix(**scorer.summary, refresh=False) - - print(f"Epoch {self.epoch} {dataset:5} accuracy: {scorer.accuracy*100:8.4f}% F1: {scorer.f1*100:8.4f}% (P: {scorer.precision*100:8.4f}% R: {scorer.recall*100:8.4f}%) NLL: {scorer.nll:8.4f}") - logger.info(f"epoch {self.epoch} {dataset} accuracy {scorer.accuracy} F1 {scorer.f1} precision {scorer.precision} recall {scorer.recall} NLL {scorer.nll}") - return getattr(scorer, self.config.validation_metric) - - def save(self, path, full): - state_dicts = {"model": self.model.state_dict()} - if full: - state_dicts.update({ - "logdir": self.logdir, - "optimizer": self.optimizer.state_dict(), - "train_rng": self.dataset["train"].state_dict(), - "torch_rng": torch.random.get_rng_state(), - "epoch": self.epoch, - "best_epoch": self.best_epoch, - "best_eval": self.best_eval, - }) - if torch.cuda.is_available(): - state_dicts["cuda_rng"] = torch.cuda.random.get_rng_state_all() + self.eval_handle_predictions(batch["id"], predictions) + if has_target: + scorer.update(predictions, target) + loop.set_postfix(**scorer.summary, refresh=False) + + if has_target: + print(f"Epoch {self.epoch} {dataset:5} accuracy: {scorer.accuracy*100:8.4f}% Half-directed Macro F1: {scorer.half_directed_macro_f1*100:8.4f}% (P: {scorer.half_directed_macro_precision*100:8.4f}% R: {scorer.half_directed_macro_recall*100:8.4f}%) NLL: {scorer.nll:8.4f}") + for metric, value in scorer.all.items(): + logger.info(f"epoch {self.epoch} {dataset} {metric} {value}") + return getattr(scorer, self.config.validation_metric) + return None + + def save(self, path): + state_dicts = { + "logdir": self.logdir, + "config": self.config, + "model": self.model.state_dict(), + "optimizer": self.optimizer.state_dict(), + "train_rng": self.dataset["train"].state_dict(), + "torch_rng": torch.random.get_rng_state(), + "epoch": self.epoch, + "best_epoch": self.best_epoch, + "best_eval": self.best_eval, + } + if torch.cuda.is_available(): + state_dicts["cuda_rng"] = torch.cuda.random.get_rng_state_all() torch.save(state_dicts, path) def train(self): @@ -275,14 +311,14 @@ class Trainer: print(f"Epoch {self.epoch} train mean loss: {total_loss / total_sample:8.4f}") logger.info(f"epoch {self.epoch} train mean_loss {total_loss / total_sample}") - self.save(self.logdir / "checkpoint.new", full=True) + self.save(self.logdir / "checkpoint.new") os.rename(self.logdir / "checkpoint.new", self.logdir / "checkpoint") candidate = self.evaluate("valid") if candidate > self.best_eval: self.best_epoch = self.epoch self.best_eval = candidate - self.save(best_path, full=False) + self.save(best_path) logger.info(f"Model saved to {best_path}") elif self.epoch - self.best_epoch > self.config.get("early_stopping_patience", self.config.max_epoch): break @@ -302,10 +338,10 @@ if __name__ == "__main__": state_dicts = None if config.get("load"): - state_dicts = torch.load(config.load, strict=False) - - if config.get("load"): + state_dicts = torch.load(config.load) logdir = state_dicts["logdir"] + if config.get("reuse_config"): + config = state_dicts["config"] else: logdir = fsre.utils.logdir_name("FSRE") assert(not logdir.exists())