optimization.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. import torch.fx as fx
  2. from torch.fx.node import Argument, Target
  3. from torch.nn.utils.fusion import fuse_conv_bn_eval
  4. from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from torch.fx.passes.shape_prop import ShapeProp
  9. import copy
  10. from collections import defaultdict
  11. import torch.utils.mkldnn as th_mkldnn
  12. import operator
  13. import time
  14. import logging
  15. from enum import Enum
  16. def _parent_name(target : str) -> Tuple[str, str]:
  17. """
  18. Splits a qualname into parent path and last atom.
  19. For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
  20. """
  21. *parent, name = target.rsplit('.', 1)
  22. return parent[0] if parent else '', name
  23. # Works for length 2 patterns with 2 modules
  24. def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]):
  25. if len(node.args) == 0:
  26. return False
  27. nodes: Tuple[Any, fx.Node] = (node.args[0], node)
  28. for expected_type, current_node in zip(pattern, nodes):
  29. if not isinstance(current_node, fx.Node):
  30. return False
  31. if current_node.op != 'call_module':
  32. return False
  33. if not isinstance(current_node.target, str):
  34. return False
  35. if current_node.target not in modules:
  36. return False
  37. if type(modules[current_node.target]) is not expected_type:
  38. return False
  39. return True
  40. def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
  41. assert(isinstance(node.target, str))
  42. parent_name, name = _parent_name(node.target)
  43. modules[node.target] = new_module
  44. setattr(modules[parent_name], name, new_module)
  45. def fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
  46. """
  47. Fuses convolution/BN layers for inference purposes. Will deepcopy your
  48. model by default, but can modify the model inplace as well.
  49. """
  50. patterns = [(nn.Conv1d, nn.BatchNorm1d),
  51. (nn.Conv2d, nn.BatchNorm2d),
  52. (nn.Conv3d, nn.BatchNorm3d)]
  53. if not inplace:
  54. model = copy.deepcopy(model)
  55. fx_model = fx.symbolic_trace(model)
  56. modules = dict(fx_model.named_modules())
  57. new_graph = copy.deepcopy(fx_model.graph)
  58. for pattern in patterns:
  59. for node in new_graph.nodes:
  60. if matches_module_pattern(pattern, node, modules):
  61. if len(node.args[0].users) > 1: # Output of conv is used by other nodes
  62. continue
  63. conv = modules[node.args[0].target]
  64. bn = modules[node.target]
  65. if not bn.track_running_stats:
  66. continue
  67. fused_conv = fuse_conv_bn_eval(conv, bn)
  68. replace_node_module(node.args[0], modules, fused_conv)
  69. node.replace_all_uses_with(node.args[0])
  70. new_graph.erase_node(node)
  71. return fx.GraphModule(fx_model, new_graph)
  72. def remove_dropout(model: nn.Module) -> nn.Module:
  73. """
  74. Removes all dropout layers from the module.
  75. """
  76. fx_model = fx.symbolic_trace(model)
  77. class DropoutRemover(torch.fx.Transformer):
  78. def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  79. if isinstance(self.submodules[target], nn.Dropout):
  80. assert len(args) == 1
  81. return args[0]
  82. else:
  83. return super().call_module(target, args, kwargs)
  84. return DropoutRemover(fx_model).transform()
  85. def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]):
  86. """
  87. Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
  88. """
  89. new_graph = fx.Graph()
  90. env: Dict[fx.Node, fx.Node] = {}
  91. for input in inputs:
  92. new_node = new_graph.placeholder(input.name)
  93. env[input] = new_node
  94. for node in nodes:
  95. new_node = new_graph.node_copy(node, lambda x: env[x])
  96. env[node] = new_node
  97. new_graph.output([env[output] for output in outputs])
  98. new_graph.lint()
  99. return fx.GraphModule(orig_module, new_graph)
  100. mkldnn_supported = [
  101. nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d,
  102. torch.relu, torch.transpose, torch.sigmoid,
  103. F.relu, F.avg_pool2d, F.adaptive_avg_pool2d
  104. ]
  105. # These are operators that may not be convertible into MKLDNN ops (e.g. the
  106. # args are scalar values). Thus, we only include them in the subgraph if their
  107. # arguments are already in MKLDNN.
  108. # TODO: Determine whether this can be removed after type inference.
  109. mkldnn_supported_unknown = [operator.add, operator.mul]
  110. mkldnn_map = {
  111. nn.Conv2d: th_mkldnn.MkldnnConv2d,
  112. nn.Linear: th_mkldnn.MkldnnLinear,
  113. nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a)
  114. }
  115. def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
  116. """
  117. For each node, if it's a module that can be preconverted into MKLDNN,
  118. then we do so and create a mapping to allow us to convert from the MKLDNN
  119. version of the module to the original.
  120. """
  121. old_modules: Dict[nn.Module, nn.Module] = {}
  122. for node in nodes:
  123. if node.op == 'call_module':
  124. assert(isinstance(node.target, str))
  125. cur_module = modules[node.target]
  126. if type(cur_module) in mkldnn_map:
  127. new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
  128. assert(isinstance(new_module, nn.Module))
  129. old_modules[new_module] = copy.deepcopy(cur_module)
  130. replace_node_module(node, modules, new_module)
  131. return old_modules
  132. def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modules: Dict[nn.Module, nn.Module]):
  133. """
  134. Maps each module that's been changed with `modules_to_mkldnn` back to its
  135. original.
  136. """
  137. for node in nodes:
  138. if node.op == 'call_module':
  139. assert(isinstance(node.target, str))
  140. cur_module = modules[node.target]
  141. if cur_module in old_modules:
  142. replace_node_module(node, modules, old_modules[cur_module])
  143. class MklSubgraph:
  144. def __init__(self, fx_graph: fx.Graph):
  145. self.fx_graph = fx_graph
  146. self.nodes: List[fx.Node] = []
  147. self.start_nodes: List[fx.Node] = []
  148. self.end_nodes: List[fx.Node] = []
  149. def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
  150. """
  151. This generates a heuristic that can be passed into `optimize_for_inference` that
  152. determines whether a subgraph should be run in MKL by running it with the example_inputs.
  153. Example usage:
  154. heuristic = gen_mkl_autotuner(example_inputs, iters=10)
  155. fast_model = optimization.optimize_for_inference(model, heuristic)
  156. """
  157. fx_model = None
  158. old_modules = None
  159. def use_mkl_heuristic(graph: MklSubgraph) -> bool:
  160. nonlocal fx_model, old_modules
  161. input_nodes = graph.start_nodes
  162. if fx_model is None:
  163. fx_model = graph.fx_graph.owning_module
  164. old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined]
  165. ShapeProp(fx_model).propagate(example_inputs)
  166. sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined]
  167. output_args = cast(List[fx.Node], [node.args[0] for node in graph.end_nodes])
  168. submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args)
  169. def benchmark(f):
  170. for _ in range(warmup):
  171. f()
  172. begin = time.time()
  173. for _ in range(iters):
  174. out = f()
  175. return time.time() - begin
  176. mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])])
  177. reset_modules(submodule.graph.nodes, dict(submodule.named_modules()), old_modules)
  178. no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
  179. return mkl_time < no_mkl_time
  180. return use_mkl_heuristic
  181. def use_mkl_length(graph: MklSubgraph) -> bool:
  182. """
  183. This is a heuristic that can be passed into `optimize_for_inference` that
  184. determines whether a subgraph should be run in MKL by checking if there
  185. are more than 2 nodes in it
  186. """
  187. return len(graph.nodes) > 2
  188. class UnionFind:
  189. def __init__(self, n):
  190. self.parent: List[Optional[int]] = [None] * n
  191. self.size: List[int] = [0] * n
  192. def make_set(self, v: int):
  193. self.parent[v] = v
  194. self.size[v] = 1
  195. def find(self, v: int) -> int:
  196. par = self.parent[v]
  197. if v == par:
  198. return v
  199. assert(par is not None)
  200. self.parent[v] = self.find(par)
  201. return cast(int, self.parent[v])
  202. def join(self, a: int, b: int):
  203. a, b = self.find(a), self.find(b)
  204. if a == b:
  205. return a
  206. if self.size[a] < self.size[b]:
  207. a, b = b, a
  208. self.parent[b] = a
  209. self.size[a] += self.size[b]
  210. def optimize_for_inference(
  211. model: torch.nn.Module,
  212. pass_config: Optional[Dict[str, Any]] = None,
  213. tracer: Type[fx.Tracer] = fx.Tracer
  214. ) -> torch.nn.Module:
  215. """
  216. Performs a set of optimization passes to optimize a model for the
  217. purposes of inference. Specifically, the passes that are run are:
  218. 1. Conv/BN fusion
  219. 2. Dropout removal
  220. 3. MKL layout optimizations
  221. The third optimization takes a function `use_mkl_heuristic` that's used
  222. to determine whether a subgraph should be explicity run in MKL layout.
  223. Note: As FX does not currently handle aliasing, this pass currently
  224. assumes nothing aliases. If that isn't true, use at your own risk.
  225. """
  226. default_pass_config = {
  227. "conv_bn_fuse": True,
  228. "remove_dropout": True,
  229. "mkldnn_layout_optimize": {'heuristic': use_mkl_length},
  230. }
  231. if pass_config is None:
  232. pass_config = {}
  233. default_pass_config.update(pass_config)
  234. if default_pass_config["conv_bn_fuse"]:
  235. model = fuse(model)
  236. if default_pass_config["remove_dropout"]:
  237. model = remove_dropout(model)
  238. if default_pass_config["mkldnn_layout_optimize"] is False:
  239. return model
  240. if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict):
  241. raise RuntimeError("mkldnn_layout_optimize config is not a dict")
  242. if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]:
  243. raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config")
  244. use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"]
  245. cur_tracer = tracer()
  246. fx_graph = cur_tracer.trace(copy.deepcopy(model))
  247. fx_model = fx.GraphModule(cur_tracer.root, fx_graph)
  248. modules: Dict[str, nn.Module] = dict(model.named_modules())
  249. class MklSupport(Enum):
  250. NO = 1
  251. YES = 2
  252. UNKNOWN = 3
  253. # Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node.
  254. # If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node.
  255. # However, if it's in `mkldnn_supported_unknown`, then we only treat it as
  256. # a MKLDNN node if its inputs are MKLDNN nodes.
  257. for node in list(fx_graph.nodes):
  258. supports_mkldnn = MklSupport.NO
  259. if node.op == 'call_module':
  260. cur_module = modules[node.target]
  261. if type(cur_module) in mkldnn_supported:
  262. supports_mkldnn = MklSupport.YES
  263. sample_parameter = next(cur_module.parameters(), None)
  264. if sample_parameter is not None:
  265. assert(sample_parameter.dtype == torch.float), "this pass is only for torch.float modules"
  266. assert(sample_parameter.device == torch.device('cpu')), "this pass is only for CPU modules"
  267. elif node.op == 'call_function':
  268. if node.target in mkldnn_supported:
  269. supports_mkldnn = MklSupport.YES
  270. elif node.target in mkldnn_supported_unknown:
  271. supports_mkldnn = MklSupport.UNKNOWN
  272. if supports_mkldnn != MklSupport.NO:
  273. if supports_mkldnn == MklSupport.UNKNOWN:
  274. if not any([arg.target == 'to_dense' for arg in node.args]):
  275. continue
  276. with fx_graph.inserting_before(node):
  277. mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, )))
  278. node.args = cast(Tuple[fx.node.Argument], mkldnn_args)
  279. with fx_graph.inserting_after(node):
  280. dense_x = fx_graph.create_node('call_method', 'to_dense', (node,))
  281. node.replace_all_uses_with(dense_x)
  282. dense_x.args = (node,)
  283. # Does pre-conversion of all modules into MKLDNN (when possible)
  284. old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules)
  285. fx_graph.old_modules = old_modules # type: ignore[attr-defined]
  286. # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b
  287. for node in fx_graph.nodes:
  288. if node.op == 'call_method' and node.target == 'to_dense':
  289. prv_node = node.args[0]
  290. users = list(node.users)
  291. for user in users:
  292. if user.op == 'call_method' and user.target == 'to_mkldnn':
  293. user.replace_all_uses_with(prv_node)
  294. fx_graph.erase_node(user)
  295. if len(node.users) == 0:
  296. fx_graph.erase_node(node)
  297. num_nodes = len(fx_graph.nodes)
  298. uf = UnionFind(num_nodes)
  299. def get_color(n):
  300. if hasattr(n, 'color'): # Current node is part of a MKL subgraph
  301. return uf.find(n.color)
  302. if hasattr(n, 'start_color'): # Current node is input to MKL subgraph
  303. return uf.find(n.start_color)
  304. return None
  305. # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists
  306. # of input nodes (which are only `to_mkldnn` calls), output nodes
  307. # (`to_dense` calls), and intermediate nodes, which are run entirely on
  308. # MKLDNN layout tensors.
  309. #
  310. # Specifically, this code does a flood fill on a directed acyclic graph
  311. # (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes).
  312. # If every node only had one input, this would be sufficient. However, in
  313. # the case that a node has multiple inputs coming from different start
  314. # nodes (i.e. colors), we need to join these 2 colors into 1. That's done
  315. # using a Disjoint Set Union.
  316. for cur_idx, node in enumerate(fx_graph.nodes):
  317. if node.op == 'call_method' and node.target == 'to_mkldnn':
  318. node.start_color = cur_idx
  319. uf.make_set(cur_idx)
  320. elif node.op == 'call_method' and node.target == 'to_dense':
  321. assert(get_color(node.args[0]) is not None)
  322. node.end_color = get_color(node.args[0])
  323. else:
  324. cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None]
  325. if len(cur_colors) == 0:
  326. continue
  327. assert(not any(i is None for i in cur_colors))
  328. cur_colors = sorted(cur_colors)
  329. node.color = cur_colors[0]
  330. for other_color in cur_colors[1:]:
  331. uf.join(cur_colors[0], other_color)
  332. mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph))
  333. for node in fx_graph.nodes:
  334. if hasattr(node, 'color'):
  335. mkldnn_graphs[uf.find(node.color)].nodes.append(node)
  336. if hasattr(node, 'start_color'):
  337. mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node)
  338. if hasattr(node, 'end_color'):
  339. mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node)
  340. # Now that we have all the subgraphs, we need to decide which MKLDNN
  341. # subgraphs we actually want to keep in MKLDNN.
  342. for graph in mkldnn_graphs.values():
  343. if not use_mkl_heuristic(graph):
  344. for node in graph.start_nodes + graph.end_nodes:
  345. prv = node.args[0]
  346. node.replace_all_uses_with(prv)
  347. fx_graph.erase_node(node)
  348. reset_modules(graph.nodes, modules, old_modules)
  349. mkldnn_conversions = 0
  350. for node in fx_graph.nodes:
  351. if node.target == 'to_mkldnn' or node.target == 'to_dense':
  352. mkldnn_conversions += 1
  353. logging.getLogger(__name__).info(f"mkldnn conversions: {mkldnn_conversions}")
  354. fx_graph.lint()
  355. result = fx.GraphModule(model, fx_graph)
  356. return result