app.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. ## @package app
  2. # Module caffe2.python.mint.app
  3. import argparse
  4. import flask
  5. import glob
  6. import numpy as np
  7. import nvd3
  8. import os
  9. import sys
  10. # pyre-fixme[21]: Could not find module `tornado.httpserver`.
  11. import tornado.httpserver
  12. # pyre-fixme[21]: Could not find a module corresponding to import `tornado.wsgi`
  13. import tornado.wsgi
  14. __folder__ = os.path.abspath(os.path.dirname(__file__))
  15. app = flask.Flask(
  16. __name__,
  17. template_folder=os.path.join(__folder__, "templates"),
  18. static_folder=os.path.join(__folder__, "static")
  19. )
  20. args = None
  21. def jsonify_nvd3(chart):
  22. chart.buildcontent()
  23. # Note(Yangqing): python-nvd3 does not seem to separate the built HTML part
  24. # and the script part. Luckily, it seems to be the case that the HTML part is
  25. # only a <div>, which can be accessed by chart.container; the script part,
  26. # while the script part occupies the rest of the html content, which we can
  27. # then find by chart.htmlcontent.find['<script>'].
  28. script_start = chart.htmlcontent.find('<script>') + 8
  29. script_end = chart.htmlcontent.find('</script>')
  30. return flask.jsonify(
  31. result=chart.container,
  32. script=chart.htmlcontent[script_start:script_end].strip()
  33. )
  34. def visualize_summary(filename):
  35. try:
  36. data = np.loadtxt(filename)
  37. except Exception as e:
  38. return 'Cannot load file {}: {}'.format(filename, str(e))
  39. chart_name = os.path.splitext(os.path.basename(filename))[0]
  40. chart = nvd3.lineChart(
  41. name=chart_name + '_summary_chart',
  42. height=args.chart_height,
  43. y_axis_format='.03g'
  44. )
  45. if args.sample < 0:
  46. step = max(data.shape[0] / -args.sample, 1)
  47. else:
  48. step = args.sample
  49. xdata = np.arange(0, data.shape[0], step)
  50. # data should have 4 dimensions.
  51. chart.add_serie(x=xdata, y=data[xdata, 0], name='min')
  52. chart.add_serie(x=xdata, y=data[xdata, 1], name='max')
  53. chart.add_serie(x=xdata, y=data[xdata, 2], name='mean')
  54. chart.add_serie(x=xdata, y=data[xdata, 2] + data[xdata, 3], name='m+std')
  55. chart.add_serie(x=xdata, y=data[xdata, 2] - data[xdata, 3], name='m-std')
  56. return jsonify_nvd3(chart)
  57. def visualize_print_log(filename):
  58. try:
  59. data = np.loadtxt(filename)
  60. if data.ndim == 1:
  61. data = data[:, np.newaxis]
  62. except Exception as e:
  63. return 'Cannot load file {}: {}'.format(filename, str(e))
  64. chart_name = os.path.splitext(os.path.basename(filename))[0]
  65. chart = nvd3.lineChart(
  66. name=chart_name + '_log_chart',
  67. height=args.chart_height,
  68. y_axis_format='.03g'
  69. )
  70. if args.sample < 0:
  71. step = max(data.shape[0] / -args.sample, 1)
  72. else:
  73. step = args.sample
  74. xdata = np.arange(0, data.shape[0], step)
  75. # if there is only one curve, we also show the running min and max
  76. if data.shape[1] == 1:
  77. # We also print the running min and max for the steps.
  78. trunc_size = data.shape[0] / step
  79. running_mat = data[:trunc_size * step].reshape((trunc_size, step))
  80. chart.add_serie(
  81. x=xdata[:trunc_size],
  82. y=running_mat.min(axis=1),
  83. name='running_min'
  84. )
  85. chart.add_serie(
  86. x=xdata[:trunc_size],
  87. y=running_mat.max(axis=1),
  88. name='running_max'
  89. )
  90. chart.add_serie(x=xdata, y=data[xdata, 0], name=chart_name)
  91. else:
  92. for i in range(0, min(data.shape[1], args.max_curves)):
  93. # data should have 4 dimensions.
  94. chart.add_serie(
  95. x=xdata,
  96. y=data[xdata, i],
  97. name='{}[{}]'.format(chart_name, i)
  98. )
  99. return jsonify_nvd3(chart)
  100. def visualize_file(filename):
  101. fullname = os.path.join(args.root, filename)
  102. if filename.endswith('summary'):
  103. return visualize_summary(fullname)
  104. elif filename.endswith('log'):
  105. return visualize_print_log(fullname)
  106. else:
  107. return flask.jsonify(
  108. result='Unsupport file: {}'.format(filename),
  109. script=''
  110. )
  111. @app.route('/')
  112. def index():
  113. files = glob.glob(os.path.join(args.root, "*.*"))
  114. files.sort()
  115. names = [os.path.basename(f) for f in files]
  116. return flask.render_template(
  117. 'index.html',
  118. root=args.root,
  119. names=names,
  120. debug_messages=names
  121. )
  122. @app.route('/visualization/<string:name>')
  123. def visualization(name):
  124. ret = visualize_file(name)
  125. return ret
  126. def main(argv):
  127. parser = argparse.ArgumentParser("The mint visualizer.")
  128. parser.add_argument(
  129. '-p',
  130. '--port',
  131. type=int,
  132. default=5000,
  133. help="The flask port to use."
  134. )
  135. parser.add_argument(
  136. '-r',
  137. '--root',
  138. type=str,
  139. default='.',
  140. help="The root folder to read files for visualization."
  141. )
  142. parser.add_argument(
  143. '--max_curves',
  144. type=int,
  145. default=5,
  146. help="The max number of curves to show in a dump tensor."
  147. )
  148. parser.add_argument(
  149. '--chart_height',
  150. type=int,
  151. default=300,
  152. help="The chart height for nvd3."
  153. )
  154. parser.add_argument(
  155. '-s',
  156. '--sample',
  157. type=int,
  158. default=-200,
  159. help="Sample every given number of data points. A negative "
  160. "number means the total points we will sample on the "
  161. "whole curve. Default 100 points."
  162. )
  163. global args
  164. args = parser.parse_args(argv)
  165. server = tornado.httpserver.HTTPServer(tornado.wsgi.WSGIContainer(app))
  166. server.listen(args.port)
  167. print("Tornado server starting on port {}.".format(args.port))
  168. tornado.ioloop.IOLoop.instance().start()
  169. if __name__ == '__main__':
  170. main(sys.argv[1:])