commit 107b3798cca35472e158d94f36a0bd08f3fe1fe8
parent a25d4fb6e92f203183de2d89e8c467a6b14e1730
Author: Alex Auvolat <alex.auvolat@ens.fr>
Date: Mon, 27 Apr 2015 16:38:25 -0400
Make TaxiData more flexible.
Diffstat:
M | data.py | | | 34 | +++++++++++++++++++++------------- |
1 file changed, 21 insertions(+), 13 deletions(-)
diff --git a/data.py b/data.py
@@ -60,16 +60,17 @@ class DayType(Enum):
return 'C'
class TaxiData(Dataset):
- provides_sources= ("trip_id","call_type","origin_call","origin_stand","taxi_id","timestamp","day_type","missing_data","polyline")
example_iteration_scheme=None
class State:
__slots__ = ('file', 'index', 'reader')
- def __init__(self, pathes, has_header=False):
+ def __init__(self, pathes, columns, has_header=False):
if not isinstance(pathes, list):
pathes=[pathes]
assert len(pathes)>0
+ self.columns=columns
+ self.provides_sources = tuple(map(lambda x: x[0], columns))
self.pathes=pathes
self.has_header=has_header
super(TaxiData, self).__init__()
@@ -113,20 +114,27 @@ class TaxiData(Dataset):
state.reader.next()
return self.get_data(state)
- line[1]=CallType.from_data(line[1]) # call_type
- line[2]=0 if line[2]=='' or line[2]=='NA' else client_ids[int(line[2])] # origin_call
- line[3]=0 if line[3]=='' or line[3]=='NA' else int(line[3]) # origin_stand
- line[4]=int(line[4]) # taxi_id
- line[5]=int(line[5]) # timestamp
- line[6]=DayType.from_data(line[6]) # day_type
- line[7]=line[7][0]=='T' # missing_data
- line[8]=map(tuple, ast.literal_eval(line[8])) # polyline
- return tuple(line)
+ values = []
+ for idx, (_, constructor) in enumerate(self.columns):
+ values.append(constructor(line[idx]))
+ return tuple(values)
+
+taxi_columns = [
+ ("trip_id", lambda x: x),
+ ("call_type", CallType.from_data),
+ ("origin_call", lambda x: 0 if x == '' or x == 'NA' else client_ids[int(x)]),
+ ("origin_stand", lambda x: 0 if x == '' or x == 'NA' else int(x)),
+ ("taxi_id", int),
+ ("timestamp", int),
+ ("day_type", DayType.from_data),
+ ("missing_data", lambda x: x[0] == 'T'),
+ ("polyline", lambda x: map(tuple, ast.literal_eval(x))),
+]
train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)]
valid_files=["%s/split/valid.csv" % (DATA_PATH,)]
-train_data=TaxiData(train_files)
-valid_data=TaxiData(valid_files)
+train_data=TaxiData(train_files, taxi_columns)
+valid_data=TaxiData(valid_files, taxi_columns)
def train_it():
return DataIterator(DataStream(train_data))