gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

graph.py (6829B)


      1 from typing import List, Optional, Tuple, Union
      2 import pathlib
      3 
      4 import torch
      5 import tqdm
      6 import transformers
      7 
      8 from gbure.data.dictionary import Dictionary
      9 from gbure.utils import SharedLongTensorList
     10 
     11 
     12 class Graph:
     13     """
     14     Graph represented by an adjacency list.
     15 
     16     The node of the Graph correspond to entities, an edge between node n1 and node n2 indicates that the two corresponding entities appear together in a sentence, this sentence being the label of the edge.
     17     To avoid storing the same sentence several time, a list of sentences is stored and the edges are labeled with the sentence id and entities postions.
     18 
     19     The data is stored in four objects:
     20         sentences: sid -> list of tokens without tags (<e1>, etc)
     21         entity_dictionary: KB entity identifier -> vid (vertex id)
     22         adj: source vid -> [(destination vid, eid)] sorted in lexicographic order
     23         edges: eid (edge id) -> (e1 vid, e2 vid, sid, e1 start position, e1 end position, e2 start position, e2 end position)
     24     """
     25 
     26     def __init__(self,
     27                  sentences: Optional[List[torch.Tensor]] = None,
     28                  entity_dictionary: Optional[Dictionary] = None,
     29                  degrees: Optional[List[int]] = None,
     30                  edges: Optional[List[Tuple[int, int, int, int, int, int, int]]] = None,
     31                  *, path: Optional[pathlib.Path] = None) -> None:
     32         """ Initialize a Graph, either with the provided data or by loading the given file. """
     33         if path is not None:
     34             assert(sentences is None and entity_dictionary is None and degrees is None and edges is None)
     35             self.load(path)
     36         else:
     37             assert(sentences is not None and entity_dictionary is not None and degrees is not None and edges is not None)
     38             self.sentences: Union[List[torch.Tensor], SharedLongTensorList] = sentences
     39             self.entity_dictionary: Dictionary = entity_dictionary
     40             self.compile_edges_adj(edges, degrees)
     41 
     42     def compile_edges_adj(self, edges: List[Tuple[int, int, int, int, int, int, int]], degrees: List[int]) -> None:
     43         """ Compile the edge and adjacency list of the graph. """
     44         edge_order: List[int] = sorted(range(len(edges)), key=edges.__getitem__)
     45 
     46         self.edges: torch.Tensor = torch.empty((len(edges), 7), dtype=torch.int32)
     47         self.adj: Union[List[torch.Tensor], SharedLongTensorList] = [torch.empty((degree, 2), dtype=torch.int32) for degree in degrees]
     48 
     49         for new_id, old_id in enumerate(tqdm.tqdm(edge_order, desc="compiling graph")):
     50             edge: torch.Tensor = torch.tensor(edges[old_id], dtype=torch.int32)
     51             self.edges[new_id] = edge
     52             for source, destination in [(edge[0], edge[1]), (edge[1], edge[0])]:
     53                 degrees[source] -= 1
     54                 self.adj[source][degrees[source], 0] = destination
     55                 self.adj[source][degrees[source], 1] = new_id
     56             # Free the memory along the way to avoid storing the graph twice
     57             edges[old_id] = None
     58 
     59         for i, vertex in enumerate(tqdm.tqdm(self.adj, desc="sorting adjacency list")):
     60             self.adj[i] = torch.stack(sorted(vertex, key=torch.Tensor.tolist))  # pytype: disable=unsupported-operands
     61 
     62     @property
     63     def order(self) -> int:
     64         """ Number of vertices. """
     65         return len(self.adj)
     66 
     67     @property
     68     def size(self) -> int:
     69         """ Number of edges. """
     70         return self.edges.shape[0]
     71 
     72     def degree(self, vertex: int) -> int:
     73         """ Number of edges connected to a given vertex. """
     74         return self.adj[vertex].shape[0]
     75 
     76     def eid_simple_adjacency(self, eid: int) -> bool:
     77         """ Decide whether a given edge is the sole edge between two nodes. """
     78         if eid > 0 and (self.edges[eid, :2] == self.edges[eid-1, :2]).all():
     79             return False
     80         if eid < self.size-1 and (self.edges[eid, :2] == self.edges[eid+1, :2]).all():
     81             return False
     82         return True
     83 
     84     def eid_adjacency_range(self, eid: int, prefix: int = 2) -> Tuple[int, int]:
     85         """ Return the range of edges (in the global edge list) sharing the same end points as eid. """
     86         range_start: int = eid
     87         while range_start > 0 and (self.edges[eid, :prefix] == self.edges[range_start-1, :prefix]).all():
     88             range_start -= 1
     89 
     90         range_end: int = eid+1
     91         while range_end < self.size and (self.edges[eid, :prefix] == self.edges[range_end, :prefix]).all():
     92             range_end += 1
     93 
     94         return range_start, range_end
     95 
     96     def reid_adjacency_begin(self, source: int, destination: int) -> int:
     97         """ Return the first edge from source to destination as relative index in source's adjacency list. """
     98         left: int = 0
     99         right: int = self.degree(source)
    100         while left < right:
    101             middle: int = (left + right) // 2
    102             if self.adj[source][middle, 0] < destination:
    103                 left = middle + 1
    104             else:
    105                 right = middle
    106 
    107         return left
    108 
    109     def tagged_sentence(self, eid: int, tokenizer: transformers.PreTrainedTokenizer, invert: bool = False) -> Tuple[List[int], int, int]:
    110         """ Get the tagged sentence corresponding to an edge. """
    111         edge: torch.Tensor = self.edges[eid]
    112         text: List[int] = self.sentences[edge[2]].tolist()
    113         # Abuse the fact that "</e1>" < "<e1>"
    114         if invert:
    115             tags: List[Tuple[int, str]] = [(edge[5], "<e1>"), (edge[6], "</e1>"), (edge[3], "<e2>"), (edge[4], "</e2>")]
    116         else:
    117             tags: List[Tuple[int, str]] = [(edge[3], "<e1>"), (edge[4], "</e1>"), (edge[5], "<e2>"), (edge[6], "</e2>")]
    118         tags.sort(reverse=True)
    119         for position, tag in tags:
    120             text.insert(position, tokenizer.convert_tokens_to_ids(tag))
    121         e1_pos: int = self.edges[eid, 5 if invert else 3].item()
    122         e2_pos: int = self.edges[eid, 3 if invert else 5].item()
    123         if e1_pos < e2_pos:
    124             e2_pos += 2
    125         else:
    126             e1_pos += 2
    127         return text, e1_pos, e2_pos
    128 
    129     def save(self, path: pathlib.Path) -> None:
    130         """ Save the graph to the given directory. """
    131         if not path.is_dir():
    132             path.mkdir()
    133 
    134         self.entity_dictionary.save(path / "entities")
    135         for attribute in ["sentences", "edges", "adj"]:
    136             torch.save(getattr(self, attribute), path / attribute)
    137 
    138     def load(self, path: pathlib.Path) -> None:
    139         """ Load a graph from the given directory. """
    140         self.entity_dictionary = Dictionary(path=path / "entities")
    141         for attribute in ["sentences", "edges", "adj"]:
    142             setattr(self, attribute, torch.load(path / attribute))
    143 
    144     def share_memory(self) -> None:
    145         self.sentences = SharedLongTensorList(self.sentences)
    146         self.adj = SharedLongTensorList(self.adj, [-1, 2])