utils.py 58 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541
  1. """Functions to export models into the ONNX IR format.
  2. These models can be loaded with the ONNX library and then
  3. converted to models which run on other deep learning frameworks.
  4. """
  5. from __future__ import annotations
  6. import contextlib
  7. import copy
  8. import inspect
  9. import itertools
  10. import os
  11. import re
  12. import textwrap
  13. import typing
  14. import warnings
  15. import zipfile
  16. from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  17. import torch
  18. import torch._C._onnx as _C_onnx
  19. import torch.jit._trace
  20. import torch.serialization
  21. from torch import _C
  22. from torch.onnx import ( # noqa: F401
  23. _constants,
  24. _patch_torch,
  25. symbolic_caffe2,
  26. symbolic_helper,
  27. symbolic_registry,
  28. )
  29. from torch.onnx._globals import GLOBALS
  30. # the flag to tell the user whether it's in the middle of ONNX export or not
  31. __IN_ONNX_EXPORT = False
  32. def is_in_onnx_export():
  33. global __IN_ONNX_EXPORT
  34. return __IN_ONNX_EXPORT
  35. # TODO(justinchuby): Remove dependency to this global variable from constant_fold.cpp
  36. # Skip check due to cannot import IValue from torch._C
  37. _params_dict = {} # type: ignore[var-annotated]
  38. @contextlib.contextmanager
  39. def select_model_mode_for_export(model, mode):
  40. if not isinstance(model, torch.jit.ScriptFunction):
  41. is_originally_training = model.training
  42. if mode is None:
  43. mode = _C_onnx.TrainingMode.EVAL
  44. # if the model is in training mode but the user did not specify
  45. # to export the model in training mode, export the model in inference
  46. # mode (default) and warn them
  47. if is_originally_training:
  48. warnings.warn(
  49. "You are exporting the model to ONNX while in training mode with "
  50. "'train' parameter not specified. The model will default to inference mode export. "
  51. "If you wish to export a training amenable ONNX model, specify training=TrainingMode.TRAINING or "
  52. "training=TrainingMode.PRESERVE (to preserve the original model state) in torch.onnx.export()."
  53. )
  54. # if mode == TrainingMode.EVAL or (mode == TrainingMode.PRESERVE and not is_originally_training) => is_training = False
  55. is_export_training = False
  56. # ONNX opset 12 has better support for training amenable models, with updated
  57. # versions of the dropout and batch_norm operators
  58. if mode == _C_onnx.TrainingMode.TRAINING or (
  59. mode == _C_onnx.TrainingMode.PRESERVE and is_originally_training
  60. ):
  61. if GLOBALS.export_onnx_opset_version < 12:
  62. warnings.warn(
  63. "You are exporting the model in training mode with onnx opset version {}. "
  64. "Opset versions lower than opset 12 will not be able to export nodes such as "
  65. "Dropout and BatchNorm correctly.".format(
  66. GLOBALS.export_onnx_opset_version
  67. )
  68. )
  69. is_export_training = True
  70. symbolic_helper._set_training_mode(is_export_training)
  71. model.train(is_export_training)
  72. try:
  73. yield
  74. finally:
  75. if not isinstance(model, torch.jit.ScriptFunction):
  76. # FIXME(justinchuby): is_originally_training is possibly unbound
  77. model.train(is_originally_training)
  78. @contextlib.contextmanager
  79. def disable_apex_o2_state_dict_hook(model):
  80. # Apex O2 hook state_dict to return fp16 weights as fp32.
  81. # Exporter cannot identify them as same tensors.
  82. # Since this hook is only used by optimizer, it is safe to
  83. # remove this hook while exporting.
  84. if not isinstance(model, torch.jit.ScriptFunction):
  85. tmp_map = {} # type: ignore[var-annotated]
  86. for module in model.modules():
  87. for k, v in module._state_dict_hooks.items():
  88. if type(v).__name__ == "O2StateDictHook":
  89. if module not in tmp_map:
  90. tmp_map[module] = {}
  91. tmp_map[module][k] = v
  92. if module in tmp_map:
  93. for k in tmp_map[module].keys():
  94. module._state_dict_hooks.pop(k)
  95. try:
  96. yield
  97. finally:
  98. if not isinstance(model, torch.jit.ScriptFunction):
  99. # FIXME(justinchuby): tmp_map is possibly unbound
  100. for module, m_map in tmp_map.items():
  101. for k, v in m_map.items():
  102. module._state_dict_hooks[k] = v
  103. @contextlib.contextmanager
  104. def setup_onnx_logging(verbose):
  105. is_originally_enabled = torch.onnx.is_onnx_log_enabled()
  106. if is_originally_enabled or verbose:
  107. torch.onnx.enable_log()
  108. try:
  109. yield
  110. finally:
  111. if not is_originally_enabled:
  112. torch.onnx.disable_log()
  113. @contextlib.contextmanager
  114. def exporter_context(model, mode, verbose):
  115. with select_model_mode_for_export(
  116. model, mode
  117. ) as mode_ctx, disable_apex_o2_state_dict_hook(
  118. model
  119. ) as apex_ctx, setup_onnx_logging(
  120. verbose
  121. ) as log_ctx:
  122. yield (mode_ctx, apex_ctx, log_ctx)
  123. def export(
  124. model,
  125. args,
  126. f,
  127. export_params=True,
  128. verbose=False,
  129. training=None,
  130. input_names=None,
  131. output_names=None,
  132. operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
  133. opset_version=None,
  134. do_constant_folding=True,
  135. dynamic_axes=None,
  136. keep_initializers_as_inputs=None,
  137. custom_opsets=None,
  138. export_modules_as_functions=False,
  139. ):
  140. _export(
  141. model,
  142. args,
  143. f,
  144. export_params,
  145. verbose,
  146. training,
  147. input_names,
  148. output_names,
  149. operator_export_type=operator_export_type,
  150. opset_version=opset_version,
  151. do_constant_folding=do_constant_folding,
  152. dynamic_axes=dynamic_axes,
  153. keep_initializers_as_inputs=keep_initializers_as_inputs,
  154. custom_opsets=custom_opsets,
  155. export_modules_as_functions=export_modules_as_functions,
  156. )
  157. def _is_constant_tensor_list(node):
  158. if node.kind() != "prim::Constant":
  159. return False
  160. output_type = node.output().type()
  161. if output_type.isSubtypeOf(_C.ListType.ofTensors()):
  162. return True
  163. if output_type.isSubtypeOf(_C.ListType(_C.OptionalType.ofTensor())):
  164. return True
  165. # ONNX can't handle constants that are lists of tensors, which can
  166. # get generated in constant prop. So we split them back into prim::ListConstructs
  167. def _split_tensor_list_constants(g, block):
  168. for node in block.nodes():
  169. for subblock in node.blocks():
  170. _split_tensor_list_constants(g, subblock)
  171. if _is_constant_tensor_list(node):
  172. inputs = []
  173. for val in node.output().toIValue():
  174. input = g.insertConstant(val)
  175. input.node().moveBefore(node)
  176. input.node().copyMetadata(node)
  177. inputs.append(input)
  178. lc = (
  179. g.create("prim::ListConstruct", inputs)
  180. .insertBefore(node)
  181. .output()
  182. .setType(_C.ListType.ofTensors())
  183. )
  184. lc.node().copyMetadata(node)
  185. node.output().replaceAllUsesWith(lc)
  186. def _optimize_graph(
  187. graph: _C.Graph,
  188. operator_export_type: _C_onnx.OperatorExportTypes,
  189. _disable_torch_constant_prop: bool = False,
  190. fixed_batch_size: bool = False,
  191. params_dict=None,
  192. dynamic_axes=None,
  193. input_names=None,
  194. module=None,
  195. ):
  196. # Inline everything
  197. _C._jit_pass_inline(graph)
  198. # Remove fork/wait nodes
  199. _C._jit_pass_inline_fork_wait(graph)
  200. _C._jit_pass_lint(graph)
  201. _C._jit_pass_lower_all_tuples(graph)
  202. # we now record some ops like ones/zeros
  203. # into a trace where we previously recorded constants.
  204. # use constant prop to maintain our current level of onnx support
  205. # without implementing symbolics for all of them
  206. if _disable_torch_constant_prop is False:
  207. _C._jit_pass_constant_propagation(graph)
  208. _split_tensor_list_constants(graph, graph)
  209. # run dce to eliminate dead parts of the graph that might have been
  210. # left behind by things like symbolic_override
  211. _C._jit_pass_dce(graph)
  212. _C._jit_pass_lint(graph)
  213. _C._jit_pass_canonicalize_graph_fuser_ops(graph)
  214. _C._jit_pass_lint(graph)
  215. _C._jit_pass_peephole(graph, True)
  216. _C._jit_pass_fuse_addmm(graph)
  217. _C._jit_pass_lint(graph)
  218. _C._jit_pass_peephole(graph, True)
  219. _C._jit_pass_lower_all_tuples(graph)
  220. # in _jit_pass_onnx, symbolic functions are called for each node for conversion.
  221. # However, there are nodes that cannot be converted without additional context.
  222. # For example, the number of outputs from split (and whether it is static or dynamic) is unknown
  223. # until the point where it is unpacked by listUnpack node.
  224. # This pass does a preprocess, and prepares the nodes such that enough context can be received
  225. # by the symbolic function.
  226. _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module)
  227. _C._jit_pass_onnx_preprocess(graph)
  228. # onnx does not support tuples, so try to remove them
  229. _C._jit_pass_lint(graph)
  230. # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
  231. _C._jit_pass_prepare_division_for_onnx(graph)
  232. _C._jit_pass_onnx_remove_print(graph)
  233. _C._jit_pass_onnx_preprocess_caffe2(graph)
  234. symbolic_helper._quantized_ops.clear()
  235. # Unpack quantized weights for conv and linear ops and insert into graph.
  236. _C._jit_pass_onnx_unpack_quantized_weights(
  237. graph, params_dict, symbolic_helper.is_caffe2_aten_fallback()
  238. )
  239. if symbolic_helper.is_caffe2_aten_fallback():
  240. # Insert permutes before and after each conv op to ensure correct order.
  241. _C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict)
  242. # Find consecutive permutes that are no-ops and remove them.
  243. _C._jit_pass_custom_pattern_based_rewrite_graph(
  244. textwrap.dedent(
  245. """\
  246. graph(%Pi):
  247. %Pq = quantized::nhwc2nchw(%Pi)
  248. %Pr = quantized::nchw2nhwc(%Pq)
  249. return (%Pr)"""
  250. ),
  251. textwrap.dedent(
  252. """\
  253. graph(%Ri):
  254. return (%Ri)"""
  255. ),
  256. graph,
  257. )
  258. # onnx only supports tensors, so we turn all out number types into tensors
  259. _C._jit_pass_erase_number_types(graph)
  260. if GLOBALS.onnx_shape_inference:
  261. input_names = [] if input_names is None else input_names
  262. dynamic_axes = {} if dynamic_axes is None else dynamic_axes
  263. _C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names)
  264. _C._jit_pass_onnx_lint(graph)
  265. graph = _C._jit_pass_onnx(graph, operator_export_type)
  266. _C._jit_pass_onnx_lint(graph)
  267. _C._jit_pass_lint(graph)
  268. _C._jit_pass_onnx_scalar_type_analysis(
  269. graph, True, GLOBALS.export_onnx_opset_version
  270. )
  271. _C._jit_pass_lint(graph)
  272. _C._jit_pass_onnx_peephole(
  273. graph, GLOBALS.export_onnx_opset_version, fixed_batch_size
  274. )
  275. _C._jit_pass_lint(graph)
  276. # graph is not a valid jit graph anymore because types have been replaced
  277. # (e.g. int with Tensor), so it now contains operators that don't actually
  278. # exist. We can't run normal dead code elimination because it'd fail trying
  279. # to look up if an operator has side effects, but we can run a dead code
  280. # elimination variant that doesn't need to look up if an op has side effects.
  281. _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
  282. _C._jit_pass_lint(graph)
  283. graph = _C._jit_pass_canonicalize(graph)
  284. _C._jit_pass_lint(graph)
  285. if GLOBALS.onnx_shape_inference:
  286. _C._jit_pass_onnx_graph_shape_type_inference(
  287. graph, params_dict, GLOBALS.export_onnx_opset_version
  288. )
  289. return graph
  290. def warn_on_static_input_change(input_states):
  291. """Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph.
  292. We accept dictionaries and strings as ONNX inputs, but they should be only for
  293. configuration use. we detect here if these inputs are modified, and if so we warn
  294. the user that the changes won't take effect in the traced ONNX graph.
  295. """
  296. for input, traced_input in zip(input_states[0], input_states[1]):
  297. if isinstance(input, dict):
  298. if list(input.keys()) != list(traced_input.keys()):
  299. warning = (
  300. "We detected that you are modifying a dictionary that is an input to your "
  301. "model. "
  302. "Note that dictionaries are allowed as inputs in ONNX but they should be "
  303. "handled with care. "
  304. "Usages of dictionaries is not recommended, and should not be used except "
  305. "for configuration use. "
  306. "Also note that the order and values of the keys must remain the same. "
  307. )
  308. warnings.warn(warning)
  309. elif isinstance(input, str):
  310. if input != traced_input:
  311. warning = (
  312. "The model seems to have string inputs/outputs. "
  313. "Note that strings will not appear as inputs/outputs of the ONNX graph. "
  314. )
  315. warnings.warn(warning)
  316. def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
  317. """Resolves the arguments that are ignored when export_type != operator_export_type.ONNX."""
  318. if (
  319. operator_export_type is not operator_export_type.ONNX
  320. and _C_onnx._CAFFE2_ATEN_FALLBACK
  321. ):
  322. if arg_value is True:
  323. warnings.warn(
  324. "`{}' can be set to True only when 'operator_export_type' is "
  325. "`ONNX`. Since 'operator_export_type' is not set to 'ONNX', "
  326. "`{}` argument will be ignored.".format(arg_name, arg_name)
  327. )
  328. arg_value = False
  329. return arg_value
  330. def _decide_keep_init_as_input(
  331. keep_initializers_as_inputs: Optional[bool],
  332. operator_export_type: _C_onnx.OperatorExportTypes,
  333. opset_version: int,
  334. ):
  335. """Decides whether the initializers in the graph should be listed as ONNX graph inputs.
  336. This method encapsulates the logic to decide whether the initializers in the graph
  337. should be listed as ONNX graph inputs (i.e., whether to choose ONNX IR v3 or v4).
  338. If keep_initializers_as_inputs is not specified (None), then we decide whether to keep
  339. initializers as graph inputs (val_keep_init_as_ip) based on export type. If export type
  340. is ONNX, then do not keep initializers as input (val_keep_init_as_ip=False). For all other
  341. export types keep initializers as input (val_keep_init_as_ip=True).
  342. If keep_initializers_as_inputs is specified, then respect it. Unless opset version <= 8,
  343. in which case it must be ignored because for opset version <= 8, all initializers MUST be
  344. part of graph input (only ONNX IR v3 is allowed), i.e. val_keep_init_as_ip=True.
  345. Special handling is needed for opset version 8 or lower, because irrespective
  346. of user input for keep_initializers_as_inputs, the graph must follow ONNX IR v3
  347. semantics, i.e. all initializers must be listed as ONNX graph input.
  348. """
  349. if opset_version < 9:
  350. if keep_initializers_as_inputs is False:
  351. warnings.warn(
  352. "Setting 'keep_initializers_as_inputs=False' for opset version"
  353. "8 or lower would lead to an invalid ONNX graph. Therefore, "
  354. "'keep_initializers_as_inputs=False' is ignored during export."
  355. "Exported model will have initializers as graph inputs (compliant "
  356. " to ONNX IR v3)."
  357. )
  358. return True # i.e. True == initializers are part of graph input (ONNX IR v3)
  359. val_keep_init_as_ip = (
  360. True if keep_initializers_as_inputs is None else keep_initializers_as_inputs
  361. )
  362. if (
  363. keep_initializers_as_inputs is None
  364. and operator_export_type is _C_onnx.OperatorExportTypes.ONNX
  365. ):
  366. val_keep_init_as_ip = False
  367. return val_keep_init_as_ip
  368. def _decide_add_node_names(add_node_names, operator_export_type):
  369. return _resolve_args_by_export_type(
  370. "add_node_names", add_node_names, operator_export_type
  371. )
  372. def _decide_constant_folding(do_constant_folding, operator_export_type, training):
  373. do_constant_folding = _resolve_args_by_export_type(
  374. "do_constant_folding", do_constant_folding, operator_export_type
  375. )
  376. if do_constant_folding and (
  377. training is not None and training is not _C_onnx.TrainingMode.EVAL
  378. ):
  379. warnings.warn(
  380. "It is recommended that constant folding be turned off ('do_constant_folding=False') "
  381. "when exporting the model in training-amenable mode, i.e. with 'training=TrainingMode.TRAIN' "
  382. "or 'training=TrainingMode.PRESERVE' (when model is in training mode). Otherwise, some "
  383. "learnable model parameters may not translate correctly in the exported ONNX model "
  384. "because constant folding mutates model parameters. Please consider "
  385. "turning off constant folding or setting the training=TrainingMode.EVAL."
  386. )
  387. return do_constant_folding
  388. def _signature(model) -> inspect.Signature:
  389. should_be_callable = getattr(model, "forward", model)
  390. if callable(should_be_callable):
  391. return inspect.signature(should_be_callable)
  392. raise ValueError("model has no forward method and is not callable")
  393. def _decide_input_format(model, args):
  394. try:
  395. sig = _signature(model)
  396. except ValueError as e:
  397. warnings.warn("%s, skipping _decide_input_format" % e)
  398. return args
  399. try:
  400. ordered_list_keys = list(sig.parameters.keys())
  401. if ordered_list_keys[0] == "self":
  402. ordered_list_keys = ordered_list_keys[1:]
  403. args_dict: Dict = {}
  404. if isinstance(args, list):
  405. args_list = args
  406. elif isinstance(args, tuple):
  407. args_list = list(args)
  408. else:
  409. args_list = [args]
  410. if isinstance(args_list[-1], dict):
  411. args_dict = args_list[-1]
  412. args_list = args_list[:-1]
  413. n_nonkeyword = len(args_list)
  414. for optional_arg in ordered_list_keys[n_nonkeyword:]:
  415. if optional_arg in args_dict:
  416. args_list.append(args_dict[optional_arg])
  417. # Check if this arg has a default value
  418. else:
  419. param = sig.parameters[optional_arg]
  420. if param.default != param.empty:
  421. args_list.append(param.default)
  422. args = args_list if isinstance(args, list) else tuple(args_list)
  423. # Cases of models with no input args
  424. except IndexError:
  425. warnings.warn("No input args, skipping _decide_input_format")
  426. except Exception as e:
  427. warnings.warn("Skipping _decide_input_format\n {}".format(e.args[0]))
  428. return args
  429. def _trace(func, args, operator_export_type, return_outs=False):
  430. # Special case for common case of passing a single Tensor
  431. if isinstance(args, torch.Tensor):
  432. args = (args,)
  433. trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  434. func, args, strict=False, _force_outplace=False, _return_inputs_states=True
  435. )
  436. warn_on_static_input_change(inputs_states)
  437. trace_graph = _optimize_graph(trace_graph, operator_export_type, params_dict={})
  438. if return_outs:
  439. return trace_graph, torch_out
  440. return trace_graph
  441. def _trace_and_get_graph_from_model(model, args):
  442. # A basic sanity check: make sure the state_dict keys are the same
  443. # before and after running the model. Fail fast!
  444. orig_state_dict_keys = torch.jit._unique_state_dict(model).keys()
  445. trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  446. model, args, strict=False, _force_outplace=False, _return_inputs_states=True
  447. )
  448. warn_on_static_input_change(inputs_states)
  449. if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys():
  450. raise RuntimeError(
  451. "state_dict changed after running the tracer; "
  452. "something weird is happening in your model!"
  453. )
  454. return trace_graph, torch_out
  455. def _get_param_count_list(method_graph, args_params):
  456. param_count_list = []
  457. for input_, arg_params_ in zip(method_graph.inputs(), args_params):
  458. if "PackedParams" in str(input_.type()):
  459. in_vars, _ = torch.jit._flatten(arg_params_)
  460. param_count_list.append(len(in_vars))
  461. else:
  462. param_count_list.append(arg_params_ is not None)
  463. return param_count_list
  464. def _check_flatten_did_not_remove(original, jit_flattened):
  465. """torch.jit._flatten removes None. Check if it did so in this case."""
  466. def flatten(x):
  467. if isinstance(x, (list, tuple)):
  468. for inner in x:
  469. for y in flatten(inner):
  470. yield y
  471. elif isinstance(x, dict):
  472. for inner in x.values():
  473. for y in flatten(inner):
  474. yield y
  475. else:
  476. yield x
  477. flattened_with_none = list(flatten(original))
  478. num_none = len(flattened_with_none) - len(jit_flattened)
  479. assert num_none >= 0
  480. if num_none:
  481. raise ValueError(
  482. f"args contained {num_none} None's after flattening. "
  483. "When exporting a ScriptModule or ScriptFunction, no args may "
  484. "be None because that breaks type propagation."
  485. )
  486. def _create_jit_graph(model, args):
  487. torch_out = None
  488. params: Union[List, Tuple]
  489. if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)):
  490. flattened_args = tuple(torch.jit._flatten(tuple(args))[0])
  491. _check_flatten_did_not_remove(args, flattened_args)
  492. if isinstance(model, torch.jit.ScriptModule):
  493. try:
  494. graph = model.forward.graph
  495. except AttributeError as e:
  496. raise RuntimeError("'forward' method must be a script method") from e
  497. _C._jit_pass_onnx_function_substitution(graph)
  498. freezed_m = _C._freeze_module(model._c, preserveParameters=True)
  499. module, params = _C._jit_onnx_list_model_parameters(freezed_m)
  500. method_graph = module._get_method("forward").graph
  501. args_params = tuple(args) + tuple(params)
  502. param_count_list = _get_param_count_list(method_graph, args_params)
  503. in_vars, _ = torch.jit._flatten(args_params)
  504. graph = _C._propagate_and_assign_input_shapes(
  505. method_graph, tuple(in_vars), param_count_list, False, False
  506. )
  507. return graph, params, torch_out, module
  508. elif isinstance(model, torch.jit.ScriptFunction):
  509. params = ()
  510. graph = model.graph
  511. _C._jit_pass_onnx_function_substitution(graph)
  512. param_count_list = _get_param_count_list(graph, args)
  513. # FIXME(justinchuby): flattened_args is possibly unbound
  514. graph = _C._propagate_and_assign_input_shapes(
  515. graph, flattened_args, param_count_list, False, False
  516. )
  517. return graph, params, torch_out, None
  518. else:
  519. graph, torch_out = _trace_and_get_graph_from_model(model, args)
  520. _C._jit_pass_onnx_lint(graph)
  521. state_dict = torch.jit._unique_state_dict(model)
  522. params = list(state_dict.values())
  523. graph_inputs = list(graph.inputs())
  524. user_input_num = len(graph_inputs) - len(state_dict)
  525. param_names = list(state_dict.keys())
  526. for i, inp in enumerate(graph_inputs):
  527. if i >= user_input_num:
  528. inp.setDebugName(param_names[i - user_input_num])
  529. _C._jit_pass_onnx_function_substitution(graph)
  530. return graph, params, torch_out, None
  531. def _get_named_param_dict(graph, params):
  532. input_and_param_names = [val.debugName() for val in graph.inputs()]
  533. param_names = input_and_param_names[len(input_and_param_names) - len(params) :]
  534. _params_dict = dict(zip(param_names, params))
  535. return _params_dict
  536. def _get_example_outputs(model, args):
  537. input_args = copy.deepcopy(args)
  538. input_kwargs = {}
  539. if input_args and isinstance(input_args[-1], dict):
  540. input_kwargs = input_args[-1]
  541. input_args = input_args[:-1]
  542. example_outputs = model(*input_args, **input_kwargs)
  543. if isinstance(example_outputs, list):
  544. example_outputs = [example_outputs]
  545. elif not isinstance(example_outputs, tuple):
  546. example_outputs = (example_outputs,)
  547. return example_outputs
  548. _qtype_vtype_map = {
  549. torch.quint8: torch.uint8,
  550. torch.qint8: torch.int8,
  551. torch.qint32: torch.int32,
  552. torch.quint4x2: torch.int8,
  553. }
  554. def unpack_quantized_tensor(value):
  555. if isinstance(value, torch.Tensor) and value.dtype in _qtype_vtype_map:
  556. q_value_dequantize = value.dequantize()
  557. q_scale = torch.tensor(value.q_scale(), dtype=torch.double)
  558. q_zero_point = torch.tensor(value.q_zero_point(), dtype=torch.int64)
  559. q_value = q_value_dequantize / q_scale + q_zero_point
  560. q_value = q_value.to(dtype=_qtype_vtype_map[value.dtype])
  561. return q_value, q_scale, q_zero_point
  562. else:
  563. return (value,)
  564. def _pre_trace_quant_model(model, args):
  565. r"""Returns `torch.jit.trace(model, args)` if model is quantized. Otherwise do nothing and return
  566. original model.
  567. This is due to https://github.com/pytorch/pytorch/issues/75761.
  568. """
  569. if any(
  570. hasattr(m, "_packed_params") for m in getattr(model, "modules", lambda: [])()
  571. ) or any(getattr(arg, "is_quantized", False) for arg in args):
  572. return torch.jit.trace(model, args)
  573. return model
  574. def _assign_onnx_node_name(graph, node_names):
  575. """Takes in ONNX graph, and mapping from _C.Node to node name in exported ONNX ModelProto.
  576. Returns:
  577. graph (_C.Graph): A TorchScript IR Graph with ONNX nodes, where each _C.Node gets its name
  578. in exported ONNX ModelProto assigned as attribute ``onnx_name``.
  579. """
  580. def n_fn(n, b_fn, node_names):
  581. for b in n.blocks():
  582. b_fn(b, node_names)
  583. if n in node_names:
  584. n.s_("onnx_name", node_names[n])
  585. def b_fn(b, node_names):
  586. for n in b.nodes():
  587. n_fn(n, b_fn, node_names)
  588. b_fn(graph, node_names)
  589. return graph
  590. def _model_to_graph(
  591. model,
  592. args,
  593. verbose=False,
  594. input_names=None,
  595. output_names=None,
  596. operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
  597. do_constant_folding=True,
  598. _disable_torch_constant_prop=False,
  599. fixed_batch_size=False,
  600. training=None,
  601. dynamic_axes=None,
  602. ) -> Tuple[
  603. _C.Graph,
  604. Dict[str, torch.Tensor],
  605. Optional[Union[torch.Tensor, Tuple[torch.Tensor], List[torch.Tensor]]],
  606. ]:
  607. """Converts model into an ONNX graph.
  608. Returns:
  609. graph: A TorchScript IR Graph with ONNX nodes.
  610. params_dict: Dict from input param name to param value.
  611. torch_out: The output tensors resulting from the trace of ``model``.
  612. If ``model`` is a :class:`torch.jit.ScriptModule` or :class:`torch.jit.ScriptFunction`,
  613. this will be None, since we are not doing any tracing.
  614. """
  615. # TODO: can we simplify this to always return a tuple of Tensor or None?
  616. # Special case for common case of passing a single Tensor
  617. if isinstance(args, (torch.Tensor, int, float, bool)):
  618. args = (args,)
  619. model = _pre_trace_quant_model(model, args)
  620. graph, params, torch_out, module = _create_jit_graph(model, args)
  621. params_dict = _get_named_param_dict(graph, params)
  622. try:
  623. graph = _optimize_graph(
  624. graph,
  625. operator_export_type,
  626. _disable_torch_constant_prop=_disable_torch_constant_prop,
  627. fixed_batch_size=fixed_batch_size,
  628. params_dict=params_dict,
  629. dynamic_axes=dynamic_axes,
  630. input_names=input_names,
  631. module=module,
  632. )
  633. except Exception as e:
  634. torch.onnx.log("Torch IR graph at exception: ", graph)
  635. raise
  636. is_script = isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule))
  637. if is_script:
  638. example_outputs = _get_example_outputs(model, args)
  639. example_outputs_final = ()
  640. for example_output in example_outputs:
  641. example_outputs_final += unpack_quantized_tensor(example_output)
  642. out_vars, desc = torch.jit._flatten(example_outputs_final)
  643. _C._jit_pass_onnx_assign_output_shape(
  644. graph, out_vars, desc, GLOBALS.onnx_shape_inference, is_script
  645. )
  646. # NB: ONNX requires complete information about output types, which might be
  647. # erased by some optimizations, so we need to set it explicitly again.
  648. else:
  649. if not isinstance(torch_out, (list, tuple)):
  650. output_wrapped = [torch_out]
  651. else:
  652. output_wrapped = torch_out # type: ignore[assignment]
  653. output_tensors, out_desc = _C._jit_flatten(tuple(output_wrapped))
  654. # assign_output_shape pass is not compatible with quantized outputs.
  655. # Quantized outputs are flattened to 3 values in ONNX, while packed as
  656. # single value in PyTorch.
  657. if not any(getattr(out, "is_quantized", False) for out in output_tensors):
  658. _C._jit_pass_onnx_assign_output_shape(
  659. graph,
  660. output_tensors,
  661. out_desc,
  662. GLOBALS.onnx_shape_inference,
  663. is_script,
  664. )
  665. _set_input_and_output_names(graph, input_names, output_names)
  666. params_dict = _get_named_param_dict(graph, params)
  667. if training is None or training == _C_onnx.TrainingMode.EVAL:
  668. params_dict = _C._jit_pass_onnx_eval_peephole(graph, params_dict)
  669. if (
  670. do_constant_folding
  671. and GLOBALS.export_onnx_opset_version in _constants.onnx_constant_folding_opsets
  672. ):
  673. params_dict = _C._jit_pass_onnx_constant_fold(
  674. graph, params_dict, GLOBALS.export_onnx_opset_version
  675. )
  676. _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
  677. if GLOBALS.onnx_shape_inference:
  678. _C._jit_pass_onnx_graph_shape_type_inference(
  679. graph, params_dict, GLOBALS.export_onnx_opset_version
  680. )
  681. params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
  682. # For ONNX opset < 9, constants only have three data types: float16, float, double.
  683. # In this pass transform constants of other data types to float/double + cast operator.
  684. if GLOBALS.export_onnx_opset_version < 9:
  685. _C._jit_pass_onnx_cast_all_constant_to_floating(graph)
  686. params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict)
  687. _C._jit_decay_packed_param_input_types(graph)
  688. # If output names lack a proper name and are identified only by their unique
  689. # give them a legible name for debugging purposes
  690. _apply_friendly_debug_names(graph, params_dict)
  691. return graph, params_dict, torch_out
  692. def export_to_pretty_string(
  693. model,
  694. args,
  695. export_params=True,
  696. verbose=False,
  697. training=None,
  698. input_names=None,
  699. output_names=None,
  700. operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
  701. export_type=None,
  702. google_printer=False,
  703. opset_version=None,
  704. keep_initializers_as_inputs=None,
  705. custom_opsets=None,
  706. add_node_names=True,
  707. do_constant_folding=True,
  708. dynamic_axes=None,
  709. ):
  710. if opset_version is None:
  711. opset_version = _constants.onnx_default_opset
  712. if custom_opsets is None:
  713. custom_opsets = {}
  714. symbolic_helper._set_opset_version(opset_version)
  715. symbolic_helper._set_operator_export_type(operator_export_type)
  716. symbolic_helper._set_onnx_shape_inference(True)
  717. with exporter_context(model, training, verbose):
  718. val_keep_init_as_ip = _decide_keep_init_as_input(
  719. keep_initializers_as_inputs, operator_export_type, opset_version
  720. )
  721. val_add_node_names = _decide_add_node_names(
  722. add_node_names, operator_export_type
  723. )
  724. val_do_constant_folding = _decide_constant_folding(
  725. do_constant_folding, operator_export_type, training
  726. )
  727. args = _decide_input_format(model, args)
  728. graph, params_dict, torch_out = _model_to_graph(
  729. model,
  730. args,
  731. verbose,
  732. input_names,
  733. output_names,
  734. operator_export_type,
  735. val_do_constant_folding,
  736. training=training,
  737. dynamic_axes=dynamic_axes,
  738. )
  739. return graph._pretty_print_onnx( # type: ignore[attr-defined]
  740. params_dict,
  741. opset_version,
  742. False,
  743. operator_export_type,
  744. google_printer,
  745. val_keep_init_as_ip,
  746. custom_opsets,
  747. val_add_node_names,
  748. )
  749. def unconvertible_ops(
  750. model, args, training=_C_onnx.TrainingMode.EVAL, opset_version=None
  751. ):
  752. r"""
  753. Converts the model with operator_export_type set to
  754. torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH once in order to get a list of
  755. all the ops that are not supported/implemented by the exporter.
  756. Args:
  757. model: Same as corresponding arg to torch.onnx.export.
  758. args: Same as corresponding arg to torch.onnx.export.
  759. training: Same as corresponding arg to torch.onnx.export.
  760. opset_version: Same as corresponding arg to torch.onnx.export.
  761. Returns:
  762. Tuple[torch._C.Graph, List[str]], where the list includes the names
  763. of the unconvertible ops.
  764. """
  765. opset_version = opset_version or _constants.onnx_default_opset
  766. symbolic_helper._set_opset_version(opset_version)
  767. # operator_export_type is set to ONNX_FALLTHROUGH by default so that if an op is not supported
  768. # in ONNX, fall through will occur and export the operator as is, as a custom ONNX op.
  769. with exporter_context(model, training, False):
  770. args = _decide_input_format(model, args)
  771. graph, params_dict, torch_out = _model_to_graph(
  772. model,
  773. args,
  774. # So that if an op connot be converted to ONNX, it will be kept
  775. # as-is rather than cause a failure.
  776. operator_export_type=_C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
  777. )
  778. unsupported_ops = list()
  779. supported_namespaces = ("onnx", "prim", "quantized")
  780. for node in graph.nodes(): # type: ignore[attr-defined]
  781. if node.kind().split(":")[0] not in supported_namespaces:
  782. unsupported_ops.append(node.kind())
  783. return graph, unsupported_ops
  784. def _setup_trace_module_map(model, export_modules_as_functions):
  785. def __setup_trace_module_map():
  786. trace_module_map = {_m: torch.typename(type(_m)) for _m in model.modules()}
  787. torch.jit._trace._trace_module_map = trace_module_map
  788. return trace_module_map
  789. def __register_attribute_hook():
  790. attr_name = "_onnx_attrs"
  791. def _track_module_attributes_forward_pre_hook(module, input):
  792. setattr(module, attr_name, _get_module_attributes(module))
  793. def _track_module_attributes_forward_hook(module, input, output):
  794. tracing_state = _C._get_tracing_state()
  795. if not tracing_state:
  796. return
  797. graph = tracing_state.graph()
  798. onnx_attrs = {}
  799. if hasattr(module, attr_name):
  800. onnx_attrs = getattr(module, attr_name)
  801. delattr(module, attr_name)
  802. _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs)
  803. for m in model.modules():
  804. m.register_forward_hook(_track_module_attributes_forward_hook)
  805. m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook)
  806. if isinstance(export_modules_as_functions, bool) and export_modules_as_functions:
  807. trace_module_map = __setup_trace_module_map()
  808. export_modules_as_functions = {v for k, v in trace_module_map.items()}
  809. elif (
  810. isinstance(export_modules_as_functions, set)
  811. and len(export_modules_as_functions) > 0
  812. ):
  813. def _find_typename(v):
  814. if isinstance(v, type):
  815. return torch.typename(v)
  816. else:
  817. raise RuntimeError(
  818. "Only type of the `nn.Module` should be "
  819. "passed in the set for argument `export_modules_as_functions`. "
  820. "Got `%s`." % (type(v).__name__)
  821. )
  822. trace_module_map = __setup_trace_module_map()
  823. module_typenames = {_find_typename(v) for v in export_modules_as_functions}
  824. export_modules_as_functions = module_typenames
  825. else:
  826. export_modules_as_functions = None
  827. if export_modules_as_functions:
  828. __register_attribute_hook()
  829. return export_modules_as_functions
  830. def _reset_trace_module_map():
  831. torch.jit._trace._trace_module_map = None
  832. _C._jit_pass_onnx_clear_scope_records()
  833. def _get_module_attributes(module):
  834. annotations = typing.get_type_hints(type(module))
  835. base_m_annotations = typing.get_type_hints(torch.nn.Module)
  836. [annotations.pop(k, None) for k in base_m_annotations]
  837. return {k: getattr(module, k) for k in annotations}
  838. def _export(
  839. model,
  840. args,
  841. f,
  842. export_params=True,
  843. verbose=False,
  844. training=None,
  845. input_names=None,
  846. output_names=None,
  847. operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
  848. export_type=None,
  849. opset_version=None,
  850. do_constant_folding=True,
  851. dynamic_axes=None,
  852. keep_initializers_as_inputs=None,
  853. fixed_batch_size=False,
  854. custom_opsets=None,
  855. add_node_names=True,
  856. onnx_shape_inference=True,
  857. export_modules_as_functions=False,
  858. ):
  859. if export_type is None:
  860. export_type = torch.onnx.ExportTypes.PROTOBUF_FILE
  861. if isinstance(model, torch.nn.DataParallel):
  862. raise ValueError(
  863. "torch.nn.DataParallel is not supported by ONNX "
  864. "exporter, please use 'attribute' module to "
  865. "unwrap model from torch.nn.DataParallel. Try "
  866. "torch.onnx.export(model.module, ...)"
  867. )
  868. global __IN_ONNX_EXPORT
  869. assert __IN_ONNX_EXPORT is False
  870. __IN_ONNX_EXPORT = True
  871. try:
  872. symbolic_helper._set_onnx_shape_inference(onnx_shape_inference)
  873. if opset_version is None:
  874. opset_version = _constants.onnx_default_opset
  875. if export_modules_as_functions and opset_version < 15:
  876. raise ValueError(
  877. "`export_modules_as_functions` is not supported for `opset_version` < 15."
  878. "This is because `opset_version` < 15 implies IR version < 8, which means "
  879. "no local function support. "
  880. )
  881. export_modules_as_functions = _setup_trace_module_map(
  882. model, export_modules_as_functions
  883. )
  884. if not operator_export_type:
  885. if _C_onnx._CAFFE2_ATEN_FALLBACK:
  886. operator_export_type = _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
  887. else:
  888. operator_export_type = _C_onnx.OperatorExportTypes.ONNX
  889. # By default, training=None, (which defaults to TrainingMode.EVAL),
  890. # which is good because running a model in training mode could result in
  891. # internal buffers getting updated, dropout getting applied, etc.
  892. # If you really know what you're doing, you can turn
  893. # training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE,
  894. # (to preserve whatever the original training mode was.)
  895. symbolic_helper._set_opset_version(opset_version)
  896. symbolic_helper._set_operator_export_type(operator_export_type)
  897. with exporter_context(model, training, verbose):
  898. val_keep_init_as_ip = _decide_keep_init_as_input(
  899. keep_initializers_as_inputs, operator_export_type, opset_version
  900. )
  901. val_add_node_names = _decide_add_node_names(
  902. add_node_names, operator_export_type
  903. )
  904. val_do_constant_folding = _decide_constant_folding(
  905. do_constant_folding, operator_export_type, training
  906. )
  907. # Normally f can be a file-like object, but for large models, the external data format requires a
  908. # valid `model_file_location`. Code in export.cpp will enforce this.
  909. if isinstance(f, str):
  910. model_file_location = f
  911. else:
  912. model_file_location = str()
  913. args = _decide_input_format(model, args)
  914. if dynamic_axes is None:
  915. dynamic_axes = {}
  916. _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
  917. graph, params_dict, torch_out = _model_to_graph(
  918. model,
  919. args,
  920. verbose,
  921. input_names,
  922. output_names,
  923. operator_export_type,
  924. val_do_constant_folding,
  925. fixed_batch_size=fixed_batch_size,
  926. training=training,
  927. dynamic_axes=dynamic_axes,
  928. )
  929. # TODO: Don't allocate a in-memory string for the protobuf
  930. defer_weight_export = (
  931. export_type is not torch.onnx.ExportTypes.PROTOBUF_FILE
  932. )
  933. if custom_opsets is None:
  934. custom_opsets = {}
  935. _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
  936. node_attr_to_name = {} # type: ignore[var-annotated]
  937. if export_modules_as_functions:
  938. # NOTE: cannot call DCE after this pass. DCE will remove function definition nodes.
  939. node_attr_to_name = _C._jit_pass_onnx_function_extraction(
  940. graph, export_modules_as_functions, list(params_dict.keys())
  941. )
  942. params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment]
  943. graph, params_dict, getattr(model, "training", False) # type: ignore[arg-type]
  944. )
  945. if export_params:
  946. (
  947. proto,
  948. export_map,
  949. val_use_external_data_format,
  950. node_names,
  951. ) = graph._export_onnx( # type: ignore[attr-defined]
  952. params_dict,
  953. opset_version,
  954. dynamic_axes,
  955. defer_weight_export,
  956. operator_export_type,
  957. not verbose,
  958. val_keep_init_as_ip,
  959. custom_opsets,
  960. val_add_node_names,
  961. model_file_location,
  962. node_attr_to_name,
  963. )
  964. else:
  965. (
  966. proto,
  967. export_map,
  968. val_use_external_data_format,
  969. node_names,
  970. ) = graph._export_onnx( # type: ignore[attr-defined]
  971. {},
  972. opset_version,
  973. dynamic_axes,
  974. False,
  975. operator_export_type,
  976. not verbose,
  977. val_keep_init_as_ip,
  978. custom_opsets,
  979. val_add_node_names,
  980. model_file_location,
  981. node_attr_to_name,
  982. )
  983. if verbose:
  984. torch.onnx.log(
  985. "Exported graph: ", _assign_onnx_node_name(graph, node_names)
  986. )
  987. if export_type == torch.onnx.ExportTypes.PROTOBUF_FILE:
  988. assert len(export_map) == 0
  989. with torch.serialization._open_file_like(f, "wb") as opened_file:
  990. opened_file.write(proto)
  991. elif export_type in [
  992. torch.onnx.ExportTypes.ZIP_ARCHIVE,
  993. torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE,
  994. ]:
  995. compression = (
  996. zipfile.ZIP_DEFLATED
  997. if export_type == torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE
  998. else zipfile.ZIP_STORED
  999. )
  1000. with zipfile.ZipFile(f, "w", compression=compression) as z:
  1001. z.writestr(torch.onnx.ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
  1002. for k, v in export_map.items():
  1003. z.writestr(k, v)
  1004. elif export_type == torch.onnx.ExportTypes.DIRECTORY:
  1005. if os.path.exists(f):
  1006. assert os.path.isdir(f)
  1007. else:
  1008. os.makedirs(f)
  1009. model_proto_file = os.path.join(
  1010. f, torch.onnx.ONNX_ARCHIVE_MODEL_PROTO_NAME
  1011. )
  1012. with torch.serialization._open_file_like(
  1013. model_proto_file, "wb"
  1014. ) as opened_file:
  1015. opened_file.write(proto)
  1016. for k, v in export_map.items():
  1017. weight_proto_file = os.path.join(f, k)
  1018. with torch.serialization._open_file_like(
  1019. weight_proto_file, "wb"
  1020. ) as opened_file:
  1021. opened_file.write(v)
  1022. else:
  1023. raise RuntimeError("Unknown export type")
  1024. # The ONNX checker only works for ONNX graph. So if the operator_export_type is not ONNX,
  1025. # we can skip this check.
  1026. # If large model format export is enabled, proto will only contain data location instead of
  1027. # raw data and _check_onnx_proto() will fail because it can only handle the raw ONNX proto
  1028. # string in memory.
  1029. if (operator_export_type is _C_onnx.OperatorExportTypes.ONNX) and (
  1030. not val_use_external_data_format
  1031. ):
  1032. try:
  1033. _C._check_onnx_proto(proto, full_check=True)
  1034. except RuntimeError as e:
  1035. raise torch.onnx.CheckerError(e)
  1036. finally:
  1037. assert __IN_ONNX_EXPORT
  1038. __IN_ONNX_EXPORT = False
  1039. _reset_trace_module_map()
  1040. return torch_out
  1041. def _apply_friendly_debug_names(graph, params):
  1042. for n in graph.nodes():
  1043. for v in n.inputs():
  1044. old_name = v.debugName()
  1045. if old_name != str(v.unique()):
  1046. continue
  1047. new_name = f"{n.kind()}_{v.unique()}"
  1048. v.setDebugName(new_name)
  1049. if old_name in params:
  1050. params[new_name] = params.pop(old_name)
  1051. def _set_input_and_output_names(graph, input_names, output_names):
  1052. def set_names(node_list, name_list, descriptor):
  1053. if name_list is None:
  1054. return
  1055. if len(name_list) > len(node_list):
  1056. raise RuntimeError(
  1057. "number of %s names provided (%d) exceeded number of %ss (%d)"
  1058. % (descriptor, len(name_list), descriptor, len(node_list))
  1059. )
  1060. # Mark if the output node DebugName is set before.
  1061. output_node_set = set()
  1062. for i, (name, node) in enumerate(zip(name_list, node_list)):
  1063. # Duplicated output node, insert onnx::Identity to avoid setting the same DebugName after setDebugName().
  1064. if descriptor == "output":
  1065. if node in output_node_set:
  1066. identity_node = graph.create("onnx::Identity")
  1067. identity_node.insertAfter(node.node())
  1068. identity_node.addInput(node)
  1069. identity_node.output().setType(node.type())
  1070. graph.return_node().replaceInput(i, identity_node.output())
  1071. node = identity_node.output()
  1072. output_node_set.add(node)
  1073. if node.debugName() != name:
  1074. node.setDebugName(name)
  1075. set_names(list(graph.inputs()), input_names, "input")
  1076. set_names(list(graph.outputs()), output_names, "output")
  1077. def _run_symbolic_method(g, op_name, symbolic_fn, args):
  1078. r"""
  1079. This trampoline function gets invoked for every symbolic method
  1080. call from C++.
  1081. """
  1082. try:
  1083. return symbolic_fn(g, *args)
  1084. except TypeError as e:
  1085. # Handle the specific case where we didn't successfully dispatch
  1086. # to symbolic_fn. Otherwise, the backtrace will have the clues
  1087. # you need.
  1088. e.args = ("{} (occurred when translating {})".format(e.args[0], op_name),)
  1089. raise
  1090. def _add_block(node: _C.Node):
  1091. return node.addBlock() # type: ignore[attr-defined]
  1092. def _add_input_to_block(block: _C.Block):
  1093. return block.addInputToBlock() # type: ignore[attr-defined]
  1094. def _add_output_to_block(block: _C.Block, value: _C.Value):
  1095. new_output = block.registerOutput(value) # type: ignore[attr-defined]
  1096. return new_output
  1097. # Note [Export inplace]
  1098. # ~~~~~~~~~~~~~~~~~~~~~
  1099. # In abstract, it would be better for us to export inplace annotations,
  1100. # than to not export them, since it is useful information that can
  1101. # help the target of an ONNX export export more efficiently. However,
  1102. # ONNX doesn't currently formalize inplace. Fortunately, it's sound to drop
  1103. # inplace annotations, but we are losing information this way.
  1104. def _find_symbolic_in_registry(
  1105. domain: str,
  1106. op_name: str,
  1107. opset_version: int,
  1108. operator_export_type: _C_onnx.OperatorExportTypes,
  1109. ) -> Optional[Callable]:
  1110. """Looks up for the symbolic function in the registry.
  1111. Args:
  1112. domain: The domain of the symbolic function.
  1113. op_name: The name of the op.
  1114. opset_version: Currect opset used.
  1115. operator_export_type: An enum in _C_onnx.OperatorExportTypes.
  1116. Returns:
  1117. The symbolic function if found, None otherwise.
  1118. """
  1119. if not symbolic_registry.is_registered_op(op_name, domain, opset_version):
  1120. if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
  1121. # Use the original node directly
  1122. return None
  1123. return symbolic_registry.get_registered_op(op_name, domain, opset_version)
  1124. def _should_aten_fallback(ns, op_name, opset_version, operator_export_type):
  1125. is_exportable_aten_op = symbolic_registry.is_registered_op(
  1126. op_name, "", opset_version
  1127. )
  1128. is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN
  1129. is_aten_fallback_export = (
  1130. operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
  1131. )
  1132. return is_onnx_aten_export or (
  1133. not is_exportable_aten_op and is_aten_fallback_export
  1134. )
  1135. def _need_symbolic_context(symbolic_fn) -> bool:
  1136. """Checks if the first argument to symbolic_fn is annotated as type `torch.onnx.SymbolicContext`."""
  1137. params = tuple(inspect.signature(symbolic_fn).parameters.values())
  1138. # When the annotation is postpone-evaluated, the annotation is a string
  1139. # and not a type. We need to use get_type_hints to get the real type.
  1140. if not params:
  1141. return False
  1142. first_param_name = params[0].name
  1143. type_hints = typing.get_type_hints(symbolic_fn)
  1144. if first_param_name not in type_hints:
  1145. return False
  1146. param_type = type_hints[first_param_name]
  1147. return issubclass(param_type, torch.onnx.SymbolicContext)
  1148. def _get_aten_op_overload_name(n: _C.Node) -> str:
  1149. # Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
  1150. schema = n.schema()
  1151. if not schema.startswith("aten::") or symbolic_helper.is_caffe2_aten_fallback():
  1152. return ""
  1153. return _C.parse_schema(schema).overload_name
  1154. def _run_symbolic_function(
  1155. g: _C.Graph,
  1156. block: _C.Block,
  1157. n: _C.Node,
  1158. inputs: Any,
  1159. env: Dict[_C.Value, _C.Value],
  1160. operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
  1161. ) -> Optional[Union[_C.Value, Tuple[_C.Value, ...]]]:
  1162. """Runs a symbolic function.
  1163. The function is used in C++ to export the node to ONNX.
  1164. Returns:
  1165. A single or a tuple of Values.
  1166. None when the node gets cloned as is into the new graph.
  1167. """
  1168. opset_version = GLOBALS.export_onnx_opset_version
  1169. symbolic_helper.is_caffe2_aten_fallback = symbolic_helper.is_caffe2_aten_fallback
  1170. # See Note [Export inplace]
  1171. # TODO(ezyang): I think this is not necessary anymore
  1172. if n.kind().endswith("_"): # type: ignore[attr-defined]
  1173. ns_op_name = n.kind()[:-1] # type: ignore[attr-defined]
  1174. else:
  1175. ns_op_name = n.kind() # type: ignore[attr-defined]
  1176. ns, op_name = ns_op_name.split("::")
  1177. try:
  1178. symbolic_registry.register_version("", opset_version)
  1179. # Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
  1180. if symbolic_helper.is_caffe2_aten_fallback() and opset_version == 9:
  1181. symbolic_caffe2.register_quantized_ops("caffe2", opset_version)
  1182. if ns == "aten":
  1183. domain = ""
  1184. elif ns == "quantized" and symbolic_helper.is_caffe2_aten_fallback():
  1185. domain = "caffe2"
  1186. else:
  1187. domain = ns
  1188. if symbolic_registry.is_registered_op(op_name, domain, opset_version):
  1189. symbolic_fn = _find_symbolic_in_registry(
  1190. domain, op_name, opset_version, operator_export_type
  1191. )
  1192. assert symbolic_fn is not None
  1193. attrs = {k: n[k] for k in n.attributeNames()} # type: ignore[attr-defined]
  1194. if _need_symbolic_context(symbolic_fn):
  1195. ctx = torch.onnx.SymbolicContext(_params_dict, env, n, block)
  1196. return symbolic_fn(ctx, g, *inputs, **attrs)
  1197. # PythonOp symbolic need access to the node to resolve the name conflict,
  1198. # this is inconsistent with regular op symbolic.
  1199. if op_name == "PythonOp":
  1200. inputs = (n, *inputs)
  1201. return symbolic_fn(g, *inputs, **attrs)
  1202. elif ns == "onnx":
  1203. # Clone node to trigger ONNX shape inference
  1204. attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()} # type: ignore[attr-defined]
  1205. return g.op(op_name, *inputs, **attrs, outputs=n.outputsSize()) # type: ignore[attr-defined]
  1206. elif _should_aten_fallback(ns, op_name, opset_version, operator_export_type):
  1207. # Direct ATen export requested
  1208. attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()} # type: ignore[attr-defined]
  1209. outputs = n.outputsSize()
  1210. attrs["outputs"] = outputs
  1211. # `overload_name` is set for non-Caffe2 builds only
  1212. return g.at( # type: ignore[attr-defined]
  1213. op_name, *inputs, overload_name=_get_aten_op_overload_name(n), **attrs
  1214. )
  1215. else:
  1216. raise symbolic_registry.UnsupportedOperatorError(
  1217. domain, op_name, opset_version
  1218. )
  1219. except RuntimeError:
  1220. if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
  1221. return None
  1222. elif (
  1223. operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
  1224. and not symbolic_helper.is_caffe2_aten_fallback()
  1225. ):
  1226. # Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK`
  1227. attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()} # type: ignore[attr-defined]
  1228. return g.at( # type: ignore[attr-defined]
  1229. op_name, *inputs, overload_name=_get_aten_op_overload_name(n), **attrs
  1230. )
  1231. raise
  1232. except TypeError as e:
  1233. # Handle the specific case where we didn't successfully dispatch.
  1234. # Otherwise, the backtrace will have the clues you need.
  1235. e.args = (f"{e.args[0]} \n(Occurred when translating {op_name}).",)
  1236. raise
  1237. def get_ns_op_name_from_custom_op(symbolic_name):
  1238. if not bool(
  1239. re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name)
  1240. ):
  1241. raise ValueError(
  1242. f"Failed to register operator {symbolic_name}."
  1243. "The symbolic name must match the format Domain::Name, "
  1244. "and should start with a letter and contain only "
  1245. "alphanumerical characters"
  1246. )
  1247. ns, op_name = symbolic_name.split("::")
  1248. if ns == "onnx":
  1249. raise ValueError(
  1250. f"Failed to register operator {symbolic_name}. {ns} domain cannot be modified."
  1251. )
  1252. if ns == "aten":
  1253. ns = ""
  1254. return ns, op_name
  1255. def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
  1256. """Registers a symbolic function for a custom operator.
  1257. When the user registers symbolic for custom/contrib ops,
  1258. it is highly recommended to add shape inference for that operator via setType API,
  1259. otherwise the exported graph may have incorrect shape inference in some extreme cases.
  1260. An example of setType is `test_aten_embedding_2` in `test_operators.py`.
  1261. """
  1262. ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
  1263. for version in itertools.chain(
  1264. _constants.onnx_stable_opsets, [_constants.onnx_main_opset]
  1265. ):
  1266. if version >= opset_version:
  1267. symbolic_registry.register_op(op_name, symbolic_fn, ns, version)
  1268. def unregister_custom_op_symbolic(symbolic_name, opset_version):
  1269. ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
  1270. for version in itertools.chain(
  1271. _constants.onnx_stable_opsets, [_constants.onnx_main_opset]
  1272. ):
  1273. if version >= opset_version:
  1274. symbolic_registry.unregister_op(op_name, ns, version)
  1275. def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
  1276. """Ensures dynamic axes argument is follows the expected format."""
  1277. if len(dynamic_axes) == 0:
  1278. return
  1279. if hasattr(model, "graph"):
  1280. # Extracting set of valid input/output names that shall be used for dynamic_axes
  1281. if (input_names is None) or len(input_names) == 0:
  1282. input_names = [x.debugName() for x in model.graph.inputs()]
  1283. if (output_names is None) or len(output_names) == 0:
  1284. output_names = [y.debugName() for y in model.graph.outputs()]
  1285. valid_names = set((input_names or []) + (output_names or []))
  1286. # If dynamic axes are provided as a list rather than dictionary, they should
  1287. # first get converted to a dictionary in expected format. If desired axes names
  1288. # are not provided for dynamic axes, automatic names shall be generated for
  1289. # provided dynamic axes of specified input/output
  1290. for key, value in dynamic_axes.items():
  1291. if key not in valid_names:
  1292. warnings.warn(
  1293. "Provided key {} for dynamic axes is not a valid input/output name".format(
  1294. key
  1295. )
  1296. )
  1297. if isinstance(value, list):
  1298. warnings.warn(
  1299. "No names were found for specified dynamic axes of provided input."
  1300. "Automatically generated names will be applied to each dynamic axes of input {}".format(
  1301. key
  1302. )
  1303. )
  1304. value_dict = {}
  1305. for i, x in enumerate(value):
  1306. if not isinstance(x, int):
  1307. raise ValueError(
  1308. "The type of axis index is expected to be an integer"
  1309. )
  1310. if x in value_dict:
  1311. warnings.warn(
  1312. "Duplicate dynamic axis index {} was provided for input {}.".format(
  1313. x, key
  1314. )
  1315. )
  1316. else:
  1317. value_dict[x] = str(key) + "_dynamic_axes_" + str(i + 1)
  1318. dynamic_axes[key] = value_dict