utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. import contextlib
  2. import functools
  3. import hashlib
  4. import os
  5. import re
  6. import textwrap
  7. import sys
  8. from argparse import Namespace
  9. from dataclasses import (
  10. fields,
  11. is_dataclass,
  12. )
  13. from typing import (
  14. Tuple,
  15. List,
  16. Iterable,
  17. Iterator,
  18. Callable,
  19. Sequence,
  20. TypeVar,
  21. Optional,
  22. Dict,
  23. Any,
  24. Union,
  25. Set,
  26. NoReturn,
  27. )
  28. from enum import Enum
  29. from torchgen.code_template import CodeTemplate
  30. # Safely load fast C Yaml loader/dumper if they are available
  31. try:
  32. from yaml import CSafeLoader as Loader
  33. except ImportError:
  34. from yaml import SafeLoader as Loader # type: ignore[misc]
  35. try:
  36. from yaml import CSafeDumper as Dumper
  37. except ImportError:
  38. from yaml import SafeDumper as Dumper # type: ignore[misc]
  39. YamlDumper = Dumper
  40. # A custom loader for YAML that errors on duplicate keys.
  41. # This doesn't happen by default: see https://github.com/yaml/pyyaml/issues/165
  42. class YamlLoader(Loader):
  43. def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
  44. mapping = []
  45. for key_node, value_node in node.value:
  46. key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
  47. assert (
  48. key not in mapping
  49. ), f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}"
  50. mapping.append(key)
  51. mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
  52. return mapping
  53. # Many of these functions share logic for defining both the definition
  54. # and declaration (for example, the function signature is the same), so
  55. # we organize them into one function that takes a Target to say which
  56. # code we want.
  57. #
  58. # This is an OPEN enum (we may add more cases to it in the future), so be sure
  59. # to explicitly specify with Union[Literal[Target.XXX]] what targets are valid
  60. # for your use.
  61. Target = Enum(
  62. "Target",
  63. (
  64. # top level namespace (not including at)
  65. "DEFINITION",
  66. "DECLARATION",
  67. # TORCH_LIBRARY(...) { ... }
  68. "REGISTRATION",
  69. # namespace { ... }
  70. "ANONYMOUS_DEFINITION",
  71. # namespace cpu { ... }
  72. "NAMESPACED_DEFINITION",
  73. "NAMESPACED_DECLARATION",
  74. ),
  75. )
  76. # Matches "foo" in "foo, bar" but not "foobar". Used to search for the
  77. # occurrence of a parameter in the derivative formula
  78. IDENT_REGEX = r"(^|\W){}($|\W)"
  79. # TODO: Use a real parser here; this will get bamboozled
  80. def split_name_params(schema: str) -> Tuple[str, List[str]]:
  81. m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
  82. if m is None:
  83. raise RuntimeError(f"Unsupported function schema: {schema}")
  84. name, _, params = m.groups()
  85. return name, params.split(", ")
  86. T = TypeVar("T")
  87. S = TypeVar("S")
  88. # These two functions purposely return generators in analogy to map()
  89. # so that you don't mix up when you need to list() them
  90. # Map over function that may return None; omit Nones from output sequence
  91. def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]:
  92. for x in xs:
  93. r = func(x)
  94. if r is not None:
  95. yield r
  96. # Map over function that returns sequences and cat them all together
  97. def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
  98. for x in xs:
  99. for r in func(x):
  100. yield r
  101. # Conveniently add error context to exceptions raised. Lets us
  102. # easily say that an error occurred while processing a specific
  103. # context.
  104. @contextlib.contextmanager
  105. def context(msg_fn: Callable[[], str]) -> Iterator[None]:
  106. try:
  107. yield
  108. except Exception as e:
  109. # TODO: this does the wrong thing with KeyError
  110. msg = msg_fn()
  111. msg = textwrap.indent(msg, " ")
  112. msg = f"{e.args[0]}\n{msg}" if e.args else msg
  113. e.args = (msg,) + e.args[1:]
  114. raise
  115. # A little trick from https://github.com/python/mypy/issues/6366
  116. # for getting mypy to do exhaustiveness checking
  117. # TODO: put this somewhere else, maybe
  118. def assert_never(x: NoReturn) -> NoReturn:
  119. raise AssertionError("Unhandled type: {}".format(type(x).__name__))
  120. @functools.lru_cache(maxsize=None)
  121. def _read_template(template_fn: str) -> CodeTemplate:
  122. return CodeTemplate.from_file(template_fn)
  123. # String hash that's stable across different executions, unlike builtin hash
  124. def string_stable_hash(s: str) -> int:
  125. sha1 = hashlib.sha1(s.encode("latin1")).digest()
  126. return int.from_bytes(sha1, byteorder="little")
  127. # A small abstraction for writing out generated files and keeping track
  128. # of what files have been written (so you can write out a list of output
  129. # files)
  130. class FileManager:
  131. install_dir: str
  132. template_dir: str
  133. dry_run: bool
  134. filenames: Set[str]
  135. def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
  136. self.install_dir = install_dir
  137. self.template_dir = template_dir
  138. self.filenames = set()
  139. self.dry_run = dry_run
  140. def _write_if_changed(self, filename: str, contents: str) -> None:
  141. old_contents: Optional[str]
  142. try:
  143. with open(filename, "r") as f:
  144. old_contents = f.read()
  145. except IOError:
  146. old_contents = None
  147. if contents != old_contents:
  148. # Create output directory if it doesn't exist
  149. os.makedirs(os.path.dirname(filename), exist_ok=True)
  150. with open(filename, "w") as f:
  151. f.write(contents)
  152. def write_with_template(
  153. self,
  154. filename: str,
  155. template_fn: str,
  156. env_callable: Callable[[], Union[str, Dict[str, Any]]],
  157. ) -> None:
  158. filename = "{}/{}".format(self.install_dir, filename)
  159. assert filename not in self.filenames, "duplicate file write {filename}"
  160. self.filenames.add(filename)
  161. if not self.dry_run:
  162. env = env_callable()
  163. if isinstance(env, dict):
  164. # TODO: Update the comment reference to the correct location
  165. if "generated_comment" not in env:
  166. comment = "@" + "generated by torchgen/gen.py"
  167. comment += " from {}".format(os.path.basename(template_fn))
  168. env["generated_comment"] = comment
  169. template = _read_template(os.path.join(self.template_dir, template_fn))
  170. self._write_if_changed(filename, template.substitute(env))
  171. elif isinstance(env, str):
  172. self._write_if_changed(filename, env)
  173. else:
  174. assert_never(env)
  175. def write(
  176. self,
  177. filename: str,
  178. env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]],
  179. ) -> None:
  180. self.write_with_template(filename, filename, env_callable)
  181. def write_sharded(
  182. self,
  183. filename: str,
  184. items: Iterable[T],
  185. *,
  186. key_fn: Callable[[T], str],
  187. env_callable: Callable[[T], Dict[str, List[str]]],
  188. num_shards: int,
  189. base_env: Optional[Dict[str, Any]] = None,
  190. sharded_keys: Set[str],
  191. ) -> None:
  192. everything: Dict[str, Any] = {"shard_id": "Everything"}
  193. shards: List[Dict[str, Any]] = [
  194. {"shard_id": f"_{i}"} for i in range(num_shards)
  195. ]
  196. all_shards = [everything] + shards
  197. if base_env is not None:
  198. for shard in all_shards:
  199. shard.update(base_env)
  200. for key in sharded_keys:
  201. for shard in all_shards:
  202. if key in shard:
  203. assert isinstance(
  204. shard[key], list
  205. ), "sharded keys in base_env must be a list"
  206. shard[key] = shard[key].copy()
  207. else:
  208. shard[key] = []
  209. def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
  210. for k, v in from_.items():
  211. assert k in sharded_keys, f"undeclared sharded key {k}"
  212. into[k] += v
  213. if self.dry_run:
  214. # Dry runs don't write any templates, so incomplete environments are fine
  215. items = ()
  216. for item in items:
  217. key = key_fn(item)
  218. sid = string_stable_hash(key) % num_shards
  219. env = env_callable(item)
  220. merge_env(shards[sid], env)
  221. merge_env(everything, env)
  222. dot_pos = filename.rfind(".")
  223. if dot_pos == -1:
  224. dot_pos = len(filename)
  225. base_filename = filename[:dot_pos]
  226. extension = filename[dot_pos:]
  227. for shard in all_shards:
  228. shard_id = shard["shard_id"]
  229. self.write_with_template(
  230. f"{base_filename}{shard_id}{extension}", filename, lambda: shard
  231. )
  232. # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
  233. self.filenames.discard(
  234. f"{self.install_dir}/{base_filename}Everything{extension}"
  235. )
  236. def write_outputs(self, variable_name: str, filename: str) -> None:
  237. """Write a file containing the list of all outputs which are
  238. generated by this script."""
  239. content = "set({}\n {})".format(
  240. variable_name,
  241. "\n ".join('"' + name + '"' for name in sorted(self.filenames)),
  242. )
  243. self._write_if_changed(filename, content)
  244. # Helper function to generate file manager
  245. def make_file_manager(
  246. options: Namespace, install_dir: Optional[str] = None
  247. ) -> FileManager:
  248. template_dir = os.path.join(options.source_path, "templates")
  249. install_dir = install_dir if install_dir else options.install_dir
  250. return FileManager(
  251. install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run
  252. )
  253. # Helper function to create a pretty representation for dataclasses
  254. def dataclass_repr(
  255. obj: Any,
  256. indent: int = 0,
  257. width: int = 80,
  258. ) -> str:
  259. # built-in pprint module support dataclasses from python 3.10
  260. if sys.version_info >= (3, 10):
  261. from pprint import pformat
  262. return pformat(obj, indent, width)
  263. return _pformat(obj, indent=indent, width=width)
  264. def _pformat(
  265. obj: Any,
  266. indent: int,
  267. width: int,
  268. curr_indent: int = 0,
  269. ) -> str:
  270. assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}"
  271. class_name = obj.__class__.__name__
  272. # update current indentation level with class name
  273. curr_indent += len(class_name) + 1
  274. fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr]
  275. fields_str = []
  276. for name, attr in fields_list:
  277. # update the current indent level with the field name
  278. # dict, list, set and tuple also add indent as done in pprint
  279. _curr_indent = curr_indent + len(name) + 1
  280. if is_dataclass(attr):
  281. str_repr = _pformat(attr, indent, width, _curr_indent)
  282. elif isinstance(attr, dict):
  283. str_repr = _format_dict(attr, indent, width, _curr_indent)
  284. elif isinstance(attr, (list, set, tuple)):
  285. str_repr = _format_list(attr, indent, width, _curr_indent)
  286. else:
  287. str_repr = repr(attr)
  288. fields_str.append(f"{name}={str_repr}")
  289. indent_str = curr_indent * " "
  290. body = f",\n{indent_str}".join(fields_str)
  291. return f"{class_name}({body})"
  292. def _format_dict(
  293. attr: Dict[Any, Any],
  294. indent: int,
  295. width: int,
  296. curr_indent: int,
  297. ) -> str:
  298. curr_indent += indent + 3
  299. dict_repr = []
  300. for k, v in attr.items():
  301. k_repr = repr(k)
  302. v_str = (
  303. _pformat(v, indent, width, curr_indent + len(k_repr))
  304. if is_dataclass(v)
  305. else repr(v)
  306. )
  307. dict_repr.append(f"{k_repr}: {v_str}")
  308. return _format(dict_repr, indent, width, curr_indent, "{", "}")
  309. def _format_list(
  310. attr: Union[List[Any], Set[Any], Tuple[Any, ...]],
  311. indent: int,
  312. width: int,
  313. curr_indent: int,
  314. ) -> str:
  315. curr_indent += indent + 1
  316. list_repr = [
  317. _pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l)
  318. for l in attr
  319. ]
  320. start, end = ("[", "]") if isinstance(attr, list) else ("(", ")")
  321. return _format(list_repr, indent, width, curr_indent, start, end)
  322. def _format(
  323. fields_str: List[str],
  324. indent: int,
  325. width: int,
  326. curr_indent: int,
  327. start: str,
  328. end: str,
  329. ) -> str:
  330. delimiter, curr_indent_str = "", ""
  331. # if it exceed the max width then we place one element per line
  332. if len(repr(fields_str)) >= width:
  333. delimiter = "\n"
  334. curr_indent_str = " " * curr_indent
  335. indent_str = " " * indent
  336. body = f", {delimiter}{curr_indent_str}".join(fields_str)
  337. return f"{start}{indent_str}{body}{end}"