annotations.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. import ast
  2. import enum
  3. import inspect
  4. import re
  5. import builtins
  6. import torch
  7. import warnings
  8. from .._jit_internal import List, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
  9. is_optional, _qualified_name, Any, Future, is_future, is_ignored_fn, Union, is_union
  10. from .._jit_internal import BroadcastingList1, BroadcastingList2, BroadcastingList3 # type: ignore[attr-defined]
  11. from ._state import _get_script_class
  12. from torch._C import TensorType, TupleType, FloatType, IntType, ComplexType, \
  13. ListType, StringType, DictType, BoolType, OptionalType, InterfaceType, AnyType, \
  14. NoneType, DeviceObjType, StreamObjType, FutureType, EnumType, UnionType, NumberType
  15. from textwrap import dedent
  16. from torch._sources import get_source_lines_and_file
  17. from typing import Type
  18. if torch.distributed.rpc.is_available():
  19. from .._jit_internal import RRef, is_rref
  20. from torch._C import RRefType
  21. from torch._ops import OpOverloadPacket
  22. class Module(object):
  23. def __init__(self, name, members):
  24. self.name = name
  25. self.members = members
  26. def __getattr__(self, name):
  27. try:
  28. return self.members[name]
  29. except KeyError:
  30. raise RuntimeError(f"Module {self.name} has no member called {name}") from None
  31. class EvalEnv(object):
  32. env = {
  33. 'torch': Module('torch', {'Tensor': torch.Tensor}),
  34. 'Tensor': torch.Tensor,
  35. 'typing': Module('typing', {'Tuple': Tuple}),
  36. 'Tuple': Tuple,
  37. 'List': List,
  38. 'Dict': Dict,
  39. 'Optional': Optional,
  40. 'Union': Union,
  41. 'Future': Future
  42. }
  43. def __init__(self, rcb):
  44. self.rcb = rcb
  45. if torch.distributed.rpc.is_available():
  46. self.env['RRef'] = RRef
  47. def __getitem__(self, name):
  48. if name in self.env:
  49. return self.env[name]
  50. if self.rcb is not None:
  51. return self.rcb(name)
  52. return getattr(builtins, name, None)
  53. def get_signature(fn, rcb, loc, is_method):
  54. if isinstance(fn, OpOverloadPacket):
  55. signature = try_real_annotations(fn.op, loc)
  56. else:
  57. signature = try_real_annotations(fn, loc)
  58. if signature is not None and is_method:
  59. # If this is a method, then the signature will include a type for
  60. # `self`, but type comments do not contain a `self`. So strip it
  61. # away here so everything is consistent (`inspect.ismethod` does
  62. # not work here since `fn` is unbound at this point)
  63. param_types, return_type = signature
  64. param_types = param_types[1:]
  65. signature = (param_types, return_type)
  66. if signature is None:
  67. type_line, source = None, None
  68. try:
  69. source = dedent(''.join(get_source_lines_and_file(fn)[0]))
  70. type_line = get_type_line(source)
  71. except TypeError:
  72. pass
  73. # This might happen both because we failed to get the source of fn, or
  74. # because it didn't have any annotations.
  75. if type_line is not None:
  76. signature = parse_type_line(type_line, rcb, loc)
  77. return signature
  78. def is_function_or_method(the_callable):
  79. # A stricter version of `inspect.isroutine` that does not pass for built-in
  80. # functions
  81. return inspect.isfunction(the_callable) or inspect.ismethod(the_callable)
  82. def is_vararg(the_callable):
  83. if not is_function_or_method(the_callable) and hasattr(the_callable, '__call__'): # noqa: B004
  84. # If `the_callable` is a class, de-sugar the call so we can still get
  85. # the signature
  86. the_callable = the_callable.__call__
  87. if is_function_or_method(the_callable):
  88. return inspect.getfullargspec(the_callable).varargs is not None
  89. else:
  90. return False
  91. def get_param_names(fn, n_args):
  92. if isinstance(fn, OpOverloadPacket):
  93. fn = fn.op
  94. if not is_function_or_method(fn) and hasattr(fn, '__call__') and is_function_or_method(fn.__call__): # noqa: B004
  95. # De-sugar calls to classes
  96. fn = fn.__call__
  97. if is_function_or_method(fn):
  98. if is_ignored_fn(fn):
  99. fn = inspect.unwrap(fn)
  100. return inspect.getfullargspec(fn).args
  101. else:
  102. # The `fn` was not a method or function (maybe a class with a __call__
  103. # method, so use a default param name list)
  104. return [str(i) for i in range(n_args)]
  105. def check_fn(fn, loc):
  106. # Make sure the function definition is not a class instantiation
  107. try:
  108. source = dedent(''.join(get_source_lines_and_file(fn)[0]))
  109. except (TypeError, IOError):
  110. return
  111. if source is None:
  112. return
  113. py_ast = ast.parse(source)
  114. if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef):
  115. raise torch.jit.frontend.FrontendError(
  116. loc, f"Cannot instantiate class '{py_ast.body[0].name}' in a script function")
  117. if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
  118. raise torch.jit.frontend.FrontendError(loc, "Expected a single top-level function")
  119. def parse_type_line(type_line, rcb, loc):
  120. """Parses a type annotation specified as a comment.
  121. Example inputs:
  122. # type: (Tensor, torch.Tensor) -> Tuple[Tensor]
  123. # type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor
  124. """
  125. arg_ann_str, ret_ann_str = split_type_line(type_line)
  126. try:
  127. arg_ann = eval(arg_ann_str, {}, EvalEnv(rcb)) # type: ignore[arg-type] # noqa: P204
  128. except (NameError, SyntaxError) as e:
  129. raise RuntimeError("Failed to parse the argument list of a type annotation") from e
  130. if not isinstance(arg_ann, tuple):
  131. arg_ann = (arg_ann,)
  132. try:
  133. ret_ann = eval(ret_ann_str, {}, EvalEnv(rcb)) # type: ignore[arg-type] # noqa: P204
  134. except (NameError, SyntaxError) as e:
  135. raise RuntimeError("Failed to parse the return type of a type annotation") from e
  136. arg_types = [ann_to_type(ann, loc) for ann in arg_ann]
  137. return arg_types, ann_to_type(ret_ann, loc)
  138. def get_type_line(source):
  139. """Tries to find the line containing a comment with the type annotation."""
  140. type_comment = '# type:'
  141. lines = source.split('\n')
  142. lines = [(line_num, line) for line_num, line in enumerate(lines)]
  143. type_lines = list(filter(lambda line: type_comment in line[1], lines))
  144. # `type: ignore` comments may be needed in JIT'ed functions for mypy, due
  145. # to the hack in torch/_VF.py.
  146. # An ignore type comment can be of following format:
  147. # 1) type: ignore
  148. # 2) type: ignore[rule-code]
  149. # This ignore statement must be at the end of the line
  150. # adding an extra backslash before the space, to avoid triggering
  151. # one of the checks in .github/workflows/lint.yml
  152. type_pattern = re.compile("# type:\\ ignore(\\[[a-zA-Z-]+\\])?$")
  153. type_lines = list(filter(lambda line: not type_pattern.search(line[1]),
  154. type_lines))
  155. if len(type_lines) == 0:
  156. # Catch common typo patterns like extra spaces, typo in 'ignore', etc.
  157. wrong_type_pattern = re.compile("#[\t ]*type[\t ]*(?!: ignore(\\[.*\\])?$):")
  158. wrong_type_lines = list(filter(lambda line: wrong_type_pattern.search(line[1]), lines))
  159. if len(wrong_type_lines) > 0:
  160. raise RuntimeError("The annotation prefix in line " + str(wrong_type_lines[0][0])
  161. + " is probably invalid.\nIt must be '# type:'"
  162. + "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" # noqa: B950
  163. + "\nfor examples")
  164. return None
  165. elif len(type_lines) == 1:
  166. # Only 1 type line, quit now
  167. return type_lines[0][1].strip()
  168. # Parse split up argument types according to PEP 484
  169. # https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code
  170. return_line = None
  171. parameter_type_lines = []
  172. for line_num, line in type_lines:
  173. if '# type: (...) -> ' in line:
  174. return_line = (line_num, line)
  175. break
  176. elif type_comment in line:
  177. parameter_type_lines.append(line)
  178. if return_line is None:
  179. raise RuntimeError(
  180. "Return type line '# type: (...) -> ...' not found on multiline "
  181. "type annotation\nfor type lines:\n" +
  182. '\n'.join([line[1] for line in type_lines]) +
  183. "\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)")
  184. def get_parameter_type(line):
  185. item_type = line[line.find(type_comment) + len(type_comment):]
  186. return item_type.strip()
  187. types = map(get_parameter_type, parameter_type_lines)
  188. parameter_types = ", ".join(types)
  189. return return_line[1].replace("...", parameter_types)
  190. def split_type_line(type_line):
  191. """Splits the comment with the type annotation into parts for argument and return types.
  192. For example, for an input of:
  193. # type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor]
  194. This function will return:
  195. ("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]")
  196. """
  197. start_offset = len('# type:')
  198. try:
  199. arrow_pos = type_line.index('->')
  200. except ValueError:
  201. raise RuntimeError("Syntax error in type annotation (cound't find `->`)") from None
  202. return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2:].strip()
  203. def try_real_annotations(fn, loc):
  204. """Tries to use the Py3.5+ annotation syntax to get the type."""
  205. try:
  206. # Note: anything annotated as `Optional[T]` will automatically
  207. # be returned as `Union[T, None]` per
  208. # https://github.com/python/typing/blob/master/src/typing.py#L850
  209. sig = inspect.signature(fn)
  210. except ValueError:
  211. return None
  212. all_annots = [sig.return_annotation] + [p.annotation for p in sig.parameters.values()]
  213. if all(ann is sig.empty for ann in all_annots):
  214. return None
  215. arg_types = [ann_to_type(p.annotation, loc)
  216. for p in sig.parameters.values()]
  217. return_type = ann_to_type(sig.return_annotation, loc)
  218. return arg_types, return_type
  219. # Finds common type for enum values belonging to an Enum class. If not all
  220. # values have the same type, AnyType is returned.
  221. def get_enum_value_type(e: Type[enum.Enum], loc):
  222. enum_values: List[enum.Enum] = list(e)
  223. if not enum_values:
  224. raise ValueError(f"No enum values defined for: '{e.__class__}'")
  225. types = {type(v.value) for v in enum_values}
  226. ir_types = [try_ann_to_type(t, loc) for t in types]
  227. # If Enum values are of different types, an exception will be raised here.
  228. # Even though Python supports this case, we chose to not implement it to
  229. # avoid overcomplicate logic here for a rare use case. Please report a
  230. # feature request if you find it necessary.
  231. return torch._C.unify_type_list(ir_types)
  232. def is_tensor(ann):
  233. if issubclass(ann, torch.Tensor):
  234. return True
  235. if issubclass(ann, (torch.LongTensor, torch.DoubleTensor, torch.FloatTensor,
  236. torch.IntTensor, torch.ShortTensor, torch.HalfTensor,
  237. torch.CharTensor, torch.ByteTensor, torch.BoolTensor)):
  238. warnings.warn("TorchScript will treat type annotations of Tensor "
  239. "dtype-specific subtypes as if they are normal Tensors. "
  240. "dtype constraints are not enforced in compilation either.")
  241. return True
  242. return False
  243. def try_ann_to_type(ann, loc):
  244. if ann is inspect.Signature.empty:
  245. return TensorType.getInferred()
  246. if ann is None:
  247. return NoneType.get()
  248. if inspect.isclass(ann) and is_tensor(ann):
  249. return TensorType.get()
  250. if is_tuple(ann):
  251. # Special case for the empty Tuple type annotation `Tuple[()]`
  252. if len(ann.__args__) == 1 and ann.__args__[0] == ():
  253. return TupleType([])
  254. return TupleType([try_ann_to_type(a, loc) for a in ann.__args__])
  255. if is_list(ann):
  256. elem_type = try_ann_to_type(ann.__args__[0], loc)
  257. if elem_type:
  258. return ListType(elem_type)
  259. if is_dict(ann):
  260. key = try_ann_to_type(ann.__args__[0], loc)
  261. value = try_ann_to_type(ann.__args__[1], loc)
  262. # Raise error if key or value is None
  263. if key is None:
  264. raise ValueError(f"Unknown type annotation: '{ann.__args__[0]}' at {loc.highlight()}")
  265. if value is None:
  266. raise ValueError(f"Unknown type annotation: '{ann.__args__[1]}' at {loc.highlight()}")
  267. return DictType(key, value)
  268. if is_optional(ann):
  269. if issubclass(ann.__args__[1], type(None)):
  270. contained = ann.__args__[0]
  271. else:
  272. contained = ann.__args__[1]
  273. valid_type = try_ann_to_type(contained, loc)
  274. msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
  275. assert valid_type, msg.format(repr(ann), repr(contained))
  276. return OptionalType(valid_type)
  277. if is_union(ann):
  278. # TODO: this is hack to recognize NumberType
  279. if set(ann.__args__) == set([int, float, complex]):
  280. return NumberType.get()
  281. inner: List = []
  282. # We need these extra checks because both `None` and invalid
  283. # values will return `None`
  284. # TODO: Determine if the other cases need to be fixed as well
  285. for a in ann.__args__:
  286. if a is None:
  287. inner.append(NoneType.get())
  288. maybe_type = try_ann_to_type(a, loc)
  289. msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
  290. assert maybe_type, msg.format(repr(ann), repr(maybe_type))
  291. inner.append(maybe_type)
  292. return UnionType(inner) # type: ignore[arg-type]
  293. if torch.distributed.rpc.is_available() and is_rref(ann):
  294. return RRefType(try_ann_to_type(ann.__args__[0], loc))
  295. if is_future(ann):
  296. return FutureType(try_ann_to_type(ann.__args__[0], loc))
  297. if ann is float:
  298. return FloatType.get()
  299. if ann is complex:
  300. return ComplexType.get()
  301. if ann is int:
  302. return IntType.get()
  303. if ann is str:
  304. return StringType.get()
  305. if ann is bool:
  306. return BoolType.get()
  307. if ann is Any:
  308. return AnyType.get()
  309. if ann is type(None):
  310. return NoneType.get()
  311. if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
  312. return InterfaceType(ann.__torch_script_interface__)
  313. if ann is torch.device:
  314. return DeviceObjType.get()
  315. if ann is torch.Stream:
  316. return StreamObjType.get()
  317. if ann is torch.dtype:
  318. return IntType.get() # dtype not yet bound in as its own type
  319. if inspect.isclass(ann) and issubclass(ann, enum.Enum):
  320. if _get_script_class(ann) is None:
  321. scripted_class = torch.jit._script._recursive_compile_class(ann, loc)
  322. name = scripted_class.qualified_name()
  323. else:
  324. name = _qualified_name(ann)
  325. return EnumType(name, get_enum_value_type(ann, loc), list(ann))
  326. if inspect.isclass(ann):
  327. maybe_script_class = _get_script_class(ann)
  328. if maybe_script_class is not None:
  329. return maybe_script_class
  330. if torch._jit_internal.can_compile_class(ann):
  331. return torch.jit._script._recursive_compile_class(ann, loc)
  332. # Maybe resolve a NamedTuple to a Tuple Type
  333. def fake_rcb(key):
  334. return None
  335. return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
  336. def ann_to_type(ann, loc):
  337. the_type = try_ann_to_type(ann, loc)
  338. if the_type is not None:
  339. return the_type
  340. raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}")
  341. __all__ = [
  342. 'Any',
  343. 'List',
  344. 'BroadcastingList1',
  345. 'BroadcastingList2',
  346. 'BroadcastingList3',
  347. 'Tuple',
  348. 'is_tuple',
  349. 'is_list',
  350. 'Dict',
  351. 'is_dict',
  352. 'is_optional',
  353. 'is_union',
  354. 'TensorType',
  355. 'TupleType',
  356. 'FloatType',
  357. 'ComplexType',
  358. 'IntType',
  359. 'ListType',
  360. 'StringType',
  361. 'DictType',
  362. 'AnyType',
  363. 'Module',
  364. # TODO: Consider not exporting these during wildcard import (reserve
  365. # that for the types; for idiomatic typing code.)
  366. 'get_signature',
  367. 'check_fn',
  368. 'get_param_names',
  369. 'parse_type_line',
  370. 'get_type_line',
  371. 'split_type_line',
  372. 'try_real_annotations',
  373. 'try_ann_to_type',
  374. 'ann_to_type',
  375. ]