hdf5.py (2300B)
1 import os 2 3 import h5py 4 from fuel.datasets import H5PYDataset 5 from fuel.iterator import DataIterator 6 from fuel.schemes import SequentialExampleScheme 7 from fuel.streams import DataStream 8 9 import data 10 11 12 class TaxiDataset(H5PYDataset): 13 def __init__(self, which_set, filename='data.hdf5', **kwargs): 14 self.filename = filename 15 kwargs.setdefault('load_in_memory', True) 16 super(TaxiDataset, self).__init__(self.data_path, (which_set,), **kwargs) 17 18 @property 19 def data_path(self): 20 return os.path.join(data.path, self.filename) 21 22 def extract(self, request): 23 if not self.load_in_memory: 24 raise ValueError('extract called on a dataset not loaded in memory') 25 return dict(zip(self.sources, self.get_data(None, request))) 26 27 class TaxiStream(DataStream): 28 def __init__(self, which_set, filename='data.hdf5', iteration_scheme=None, **kwargs): 29 dataset = TaxiDataset(which_set, filename, **kwargs) 30 if iteration_scheme is None: 31 iteration_scheme = SequentialExampleScheme(dataset.num_examples) 32 super(TaxiStream, self).__init__(dataset, iteration_scheme=iteration_scheme) 33 34 _origin_calls = None 35 _reverse_origin_calls = None 36 37 def origin_call_unnormalize(x): 38 if _origin_calls is None: 39 _origin_calls = h5py.File(os.path.join(data.path, 'data.hdf5'), 'r')['unique_origin_call'] 40 return _origin_calls[x] 41 42 def origin_call_normalize(x): 43 if _reverse_origin_calls is None: 44 origin_call_unnormalize(0) 45 _reverse_origin_calls = { _origin_calls[i]: i for i in range(_origin_calls.shape[0]) } 46 return _reverse_origin_calls[x] 47 48 _taxi_ids = None 49 _reverse_taxi_ids = None 50 51 def taxi_id_unnormalize(x): 52 if _taxi_ids is None: 53 _taxi_ids = h5py.File(os.path.join(data.path, 'data.hdf5'), 'r')['unique_taxi_id'] 54 return _taxi_ids[x] 55 56 def taxi_id_normalize(x): 57 if _reverse_taxi_ids is None: 58 taxi_id_unnormalize(0) 59 _reverse_taxi_ids = { _taxi_ids[i]: i for i in range(_taxi_ids.shape[0]) } 60 return _reverse_taxi_ids[x] 61 62 def taxi_it(which_set, filename='data.hdf5', sub=None, as_dict=True): 63 dataset = TaxiDataset(which_set, filename) 64 if sub is None: 65 sub = xrange(dataset.num_examples) 66 return DataIterator(DataStream(dataset), iter(sub), as_dict)