cluster_arrival.py (1052B)
1 #!/usr/bin/env python2 2 import numpy 3 import cPickle 4 import scipy.misc 5 import os 6 7 from sklearn.cluster import MeanShift, estimate_bandwidth 8 from sklearn.datasets.samples_generator import make_blobs 9 from itertools import cycle 10 11 import data 12 from data.hdf5 import taxi_it 13 from data.transformers import add_destination 14 15 print "Generating arrival point list" 16 dests = [] 17 for v in taxi_it("train"): 18 if len(v['latitude']) == 0: continue 19 dests.append([v['latitude'][-1], v['longitude'][-1]]) 20 pts = numpy.array(dests) 21 22 with open(os.path.join(data.path, "arrivals.pkl"), "w") as f: 23 cPickle.dump(pts, f, protocol=cPickle.HIGHEST_PROTOCOL) 24 25 print "Doing clustering" 26 bw = estimate_bandwidth(pts, quantile=.1, n_samples=1000) 27 print bw 28 bw = 0.001 # ( 29 30 ms = MeanShift(bandwidth=bw, bin_seeding=True, min_bin_freq=5) 31 ms.fit(pts) 32 cluster_centers = ms.cluster_centers_ 33 34 print "Clusters shape: ", cluster_centers.shape 35 36 with open(os.path.join(data.path, "arrival-clusters.pkl"), "w") as f: 37 cPickle.dump(cluster_centers, f, protocol=cPickle.HIGHEST_PROTOCOL) 38