gbure

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

commit 984e8ea40d1b0ce5788af392cfa5b9b67b6f4ea4
Author: Étienne Simon <esimon@esimon.eu>
Date:   Thu, 14 Nov 2019 18:23:16 +0100

Initial commit

Diffstat:
ALICENSE | 5+++++
AREADME | 23+++++++++++++++++++++++
Afsre/__init__.py | 3+++
Afsre/config/soares_supervised_kbp37.py | 21+++++++++++++++++++++
Afsre/config/soares_supervised_semeval.py | 21+++++++++++++++++++++
Afsre/data/__init__.py | 2++
Afsre/data/dataset.py | 116+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Afsre/data/prepare_fewrel.py | 61+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Afsre/data/prepare_kbp37.py | 71+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Afsre/data/prepare_semeval.py | 96+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Afsre/data/relation_dictionary.py | 31+++++++++++++++++++++++++++++++
Afsre/metrics.py | 84+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Afsre/model/mtb_classifier.py | 58++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Afsre/model/mtb_supervised.py | 39+++++++++++++++++++++++++++++++++++++++
Afsre/train.py | 317+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Afsre/utils.py | 143+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Arequirements.txt | 4++++
17 files changed, 1095 insertions(+), 0 deletions(-)

diff --git a/LICENSE b/LICENSE @@ -0,0 +1,5 @@ +Copyright 2019 Étienne Simon + +Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. diff --git a/README b/README @@ -0,0 +1,23 @@ +Reproduction of Matching the Blanks: Distributional Similarity for Relation Learning by Livio Baldini Soares, Nicholas FitzGerald, Jeffrey Ling, and Tom Kwiatkowski. + +This repository currently contains the supervised model for the Semeval and KBP37 datasets, the unsupervised MTB model and the FewRel dataset will be added latter on. +In order to reproduce the results, you first need to manually download and extract the datasets (e.g. in /tmp), then do the following: + +$ export DATA_PATH="/tmp" # the directory containing the extracted datasets +$ export LOG_PATH="/tmp" # the directory where the models will be saved +$ git clone https://gitlab.lip6.fr/esimon/few-shot-relation-extraction +$ cd few-shot-relation-extraction +$ python -m fsre.data.prepare_semeval bert-base-cased +$ python -m fsre.train fsre/config/soares_supervised_semeval.py +$ python -m fsre.data.prepare_kbp37 bert-base-cased +$ python -m fsre.train fsre/config/soares_supervised_kbp37.py + +I must be missing something since I'm still far from the results reported in the paper: + +On Semeval: + valid macro F1: paper 82.1 vs us 81.5 (accuracy 84.7) + test macro F1: paper 89.2 vs us 81.8 (accuracy 84.8) + +On KBP37: + valid macro F1: paper 70.0 vs us 65.7 (accuracy 65.5) + test macro F1: paper 68.3 vs us 63.2 (accuracy 64.2) diff --git a/fsre/__init__.py b/fsre/__init__.py @@ -0,0 +1,3 @@ +import fsre.data +import fsre.utils +import fsre.metrics diff --git a/fsre/config/soares_supervised_kbp37.py b/fsre/config/soares_supervised_kbp37.py @@ -0,0 +1,21 @@ +from fsre.model.mtb_supervised import Model + + +dataset_name = "KBP37" + +# From Table 1 +bert_model = "bert-large-cased" +post_transformer_layer = "linear" +max_epoch = 10 +learning_rate = 3e-5 +true_batch_size = 64 + +# Guessed +validation_metric = "f1" +early_stopping_patience = 2 + +# Implementation details +seed = 0 +batch_size = 2 +batch_per_sort_bucket = 8 +sort_per_shuffle_bucket = 8 diff --git a/fsre/config/soares_supervised_semeval.py b/fsre/config/soares_supervised_semeval.py @@ -0,0 +1,21 @@ +from fsre.model.mtb_supervised import Model + + +dataset_name = "SemEval2010_task8_all_data" + +# From Table 1 +bert_model = "bert-large-cased" +post_transformer_layer = "layer_norm" +max_epoch = 10 +learning_rate = 3e-5 +true_batch_size = 64 + +# Guessed +validation_metric = "f1" +early_stopping_patience = 2 + +# Implementation details +seed = 0 +batch_size = 2 +batch_per_sort_bucket = 8 +sort_per_shuffle_bucket = 8 diff --git a/fsre/data/__init__.py b/fsre/data/__init__.py @@ -0,0 +1,2 @@ +from fsre.data.dataset import RelationExtractionDataset +from fsre.data.relation_dictionary import RelationDictionary diff --git a/fsre/data/dataset.py b/fsre/data/dataset.py @@ -0,0 +1,116 @@ +import math +import numpy +import torch + + +class RelationExtractionDataset(torch.utils.data.IterableDataset): + """ + Read a preprocessed Relation Extraction dataset from a .npy file. + + When generating data, we first read a large shuffle bucket which is + shuffled (unless the dataset is for evaluation). This shuffle_bucket + is then cut down into several sort buckets, each of them is sorted + so that sentences of similar length end up next to each other. The + sort buckets are then cut down into batches which pad the sentences. + However this class generate samples, the proper batching should be + done by a DataLoader. + A preprocessed dataset can be created from the fsre.data.prepare_* + modules. + + Config: + batch_per_sort_bucket: the number of batches in a sort bucket + batch_size: the number of samples in a batch + seed: the seed for the random number generator + sort_per_shuffle_bucket: the number of sort buckets in a shuffle bucket + """ + + def __init__(self, config, path, pad, evaluation, rng=None): + """ + Initialize a Relation Extraction dataset and load the data in RAM. + + Args: + config: global config object + path: path to the dataset to load, this should be a .npy file + pad: the value used to pad text in a batch + evaluation: whether this dataset is an evaluation one (no need to sort then) + rng: the random number generator to use for shuffling + """ + + super().__init__() + + self.config = config + self.pad = pad + self.evaluation = evaluation + + self.data = numpy.load(path, allow_pickle=True) + if not evaluation: + self.rng = rng if rng is not None else numpy.random.RandomState(config.seed) + + self.batch_size = config.batch_size + self.sort_bucket_size = config.batch_size * config.batch_per_sort_bucket + self.shuffle_bucket_size = self.sort_bucket_size * config.sort_per_shuffle_bucket + + def __len__(self): + return len(self.data) + + def pad_text(self, text, diff, pad): + """ Append diff tokens pad to the end of text """ + text = torch.tensor(text, dtype=torch.int64) + padding = text.new_full((diff,), pad) + return torch.cat((text, padding)) + + def iter_sample(self, samples): + """ Generate samples from a batch """ + lengths = [sample[1].shape[0] for sample in samples] + max_len = max(lengths) + + for sample, length in zip(samples, lengths): + ds = dict(zip(["id", "text", "e1_pos", "e2_pos", "relation"], sample)) + ds["length"] = length + ds["mask"] = self.pad_text(numpy.ones_like(ds["text"]), max_len - length, 0) + ds["text"] = self.pad_text(ds["text"], max_len - length, self.pad) + yield ds + + def iter_batch(self, sort_bucket): + """ Generate batches from a sort_bucket """ + for batch_start in range(0, len(sort_bucket), self.batch_size): + batch_end = batch_start + self.batch_size + batch_end = min(batch_end, len(sort_bucket)) + + yield from self.iter_sample([self.data[i] for i in sort_bucket[batch_start:batch_end]]) + + def iter_shuffle_bucket(self, shuffle_bucket): + """ Generate sort_buckets from a shuffle_bucket """ + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + start = 0 + end = len(shuffle_bucket) + else: + work_size = len(shuffle_bucket) / self.sort_bucket_size + per_worker = int(math.ceil(work_size / float(worker_info.num_workers))) * self.sort_bucket_size + start = worker_info.id * per_worker + end = min(start + per_worker, len(shuffle_bucket)) + + for sort_start in range(start, end, self.sort_bucket_size): + sort_end = sort_start + self.sort_bucket_size + sort_end = min(sort_end, end) + + sort_bucket = shuffle_bucket[sort_start:sort_end] + + sort_bucket.sort(key=lambda x: self.data[x, 1].shape[0]) + yield from self.iter_batch(sort_bucket) + + def __iter__(self): + """ Generate shuffle_buckets from the dataset """ + for shuffle_start in range(0, len(self), self.shuffle_bucket_size): + shuffle_end = shuffle_start + self.shuffle_bucket_size + shuffle_end = min(shuffle_end, len(self)) + + shuffle_bucket = list(range(shuffle_start, shuffle_end)) + + if not self.evaluation: + self.rng.shuffle(shuffle_bucket) + yield from self.iter_shuffle_bucket(shuffle_bucket) + + def state_dict(self): + return {"rng": self.rng} diff --git a/fsre/data/prepare_fewrel.py b/fsre/data/prepare_fewrel.py @@ -0,0 +1,61 @@ +import argparse +import numpy +import transformers +import tqdm + +from fsre.utils import DATA_PATH +from fsre.data.relation_dictionary import RelationDictionary + + +def load_fewrel_dataset(path, tokenizer, relation_dictionary): + pass + + +def prepare_fewrel(args): + rng = numpy.random.RandomState(args.seed) + fewrel_path = DATA_PATH / "FewRel" + output_path = fewrel_path / args.tokenizer + + if not output_path.is_dir(): + output_path.mkdir() + + relation_dictionary = RelationDictionary() + + tokenizer = transformers.BertTokenizer.from_pretrained(args.tokenizer) + tokenizer.add_special_tokens({"additional_special_tokens": ["<e1>", "</e1>", "<e2>", "</e2>"]}) + tokenizer_path = output_path / "tokenizer" + if not tokenizer_path.is_dir(): + tokenizer_path.mkdir() + tokenizer.save_pretrained(tokenizer_path) + + train = load_fewrel_dataset( + fewrel_path / "train.json", + tokenizer, + relation_dictionary) + rng.shuffle(train) + + valid = load_fewrel_dataset( + fewrel_path / "val.json", + tokenizer, + relation_dictionary) + + numpy.save(output_path / "train.npy", numpy.array(train)) + numpy.save(output_path / "valid.npy", numpy.array(valid)) + + relation_dictionary.save(output_path / "relations") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Prepare the FewRel dataset.") + parser.add_argument("tokenizer", + type=str, + nargs='?', + default="bert-large-uncased", + help="Name of the transformers tokenizer") + parser.add_argument("-s", "--seed", + type=int, + default=0, + help="Seed of the RNG for shuffling the dataset") + + prepare_fewrel(parser.parse_args()) diff --git a/fsre/data/prepare_kbp37.py b/fsre/data/prepare_kbp37.py @@ -0,0 +1,71 @@ +import argparse +import numpy +import transformers +import tqdm + +from fsre.utils import DATA_PATH +from fsre.data.relation_dictionary import RelationDictionary +from fsre.data.prepare_semeval import load_semeval_dataset + +TRAIN_SIZE = 15917 +VALID_SIZE = 1724 +TEST_SIZE = 3405 + + +def prepare_kbp37(args): + rng = numpy.random.RandomState(args.seed) + kbp37_path = DATA_PATH / "KBP37" + output_path = kbp37_path / args.tokenizer + + if not output_path.is_dir(): + output_path.mkdir() + + relation_dictionary = RelationDictionary() + + tokenizer = transformers.BertTokenizer.from_pretrained(args.tokenizer) + tokenizer.add_special_tokens({"additional_special_tokens": ["<e1>", "</e1>", "<e2>", "</e2>"]}) + tokenizer_path = output_path / "tokenizer" + if not tokenizer_path.is_dir(): + tokenizer_path.mkdir() + tokenizer.save_pretrained(tokenizer_path) + + train = load_semeval_dataset( + kbp37_path / "train.txt", + tokenizer, + relation_dictionary, + TRAIN_SIZE) + rng.shuffle(train) + + valid = load_semeval_dataset( + kbp37_path / "dev.txt", + tokenizer, + relation_dictionary, + VALID_SIZE) + + test = load_semeval_dataset( + kbp37_path / "test.txt", + tokenizer, + relation_dictionary, + TEST_SIZE) + + numpy.save(output_path / "train.npy", numpy.array(train)) + numpy.save(output_path / "valid.npy", numpy.array(valid)) + numpy.save(output_path / "test.npy", numpy.array(test)) + + relation_dictionary.save(output_path / "relations") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Prepare the KBP37 dataset.") + parser.add_argument("tokenizer", + type=str, + nargs='?', + default="bert-large-uncased", + help="Name of the transformers tokenizer") + parser.add_argument("-s", "--seed", + type=int, + default=0, + help="Seed of the RNG for shuffling the dataset") + + prepare_kbp37(parser.parse_args()) diff --git a/fsre/data/prepare_semeval.py b/fsre/data/prepare_semeval.py @@ -0,0 +1,96 @@ +import argparse +import numpy +import transformers +import tqdm + +from fsre.utils import DATA_PATH +from fsre.data.relation_dictionary import RelationDictionary + +TRAIN_SIZE = 8000 +TEST_SIZE = 2717 + + +def load_semeval_dataset(path, tokenizer, relation_dictionary, size): + be1_id = tokenizer.added_tokens_encoder["<e1>"] + be2_id = tokenizer.added_tokens_encoder["<e2>"] + + dataset = [] + with open(path) as infile: + for _ in tqdm.trange(size): + idtext_line = infile.readline() + relation_line = infile.readline() + if not (idtext_line and relation_line): + break + + id, raw_text = idtext_line.rstrip().split('\t') + id = int(id) + raw_text = raw_text[1:-1] # remove quotes around text + text = tokenizer.encode(raw_text, add_special_tokens=True) + e1_pos = text.index(be1_id) + e2_pos = text.index(be2_id) + text = numpy.array(text, dtype=numpy.int32) + relation = relation_dictionary.encode(relation_line.rstrip()) + + dataset.append([id, text, e1_pos, e2_pos, relation]) + infile.readline() # Ignore Comment line + infile.readline() # Ignore empty line + return dataset + + +def prepare_semeval(args): + rng = numpy.random.RandomState(args.seed) + semeval_path = DATA_PATH / "SemEval2010_task8_all_data" + output_path = semeval_path / args.tokenizer + + if not output_path.is_dir(): + output_path.mkdir() + + relation_dictionary = RelationDictionary() + + tokenizer = transformers.BertTokenizer.from_pretrained(args.tokenizer) + tokenizer.add_special_tokens({"additional_special_tokens": ["<e1>", "</e1>", "<e2>", "</e2>"]}) + tokenizer_path = output_path / "tokenizer" + if not tokenizer_path.is_dir(): + tokenizer_path.mkdir() + tokenizer.save_pretrained(tokenizer_path) + + dataset = load_semeval_dataset( + semeval_path / "SemEval2010_task8_training" / "TRAIN_FILE.TXT", + tokenizer, + relation_dictionary, + TRAIN_SIZE) + rng.shuffle(dataset) + train = dataset[args.valid_size:] + valid = dataset[:args.valid_size] + + test = load_semeval_dataset( + semeval_path / "SemEval2010_task8_testing_keys" / "TEST_FILE_FULL.TXT", + tokenizer, + relation_dictionary, + TEST_SIZE) + + numpy.save(output_path / "train.npy", numpy.array(train)) + numpy.save(output_path / "valid.npy", numpy.array(valid)) + numpy.save(output_path / "test.npy", numpy.array(test)) + + relation_dictionary.save(output_path / "relations") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Prepare the SemEval 2010 Task 8 dataset.") + parser.add_argument("tokenizer", + type=str, + nargs='?', + default="bert-large-uncased", + help="Name of the transformers tokenizer") + parser.add_argument("-s", "--seed", + type=int, + default=0, + help="Seed of the RNG for shuffling the dataset") + parser.add_argument("-v", "--valid-size", + type=int, + default=1500, + help="Size of the validation set") + + prepare_semeval(parser.parse_args()) diff --git a/fsre/data/relation_dictionary.py b/fsre/data/relation_dictionary.py @@ -0,0 +1,31 @@ +class RelationDictionary: + """ A very simple dictionary to be used for relations. """ + def __init__(self, path=None): + self.encoder = {} + self.decoder = [] + if path is not None: + self.load(path) + + def __len__(self): + return len(self.decoder) + + def encode(self, token): + id = self.encoder.get(token) + if id is not None: + return id + id = len(self.decoder) + self.encoder[token] = id + self.decoder.append(token) + return id + + def decode(self, id): + return self.decoder[id] + + def save(self, path): + with open(path, 'w') as file: + file.writelines(map(lambda x: f"{x}\n", self.decoder)) + + def load(self, path): + with open(path, 'r') as file: + self.decoder = list(map(str.rstrip, file.readlines())) + self.encoder = dict(zip(self.decoder, range(len(self.decoder)))) diff --git a/fsre/metrics.py b/fsre/metrics.py @@ -0,0 +1,84 @@ +import numpy +import torch + + +class Metrics: + """ + Class for computing metrics. + + Five metrics are computed: + - Accuracy + - Macro F1 + - Macro Precision + - Macro Recall + - Negative Log Likelihood + """ + + def __init__(self, nclass): + """ + Initialize all metrics. + + Args: + nclass: number of relations + """ + + self.n = nclass + self.crossentropy = torch.nn.CrossEntropyLoss(reduction="sum") + + self.size = 0 + self.ce_sum = 0 + self.confusion = numpy.zeros((nclass, nclass), numpy.int64) + + def update(self, predictions, target): + """ + Update metrics with a batch of predictions and corresponding targets. + + Args: + predictions: the predicted logits (before softmax) + target: the gold relations + """ + + self.size += predictions.shape[0] + self.ce_sum += self.crossentropy(predictions, target).cpu().item() + + prediction = predictions.argmax(1) + for p, t in zip(prediction.cpu(), target.cpu()): + self.confusion[p.item(), t.item()] += 1 + + @property + def accuracy(self): + return self.confusion.diagonal().sum() / (self.confusion.sum() + 1e-12) + + @property + def class_precision(self): + return self.confusion.diagonal() / (self.confusion.sum(1) + 1e-12) + + @property + def class_recall(self): + return self.confusion.diagonal() / (self.confusion.sum(0) + 1e-12) + + @property + def class_f1(self): + return 2 * self.class_precision * self.class_recall / (self.class_precision + self.class_recall + 1e-12) + + @property + def precision(self): + return self.class_precision.mean() + + @property + def recall(self): + return self.class_recall.mean() + + @property + def f1(self): + return self.class_f1.mean() + + @property + def nll(self): + return self.ce_sum / self.size + + @property + def summary(self): + return {"accuracy": f"{self.accuracy*100:.2f}", + "f1": f"{self.f1*100:.2f}", + "nll": f"{self.nll:.2f}"} diff --git a/fsre/model/mtb_classifier.py b/fsre/model/mtb_classifier.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +import transformers + + +class Classifier(nn.Module): + """ + Transformer classifier from Soares et al. + + Correspond to the left part of each subfigure of Figure 2 (Deep Transformer Encoder and the green layer above). + + Config: + bert_model: The version of BERT to use (e.g. bert-large-uncased). + post_transformer_layer: The transformation applied after BERT (must be "linear" or "layer_norm") + """ + + def __init__(self, config, tokenizer): + """ + Instantiate a Soares et al. classifier. + + Args: + config: global config object + tokenizer: tokenizer used to create the vocabulary + """ + + super().__init__() + + self.config = config + self.tokenizer = tokenizer + + if self.config.get("load"): + bert_config = transformers.BertConfig.from_pretrained(self.config.bert_model) + bert_config.vocab_size = len(tokenizer) + self.bert = transformers.BertModel(bert_config) + else: + self.bert = transformers.BertModel.from_pretrained(self.config.bert_model) + self.bert.resize_token_embeddings(len(tokenizer)) + + if self.config.post_transformer_layer == "linear": + self.post_transformer = nn.Linear( + in_features=self.output_size, + out_features=self.output_size) + elif self.config.post_transformer_layer == "layer_norm": + self.post_transformer = torch.nn.LayerNorm(self.output_size) + else: + assert(False) + + @property + def output_size(self): + return self.bert.config.hidden_size * 2 + + def forward(self, inputs): + bert_out = self.bert(inputs["text"], attention_mask=inputs["mask"])[0] + batch_ids = torch.arange(bert_out.shape[0], device=bert_out.device, dtype=torch.int64) + e1_out = bert_out[batch_ids, inputs["e1_pos"]] + e2_out = bert_out[batch_ids, inputs["e2_pos"]] + sentence = torch.cat((e1_out, e2_out), dim=1) + return self.post_transformer(sentence) diff --git a/fsre/model/mtb_supervised.py b/fsre/model/mtb_supervised.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn +import transformers + +from fsre.model.mtb_classifier import Classifier + + +class Model(nn.Module): + """ + Supervised model from Soares et al. + + Correspond to the left subfigure of Figure 2. + """ + + def __init__(self, config, tokenizer, relation_dictionary): + """ + Instantiate a Soares et al. supervised model. + + Args: + config: global config object + tokenizer: tokenizer used to create the vocabulary + relation_dictionary: dictionary of all relations + """ + + super().__init__() + + self.config = config + self.tokenizer = tokenizer + self.relation_dictionary = relation_dictionary + + self.classifier = Classifier(config, tokenizer) + self.relation_classifier = nn.Linear( + in_features=self.classifier.output_size, + out_features=len(relation_dictionary), + bias=False) + + def forward(self, inputs): + latent = self.classifier(inputs) + return self.relation_classifier(latent) diff --git a/fsre/train.py b/fsre/train.py @@ -0,0 +1,317 @@ +import sys +import os +import math +import multiprocessing +import signal +import logging + +import tqdm +import torch +import transformers + +import fsre + +logger = logging.getLogger(__name__) + + +class Trainer: + """ + Train a model. + + Config: + batch_size: the number of samples in the batch of data loaded + true_batch_size: the actual number of sample in a batch, the number of + sample seen before a backward (must be a multiple of batch_size) + dataset_name: name of the dataset to load + deterministic: run in deterministic mode + learning_rate: learning rate + max_epoch: maximum number of epoch + seed: the seed for the random number generator + sort_per_shuffle_bucket: the number of sort buckets in a shuffle bucket + """ + + def __init__(self, config, logdir, state_dicts=None): + self.config = config + self.logdir = logdir + self.state_dicts = state_dicts + + def run(self): + self.info() + self.log_patch() + self.initialize_rng() + self.prepare_dataset() + self.build_model() + self.count_parameters() + self.setup_optimizer() + self.hook_signals() + self.train() + + def environment_check(self): + python_version = '.'.join(map(str, sys.version_info[:3])) + torch_version = torch.__version__ + cuda_available = torch.cuda.is_available() + + logger.info(f"python version {python_version}") + logger.info(f"torch version {torch_version}") + logger.info(f"cuda available {cuda_available}") + + def problem(str): + return f"\033[1m\033[31m{str}\033[0m" + + if sys.version_info < (3, 7): + python_version = problem(python_version) + if list(map(int, torch_version.split('.'))) < [1, 3]: + torch_version = problem(torch_version) + if not cuda_available: + cuda_available = problem(cuda_available) + + print(f"python version: {python_version}, torch version: {torch_version}, cuda available: {cuda_available}") + + def detect_gpus(self): + count = torch.cuda.device_count() + + if count == 0: + print(f"\033[1m\033[31mNo GPU available\033[0m") + logger.warning("no GPU available") + self.device = torch.device("cpu") + else: + self.device = torch.device("cuda:0") + + for i in range(count): + gp = torch.cuda.get_device_properties(i) + print(f"GPU{i}: \033[33m{gp.name}\033[0m (Mem: {gp.total_memory/2**30:.2f}GiB CC: {gp.major}.{gp.minor})") + logger.info(f"GPU{i} {gp.name} {gp.total_memory} {gp.major}.{gp.minor}") + + def info(self): + print(f"logdir is \033[1m\033[33m{self.logdir}\033[0m") + self.environment_check() + self.detect_gpus() + print("") + + print("\033[1m\033[33mConfiguration\033[0m") + fsre.utils.print_dict(self.config) + fsre.utils.log_dict(logging.getLogger("config"), self.config) + print("") + + def log_patch(self): + version = fsre.utils.get_repo_version() + logger.info(f"repository_version {version}") + if version == "release": + print(f"\033[41mRelease version\033[0m\n") + elif version.endswith('+'): + print(f"\033[31mUncommited changes detected, saving patch to logdir.\033[0m\n") + suffix = "" + if self.state_dicts: + suffix = time.strftime("%FT%H:%M:%S") + fsre.utils.save_patch(self.logdir / "patch{suffix}") + + def initialize_rng(self): + if self.state_dicts: + torch.random.set_rng_state(self.state_dicts["torch_rng"]) + assert(("cuda_rng" in self.state_dicts) == torch.cuda.is_available()) + if "cuda_rng" in self.state_dicts: + torch.cuda.random.set_rng_state_all(self.state_dicts["cuda_rng"]) + else: + torch.manual_seed(self.config.seed) + + if self.config.get("deterministic"): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + def prepare_dataset(self): + data_dir = fsre.utils.DATA_PATH / self.config.dataset_name / self.config.bert_model + self.relation_dictionary = fsre.data.RelationDictionary(data_dir / "relations") + self.tokenizer = transformers.BertTokenizer.from_pretrained(data_dir / "tokenizer") + + self.dataset = {} + self.iterator = {} + + for dataset in ["train", "valid", "test"]: + if dataset == "train": + kwargs = {"rng": self.state_dicts["train_rng"]} if self.state_dicts else {} + + self.dataset[dataset] = fsre.data.RelationExtractionDataset( + self.config, + data_dir / f"{dataset}.npy", + pad=self.tokenizer.pad_token_id, + evaluation=(dataset != "train"), + **kwargs) + + self.iterator[dataset] = lambda dataset=dataset: torch.utils.data.DataLoader( + dataset=self.dataset[dataset], + batch_size=self.config.batch_size, + num_workers=self.config.sort_per_shuffle_bucket, + pin_memory=(self.device.type == "cuda")) + + def build_model(self): + self.model = self.config.Model(self.config, self.tokenizer, self.relation_dictionary) + + if self.state_dicts: + self.model.load_state_dict(self.state_dicts["model"]) + + self.loss = torch.nn.CrossEntropyLoss() + if self.device.type == "cuda": + self.model.to(self.device) + self.loss.to(self.device) + + def count_parameters(self): + total = 0 + for parameter in self.model.parameters(): + total += parameter.shape.numel() + print(f"\033[33mNumber of parameters: {total:,}\033[0m") + + def setup_optimizer(self): + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.learning_rate) + if self.state_dicts: + self.optimizer.load_state_dict(self.state_dicts["optimizer"]) + + def hook_signals(self): + self.interrupted = False + + def handler(sig, frame): + if multiprocessing.current_process().name != "MainProcess": + return + + print("\n\033[31mInterrupted, training will stop at the end of this epoch.\n\033[1mNEXT ^C WILL KILL THE PROCESS!\033[0m\n", file=sys.stderr) + self.interrupted = True + signal.signal(signal.SIGINT, signal.SIG_DFL) + + signal.signal(signal.SIGINT, handler) + + def evaluate(self, dataset): + loop = tqdm.tqdm( + iterable=self.iterator[dataset](), + desc=f"Epoch {self.epoch:2} {dataset:5}", + unit="samples", + unit_scale=self.config.batch_size, + total=math.ceil(len(self.dataset[dataset]) / self.config.batch_size), + leave=False) + + self.model.eval() + with torch.no_grad(): + scorer = fsre.metrics.Metrics(len(self.relation_dictionary)) + correct_prediction = 0 + for batch in loop: + batch = {key: value.to(self.device) for key, value in batch.items()} + + # Pop the target to ensure it's not used by the model + target = batch.pop("relation") + predictions = self.model(batch) + scorer.update(predictions, target) + loop.set_postfix(**scorer.summary, refresh=False) + + print(f"Epoch {self.epoch} {dataset:5} accuracy: {scorer.accuracy*100:8.4f}% F1: {scorer.f1*100:8.4f}% (P: {scorer.precision*100:8.4f}% R: {scorer.recall*100:8.4f}%) NLL: {scorer.nll:8.4f}") + logger.info(f"epoch {self.epoch} {dataset} accuracy {scorer.accuracy} F1 {scorer.f1} precision {scorer.precision} recall {scorer.recall} NLL {scorer.nll}") + return getattr(scorer, self.config.validation_metric) + + def save(self, path, full): + state_dicts = {"model": self.model.state_dict()} + if full: + state_dicts.update({ + "logdir": self.logdir, + "optimizer": self.optimizer.state_dict(), + "train_rng": self.dataset["train"].state_dict(), + "torch_rng": torch.random.get_rng_state(), + "epoch": self.epoch, + "best_epoch": self.best_epoch, + "best_eval": self.best_eval, + }) + if torch.cuda.is_available(): + state_dicts["cuda_rng"] = torch.cuda.random.get_rng_state_all() + torch.save(state_dicts, path) + + def train(self): + self.epoch = 0 + self.best_epoch = 0 + self.best_eval = -float("inf") + + if self.state_dicts: + self.epoch = self.state_dicts["epoch"] + self.best_epoch = self.state_dicts["best_epoch"] + self.best_eval = self.state_dicts["best_eval"] + + best_path = self.logdir / "best" + batch_per_epoch = int(math.ceil(len(self.dataset["train"]) / self.config.batch_size)) + assert(self.config.true_batch_size % self.config.batch_size == 0) + data_batch_per_true_batch = self.config.true_batch_size // self.config.batch_size + + if not self.config.get("no_initial_validation"): + self.best_eval = self.evaluate("valid") + + for self.epoch in range(self.epoch+1, self.config.max_epoch+1): + if self.interrupted: + break + + loop = tqdm.tqdm( + iterable=self.iterator["train"](), + desc=f"Epoch {self.epoch} train", + unit="samples", + unit_scale=self.config.batch_size, + total=math.ceil(len(self.dataset["train"]) / self.config.batch_size), + leave=False) + + self.model.train() + self.optimizer.zero_grad() + total_loss = 0 + total_sample = 0 + for batch_id, batch in enumerate(loop): + batch = {key: value.to(self.device) for key, value in batch.items()} + + # Pop the target to ensure it's not used by the model + target = batch.pop("relation") + + output = self.model(batch) + loss = self.loss(output, target) + loss.backward() + + total_loss += loss.item() + total_sample += target.shape[0] + loop.set_postfix(loss=f"{total_loss / total_sample:.2f}", refresh=False) + + if batch_id % data_batch_per_true_batch == data_batch_per_true_batch - 1: + self.optimizer.step() + self.optimizer.zero_grad() + + print(f"Epoch {self.epoch} train mean loss: {total_loss / total_sample:8.4f}") + logger.info(f"epoch {self.epoch} train mean_loss {total_loss / total_sample}") + + self.save(self.logdir / "checkpoint.new", full=True) + os.rename(self.logdir / "checkpoint.new", self.logdir / "checkpoint") + + candidate = self.evaluate("valid") + if candidate > self.best_eval: + self.best_epoch = self.epoch + self.best_eval = candidate + self.save(best_path, full=False) + logger.info(f"Model saved to {best_path}") + elif self.epoch - self.best_epoch > self.config.get("early_stopping_patience", self.config.max_epoch): + break + + if self.best_eval != -float("inf") and best_path.exists(): + print(f"Loading best model from {best_path}…", end="", flush=True) + self.model.load_state_dict(torch.load(best_path)["model"]) + print(" done") + logger.info(f"{best_path} loaded for evaluation on test set") + + self.evaluate("test") + + +if __name__ == "__main__": + fsre.utils.fix_transformers_logging_handler() + config = fsre.utils.parse_args() + + state_dicts = None + if config.get("load"): + state_dicts = torch.load(config.load, strict=False) + + if config.get("load"): + logdir = state_dicts["logdir"] + else: + logdir = fsre.utils.logdir_name("FSRE") + assert(not logdir.exists()) + logdir.mkdir() + + logfile = logdir / "log" + logging.basicConfig(format="%(asctime)s\t%(levelname)s:%(name)s:%(message)s", filename=logfile, filemode='a', level=logging.INFO) + + Trainer(config, logdir, state_dicts).run() diff --git a/fsre/utils.py b/fsre/utils.py @@ -0,0 +1,143 @@ +import os +import sys +import types +import importlib +import subprocess +import logging +import time +import hashlib +import pathlib + + +def import_environment(name, cast=str): + try: + globals()[name] = cast(os.environ[name]) + except KeyError: + print(f"ERROR: {name} environment variable is not set.", + file=sys.stderr) + sys.exit(1) + + +import_environment("DATA_PATH", pathlib.Path) +import_environment("LOG_PATH", pathlib.Path) + + +class dotdict(dict): + def __getattr__(self, name): + if name not in self: + raise AttributeError(f"Config key {name} not found") + return dotdict(self[name]) if type(self[name]) is dict else self[name] + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +def eval_arg(config, arg): + if '=' in arg: + key, value = arg.split('=', maxsplit=1) + value = eval(value, config) + else: + key, value = arg, True + path = key.split('.') + for d in path[:-1]: + config = config[d] + config[path[-1]] = value + config.pop("__builtins__", None) + + +def import_arg(config, arg): + if arg.endswith(".py"): + arg = arg[:-3].replace('/', '.') + module = importlib.import_module(arg) + for key, value in vars(module).items(): + if key not in module.__builtins__ \ + and not key.startswith("__") \ + and not isinstance(value, types.ModuleType): + config[key] = value + + +def parse_args(): + config = {} + for arg in sys.argv[1:]: + if arg.startswith("--"): + eval_arg(config, arg[2:]) + else: + import_arg(config, arg) + return dotdict(config) + + +def map_dict(output, input, depth=0): + for key, value in input.items(): + indent = '\t'*depth + output(f"{indent}{key}:") + if isinstance(value, dict): + output('\n') + map_dict(output, value, depth+1) + else: + output(f" {value}\n") + + +def print_dict(input): + map_dict(lambda x: print(x, end=""), input) + + +def log_dict(logger, input): + class log: + buf = "" + + def __call__(self, x): + self.buf += x + if self.buf.endswith('\n'): + logger.info(self.buf[:-1]) + self.buf = "" + map_dict(log(), input) + + +def get_repo_version(): + repo_dir = pathlib.Path(__file__).parents[0] + result = subprocess.run(["hg", "id", "-i"], + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + encoding="utf-8", + cwd=repo_dir) + + if result.returncode != 0: + return "release" + return result.stdout.rstrip() + + +def experiment_name(name): + args = ' '.join(sys.argv[1:]) + version = get_repo_version() + stime = time.strftime("%FT%H:%M:%S") + return f"{name} {version} {args} {stime}" + + +def logdir_name(name): + subdir = experiment_name(name).replace('/', '_') + if len(subdir) > 255: + sha1 = hashlib.sha1(subdir.encode("utf-8")).digest().hex()[:16] + subdir = subdir[:255-17] + ' ' + sha1 + return LOG_PATH / subdir + + +def fix_transformers_logging_handler(): + """ + The transformers package from huggingface install its own logger on import, + I don't want it. + """ + logger = logging.getLogger() + for handler in logger.handlers: + logger.removeHandler(handler) + + +def save_patch(outpath): + repo_dir = pathlib.Path(__file__).parents[0] + + with open(outpath, "w") as outfile: + result = subprocess.run(["hg", "diff"], + stdout=outfile, + stderr=subprocess.DEVNULL, + encoding="utf-8", + cwd=repo_dir) + + assert(result.returncode == 0) diff --git a/requirements.txt b/requirements.txt @@ -0,0 +1,4 @@ +numpy>=1.16 +torch==1.3.0 +tqdm>=4 +transformers==2.0.0