commit 9a60f6c4e39c09187710608a9e225b6024b34364
parent 107b3798cca35472e158d94f36a0bd08f3fe1fe8
Author: Alex Auvolat <alex.auvolat@ens.fr>
Date: Mon, 27 Apr 2015 17:27:43 -0400
Add validation set ; fix lat/lon
Diffstat:
5 files changed, 62 insertions(+), 7 deletions(-)
diff --git a/data.py b/data.py
@@ -131,10 +131,16 @@ taxi_columns = [
("polyline", lambda x: map(tuple, ast.literal_eval(x))),
]
+taxi_columns_valid = taxi_columns + [
+ ("destination_x", float),
+ ("destination_y", float),
+ ("time", int),
+]
+
train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)]
valid_files=["%s/split/valid.csv" % (DATA_PATH,)]
train_data=TaxiData(train_files, taxi_columns)
-valid_data=TaxiData(valid_files, taxi_columns)
+valid_data=TaxiData(valid_files, taxi_columns_valid)
def train_it():
return DataIterator(DataStream(train_data))
diff --git a/hdist.py b/hdist.py
@@ -6,10 +6,10 @@ def hdist(a, b):
rearth = numpy.float32(6371)
deg2rad = numpy.float32(3.14159265358979 / 180)
- lat1 = a[:, 0] * deg2rad
- lon1 = a[:, 1] * deg2rad
- lat2 = b[:, 0] * deg2rad
- lon2 = b[:, 1] * deg2rad
+ lat1 = a[:, 1] * deg2rad
+ lon1 = a[:, 0] * deg2rad
+ lat2 = b[:, 1] * deg2rad
+ lon2 = b[:, 0] * deg2rad
dlat = abs(lat1-lat2)
dlon = abs(lon1-lon2)
diff --git a/make_valid.py b/make_valid.py
@@ -0,0 +1,32 @@
+# Takes valid-full.csv which is a subset of the lines of train.csv, formatted in the
+# exact same way
+# Outputs valid.csv which contains the polylines cut at an arbitrary location, and three
+# new columns containing the destination point and the length in seconds of the original polyline
+# (see contest definition for the time taken by a taxi along a polyline)
+
+import random
+import csv
+import ast
+
+with open("valid-full.csv") as f:
+ vlines = [l for l in csv.reader(f)]
+
+def make_valid_item(l):
+ polyline = ast.literal_eval(l[-1])
+ print len(polyline)
+ last = polyline[-1]
+ cut_idx = random.randrange(len(polyline)-5)
+ cut = polyline[:cut_idx+6]
+ return l[:-1] + [
+ cut.__str__(),
+ last[0],
+ last[1],
+ 15 * (len(polyline)-1),
+ ]
+
+vlines = map(make_valid_item, filter(lambda l: (len(ast.literal_eval(l[-1])) > 5), vlines))
+
+with open("valid.csv", "w") as f:
+ wr = csv.writer(f)
+ for r in vlines:
+ wr.writerow(r)
diff --git a/model.py b/model.py
@@ -2,6 +2,9 @@ import logging
import os
from argparse import ArgumentParser
+import numpy
+
+import theano
from theano import tensor
from theano.ifelse import ifelse
@@ -69,6 +72,7 @@ def main():
# Calculate the cost
# cost = (outputs - y).norm(2, axis=1).mean()
+ # outputs = numpy.array([[ -8.621953, 41.162142]], dtype='float32') + 0 * outputs
cost = hdist.hdist(outputs, y).mean()
cost.name = 'cost'
@@ -95,8 +99,8 @@ def main():
valid = data.valid_data
valid = DataStream(valid)
valid = transformers.add_first_k(n_begin_end_pts, valid)
- valid = transformers.add_random_k(n_begin_end_pts, valid)
- valid = transformers.add_destination(valid)
+ valid = transformers.add_last_k(n_begin_end_pts, valid)
+ valid = transformers.concat_destination_xy(valid)
valid = transformers.Select(valid, ('origin_stand', 'origin_call', 'first_k', 'last_k', 'destination'))
valid_stream = Batch(valid, iteration_scheme=ConstantScheme(batch_size))
diff --git a/transformers.py b/transformers.py
@@ -32,6 +32,19 @@ def add_random_k(k, stream):
stream = Mapping(stream, random_k, ('last_k',))
return stream
+def add_last_k(k, stream):
+ id_polyline=stream.sources.index('polyline')
+ def last_k(x):
+ return (numpy.array(x[id_polyline][-k:], dtype=theano.config.floatX).flatten(),)
+ stream = Filter(stream, lambda x: len(x[id_polyline])>=k)
+ stream = Mapping(stream, last_k, ('last_k',))
+ return stream
+
def add_destination(stream):
id_polyline=stream.sources.index('polyline')
return Mapping(stream, lambda x: (numpy.array(x[id_polyline][-1], dtype=theano.config.floatX),), ('destination',))
+
+def concat_destination_xy(stream):
+ id_dx=stream.sources.index('destination_x')
+ id_dy=stream.sources.index('destination_y')
+ return Mapping(stream, lambda x: (numpy.array([x[id_dx], x[id_dy]], dtype=theano.config.floatX),), ('destination',))