diff --git a/dpg_bench/compute_dpg_bench.py b/dpg_bench/compute_dpg_bench.py index aa2014b..3166bbf 100644 --- a/dpg_bench/compute_dpg_bench.py +++ b/dpg_bench/compute_dpg_bench.py @@ -70,9 +70,6 @@ def prepare_dpg_data(args): # 'item_id', 'text', 'keywords', 'proposition_id', 'dependency', 'category_broad', 'category_detailed', 'tuple', 'question_natural_language' data = pd.read_csv(args.csv) for i, line in data.iterrows(): - if i == 0: - continue - current_id = line.item_id qid = int(line.proposition_id) dependency_list_str = line.dependency.split(',') @@ -122,12 +119,12 @@ def compute_dpg_one_sample(args, question_dict, image_path, vqa_model, resolutio qid2question = value['qid2question'] qid2dependency = value['qid2dependency'] - qid2answer = dict() - qid2scores = dict() - qid2validity = dict() - scores = [] + all_crops_qid2scores_orig = [] for crop_tuple in crop_tuples: + qid2answer = dict() + qid2scores = dict() + qid2validity = dict() cropped_image = crop_image(generated_image, crop_tuple) for id, question in qid2question.items(): answer = vqa_model.vqa(cropped_image, question) @@ -136,6 +133,7 @@ def compute_dpg_one_sample(args, question_dict, image_path, vqa_model, resolutio with open(args.res_path.replace('.txt', '_detail.txt'), 'a') as f: f.write(image_path + ', ' + str(crop_tuple) + ', ' + question + ', ' + answer + '\n') qid2scores_orig = qid2scores.copy() + all_crops_qid2scores_orig.append(qid2scores_orig) for id, parent_ids in qid2dependency.items(): # zero-out scores if parent questions are answered 'no' @@ -158,7 +156,16 @@ def compute_dpg_one_sample(args, question_dict, image_path, vqa_model, resolutio with open(args.res_path, 'a') as f: f.write(image_path + ', ' + ', '.join(str(i) for i in scores) + ', ' + str(average_score) + '\n') - return average_score, qid2tuple, qid2scores_orig + aggregated_qid2scores_orig = defaultdict(float) + num_crops = len(all_crops_qid2scores_orig) + if num_crops > 0: + for crop_scores in all_crops_qid2scores_orig: + for qid, s in crop_scores.items(): + aggregated_qid2scores_orig[qid] += s + for qid in aggregated_qid2scores_orig: + aggregated_qid2scores_orig[qid] /= num_crops + + return average_score, qid2tuple, dict(aggregated_qid2scores_orig) def main(): @@ -188,9 +195,7 @@ def main(): vqa_model = getattr(vqa_model, 'module', vqa_model) filename_list = os.listdir(args.image_root_path) - num_each_rank = len(filename_list) / accelerator.num_processes - local_rank = accelerator.process_index - local_filename_list = filename_list[round(local_rank * num_each_rank) : round((local_rank + 1) * num_each_rank)] + local_filename_list = filename_list[accelerator.process_index::accelerator.num_processes] local_scores = [] local_category2scores = defaultdict(list)