gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

prepare_semeval.py (3451B)


      1 from typing import Dict, Iterable, Tuple
      2 import argparse
      3 import pathlib
      4 import random
      5 
      6 import tqdm
      7 
      8 from gbure.utils import DATA_PATH
      9 import gbure.data.preprocessing as preprocessing
     10 
     11 DATASET_PATH: pathlib.Path = DATA_PATH / "SemEval 2010 Task 8"
     12 DIRECTORY_NAME: str = "SemEval2010_task8_all_data"
     13 ARCHIVE_NAME: str = f"{DIRECTORY_NAME}.zip"
     14 ARCHIVE_SHA512: str = "7ac2d71ba1772105c1f73e4278e4b85cebd9fb95187fb8a153c83215d890f0d2b98929fb2363e8c117d2e2c8e7f9926d7997e38fc57b5730fb00912ff376b66b"
     15 DOWNLOAD_URL: str = f"https://esimon.eu/GBURE/{ARCHIVE_NAME}"
     16 
     17 TRAIN_VALID_SIZE: int = 8000
     18 TEST_SIZE: int = 2717
     19 UNKNOWN_RELATION: str = "Other"
     20 
     21 
     22 def read_data(path: pathlib.Path, size: int) -> Iterable[Tuple[str, str, str, str, str]]:
     23     """
     24     Read a file in SemEval format and return (text, relation, relation_base, e1, e2) tuples.
     25 
     26     For now, the entities are empty, we could encode them using their surface form or use a true entity linker if we want to use this information.
     27     """
     28     with path.open() as file:
     29         for _ in tqdm.trange(size, desc=f"loading {path.name}"):
     30             idtext_line: str = file.readline()
     31             relation_line: str = file.readline()
     32             file.readline()  # Ignore Comment line
     33             file.readline()  # Ignore empty line
     34 
     35             if not (idtext_line and relation_line):
     36                 break
     37 
     38             id, raw_text = idtext_line.rstrip().split('\t')
     39             text = raw_text[1:-1]  # remove quotes around text
     40             relation = relation_line.rstrip()
     41 
     42             dir_start: int = relation.find('(')
     43             relation_base: str = relation[:dir_start] if dir_start >= 0 else relation
     44 
     45             # TODO handle entities
     46             yield (text, relation, relation_base, "", "")
     47 
     48 
     49 def split_train_valid(data: Iterable[Tuple[str, str, str, str, str]], valid_size: int, seed: int) -> Tuple[Iterable[Tuple[str, str, str, str, str]], Iterable[Tuple[str, str, str, str, str]]]:
     50     data = list(data)
     51     rng = random.Random(seed)
     52     rng.shuffle(data)
     53 
     54     train = data[valid_size:]
     55     valid = data[:valid_size]
     56     return train, valid
     57 
     58 
     59 def read_splits(valid_size: int) -> Dict[str, Iterable[Tuple[str, str, str, str, str]]]:
     60     splits = {}
     61     train_valid = read_data(DATASET_PATH / DIRECTORY_NAME / "SemEval2010_task8_training" / "TRAIN_FILE.TXT", TRAIN_VALID_SIZE)
     62     splits["train"], splits["valid"] = split_train_valid(train_valid, args.valid_size, args.seed)
     63     splits["test"] = read_data(DATASET_PATH / DIRECTORY_NAME / "SemEval2010_task8_testing_keys" / "TEST_FILE_FULL.TXT", TEST_SIZE)
     64     return splits
     65 
     66 
     67 if __name__ == "__main__":
     68     parser: argparse.ArgumentParser = preprocessing.base_argument_parser("Prepare the supervised SemEval 2010 Task 8 dataset.")
     69     parser.add_argument("-v", "--valid-size",
     70                         type=int,
     71                         default=1500,
     72                         help="Size of the validation set")
     73 
     74     args: argparse.Namespace = parser.parse_args()
     75     name: str = preprocessing.dataset_name(args, f"-v{args.valid_size}" if args.valid_size != 1500 else "")
     76 
     77     preprocessing.get_zip_data(DATASET_PATH, DIRECTORY_NAME, ARCHIVE_NAME, ARCHIVE_SHA512, DOWNLOAD_URL)
     78     preprocessing.serialize_dataset(
     79             supervision="supervised",
     80             path=DATASET_PATH / name,
     81             splits=read_splits(args.valid_size),
     82             unknown_relation=UNKNOWN_RELATION,
     83             **preprocessing.args_to_serialize(args))