taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

__init__.py (1267B)


      1 from blocks.bricks import application, Initializable
      2 from blocks.bricks.lookup import LookupTable
      3 
      4 
      5 class ContextEmbedder(Initializable):
      6     def __init__(self, config, **kwargs):
      7         super(ContextEmbedder, self).__init__(**kwargs)
      8         self.dim_embeddings = config.dim_embeddings
      9         self.embed_weights_init = config.embed_weights_init
     10 
     11         self.inputs = [ name for (name, _, _) in self.dim_embeddings ]
     12         self.outputs = [ '%s_embedded' % name for name in self.inputs ]
     13 
     14         self.lookups = { name: LookupTable(name='%s_lookup' % name) for name in self.inputs }
     15         self.children = self.lookups.values()
     16 
     17     def _push_allocation_config(self):
     18         for (name, num, dim) in self.dim_embeddings:
     19             self.lookups[name].length = num
     20             self.lookups[name].dim = dim
     21 
     22     def _push_initialization_config(self):
     23         for name in self.inputs:
     24             self.lookups[name].weights_init = self.embed_weights_init
     25 
     26     @application
     27     def apply(self, **kwargs):
     28         return tuple(self.lookups[name].apply(kwargs[name]) for name in self.inputs)
     29 
     30     @apply.property('inputs')
     31     def apply_inputs(self):
     32         return self.inputs
     33 
     34     @apply.property('outputs')
     35     def apply_outputs(self):
     36         return self.outputs