rfc4180.py (3460B)
1 import ast 2 import csv 3 import numpy 4 import os 5 6 from fuel.datasets import Dataset 7 from fuel.streams import DataStream 8 from fuel.iterator import DataIterator 9 10 import data 11 from data.hdf5 import origin_call_normalize, taxi_id_normalize 12 13 14 class TaxiData(Dataset): 15 example_iteration_scheme=None 16 17 class State: 18 __slots__ = ('file', 'index', 'reader') 19 20 def __init__(self, pathes, columns, has_header=False): 21 if not isinstance(pathes, list): 22 pathes=[pathes] 23 assert len(pathes)>0 24 self.columns=columns 25 self.provides_sources = tuple(map(lambda x: x[0], columns)) 26 self.pathes=pathes 27 self.has_header=has_header 28 super(TaxiData, self).__init__() 29 30 def open(self): 31 state=self.State() 32 state.file=open(self.pathes[0]) 33 state.index=0 34 state.reader=csv.reader(state.file) 35 if self.has_header: 36 state.reader.next() 37 return state 38 39 def close(self, state): 40 state.file.close() 41 42 def reset(self, state): 43 if state.index==0: 44 state.file.seek(0) 45 else: 46 state.index=0 47 state.file.close() 48 state.file=open(self.pathes[0]) 49 state.reader=csv.reader(state.file) 50 return state 51 52 def get_data(self, state, request=None): 53 if request is not None: 54 raise ValueError 55 try: 56 line=state.reader.next() 57 except (ValueError, StopIteration): 58 # print state.index 59 state.file.close() 60 state.index+=1 61 if state.index>=len(self.pathes): 62 raise StopIteration 63 state.file=open(self.pathes[state.index]) 64 state.reader=csv.reader(state.file) 65 if self.has_header: 66 state.reader.next() 67 return self.get_data(state) 68 69 values = [] 70 for _, constructor in self.columns: 71 values.append(constructor(line)) 72 return tuple(values) 73 74 taxi_columns = [ 75 ("trip_id", lambda l: l[0]), 76 ("call_type", lambda l: ord(l[1])-ord('A')), 77 ("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else origin_call_normalize(int(l[2]))), 78 ("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])), 79 ("taxi_id", lambda l: taxi_id_normalize(int(l[4]))), 80 ("timestamp", lambda l: int(l[5])), 81 ("day_type", lambda l: ord(l[6])-ord('A')), 82 ("missing_data", lambda l: l[7][0] == 'T'), 83 ("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))), 84 ("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))), 85 ("latitude", lambda l: map(lambda p: p[1], ast.literal_eval(l[8]))), 86 ] 87 88 taxi_columns_valid = taxi_columns + [ 89 ("destination_longitude", lambda l: numpy.float32(float(l[9]))), 90 ("destination_latitude", lambda l: numpy.float32(float(l[10]))), 91 ("time", lambda l: int(l[11])), 92 ] 93 94 train_file = os.path.join(data.path, 'train.csv') 95 valid_file = os.path.join(data.path, 'valid2-cut.csv') 96 test_file = os.path.join(data.path, 'test.csv') 97 98 train_data=TaxiData(train_file, taxi_columns, has_header=True) 99 valid_data = TaxiData(valid_file, taxi_columns_valid) 100 test_data = TaxiData(test_file, taxi_columns, has_header=True) 101 102 with open(os.path.join(data.path, 'valid2-cut-ids.txt')) as f: 103 valid_trips = [l for l in f] 104 105 def train_it(): 106 return DataIterator(DataStream(train_data)) 107 108 def test_it(): 109 return DataIterator(DataStream(valid_data))