make_valid_cut.py (2481B)
1 #!/usr/bin/env python2 2 # Make a valid dataset by cutting the training set at specified timestamps 3 4 import os 5 import sys 6 import importlib 7 8 import h5py 9 import numpy 10 11 import data 12 from data.hdf5 import taxi_it 13 14 15 _fields = ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude', 'destination_latitude', 'destination_longitude', 'travel_time'] 16 17 def make_valid(cutfile, outpath): 18 cuts = importlib.import_module('.%s' % cutfile, 'data.cuts').cuts 19 20 print "Number of cuts:", len(cuts) 21 22 valid = [] 23 24 for line in taxi_it('train'): 25 time = line['timestamp'] 26 latitude = line['latitude'] 27 longitude = line['longitude'] 28 29 if len(latitude) == 0: 30 continue 31 32 for ts in cuts: 33 if time <= ts and time + 15 * (len(latitude) - 1) >= ts: 34 # keep it 35 n = (ts - time) / 15 + 1 36 line.update({ 37 'latitude': latitude[:n], 38 'longitude': longitude[:n], 39 'destination_latitude': latitude[-1], 40 'destination_longitude': longitude[-1], 41 'travel_time': 15 * (len(latitude)-1) 42 }) 43 valid.append(line) 44 break 45 46 print "Number of trips in validation set:", len(valid) 47 48 file = h5py.File(outpath, 'a') 49 clen = file['trip_id'].shape[0] 50 alen = len(valid) 51 for field in _fields: 52 dset = file[field] 53 dset.resize((clen + alen,)) 54 for i in xrange(alen): 55 dset[clen + i] = valid[i][field] 56 57 splits = file.attrs['split'] 58 slen = splits.shape[0] 59 splits = numpy.resize(splits, (slen+len(_fields),)) 60 for (i, field) in enumerate(_fields): 61 splits[slen+i]['split'] = ('cuts/%s' % cutfile).encode('utf8') 62 splits[slen+i]['source'] = field.encode('utf8') 63 splits[slen+i]['start'] = clen 64 splits[slen+i]['stop'] = alen 65 splits[slen+i]['indices'] = None 66 splits[slen+i]['available'] = True 67 splits[slen+i]['comment'] = '.' 68 file.attrs['split'] = splits 69 70 file.flush() 71 file.close() 72 73 if __name__ == '__main__': 74 if len(sys.argv) < 2 or len(sys.argv) > 3: 75 print >> sys.stderr, 'Usage: %s cutfile [outfile]' % sys.argv[0] 76 sys.exit(1) 77 outpath = os.path.join(data.path, 'valid.hdf5') if len(sys.argv) < 3 else sys.argv[2] 78 make_valid(sys.argv[1], outpath)