@@ -73,7 +73,9 @@ def compute_recall(gt, predictions, numQ, n_values, recall_str=''):
7373def  write_kapture_output (opt , eval_set , predictions , outfile_name ):
7474    if  not  exists (opt .result_save_folder ):
7575        os .mkdir (opt .result_save_folder )
76-     with  open (join (opt .result_save_folder , outfile_name ), 'w' ) as  kap_out :
76+     outfile  =  join (opt .result_save_folder , outfile_name )
77+     print ('Writing results to' , outfile )
78+     with  open (outfile , 'w' ) as  kap_out :
7779        kap_out .write ('# kapture format: 1.0\n ' )
7880        kap_out .write ('# query_image, map_image\n ' )
7981        image_list_array  =  np .array (eval_set .images )
@@ -87,7 +89,9 @@ def write_kapture_output(opt, eval_set, predictions, outfile_name):
8789def  write_recalls_output (opt , recalls_netvlad , recalls_patchnetvlad , n_values ):
8890    if  not  exists (opt .result_save_folder ):
8991        os .mkdir (opt .result_save_folder )
90-     with  open (join (opt .result_save_folder , 'recalls.txt' ), 'w' ) as  rec_out :
92+     outfile  =  join (opt .result_save_folder , 'recalls.txt' )
93+     print ('Writing recalls to' , outfile )
94+     with  open (outfile , 'w' ) as  rec_out :
9195        for  n  in  n_values :
9296            rec_out .write ("Recall {}@{}: {:.4f}\n " .format ('NetVLAD' , n , recalls_netvlad [n ]))
9397        for  n  in  n_values :
@@ -120,7 +124,8 @@ def feature_match(eval_set, device, opt, config):
120124    if  config ['feature_match' ]['pred_input_path' ] !=  'None' :
121125        predictions  =  np .load (config ['feature_match' ]['pred_input_path' ])  # optionally load predictions from a np file 
122126    else :
123-         if  opt .ground_truth_path .split ('/' )[1 ][:- 4 ] ==  'tokyo247' :
127+         if  opt .ground_truth_path  and  'tokyo247'  in  opt .ground_truth_path :
128+             print ('Tokyo24/7: Selecting only one of the 12 cutouts per panorama' )
124129            # followed nnSearchPostprocess in https://github.com/Relja/netvlad/blob/master/datasets/dbTokyo247.m 
125130            # noinspection PyArgumentList 
126131            _ , predictions  =  faiss_index .search (qFeat , max (n_values ) *  12 )  # 12 cutouts per panorama 
@@ -133,7 +138,7 @@ def feature_match(eval_set, device, opt, config):
133138            predictions  =  np .array (predictions_new )
134139        else :
135140            # noinspection PyArgumentList 
136-             _ , predictions  =  faiss_index .search (qFeat , max (n_values ))
141+             _ , predictions  =  faiss_index .search (qFeat , min ( len ( qFeat ),  max (n_values ) ))
137142
138143    reranked_predictions  =  local_matcher (predictions , eval_set , input_query_local_features_prefix ,
139144                                         input_index_local_features_prefix , config , device )
@@ -142,16 +147,19 @@ def feature_match(eval_set, device, opt, config):
142147    write_kapture_output (opt , eval_set , predictions , 'NetVLAD_predictions.txt' )
143148    write_kapture_output (opt , eval_set , reranked_predictions , 'PatchNetVLAD_predictions.txt' )
144149
145-     print ('Finished matching features. About to eval GT if GT was provided ' )
150+     print ('Finished matching features.' )
146151
147152    # for each query get those within threshold distance 
148153    if  opt .ground_truth_path  is  not   None :
154+         print ('Calculating recalls using ground truth.' )
149155        gt  =  eval_set .get_positives ()
150156
151157        global_recalls  =  compute_recall (gt , predictions , eval_set .numQ , n_values , 'NetVLAD' )
152158        local_recalls  =  compute_recall (gt , reranked_predictions , eval_set .numQ , n_values , 'PatchNetVLAD' )
153159
154160        write_recalls_output (opt , global_recalls , local_recalls , n_values )
161+     else :
162+         print ('No ground truth was provided; not calculating recalls.' )
155163
156164
157165def  main ():
0 commit comments