preprocessing.py (17972B)
1 from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast 2 import argparse 3 import collections 4 import hashlib 5 import math 6 import os 7 import pathlib 8 import random 9 import urllib.request 10 import zipfile 11 12 import torch 13 import tqdm 14 import transformers 15 16 from gbure.data.dictionary import Dictionary, RelationDictionary 17 from gbure.data.graph import Graph 18 19 20 def hash_file(path: pathlib.Path, filename: Optional[str] = None, filesize: Optional[int] = None) -> str: 21 """ Get a unique identifier for the file. """ 22 hasher = hashlib.sha512() 23 with path.open("rb") as file: 24 loop = iter(lambda: file.read(2**16), b"") 25 if filename is not None and filesize is not None: 26 loop = tqdm.tqdm(loop, 27 desc=f"checking {filename} hash", 28 total=math.ceil(filesize / 2**16), 29 unit_scale=2**16, unit="B", unit_divisor=1024) 30 for chunk in loop: 31 hasher.update(chunk) 32 return hasher.hexdigest() 33 34 35 def download(url: str, path: pathlib.Path, filename: str, sha512: str) -> None: 36 """ Download a file at the given path and check its hash. """ 37 if not path.parent.is_dir(): 38 path.parent.mkdir(parents=True) 39 40 unchecked: pathlib.Path = pathlib.Path(f"{path}.unchecked") 41 with tqdm.tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=f"downloading {filename}") as progress: 42 def report_hook(num_blocks: int, chunk_size: int, total_size: int): 43 if progress.total is None: 44 progress.total = total_size 45 progress.update(num_blocks * chunk_size - progress.n) 46 urllib.request.urlretrieve(url, unchecked, report_hook) 47 48 unchecked_hash = hash_file(unchecked, filename, progress.total) 49 if unchecked_hash != sha512: 50 raise RuntimeError(f"Downloaded file \"{filename}\" has wrong hash.") 51 os.rename(unchecked, path) 52 53 54 def get_zip_data(dataset_path: pathlib.Path, directory_name: str, archive_name: str, archive_sha512: str, download_url: str, unzip_directory: bool = False) -> None: 55 """ Download and extract data zip archive if needed. """ 56 if not (dataset_path / directory_name).exists(): 57 if not (dataset_path / archive_name).exists(): 58 download(download_url, dataset_path / archive_name, archive_name, archive_sha512) 59 60 with zipfile.ZipFile(str(dataset_path / archive_name), "r") as archive: 61 archive.extractall(dataset_path / directory_name if unzip_directory else dataset_path) 62 63 64 def base_argument_parser(description: str = "", deterministic: bool = False, parser: Optional[argparse.ArgumentParser] = None) -> argparse.ArgumentParser: 65 assert(description != "" or parser is not None) 66 """ Return an argument parser with standard command line arguments used by preprocessing functions. """ 67 parser: argparse.ArgumentParser = argparse.ArgumentParser(description=description) if parser is None else parser 68 parser.add_argument("tokenizer", 69 type=str, 70 nargs='?', 71 default="bert-base-cased", 72 help="Name of the transformers tokenizer") 73 if not deterministic: 74 parser.add_argument("-s", "--seed", 75 type=int, 76 default=0, 77 help="Seed of the RNG for shuffling the dataset") 78 return parser 79 80 81 def dataset_name(args: argparse.Namespace, infix: str = "") -> str: 82 """ Returns the dataset name with suffix containing non-standard preprocessing parameters. """ 83 suffix: str = "" 84 if "seed" in args and args.seed != 0: 85 suffix = f"-s{args.seed}" 86 return f"{args.tokenizer}{infix}{suffix}" 87 88 89 def args_to_serialize(args: argparse.Namespace) -> Dict[str, Any]: 90 """ Map standard preprocessing command line arguments defined in base_argument_parser to serialize_supervised_dataset parameters. """ 91 kwargs = {"tokenizer_name": args.tokenizer} 92 if "seed" in args: 93 kwargs["seed"] = args.seed 94 return kwargs 95 96 97 def make_tokenizer(name: str, path: pathlib.Path) -> transformers.PreTrainedTokenizer: 98 """ Build the given tokenizer and save it. """ 99 if not path.is_dir(): 100 path.mkdir() 101 102 tokenizer = transformers.AutoTokenizer.from_pretrained(name) 103 special_tokens = ["<e1>", "</e1>", "<e2>", "</e2>", "<blank/>"] 104 tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) 105 tokenizer.save_pretrained(path) 106 107 # fix huggingface transformers issue #6368 108 config_file = transformers.AutoConfig.from_pretrained(name) 109 config_file.save_pretrained(path) 110 111 return tokenizer 112 113 114 def process_text_2(raw_text: str, tokenizer: transformers.PreTrainedTokenizer) -> Tuple[torch.Tensor, int, int]: 115 """ 116 Transform a string with two entities tagged to a list of token ids together with the positions of the two entities. 117 118 The returned token list contains the token corresponding to the tags. 119 The two returned positions are the positions of <e1> and <e2>. 120 """ 121 be1_id: int = tokenizer.convert_tokens_to_ids("<e1>") 122 be2_id: int = tokenizer.convert_tokens_to_ids("<e2>") 123 124 text: List[int] = tokenizer.encode(raw_text, add_special_tokens=True) 125 e1_pos: int = text.index(be1_id) 126 e2_pos: int = text.index(be2_id) 127 if len(text) > tokenizer.model_max_length: 128 text = text[:tokenizer.model_max_length] 129 e1_pos = min(tokenizer.model_max_length-1, e1_pos) 130 e2_pos = min(tokenizer.model_max_length-1, e2_pos) 131 132 return torch.tensor(text, dtype=torch.int32), e1_pos, e2_pos 133 134 135 def process_text_n(raw_text: str, raw_entities: List[Tuple[str, int, int]], tokenizer: transformers.PreTrainedTokenizer) -> Tuple[torch.Tensor, List[Tuple[str, int, int]]]: 136 """ 137 Transform a string with several entities tagged to a list of token id together with the positions of entities. 138 139 The returned token list does not contain the token corresponding to the tags. 140 The returned postions, are where the tags should be inserted. 141 If the leftmost tag is inserted first, the position of subsequent inserts should be shifted accordingly. 142 """ 143 be1_id: int = tokenizer.convert_tokens_to_ids("<e1>") 144 145 # If one entity end at a position, and another entity start at the same position, we want to close the first entity before starting the sencond one, the second field "1 - extremity" has this function since the list is sorted in lexicographic order. 146 tag_positions: List[Tuple[int, int, int]] = [ 147 (cast(int, entity[1 + extremity]), # Position of the tag (start or end of entity) in the sentence. 148 cast(int, 1 - extremity), # Whether this is a start or end of entity. 149 i) # The index of the entity used to rebuild the list at the end. 150 for i, entity in enumerate(raw_entities) for extremity in [0, 1]] 151 tag_positions.sort() 152 153 # We insert the tag <e1> at every tag postion in order to be able to convert postions in the raw text to positions in the token list. 154 pieces: List[str] = [] 155 for piece_start, piece_end in zip([(0,)] + tag_positions, tag_positions + [(len(raw_text),)]): 156 pieces.append(raw_text[piece_start[0]:piece_end[0]]) 157 pieces.append("<e1>") 158 # Remove the last <e1> added at the end of the sentence. 159 pieces.pop() 160 161 text: List[int] = tokenizer.encode("".join(pieces), add_special_tokens=True) 162 if len(text) > tokenizer.model_max_length: 163 text = text[:tokenizer.model_max_length] 164 165 # New entity list, with converted positions. 166 entities: List[List[Union[str, int]]] = [[entity[0], -1, -1] for entity in raw_entities] 167 168 j: int = 0 # Counter on the tags. 169 for i, token in enumerate(text): 170 if token == be1_id: 171 # The order of the <e1> in the text match the one in tag_positions. 172 tag_position: Tuple[int, int, int] = tag_positions[j] 173 174 # tag_position[2] is the index of the entity in raw_entities (and thus entities). 175 # tag_position[1] is 0 for the end of the entity and 1 for its start. 176 # Since the returned token list will be pruned of all the <e1>, the position of the tag should be shifted by the number of <e1> already met, thus "i - j". 177 entities[tag_position[2]][2 - tag_position[1]] = i - j 178 179 j += 1 # Move to the next tag. 180 181 # Remove all tags 182 text = list(filter(lambda x: x != be1_id, text)) 183 184 # Remove entities which didn't fit inside tokenizer.model_max_length tokens. 185 entities = list(filter(lambda x: x[1] >= 0 and x[2] >= 0, entities)) 186 187 tuple_entities: List[Tuple[str, int, int]] = list(map(tuple, entities)) 188 return torch.tensor(text, dtype=torch.int32), tuple_entities 189 190 191 def serialize_supervised_split( 192 path: pathlib.Path, 193 split: Iterable[Tuple[str, str, str, str, str]], 194 tokenizer: transformers.PreTrainedTokenizer, 195 entity_dictionary: Dictionary, 196 relation_dictionary: RelationDictionary) -> None: 197 """ 198 Serialize a supervised split to a given path. 199 200 split is an iterable containing (text, directed relation, undirected relation, e1, e2) tuples. 201 Entities are ignored. 202 The relations are raw values (e.g. P42). This function performs the encoding. 203 """ 204 data: List[Tuple[torch.Tensor, int, int, int]] = [] 205 206 # TODO handle entities 207 for raw_text, relation, relation_base, _, _ in split: 208 text, e1_pos, e2_pos = process_text_2(raw_text, tokenizer) 209 relation_id: int = relation_dictionary.encode(relation, relation_base) 210 data.append((text, e1_pos, e2_pos, relation_id)) 211 212 torch.save(("supervised", data), path) 213 214 215 def serialize_fewshot_split( 216 path: pathlib.Path, 217 split: Iterable[Tuple[str, str, str, str, str]], 218 tokenizer: transformers.PreTrainedTokenizer, 219 entity_dictionary: Dictionary, 220 relation_dictionary: RelationDictionary) -> None: 221 """ 222 Serialize a fewshot split to a given path. 223 224 split is an iterable containing (text, directed relation, undirected relation, e1, e2) tuples. 225 The relations and entities are raw values (e.g. P42, Q42). This function performs the encoding. 226 """ 227 data: Dict[int, List[Tuple[torch.Tensor, int, int, int, int, int]]] = collections.defaultdict(list) 228 229 for raw_text, relation, relation_base, e1, e2 in split: 230 text, e1_pos, e2_pos = process_text_2(raw_text, tokenizer) 231 relation_id: int = relation_dictionary.encode(relation, relation_base) 232 e1_id: int = entity_dictionary.encode(e1) 233 e2_id: int = entity_dictionary.encode(e2) 234 data[relation_id].append((text, e1_pos, e2_pos, relation_id, e1_id, e2_id)) 235 236 torch.save(("fewshot", list(data.values())), path) 237 238 239 def serialize_fewshot_sampled_split( 240 path: pathlib.Path, 241 name: str, 242 split: Iterable[Tuple[Tuple[str, str, str], List[List[Tuple[str, str, str]]], int]], 243 tokenizer_name: str) -> None: 244 """ 245 Serialize a sampled fewshot split. 246 247 split is an iterable of (query, candidates, answer) tuples. 248 In these tuples, query is a tuple (text, e1, e2). 249 The relations are not given. 250 """ 251 tokenizer: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(str(path / "tokenizer")) 252 entity_dictionary = Dictionary() 253 254 data: List[Tuple[torch.Tensor, int, int, int, int, List[List[torch.Tensor]], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]] = [] 255 for train, test, answer in split: 256 query_text, query_e1_pos, query_e2_pos = process_text_2(train[0], tokenizer) 257 query_e1 = entity_dictionary.encode(train[1]) 258 query_e2 = entity_dictionary.encode(train[2]) 259 260 way = len(test) 261 shot = len(test[0]) 262 candidates_processed_text: List[List[Tuple[torch.Tensor, int, int]]] = list(map(lambda relation: list(map(lambda candidate: process_text_2(candidate[0], tokenizer), relation)), test)) 263 candidates_text_len = max(map(lambda relation: max(map(lambda candidate: candidate[0].shape[0], relation)), candidates_processed_text)) 264 265 candidates_text = [[None]*shot for _ in range(way)] 266 candidates_e1_pos = torch.empty((way, shot), dtype=torch.int64) 267 candidates_e2_pos = torch.empty((way, shot), dtype=torch.int64) 268 candidates_e1 = torch.empty((way, shot), dtype=torch.int64) 269 candidates_e2 = torch.empty((way, shot), dtype=torch.int64) 270 271 for n, (relation, relation_processed) in enumerate(zip(test, candidates_processed_text)): 272 for k, (candidate, candidate_processed) in enumerate(zip(relation, relation_processed)): 273 candidates_text[n][k] = candidate_processed[0] 274 candidates_e1_pos[n, k] = candidate_processed[1] 275 candidates_e2_pos[n, k] = candidate_processed[2] 276 candidates_e1[n, k] = entity_dictionary.encode(candidate[1]) 277 candidates_e2[n, k] = entity_dictionary.encode(candidate[2]) 278 279 data.append((query_text, query_e1_pos, query_e2_pos, query_e1, query_e2, candidates_text, candidates_e1_pos, candidates_e2_pos, candidates_e1, candidates_e2, answer)) 280 entity_dictionary.save(path / f"{name}.entities") 281 torch.save(("sampled fewshot", data), path / name) 282 283 284 def serialize_dataset( 285 supervision: str, 286 path: pathlib.Path, 287 splits: Dict[str, Iterable[Tuple[str, str, str, str, str]]], 288 tokenizer_name: str, 289 unknown_entity: Optional[str] = None, 290 unknown_relation: Optional[str] = None, 291 seed: Optional[int] = None) -> None: 292 """ 293 Serialize a dataset to a given path. 294 295 The splits must be given as iterables of (text, relation, relation_base, e1, e2) tuples. 296 supervision must be one of "supervised" or "fewshot". 297 """ 298 if not path.is_dir(): 299 path.mkdir() 300 301 tokenizer: transformers.PreTrainedTokenizer = make_tokenizer(tokenizer_name, path / "tokenizer") 302 entity_dictionary = Dictionary(unknown=unknown_entity) 303 relation_dictionary = RelationDictionary(unknown=unknown_relation) 304 305 serialize_split = serialize_supervised_split if supervision == "supervised" else serialize_fewshot_split 306 for split_name in ["train", "valid", "test"]: 307 if split_name not in splits: 308 continue 309 310 split = list(splits[split_name]) 311 if split_name == "train": 312 rng = random.Random(seed) 313 rng.shuffle(split) 314 split = tqdm.tqdm(split, desc=f"{split_name} tokenization") 315 serialize_split(path / split_name, split, tokenizer, entity_dictionary, relation_dictionary) 316 entity_dictionary.save(path / "entities") 317 relation_dictionary.save(path / "relations") 318 319 320 def build_edge_list(data: Iterable[Tuple[str, List[Tuple[str, int, int]]]], tokenizer: transformers.PreTrainedTokenizer) -> Tuple[List[torch.Tensor], Dictionary, List[int], List[Tuple[int, int, int, int, int, int, int]]]: 321 """ 322 Build a list of edges and nodes corresponding to the given data. 323 324 The tuples in the returned edge list are composed of the following elements: 325 (entity 1, entity 2, sentence id, entity 1 start, entity 1 end, entity 2 start, entity 2 end) 326 """ 327 sentences: List[str] = [] 328 entity_dictionary = Dictionary() 329 degrees: List[int] = [] 330 edges: List[Tuple[int, int, int, int, int, int, int]] = [] 331 332 for raw_sentence, raw_entities in data: 333 sentence: torch.Tensor 334 entities: List[Tuple[str, int, int]] 335 sentence, entities = process_text_n(raw_sentence, raw_entities, tokenizer) 336 337 # Buffer the ids to avoid re-hashing the entities 338 entity_ids: List[Optional[int]] = [None] * len(entities) 339 edge_added: bool = False 340 341 # Add all edges appearing in the clique corresponding to this sentence 342 for i, (e1_name, e1_start, e1_end) in enumerate(entities): 343 for j, (e2_name, e2_start, e2_end) in enumerate(entities[:i]): 344 # Soares et al. footnote 2 "We use a window of 40 tokens" 345 if max(e2_end - e1_start, e1_end - e2_start) < 40: 346 if entity_ids[i] is None: 347 entity_ids[i] = entity_dictionary.encode(e1_name) 348 if entity_ids[i] >= len(degrees): 349 degrees.append(0) 350 e1_id: int = cast(int, entity_ids[i]) 351 352 if entity_ids[j] is None: 353 entity_ids[j] = entity_dictionary.encode(e2_name) 354 if entity_ids[j] >= len(degrees): 355 degrees.append(0) 356 e2_id: int = cast(int, entity_ids[j]) 357 358 if e1_id <= e2_id: 359 edges.append((e1_id, e2_id, len(sentences), e1_start, e1_end, e2_start, e2_end)) 360 else: 361 edges.append((e2_id, e1_id, len(sentences), e2_start, e2_end, e1_start, e1_end)) 362 degrees[e1_id] += 1 363 degrees[e2_id] += 1 364 edge_added = True 365 366 if edge_added: 367 sentences.append(sentence) 368 369 return sentences, entity_dictionary, degrees, edges 370 371 372 def serialize_unsupervised_dataset( 373 path: pathlib.Path, 374 data: Iterable[Tuple[str, List[Tuple[str, int, int]]]], 375 tokenizer_name: str, 376 seed: int) -> None: 377 """ 378 Serialize an unsupervised dataset to a given path. 379 380 The data must be given as an iterable of (sentence, list of entities) tuples. 381 Where entities are tuples of (identifier, start indice in sentence, end indice in sentence). 382 """ 383 if not path.is_dir(): 384 path.mkdir() 385 386 tokenizer: transformers.PreTrainedTokenizer = make_tokenizer(tokenizer_name, path / "tokenizer") 387 388 sentences: List[str] 389 entities: Dictionary 390 edges: List[Tuple[int, int, int, int, int, int, int]] 391 graph = Graph(*build_edge_list(data, tokenizer)) 392 graph.save(path / "train")