commit 929eaf8dd0233f8423b24b93b78c99fc9df65343
parent 9adfe767010e23823089b4db94cb4dc53cc3c12a
Author: Alex Auvolat <alex.auvolat@ens.fr>
Date: Mon, 4 May 2015 13:20:50 -0400
Fixes
Diffstat:
3 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/config/model_0.py b/config/model_0.py
@@ -9,7 +9,7 @@ n_valid = 1000
dim_embed = 10
dim_input = n_begin_end_pts * 2 * 2 + dim_embed + dim_embed
-dim_hidden = [100]
+dim_hidden = [200, 100]
dim_output = 2
learning_rate = 0.002
diff --git a/hdist.py b/hdist.py
@@ -2,15 +2,15 @@ from theano import tensor
import theano
import numpy
-rearth = const(6371)
-deg2rad = const(3.141592653589793 / 180)
-
def const(v):
if theano.config.floatX == 'float32':
return numpy.float32(v)
else:
return numpy.float64(v)
+rearth = const(6371)
+deg2rad = const(3.141592653589793 / 180)
+
def hdist(a, b):
lat1 = a[:, 0] * deg2rad
lon1 = a[:, 1] * deg2rad
diff --git a/transformers.py b/transformers.py
@@ -43,8 +43,8 @@ class TaxiGenerateSplits(Transformer):
raise ValueError
while self.isplit >= len(self.splits):
self.data = next(self.child_epoch_iterator)
- self.splits = range(len(self.data[self.id_polyline]))
- random.shuffle_array(self.splits)
+ self.splits = range(len(self.data[self.id_longitude]))
+ random.shuffle(self.splits)
if self.max_splits != -1 and len(self.splits) > self.max_splits:
self.splits = self.splits[:self.max_splits]
self.isplit = 0
@@ -55,11 +55,11 @@ class TaxiGenerateSplits(Transformer):
r = list(self.data)
- r[self.id_latitude] = r[self.id_latitude][:n]
- r[self.id_longitude] = r[self.id_longitude][:n]
+ r[self.id_latitude] = numpy.array(r[self.id_latitude][:n], dtype=theano.config.floatX)
+ r[self.id_longitude] = numpy.array(r[self.id_longitude][:n], dtype=theano.config.floatX)
- dlat = self.data[self.id_latitude][-1]
- dlon = self.data[self.id_longitude][-1]
+ dlat = numpy.float32(self.data[self.id_latitude][-1])
+ dlon = numpy.float32(self.data[self.id_longitude][-1])
return tuple(r + [dlat, dlon])