Skip to content

Commit

Permalink
Merge pull request BVLC#4048 from achalddave/python_plot_exit_properly
Browse files Browse the repository at this point in the history
(Minor) Exit on error and report argument error details
  • Loading branch information
longjon committed May 4, 2016
2 parents de8ac32 + c2656f0 commit 3d41c8a
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions tools/extra/plot_training_log.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import matplotlib.legend as lgd
import matplotlib.markers as mks

def get_log_parsing_script():
dirname = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
dirname = os.path.dirname(os.path.abspath(inspect.getfile(
inspect.currentframe())))
return dirname + '/parse_log.sh'

def get_log_file_suffix():
Expand Down Expand Up @@ -61,16 +62,17 @@ def get_data_file_type(chart_type):
return data_file_type

def get_data_file(chart_type, path_to_log):
return os.path.basename(path_to_log) + '.' + get_data_file_type(chart_type).lower()
return (os.path.basename(path_to_log) + '.' +
get_data_file_type(chart_type).lower())

def get_field_descriptions(chart_type):
description = get_chart_type_description(chart_type).split(
get_chart_type_description_separator())
y_axis_field = description[0]
x_axis_field = description[1]
return x_axis_field, y_axis_field
return x_axis_field, y_axis_field

def get_field_indecies(x_axis_field, y_axis_field):
def get_field_indices(x_axis_field, y_axis_field):
data_file_type = get_data_file_type(chart_type)
fields = create_field_index()[0][data_file_type]
return fields[x_axis_field], fields[y_axis_field]
Expand Down Expand Up @@ -111,7 +113,7 @@ def plot_chart(chart_type, path_to_png, path_to_log_list):
os.system('%s %s' % (get_log_parsing_script(), path_to_log))
data_file = get_data_file(chart_type, path_to_log)
x_axis_field, y_axis_field = get_field_descriptions(chart_type)
x, y = get_field_indecies(x_axis_field, y_axis_field)
x, y = get_field_indices(x_axis_field, y_axis_field)
data = load_data(data_file, x, y)
## TODO: more systematic color cycle for lines
color = [random.random(), random.random(), random.random()]
Expand All @@ -138,8 +140,8 @@ def plot_chart(chart_type, path_to_png, path_to_log_list):
plt.legend(loc = legend_loc, ncol = 1) # ajust ncol to fit the space
plt.title(get_chart_type_description(chart_type))
plt.xlabel(x_axis_field)
plt.ylabel(y_axis_field)
plt.savefig(path_to_png)
plt.ylabel(y_axis_field)
plt.savefig(path_to_png)
plt.show()

def print_help():
Expand All @@ -160,28 +162,30 @@ Supported chart types:""" % (len(get_supported_chart_types()) - 1,
num = len(supported_chart_types)
for i in xrange(num):
print ' %d: %s' % (i, supported_chart_types[i])
exit
sys.exit()

def is_valid_chart_type(chart_type):
return chart_type >= 0 and chart_type < len(get_supported_chart_types())

if __name__ == '__main__':
if len(sys.argv) < 4:
print_help()
else:
chart_type = int(sys.argv[1])
if not is_valid_chart_type(chart_type):
print '%s is not a valid chart type.' % chart_type
print_help()
path_to_png = sys.argv[2]
if not path_to_png.endswith('.png'):
print 'Path must ends with png' % path_to_png
exit
sys.exit()
path_to_logs = sys.argv[3:]
for path_to_log in path_to_logs:
if not os.path.exists(path_to_log):
print 'Path does not exist: %s' % path_to_log
exit
sys.exit()
if not path_to_log.endswith(get_log_file_suffix()):
print 'Log file must end in %s.' % get_log_file_suffix()
print_help()
## plot_chart accpets multiple path_to_logs
plot_chart(chart_type, path_to_png, path_to_logs)

0 comments on commit 3d41c8a

Please sign in to comment.