profiler_legacy.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. import torch
  2. import torch.cuda
  3. from torch.autograd.profiler_util import (
  4. EventList, FunctionEvent, MEMORY_EVENT_NAME,
  5. _filter_name, _filter_stack_entry, _rewrite_name
  6. )
  7. from torch.autograd import (
  8. DeviceType, ProfilerConfig, ProfilerState,
  9. _disable_profiler_legacy, _enable_profiler_legacy,
  10. )
  11. import itertools
  12. from warnings import warn
  13. class profile(object):
  14. """DEPRECATED: use torch.profiler instead"""
  15. def __init__(
  16. self,
  17. enabled=True,
  18. *,
  19. use_cuda=False,
  20. record_shapes=False,
  21. with_flops=False,
  22. profile_memory=False,
  23. with_stack=False,
  24. with_modules=False):
  25. self.enabled: bool = enabled
  26. if not self.enabled:
  27. return
  28. self.use_cuda = use_cuda
  29. self.function_events = None
  30. self.entered = False
  31. self.record_shapes = record_shapes
  32. self.with_flops = with_flops
  33. self.record_shapes |= self.with_flops
  34. self.profile_memory = profile_memory
  35. self.with_stack = with_stack
  36. self.with_modules = with_modules
  37. if self.use_cuda and not torch.cuda.is_available():
  38. warn("CUDA is not available, disabling CUDA profiling")
  39. self.use_cuda = False
  40. if self.use_cuda:
  41. self.profiler_kind = ProfilerState.CUDA
  42. else:
  43. self.profiler_kind = ProfilerState.CPU
  44. def config(self):
  45. return ProfilerConfig(
  46. self.profiler_kind,
  47. self.record_shapes,
  48. self.profile_memory,
  49. self.with_stack,
  50. self.with_flops,
  51. self.with_modules,
  52. # avoid exposing _ExperimentalConfig this in legacy public API
  53. torch._C._autograd._ExperimentalConfig(),
  54. )
  55. def __enter__(self):
  56. if not self.enabled:
  57. return
  58. if self.entered:
  59. raise RuntimeError("Profiler context manager is not reentrant")
  60. self.entered = True
  61. self._start_trace()
  62. return self
  63. def _start_trace(self):
  64. _enable_profiler_legacy(self.config())
  65. def __exit__(self, exc_type, exc_val, exc_tb):
  66. if not self.enabled:
  67. return
  68. if self.use_cuda:
  69. torch.cuda.synchronize()
  70. records = _disable_profiler_legacy()
  71. parsed_results = _parse_legacy_records(records)
  72. self.function_events = EventList(
  73. parsed_results,
  74. use_cuda=self.use_cuda,
  75. profile_memory=self.profile_memory,
  76. with_flops=self.with_flops)
  77. self.function_events._build_tree()
  78. return False
  79. def __repr__(self):
  80. if self.function_events is None:
  81. return '<unfinished profiler_legacy.profile>'
  82. return repr(self.function_events)
  83. def __str__(self):
  84. if self.function_events is None:
  85. return '<unfinished profile.profiler_legacy.profile>'
  86. return str(self.function_events)
  87. def _check_finish(self):
  88. if self.function_events is None:
  89. raise RuntimeError("Profiler didn't finish running")
  90. def table(self, sort_by=None, row_limit=100, max_src_column_width=75, header=None, top_level_events_only=False):
  91. self._check_finish()
  92. assert self.function_events is not None
  93. return self.function_events.table(
  94. sort_by=sort_by, row_limit=row_limit, max_src_column_width=max_src_column_width, header=header,
  95. top_level_events_only=top_level_events_only
  96. )
  97. table.__doc__ = EventList.table.__doc__
  98. def export_chrome_trace(self, path):
  99. self._check_finish()
  100. assert self.function_events is not None
  101. return self.function_events.export_chrome_trace(path)
  102. export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
  103. def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
  104. self._check_finish()
  105. assert self.function_events is not None, "Expected profiling results"
  106. assert self.with_stack, "export_stacks() requires with_stack=True"
  107. return self.function_events.export_stacks(path, metric)
  108. def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
  109. self._check_finish()
  110. assert self.function_events is not None, "Expected profiling results"
  111. return self.function_events.key_averages(group_by_input_shape, group_by_stack_n)
  112. key_averages.__doc__ = EventList.key_averages.__doc__
  113. def total_average(self):
  114. self._check_finish()
  115. assert self.function_events is not None, "Expected profiling results"
  116. return self.function_events.total_average()
  117. total_average.__doc__ = EventList.total_average.__doc__
  118. @property
  119. def self_cpu_time_total(self):
  120. """ Returns total time spent on CPU obtained as a sum of
  121. all self times across all the events.
  122. """
  123. self._check_finish()
  124. assert self.function_events is not None
  125. return self.function_events.self_cpu_time_total
  126. def _parse_legacy_records(thread_records):
  127. def _get_record_key(record):
  128. """
  129. Returns a tuple to be used by _parse_legacy_records for correlating start and
  130. end records.
  131. """
  132. return (record.handle(), record.node_id())
  133. next_id = 0
  134. start_record = None
  135. functions = []
  136. record_stack = []
  137. # '__start_profile' is not guaranteed to be first, so we must find it here
  138. for record in itertools.chain(*thread_records):
  139. name = record.name()
  140. if start_record is None and name == '__start_profile':
  141. start_record = record
  142. assert start_record is not None and not start_record.is_remote()
  143. for thread_record_list in thread_records:
  144. # accumulated memory allocations per handle
  145. cpu_memory_allocs = {}
  146. cuda_memory_allocs = {}
  147. # ranges per handle
  148. range_starts = {}
  149. filtered_handles = set()
  150. prev_record = None
  151. for record in thread_record_list:
  152. record_key = _get_record_key(record)
  153. if (_filter_name(record.name()) or
  154. record_key in filtered_handles):
  155. filtered_handles.add(record_key)
  156. continue
  157. if record.kind() == 'push':
  158. # workaround to reduce double logging from operator
  159. # wrappers and redispatch
  160. if prev_record is not None:
  161. duplicate = (
  162. prev_record.name() == record.name()
  163. and prev_record.kind() == record.kind()
  164. and prev_record.node_id() == record.node_id()
  165. )
  166. if duplicate:
  167. filtered_handles.add(record_key)
  168. continue
  169. range_starts[record_key] = record
  170. cpu_memory_allocs[record_key] = 0
  171. cuda_memory_allocs[record_key] = 0
  172. elif record.kind() == 'pop':
  173. assert (
  174. record_key in range_starts
  175. ), """Expected record with key {} to exist in range_starts.
  176. This means that the pop event did not have a corresponding push.""".format(
  177. record_key
  178. )
  179. start = range_starts[record_key]
  180. cpu_memory_usage = cpu_memory_allocs[record_key]
  181. cuda_memory_usage = cuda_memory_allocs[record_key]
  182. is_async = start.is_async() or (
  183. start.thread_id() != record.thread_id()
  184. )
  185. is_remote_event = record.is_remote()
  186. start_flops = start.flops()
  187. fe = FunctionEvent(
  188. id=record.handle(),
  189. node_id=record.node_id(),
  190. name=_rewrite_name(name=start.name(), with_wildcard=True),
  191. trace_name=_rewrite_name(name=start.name(), with_wildcard=False),
  192. thread=start.thread_id(),
  193. start_us=start_record.cpu_elapsed_us(start),
  194. end_us=start_record.cpu_elapsed_us(record),
  195. fwd_thread=start.fwd_thread_id(),
  196. input_shapes=start.shapes(),
  197. stack=[entry for entry in start.stack() if _filter_stack_entry(entry)],
  198. scope=start.scope(),
  199. cpu_memory_usage=cpu_memory_usage,
  200. cuda_memory_usage=cuda_memory_usage,
  201. is_async=is_async,
  202. is_remote=is_remote_event,
  203. sequence_nr=start.sequence_nr(),
  204. device_type=DeviceType.CPU,
  205. is_legacy=True,
  206. flops=start_flops,
  207. )
  208. # note: async events have only cpu total time
  209. if not is_async and start.has_cuda():
  210. duration = start.cuda_elapsed_us(record)
  211. if duration > 0:
  212. fe.append_kernel(
  213. start.name(),
  214. start.device(),
  215. duration)
  216. functions.append(fe)
  217. del range_starts[record_key]
  218. del cpu_memory_allocs[record_key]
  219. del cuda_memory_allocs[record_key]
  220. elif record.kind() == 'memory_alloc':
  221. num_open_handles_cpu = len(cpu_memory_allocs)
  222. num_open_handles_cuda = len(cuda_memory_allocs)
  223. assert num_open_handles_cpu == num_open_handles_cuda
  224. for handle in cpu_memory_allocs.keys():
  225. cpu_memory_allocs[handle] += record.cpu_memory_usage()
  226. for handle in cuda_memory_allocs.keys():
  227. cuda_memory_allocs[handle] += record.cuda_memory_usage()
  228. if num_open_handles_cpu == 0:
  229. # output event as a top-level memory event
  230. fe = FunctionEvent(
  231. id=0,
  232. name=MEMORY_EVENT_NAME,
  233. trace_name=None,
  234. thread=0,
  235. start_us=0,
  236. end_us=0,
  237. stack=[],
  238. cpu_memory_usage=record.cpu_memory_usage(),
  239. cuda_memory_usage=record.cuda_memory_usage(),
  240. is_legacy=True,
  241. )
  242. functions.append(fe)
  243. prev_record = record
  244. # Sort functions by start time then by end time ascending.
  245. # This ensures that--in the case of nested events which
  246. # have the same start time (which may happen due to the
  247. # granularity of the given clock tick)--we always show
  248. # the outermost nested call first. This adds stability
  249. # in how FunctionEvents appear
  250. functions.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end])
  251. return functions