-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* general plot utilities that will be useful for loading CI score files, concatenating JSON/score files into dataframes, and keeping the same plotting parameters across figures. * plotting scripts to generate figures from paper --------- Co-authored-by: aleto1999 <[email protected]>
- Loading branch information
Showing
5 changed files
with
412 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,8 @@ | ||
# Output figures | ||
*.png | ||
*.fig | ||
|
||
# Ignore changes to jupyter notebooks | ||
*/*.ipynb | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import matplotlib | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import plot_utils | ||
|
||
|
||
plot_utils.set_plot_params() | ||
datasets = ['asqa', 'nq'] | ||
models = ['Llama', 'Mistral'] | ||
|
||
# Grab the files and generate data frames | ||
print("Loading data from gold-recall experiments...") | ||
file_list, field_list = plot_utils.load_metric_ci_files(datasets, models, conditions=[0.5, 0.7, 0.9, 1.0], subfolder="gold-recall") | ||
goldrec_df = plot_utils.compile_metric_df(file_list, field_list) | ||
#print(goldrec_df) | ||
print("Loading data from search-recall experiments...") | ||
file_list, field_list = plot_utils.load_metric_ci_files(datasets, models, conditions=[0.7, 0.9, 0.95], subfolder="search-recall") | ||
searchrec_df = plot_utils.compile_metric_df(file_list, field_list) | ||
|
||
# Now generate the plot itself! | ||
df_datasets = ['ASQA', 'NQ', 'NQ-nocite'] # order that I want to plot the datasets | ||
colors = plt.cm.tab10(np.linspace(0, 1, 10)) | ||
for model in models: | ||
#f, ax = plt.subplots(1, 2, figsize=(4.25, 1.75), width_ratios=[2, 1]) | ||
f, ax = plt.subplots(1, 2, figsize=(5.5, 2), width_ratios=[2, 1]) | ||
p0 = goldrec_df[goldrec_df['model'] == model] | ||
for i, d in enumerate(df_datasets): | ||
this_plot = p0[p0['dataset'] == d] | ||
x = np.array(this_plot['condition']) | ||
m = np.array(this_plot['em_rec_mean']) | ||
yerr = np.stack([this_plot['em_rec_ci_lower'], this_plot['em_rec_ci_upper']], axis=0) | ||
ax[0].errorbar(x, m, yerr=abs(yerr - m), label=d) | ||
# Plot shaded bar for recall = 1.0 | ||
# print(p0) | ||
rdf = p0[(p0['dataset'] == d) & (p0['condition'] == 1.0)] | ||
if len(rdf) == 0: | ||
print(f"Could not find data for {model}, {d}, skipping...") | ||
continue | ||
ax[0].axhspan(rdf['em_rec_ci_lower'].item(), rdf['em_rec_ci_upper'].item(), color=colors[i, :3], alpha=0.15) | ||
ax[1].axhspan(rdf['em_rec_ci_lower'].item(), rdf['em_rec_ci_upper'].item(), color=colors[i, :3], alpha=0.15) | ||
ax[0].set_ylim([20, 90]) | ||
ax[0].set_xlabel('Gold Document Recall') | ||
ax[0].set_ylabel('EM Recall') | ||
ax[0].xaxis.set_major_locator(matplotlib.ticker.MultipleLocator(0.1)) | ||
ax[0].xaxis.set_minor_locator(matplotlib.ticker.MultipleLocator(0.05)) | ||
ax[0].yaxis.set_minor_locator(matplotlib.ticker.MultipleLocator(5)) | ||
ax[0].grid(which='minor', alpha=0.5) | ||
ax[0].invert_xaxis() | ||
# Add legend | ||
#ax[0].legend(loc='lower left', ncols=3, columnspacing=0.5, frameon=False, bbox_to_anchor=(-0.02, -0.02), handletextpad=0.5) | ||
ax[0].legend(loc='lower left', ncols=3, frameon=True) | ||
|
||
p1 = searchrec_df[searchrec_df['model'] == model] | ||
for d in df_datasets: | ||
this_plot = p1[p1['dataset'] == d] | ||
x = np.array(this_plot['condition']) | ||
m = np.array(this_plot['em_rec_mean']) | ||
yerr = np.stack([this_plot['em_rec_ci_lower'], this_plot['em_rec_ci_upper']], axis=0) | ||
ax[1].errorbar(x, m, yerr=abs(yerr - m), label=d) | ||
ax[1].set_ylim([20, 90]) | ||
ax[1].set_xlabel('Search Recall@10') | ||
ax[1].xaxis.set_minor_locator(matplotlib.ticker.MultipleLocator(0.05)) | ||
ax[1].yaxis.set_minor_locator(matplotlib.ticker.MultipleLocator(5)) | ||
ax[1].grid(which='minor', alpha=0.5) | ||
ax[1].invert_xaxis() | ||
plt.tight_layout() | ||
|
||
# Save the figure to an output file | ||
figname = f'plots/gold-search-recall-{model}.png' | ||
plt.savefig(figname) | ||
print(f'Saved {figname}!') | ||
plt.close(f) | ||
del f, ax | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import matplotlib | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import plot_utils | ||
import seaborn as sns | ||
|
||
|
||
def main(): | ||
plot_utils.set_plot_params() | ||
colors = sns.color_palette() | ||
|
||
datasets = ['asqa', 'qampari', 'nq', 'nq-nocite'] | ||
retrievers = ['bge-base', 'colbert'] | ||
models = ['Llama', 'Mistral'] | ||
k_vals = [1, 2, 3, 4, 5, 10, 20, 50, 100] | ||
shared_y = False | ||
|
||
subplot_indices = { | ||
'asqa': [0, 0], | ||
'qampari': [0, 1], | ||
'nq': [1, 0], | ||
'nq-nocite': [1, 1] | ||
} | ||
|
||
# grab the files and generate data frames | ||
print("Loading data from ndoc experiments...") | ||
file_list, field_list = plot_utils.load_metric_ndoc_files( | ||
datasets, models, retrievers, conditions=k_vals | ||
) | ||
|
||
df = plot_utils.compile_metric_df(file_list, extra_fields=field_list, nested=False) | ||
print(df.describe()) | ||
|
||
# a plot for each retriever/model combination | ||
print('\n\nGenerating plots...') | ||
for model in models: | ||
fig, ax = plt.subplots(2, 2, figsize=(5.5, 4), width_ratios=[1, 1]) | ||
for dataset in datasets: | ||
axh_plotted = False # flag to plot the no-context and gold lines (only once per dataset) | ||
for retriever in retrievers: | ||
df_filtered = df[(df['model'] == model) & (df['retriever'] == retriever) & (df['dataset'] == dataset)] | ||
if len(df_filtered) == 0: | ||
print(f"Could not find data for {dataset}, {model}, {retriever}, skipping...") | ||
continue | ||
print(f"Plotting {model}, {retriever}, {dataset}...") | ||
ax_x = subplot_indices[dataset][0] | ||
ax_y = subplot_indices[dataset][1] | ||
if not axh_plotted: # plot the no-context and gold lines | ||
no_context_acc = df[(df['model'] == model) & (df['dataset'] == dataset) & (df['condition'] == 'no-context')]["em_rec_mean"].values[0] | ||
ax[ax_x][ax_y].axhline(y=no_context_acc, color=colors[3], linestyle='--', label='No Context') | ||
gold_acc = df[(df['model'] == model) & (df['dataset'] == dataset) & (df['condition'] == 'gold')]["em_rec_mean"].values[0] | ||
ax[ax_x][ax_y].axhline(y=gold_acc, color=colors[2], linestyle='--', label='Gold') | ||
axh_plotted = True | ||
|
||
# plot em rec mean with error bars | ||
x = np.array(df_filtered['condition']) | ||
m = np.array(df_filtered['em_rec_mean']) | ||
yerr = np.stack([df_filtered['em_rec_ci_lower'], df_filtered['em_rec_ci_upper']], axis=0) | ||
ax[ax_x][ax_y].errorbar(x, m, yerr=abs(yerr - m), label=retriever) | ||
|
||
if dataset == 'nq-nocite': | ||
ax[ax_x][ax_y].set_title("NQ (No Citations)") | ||
else: | ||
ax[ax_x][ax_y].set_title(f'{dataset.upper()}') | ||
|
||
acc_label = 'EM Recall' | ||
if ax_y == 0: # only show y-axis label on the left column | ||
ax[ax_x][ax_y].set_ylabel(acc_label) | ||
|
||
if ax_x == 1: # only show x-axis label on the bottom row | ||
ax[ax_x][ax_y].set_xlabel('k') | ||
|
||
if shared_y: | ||
ax[ax_x][ax_y].set_ylim(0, 90) | ||
|
||
if 'nq' in dataset: | ||
ax[ax_x][ax_y].set_ylim(0, 90) | ||
|
||
# legend with only labels from first plot (top left) | ||
handles, labels = ax[0][0].get_legend_handles_labels() | ||
fig.legend(handles, labels, loc='lower center', ncol=4) | ||
|
||
# Save the figure to an output file | ||
fig.tight_layout() | ||
plt.subplots_adjust(bottom=0.17) # fix overlap with x axis label | ||
|
||
figname = f'plots/ndoc-reader-acc-{model}.png' | ||
if shared_y: | ||
figname = f'plots/ndoc-reader-acc-{model}-shared-y.png' | ||
plt.savefig(figname) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import plot_utils | ||
import seaborn as sns | ||
|
||
def main(): | ||
|
||
plot_utils.set_plot_params() | ||
colors = sns.color_palette() | ||
|
||
datasets = ['asqa'] | ||
retrievers = ['bge-base', 'gold'] | ||
models = ['Mistral'] | ||
percentiles = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100] | ||
|
||
# grab the files and generate data frames | ||
print("Loading data from noise experiments...") | ||
file_list, field_list = plot_utils.load_metric_noise_files( | ||
datasets, models, retrievers, conditions=percentiles, subfolder="noise" | ||
) | ||
df = plot_utils.compile_metric_df(file_list, field_list) | ||
|
||
print("Generating plots...") | ||
for dataset in datasets: | ||
for model in models: | ||
fig, ax = plt.subplots(figsize=(5, 2)) | ||
df_dataset = df[df['dataset'] == dataset] | ||
|
||
for idx, retriever in enumerate(retrievers): | ||
|
||
df_ret = df_dataset[df_dataset['retriever'] == retriever] | ||
ret_only = df_ret[df_ret['condition'] == 0] # retriever only | ||
df_ret = df_ret[df_ret['condition'] != 0] # remove retriever only | ||
|
||
# plot the data | ||
x = np.array(df_ret['condition']) | ||
m = np.array(df_ret['em_rec_mean']) | ||
yerr = np.stack([df_ret['em_rec_ci_lower'], df_ret['em_rec_ci_upper']], axis=0) | ||
ax.plot(x, m, label=f"{retriever} + noise", color=colors[idx]) | ||
ax.errorbar(x, m, yerr=abs(yerr - m), color=colors[idx]) | ||
ax.axhline(ret_only['em_rec_mean'].values[0], linestyle='--', label=retriever, color=colors[idx]) | ||
|
||
|
||
plt.ylabel("EM Recall") | ||
plt.xlabel("Noise Percentile") | ||
plt.title(f"{dataset.upper()} ({model})") | ||
plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left') | ||
plt.ylim([20, 55]) | ||
plt.tight_layout() | ||
fig.savefig(f"plots/rand_percentile_recall_{dataset}_{model}.png", dpi=300) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.