commit fe608831c62c7dba60a3bf57433d97b999e567c8
parent e7aba08e6b209ac7f091eb9f08b49a2c90b070ed
Author: Étienne Simon <esimon@esimon.eu>
Date: Thu, 23 Jul 2015 18:34:51 -0400
Fix tvt hdf5
Diffstat:
1 file changed, 12 insertions(+), 0 deletions(-)
diff --git a/data/make_tvt.py b/data/make_tvt.py
@@ -31,6 +31,9 @@ native_fields = {
all_fields = {
'path_len': numpy.int16,
'cluster': numpy.int16,
+ 'destination_latitude': numpy.float32,
+ 'destination_longitude': numpy.float32,
+ 'travel_time': numpy.int32,
}
all_fields.update(native_fields)
@@ -125,6 +128,15 @@ def make_tvt(test_cuts_name, valid_cuts_name, outpath):
i = train_i
train_i += 1
+ trajlen = len(traindata['latitude'][idtraj])
+ if trajlen == 0:
+ hdata['destination_latitude'] = data.train_gps_mean[0]
+ hdata['destination_longitude'] = data.train_gps_mean[1]
+ else:
+ hdata['destination_latitude'] = traindata['latitude'][idtraj][-1]
+ hdata['destination_longitude'] = traindata['longitude'][idtraj][-1]
+ hdata['travel_time'] = trajlen
+
for field in native_fields:
val = traindata[field][idtraj]
if field in ['latitude', 'longitude']: