Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions dpg_bench/compute_dpg_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ def prepare_dpg_data(args):
# 'item_id', 'text', 'keywords', 'proposition_id', 'dependency', 'category_broad', 'category_detailed', 'tuple', 'question_natural_language'
data = pd.read_csv(args.csv)
for i, line in data.iterrows():
if i == 0:
continue

current_id = line.item_id
qid = int(line.proposition_id)
dependency_list_str = line.dependency.split(',')
Expand Down Expand Up @@ -122,12 +119,12 @@ def compute_dpg_one_sample(args, question_dict, image_path, vqa_model, resolutio
qid2question = value['qid2question']
qid2dependency = value['qid2dependency']

qid2answer = dict()
qid2scores = dict()
qid2validity = dict()

scores = []
all_crops_qid2scores_orig = []
for crop_tuple in crop_tuples:
qid2answer = dict()
qid2scores = dict()
qid2validity = dict()
cropped_image = crop_image(generated_image, crop_tuple)
for id, question in qid2question.items():
answer = vqa_model.vqa(cropped_image, question)
Expand All @@ -136,6 +133,7 @@ def compute_dpg_one_sample(args, question_dict, image_path, vqa_model, resolutio
with open(args.res_path.replace('.txt', '_detail.txt'), 'a') as f:
f.write(image_path + ', ' + str(crop_tuple) + ', ' + question + ', ' + answer + '\n')
qid2scores_orig = qid2scores.copy()
all_crops_qid2scores_orig.append(qid2scores_orig)

for id, parent_ids in qid2dependency.items():
# zero-out scores if parent questions are answered 'no'
Expand All @@ -158,7 +156,16 @@ def compute_dpg_one_sample(args, question_dict, image_path, vqa_model, resolutio
with open(args.res_path, 'a') as f:
f.write(image_path + ', ' + ', '.join(str(i) for i in scores) + ', ' + str(average_score) + '\n')

return average_score, qid2tuple, qid2scores_orig
aggregated_qid2scores_orig = defaultdict(float)
num_crops = len(all_crops_qid2scores_orig)
if num_crops > 0:
for crop_scores in all_crops_qid2scores_orig:
for qid, s in crop_scores.items():
aggregated_qid2scores_orig[qid] += s
for qid in aggregated_qid2scores_orig:
aggregated_qid2scores_orig[qid] /= num_crops

return average_score, qid2tuple, dict(aggregated_qid2scores_orig)


def main():
Expand Down Expand Up @@ -188,9 +195,7 @@ def main():
vqa_model = getattr(vqa_model, 'module', vqa_model)

filename_list = os.listdir(args.image_root_path)
num_each_rank = len(filename_list) / accelerator.num_processes
local_rank = accelerator.process_index
local_filename_list = filename_list[round(local_rank * num_each_rank) : round((local_rank + 1) * num_each_rank)]
local_filename_list = filename_list[accelerator.process_index::accelerator.num_processes]

local_scores = []
local_category2scores = defaultdict(list)
Expand Down