commit f31caf61be87850f3afcd367d6eb9521b2f613da
Author: Thomas Mesnard <thomas.mesnard@ens.fr>
Date:   Tue,  1 Mar 2016 00:27:15 +0100
Initial commit
Diffstat:
14 files changed, 841 insertions(+), 0 deletions(-)
diff --git a/LICENSE b/LICENSE
@@ -0,0 +1,21 @@
+The MIT License (MIT)
+
+Copyright (c) 2016 Thomas Mesnard
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
@@ -0,0 +1,21 @@
+DeepMind : Teaching Machines to Read and Comprehend
+=========================================
+
+This repository contains an implementation of the two models (the Deep LSTM and the Attentive Reader) described in *Teaching Machines to Read and Comprehend* by Karl Moritz Hermann and al., NIPS, 2015. This repository also contains an implementation of a Deep Bidirectional LSTM. 
+
+Models are implemented using [Theano](https://github.com/Theano/Theano) and [Blocks](https://github.com/mila-udem/blocks). Datasets are implemented using [Fuel](https://github.com/mila-udem/fuel).
+
+The corresponding dataset is provided by [DeepMind](https://github.com/deepmind/rc-data) but if the script does not work you can check [http://cs.nyu.edu/~kcho/DMQA/](http://cs.nyu.edu/~kcho/DMQA/) by [Kyunghyun Cho](http://www.kyunghyuncho.me/).
+
+Reference
+=========
+[Teaching Machines to Read and Comprehend](https://papers.nips.cc/paper/5945-teaching-machines-to-read-and-comprehend.pdf), by Karl Moritz Hermann, Tomáš Kočiský, Edward Grefenstette, Lasse Espeholt, Will Kay, Mustafa Suleyman and Phil Blunsom, Neural Information Processing Systems, 2015.
+
+
+Credits
+=======
+[Thomas Mesnard](https://github.com/thomasmesnard)
+
+[Alex Auvolat](https://github.com/Alexis211)
+
+[Étienne Simon](https://github.com/ejls)+
\ No newline at end of file
diff --git a/__init__.py b/__init__.py
diff --git a/config/__init__.py b/config/__init__.py
diff --git a/config/deep_bidir_lstm_2x128.py b/config/deep_bidir_lstm_2x128.py
@@ -0,0 +1,37 @@
+from blocks.algorithms import BasicMomentum, AdaDelta, RMSProp, Adam, CompositeRule, StepClipping
+from blocks.initialization import IsotropicGaussian, Constant
+from blocks.bricks import Tanh
+
+from model.deep_bidir_lstm import Model
+
+
+batch_size = 32
+sort_batch_count = 20
+
+shuffle_questions = True
+shuffle_entities = True
+
+concat_ctx_and_question = True
+concat_question_before = True		## should not matter for bidirectionnal network
+
+embed_size = 200
+
+lstm_size = [128, 128]
+skip_connections = True
+
+n_entities = 550
+out_mlp_hidden = []
+out_mlp_activations = []
+
+step_rule = CompositeRule([RMSProp(decay_rate=0.95, learning_rate=5e-5),
+                           BasicMomentum(momentum=0.9)])
+
+dropout = 0.1
+w_noise = 0.05
+
+valid_freq = 1000
+save_freq = 1000
+print_freq = 100
+
+weights_init = IsotropicGaussian(0.01)
+biases_init = Constant(0.)
diff --git a/config/deepmind_attentive_reader.py b/config/deepmind_attentive_reader.py
@@ -0,0 +1,42 @@
+from blocks.bricks import Tanh
+from blocks.algorithms import BasicMomentum, AdaDelta, RMSProp, Adam, CompositeRule, StepClipping, Momentum
+from blocks.initialization import IsotropicGaussian, Constant
+
+from model.attentive_reader import Model
+
+
+batch_size = 32
+sort_batch_count = 20
+
+shuffle_questions = True
+
+concat_ctx_and_question = False
+
+n_entities = 550
+embed_size = 200
+
+ctx_lstm_size = [256]
+ctx_skip_connections = True
+
+question_lstm_size = [256]
+question_skip_connections = True
+
+attention_mlp_hidden = [100]
+attention_mlp_activations = [Tanh()]
+
+out_mlp_hidden = []
+out_mlp_activations = []
+
+step_rule = CompositeRule([RMSProp(decay_rate=0.95, learning_rate=5e-5),
+                           BasicMomentum(momentum=0.9)])
+
+dropout = 0.2
+w_noise = 0.
+
+valid_freq = 1000
+save_freq = 1000
+print_freq = 100
+
+weights_init = IsotropicGaussian(0.01)
+biases_init = Constant(0.)
+
diff --git a/config/deepmind_deep_lstm.py b/config/deepmind_deep_lstm.py
@@ -0,0 +1,33 @@
+from blocks.algorithms import BasicMomentum, AdaDelta, RMSProp, Adam, CompositeRule, StepClipping
+from blocks.initialization import IsotropicGaussian, Constant
+
+from model.deep_lstm import Model
+
+
+batch_size = 32
+sort_batch_count = 20
+
+shuffle_questions = True
+
+concat_ctx_and_question = True
+concat_question_before = True
+
+embed_size = 200
+
+lstm_size = [256, 256]
+skip_connections = True
+
+out_mlp_hidden = []
+out_mlp_activations = []
+
+step_rule = CompositeRule([RMSProp(decay_rate=0.95, learning_rate=1e-4),
+                           BasicMomentum(momentum=0.9)])
+
+dropout = 0.1
+
+valid_freq = 1000
+save_freq = 1000
+print_freq = 100
+
+weights_init = IsotropicGaussian(0.01)
+biases_init = Constant(0.)
diff --git a/data.py b/data.py
@@ -0,0 +1,177 @@
+import logging
+import random
+import numpy
+
+import cPickle
+
+from picklable_itertools import iter_
+
+from fuel.datasets import Dataset
+from fuel.streams import DataStream
+from fuel.schemes import IterationScheme, ConstantScheme
+from fuel.transformers import Batch, Mapping, SortMapping, Unpack, Padding, Transformer
+
+import sys
+import os
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger(__name__)
+
+class QADataset(Dataset):
+    def __init__(self, path, vocab_file, n_entities, need_sep_token, **kwargs):
+        self.provides_sources = ('context', 'question', 'answer', 'candidates')
+
+        self.path = path
+
+        self.vocab = ['@entity%d' % i for i in range(n_entities)] + \
+                     [w.rstrip('\n') for w in open(vocab_file)] + \
+                     ['<UNK>', '@placeholder'] + \
+                     (['<SEP>'] if need_sep_token else [])
+
+        self.n_entities = n_entities
+        self.vocab_size = len(self.vocab)
+        self.reverse_vocab = {w: i for i, w in enumerate(self.vocab)}
+
+        super(QADataset, self).__init__(**kwargs)
+
+    def to_word_id(self, w, cand_mapping):
+        if w in cand_mapping:
+            return cand_mapping[w]
+        elif w[:7] == '@entity':
+            raise ValueError("Unmapped entity token: %s"%w)
+        elif w in self.reverse_vocab:
+            return self.reverse_vocab[w]
+        else:
+            return self.reverse_vocab['<UNK>']
+
+    def to_word_ids(self, s, cand_mapping):
+        return numpy.array([self.to_word_id(x, cand_mapping) for x in s.split(' ')], dtype=numpy.int32)
+
+    def get_data(self, state=None, request=None):
+        if request is None or state is not None:
+            raise ValueError("Expected a request (name of a question file) and no state.")
+
+        lines = [l.rstrip('\n') for l in open(os.path.join(self.path, request))]
+
+        ctx = lines[2]
+        q = lines[4]
+        a = lines[6]
+        cand = [s.split(':')[0] for s in lines[8:]]
+
+        entities = range(self.n_entities)
+        while len(cand) > len(entities):
+            logger.warning("Too many entities (%d) for question: %s, using duplicate entity identifiers"
+                %(len(cand), request))
+            entities = entities + entities
+        random.shuffle(entities)
+        cand_mapping = {t: k for t, k in zip(cand, entities)}
+
+        ctx = self.to_word_ids(ctx, cand_mapping)
+        q = self.to_word_ids(q, cand_mapping)
+        cand = numpy.array([self.to_word_id(x, cand_mapping) for x in cand], dtype=numpy.int32)
+        a = numpy.int32(self.to_word_id(a, cand_mapping))
+
+        if not a < self.n_entities:
+            raise ValueError("Invalid answer token %d"%a)
+        if not numpy.all(cand < self.n_entities):
+            raise ValueError("Invalid candidate in list %s"%repr(cand))
+        if not numpy.all(ctx < self.vocab_size):
+            raise ValueError("Context word id out of bounds: %d"%int(ctx.max()))
+        if not numpy.all(ctx >= 0):
+            raise ValueError("Context word id negative: %d"%int(ctx.min()))
+        if not numpy.all(q < self.vocab_size):
+            raise ValueError("Question word id out of bounds: %d"%int(q.max()))
+        if not numpy.all(q >= 0):
+            raise ValueError("Question word id negative: %d"%int(q.min()))
+
+        return (ctx, q, a, cand)
+
+class QAIterator(IterationScheme):
+    requests_examples = True
+    def __init__(self, path, shuffle=False, **kwargs):
+        self.path = path
+        self.shuffle = shuffle
+
+        super(QAIterator, self).__init__(**kwargs)
+    
+    def get_request_iterator(self):
+        l = [f for f in os.listdir(self.path)
+               if os.path.isfile(os.path.join(self.path, f))]
+        if self.shuffle:
+            random.shuffle(l)
+        return iter_(l)
+
+# -------------- DATASTREAM SETUP --------------------
+
+
+class ConcatCtxAndQuestion(Transformer):
+    produces_examples = True
+    def __init__(self, stream, concat_question_before, separator_token=None, **kwargs):
+        assert stream.sources == ('context', 'question', 'answer', 'candidates')
+        self.sources = ('question', 'answer', 'candidates')
+
+        self.sep = numpy.array([separator_token] if separator_token is not None else [],
+                               dtype=numpy.int32)
+        self.concat_question_before = concat_question_before
+
+        super(ConcatCtxAndQuestion, self).__init__(stream, **kwargs)
+
+    def get_data(self, request=None):
+        if request is not None:
+            raise ValueError('Unsupported: request')
+
+        ctx, q, a, cand = next(self.child_epoch_iterator)
+
+        if self.concat_question_before:
+            return (numpy.concatenate([q, self.sep, ctx]), a, cand)
+        else:
+            return (numpy.concatenate([ctx, self.sep, q]), a, cand)
+        
+class _balanced_batch_helper(object):
+    def __init__(self, key):
+        self.key = key
+    def __call__(self, data):
+        return data[self.key].shape[0]
+
+def setup_datastream(path, vocab_file, config):
+    ds = QADataset(path, vocab_file, config.n_entities, need_sep_token=config.concat_ctx_and_question)
+    it = QAIterator(path, shuffle=config.shuffle_questions)
+
+    stream = DataStream(ds, iteration_scheme=it)
+
+    if config.concat_ctx_and_question:
+        stream = ConcatCtxAndQuestion(stream, config.concat_question_before, ds.reverse_vocab['<SEP>'])
+
+    # Sort sets of multiple batches to make batches of similar sizes
+    stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size * config.sort_batch_count))
+    comparison = _balanced_batch_helper(stream.sources.index('question' if config.concat_ctx_and_question else 'context'))
+    stream = Mapping(stream, SortMapping(comparison))
+    stream = Unpack(stream)
+
+    stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size))
+    stream = Padding(stream, mask_sources=['context', 'question', 'candidates'], mask_dtype='int32')
+
+    return ds, stream
+
+if __name__ == "__main__":
+    # Test
+    class DummyConfig:
+        def __init__(self):
+            self.shuffle_entities = True
+            self.shuffle_questions = False
+            self.concat_ctx_and_question = False
+            self.concat_question_before = False
+            self.batch_size = 2
+            self.sort_batch_count = 1000
+
+    ds, stream = setup_datastream(os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/training"),
+                                  os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/stats/training/vocab.txt"),
+                                  DummyConfig())
+    it = stream.get_epoch_iterator()
+
+    for i, d in enumerate(stream.get_epoch_iterator()):
+        print '--'
+        print d
+        if i > 2: break
+
+# vim: set sts=4 ts=4 sw=4 tw=0 et :
diff --git a/model/__init__.py b/model/__init__.py
diff --git a/model/attentive_reader.py b/model/attentive_reader.py
@@ -0,0 +1,152 @@
+import theano
+from theano import tensor
+import numpy
+
+from blocks.bricks import Tanh, Softmax, Linear, MLP, Identity, Rectifier
+from blocks.bricks.lookup import LookupTable
+from blocks.bricks.recurrent import LSTM
+
+from blocks.filter import VariableFilter
+from blocks.roles import WEIGHT
+from blocks.graph import ComputationGraph, apply_dropout, apply_noise
+
+def make_bidir_lstm_stack(seq, seq_dim, mask, sizes, skip=True, name=''):
+    bricks = []
+
+    curr_dim = [seq_dim]
+    curr_hidden = [seq]
+
+    hidden_list = []
+    for k, dim in enumerate(sizes):
+        fwd_lstm_ins = [Linear(input_dim=d, output_dim=4*dim, name='%s_fwd_lstm_in_%d_%d'%(name,k,l)) for l, d in enumerate(curr_dim)]
+        fwd_lstm = LSTM(dim=dim, activation=Tanh(), name='%s_fwd_lstm_%d'%(name,k))
+
+        bwd_lstm_ins = [Linear(input_dim=d, output_dim=4*dim, name='%s_bwd_lstm_in_%d_%d'%(name,k,l)) for l, d in enumerate(curr_dim)]
+        bwd_lstm = LSTM(dim=dim, activation=Tanh(), name='%s_bwd_lstm_%d'%(name,k))
+
+        bricks = bricks + [fwd_lstm, bwd_lstm] + fwd_lstm_ins + bwd_lstm_ins
+
+        fwd_tmp = sum(x.apply(v) for x, v in zip(fwd_lstm_ins, curr_hidden))
+        bwd_tmp = sum(x.apply(v) for x, v in zip(bwd_lstm_ins, curr_hidden))
+        fwd_hidden, _ = fwd_lstm.apply(fwd_tmp, mask=mask)
+        bwd_hidden, _ = bwd_lstm.apply(bwd_tmp[::-1], mask=mask[::-1])
+        hidden_list = hidden_list + [fwd_hidden, bwd_hidden]
+        if skip:
+            curr_hidden = [seq, fwd_hidden, bwd_hidden[::-1]]
+            curr_dim = [seq_dim, dim, dim]
+        else:
+            curr_hidden = [fwd_hidden, bwd_hidden[::-1]]
+            curr_dim = [dim, dim]
+
+    return bricks, hidden_list
+
+class Model():
+    def __init__(self, config, vocab_size):
+        question = tensor.imatrix('question')
+        question_mask = tensor.imatrix('question_mask')
+        context = tensor.imatrix('context')
+        context_mask = tensor.imatrix('context_mask')
+        answer = tensor.ivector('answer')
+        candidates = tensor.imatrix('candidates')
+        candidates_mask = tensor.imatrix('candidates_mask')
+
+        bricks = []
+
+        question = question.dimshuffle(1, 0)
+        question_mask = question_mask.dimshuffle(1, 0)
+        context = context.dimshuffle(1, 0)
+        context_mask = context_mask.dimshuffle(1, 0)
+
+        # Embed questions and cntext
+        embed = LookupTable(vocab_size, config.embed_size, name='question_embed')
+        bricks.append(embed)
+
+        qembed = embed.apply(question)
+        cembed = embed.apply(context)
+
+        qlstms, qhidden_list = make_bidir_lstm_stack(qembed, config.embed_size, question_mask.astype(theano.config.floatX),
+                                                     config.question_lstm_size, config.question_skip_connections, 'q')
+        clstms, chidden_list = make_bidir_lstm_stack(cembed, config.embed_size, context_mask.astype(theano.config.floatX),
+                                                     config.ctx_lstm_size, config.ctx_skip_connections, 'ctx')
+        bricks = bricks + qlstms + clstms
+
+        # Calculate question encoding (concatenate layer1)
+        if config.question_skip_connections:
+            qenc_dim = 2*sum(config.question_lstm_size)
+            qenc = tensor.concatenate([h[-1,:,:] for h in qhidden_list], axis=1)
+        else:
+            qenc_dim = 2*config.question_lstm_size[-1]
+            qenc = tensor.concatenate([h[-1,:,:] for h in qhidden_list[-2:]], axis=1)
+        qenc.name = 'qenc'
+
+        # Calculate context encoding (concatenate layer1)
+        if config.ctx_skip_connections:
+            cenc_dim = 2*sum(config.ctx_lstm_size)
+            cenc = tensor.concatenate(chidden_list, axis=2)
+        else:
+            cenc_dim = 2*config.ctx_lstm_size[-1]
+            cenc = tensor.concatenate(chidden_list[-2:], axis=2)
+        cenc.name = 'cenc'
+
+        # Attention mechanism MLP
+        attention_mlp = MLP(dims=config.attention_mlp_hidden + [1],
+                            activations=config.attention_mlp_activations[1:] + [Identity()],
+                            name='attention_mlp')
+        attention_qlinear = Linear(input_dim=qenc_dim, output_dim=config.attention_mlp_hidden[0], name='attq')
+        attention_clinear = Linear(input_dim=cenc_dim, output_dim=config.attention_mlp_hidden[0], use_bias=False, name='attc')
+        bricks += [attention_mlp, attention_qlinear, attention_clinear]
+        layer1 = Tanh().apply(attention_clinear.apply(cenc.reshape((cenc.shape[0]*cenc.shape[1], cenc.shape[2])))
+                                        .reshape((cenc.shape[0],cenc.shape[1],config.attention_mlp_hidden[0]))
+                             + attention_qlinear.apply(qenc)[None, :, :])
+        layer1.name = 'layer1'
+        att_weights = attention_mlp.apply(layer1.reshape((layer1.shape[0]*layer1.shape[1], layer1.shape[2])))
+        att_weights.name = 'att_weights_0'
+        att_weights = att_weights.reshape((layer1.shape[0], layer1.shape[1]))
+        att_weights.name = 'att_weights'
+
+        attended = tensor.sum(cenc * tensor.nnet.softmax(att_weights.T).T[:, :, None], axis=0)
+        attended.name = 'attended'
+
+        # Now we can calculate our output
+        out_mlp = MLP(dims=[cenc_dim + qenc_dim] + config.out_mlp_hidden + [config.n_entities],
+                      activations=config.out_mlp_activations + [Identity()],
+                      name='out_mlp')
+        bricks += [out_mlp]
+        probs = out_mlp.apply(tensor.concatenate([attended, qenc], axis=1))
+        probs.name = 'probs'
+
+        is_candidate = tensor.eq(tensor.arange(config.n_entities, dtype='int32')[None, None, :],
+                                 tensor.switch(candidates_mask, candidates, -tensor.ones_like(candidates))[:, :, None]).sum(axis=1)
+        probs = tensor.switch(is_candidate, probs, -1000 * tensor.ones_like(probs))
+
+        # Calculate prediction, cost and error rate
+        pred = probs.argmax(axis=1)
+        cost = Softmax().categorical_cross_entropy(answer, probs).mean()
+        error_rate = tensor.neq(answer, pred).mean()
+
+        # Apply dropout
+        cg = ComputationGraph([cost, error_rate])
+        if config.w_noise > 0:
+            noise_vars = VariableFilter(roles=[WEIGHT])(cg)
+            cg = apply_noise(cg, noise_vars, config.w_noise)
+        if config.dropout > 0:
+            cg = apply_dropout(cg, qhidden_list + chidden_list, config.dropout)
+        [cost_reg, error_rate_reg] = cg.outputs
+
+        # Other stuff
+        cost_reg.name = cost.name = 'cost'
+        error_rate_reg.name = error_rate.name = 'error_rate'
+
+        self.sgd_cost = cost_reg
+        self.monitor_vars = [[cost_reg], [error_rate_reg]]
+        self.monitor_vars_valid = [[cost], [error_rate]]
+
+        # Initialize bricks
+        for brick in bricks:
+            brick.weights_init = config.weights_init
+            brick.biases_init = config.biases_init
+            brick.initialize()
+
+        
+
+#  vim: set sts=4 ts=4 sw=4 tw=0 et :
diff --git a/model/deep_bidir_lstm.py b/model/deep_bidir_lstm.py
@@ -0,0 +1,109 @@
+import theano
+from theano import tensor
+import numpy
+
+from blocks.bricks import Tanh, Softmax, Linear, MLP, Identity, Rectifier
+from blocks.bricks.lookup import LookupTable
+from blocks.bricks.recurrent import LSTM
+
+from blocks.filter import VariableFilter
+from blocks.roles import WEIGHT
+from blocks.graph import ComputationGraph, apply_dropout, apply_noise
+
+class Model():
+    def __init__(self, config, vocab_size):
+        question = tensor.imatrix('question')
+        question_mask = tensor.imatrix('question_mask')
+        answer = tensor.ivector('answer')
+        candidates = tensor.imatrix('candidates')
+        candidates_mask = tensor.imatrix('candidates_mask')
+
+        bricks = []
+
+
+        # set time as first dimension
+        question = question.dimshuffle(1, 0)
+        question_mask = question_mask.dimshuffle(1, 0)
+
+        # Embed questions
+        embed = LookupTable(vocab_size, config.embed_size, name='question_embed')
+        bricks.append(embed)
+        qembed = embed.apply(question)
+
+        # Create and apply LSTM stack
+        curr_dim = [config.embed_size]
+        curr_hidden = [qembed]
+
+        hidden_list = []
+        for k, dim in enumerate(config.lstm_size):
+            fwd_lstm_ins = [Linear(input_dim=d, output_dim=4*dim, name='fwd_lstm_in_%d_%d'%(k,l)) for l, d in enumerate(curr_dim)]
+            fwd_lstm = LSTM(dim=dim, activation=Tanh(), name='fwd_lstm_%d'%k)
+
+            bwd_lstm_ins = [Linear(input_dim=d, output_dim=4*dim, name='bwd_lstm_in_%d_%d'%(k,l)) for l, d in enumerate(curr_dim)]
+            bwd_lstm = LSTM(dim=dim, activation=Tanh(), name='bwd_lstm_%d'%k)
+
+            bricks = bricks + [fwd_lstm, bwd_lstm] + fwd_lstm_ins + bwd_lstm_ins
+
+            fwd_tmp = sum(x.apply(v) for x, v in zip(fwd_lstm_ins, curr_hidden))
+            bwd_tmp = sum(x.apply(v) for x, v in zip(bwd_lstm_ins, curr_hidden))
+            fwd_hidden, _ = fwd_lstm.apply(fwd_tmp, mask=question_mask.astype(theano.config.floatX))
+            bwd_hidden, _ = bwd_lstm.apply(bwd_tmp[::-1], mask=question_mask.astype(theano.config.floatX)[::-1])
+            hidden_list = hidden_list + [fwd_hidden, bwd_hidden]
+            if config.skip_connections:
+                curr_hidden = [qembed, fwd_hidden, bwd_hidden[::-1]]
+                curr_dim = [config.embed_size, dim, dim]
+            else:
+                curr_hidden = [fwd_hidden, bwd_hidden[::-1]]
+                curr_dim = [dim, dim]
+
+        # Create and apply output MLP
+        if config.skip_connections:
+            out_mlp = MLP(dims=[2*sum(config.lstm_size)] + config.out_mlp_hidden + [config.n_entities],
+                          activations=config.out_mlp_activations + [Identity()],
+                          name='out_mlp')
+            bricks.append(out_mlp)
+
+            probs = out_mlp.apply(tensor.concatenate([h[-1,:,:] for h in hidden_list], axis=1))
+        else:
+            out_mlp = MLP(dims=[2*config.lstm_size[-1]] + config.out_mlp_hidden + [config.n_entities],
+                          activations=config.out_mlp_activations + [Identity()],
+                          name='out_mlp')
+            bricks.append(out_mlp)
+
+            probs = out_mlp.apply(tensor.concatenate([h[-1,:,:] for h in hidden_list[-2:]], axis=1))
+
+        is_candidate = tensor.eq(tensor.arange(config.n_entities, dtype='int32')[None, None, :],
+                                 tensor.switch(candidates_mask, candidates, -tensor.ones_like(candidates))[:, :, None]).sum(axis=1)
+        probs = tensor.switch(is_candidate, probs, -1000 * tensor.ones_like(probs))
+
+        # Calculate prediction, cost and error rate
+        pred = probs.argmax(axis=1)
+        cost = Softmax().categorical_cross_entropy(answer, probs).mean()
+        error_rate = tensor.neq(answer, pred).mean()
+
+        # Apply dropout
+        cg = ComputationGraph([cost, error_rate])
+        if config.w_noise > 0:
+            noise_vars = VariableFilter(roles=[WEIGHT])(cg)
+            cg = apply_noise(cg, noise_vars, config.w_noise)
+        if config.dropout > 0:
+            cg = apply_dropout(cg, hidden_list, config.dropout)
+        [cost_reg, error_rate_reg] = cg.outputs
+
+        # Other stuff
+        cost_reg.name = cost.name = 'cost'
+        error_rate_reg.name = error_rate.name = 'error_rate'
+
+        self.sgd_cost = cost_reg
+        self.monitor_vars = [[cost_reg], [error_rate_reg]]
+        self.monitor_vars_valid = [[cost], [error_rate]]
+
+        # Initialize bricks
+        for brick in bricks:
+            brick.weights_init = config.weights_init
+            brick.biases_init = config.biases_init
+            brick.initialize()
+
+        
+
+#  vim: set sts=4 ts=4 sw=4 tw=0 et :
diff --git a/model/deep_lstm.py b/model/deep_lstm.py
@@ -0,0 +1,99 @@
+import theano
+from theano import tensor
+import numpy
+
+from blocks.bricks import Tanh, Softmax, Linear, MLP, Identity, Rectifier
+from blocks.bricks.lookup import LookupTable
+from blocks.bricks.recurrent import LSTM
+
+from blocks.graph import ComputationGraph, apply_dropout
+
+
+class Model():
+    def __init__(self, config, vocab_size):
+        question = tensor.imatrix('question')
+        question_mask = tensor.imatrix('question_mask')
+        answer = tensor.ivector('answer')
+        candidates = tensor.imatrix('candidates')
+        candidates_mask = tensor.imatrix('candidates_mask')
+
+        bricks = []
+
+
+        # set time as first dimension
+        question = question.dimshuffle(1, 0)
+        question_mask = question_mask.dimshuffle(1, 0)
+
+        # Embed questions
+        embed = LookupTable(vocab_size, config.embed_size, name='question_embed')
+        bricks.append(embed)
+        qembed = embed.apply(question)
+
+        # Create and apply LSTM stack
+        curr_dim = config.embed_size
+        curr_hidden = qembed
+
+        hidden_list = []
+        for k, dim in enumerate(config.lstm_size):
+            lstm_in = Linear(input_dim=curr_dim, output_dim=4*dim, name='lstm_in_%d'%k)
+            lstm = LSTM(dim=dim, activation=Tanh(), name='lstm_%d'%k)
+            bricks = bricks + [lstm_in, lstm]
+
+            tmp = lstm_in.apply(curr_hidden)
+            hidden, _ = lstm.apply(tmp, mask=question_mask.astype(theano.config.floatX))
+            hidden_list.append(hidden)
+            if config.skip_connections:
+                curr_hidden = tensor.concatenate([hidden, qembed], axis=2)
+                curr_dim =  dim + config.embed_size
+            else:
+                curr_hidden = hidden
+                curr_dim = dim
+
+        # Create and apply output MLP
+        if config.skip_connections:
+            out_mlp = MLP(dims=[sum(config.lstm_size)] + config.out_mlp_hidden + [config.n_entities],
+                          activations=config.out_mlp_activations + [Identity()],
+                          name='out_mlp')
+            bricks.append(out_mlp)
+
+            probs = out_mlp.apply(tensor.concatenate([h[-1,:,:] for h in hidden_list], axis=1))
+        else:
+            out_mlp = MLP(dims=[config.lstm_size[-1]] + config.out_mlp_hidden + [config.n_entities],
+                          activations=config.out_mlp_activations + [Identity()],
+                          name='out_mlp')
+            bricks.append(out_mlp)
+
+            probs = out_mlp.apply(hidden_list[-1][-1,:,:])
+
+        is_candidate = tensor.eq(tensor.arange(config.n_entities, dtype='int32')[None, None, :],
+                                 tensor.switch(candidates_mask, candidates, -tensor.ones_like(candidates))[:, :, None]).sum(axis=1)
+        probs = tensor.switch(is_candidate, probs, -1000 * tensor.ones_like(probs))
+
+        # Calculate prediction, cost and error rate
+        pred = probs.argmax(axis=1)
+        cost = Softmax().categorical_cross_entropy(answer, probs).mean()
+        error_rate = tensor.neq(answer, pred).mean()
+
+        # Apply dropout
+        cg = ComputationGraph([cost, error_rate])
+        if config.dropout > 0:
+            cg = apply_dropout(cg, hidden_list, config.dropout)
+        [cost_reg, error_rate_reg] = cg.outputs
+
+        # Other stuff
+        cost_reg.name = cost.name = 'cost'
+        error_rate_reg.name = error_rate.name = 'error_rate'
+
+        self.sgd_cost = cost_reg
+        self.monitor_vars = [[cost_reg], [error_rate_reg]]
+        self.monitor_vars_valid = [[cost], [error_rate]]
+
+        # Initialize bricks
+        for brick in bricks:
+            brick.weights_init = config.weights_init
+            brick.biases_init = config.biases_init
+            brick.initialize()
+
+        
+
+#  vim: set sts=4 ts=4 sw=4 tw=0 et :
diff --git a/paramsaveload.py b/paramsaveload.py
@@ -0,0 +1,37 @@
+import logging
+
+import numpy
+
+import cPickle
+
+from blocks.extensions import SimpleExtension
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger('extensions.SaveLoadParams')
+
+class SaveLoadParams(SimpleExtension):
+	def __init__(self, path, model, **kwargs):
+		super(SaveLoadParams, self).__init__(**kwargs)
+
+		self.path = path
+		self.model = model
+	
+	def do_save(self):
+		with open(self.path, 'w') as f:
+			logger.info('Saving parameters to %s...'%self.path)
+			cPickle.dump(self.model.get_parameter_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
+	
+	def do_load(self):
+		try:
+			with open(self.path, 'r') as f:
+				logger.info('Loading parameters from %s...'%self.path)
+				self.model.set_parameter_values(cPickle.load(f))
+		except IOError:
+			pass
+
+	def do(self, which_callback, *args):
+		if which_callback == 'before_training':
+			self.do_load()
+		else:
+			self.do_save()
+		
diff --git a/train.py b/train.py
@@ -0,0 +1,112 @@
+#!/usr/bin/env python
+
+import logging
+import numpy
+import sys
+import os
+import importlib
+
+import theano
+from theano import tensor
+
+from blocks.extensions import Printing, SimpleExtension, FinishAfter, ProgressBar
+from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
+from blocks.graph import ComputationGraph
+from blocks.main_loop import MainLoop
+from blocks.model import Model
+from blocks.algorithms import GradientDescent
+
+try:
+    from blocks.extras.extensions.plot import Plot
+    plot_avail = True
+except ImportError:
+    plot_avail = False
+    print "No plotting extension available."
+
+import data
+from paramsaveload import SaveLoadParams
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger(__name__)
+
+sys.setrecursionlimit(500000)
+
+if __name__ == "__main__":
+    if len(sys.argv) != 2:
+        print >> sys.stderr, 'Usage: %s config' % sys.argv[0]
+        sys.exit(1)
+    model_name = sys.argv[1]
+    config = importlib.import_module('.%s' % model_name, 'config')
+
+    # Build datastream
+    path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/training")
+    valid_path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/questions/validation")
+    vocab_path = os.path.join(os.getenv("DATAPATH"), "deepmind-qa/cnn/stats/training/vocab.txt")
+
+    ds, train_stream = data.setup_datastream(path, vocab_path, config)
+    _, valid_stream = data.setup_datastream(valid_path, vocab_path, config)
+
+    dump_path = os.path.join("model_params", model_name+".pkl")
+
+    # Build model
+    m = config.Model(config, ds.vocab_size)
+
+    # Build the Blocks stuff for training
+    model = Model(m.sgd_cost)
+
+    algorithm = GradientDescent(cost=m.sgd_cost,
+                                step_rule=config.step_rule,
+                                parameters=model.parameters)
+
+    extensions = [
+            TrainingDataMonitoring(
+                [v for l in m.monitor_vars for v in l],
+                prefix='train',
+                every_n_batches=config.print_freq)
+    ]
+    if config.save_freq is not None and dump_path is not None:
+        extensions += [
+            SaveLoadParams(path=dump_path,
+                           model=model,
+                           before_training=True,
+                           after_training=True,
+                           after_epoch=True,
+                           every_n_batches=config.save_freq)
+        ]
+    if valid_stream is not None and config.valid_freq != -1:
+        extensions += [
+            DataStreamMonitoring(
+                [v for l in m.monitor_vars_valid for v in l],
+                valid_stream,
+                prefix='valid',
+                every_n_batches=config.valid_freq),
+        ]
+    if plot_avail:
+        plot_channels = [['train_' + v.name for v in lt] + ['valid_' + v.name for v in lv]
+                         for lt, lv in zip(m.monitor_vars, m.monitor_vars_valid)]
+        extensions += [
+            Plot(document='deepmind_qa_'+model_name,
+                 channels=plot_channels,
+                 # server_url='http://localhost:5006/', # If you need, change this
+                 every_n_batches=config.print_freq)
+        ]
+    extensions += [
+            Printing(every_n_batches=config.print_freq,
+                     after_epoch=True),
+            ProgressBar()
+    ]
+
+    main_loop = MainLoop(
+        model=model,
+        data_stream=train_stream,
+        algorithm=algorithm,
+        extensions=extensions
+    )
+
+    # Run the model !
+    main_loop.run()
+    main_loop.profile.report()
+
+
+
+#  vim: set sts=4 ts=4 sw=4 tw=0 et :