generator.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. ## @package generator
  2. # Module caffe2.python.docs.generator
  3. import argparse
  4. import os
  5. from caffe2.python import core, workspace
  6. from caffe2.python.docs.formatter import Markdown
  7. from future.utils import viewitems, viewvalues
  8. OpSchema = workspace.C.OpSchema
  9. class DocUploader(object):
  10. def __init__(self):
  11. pass
  12. def upload(self, text):
  13. pass
  14. class DocGenerator(object):
  15. def __init__(self, formatter, uploader):
  16. self.formatter = formatter
  17. self.uploader = uploader
  18. self.content_body = ""
  19. def create_body(self):
  20. pass
  21. def update(self):
  22. self.uploader.upload(self.content_body)
  23. class OpDocGenerator(DocGenerator):
  24. def getOperatorDoc(self, name, schema, priority):
  25. return OperatorDoc(name, schema, priority)
  26. def getOperatorEngine(self, name):
  27. return OperatorEngine(name)
  28. def getOperators(self):
  29. # map: op_name -> operator
  30. self.operators = {}
  31. # map: op_name -> [engine, engine]
  32. self.engines = {}
  33. def filePriority(x):
  34. if x == "caffe2/caffe2/operators":
  35. return 0
  36. if 'contrib' in x.split('/'):
  37. return 2
  38. if 'experiments' in x.split('/'):
  39. return 3
  40. return 1
  41. for name in core._GetRegisteredOperators():
  42. schema = OpSchema.get(name)
  43. if schema:
  44. priority = filePriority(os.path.dirname(schema.file))
  45. operator = self.getOperatorDoc(name, schema, priority)
  46. self.operators[name] = operator
  47. # Engine
  48. elif name.find("_ENGINE_") != -1:
  49. engine = self.getOperatorEngine(name)
  50. if engine.base_op_name in self.engines:
  51. self.engines[engine.base_op_name].append(engine)
  52. else:
  53. self.engines[engine.base_op_name] = [engine]
  54. # No schema
  55. else:
  56. priority = 4
  57. self.operators[name] = self.getOperatorDoc(name, schema, priority)
  58. for name, engines in viewitems(self.engines):
  59. if name in self.operators:
  60. self.operators[name].addEngines(engines)
  61. # Generate a sorted list of operators
  62. return sorted(
  63. viewvalues(self.operators),
  64. key=lambda op: (op.priority, op.name)
  65. )
  66. def createBody(self):
  67. operators = self.getOperators()
  68. for operator in operators:
  69. operator.generateSchema(self.formatter)
  70. self.content_body += self.formatter.dump()
  71. class OperatorEngine(object):
  72. def __init__(self, name):
  73. self.op_name = name
  74. self.base_op_name, self.engine = name.split("_ENGINE_", 1)
  75. def getDeviceImpl(self):
  76. deviceImplList = []
  77. for device, impl in [('CPU', OpSchema.get_cpu_impl(self.op_name)),
  78. ('CUDA', OpSchema.get_cuda_impl(self.op_name))]:
  79. if not impl:
  80. continue
  81. deviceImplList.append((device, impl))
  82. return deviceImplList
  83. def generateDoc(self, formatter):
  84. for device, impl in self.getDeviceImpl():
  85. formatter.addLine(
  86. '{engine} on {device}: {impl}'.format(engine=self.engine,
  87. device=device,
  88. impl=impl))
  89. class OperatorDoc(object):
  90. def __init__(self, name, schema, priority):
  91. self.name = name
  92. self.schema = schema
  93. self.priority = priority
  94. print("Gathering docs for {}...".format(self.name))
  95. self.engines = []
  96. def addEngines(self, engines):
  97. self.engines = engines
  98. def generateDoc(self, formatter):
  99. if self.schema.doc:
  100. formatter.parseAndAdd(self.schema.doc)
  101. formatter.addLinebreak()
  102. else:
  103. formatter.addLine("No documentation yet.")
  104. def generateTable(self, formatter, tuples, title_row, title):
  105. if tuples:
  106. if title:
  107. formatter.addHeader(title, 3)
  108. table = []
  109. if title_row:
  110. table = [title_row]
  111. for name, doc in tuples:
  112. table.append([name, doc or ''])
  113. formatter.addTable(table, (table == []))
  114. def generateInterface(self, formatter):
  115. def makeDesc(title, args):
  116. f = formatter.clone()
  117. f.addEmphasis(title, 1)
  118. out = [(f.dump(), '')]
  119. for arg in args:
  120. f = formatter.clone()
  121. if isinstance(arg, tuple):
  122. name = arg[0]
  123. if len(arg) > 1:
  124. description = arg[1] or ''
  125. else:
  126. description = ''
  127. else:
  128. name = arg.name
  129. description = arg.description or ''
  130. f.addCode(name, inline=True)
  131. out.append((f.dump(), description or ''))
  132. return out
  133. tuples = []
  134. if self.schema.args:
  135. tuples += makeDesc('Arguments', self.schema.args)
  136. if self.schema.input_desc:
  137. tuples += makeDesc('Inputs', self.schema.input_desc)
  138. if self.schema.output_desc:
  139. tuples += makeDesc('Outputs', self.schema.output_desc)
  140. self.generateTable(formatter, tuples, None, 'Interface')
  141. print("Generated interface for {}".format(self.name))
  142. def generateCodeLink(self, formatter):
  143. formatter.addHeader("Code", 3)
  144. formatter.addLinebreak()
  145. formatter.addCodeLink(self.schema.file)
  146. def getInfo(self, formatter, name, impl):
  147. pass
  148. def generateDevices(self, formatter):
  149. formatter.addHeader("Devices", 3)
  150. devices = [
  151. self.getInfo(formatter,
  152. 'CPU', OpSchema.get_cpu_impl(self.name)),
  153. self.getInfo(formatter,
  154. 'GPU', OpSchema.get_cuda_impl(self.name)),
  155. ]
  156. formatter.addList([i for i in devices if i])
  157. def generateEngines(self, formatter):
  158. if not len(self.engines):
  159. return
  160. formatter.addHeader("Engines", 3)
  161. for engine in self.engines:
  162. engine.generateDoc(formatter)
  163. def generateSchema(self, formatter):
  164. formatter.addHeader(self.name, 2)
  165. if self.schema:
  166. self.generateDoc(formatter)
  167. self.generateInterface(formatter)
  168. self.generateCodeLink(formatter)
  169. self.generateDevices(formatter)
  170. self.generateEngines(formatter)
  171. formatter.addBreak()
  172. else:
  173. formatter.addLine("No schema documented yet.")
  174. self.generateDevices(formatter)
  175. if __name__ == "__main__":
  176. parser = argparse.ArgumentParser(description="Operators catalog generator.")
  177. parser.add_argument('catalog_path', type=str,
  178. help='operators-catalogue.md to write out to')
  179. args = parser.parse_args()
  180. with open(args.catalog_path, 'w') as fp:
  181. ops = OpDocGenerator(Markdown(), DocUploader())
  182. ops.createBody()
  183. fp.write(ops.content_body)