__init__.py (1286B)
1 import os 2 import sys 3 4 import h5py 5 import numpy 6 7 8 path = os.environ.get('TAXI_PATH', '/data/lisatmp3/auvolat/taxikaggle') 9 10 Polyline = h5py.special_dtype(vlen=numpy.float32) 11 12 # `wc -l metaData_taxistandsID_name_GPSlocation.csv` 13 stands_size = 64 # include 0 ("no origin_stands") 14 15 # `cut -d, -f 5 train.csv test.csv | sort -u | wc -l` - 1 16 taxi_id_size = 448 17 18 train_gps_mean = numpy.array([41.1573, -8.61612], dtype=numpy.float32) 19 train_gps_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=numpy.float32)) 20 21 tvt = '--tvt' in sys.argv 22 23 if tvt: 24 test_size = 19770 25 valid_size = 19427 26 train_size = 1671473 27 28 origin_call_size = 57106 29 origin_call_train_size = 57106 30 31 valid_set = 'valid' 32 valid_ds = 'tvt.hdf5' 33 traintest_ds = 'tvt.hdf5' 34 35 else: 36 # `wc -l test.csv` - 1 # Minus 1 to ignore the header 37 test_size = 320 38 39 # `wc -l train.csv` - 1 40 train_size = 1710670 41 42 # `cut -d, -f 3 train.csv test.csv | sort -u | wc -l` - 2 43 origin_call_size = 57125 # include 0 ("no origin_call") 44 45 # As printed by csv_to_hdf5.py 46 origin_call_train_size = 57106 47 48 if '--largevalid' in sys.argv: 49 valid_set = 'cuts/large_valid' 50 else: 51 valid_set = 'cuts/test_times_0' 52 53 valid_ds = 'valid.hdf5' 54 traintest_ds = 'data.hdf5'