csv_to_hdf5.py (5189B)
1 #!/usr/bin/env python2 2 3 import ast 4 import csv 5 import os 6 import sys 7 8 import h5py 9 import numpy 10 from fuel.converters.base import fill_hdf5_file 11 12 import data 13 14 15 taxi_id_dict = {} 16 origin_call_dict = {0: 0} 17 18 def get_unique_taxi_id(val): 19 if val in taxi_id_dict: 20 return taxi_id_dict[val] 21 else: 22 taxi_id_dict[val] = len(taxi_id_dict) 23 return len(taxi_id_dict) - 1 24 25 def get_unique_origin_call(val): 26 if val in origin_call_dict: 27 return origin_call_dict[val] 28 else: 29 origin_call_dict[val] = len(origin_call_dict) 30 return len(origin_call_dict) - 1 31 32 def read_stands(input_directory, h5file): 33 stands_name = numpy.empty(shape=(data.stands_size,), dtype=('a', 24)) 34 stands_latitude = numpy.empty(shape=(data.stands_size,), dtype=numpy.float32) 35 stands_longitude = numpy.empty(shape=(data.stands_size,), dtype=numpy.float32) 36 stands_name[0] = 'None' 37 stands_latitude[0] = stands_longitude[0] = 0 38 with open(os.path.join(input_directory, 'metaData_taxistandsID_name_GPSlocation.csv'), 'r') as f: 39 reader = csv.reader(f) 40 reader.next() # header 41 for line in reader: 42 id = int(line[0]) 43 stands_name[id] = line[1] 44 stands_latitude[id] = float(line[2]) 45 stands_longitude[id] = float(line[3]) 46 return (('stands', 'stands_name', stands_name), 47 ('stands', 'stands_latitude', stands_latitude), 48 ('stands', 'stands_longitude', stands_longitude)) 49 50 def read_taxis(input_directory, h5file, dataset): 51 print >> sys.stderr, 'read %s: begin' % dataset 52 size=getattr(data, '%s_size'%dataset) 53 trip_id = numpy.empty(shape=(size,), dtype='S19') 54 call_type = numpy.empty(shape=(size,), dtype=numpy.int8) 55 origin_call = numpy.empty(shape=(size,), dtype=numpy.int32) 56 origin_stand = numpy.empty(shape=(size,), dtype=numpy.int8) 57 taxi_id = numpy.empty(shape=(size,), dtype=numpy.int16) 58 timestamp = numpy.empty(shape=(size,), dtype=numpy.int32) 59 day_type = numpy.empty(shape=(size,), dtype=numpy.int8) 60 missing_data = numpy.empty(shape=(size,), dtype=numpy.bool) 61 latitude = numpy.empty(shape=(size,), dtype=data.Polyline) 62 longitude = numpy.empty(shape=(size,), dtype=data.Polyline) 63 with open(os.path.join(input_directory, '%s.csv'%dataset), 'r') as f: 64 reader = csv.reader(f) 65 reader.next() # header 66 id=0 67 for line in reader: 68 if id%10000==0 and id!=0: 69 print >> sys.stderr, 'read %s: %d done' % (dataset, id) 70 trip_id[id] = line[0] 71 call_type[id] = ord(line[1][0]) - ord('A') 72 origin_call[id] = 0 if line[2]=='NA' or line[2]=='' else get_unique_origin_call(int(line[2])) 73 origin_stand[id] = 0 if line[3]=='NA' or line[3]=='' else int(line[3]) 74 taxi_id[id] = get_unique_taxi_id(int(line[4])) 75 timestamp[id] = int(line[5]) 76 day_type[id] = ord(line[6][0]) - ord('A') 77 missing_data[id] = line[7][0] == 'T' 78 polyline = ast.literal_eval(line[8]) 79 latitude[id] = numpy.array([point[1] for point in polyline], dtype=numpy.float32) 80 longitude[id] = numpy.array([point[0] for point in polyline], dtype=numpy.float32) 81 id+=1 82 splits = () 83 print >> sys.stderr, 'read %s: writing' % dataset 84 for name in ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude']: 85 splits += ((dataset, name, locals()[name]),) 86 print >> sys.stderr, 'read %s: end' % dataset 87 return splits 88 89 def unique(h5file): 90 unique_taxi_id = numpy.empty(shape=(data.taxi_id_size,), dtype=numpy.int32) 91 assert len(taxi_id_dict) == data.taxi_id_size 92 for k, v in taxi_id_dict.items(): 93 unique_taxi_id[v] = k 94 95 unique_origin_call = numpy.empty(shape=(data.origin_call_size,), dtype=numpy.int32) 96 assert len(origin_call_dict) == data.origin_call_size 97 for k, v in origin_call_dict.items(): 98 unique_origin_call[v] = k 99 100 return (('unique_taxi_id', 'unique_taxi_id', unique_taxi_id), 101 ('unique_origin_call', 'unique_origin_call', unique_origin_call)) 102 103 def convert(input_directory, save_path): 104 h5file = h5py.File(save_path, 'w') 105 split = () 106 split += read_stands(input_directory, h5file) 107 split += read_taxis(input_directory, h5file, 'train') 108 print 'First origin_call not present in training set: ', len(origin_call_dict) 109 split += read_taxis(input_directory, h5file, 'test') 110 split += unique(h5file) 111 112 fill_hdf5_file(h5file, split) 113 114 for name in ['stands_name', 'stands_latitude', 'stands_longitude', 'unique_taxi_id', 'unique_origin_call']: 115 h5file[name].dims[0].label = 'index' 116 for name in ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude']: 117 h5file[name].dims[0].label = 'batch' 118 119 h5file.flush() 120 h5file.close() 121 122 if __name__ == '__main__': 123 if len(sys.argv) != 3: 124 print >> sys.stderr, 'Usage: %s download_dir output_file' % sys.argv[0] 125 sys.exit(1) 126 convert(sys.argv[1], sys.argv[2])