prepare_fewrel.py (2697B)
1 from typing import Any, Dict, Iterable, Tuple 2 import argparse 3 import json 4 import pathlib 5 6 import tqdm 7 8 from gbure.utils import DATA_PATH 9 import gbure.data.preprocessing as preprocessing 10 11 DATASET_PATH: pathlib.Path = DATA_PATH / "FewRel" 12 DOWNLOAD_URL: str = "https://thunlp.oss-cn-qingdao.aliyuncs.com/fewrel/" 13 FILES_SHA512: Dict[str, str] = { 14 "fewrel_train.json": "2ec687d16999bd59bbcac39fdfed319cee3bec14963717c6ee262da981ad64e58b93d5c95a6ea4f6c5fe9c3d09a57d098d25ad61955f4df5f910bc28718e8220", 15 "fewrel_val.json": "32bdac8c9aba880484d00417d823a310657acca5604a06fa7c4c01f8dfb54b9e05ecf67c0c374aa61bd42cb9e90cd62c0d50377cce2d6bc8fa6d1fbbb61d0f5e" 16 } 17 18 19 def get_data() -> None: 20 """ Download FewRel's train and val json files if needed. """ 21 for filename, sha512 in FILES_SHA512.items(): 22 if not (DATASET_PATH / filename).exists(): 23 preprocessing.download(DOWNLOAD_URL + filename, DATASET_PATH / filename, filename, sha512) 24 25 26 def read_data(path: pathlib.Path) -> Iterable[Tuple[str, str, str, str, str]]: 27 """ Read a FewRel json file and return (text, relation, relation_base, e1, e2) tuples. """ 28 with open(path) as file: 29 data = json.load(file) 30 31 for relation, relset in tqdm.tqdm(data.items(), desc=f"loading {path.name}"): 32 relation_dataset = [] 33 for sentence in relset: 34 # We assume we know the direction of the relation 35 yield (process_sentence(sentence), relation, relation, sentence["h"][1][1:], sentence["t"][1][1:]) 36 37 38 def process_sentence(sentence: Dict[str, Any]) -> str: 39 """ Transform a FewRel json sentence object to a tagged sentence string. """ 40 tokens = sentence["tokens"] 41 tokens[sentence["h"][2][0][0]] = "<e1>" + tokens[sentence["h"][2][0][0]] 42 tokens[sentence["h"][2][0][-1]] += "</e1>" 43 tokens[sentence["t"][2][0][0]] = "<e2>" + tokens[sentence["t"][2][0][0]] 44 tokens[sentence["t"][2][0][-1]] += "</e2>" 45 return " ".join(tokens) 46 47 48 def read_splits() -> Dict[str, Iterable[Tuple[str, str, str, str, str]]]: 49 return {"train": read_data(DATASET_PATH / "fewrel_train.json"), 50 "valid": read_data(DATASET_PATH / "fewrel_val.json")} 51 52 53 if __name__ == "__main__": 54 parser: argparse.ArgumentParser = preprocessing.base_argument_parser("Prepare the few shot FewRel dataset.") 55 args: argparse.Namespace = parser.parse_args() 56 name: str = preprocessing.dataset_name(args) 57 58 get_data() 59 preprocessing.serialize_dataset( 60 supervision="fewshot", 61 path=DATASET_PATH / name, 62 splits=read_splits(), 63 unknown_entity=None, 64 unknown_relation=None, 65 **preprocessing.args_to_serialize(args))