Skip to content

Commit

Permalink
Plotting functions (#16)
Browse files Browse the repository at this point in the history
* general plot utilities that will be useful for loading CI score files, concatenating JSON/score files into dataframes, and keeping the same plotting parameters across figures.
* plotting scripts to generate figures from paper

---------

Co-authored-by: aleto1999 <[email protected]>
  • Loading branch information
vyaivo and aleto1999 authored Oct 30, 2024
1 parent 211bfad commit 29d1a3c
Show file tree
Hide file tree
Showing 5 changed files with 412 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Output figures
*.png
*.pdf
*.fig

# Ignore changes to jupyter notebooks
*/*.ipynb

Expand Down
74 changes: 74 additions & 0 deletions plots/plot_gold-search-recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import plot_utils


plot_utils.set_plot_params()
datasets = ['asqa', 'nq']
models = ['Llama', 'Mistral']

# Grab the files and generate data frames
print("Loading data from gold-recall experiments...")
file_list, field_list = plot_utils.load_metric_ci_files(datasets, models, conditions=[0.5, 0.7, 0.9, 1.0], subfolder="gold-recall")
goldrec_df = plot_utils.compile_metric_df(file_list, field_list)
#print(goldrec_df)
print("Loading data from search-recall experiments...")
file_list, field_list = plot_utils.load_metric_ci_files(datasets, models, conditions=[0.7, 0.9, 0.95], subfolder="search-recall")
searchrec_df = plot_utils.compile_metric_df(file_list, field_list)

# Now generate the plot itself!
df_datasets = ['ASQA', 'NQ', 'NQ-nocite'] # order that I want to plot the datasets
colors = plt.cm.tab10(np.linspace(0, 1, 10))
for model in models:
#f, ax = plt.subplots(1, 2, figsize=(4.25, 1.75), width_ratios=[2, 1])
f, ax = plt.subplots(1, 2, figsize=(5.5, 2), width_ratios=[2, 1])
p0 = goldrec_df[goldrec_df['model'] == model]
for i, d in enumerate(df_datasets):
this_plot = p0[p0['dataset'] == d]
x = np.array(this_plot['condition'])
m = np.array(this_plot['em_rec_mean'])
yerr = np.stack([this_plot['em_rec_ci_lower'], this_plot['em_rec_ci_upper']], axis=0)
ax[0].errorbar(x, m, yerr=abs(yerr - m), label=d)
# Plot shaded bar for recall = 1.0
# print(p0)
rdf = p0[(p0['dataset'] == d) & (p0['condition'] == 1.0)]
if len(rdf) == 0:
print(f"Could not find data for {model}, {d}, skipping...")
continue
ax[0].axhspan(rdf['em_rec_ci_lower'].item(), rdf['em_rec_ci_upper'].item(), color=colors[i, :3], alpha=0.15)
ax[1].axhspan(rdf['em_rec_ci_lower'].item(), rdf['em_rec_ci_upper'].item(), color=colors[i, :3], alpha=0.15)
ax[0].set_ylim([20, 90])
ax[0].set_xlabel('Gold Document Recall')
ax[0].set_ylabel('EM Recall')
ax[0].xaxis.set_major_locator(matplotlib.ticker.MultipleLocator(0.1))
ax[0].xaxis.set_minor_locator(matplotlib.ticker.MultipleLocator(0.05))
ax[0].yaxis.set_minor_locator(matplotlib.ticker.MultipleLocator(5))
ax[0].grid(which='minor', alpha=0.5)
ax[0].invert_xaxis()
# Add legend
#ax[0].legend(loc='lower left', ncols=3, columnspacing=0.5, frameon=False, bbox_to_anchor=(-0.02, -0.02), handletextpad=0.5)
ax[0].legend(loc='lower left', ncols=3, frameon=True)

p1 = searchrec_df[searchrec_df['model'] == model]
for d in df_datasets:
this_plot = p1[p1['dataset'] == d]
x = np.array(this_plot['condition'])
m = np.array(this_plot['em_rec_mean'])
yerr = np.stack([this_plot['em_rec_ci_lower'], this_plot['em_rec_ci_upper']], axis=0)
ax[1].errorbar(x, m, yerr=abs(yerr - m), label=d)
ax[1].set_ylim([20, 90])
ax[1].set_xlabel('Search Recall@10')
ax[1].xaxis.set_minor_locator(matplotlib.ticker.MultipleLocator(0.05))
ax[1].yaxis.set_minor_locator(matplotlib.ticker.MultipleLocator(5))
ax[1].grid(which='minor', alpha=0.5)
ax[1].invert_xaxis()
plt.tight_layout()

# Save the figure to an output file
figname = f'plots/gold-search-recall-{model}.png'
plt.savefig(figname)
print(f'Saved {figname}!')
plt.close(f)
del f, ax

94 changes: 94 additions & 0 deletions plots/plot_ndoc-recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import plot_utils
import seaborn as sns


def main():
plot_utils.set_plot_params()
colors = sns.color_palette()

datasets = ['asqa', 'qampari', 'nq', 'nq-nocite']
retrievers = ['bge-base', 'colbert']
models = ['Llama', 'Mistral']
k_vals = [1, 2, 3, 4, 5, 10, 20, 50, 100]
shared_y = False

subplot_indices = {
'asqa': [0, 0],
'qampari': [0, 1],
'nq': [1, 0],
'nq-nocite': [1, 1]
}

# grab the files and generate data frames
print("Loading data from ndoc experiments...")
file_list, field_list = plot_utils.load_metric_ndoc_files(
datasets, models, retrievers, conditions=k_vals
)

df = plot_utils.compile_metric_df(file_list, extra_fields=field_list, nested=False)
print(df.describe())

# a plot for each retriever/model combination
print('\n\nGenerating plots...')
for model in models:
fig, ax = plt.subplots(2, 2, figsize=(5.5, 4), width_ratios=[1, 1])
for dataset in datasets:
axh_plotted = False # flag to plot the no-context and gold lines (only once per dataset)
for retriever in retrievers:
df_filtered = df[(df['model'] == model) & (df['retriever'] == retriever) & (df['dataset'] == dataset)]
if len(df_filtered) == 0:
print(f"Could not find data for {dataset}, {model}, {retriever}, skipping...")
continue
print(f"Plotting {model}, {retriever}, {dataset}...")
ax_x = subplot_indices[dataset][0]
ax_y = subplot_indices[dataset][1]
if not axh_plotted: # plot the no-context and gold lines
no_context_acc = df[(df['model'] == model) & (df['dataset'] == dataset) & (df['condition'] == 'no-context')]["em_rec_mean"].values[0]
ax[ax_x][ax_y].axhline(y=no_context_acc, color=colors[3], linestyle='--', label='No Context')
gold_acc = df[(df['model'] == model) & (df['dataset'] == dataset) & (df['condition'] == 'gold')]["em_rec_mean"].values[0]
ax[ax_x][ax_y].axhline(y=gold_acc, color=colors[2], linestyle='--', label='Gold')
axh_plotted = True

# plot em rec mean with error bars
x = np.array(df_filtered['condition'])
m = np.array(df_filtered['em_rec_mean'])
yerr = np.stack([df_filtered['em_rec_ci_lower'], df_filtered['em_rec_ci_upper']], axis=0)
ax[ax_x][ax_y].errorbar(x, m, yerr=abs(yerr - m), label=retriever)

if dataset == 'nq-nocite':
ax[ax_x][ax_y].set_title("NQ (No Citations)")
else:
ax[ax_x][ax_y].set_title(f'{dataset.upper()}')

acc_label = 'EM Recall'
if ax_y == 0: # only show y-axis label on the left column
ax[ax_x][ax_y].set_ylabel(acc_label)

if ax_x == 1: # only show x-axis label on the bottom row
ax[ax_x][ax_y].set_xlabel('k')

if shared_y:
ax[ax_x][ax_y].set_ylim(0, 90)

if 'nq' in dataset:
ax[ax_x][ax_y].set_ylim(0, 90)

# legend with only labels from first plot (top left)
handles, labels = ax[0][0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=4)

# Save the figure to an output file
fig.tight_layout()
plt.subplots_adjust(bottom=0.17) # fix overlap with x axis label

figname = f'plots/ndoc-reader-acc-{model}.png'
if shared_y:
figname = f'plots/ndoc-reader-acc-{model}-shared-y.png'
plt.savefig(figname)


if __name__ == "__main__":
main()
54 changes: 54 additions & 0 deletions plots/plot_noise_percentile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
import matplotlib.pyplot as plt
import plot_utils
import seaborn as sns

def main():

plot_utils.set_plot_params()
colors = sns.color_palette()

datasets = ['asqa']
retrievers = ['bge-base', 'gold']
models = ['Mistral']
percentiles = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]

# grab the files and generate data frames
print("Loading data from noise experiments...")
file_list, field_list = plot_utils.load_metric_noise_files(
datasets, models, retrievers, conditions=percentiles, subfolder="noise"
)
df = plot_utils.compile_metric_df(file_list, field_list)

print("Generating plots...")
for dataset in datasets:
for model in models:
fig, ax = plt.subplots(figsize=(5, 2))
df_dataset = df[df['dataset'] == dataset]

for idx, retriever in enumerate(retrievers):

df_ret = df_dataset[df_dataset['retriever'] == retriever]
ret_only = df_ret[df_ret['condition'] == 0] # retriever only
df_ret = df_ret[df_ret['condition'] != 0] # remove retriever only

# plot the data
x = np.array(df_ret['condition'])
m = np.array(df_ret['em_rec_mean'])
yerr = np.stack([df_ret['em_rec_ci_lower'], df_ret['em_rec_ci_upper']], axis=0)
ax.plot(x, m, label=f"{retriever} + noise", color=colors[idx])
ax.errorbar(x, m, yerr=abs(yerr - m), color=colors[idx])
ax.axhline(ret_only['em_rec_mean'].values[0], linestyle='--', label=retriever, color=colors[idx])


plt.ylabel("EM Recall")
plt.xlabel("Noise Percentile")
plt.title(f"{dataset.upper()} ({model})")
plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')
plt.ylim([20, 55])
plt.tight_layout()
fig.savefig(f"plots/rand_percentile_recall_{dataset}_{model}.png", dpi=300)


if __name__ == "__main__":
main()
Loading

0 comments on commit 29d1a3c

Please sign in to comment.