 # gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

graph.py (6829B)

```      1 from typing import List, Optional, Tuple, Union
2 import pathlib
3
4 import torch
5 import tqdm
6 import transformers
7
8 from gbure.data.dictionary import Dictionary
9 from gbure.utils import SharedLongTensorList
10
11
12 class Graph:
13     """
14     Graph represented by an adjacency list.
15
16     The node of the Graph correspond to entities, an edge between node n1 and node n2 indicates that the two corresponding entities appear together in a sentence, this sentence being the label of the edge.
17     To avoid storing the same sentence several time, a list of sentences is stored and the edges are labeled with the sentence id and entities postions.
18
19     The data is stored in four objects:
20         sentences: sid -> list of tokens without tags (<e1>, etc)
21         entity_dictionary: KB entity identifier -> vid (vertex id)
22         adj: source vid -> [(destination vid, eid)] sorted in lexicographic order
23         edges: eid (edge id) -> (e1 vid, e2 vid, sid, e1 start position, e1 end position, e2 start position, e2 end position)
24     """
25
26     def __init__(self,
27                  sentences: Optional[List[torch.Tensor]] = None,
28                  entity_dictionary: Optional[Dictionary] = None,
29                  degrees: Optional[List[int]] = None,
30                  edges: Optional[List[Tuple[int, int, int, int, int, int, int]]] = None,
31                  *, path: Optional[pathlib.Path] = None) -> None:
32         """ Initialize a Graph, either with the provided data or by loading the given file. """
33         if path is not None:
34             assert(sentences is None and entity_dictionary is None and degrees is None and edges is None)
36         else:
37             assert(sentences is not None and entity_dictionary is not None and degrees is not None and edges is not None)
38             self.sentences: Union[List[torch.Tensor], SharedLongTensorList] = sentences
39             self.entity_dictionary: Dictionary = entity_dictionary
41
42     def compile_edges_adj(self, edges: List[Tuple[int, int, int, int, int, int, int]], degrees: List[int]) -> None:
43         """ Compile the edge and adjacency list of the graph. """
44         edge_order: List[int] = sorted(range(len(edges)), key=edges.__getitem__)
45
46         self.edges: torch.Tensor = torch.empty((len(edges), 7), dtype=torch.int32)
47         self.adj: Union[List[torch.Tensor], SharedLongTensorList] = [torch.empty((degree, 2), dtype=torch.int32) for degree in degrees]
48
49         for new_id, old_id in enumerate(tqdm.tqdm(edge_order, desc="compiling graph")):
50             edge: torch.Tensor = torch.tensor(edges[old_id], dtype=torch.int32)
51             self.edges[new_id] = edge
52             for source, destination in [(edge, edge), (edge, edge)]:
53                 degrees[source] -= 1
54                 self.adj[source][degrees[source], 0] = destination
55                 self.adj[source][degrees[source], 1] = new_id
56             # Free the memory along the way to avoid storing the graph twice
57             edges[old_id] = None
58
59         for i, vertex in enumerate(tqdm.tqdm(self.adj, desc="sorting adjacency list")):
60             self.adj[i] = torch.stack(sorted(vertex, key=torch.Tensor.tolist))  # pytype: disable=unsupported-operands
61
62     @property
63     def order(self) -> int:
64         """ Number of vertices. """
66
67     @property
68     def size(self) -> int:
69         """ Number of edges. """
70         return self.edges.shape
71
72     def degree(self, vertex: int) -> int:
73         """ Number of edges connected to a given vertex. """
75
76     def eid_simple_adjacency(self, eid: int) -> bool:
77         """ Decide whether a given edge is the sole edge between two nodes. """
78         if eid > 0 and (self.edges[eid, :2] == self.edges[eid-1, :2]).all():
79             return False
80         if eid < self.size-1 and (self.edges[eid, :2] == self.edges[eid+1, :2]).all():
81             return False
82         return True
83
84     def eid_adjacency_range(self, eid: int, prefix: int = 2) -> Tuple[int, int]:
85         """ Return the range of edges (in the global edge list) sharing the same end points as eid. """
86         range_start: int = eid
87         while range_start > 0 and (self.edges[eid, :prefix] == self.edges[range_start-1, :prefix]).all():
88             range_start -= 1
89
90         range_end: int = eid+1
91         while range_end < self.size and (self.edges[eid, :prefix] == self.edges[range_end, :prefix]).all():
92             range_end += 1
93
94         return range_start, range_end
95
96     def reid_adjacency_begin(self, source: int, destination: int) -> int:
97         """ Return the first edge from source to destination as relative index in source's adjacency list. """
98         left: int = 0
99         right: int = self.degree(source)
100         while left < right:
101             middle: int = (left + right) // 2
102             if self.adj[source][middle, 0] < destination:
103                 left = middle + 1
104             else:
105                 right = middle
106
107         return left
108
109     def tagged_sentence(self, eid: int, tokenizer: transformers.PreTrainedTokenizer, invert: bool = False) -> Tuple[List[int], int, int]:
110         """ Get the tagged sentence corresponding to an edge. """
111         edge: torch.Tensor = self.edges[eid]
112         text: List[int] = self.sentences[edge].tolist()
113         # Abuse the fact that "</e1>" < "<e1>"
114         if invert:
115             tags: List[Tuple[int, str]] = [(edge, "<e1>"), (edge, "</e1>"), (edge, "<e2>"), (edge, "</e2>")]
116         else:
117             tags: List[Tuple[int, str]] = [(edge, "<e1>"), (edge, "</e1>"), (edge, "<e2>"), (edge, "</e2>")]
118         tags.sort(reverse=True)
119         for position, tag in tags:
120             text.insert(position, tokenizer.convert_tokens_to_ids(tag))
121         e1_pos: int = self.edges[eid, 5 if invert else 3].item()
122         e2_pos: int = self.edges[eid, 3 if invert else 5].item()
123         if e1_pos < e2_pos:
124             e2_pos += 2
125         else:
126             e1_pos += 2
127         return text, e1_pos, e2_pos
128
129     def save(self, path: pathlib.Path) -> None:
130         """ Save the graph to the given directory. """
131         if not path.is_dir():
132             path.mkdir()
133
134         self.entity_dictionary.save(path / "entities")
135         for attribute in ["sentences", "edges", "adj"]:
136             torch.save(getattr(self, attribute), path / attribute)
137
138     def load(self, path: pathlib.Path) -> None:
139         """ Load a graph from the given directory. """
140         self.entity_dictionary = Dictionary(path=path / "entities")
141         for attribute in ["sentences", "edges", "adj"]:
142             setattr(self, attribute, torch.load(path / attribute))
143
144     def share_memory(self) -> None:
145         self.sentences = SharedLongTensorList(self.sentences)