make_tvt.py (6962B)
1 #!/usr/bin/env python2 2 # Separate the training set into a Training Valid and Test set 3 4 import os 5 import sys 6 import importlib 7 import cPickle 8 9 import h5py 10 import numpy 11 import theano 12 13 import data 14 from data.hdf5 import TaxiDataset 15 from error import hdist 16 17 18 native_fields = { 19 'trip_id': 'S19', 20 'call_type': numpy.int8, 21 'origin_call': numpy.int32, 22 'origin_stand': numpy.int8, 23 'taxi_id': numpy.int16, 24 'timestamp': numpy.int32, 25 'day_type': numpy.int8, 26 'missing_data': numpy.bool, 27 'latitude': data.Polyline, 28 'longitude': data.Polyline, 29 } 30 31 all_fields = { 32 'path_len': numpy.int16, 33 'cluster': numpy.int16, 34 'destination_latitude': numpy.float32, 35 'destination_longitude': numpy.float32, 36 'travel_time': numpy.int32, 37 } 38 39 all_fields.update(native_fields) 40 41 def cut_me_baby(train, cuts, excl={}): 42 dset = {} 43 cuts.sort() 44 cut_id = 0 45 for i in xrange(data.train_size): 46 if i%10000==0 and i!=0: 47 print >> sys.stderr, 'cut: {:d} done'.format(i) 48 if i in excl: 49 continue 50 time = train['timestamp'][i] 51 latitude = train['latitude'][i] 52 longitude = train['longitude'][i] 53 54 if len(latitude) == 0: 55 continue 56 57 end_time = time + 15 * (len(latitude) - 1) 58 59 while cuts[cut_id] < time: 60 if cut_id >= len(cuts)-1: 61 return dset 62 cut_id += 1 63 64 if end_time < cuts[cut_id]: 65 continue 66 else: 67 dset[i] = (cuts[cut_id] - time) / 15 + 1 68 69 return dset 70 71 def make_tvt(test_cuts_name, valid_cuts_name, outpath): 72 trainset = TaxiDataset('train') 73 traindata = trainset.get_data(None, slice(0, trainset.num_examples)) 74 idsort = traindata[trainset.sources.index('timestamp')].argsort() 75 76 traindata = dict(zip(trainset.sources, (t[idsort] for t in traindata))) 77 78 print >> sys.stderr, 'test cut begin' 79 test_cuts = importlib.import_module('.%s' % test_cuts_name, 'data.cuts').cuts 80 test = cut_me_baby(traindata, test_cuts) 81 82 print >> sys.stderr, 'valid cut begin' 83 valid_cuts = importlib.import_module('.%s' % valid_cuts_name, 'data.cuts').cuts 84 valid = cut_me_baby(traindata, valid_cuts, test) 85 86 test_size = len(test) 87 valid_size = len(valid) 88 train_size = data.train_size - test_size - valid_size 89 90 print ' set | size | ratio' 91 print ' ----- | ------- | -----' 92 print ' train | {:>7d} | {:>5.3f}'.format(train_size, float(train_size)/data.train_size) 93 print ' valid | {:>7d} | {:>5.3f}'.format(valid_size, float(valid_size)/data.train_size) 94 print ' test | {:>7d} | {:>5.3f}'.format(test_size , float(test_size )/data.train_size) 95 96 with open(os.path.join(data.path, 'arrival-clusters.pkl'), 'r') as f: 97 clusters = cPickle.load(f) 98 99 print >> sys.stderr, 'compiling cluster assignment function' 100 latitude = theano.tensor.scalar('latitude') 101 longitude = theano.tensor.scalar('longitude') 102 coords = theano.tensor.stack(latitude, longitude).dimshuffle('x', 0) 103 parent = theano.tensor.argmin(hdist(clusters, coords)) 104 cluster = theano.function([latitude, longitude], parent) 105 106 train_clients = set() 107 108 print >> sys.stderr, 'preparing hdf5 data' 109 hdata = {k: numpy.empty(shape=(data.train_size,), dtype=v) for k, v in all_fields.iteritems()} 110 111 train_i = 0 112 valid_i = train_size 113 test_i = train_size + valid_size 114 115 print >> sys.stderr, 'write: begin' 116 for idtraj in xrange(data.train_size): 117 if idtraj%10000==0 and idtraj!=0: 118 print >> sys.stderr, 'write: {:d} done'.format(idtraj) 119 in_test = idtraj in test 120 in_valid = not in_test and idtraj in valid 121 in_train = not in_test and not in_valid 122 123 if idtraj in test: 124 i = test_i 125 test_i += 1 126 elif idtraj in valid: 127 i = valid_i 128 valid_i += 1 129 else: 130 train_clients.add(traindata['origin_call'][idtraj]) 131 i = train_i 132 train_i += 1 133 134 trajlen = len(traindata['latitude'][idtraj]) 135 if trajlen == 0: 136 hdata['destination_latitude'][i] = data.train_gps_mean[0] 137 hdata['destination_longitude'][i] = data.train_gps_mean[1] 138 else: 139 hdata['destination_latitude'][i] = traindata['latitude'][idtraj][-1] 140 hdata['destination_longitude'][i] = traindata['longitude'][idtraj][-1] 141 hdata['travel_time'][i] = trajlen 142 143 for field in native_fields: 144 val = traindata[field][idtraj] 145 if field in ['latitude', 'longitude']: 146 if in_test: 147 val = val[:test[idtraj]] 148 elif in_valid: 149 val = val[:valid[idtraj]] 150 hdata[field][i] = val 151 152 plen = len(hdata['latitude'][i]) 153 hdata['path_len'][i] = plen 154 hdata['cluster'][i] = -1 if plen==0 else cluster(hdata['latitude'][i][0], hdata['longitude'][i][0]) 155 156 print >> sys.stderr, 'write: end' 157 158 print >> sys.stderr, 'removing useless origin_call' 159 for i in xrange(train_size, data.train_size): 160 if hdata['origin_call'][i] not in train_clients: 161 hdata['origin_call'][i] = 0 162 163 print >> sys.stderr, 'preparing split array' 164 165 split_array = numpy.empty(len(all_fields)*3, dtype=numpy.dtype([ 166 ('split', 'a', 64), 167 ('source', 'a', 21), 168 ('start', numpy.int64, 1), 169 ('stop', numpy.int64, 1), 170 ('indices', h5py.special_dtype(ref=h5py.Reference)), 171 ('available', numpy.bool, 1), 172 ('comment', 'a', 1)])) 173 174 flen = len(all_fields) 175 for i, field in enumerate(all_fields): 176 split_array[i]['split'] = 'train'.encode('utf8') 177 split_array[i+flen]['split'] = 'valid'.encode('utf8') 178 split_array[i+2*flen]['split'] = 'test'.encode('utf8') 179 split_array[i]['start'] = 0 180 split_array[i]['stop'] = train_size 181 split_array[i+flen]['start'] = train_size 182 split_array[i+flen]['stop'] = train_size + valid_size 183 split_array[i+2*flen]['start'] = train_size + valid_size 184 split_array[i+2*flen]['stop'] = train_size + valid_size + test_size 185 186 for d in [0, flen, 2*flen]: 187 split_array[i+d]['source'] = field.encode('utf8') 188 189 split_array[:]['indices'] = None 190 split_array[:]['available'] = True 191 split_array[:]['comment'] = '.'.encode('utf8') 192 193 print >> sys.stderr, 'writing hdf5 file' 194 file = h5py.File(outpath, 'w') 195 for k in all_fields.keys(): 196 file.create_dataset(k, data=hdata[k], maxshape=(data.train_size,)) 197 198 file.attrs['split'] = split_array 199 200 file.flush() 201 file.close() 202 203 if __name__ == '__main__': 204 if len(sys.argv) < 3 or len(sys.argv) > 4: 205 print >> sys.stderr, 'Usage: %s test_cutfile valid_cutfile [outfile]' % sys.argv[0] 206 sys.exit(1) 207 outpath = os.path.join(data.path, 'tvt.hdf5') if len(sys.argv) < 4 else sys.argv[3] 208 make_tvt(sys.argv[1], sys.argv[2], outpath)