diff --git a/src/plot_train.py b/src/plot_train.py index 32c821d..cf651a5 100644 --- a/src/plot_train.py +++ b/src/plot_train.py @@ -146,7 +146,7 @@ def main(): parser.add_argument('--labels', nargs='+', default=[], help='Labels for each model used in the plot. There should be as many ' 'labels as models.') - parser.add_argument('--metric', default='perplexity-diff', + parser.add_argument('--metric', default='cross-entropy-diff', help='Which metric to plot on the y-axis. Default is difference in ' 'cross entropy between the model and the lower bound on the ' 'validation set.')