gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

prepare_trex.py (3335B)


      1 from typing import Any, Dict, Iterable, List, Optional, Tuple
      2 import argparse
      3 import json
      4 import os
      5 import pathlib
      6 import tqdm
      7 
      8 from gbure.utils import DATA_PATH
      9 import gbure.data.preprocessing as preprocessing
     10 
     11 DATASET_PATH: pathlib.Path = DATA_PATH / "T-REx"
     12 DIRECTORY_NAME: str = "raw_data"
     13 ARCHIVE_NAME: str = f"T-REx.zip"
     14 ARCHIVE_SHA512: str = "30349fa6f01c1928ce15325521ebd05643787220f9a545eb23b280f9209cb1615f4a855b08604f943a1affb4d1f4f17b94f8434698f347a1cb7a0d820fa9de9f"
     15 DOWNLOAD_URL: str = f"https://esimon.eu/GBURE/{ARCHIVE_NAME}"
     16 
     17 
     18 def process_json_object(data: List[Dict[str, Any]]) -> Iterable[Tuple[str, List[Tuple[str, int, int]]]]:
     19     """ Process a T-REx json object and return (sentence, list of entities) tuples. """
     20     for article in data:
     21         eid: int = 0
     22         for sbs in article["sentences_boundaries"]:
     23             entities: List[Tuple[str, int, int]] = []
     24             while eid < len(article["entities"]) and article["entities"][eid]["boundaries"][0] < sbs[0]:
     25                 eid += 1
     26 
     27             while eid < len(article["entities"]) and article["entities"][eid]["boundaries"][1] <= sbs[1]:
     28                 entity: Dict[str, Any] = article["entities"][eid]
     29                 eid += 1
     30 
     31                 # Ignore date entities
     32                 if entity["annotator"] != "Wikidata_Spotlight_Entity_Linker":
     33                     continue
     34 
     35                 uri: str = entity["uri"]
     36                 prefix: str = "http://www.wikidata.org/entity/Q"
     37                 assert(uri.startswith(prefix))
     38                 uri = uri[len(prefix):]
     39                 entities.append((uri, entity["boundaries"][0] - sbs[0], entity["boundaries"][1] - sbs[0]))
     40 
     41             # ignore sentences with less than two entities
     42             if len(entities) < 2:
     43                 continue
     44 
     45             sentence = article["text"][sbs[0]:sbs[1]]
     46             yield (sentence, entities)
     47 
     48 
     49 def read_data(subset: Optional[int]) -> Iterable[Tuple[str, List[Tuple[str, int, int]]]]:
     50     """ Read all T-REx files and return (sentence, list of entities) tuples. """
     51     filenames: List[str] = list(filter(lambda filename: filename.endswith(".json"), os.listdir(DATASET_PATH / DIRECTORY_NAME)))
     52 
     53     # Make the order deterministic.
     54     filenames.sort()
     55     if subset is not None:
     56         filenames = filenames[:subset]
     57 
     58     for filename in tqdm.tqdm(filenames, desc="loading"):
     59         with open(DATASET_PATH / DIRECTORY_NAME / filename, "r") as file:
     60             data: List[Dict[str, Any]] = json.load(file)
     61         yield from process_json_object(data)
     62 
     63 
     64 if __name__ == "__main__":
     65     parser: argparse.ArgumentParser = preprocessing.base_argument_parser("Prepare the unsupervised TREx dataset.")
     66     parser.add_argument("-S", "--subset",
     67                         type=int,
     68                         help="Number of file to process (default to all, only used for creating a debug dataset)")
     69     args: argparse.Namespace = parser.parse_args()
     70     name: str = preprocessing.dataset_name(args, "" if args.subset is None else f"-ss{args.subset}")
     71 
     72     preprocessing.get_zip_data(DATASET_PATH, DIRECTORY_NAME, ARCHIVE_NAME, ARCHIVE_SHA512, DOWNLOAD_URL, unzip_directory=True)
     73     preprocessing.serialize_unsupervised_dataset(
     74             path=DATASET_PATH / name,
     75             data=read_data(args.subset),
     76             **preprocessing.args_to_serialize(args))