@@ -82,11 +82,23 @@ def val(eval_set, model, encoder_dim, device, opt, config, writer, epoch_num=0,
8282 tqdm .write ('====> Calculating recall @ N' )
8383 n_values = [1 , 5 , 10 , 20 , 50 , 100 ]
8484
85- _ , predictions = faiss_index .search (qFeat , max (n_values ))
86-
8785 # for each query get those within threshold distance
8886 gt = eval_set .all_pos_indices
8987
88+ # any combination of mapillary cities will work as a val set
89+ qEndPosTot = 0
90+ dbEndPosTot = 0
91+ for cityNum , (qEndPos , dbEndPos ) in enumerate (zip (eval_set .qEndPosList , eval_set .dbEndPosList )):
92+ faiss_index = faiss .IndexFlatL2 (pool_size )
93+ faiss_index .add (dbFeat [dbEndPosTot :dbEndPosTot + dbEndPos , :])
94+ _ , preds = faiss_index .search (qFeat [qEndPosTot :qEndPosTot + qEndPos , :], max (n_values ))
95+ if cityNum == 0 :
96+ predictions = preds
97+ else :
98+ predictions = np .vstack ((predictions , preds ))
99+ qEndPosTot += qEndPos
100+ dbEndPosTot += dbEndPos
101+
90102 correct_at_n = np .zeros (len (n_values ))
91103 # TODO can we do this on the matrix in one go?
92104 for qIx , pred in enumerate (predictions ):
@@ -104,4 +116,4 @@ def val(eval_set, model, encoder_dim, device, opt, config, writer, epoch_num=0,
104116 if write_tboard :
105117 writer .add_scalar ('Val/Recall@' + str (n ), recall_at_n [i ], epoch_num )
106118
107- return all_recalls
119+ return all_recalls
0 commit comments