| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398 |
- import contextlib
- import functools
- import hashlib
- import os
- import re
- import textwrap
- import sys
- from argparse import Namespace
- from dataclasses import (
- fields,
- is_dataclass,
- )
- from typing import (
- Tuple,
- List,
- Iterable,
- Iterator,
- Callable,
- Sequence,
- TypeVar,
- Optional,
- Dict,
- Any,
- Union,
- Set,
- NoReturn,
- )
- from enum import Enum
- from torchgen.code_template import CodeTemplate
- # Safely load fast C Yaml loader/dumper if they are available
- try:
- from yaml import CSafeLoader as Loader
- except ImportError:
- from yaml import SafeLoader as Loader # type: ignore[misc]
- try:
- from yaml import CSafeDumper as Dumper
- except ImportError:
- from yaml import SafeDumper as Dumper # type: ignore[misc]
- YamlDumper = Dumper
- # A custom loader for YAML that errors on duplicate keys.
- # This doesn't happen by default: see https://github.com/yaml/pyyaml/issues/165
- class YamlLoader(Loader):
- def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
- mapping = []
- for key_node, value_node in node.value:
- key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
- assert (
- key not in mapping
- ), f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}"
- mapping.append(key)
- mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
- return mapping
- # Many of these functions share logic for defining both the definition
- # and declaration (for example, the function signature is the same), so
- # we organize them into one function that takes a Target to say which
- # code we want.
- #
- # This is an OPEN enum (we may add more cases to it in the future), so be sure
- # to explicitly specify with Union[Literal[Target.XXX]] what targets are valid
- # for your use.
- Target = Enum(
- "Target",
- (
- # top level namespace (not including at)
- "DEFINITION",
- "DECLARATION",
- # TORCH_LIBRARY(...) { ... }
- "REGISTRATION",
- # namespace { ... }
- "ANONYMOUS_DEFINITION",
- # namespace cpu { ... }
- "NAMESPACED_DEFINITION",
- "NAMESPACED_DECLARATION",
- ),
- )
- # Matches "foo" in "foo, bar" but not "foobar". Used to search for the
- # occurrence of a parameter in the derivative formula
- IDENT_REGEX = r"(^|\W){}($|\W)"
- # TODO: Use a real parser here; this will get bamboozled
- def split_name_params(schema: str) -> Tuple[str, List[str]]:
- m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
- if m is None:
- raise RuntimeError(f"Unsupported function schema: {schema}")
- name, _, params = m.groups()
- return name, params.split(", ")
- T = TypeVar("T")
- S = TypeVar("S")
- # These two functions purposely return generators in analogy to map()
- # so that you don't mix up when you need to list() them
- # Map over function that may return None; omit Nones from output sequence
- def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]:
- for x in xs:
- r = func(x)
- if r is not None:
- yield r
- # Map over function that returns sequences and cat them all together
- def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
- for x in xs:
- for r in func(x):
- yield r
- # Conveniently add error context to exceptions raised. Lets us
- # easily say that an error occurred while processing a specific
- # context.
- @contextlib.contextmanager
- def context(msg_fn: Callable[[], str]) -> Iterator[None]:
- try:
- yield
- except Exception as e:
- # TODO: this does the wrong thing with KeyError
- msg = msg_fn()
- msg = textwrap.indent(msg, " ")
- msg = f"{e.args[0]}\n{msg}" if e.args else msg
- e.args = (msg,) + e.args[1:]
- raise
- # A little trick from https://github.com/python/mypy/issues/6366
- # for getting mypy to do exhaustiveness checking
- # TODO: put this somewhere else, maybe
- def assert_never(x: NoReturn) -> NoReturn:
- raise AssertionError("Unhandled type: {}".format(type(x).__name__))
- @functools.lru_cache(maxsize=None)
- def _read_template(template_fn: str) -> CodeTemplate:
- return CodeTemplate.from_file(template_fn)
- # String hash that's stable across different executions, unlike builtin hash
- def string_stable_hash(s: str) -> int:
- sha1 = hashlib.sha1(s.encode("latin1")).digest()
- return int.from_bytes(sha1, byteorder="little")
- # A small abstraction for writing out generated files and keeping track
- # of what files have been written (so you can write out a list of output
- # files)
- class FileManager:
- install_dir: str
- template_dir: str
- dry_run: bool
- filenames: Set[str]
- def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
- self.install_dir = install_dir
- self.template_dir = template_dir
- self.filenames = set()
- self.dry_run = dry_run
- def _write_if_changed(self, filename: str, contents: str) -> None:
- old_contents: Optional[str]
- try:
- with open(filename, "r") as f:
- old_contents = f.read()
- except IOError:
- old_contents = None
- if contents != old_contents:
- # Create output directory if it doesn't exist
- os.makedirs(os.path.dirname(filename), exist_ok=True)
- with open(filename, "w") as f:
- f.write(contents)
- def write_with_template(
- self,
- filename: str,
- template_fn: str,
- env_callable: Callable[[], Union[str, Dict[str, Any]]],
- ) -> None:
- filename = "{}/{}".format(self.install_dir, filename)
- assert filename not in self.filenames, "duplicate file write {filename}"
- self.filenames.add(filename)
- if not self.dry_run:
- env = env_callable()
- if isinstance(env, dict):
- # TODO: Update the comment reference to the correct location
- if "generated_comment" not in env:
- comment = "@" + "generated by torchgen/gen.py"
- comment += " from {}".format(os.path.basename(template_fn))
- env["generated_comment"] = comment
- template = _read_template(os.path.join(self.template_dir, template_fn))
- self._write_if_changed(filename, template.substitute(env))
- elif isinstance(env, str):
- self._write_if_changed(filename, env)
- else:
- assert_never(env)
- def write(
- self,
- filename: str,
- env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]],
- ) -> None:
- self.write_with_template(filename, filename, env_callable)
- def write_sharded(
- self,
- filename: str,
- items: Iterable[T],
- *,
- key_fn: Callable[[T], str],
- env_callable: Callable[[T], Dict[str, List[str]]],
- num_shards: int,
- base_env: Optional[Dict[str, Any]] = None,
- sharded_keys: Set[str],
- ) -> None:
- everything: Dict[str, Any] = {"shard_id": "Everything"}
- shards: List[Dict[str, Any]] = [
- {"shard_id": f"_{i}"} for i in range(num_shards)
- ]
- all_shards = [everything] + shards
- if base_env is not None:
- for shard in all_shards:
- shard.update(base_env)
- for key in sharded_keys:
- for shard in all_shards:
- if key in shard:
- assert isinstance(
- shard[key], list
- ), "sharded keys in base_env must be a list"
- shard[key] = shard[key].copy()
- else:
- shard[key] = []
- def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
- for k, v in from_.items():
- assert k in sharded_keys, f"undeclared sharded key {k}"
- into[k] += v
- if self.dry_run:
- # Dry runs don't write any templates, so incomplete environments are fine
- items = ()
- for item in items:
- key = key_fn(item)
- sid = string_stable_hash(key) % num_shards
- env = env_callable(item)
- merge_env(shards[sid], env)
- merge_env(everything, env)
- dot_pos = filename.rfind(".")
- if dot_pos == -1:
- dot_pos = len(filename)
- base_filename = filename[:dot_pos]
- extension = filename[dot_pos:]
- for shard in all_shards:
- shard_id = shard["shard_id"]
- self.write_with_template(
- f"{base_filename}{shard_id}{extension}", filename, lambda: shard
- )
- # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
- self.filenames.discard(
- f"{self.install_dir}/{base_filename}Everything{extension}"
- )
- def write_outputs(self, variable_name: str, filename: str) -> None:
- """Write a file containing the list of all outputs which are
- generated by this script."""
- content = "set({}\n {})".format(
- variable_name,
- "\n ".join('"' + name + '"' for name in sorted(self.filenames)),
- )
- self._write_if_changed(filename, content)
- # Helper function to generate file manager
- def make_file_manager(
- options: Namespace, install_dir: Optional[str] = None
- ) -> FileManager:
- template_dir = os.path.join(options.source_path, "templates")
- install_dir = install_dir if install_dir else options.install_dir
- return FileManager(
- install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run
- )
- # Helper function to create a pretty representation for dataclasses
- def dataclass_repr(
- obj: Any,
- indent: int = 0,
- width: int = 80,
- ) -> str:
- # built-in pprint module support dataclasses from python 3.10
- if sys.version_info >= (3, 10):
- from pprint import pformat
- return pformat(obj, indent, width)
- return _pformat(obj, indent=indent, width=width)
- def _pformat(
- obj: Any,
- indent: int,
- width: int,
- curr_indent: int = 0,
- ) -> str:
- assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}"
- class_name = obj.__class__.__name__
- # update current indentation level with class name
- curr_indent += len(class_name) + 1
- fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr]
- fields_str = []
- for name, attr in fields_list:
- # update the current indent level with the field name
- # dict, list, set and tuple also add indent as done in pprint
- _curr_indent = curr_indent + len(name) + 1
- if is_dataclass(attr):
- str_repr = _pformat(attr, indent, width, _curr_indent)
- elif isinstance(attr, dict):
- str_repr = _format_dict(attr, indent, width, _curr_indent)
- elif isinstance(attr, (list, set, tuple)):
- str_repr = _format_list(attr, indent, width, _curr_indent)
- else:
- str_repr = repr(attr)
- fields_str.append(f"{name}={str_repr}")
- indent_str = curr_indent * " "
- body = f",\n{indent_str}".join(fields_str)
- return f"{class_name}({body})"
- def _format_dict(
- attr: Dict[Any, Any],
- indent: int,
- width: int,
- curr_indent: int,
- ) -> str:
- curr_indent += indent + 3
- dict_repr = []
- for k, v in attr.items():
- k_repr = repr(k)
- v_str = (
- _pformat(v, indent, width, curr_indent + len(k_repr))
- if is_dataclass(v)
- else repr(v)
- )
- dict_repr.append(f"{k_repr}: {v_str}")
- return _format(dict_repr, indent, width, curr_indent, "{", "}")
- def _format_list(
- attr: Union[List[Any], Set[Any], Tuple[Any, ...]],
- indent: int,
- width: int,
- curr_indent: int,
- ) -> str:
- curr_indent += indent + 1
- list_repr = [
- _pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l)
- for l in attr
- ]
- start, end = ("[", "]") if isinstance(attr, list) else ("(", ")")
- return _format(list_repr, indent, width, curr_indent, start, end)
- def _format(
- fields_str: List[str],
- indent: int,
- width: int,
- curr_indent: int,
- start: str,
- end: str,
- ) -> str:
- delimiter, curr_indent_str = "", ""
- # if it exceed the max width then we place one element per line
- if len(repr(fields_str)) >= width:
- delimiter = "\n"
- curr_indent_str = " " * curr_indent
- indent_str = " " * indent
- body = f", {delimiter}{curr_indent_str}".join(fields_str)
- return f"{start}{indent_str}{body}{end}"
|