taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

hdf5.py (2300B)


      1 import os
      2 
      3 import h5py
      4 from fuel.datasets import H5PYDataset
      5 from fuel.iterator import DataIterator
      6 from fuel.schemes import SequentialExampleScheme
      7 from fuel.streams import DataStream
      8 
      9 import data
     10 
     11 
     12 class TaxiDataset(H5PYDataset):
     13     def __init__(self, which_set, filename='data.hdf5', **kwargs):
     14         self.filename = filename
     15         kwargs.setdefault('load_in_memory', True)
     16         super(TaxiDataset, self).__init__(self.data_path, (which_set,), **kwargs)
     17 
     18     @property
     19     def data_path(self):
     20         return os.path.join(data.path, self.filename)
     21 
     22     def extract(self, request):
     23         if not self.load_in_memory:
     24             raise ValueError('extract called on a dataset not loaded in memory')
     25         return dict(zip(self.sources, self.get_data(None, request)))
     26 
     27 class TaxiStream(DataStream):
     28     def __init__(self, which_set, filename='data.hdf5', iteration_scheme=None, **kwargs):
     29         dataset = TaxiDataset(which_set, filename, **kwargs)
     30         if iteration_scheme is None:
     31             iteration_scheme = SequentialExampleScheme(dataset.num_examples)
     32         super(TaxiStream, self).__init__(dataset, iteration_scheme=iteration_scheme)
     33 
     34 _origin_calls = None
     35 _reverse_origin_calls = None
     36 
     37 def origin_call_unnormalize(x):
     38     if _origin_calls is None:
     39         _origin_calls = h5py.File(os.path.join(data.path, 'data.hdf5'), 'r')['unique_origin_call']
     40     return _origin_calls[x]
     41 
     42 def origin_call_normalize(x):
     43     if _reverse_origin_calls is None:
     44         origin_call_unnormalize(0)
     45         _reverse_origin_calls = { _origin_calls[i]: i for i in range(_origin_calls.shape[0]) }
     46     return _reverse_origin_calls[x]
     47 
     48 _taxi_ids = None
     49 _reverse_taxi_ids = None
     50 
     51 def taxi_id_unnormalize(x):
     52     if _taxi_ids is None:
     53         _taxi_ids = h5py.File(os.path.join(data.path, 'data.hdf5'), 'r')['unique_taxi_id']
     54     return _taxi_ids[x]
     55 
     56 def taxi_id_normalize(x):
     57     if _reverse_taxi_ids is None:
     58         taxi_id_unnormalize(0)
     59         _reverse_taxi_ids = { _taxi_ids[i]: i for i in range(_taxi_ids.shape[0]) }
     60     return _reverse_taxi_ids[x]
     61 
     62 def taxi_it(which_set, filename='data.hdf5', sub=None, as_dict=True):
     63     dataset = TaxiDataset(which_set, filename)
     64     if sub is None:
     65         sub = xrange(dataset.num_examples)
     66     return DataIterator(DataStream(dataset), iter(sub), as_dict)