gen_op.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. #!/bin/env python3
  2. # Copyright (c) 2016-present, Facebook, Inc.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. ##############################################################################
  16. import sys
  17. import yaml
  18. import argparse
  19. import os
  20. from copy import deepcopy
  21. from typing import Dict, List, Set
  22. parser = argparse.ArgumentParser()
  23. parser.add_argument("--template_dir", default=".", help="where template.h is")
  24. parser.add_argument("--yaml_dir", default="aten/src/ATen/ATen",
  25. help="where ATen yaml files are")
  26. parser.add_argument("--output_prefix", default="", help="")
  27. parser.add_argument(
  28. "--install_dir", default=".", help="where to put generated file")
  29. parser.add_argument("--aten_root", default="", help="root directory of aten")
  30. args, _ = parser.parse_known_args()
  31. if args.aten_root:
  32. if not os.path.exists(args.aten_root):
  33. raise ValueError('aten_root ({}) does not exist'.format(
  34. args.aten_root))
  35. sys.path.insert(0, os.path.join(args.aten_root, '..'))
  36. from torchgen.code_template import CodeTemplate as CT
  37. else:
  38. from torchgen.code_template import CodeTemplate as CT
  39. OP_TEMPLATE = CT.from_file(
  40. os.path.join(args.template_dir, 'aten_op_template.h'))
  41. try:
  42. # use faster C loader if available
  43. from yaml import CSafeLoader as Loader
  44. except ImportError:
  45. from yaml import SafeLoader as Loader # type: ignore[misc]
  46. def write(filename, s):
  47. with open(filename, "w") as f:
  48. f.write(s)
  49. def read(filename):
  50. with open(filename, "r") as f:
  51. return f.read()
  52. def value_has_tensors(v):
  53. # Sparse shouldn't appear in public API, seems to be temporary bug
  54. return "Tensor" in v['dynamic_type'] and "Sparse" not in v['dynamic_type']
  55. def value_is_tensor_type(v):
  56. return value_has_tensors(v) and v['dynamic_type'] not in ['at::TensorList', 'const c10::List<c10::optional<at::Tensor>> &']
  57. # for each aten type, how do we handle a return value of that type?
  58. RETURN_MAP = {
  59. 'at::Tensor': 'assignTo(Output(${offset}),${output});',
  60. 'at::Scalar': 'assignTo(Output(${offset}),${output}.type(), ${output});',
  61. 'bool': 'assignToValue<int64_t>(Output(${offset}),${output});',
  62. 'int64_t': 'assignToValue<int64_t>(Output(${offset}),${output});',
  63. '::std::vector<at::Tensor>': 'assignListStartingAt(${offset}, ${output});',
  64. }
  65. # for each non-Tensor aten argument, how to we read it from caffe2's
  66. # attribute list. Most of these call runtime functions defined in the
  67. # template class.
  68. ARGUMENT_MAP = {
  69. 'const at::Scalar &': 'at::Scalar ${arg} = readScalarAttribute("${arg}");',
  70. 'bool': 'bool ${arg} = readAttribute<int64_t>("${arg}");',
  71. 'int': 'int ${arg} = readAttribute<int64_t>("${arg}");',
  72. 'double': 'double ${arg} = readAttribute<float>("${arg}");',
  73. 'int64_t': 'int64_t ${arg} = readAttribute<int64_t>("${arg}");',
  74. 'at::IntArrayRef': 'auto ${arg} = readIntArrayRef("${arg}");',
  75. '::std::array<bool,2>': 'auto ${arg} = readBoolMask<2>("${arg}");',
  76. '::std::array<bool,3>': 'auto ${arg} = readBoolMask<3>("${arg}");',
  77. }
  78. # for BC reasons we want to route some of the functions to different
  79. # implementations
  80. SPECIAL_IMPLEMENTATIONS = {
  81. 'index': 'internal::index_with_uint8_handling',
  82. }
  83. def expand(o):
  84. num_defaults = sum(1 if 'default' in arg else 0 for arg in o['arguments'])
  85. results = [o]
  86. for i in range(0, num_defaults):
  87. # last num_default values should be default
  88. assert('default' in o['arguments'][-(i + 1)])
  89. v = deepcopy(o)
  90. v['arguments'] = v['arguments'][:-(i + 1)]
  91. results.append(v)
  92. return results
  93. # filter the list of declarations removing things we cannot support
  94. def supports(o, factory_methods):
  95. # Ignore all families (!) of functions that have TensorOptions (i.e. tensor factory methods).
  96. if o['name'] in factory_methods:
  97. if factory_methods[o['name']] == 0:
  98. print("Skipping {} because it is a factory method".format(o['name']))
  99. factory_methods[o['name']] += 1
  100. return False
  101. # skip all in-place operators for now since aten cannot Resize
  102. # caffe2 memory inside an operator
  103. if o['inplace']:
  104. return False
  105. # _out variants also work in-place on arguments taken as destinations
  106. # we also cannot handle these because aten cannot resize caffe2 Tensors
  107. if "_out" in o['name']:
  108. return False
  109. # skip if no return, previously it is 'void'
  110. if len(o['returns']) == 0:
  111. return False
  112. # skip return types we cannot handle
  113. for ret in o['returns']:
  114. if not value_has_tensors(ret) and ret['type'] not in RETURN_MAP:
  115. print("Skipping {} Because of Ret: {} ({})".format(
  116. o['name'], ret['type'], ret['dynamic_type']))
  117. return False
  118. # skip arguments we cannot handle
  119. for arg in o['arguments']:
  120. if not value_has_tensors(arg) and arg['type'] not in ARGUMENT_MAP:
  121. print("Skipping {} Because of Arg: {} ({}) ".format(
  122. o['name'], arg['type'], arg['dynamic_type']))
  123. return False
  124. return True
  125. # template for each potential operator.
  126. # each operator has an integer 'key' associated with it, and
  127. # a lambda that defines the operator
  128. # non-tensor attributes are created in ${initialization}
  129. # and then saved as arguments to the lambda
  130. # Inputs/Outputs are read inside the lambda
  131. #
  132. # each implementation is defined in a separate method annotated with
  133. # C10_NOINLINE to avoid inlining into the ATenOp constructor, which would
  134. # trigger pathological compile times.
  135. IMPLEMENTATION_TEMPLATE = CT("""\
  136. C10_NOINLINE void implementation_${key}() { // ${name}
  137. ${initialization}
  138. run_op = [=] {
  139. at::AutoDispatchBelowAutograd guard;
  140. ${statements}
  141. auto the_result = ${invocation};
  142. ${assignments}
  143. return true;
  144. };
  145. }
  146. """)
  147. CASE_TEMPLATE = CT("""\
  148. case ${key}: // ${name}
  149. implementation_${key}();
  150. break;
  151. """)
  152. ASSIGN_CHECK_SIZE_TEMPLATE = CT("""\
  153. if(OutputSize() > ${offset}) {${assignment}}
  154. """)
  155. def get_output(o, i):
  156. if len(o['returns']) == 1:
  157. return 'the_result'
  158. else:
  159. return '::std::get<{}>(the_result)'.format(i)
  160. def attribute_names(o):
  161. return sorted([a['name'] for a in o['arguments'] if not value_has_tensors(a)])
  162. def required_attribute_names(o):
  163. return sorted([a['name'] for a in o['arguments'] if not value_has_tensors(a) and 'default' not in a])
  164. def self_as_first_argument(arguments):
  165. return ([a for a in arguments if a['name'] == 'self'] +
  166. [a for a in arguments if a['name'] != 'self'])
  167. def get_num_inputs(o):
  168. args = 0
  169. for a in o['arguments']:
  170. if a['type'] in ['at::TensorList', 'const c10::List<c10::optional<at::Tensor>> &']:
  171. return '*'
  172. elif value_has_tensors(a):
  173. args += 1
  174. return str(args)
  175. def find_factory_methods(decls):
  176. factory_methods = {}
  177. for o in decls:
  178. if any(arg['dynamic_type'] == 'at::TensorOptions' for arg in o['arguments']):
  179. factory_methods[o['name']] = 0
  180. return factory_methods
  181. def emit_assignments(o, env):
  182. for i, r in enumerate(o['returns']):
  183. t = RETURN_MAP[r['type'] if not value_is_tensor_type(r) else 'at::Tensor']
  184. assignment = CT(t).substitute(env, offset=i, output=get_output(o, i))
  185. check_size_assignment = ASSIGN_CHECK_SIZE_TEMPLATE.substitute(env, offset=i, assignment=assignment)
  186. env['assignments'].append(check_size_assignment)
  187. if __name__ == '__main__':
  188. decls = yaml.load(read(os.path.join(args.yaml_dir, 'Declarations.yaml')), Loader=Loader)
  189. factory_methods = find_factory_methods(decls)
  190. filtered = [expanded for o in decls for expanded in expand(o) if supports(expanded, factory_methods)]
  191. top_env: Dict[str, List] = {
  192. 'mappings': [],
  193. 'implementations': [],
  194. 'cases': [],
  195. }
  196. seen: Set[str] = set()
  197. key = 0
  198. for o in filtered:
  199. # [DESCRIPTORS]
  200. # each option is associated with a descriptor string that is used
  201. # to figure out which version of an op is being used:
  202. # The format is:
  203. # opname-num_inputs-attribute_1-attribute2
  204. # Example:
  205. # lerp-2-weight
  206. # the operator lerp takes 2 arguments and has the attribute weight
  207. attr_names = attribute_names(o)
  208. num_inputs = get_num_inputs(o)
  209. descriptor = '-'.join([o['name']] + attr_names + [num_inputs])
  210. if descriptor in seen:
  211. continue
  212. seen.add(descriptor)
  213. # map from descriptor string to the integer key in the switch statements
  214. # that initializes the operators
  215. top_env['mappings'].append('{{ "{}", {} }},'.format(descriptor, key))
  216. env = {
  217. 'name': o['name'],
  218. 'statements': [],
  219. 'arguments': [],
  220. 'assignments': [],
  221. 'initialization': [],
  222. 'key': str(key),
  223. }
  224. if 'namespace' not in o['method_of'] and 'Tensor' not in o['method_of']:
  225. # methods on type like 'ones' or 'zeros' always take a
  226. # string attribute that is translated into the at::Type object
  227. # e.g. "Float" is at::kFloat
  228. assert('Type' in o['method_of'])
  229. static_tensor_inputs = sum(arg['type'] not in ['at::TensorList', 'const c10::List<c10::optional<at::Tensor>> &'] and value_is_tensor_type(arg) for arg in o['arguments'])
  230. has_tensorlist = any(arg['type'] in ['at::TensorList', 'const c10::List<c10::optional<at::Tensor>> &'] for arg in o['arguments'])
  231. if has_tensorlist:
  232. tensorlist_idx = [i for i, arg in enumerate(o['arguments']) if arg['type'] in ['at::TensorList', 'const c10::List<c10::optional<at::Tensor>> &']][0]
  233. real_inputs = 0
  234. for i, arg in enumerate(o['arguments']):
  235. env['arguments'].append(arg['name'])
  236. # Pretend the flat argument list is a stack where the end is the top.
  237. view_length = 'InputSize()' if has_tensorlist and i < tensorlist_idx else static_tensor_inputs
  238. if arg['type'] == 'at::TensorList':
  239. # NOTE: do not advance real_inputs here. After this we will
  240. # switch to indexing the "stack" from the end
  241. env['statements'].append(
  242. 'auto {} = peekSlice({}, InputSize() - {}, InputSize());'
  243. .format(arg['name'], real_inputs, static_tensor_inputs))
  244. elif arg['type'] == 'const c10::List<c10::optional<at::Tensor>> &':
  245. # NOTE: do not advance real_inputs here. After this we will
  246. # switch to indexing the "stack" from the end
  247. env['statements'].append(
  248. 'auto {} = peekSliceOptionals({}, InputSize() - {}, InputSize());'
  249. .format(arg['name'], real_inputs, static_tensor_inputs))
  250. elif value_is_tensor_type(arg):
  251. # load tensor inputs from Caffe2
  252. env['statements'].append(
  253. 'auto {} = peek({}, {});'.format(arg['name'], real_inputs, view_length))
  254. real_inputs += 1
  255. else:
  256. init = CT(ARGUMENT_MAP[arg['type']]).substitute(env, arg=arg['name'])
  257. env['initialization'].append(init)
  258. emit_assignments(o, env)
  259. if o['name'] in SPECIAL_IMPLEMENTATIONS:
  260. env['invocation'] = "{}({})".format(SPECIAL_IMPLEMENTATIONS[o['name']], ','.join(env['arguments']))
  261. elif 'namespace' in o['method_of']:
  262. env['invocation'] = CT("at::${name}(${arguments})").substitute(env)
  263. else:
  264. assert('Tensor' in o['method_of'])
  265. env['invocation'] = "self.{}({})".format(
  266. o['name'], ', '.join(env['arguments'][1:]))
  267. top_env['implementations'].append(IMPLEMENTATION_TEMPLATE.substitute(env))
  268. top_env['cases'].append(CASE_TEMPLATE.substitute(env))
  269. key += 1
  270. write(os.path.join(args.install_dir, args.output_prefix + "aten_op.h"), OP_TEMPLATE.substitute(top_env))