net_drawer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. ## @package net_drawer
  2. # Module caffe2.python.net_drawer
  3. import argparse
  4. import json
  5. import logging
  6. from collections import defaultdict
  7. from caffe2.python import utils
  8. from future.utils import viewitems
  9. logger = logging.getLogger(__name__)
  10. logger.setLevel(logging.INFO)
  11. try:
  12. import pydot
  13. except ImportError:
  14. logger.info(
  15. 'Cannot import pydot, which is required for drawing a network. This '
  16. 'can usually be installed in python with "pip install pydot". Also, '
  17. 'pydot requires graphviz to convert dot files to pdf: in ubuntu, this '
  18. 'can usually be installed with "sudo apt-get install graphviz".'
  19. )
  20. print(
  21. 'net_drawer will not run correctly. Please install the correct '
  22. 'dependencies.'
  23. )
  24. pydot = None
  25. from caffe2.proto import caffe2_pb2
  26. OP_STYLE = {
  27. 'shape': 'box',
  28. 'color': '#0F9D58',
  29. 'style': 'filled',
  30. 'fontcolor': '#FFFFFF'
  31. }
  32. BLOB_STYLE = {'shape': 'octagon'}
  33. def _rectify_operator_and_name(operators_or_net, name):
  34. """Gets the operators and name for the pydot graph."""
  35. if isinstance(operators_or_net, caffe2_pb2.NetDef):
  36. operators = operators_or_net.op
  37. if name is None:
  38. name = operators_or_net.name
  39. elif hasattr(operators_or_net, 'Proto'):
  40. net = operators_or_net.Proto()
  41. if not isinstance(net, caffe2_pb2.NetDef):
  42. raise RuntimeError(
  43. "Expecting NetDef, but got {}".format(type(net)))
  44. operators = net.op
  45. if name is None:
  46. name = net.name
  47. else:
  48. operators = operators_or_net
  49. if name is None:
  50. name = "unnamed"
  51. return operators, name
  52. def _escape_label(name):
  53. # json.dumps is poor man's escaping
  54. return json.dumps(name)
  55. def GetOpNodeProducer(append_output, **kwargs):
  56. def ReallyGetOpNode(op, op_id):
  57. if op.name:
  58. node_name = '%s/%s (op#%d)' % (op.name, op.type, op_id)
  59. else:
  60. node_name = '%s (op#%d)' % (op.type, op_id)
  61. if append_output:
  62. for output_name in op.output:
  63. node_name += '\n' + output_name
  64. return pydot.Node(node_name, **kwargs)
  65. return ReallyGetOpNode
  66. def GetBlobNodeProducer(**kwargs):
  67. def ReallyGetBlobNode(node_name, label):
  68. return pydot.Node(node_name, label=label, **kwargs)
  69. return ReallyGetBlobNode
  70. def GetPydotGraph(
  71. operators_or_net,
  72. name=None,
  73. rankdir='LR',
  74. op_node_producer=None,
  75. blob_node_producer=None
  76. ):
  77. if op_node_producer is None:
  78. op_node_producer = GetOpNodeProducer(False, **OP_STYLE)
  79. if blob_node_producer is None:
  80. blob_node_producer = GetBlobNodeProducer(**BLOB_STYLE)
  81. operators, name = _rectify_operator_and_name(operators_or_net, name)
  82. graph = pydot.Dot(name, rankdir=rankdir)
  83. pydot_nodes = {}
  84. pydot_node_counts = defaultdict(int)
  85. for op_id, op in enumerate(operators):
  86. op_node = op_node_producer(op, op_id)
  87. graph.add_node(op_node)
  88. # print 'Op: %s' % op.name
  89. # print 'inputs: %s' % str(op.input)
  90. # print 'outputs: %s' % str(op.output)
  91. for input_name in op.input:
  92. if input_name not in pydot_nodes:
  93. input_node = blob_node_producer(
  94. _escape_label(
  95. input_name + str(pydot_node_counts[input_name])),
  96. label=_escape_label(input_name),
  97. )
  98. pydot_nodes[input_name] = input_node
  99. else:
  100. input_node = pydot_nodes[input_name]
  101. graph.add_node(input_node)
  102. graph.add_edge(pydot.Edge(input_node, op_node))
  103. for output_name in op.output:
  104. if output_name in pydot_nodes:
  105. # we are overwriting an existing blob. need to update the count.
  106. pydot_node_counts[output_name] += 1
  107. output_node = blob_node_producer(
  108. _escape_label(
  109. output_name + str(pydot_node_counts[output_name])),
  110. label=_escape_label(output_name),
  111. )
  112. pydot_nodes[output_name] = output_node
  113. graph.add_node(output_node)
  114. graph.add_edge(pydot.Edge(op_node, output_node))
  115. return graph
  116. def GetPydotGraphMinimal(
  117. operators_or_net,
  118. name=None,
  119. rankdir='LR',
  120. minimal_dependency=False,
  121. op_node_producer=None,
  122. ):
  123. """Different from GetPydotGraph, hide all blob nodes and only show op nodes.
  124. If minimal_dependency is set as well, for each op, we will only draw the
  125. edges to the minimal necessary ancestors. For example, if op c depends on
  126. op a and b, and op b depends on a, then only the edge b->c will be drawn
  127. because a->c will be implied.
  128. """
  129. if op_node_producer is None:
  130. op_node_producer = GetOpNodeProducer(False, **OP_STYLE)
  131. operators, name = _rectify_operator_and_name(operators_or_net, name)
  132. graph = pydot.Dot(name, rankdir=rankdir)
  133. # blob_parents maps each blob name to its generating op.
  134. blob_parents = {}
  135. # op_ancestry records the ancestors of each op.
  136. op_ancestry = defaultdict(set)
  137. for op_id, op in enumerate(operators):
  138. op_node = op_node_producer(op, op_id)
  139. graph.add_node(op_node)
  140. # Get parents, and set up op ancestry.
  141. parents = [
  142. blob_parents[input_name] for input_name in op.input
  143. if input_name in blob_parents
  144. ]
  145. op_ancestry[op_node].update(parents)
  146. for node in parents:
  147. op_ancestry[op_node].update(op_ancestry[node])
  148. if minimal_dependency:
  149. # only add nodes that do not have transitive ancestry
  150. for node in parents:
  151. if all(
  152. [node not in op_ancestry[other_node]
  153. for other_node in parents]
  154. ):
  155. graph.add_edge(pydot.Edge(node, op_node))
  156. else:
  157. # Add all parents to the graph.
  158. for node in parents:
  159. graph.add_edge(pydot.Edge(node, op_node))
  160. # Update blob_parents to reflect that this op created the blobs.
  161. for output_name in op.output:
  162. blob_parents[output_name] = op_node
  163. return graph
  164. def GetOperatorMapForPlan(plan_def):
  165. operator_map = {}
  166. for net_id, net in enumerate(plan_def.network):
  167. if net.HasField('name'):
  168. operator_map[plan_def.name + "_" + net.name] = net.op
  169. else:
  170. operator_map[plan_def.name + "_network_%d" % net_id] = net.op
  171. return operator_map
  172. def _draw_nets(nets, g):
  173. nodes = []
  174. for i, net in enumerate(nets):
  175. nodes.append(pydot.Node(_escape_label(net)))
  176. g.add_node(nodes[-1])
  177. if i > 0:
  178. g.add_edge(pydot.Edge(nodes[-2], nodes[-1]))
  179. return nodes
  180. def _draw_steps(steps, g, skip_step_edges=False): # noqa
  181. kMaxParallelSteps = 3
  182. def get_label():
  183. label = [step.name + '\n']
  184. if step.report_net:
  185. label.append('Reporter: {}'.format(step.report_net))
  186. if step.should_stop_blob:
  187. label.append('Stopper: {}'.format(step.should_stop_blob))
  188. if step.concurrent_substeps:
  189. label.append('Concurrent')
  190. if step.only_once:
  191. label.append('Once')
  192. return '\n'.join(label)
  193. def substep_edge(start, end):
  194. return pydot.Edge(start, end, arrowhead='dot', style='dashed')
  195. nodes = []
  196. for i, step in enumerate(steps):
  197. parallel = step.concurrent_substeps
  198. nodes.append(pydot.Node(_escape_label(get_label()), **OP_STYLE))
  199. g.add_node(nodes[-1])
  200. if i > 0 and not skip_step_edges:
  201. g.add_edge(pydot.Edge(nodes[-2], nodes[-1]))
  202. if step.network:
  203. sub_nodes = _draw_nets(step.network, g)
  204. elif step.substep:
  205. if parallel:
  206. sub_nodes = _draw_steps(
  207. step.substep[:kMaxParallelSteps], g, skip_step_edges=True)
  208. else:
  209. sub_nodes = _draw_steps(step.substep, g)
  210. else:
  211. raise ValueError('invalid step')
  212. if parallel:
  213. for sn in sub_nodes:
  214. g.add_edge(substep_edge(nodes[-1], sn))
  215. if len(step.substep) > kMaxParallelSteps:
  216. ellipsis = pydot.Node('{} more steps'.format(
  217. len(step.substep) - kMaxParallelSteps), **OP_STYLE)
  218. g.add_node(ellipsis)
  219. g.add_edge(substep_edge(nodes[-1], ellipsis))
  220. else:
  221. g.add_edge(substep_edge(nodes[-1], sub_nodes[0]))
  222. return nodes
  223. def GetPlanGraph(plan_def, name=None, rankdir='TB'):
  224. graph = pydot.Dot(name, rankdir=rankdir)
  225. _draw_steps(plan_def.execution_step, graph)
  226. return graph
  227. def GetGraphInJson(operators_or_net, output_filepath):
  228. operators, _ = _rectify_operator_and_name(operators_or_net, None)
  229. blob_strid_to_node_id = {}
  230. node_name_counts = defaultdict(int)
  231. nodes = []
  232. edges = []
  233. for op_id, op in enumerate(operators):
  234. op_label = op.name + '/' + op.type if op.name else op.type
  235. op_node_id = len(nodes)
  236. nodes.append({
  237. 'id': op_node_id,
  238. 'label': op_label,
  239. 'op_id': op_id,
  240. 'type': 'op'
  241. })
  242. for input_name in op.input:
  243. strid = _escape_label(
  244. input_name + str(node_name_counts[input_name]))
  245. if strid not in blob_strid_to_node_id:
  246. input_node = {
  247. 'id': len(nodes),
  248. 'label': input_name,
  249. 'type': 'blob'
  250. }
  251. blob_strid_to_node_id[strid] = len(nodes)
  252. nodes.append(input_node)
  253. else:
  254. input_node = nodes[blob_strid_to_node_id[strid]]
  255. edges.append({
  256. 'source': blob_strid_to_node_id[strid],
  257. 'target': op_node_id
  258. })
  259. for output_name in op.output:
  260. strid = _escape_label(
  261. output_name + str(node_name_counts[output_name]))
  262. if strid in blob_strid_to_node_id:
  263. # we are overwriting an existing blob. need to update the count.
  264. node_name_counts[output_name] += 1
  265. strid = _escape_label(
  266. output_name + str(node_name_counts[output_name]))
  267. if strid not in blob_strid_to_node_id:
  268. output_node = {
  269. 'id': len(nodes),
  270. 'label': output_name,
  271. 'type': 'blob'
  272. }
  273. blob_strid_to_node_id[strid] = len(nodes)
  274. nodes.append(output_node)
  275. edges.append({
  276. 'source': op_node_id,
  277. 'target': blob_strid_to_node_id[strid]
  278. })
  279. with open(output_filepath, 'w') as f:
  280. json.dump({'nodes': nodes, 'edges': edges}, f)
  281. # A dummy minimal PNG image used by GetGraphPngSafe as a
  282. # placeholder when rendering fail to run.
  283. _DummyPngImage = (
  284. b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00'
  285. b'\x01\x01\x00\x00\x00\x007n\xf9$\x00\x00\x00\nIDATx\x9cc`\x00\x00'
  286. b'\x00\x02\x00\x01H\xaf\xa4q\x00\x00\x00\x00IEND\xaeB`\x82')
  287. def GetGraphPngSafe(func, *args, **kwargs):
  288. """
  289. Invokes `func` (e.g. GetPydotGraph) with args. If anything fails - returns
  290. and empty image instead of throwing Exception
  291. """
  292. try:
  293. graph = func(*args, **kwargs)
  294. if not isinstance(graph, pydot.Dot):
  295. raise ValueError("func is expected to return pydot.Dot")
  296. return graph.create_png()
  297. except Exception as e:
  298. logger.error("Failed to draw graph: {}".format(e))
  299. return _DummyPngImage
  300. def main():
  301. parser = argparse.ArgumentParser(description="Caffe2 net drawer.")
  302. parser.add_argument(
  303. "--input",
  304. type=str, required=True,
  305. help="The input protobuf file."
  306. )
  307. parser.add_argument(
  308. "--output_prefix",
  309. type=str, default="",
  310. help="The prefix to be added to the output filename."
  311. )
  312. parser.add_argument(
  313. "--minimal", action="store_true",
  314. help="If set, produce a minimal visualization."
  315. )
  316. parser.add_argument(
  317. "--minimal_dependency", action="store_true",
  318. help="If set, only draw minimal dependency."
  319. )
  320. parser.add_argument(
  321. "--append_output", action="store_true",
  322. help="If set, append the output blobs to the operator names.")
  323. parser.add_argument(
  324. "--rankdir", type=str, default="LR",
  325. help="The rank direction of the pydot graph."
  326. )
  327. args = parser.parse_args()
  328. with open(args.input, 'r') as fid:
  329. content = fid.read()
  330. graphs = utils.GetContentFromProtoString(
  331. content, {
  332. caffe2_pb2.PlanDef: lambda x: GetOperatorMapForPlan(x),
  333. caffe2_pb2.NetDef: lambda x: {x.name: x.op},
  334. }
  335. )
  336. for key, operators in viewitems(graphs):
  337. if args.minimal:
  338. graph = GetPydotGraphMinimal(
  339. operators,
  340. name=key,
  341. rankdir=args.rankdir,
  342. node_producer=GetOpNodeProducer(args.append_output, **OP_STYLE),
  343. minimal_dependency=args.minimal_dependency)
  344. else:
  345. graph = GetPydotGraph(
  346. operators,
  347. name=key,
  348. rankdir=args.rankdir,
  349. node_producer=GetOpNodeProducer(args.append_output, **OP_STYLE))
  350. filename = args.output_prefix + graph.get_name() + '.dot'
  351. graph.write(filename, format='raw')
  352. pdf_filename = filename[:-3] + 'pdf'
  353. try:
  354. graph.write_pdf(pdf_filename)
  355. except Exception:
  356. print(
  357. 'Error when writing out the pdf file. Pydot requires graphviz '
  358. 'to convert dot files to pdf, and you may not have installed '
  359. 'graphviz. On ubuntu this can usually be installed with "sudo '
  360. 'apt-get install graphviz". We have generated the .dot file '
  361. 'but will not be able to generate pdf file for now.'
  362. )
  363. if __name__ == '__main__':
  364. main()