ext_test.py (3705B)
1 #!/usr/bin/env python 2 3 import logging 4 import os 5 import csv 6 7 from blocks.model import Model 8 from blocks.extensions import SimpleExtension 9 10 logger = logging.getLogger(__name__) 11 12 class RunOnTest(SimpleExtension): 13 def __init__(self, model_name, model, stream, **kwargs): 14 super(RunOnTest, self).__init__(**kwargs) 15 16 self.model_name = model_name 17 18 cg = Model(model.predict(**stream.inputs())) 19 20 self.inputs = cg.inputs 21 self.outputs = model.predict.outputs 22 23 req_vars_test = model.predict.inputs + ['trip_id'] 24 self.test_stream = stream.test(req_vars_test) 25 26 self.function = cg.get_theano_function() 27 28 self.best_dvc = None 29 self.best_tvc = None 30 31 def do(self, which_callback, *args): 32 iter_no = self.main_loop.log.status['iterations_done'] 33 if 'valid_destination_cost' in self.main_loop.log.current_row: 34 dvc = self.main_loop.log.current_row['valid_destination_cost'] 35 elif 'valid_model_cost_cost' in self.main_loop.log.current_row: 36 dvc = self.main_loop.log.current_row['valid_model_cost_cost'] 37 elif 'valid_model_valid_cost_cost' in self.main_loop.log.current_row: 38 dvc = self.main_loop.log.current_row['valid_model_valid_cost_cost'] 39 else: 40 raise RuntimeError("Unknown model type") 41 42 if 'valid_time_cost' in self.main_loop.log.current_row: 43 tvc = self.main_loop.log.current_row['valid_time_cost'] 44 elif 'valid_model_cost_cost' in self.main_loop.log.current_row: 45 tvc = self.main_loop.log.current_row['valid_model_cost_cost'] 46 elif 'valid_model_valid_cost_cost' in self.main_loop.log.current_row: 47 tvc = self.main_loop.log.current_row['valid_model_valid_cost_cost'] 48 else: 49 raise RuntimeError("Unknown model type") 50 51 output_dvc = (self.best_dvc is None or dvc < self.best_dvc) and 'destination' in self.outputs 52 output_tvc = (self.best_tvc is None or tvc < self.best_tvc) and 'duration' in self.outputs 53 54 if not output_dvc and not output_tvc: 55 return 56 57 if output_dvc: 58 self.best_dvc = dvc 59 dest_outname = 'test-dest-%s-it%09d-cost%.3f.csv' % (self.model_name, iter_no, dvc) 60 dest_outfile = open(os.path.join('output', dest_outname), 'w') 61 dest_outcsv = csv.writer(dest_outfile) 62 dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"]) 63 logger.info("Generating output for test set: %s" % dest_outname) 64 if output_tvc: 65 self.best_tvc = tvc 66 time_outname = 'test-time-%s-it%09d-cost%.3f.csv' % (self.model_name, iter_no, tvc) 67 time_outfile = open(os.path.join('output', time_outname), 'w') 68 time_outcsv = csv.writer(time_outfile) 69 time_outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"]) 70 logger.info("Generating output for test set: %s" % time_outname) 71 72 for d in self.test_stream.get_epoch_iterator(as_dict=True): 73 input_values = [d[k.name] for k in self.inputs] 74 output_values = self.function(*input_values) 75 for i in range(d['trip_id'].shape[0]): 76 if output_dvc: 77 destination = output_values[self.outputs.index('destination')] 78 dest_outcsv.writerow([d['trip_id'][i], destination[i, 0], destination[i, 1]]) 79 if output_tvc: 80 duration = output_values[self.outputs.index('duration')] 81 time_outcsv.writerow([d['trip_id'][i], int(round(duration[i]))]) 82 83 if output_dvc: 84 dest_outfile.close() 85 if output_tvc: 86 time_outfile.close() 87