cut.py (1051B)
1 from fuel.schemes import IterationScheme 2 import sqlite3 3 import random 4 import os 5 from picklable_itertools import iter_ 6 7 import data 8 9 first_time = 1372636853 10 last_time = 1404172787 11 12 13 class TaxiTimeCutScheme(IterationScheme): 14 def __init__(self, num_cuts=100, dbfile=None, use_cuts=None): 15 self.num_cuts = num_cuts 16 self.dbfile = os.path.join(data.path, 'time_index.db') if dbfile == None else dbfile 17 self.use_cuts = use_cuts 18 19 def get_request_iterator(self): 20 cuts = self.use_cuts 21 if cuts == None: 22 cuts = [random.randrange(first_time, last_time) for _ in range(self.num_cuts)] 23 24 l = [] 25 with sqlite3.connect(self.dbfile) as db: 26 c = db.cursor() 27 for cut in cuts: 28 part = [i for (i,) in 29 c.execute('SELECT trip FROM trip_times WHERE begin >= ? AND begin <= ? AND end >= ?', 30 (cut - 40000, cut, cut))] 31 l = l + part 32 random.shuffle(l) 33 34 return iter_(l) 35