deep_question_answering

Implementation of "Teaching Machines to Read and Comprehend" proposed by Google DeepMind
git clone https://esimon.eu/repos/deep_question_answering.git
Log | Files | Refs | README | LICENSE

data.py (6723B)


      1 import logging
      2 import random
      3 import numpy
      4 
      5 import cPickle
      6 
      7 from picklable_itertools import iter_
      8 
      9 from fuel.datasets import Dataset
     10 from fuel.streams import DataStream
     11 from fuel.schemes import IterationScheme, ConstantScheme
     12 from fuel.transformers import Batch, Mapping, SortMapping, Unpack, Padding, Transformer
     13 
     14 import sys
     15 import os
     16 
     17 logging.basicConfig(level='INFO')
     18 logger = logging.getLogger(__name__)
     19 
     20 class QADataset(Dataset):
     21     def __init__(self, path, vocab_file, n_entities, need_sep_token, **kwargs):
     22         self.provides_sources = ('context', 'question', 'answer', 'candidates')
     23 
     24         self.path = path
     25 
     26         self.vocab = ['@entity%d' % i for i in range(n_entities)] + \
     27                      [w.rstrip('\n') for w in open(vocab_file)] + \
     28                      ['<UNK>', '@placeholder'] + \
     29                      (['<SEP>'] if need_sep_token else [])
     30 
     31         self.n_entities = n_entities
     32         self.vocab_size = len(self.vocab)
     33         self.reverse_vocab = {w: i for i, w in enumerate(self.vocab)}
     34 
     35         super(QADataset, self).__init__(**kwargs)
     36 
     37     def to_word_id(self, w, cand_mapping):
     38         if w in cand_mapping:
     39             return cand_mapping[w]
     40         elif w[:7] == '@entity':
     41             raise ValueError("Unmapped entity token: %s"%w)
     42         elif w in self.reverse_vocab:
     43             return self.reverse_vocab[w]
     44         else:
     45             return self.reverse_vocab['<UNK>']
     46 
     47     def to_word_ids(self, s, cand_mapping):
     48         return numpy.array([self.to_word_id(x, cand_mapping) for x in s.split(' ')], dtype=numpy.int32)
     49 
     50     def get_data(self, state=None, request=None):
     51         if request is None or state is not None:
     52             raise ValueError("Expected a request (name of a question file) and no state.")
     53 
     54         lines = [l.rstrip('\n') for l in open(os.path.join(self.path, request))]
     55 
     56         ctx = lines[2]
     57         q = lines[4]
     58         a = lines[6]
     59         cand = [s.split(':')[0] for s in lines[8:]]
     60 
     61         entities = range(self.n_entities)
     62         while len(cand) > len(entities):
     63             logger.warning("Too many entities (%d) for question: %s, using duplicate entity identifiers"
     64                 %(len(cand), request))
     65             entities = entities + entities
     66         random.shuffle(entities)
     67         cand_mapping = {t: k for t, k in zip(cand, entities)}
     68 
     69         ctx = self.to_word_ids(ctx, cand_mapping)
     70         q = self.to_word_ids(q, cand_mapping)
     71         cand = numpy.array([self.to_word_id(x, cand_mapping) for x in cand], dtype=numpy.int32)
     72         a = numpy.int32(self.to_word_id(a, cand_mapping))
     73 
     74         if not a < self.n_entities:
     75             raise ValueError("Invalid answer token %d"%a)
     76         if not numpy.all(cand < self.n_entities):
     77             raise ValueError("Invalid candidate in list %s"%repr(cand))
     78         if not numpy.all(ctx < self.vocab_size):
     79             raise ValueError("Context word id out of bounds: %d"%int(ctx.max()))
     80         if not numpy.all(ctx >= 0):
     81             raise ValueError("Context word id negative: %d"%int(ctx.min()))
     82         if not numpy.all(q < self.vocab_size):
     83             raise ValueError("Question word id out of bounds: %d"%int(q.max()))
     84         if not numpy.all(q >= 0):
     85             raise ValueError("Question word id negative: %d"%int(q.min()))
     86 
     87         return (ctx, q, a, cand)
     88 
     89 class QAIterator(IterationScheme):
     90     requests_examples = True
     91     def __init__(self, path, shuffle=False, **kwargs):
     92         self.path = path
     93         self.shuffle = shuffle
     94 
     95         super(QAIterator, self).__init__(**kwargs)
     96     
     97     def get_request_iterator(self):
     98         l = [f for f in os.listdir(self.path)
     99                if os.path.isfile(os.path.join(self.path, f))]
    100         if self.shuffle:
    101             random.shuffle(l)
    102         return iter_(l)
    103 
    104 # -------------- DATASTREAM SETUP --------------------
    105 
    106 
    107 class ConcatCtxAndQuestion(Transformer):
    108     produces_examples = True
    109     def __init__(self, stream, concat_question_before, separator_token=None, **kwargs):
    110         assert stream.sources == ('context', 'question', 'answer', 'candidates')
    111         self.sources = ('question', 'answer', 'candidates')
    112 
    113         self.sep = numpy.array([separator_token] if separator_token is not None else [],
    114                                dtype=numpy.int32)
    115         self.concat_question_before = concat_question_before
    116 
    117         super(ConcatCtxAndQuestion, self).__init__(stream, **kwargs)
    118 
    119     def get_data(self, request=None):
    120         if request is not None:
    121             raise ValueError('Unsupported: request')
    122 
    123         ctx, q, a, cand = next(self.child_epoch_iterator)
    124 
    125         if self.concat_question_before:
    126             return (numpy.concatenate([q, self.sep, ctx]), a, cand)
    127         else:
    128             return (numpy.concatenate([ctx, self.sep, q]), a, cand)
    129         
    130 class _balanced_batch_helper(object):
    131     def __init__(self, key):
    132         self.key = key
    133     def __call__(self, data):
    134         return data[self.key].shape[0]
    135 
    136 def setup_datastream(path, vocab_file, config):
    137     ds = QADataset(path, vocab_file, config.n_entities, need_sep_token=config.concat_ctx_and_question)
    138     it = QAIterator(path, shuffle=config.shuffle_questions)
    139 
    140     stream = DataStream(ds, iteration_scheme=it)
    141 
    142     if config.concat_ctx_and_question:
    143         stream = ConcatCtxAndQuestion(stream, config.concat_question_before, ds.reverse_vocab['<SEP>'])
    144 
    145     # Sort sets of multiple batches to make batches of similar sizes
    146     stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size * config.sort_batch_count))
    147     comparison = _balanced_batch_helper(stream.sources.index('question' if config.concat_ctx_and_question else 'context'))
    148     stream = Mapping(stream, SortMapping(comparison))
    149     stream = Unpack(stream)
    150 
    151     stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size))
    152     stream = Padding(stream, mask_sources=['context', 'question', 'candidates'], mask_dtype='int32')
    153 
    154     return ds, stream
    155 
    156 if __name__ == "__main__":
    157     # Test
    158     class DummyConfig:
    159         def __init__(self):
    160             self.shuffle_entities = True
    161             self.shuffle_questions = False
    162             self.concat_ctx_and_question = False
    163             self.concat_question_before = False
    164             self.batch_size = 2
    165             self.sort_batch_count = 1000
    166 
    167     ds, stream = setup_datastream(os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/training"),
    168                                   os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/stats/training/vocab.txt"),
    169                                   DummyConfig())
    170     it = stream.get_epoch_iterator()
    171 
    172     for i, d in enumerate(stream.get_epoch_iterator()):
    173         print '--'
    174         print d
    175         if i > 2: break
    176 
    177 # vim: set sts=4 ts=4 sw=4 tw=0 et :