_sources.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import ast
  2. import functools
  3. import inspect
  4. from textwrap import dedent
  5. from typing import Any, Optional, Tuple, List, NamedTuple
  6. from torch._C import ErrorReport
  7. from torch._C._jit_tree_views import SourceRangeFactory
  8. def get_source_lines_and_file(
  9. obj: Any,
  10. error_msg: Optional[str] = None,
  11. ) -> Tuple[List[str], int, Optional[str]]:
  12. """
  13. Wrapper around inspect.getsourcelines and inspect.getsourcefile.
  14. Returns: (sourcelines, file_lino, filename)
  15. """
  16. filename = None # in case getsourcefile throws
  17. try:
  18. filename = inspect.getsourcefile(obj)
  19. sourcelines, file_lineno = inspect.getsourcelines(obj)
  20. except OSError as e:
  21. msg = (f"Can't get source for {obj}. TorchScript requires source access in "
  22. "order to carry out compilation, make sure original .py files are "
  23. "available.")
  24. if error_msg:
  25. msg += '\n' + error_msg
  26. raise OSError(msg) from e
  27. return sourcelines, file_lineno, filename
  28. def normalize_source_lines(sourcelines: List[str]) -> List[str]:
  29. """
  30. This helper function accepts a list of source lines. It finds the
  31. indentation level of the function definition (`def`), then it indents
  32. all lines in the function body to a point at or greater than that
  33. level. This allows for comments and continued string literals that
  34. are at a lower indentation than the rest of the code.
  35. Args:
  36. sourcelines: function source code, separated into lines by
  37. the '\n' character
  38. Returns:
  39. A list of source lines that have been correctly aligned
  40. """
  41. def remove_prefix(text, prefix):
  42. return text[text.startswith(prefix) and len(prefix):]
  43. # Find the line and line number containing the function definition
  44. for i, l in enumerate(sourcelines):
  45. if l.lstrip().startswith("def"):
  46. idx = i
  47. break
  48. fn_def = sourcelines[idx]
  49. # Get a string representing the amount of leading whitespace
  50. whitespace = fn_def.split("def")[0]
  51. # Add this leading whitespace to all lines before and after the `def`
  52. aligned_prefix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]]
  53. aligned_suffix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1:]]
  54. # Put it together again
  55. aligned_prefix.append(fn_def)
  56. return aligned_prefix + aligned_suffix
  57. # Thin wrapper around SourceRangeFactory to store extra metadata
  58. # about the function-to-be-compiled.
  59. class SourceContext(SourceRangeFactory):
  60. def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True, funcname=None):
  61. super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
  62. self.uses_true_division = uses_true_division
  63. self.filename = filename
  64. self.funcname = funcname
  65. @functools.lru_cache(maxsize=None)
  66. def make_source_context(*args):
  67. return SourceContext(*args)
  68. def fake_range():
  69. return SourceContext('', None, 0, 0).make_raw_range(0, 1)
  70. class ParsedDef(NamedTuple):
  71. ast: ast.Module
  72. ctx: SourceContext
  73. source: str
  74. filename: Optional[str]
  75. file_lineno: int
  76. def parse_def(fn):
  77. sourcelines, file_lineno, filename = get_source_lines_and_file(fn, ErrorReport.call_stack())
  78. sourcelines = normalize_source_lines(sourcelines)
  79. source = ''.join(sourcelines)
  80. dedent_src = dedent(source)
  81. py_ast = ast.parse(dedent_src)
  82. if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
  83. raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
  84. leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
  85. ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True, fn.__name__)
  86. return ParsedDef(py_ast, ctx, source, filename, file_lineno)