commit 712035b88be1816d3fbd58ce69ae6464767c780e
parent 66159d9fce0129116e82e74cf3eb1d9e048b253d
Author: Alex Auvolat <alex.auvolat@ens.fr>
Date: Tue, 5 May 2015 13:11:18 -0400
Add day type and taxi id
Diffstat:
5 files changed, 75 insertions(+), 9 deletions(-)
diff --git a/config/simple_mlp_2_cswdt.py b/config/simple_mlp_2_cswdt.py
@@ -0,0 +1,25 @@
+import model.simple_mlp as model
+
+import data
+
+n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory
+n_end_pts = 5
+
+n_valid = 1000
+
+dim_embeddings = [
+ ('origin_call', data.n_train_clients+1, 10),
+ ('origin_stand', data.n_stands+1, 10),
+ ('week_of_year', 52, 10),
+ ('day_of_week', 7, 10),
+ ('qhour_of_day', 24 * 4, 10),
+ ('day_type', 3, 10),
+]
+
+dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings)
+dim_hidden = [200, 100]
+dim_output = 2
+
+learning_rate = 0.0001
+momentum = 0.99
+batch_size = 32
diff --git a/config/simple_mlp_tgtcls_1_cswdt.py b/config/simple_mlp_tgtcls_1_cswdt.py
@@ -14,9 +14,10 @@ with open(data.DATA_PATH + "/arrival-clusters.pkl") as f: tgtcls = cPickle.load(
dim_embeddings = [
('origin_call', data.n_train_clients+1, 10),
('origin_stand', data.n_stands+1, 10),
- ('week_of_year', 53, 10),
+ ('week_of_year', 52, 10),
('day_of_week', 7, 10),
- ('qhour_of_day', 24 * 4, 10)
+ ('qhour_of_day', 24 * 4, 10),
+ ('day_type', 3, 10),
]
dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings)
diff --git a/config/simple_mlp_tgtcls_1_cswdtx.py b/config/simple_mlp_tgtcls_1_cswdtx.py
@@ -0,0 +1,30 @@
+import cPickle
+
+import data
+
+import model.simple_mlp_tgtcls as model
+
+n_begin_end_pts = 5 # how many points we consider at the beginning and end of the known trajectory
+n_end_pts = 5
+
+n_valid = 1000
+
+with open(data.DATA_PATH + "/arrival-clusters.pkl") as f: tgtcls = cPickle.load(f)
+
+dim_embeddings = [
+ ('origin_call', data.n_train_clients+1, 10),
+ ('origin_stand', data.n_stands+1, 10),
+ ('week_of_year', 52, 10),
+ ('day_of_week', 7, 10),
+ ('qhour_of_day', 24 * 4, 10),
+ ('day_type', 3, 10),
+ ('taxi_id', 448, 10),
+]
+
+dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings)
+dim_hidden = [500]
+dim_output = tgtcls.shape[0]
+
+learning_rate = 0.0001
+momentum = 0.99
+batch_size = 32
diff --git a/data.py b/data.py
@@ -30,9 +30,7 @@ dataset_size = 1710670
def make_client_ids():
f = h5py.File(H5DATA_PATH, "r")
l = f['unique_origin_call']
- r = {}
- for i in range(l.shape[0]):
- r[l[i]] = i
+ r = {l[i]: i for i in range(l.shape[0])}
return r
client_ids = make_client_ids()
@@ -43,6 +41,18 @@ def get_client_id(n):
else:
return 0
+# ---- Read taxi IDs and create reverse dictionnary
+
+def make_taxi_ids():
+ f = h5py.File(H5DATA_PATH, "r")
+ l = f['unique_taxi_id']
+ r = {l[i]: i for i in range(l.shape[0])}
+ return r
+
+taxi_ids = make_taxi_ids()
+
+# ---- Enum types
+
class CallType(Enum):
CENTRAL = 0
STAND = 1
@@ -154,9 +164,9 @@ taxi_columns = [
("call_type", lambda l: CallType.from_data(l[1])),
("origin_call", lambda l: 0 if l[2] == '' or l[2] == 'NA' else get_client_id(int(l[2]))),
("origin_stand", lambda l: 0 if l[3] == '' or l[3] == 'NA' else int(l[3])),
- ("taxi_id", lambda l: int(l[4])),
+ ("taxi_id", lambda l: taxi_ids[int(l[4])]),
("timestamp", lambda l: int(l[5])),
- ("day_type", lambda l: DayType.from_data(l[6])),
+ ("day_type", lambda l: ord(l[6])-ord('A')),
("missing_data", lambda l: l[7][0] == 'T'),
("polyline", lambda l: map(tuple, ast.literal_eval(l[8]))),
("longitude", lambda l: map(lambda p: p[0], ast.literal_eval(l[8]))),
diff --git a/transformers.py b/transformers.py
@@ -107,7 +107,8 @@ class TaxiAddDateTime(Transformer):
data = next(self.child_epoch_iterator)
ts = data[self.id_timestamp]
date = datetime.datetime.utcfromtimestamp(ts)
- info = (date.isocalendar()[1] - 1, date.weekday(), date.hour * 4 + date.minute / 15)
+ yearweek = date.isocalendar()[1] - 1
+ info = ((51 if yearweek == 52 else yearweek), date.weekday(), date.hour * 4 + date.minute / 15)
return data + info
class TaxiExcludeTrips(Transformer):
@@ -122,4 +123,3 @@ class TaxiExcludeTrips(Transformer):
if not data[self.id_trip_id] in self.exclude: break
return data
-