_check.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. import ast
  2. import inspect
  3. import sys
  4. import textwrap
  5. import torch
  6. import warnings
  7. class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
  8. """
  9. Checks the ``__init__`` method of a given ``nn.Module`` to ensure
  10. that all instance-level attributes can be properly initialized.
  11. Specifically, we do type inference based on attribute values...even
  12. if the attribute in question has already been typed using
  13. Python3-style annotations or ``torch.jit.annotate``. This means that
  14. setting an instance-level attribute to ``[]`` (for ``List``),
  15. ``{}`` for ``Dict``), or ``None`` (for ``Optional``) isn't enough
  16. information for us to properly initialize that attribute.
  17. An object of this class can walk a given ``nn.Module``'s AST and
  18. determine if it meets our requirements or not.
  19. Known limitations
  20. 1. We can only check the AST nodes for certain constructs; we can't
  21. ``eval`` arbitrary expressions. This means that function calls,
  22. class instantiations, and complex expressions that resolve to one of
  23. the "empty" values specified above will NOT be flagged as
  24. problematic.
  25. 2. We match on string literals, so if the user decides to use a
  26. non-standard import (e.g. `from typing import List as foo`), we
  27. won't catch it.
  28. Example:
  29. .. code-block:: python
  30. class M(torch.nn.Module):
  31. def fn(self):
  32. return []
  33. def __init__(self):
  34. super().__init__()
  35. self.x: List[int] = []
  36. def forward(self, x: List[int]):
  37. self.x = x
  38. return 1
  39. The above code will pass the ``AttributeTypeIsSupportedChecker``
  40. check since we have a function call in ``__init__``. However,
  41. it will still fail later with the ``RuntimeError`` "Tried to set
  42. nonexistent attribute: x. Did you forget to initialize it in
  43. __init__()?".
  44. Args:
  45. nn_module - The instance of ``torch.nn.Module`` whose
  46. ``__init__`` method we wish to check
  47. """
  48. def check(self, nn_module: torch.nn.Module) -> None:
  49. # Check if we have a Python version <3.8
  50. self.using_deprecated_ast: bool = sys.version_info < (3, 8)
  51. source_lines = inspect.getsource(nn_module.__class__.__init__)
  52. # Ignore comments no matter the indentation
  53. def is_useless_comment(line):
  54. line = line.strip()
  55. return line.startswith("#") and not line.startswith("# type:")
  56. source_lines = "\n".join([l for l in source_lines.split("\n") if not is_useless_comment(l)])
  57. # This AST only contains the `__init__` method of the nn.Module
  58. init_ast = ast.parse(textwrap.dedent(source_lines))
  59. # Get items annotated in the class body
  60. self.class_level_annotations = list(nn_module.__annotations__.keys())
  61. # Flag for later
  62. self.visiting_class_level_ann = False
  63. self.visit(init_ast)
  64. def _is_empty_container(self, node: ast.AST, ann_type: str) -> bool:
  65. if ann_type == "List":
  66. # Assigning `[]` to a `List` type gives you a Node where
  67. # value=List(elts=[], ctx=Load())
  68. if not isinstance(node, ast.List):
  69. return False
  70. if node.elts:
  71. return False
  72. elif ann_type == "Dict":
  73. # Assigning `{}` to a `Dict` type gives you a Node where
  74. # value=Dict(keys=[], values=[])
  75. if not isinstance(node, ast.Dict):
  76. return False
  77. if node.keys:
  78. return False
  79. elif ann_type == "Optional":
  80. # Assigning `None` to an `Optional` type gives you a
  81. # Node where value=Constant(value=None, kind=None)
  82. # or, in Python <3.8, value=NameConstant(value=None)
  83. if (not self.using_deprecated_ast
  84. and not isinstance(node, ast.Constant)):
  85. return False
  86. if (self.using_deprecated_ast
  87. and not isinstance(node, ast.NameConstant)):
  88. return False
  89. if node.value: # type: ignore[attr-defined]
  90. return False
  91. return True
  92. def visit_Assign(self, node):
  93. """
  94. If we're visiting a Call Node (the right-hand side of an
  95. assignment statement), we won't be able to check the variable
  96. that we're assigning to (the left-hand side of an assignment).
  97. Because of this, we need to store this state in visitAssign.
  98. (Luckily, we only have to do this if we're assigning to a Call
  99. Node, i.e. ``torch.jit.annotate``. If we're using normal Python
  100. annotations, we'll be visiting an AnnAssign Node, which has its
  101. target built in.)
  102. """
  103. try:
  104. if (isinstance(node.value, ast.Call)
  105. and node.targets[0].attr in self.class_level_annotations):
  106. self.visiting_class_level_ann = True
  107. except AttributeError:
  108. return
  109. self.generic_visit(node)
  110. self.visiting_class_level_ann = False
  111. def visit_AnnAssign(self, node):
  112. """
  113. Visit an AnnAssign node in an ``nn.Module``'s ``__init__``
  114. method and see if it conforms to our attribute annotation rules.
  115. """
  116. # If we have a local variable
  117. try:
  118. if node.target.value.id != "self":
  119. return
  120. except AttributeError:
  121. return
  122. # If we have an attribute that's already been annotated at the
  123. # class level
  124. if node.target.attr in self.class_level_annotations:
  125. return
  126. # TODO @ansley: add `Union` once landed
  127. # NB: Even though `Tuple` is a "container", we don't want to
  128. # check for it here. `Tuple` functions as an type with an
  129. # "infinite" number of subtypes, in the sense that you can have
  130. # `Tuple[())]`, `Tuple[T1]`, `Tuple[T2]`, `Tuple[T1, T2]`,
  131. # `Tuple[T2, T1]` and so on, and none of these subtypes can be
  132. # used in place of the other. Therefore, assigning an empty
  133. # tuple in `__init__` CORRECTLY means that that variable
  134. # cannot be reassigned later to a non-empty tuple. Same
  135. # deal with `NamedTuple`
  136. containers = {"List", "Dict", "Optional"}
  137. # If we're not evaluating one of the specified problem types
  138. try:
  139. if node.annotation.value.id not in containers:
  140. return
  141. except AttributeError:
  142. # To evaluate a base type (`str`, `int`, etc.), we would
  143. # have needed to get the name through `node.annotation.id`
  144. # instead of `node.annotation.value.id`. Seems that we're
  145. # not evaluating one of our "containers"
  146. return
  147. # Check if the assigned variable is empty
  148. ann_type = node.annotation.value.id
  149. if not self._is_empty_container(node.value, ann_type):
  150. return
  151. warnings.warn("The TorchScript type system doesn't support "
  152. "instance-level annotations on empty non-base "
  153. "types in `__init__`. Instead, either 1) use a "
  154. "type annotation in the class body, or 2) wrap "
  155. "the type in `torch.jit.Attribute`.")
  156. def visit_Call(self, node):
  157. """
  158. Visit a Call node in an ``nn.Module``'s ``__init__``
  159. method and determine if it's ``torch.jit.annotate``. If so,
  160. see if it conforms to our attribute annotation rules.
  161. """
  162. # If we have an attribute that's already been annotated at the
  163. # class level
  164. if self.visiting_class_level_ann:
  165. return
  166. # If this isn't a call to `torch.jit.annotate`
  167. try:
  168. if (node.func.value.value.id != "torch"
  169. or node.func.value.attr != "jit"
  170. or node.func.attr != "annotate"):
  171. self.generic_visit(node)
  172. elif (node.func.value.value.id != "jit"
  173. or node.func.value.attr != "annotate"):
  174. self.generic_visit(node)
  175. except AttributeError:
  176. # Looks like we didn't even have the right node structure
  177. # to check for `torch.jit.annotate` in the first place
  178. self.generic_visit(node)
  179. # Invariant: we have a `torch.jit.annotate` or a
  180. # `torch.annotate` call
  181. # A Call Node for `torch.jit.annotate` should have an `args`
  182. # list of length 2 where args[0] represents the annotation and
  183. # args[1] represents the actual value
  184. if len(node.args) != 2:
  185. return
  186. if not isinstance(node.args[0], ast.Subscript):
  187. return
  188. # See notes in `visit_AnnAssign` r.e. containers
  189. containers = {"List", "Dict", "Optional"}
  190. try:
  191. ann_type = node.args[0].value.id # type: ignore[attr-defined]
  192. except AttributeError:
  193. return
  194. if ann_type not in containers:
  195. return
  196. # Check if the assigned variable is empty
  197. if not self._is_empty_container(node.args[1], ann_type):
  198. return
  199. warnings.warn("The TorchScript type system doesn't support "
  200. "instance-level annotations on empty non-base "
  201. "types in `__init__`. Instead, either 1) use a "
  202. "type annotation in the class body, or 2) wrap "
  203. "the type in `torch.jit.Attribute`.")