graph.py (6829B)
1 from typing import List, Optional, Tuple, Union 2 import pathlib 3 4 import torch 5 import tqdm 6 import transformers 7 8 from gbure.data.dictionary import Dictionary 9 from gbure.utils import SharedLongTensorList 10 11 12 class Graph: 13 """ 14 Graph represented by an adjacency list. 15 16 The node of the Graph correspond to entities, an edge between node n1 and node n2 indicates that the two corresponding entities appear together in a sentence, this sentence being the label of the edge. 17 To avoid storing the same sentence several time, a list of sentences is stored and the edges are labeled with the sentence id and entities postions. 18 19 The data is stored in four objects: 20 sentences: sid -> list of tokens without tags (<e1>, etc) 21 entity_dictionary: KB entity identifier -> vid (vertex id) 22 adj: source vid -> [(destination vid, eid)] sorted in lexicographic order 23 edges: eid (edge id) -> (e1 vid, e2 vid, sid, e1 start position, e1 end position, e2 start position, e2 end position) 24 """ 25 26 def __init__(self, 27 sentences: Optional[List[torch.Tensor]] = None, 28 entity_dictionary: Optional[Dictionary] = None, 29 degrees: Optional[List[int]] = None, 30 edges: Optional[List[Tuple[int, int, int, int, int, int, int]]] = None, 31 *, path: Optional[pathlib.Path] = None) -> None: 32 """ Initialize a Graph, either with the provided data or by loading the given file. """ 33 if path is not None: 34 assert(sentences is None and entity_dictionary is None and degrees is None and edges is None) 35 self.load(path) 36 else: 37 assert(sentences is not None and entity_dictionary is not None and degrees is not None and edges is not None) 38 self.sentences: Union[List[torch.Tensor], SharedLongTensorList] = sentences 39 self.entity_dictionary: Dictionary = entity_dictionary 40 self.compile_edges_adj(edges, degrees) 41 42 def compile_edges_adj(self, edges: List[Tuple[int, int, int, int, int, int, int]], degrees: List[int]) -> None: 43 """ Compile the edge and adjacency list of the graph. """ 44 edge_order: List[int] = sorted(range(len(edges)), key=edges.__getitem__) 45 46 self.edges: torch.Tensor = torch.empty((len(edges), 7), dtype=torch.int32) 47 self.adj: Union[List[torch.Tensor], SharedLongTensorList] = [torch.empty((degree, 2), dtype=torch.int32) for degree in degrees] 48 49 for new_id, old_id in enumerate(tqdm.tqdm(edge_order, desc="compiling graph")): 50 edge: torch.Tensor = torch.tensor(edges[old_id], dtype=torch.int32) 51 self.edges[new_id] = edge 52 for source, destination in [(edge[0], edge[1]), (edge[1], edge[0])]: 53 degrees[source] -= 1 54 self.adj[source][degrees[source], 0] = destination 55 self.adj[source][degrees[source], 1] = new_id 56 # Free the memory along the way to avoid storing the graph twice 57 edges[old_id] = None 58 59 for i, vertex in enumerate(tqdm.tqdm(self.adj, desc="sorting adjacency list")): 60 self.adj[i] = torch.stack(sorted(vertex, key=torch.Tensor.tolist)) # pytype: disable=unsupported-operands 61 62 @property 63 def order(self) -> int: 64 """ Number of vertices. """ 65 return len(self.adj) 66 67 @property 68 def size(self) -> int: 69 """ Number of edges. """ 70 return self.edges.shape[0] 71 72 def degree(self, vertex: int) -> int: 73 """ Number of edges connected to a given vertex. """ 74 return self.adj[vertex].shape[0] 75 76 def eid_simple_adjacency(self, eid: int) -> bool: 77 """ Decide whether a given edge is the sole edge between two nodes. """ 78 if eid > 0 and (self.edges[eid, :2] == self.edges[eid-1, :2]).all(): 79 return False 80 if eid < self.size-1 and (self.edges[eid, :2] == self.edges[eid+1, :2]).all(): 81 return False 82 return True 83 84 def eid_adjacency_range(self, eid: int, prefix: int = 2) -> Tuple[int, int]: 85 """ Return the range of edges (in the global edge list) sharing the same end points as eid. """ 86 range_start: int = eid 87 while range_start > 0 and (self.edges[eid, :prefix] == self.edges[range_start-1, :prefix]).all(): 88 range_start -= 1 89 90 range_end: int = eid+1 91 while range_end < self.size and (self.edges[eid, :prefix] == self.edges[range_end, :prefix]).all(): 92 range_end += 1 93 94 return range_start, range_end 95 96 def reid_adjacency_begin(self, source: int, destination: int) -> int: 97 """ Return the first edge from source to destination as relative index in source's adjacency list. """ 98 left: int = 0 99 right: int = self.degree(source) 100 while left < right: 101 middle: int = (left + right) // 2 102 if self.adj[source][middle, 0] < destination: 103 left = middle + 1 104 else: 105 right = middle 106 107 return left 108 109 def tagged_sentence(self, eid: int, tokenizer: transformers.PreTrainedTokenizer, invert: bool = False) -> Tuple[List[int], int, int]: 110 """ Get the tagged sentence corresponding to an edge. """ 111 edge: torch.Tensor = self.edges[eid] 112 text: List[int] = self.sentences[edge[2]].tolist() 113 # Abuse the fact that "</e1>" < "<e1>" 114 if invert: 115 tags: List[Tuple[int, str]] = [(edge[5], "<e1>"), (edge[6], "</e1>"), (edge[3], "<e2>"), (edge[4], "</e2>")] 116 else: 117 tags: List[Tuple[int, str]] = [(edge[3], "<e1>"), (edge[4], "</e1>"), (edge[5], "<e2>"), (edge[6], "</e2>")] 118 tags.sort(reverse=True) 119 for position, tag in tags: 120 text.insert(position, tokenizer.convert_tokens_to_ids(tag)) 121 e1_pos: int = self.edges[eid, 5 if invert else 3].item() 122 e2_pos: int = self.edges[eid, 3 if invert else 5].item() 123 if e1_pos < e2_pos: 124 e2_pos += 2 125 else: 126 e1_pos += 2 127 return text, e1_pos, e2_pos 128 129 def save(self, path: pathlib.Path) -> None: 130 """ Save the graph to the given directory. """ 131 if not path.is_dir(): 132 path.mkdir() 133 134 self.entity_dictionary.save(path / "entities") 135 for attribute in ["sentences", "edges", "adj"]: 136 torch.save(getattr(self, attribute), path / attribute) 137 138 def load(self, path: pathlib.Path) -> None: 139 """ Load a graph from the given directory. """ 140 self.entity_dictionary = Dictionary(path=path / "entities") 141 for attribute in ["sentences", "edges", "adj"]: 142 setattr(self, attribute, torch.load(path / attribute)) 143 144 def share_memory(self) -> None: 145 self.sentences = SharedLongTensorList(self.sentences) 146 self.adj = SharedLongTensorList(self.adj, [-1, 2])