batcher.py (8378B)
1 from typing import Any, Dict, List, Tuple 2 import collections 3 4 import torch 5 6 7 # Must be kept prefix-sorted! 8 # (prefix, list depth) 9 FEATURE_PREFIXES: List[Tuple[str, int]] = [ 10 ("query_e1_neighborhood_", 1), 11 ("query_e2_neighborhood_", 1), 12 ("candidates_e1_neighborhood_", 3), 13 ("candidates_e2_neighborhood_", 3), 14 ("first_e1_neighborhood_", 1), 15 ("first_e2_neighborhood_", 1), 16 ("second_e1_neighborhood_", 1), 17 ("second_e2_neighborhood_", 1), 18 ("third_e1_neighborhood_", 1), 19 ("third_e2_neighborhood_", 1), 20 ("query_", 0), 21 ("candidates_", 2), 22 ("first_", 0), 23 ("second_", 0), 24 ("third_", 0), 25 ("", 0)] 26 27 28 class Batcher: 29 """ 30 Batch a group of sample together. 31 32 Two new features are derived from the "text": its length and a mask. 33 """ 34 def __init__(self, pad_value: int) -> None: 35 """ Initialize a Batcher, using the provided value to pad text. """ 36 self.pad_value: int = pad_value 37 38 def add_length_field(self, batch: Dict[str, Any], prefix: str, depth: int) -> None: 39 """ Add the length field for the given prefix. """ 40 text: List[Any] = batch[f"{prefix}text"] 41 batch_size: int = len(text) 42 43 if depth == 0: 44 # text is a list of sentences 45 lengths: torch.Tensor = torch.empty((batch_size,), dtype=torch.int64) 46 for b, sentence in enumerate(text): 47 lengths[b] = sentence.shape[0] 48 elif depth == 1: 49 # text is a list of list of sentences (each sample contains several candidates) 50 size: int = len(text[0]) 51 lengths: torch.Tensor = torch.empty((batch_size, size), dtype=torch.int64) 52 for b, sample in enumerate(text): 53 for i, sentence in enumerate(sample): 54 lengths[b, i] = sentence.shape[0] 55 elif depth == 2: 56 # text is a list of list of list of sentences (each sample contains several candidates) 57 way: int = len(text[0]) 58 shot: int = len(text[0][0]) 59 lengths: torch.Tensor = torch.empty((batch_size, way, shot), dtype=torch.int64) 60 for b, sample in enumerate(text): 61 for w, candidates in enumerate(sample): 62 for s, candidate in enumerate(candidates): 63 lengths[b, w, s] = candidate.shape[0] 64 elif depth == 3: 65 # text is a list of list of list of list of sentences (each sample contains several candidates' neighborhoods) 66 way: int = len(text[0]) 67 shot: int = len(text[0][0]) 68 size: int = len(text[0][0][0]) 69 lengths: torch.Tensor = torch.empty((batch_size, way, shot, size), dtype=torch.int64) 70 for b, sample in enumerate(text): 71 for w, candidates in enumerate(sample): 72 for s, candidate in enumerate(candidates): 73 for n, neighbor in enumerate(candidate): 74 lengths[b, w, s, n] = neighbor.shape[0] 75 76 batch[f"{prefix}length"] = lengths 77 78 def process_text(self, batch: Dict[str, Any], prefix: str, depth: int, key: str) -> None: 79 """ Build mask and text batch by padding sentences. """ 80 in_text: List[Any] = batch[f"{prefix}{key}"] 81 if isinstance(batch[f"{prefix}length"], list): 82 self.add_length_field(batch, prefix, depth) 83 max_seq_len: int = max(batch[f"{prefix}length"].max(), 1) 84 batch_size: int = len(in_text) 85 86 if depth == 0: 87 # text is a list of sentences 88 text: torch.Tensor = torch.empty((batch_size, max_seq_len), dtype=torch.int64) 89 mask: torch.Tensor = torch.empty((batch_size, max_seq_len), dtype=torch.bool) 90 for b, sentence in enumerate(in_text): 91 text[b, :sentence.shape[0]] = sentence 92 text[b, sentence.shape[0]:] = self.pad_value 93 mask[b, :sentence.shape[0]] = 1 94 mask[b, sentence.shape[0]:] = 0 95 elif depth == 1: 96 # text is a list of list of sentences (each sample contains several candidates) 97 # In this case, we are not sure the tensor is full (some neighborhoods might be of different sizes or even empty) 98 size: int = len(in_text[0]) 99 text: torch.Tensor = torch.empty((batch_size, size, max_seq_len), dtype=torch.int64) 100 mask: torch.Tensor = torch.zeros((batch_size, size, max_seq_len), dtype=torch.bool) 101 for b, samples in enumerate(in_text): 102 for i, sentence in enumerate(samples): 103 text[b, i, :sentence.shape[0]] = sentence 104 text[b, i, sentence.shape[0]:] = self.pad_value 105 mask[b, i, :sentence.shape[0]] = 1 106 elif depth == 2: 107 # text is a list of list of list of sentences (each sample contains several candidates) 108 # In this case, we are sure the tensor is full (all n way have the save k shots) 109 way: int = len(in_text[0]) 110 shot: int = len(in_text[0][0]) 111 text: torch.Tensor = torch.empty((batch_size, way, shot, max_seq_len), dtype=torch.int64) 112 mask: torch.Tensor = torch.empty((batch_size, way, shot, max_seq_len), dtype=torch.bool) 113 for b, samples in enumerate(in_text): 114 for w, candidates in enumerate(samples): 115 for s, candidate in enumerate(candidates): 116 text[b, w, s, :candidate.shape[0]] = candidate 117 text[b, w, s, candidate.shape[0]:] = self.pad_value 118 mask[b, w, s, :candidate.shape[0]] = 1 119 mask[b, w, s, candidate.shape[0]:] = 0 120 elif depth == 3: 121 # text is a list of list of list of list of sentences (each sample contains several candidates' neighborhoods) 122 # In this case, we are not sure the tensor is full (some neighborhoods might be of different sizes or even empty) 123 way: int = len(in_text[0]) 124 shot: int = len(in_text[0][0]) 125 size: int = len(in_text[0][0][0]) 126 text: torch.Tensor = torch.empty((batch_size, way, shot, size, max_seq_len), dtype=torch.int64) 127 mask: torch.Tensor = torch.empty((batch_size, way, shot, size, max_seq_len), dtype=torch.bool) 128 for b, samples in enumerate(in_text): 129 for w, candidates in enumerate(samples): 130 for s, candidate in enumerate(candidates): 131 for n, neighbor in enumerate(candidate): 132 text[b, w, s, n, :neighbor.shape[0]] = neighbor 133 text[b, w, s, n, neighbor.shape[0]:] = self.pad_value 134 mask[b, w, s, n, :neighbor.shape[0]] = 1 135 mask[b, w, s, n, neighbor.shape[0]:] = 0 136 137 batch[f"{prefix}{key}"] = text 138 if f"{prefix}mask" not in batch: 139 batch[f"{prefix}mask"] = mask 140 141 def process_int_feature(self, batch: Dict[str, Any], prefix: str, feature: str) -> None: 142 """ Transform a list of integer into a torch LongTensor. """ 143 # TODO handle neighborhoods of different sizes 144 if isinstance(batch[f"{prefix}{feature}"][0], torch.Tensor): 145 batch[f"{prefix}{feature}"] = torch.stack(batch[f"{prefix}{feature}"]) 146 else: 147 batch[f"{prefix}{feature}"] = torch.tensor(batch[f"{prefix}{feature}"], dtype=torch.int64) 148 149 def __call__(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]: 150 """ Batch the provided samples """ 151 batch = collections.defaultdict(list) 152 for sample in samples: 153 for key, value in sample.items(): 154 batch[key].append(value) 155 156 for key in list(batch.keys()): 157 for prefix, depth in FEATURE_PREFIXES: 158 if key.startswith(prefix): 159 break 160 feature: str = key[len(prefix):] 161 if feature in ["text", "mlm_input", "mlm_target"]: 162 self.process_text(batch, prefix, depth, feature) 163 if feature in ["relation", "entity_positions", "entity_identifiers", "entity_degrees", "edge_identifier", "polarity", "answer", "eid"]: 164 self.process_int_feature(batch, prefix, feature) 165 166 return batch