commit 0be3ebaa19f2cf8a630565434e785e5c24929a14
parent 5589a8af8967cfc73d3b6fda8f86acc0d08172b8
Author: Étienne Simon <esimon@esimon.eu>
Date: Fri, 24 Apr 2015 16:37:49 -0400
Make TaxiData accept multiple files
Diffstat:
M | data.py | | | 48 | +++++++++++++++++++++++++++++++++++++----------- |
1 file changed, 37 insertions(+), 11 deletions(-)
diff --git a/data.py b/data.py
@@ -55,31 +55,57 @@ class DayType(Enum):
class TaxiData(Dataset):
provides_sources= ("trip_id","call_type","origin_call","origin_stand","taxi_id","timestamp","day_type","missing_data","polyline")
-
example_iteration_scheme=None
- def __init__(self, path):
- self.path=path
+ class State:
+ __slots__ = ('file', 'index', 'reader')
+
+ def __init__(self, pathes, has_header=False):
+ if not isinstance(pathes, list):
+ pathes=[pathes]
+ assert len(pathes)
+ self.pathes=pathes
+ self.has_header=has_header
super(TaxiData, self).__init__()
def open(self):
- file=open(self.path)
- reader=csv.reader(file)
- reader.next() # Skip header
- return (file, reader)
+ state=self.State()
+ state.file=open(self.pathes[0])
+ state.index=0
+ state.reader=csv.reader(state.file)
+ if self.has_header:
+ state.reader.next()
+ return state
def close(self, state):
- state[0].close()
+ state.file.close()
def reset(self, state):
- state[0].seek(0)
- state[1]=csv.reader(state[0])
+ if state.index==0:
+ state.file.seek(0)
+ else:
+ state.index=0
+ state.file.close()
+ state.file=open(self.pathes[0])
+ state.reader=csv.reader(state[0])
return state
def get_data(self, state, request=None):
if request is not None:
raise ValueError
- line=state[1].next()
+ try:
+ line=state.reader.next()
+ except StopIteration:
+ state.file.close()
+ state.index+=1
+ if state.index>=len(self.pathes):
+ raise
+ state.file=open(self.pathes[state.index])
+ state.reader=csv.reader(state.file)
+ if self.has_header:
+ state.reader.next()
+ line=state.reader.next()
+
line[1]=CallType.from_data(line[1]) # call_type
line[2]=0 if line[2]=='' or line[2]=='NA' else int(line[2]) # origin_call
line[3]=0 if line[3]=='' or line[3]=='NA' else int(line[3]) # origin_stand