profiler.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684
  1. from torch.autograd.profiler_util import (
  2. EventList, FunctionEvent, MemRecordsAcc, MEMORY_EVENT_NAME,
  3. _filter_name, _filter_stack_entry, _rewrite_name
  4. )
  5. from torch.autograd import (
  6. DeviceType, ProfilerActivity, ProfilerConfig, ProfilerState,
  7. kineto_available, _ProfilerResult, _disable_profiler, _enable_profiler,
  8. _prepare_profiler, _supported_activities, _kineto_step,
  9. )
  10. from torch._C._autograd import _ExperimentalConfig
  11. import torch
  12. import torch.cuda
  13. from torch.futures import Future
  14. from typing import Any, Dict, List, Optional
  15. from warnings import warn
  16. try:
  17. # Available in Python >= 3.2
  18. from contextlib import ContextDecorator
  19. except ImportError:
  20. import functools
  21. class ContextDecorator(object): # type: ignore[no-redef]
  22. def __enter__(self):
  23. raise NotImplementedError
  24. def __exit__(self, exc_type, exc_val, exc_tb):
  25. raise NotImplementedError
  26. def __call__(self, func):
  27. @functools.wraps(func)
  28. def wrapped(*args, **kwargs):
  29. with self:
  30. return func(*args, **kwargs)
  31. return wrapped
  32. class profile(object):
  33. """Context manager that manages autograd profiler state and holds a summary of results.
  34. Under the hood it just records events of functions being executed in C++ and
  35. exposes those events to Python. You can wrap any code into it and it will
  36. only report runtime of PyTorch functions.
  37. Note: profiler is thread local and is automatically propagated into the async tasks
  38. Args:
  39. enabled (bool, optional): Setting this to False makes this context manager a no-op.
  40. use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API.
  41. Adds approximately 4us of overhead to each tensor operation.
  42. record_shapes (bool, optional): If shapes recording is set, information
  43. about input dimensions will be collected. This allows one to see which
  44. dimensions have been used under the hood and further group by them
  45. using prof.key_averages(group_by_input_shape=True). Please note that
  46. shape recording might skew your profiling data. It is recommended to
  47. use separate runs with and without shape recording to validate the timing.
  48. Most likely the skew will be negligible for bottom most events (in a case
  49. of nested function calls). But for higher level functions the total
  50. self cpu time might be artificially increased because of the shape
  51. collection.
  52. with_flops (bool, optional): If with_flops is set, the profiler will estimate
  53. the FLOPs (floating point operations) value using the operator's input shape.
  54. This allows one to estimate the hardware performance. Currently,
  55. this option only works for the matrix multiplication and 2D convolution operators.
  56. profile_memory (bool, optional): track tensor memory allocation/deallocation.
  57. with_stack (bool, optional): record source information (file and line number) for the ops.
  58. with_modules (bool): record module hierarchy (including function names)
  59. corresponding to the callstack of the op. e.g. If module A's forward call's
  60. module B's forward which contains an aten::add op,
  61. then aten::add's module hierarchy is A.B
  62. Note that this support exist, at the moment, only for TorchScript models
  63. and not eager mode models.
  64. use_kineto (bool, optional): experimental, enable profiling with Kineto profiler.
  65. use_cpu (bool, optional): profile CPU events; setting to ``False`` requires
  66. ``use_kineto=True`` and can be used to lower the overhead for GPU-only profiling.
  67. experimental_config (_ExperimentalConfig) : A set of experimental options
  68. used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed.
  69. .. warning:
  70. Enabling memory profiling or source attribution incurs additional profiler
  71. overhead
  72. .. warning:
  73. This context managers should not be called recursively, i.e. no nested
  74. instances are allowed
  75. .. warning:
  76. Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_),
  77. one cannot use the profiler with ``use_cuda = True`` to benchmark
  78. DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading,
  79. please use ``use_cuda = False`` or ``num_workers = 0``.
  80. Example:
  81. >>> x = torch.randn((1, 1), requires_grad=True)
  82. >>> with torch.autograd.profiler.profile() as prof:
  83. >>> for _ in range(100): # any normal python code, really!
  84. >>> y = x ** 2
  85. >> y.backward()
  86. >>> # NOTE: some columns were removed for brevity
  87. >>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
  88. ----------------------------------- --------------- --------------- ---------------
  89. Name Self CPU total CPU time avg Number of Calls
  90. ----------------------------------- --------------- --------------- ---------------
  91. mul 32.048ms 32.048ms 200
  92. pow 27.041ms 27.041ms 200
  93. PowBackward0 9.727ms 55.483ms 100
  94. torch::autograd::AccumulateGrad 9.148ms 9.148ms 100
  95. torch::autograd::GraphRoot 691.816us 691.816us 100
  96. ----------------------------------- --------------- --------------- ---------------
  97. """
  98. def __init__(
  99. self,
  100. enabled=True,
  101. *,
  102. use_cuda=False,
  103. record_shapes=False,
  104. with_flops=False,
  105. profile_memory=False,
  106. with_stack=False,
  107. with_modules=False,
  108. use_kineto=False,
  109. use_cpu=True,
  110. experimental_config=None):
  111. self.enabled: bool = enabled
  112. if not self.enabled:
  113. return
  114. self.use_cuda = use_cuda
  115. self.function_events: Optional[EventList] = None
  116. self.entered = False
  117. self.record_shapes = record_shapes
  118. self.with_flops = with_flops
  119. self.record_shapes |= self.with_flops
  120. self.profile_memory = profile_memory
  121. self.with_stack = with_stack
  122. self.with_modules = with_modules
  123. self.use_cpu = use_cpu
  124. if experimental_config is None:
  125. experimental_config = _ExperimentalConfig()
  126. self.experimental_config = experimental_config
  127. self.kineto_results: Optional[_ProfilerResult] = None
  128. if not self.use_cpu:
  129. assert use_kineto, \
  130. "Device-only events supported only with Kineto (use_kineto=True)"
  131. if self.use_cuda and not torch.cuda.is_available():
  132. warn("CUDA is not available, disabling CUDA profiling")
  133. self.use_cuda = False
  134. self.kineto_activities = set()
  135. if self.use_cpu:
  136. self.kineto_activities.add(ProfilerActivity.CPU)
  137. self.profiler_kind = ProfilerState.KINETO
  138. if self.use_cuda:
  139. if (not use_kineto or ProfilerActivity.CUDA not in
  140. _supported_activities()):
  141. assert self.use_cpu, "Legacy CUDA profiling requires use_cpu=True"
  142. self.profiler_kind = ProfilerState.KINETO_GPU_FALLBACK
  143. else:
  144. self.kineto_activities.add(ProfilerActivity.CUDA)
  145. assert len(self.kineto_activities) > 0, \
  146. "No activities specified for the profiler"
  147. def config(self):
  148. return ProfilerConfig(
  149. self.profiler_kind,
  150. self.record_shapes,
  151. self.profile_memory,
  152. self.with_stack,
  153. self.with_flops,
  154. self.with_modules,
  155. self.experimental_config)
  156. def __enter__(self):
  157. if not self.enabled:
  158. return
  159. if self.entered:
  160. raise RuntimeError("Profiler context manager is not reentrant")
  161. self._prepare_trace()
  162. self._start_trace()
  163. return self
  164. def _prepare_trace(self):
  165. self.entered = True
  166. _prepare_profiler(self.config(), self.kineto_activities)
  167. def _start_trace(self):
  168. self.entered = True
  169. _enable_profiler(self.config(), self.kineto_activities)
  170. def __exit__(self, exc_type, exc_val, exc_tb):
  171. if not self.enabled:
  172. return
  173. if self.use_cuda:
  174. torch.cuda.synchronize()
  175. self.kineto_results = _disable_profiler()
  176. parsed_results = self._parse_kineto_results(self.kineto_results)
  177. self.function_events = EventList(
  178. parsed_results,
  179. use_cuda=self.use_cuda,
  180. profile_memory=self.profile_memory,
  181. with_flops=self.with_flops)
  182. self.function_events._build_tree()
  183. return False
  184. def __repr__(self):
  185. if self.function_events is None:
  186. return '<unfinished torch.autograd.profile>'
  187. return repr(self.function_events)
  188. def __str__(self):
  189. if self.function_events is None:
  190. return '<unfinished torch.autograd.profile>'
  191. return str(self.function_events)
  192. def _check_finish(self):
  193. if self.function_events is None:
  194. raise RuntimeError("Profiler didn't finish running")
  195. def table(self, sort_by=None, row_limit=100, max_src_column_width=75, header=None, top_level_events_only=False):
  196. self._check_finish()
  197. assert self.function_events is not None
  198. return self.function_events.table(
  199. sort_by=sort_by, row_limit=row_limit, max_src_column_width=max_src_column_width, header=header,
  200. top_level_events_only=top_level_events_only
  201. )
  202. table.__doc__ = EventList.table.__doc__
  203. def export_chrome_trace(self, path):
  204. self._check_finish()
  205. if kineto_available():
  206. self.kineto_results.save(path) # type: ignore[union-attr]
  207. else:
  208. return self.function_events.export_chrome_trace(path) # type: ignore[union-attr]
  209. export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
  210. def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
  211. self._check_finish()
  212. assert self.function_events is not None, "Expected profiling results"
  213. assert self.with_stack, "export_stacks() requires with_stack=True"
  214. return self.function_events.export_stacks(path, metric)
  215. def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
  216. self._check_finish()
  217. assert self.function_events is not None, "Expected profiling results"
  218. return self.function_events.key_averages(group_by_input_shape, group_by_stack_n)
  219. key_averages.__doc__ = EventList.key_averages.__doc__
  220. def total_average(self):
  221. self._check_finish()
  222. assert self.function_events is not None, "Expected profiling results"
  223. return self.function_events.total_average()
  224. total_average.__doc__ = EventList.total_average.__doc__
  225. @property
  226. def self_cpu_time_total(self):
  227. """ Returns total time spent on CPU obtained as a sum of
  228. all self times across all the events.
  229. """
  230. self._check_finish()
  231. assert self.function_events is not None
  232. return self.function_events.self_cpu_time_total
  233. def _parse_kineto_results(self, result):
  234. # result.events() has most of the events - PyTorch op-level and device-level events
  235. trace_start_us = result.trace_start_us()
  236. mem_records = [[evt, False] for evt in result.events() if evt.name() == MEMORY_EVENT_NAME]
  237. mem_records_acc = MemRecordsAcc(mem_records)
  238. def _cpu_memory_usage(mem_record):
  239. return mem_record.nbytes() if \
  240. mem_record.device_type() in [DeviceType.CPU, DeviceType.MKLDNN, DeviceType.IDEEP] \
  241. else 0
  242. def _cuda_memory_usage(mem_record):
  243. return mem_record.nbytes() if \
  244. mem_record.device_type() in [DeviceType.CUDA, DeviceType.HIP] \
  245. else 0
  246. # Create and return FunctionEvent list
  247. function_events = []
  248. cuda_corr_map: Dict[int, List[FunctionEvent]] = {}
  249. max_evt_id = 0
  250. for kineto_event in result.events():
  251. if _filter_name(kineto_event.name()):
  252. continue
  253. rel_start_us = kineto_event.start_us() - trace_start_us
  254. rel_end_us = rel_start_us + kineto_event.duration_us()
  255. abs_end_us = kineto_event.start_us() + kineto_event.duration_us()
  256. cpu_memory_usage = 0
  257. cuda_memory_usage = 0
  258. if kineto_event.device_type() == DeviceType.CPU:
  259. # find the corresponding memory allocation events
  260. for mem_record in mem_records_acc.in_interval(kineto_event.start_us(), abs_end_us):
  261. cpu_memory_usage += _cpu_memory_usage(mem_record[0])
  262. cuda_memory_usage += _cuda_memory_usage(mem_record[0])
  263. mem_record[1] = True
  264. is_async = kineto_event.is_async() or (
  265. kineto_event.start_thread_id() != kineto_event.end_thread_id()
  266. )
  267. fe = FunctionEvent(
  268. id=kineto_event.correlation_id(),
  269. name=_rewrite_name(name=kineto_event.name(), with_wildcard=True),
  270. trace_name=_rewrite_name(name=kineto_event.name(), with_wildcard=False),
  271. thread=kineto_event.start_thread_id(),
  272. start_us=rel_start_us,
  273. end_us=rel_end_us,
  274. fwd_thread=kineto_event.fwd_thread_id(),
  275. input_shapes=kineto_event.shapes(),
  276. stack=[entry for entry in kineto_event.stack() if _filter_stack_entry(entry)],
  277. scope=kineto_event.scope(),
  278. cpu_memory_usage=cpu_memory_usage,
  279. cuda_memory_usage=cuda_memory_usage,
  280. is_async=is_async,
  281. sequence_nr=kineto_event.sequence_nr(),
  282. device_type=kineto_event.device_type(),
  283. device_index=kineto_event.device_index(),
  284. flops=kineto_event.flops(),
  285. )
  286. max_evt_id = fe.id if fe.id > max_evt_id else max_evt_id
  287. if fe.device_type == DeviceType.CPU and not fe.is_async:
  288. # Check if we have CUDA time as a fallback
  289. cuda_time = kineto_event.cuda_elapsed_us()
  290. if cuda_time > 0:
  291. fe.append_kernel(
  292. fe.name,
  293. fe.device_index,
  294. cuda_time)
  295. fe.is_legacy = True
  296. function_events.append(fe)
  297. corr_id = kineto_event.linked_correlation_id()
  298. if corr_id > 0:
  299. if corr_id not in cuda_corr_map:
  300. cuda_corr_map[corr_id] = []
  301. cuda_corr_map[corr_id].append(fe)
  302. # associate CUDA kernels and CUDA runtime (CPU) with CPU events
  303. for fe in function_events:
  304. if (fe.device_type == DeviceType.CPU and not fe.is_async and
  305. fe.id in cuda_corr_map):
  306. for f_evt in cuda_corr_map[fe.id]:
  307. if f_evt.device_type == DeviceType.CUDA:
  308. fe.append_kernel(
  309. f_evt.name,
  310. f_evt.device_index,
  311. f_evt.time_range.end - f_evt.time_range.start)
  312. elif f_evt.device_type == DeviceType.CPU:
  313. # make sure that 'thread' of a CPU Kineto (e.g. CUDA Runtime) event is associated
  314. # with the 'thread' of the corresponding linked PyTorch event to properly track
  315. # parents and children
  316. f_evt.thread = fe.thread
  317. # output top-level memory events
  318. for mem_record in mem_records:
  319. if not mem_record[1]:
  320. rel_start_us = mem_record[0].start_us() - trace_start_us
  321. max_evt_id += 1
  322. fe = FunctionEvent(
  323. id=max_evt_id,
  324. name=MEMORY_EVENT_NAME,
  325. trace_name=None, # not outputting in the trace
  326. thread=mem_record[0].start_thread_id(),
  327. start_us=rel_start_us,
  328. end_us=rel_start_us, # no duration
  329. fwd_thread=mem_record[0].start_thread_id(),
  330. input_shapes=[],
  331. stack=[],
  332. scope=0, # RecordScope::FUNCTION
  333. cpu_memory_usage=_cpu_memory_usage(mem_record[0]),
  334. cuda_memory_usage=_cuda_memory_usage(mem_record[0]),
  335. is_async=False,
  336. sequence_nr=-1,
  337. device_type=DeviceType.CPU,
  338. device_index=0,
  339. )
  340. function_events.append(fe)
  341. function_events.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end])
  342. return function_events
  343. class record_function(ContextDecorator):
  344. """Context manager/function decorator that adds a label to a block of
  345. Python code (or function) when running autograd profiler. It is
  346. useful when tracing the code profile.
  347. Args:
  348. name (str): Label assigned to the block of code.
  349. node_id (int): ID of node, for distributed profiling. Unset in
  350. non-distributed cases.
  351. Example:
  352. >>> x = torch.randn((1, 1), requires_grad=True)
  353. >>> with torch.autograd.profiler.profile() as prof:
  354. ... y = x ** 2
  355. ... with torch.autograd.profiler.record_function("label-z"): # label the block
  356. ... z = y ** 3
  357. ... y.backward()
  358. ...
  359. >>> # NOTE: some columns were removed for brevity
  360. >>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
  361. ----------------------------------- --------------- --------------- ---------------
  362. Name Self CPU total % CPU time avg Number of Calls
  363. ----------------------------------- --------------- --------------- ---------------
  364. pow 60.77% 47.470us 3
  365. mul 21.73% 25.465us 2
  366. PowBackward0 12.03% 121.891us 1
  367. torch::autograd::AccumulateGrad 2.70% 6.324us 1
  368. label-z 2.13% 12.421us 1
  369. torch::autograd::GraphRoot 0.64% 1.503us 1
  370. ----------------------------------- --------------- --------------- ---------------
  371. Self CPU time total: 234.344us
  372. CUDA time total: 0.000us
  373. """
  374. def __init__(self, name: str, args: Optional[str] = None):
  375. self.name: str = name
  376. self.args: Optional[str] = args
  377. # Whether or not we should run record function's end callbacks when exiting.
  378. self.run_callbacks_on_exit: bool = True
  379. # Stores underlying RecordFunction as a tensor. TODO: move to custom
  380. # class (https://github.com/pytorch/pytorch/issues/35026).
  381. self.handle: torch.Tensor = torch.zeros(1)
  382. def __enter__(self):
  383. self.handle = torch.ops.profiler._record_function_enter(self.name, self.args)
  384. return self
  385. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
  386. if self.run_callbacks_on_exit:
  387. torch.ops.profiler._record_function_exit(self.handle)
  388. def _call_end_callbacks_on_future(self, fut: Future[Any]) -> Future[Any]:
  389. """
  390. _call_end_callbacks_on_future is meant to be used for profiling async
  391. calls that return a future. Calling this function will extend recording
  392. beyond this scope, until the future is satisfied. It is useful for profiling
  393. the end to end time of asynchronous calls. This function should only be called
  394. once to attach the callback onto the future, and will throw if called multiple
  395. times.
  396. Args:
  397. fut: (torch._C.Future): future for which to schedule
  398. callback for.
  399. Returns:
  400. A future that completes with the value of the passed in future when
  401. the profiling callbacks have ran.
  402. """
  403. # Throw if we have already attached a callback onto the future.
  404. if not self.run_callbacks_on_exit:
  405. raise RuntimeError("_call_end_callbacks_on_future can only be called once.")
  406. # We are scheduling to run this RecordFunction's end callbacks when the
  407. # passed in future completes, so don't run end callbacks on exit.
  408. self.run_callbacks_on_exit = False
  409. profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut(self.handle, fut)
  410. return profiled_future
  411. class emit_nvtx(object):
  412. """Context manager that makes every autograd operation emit an NVTX range.
  413. It is useful when running the program under nvprof::
  414. nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
  415. Unfortunately, there's no way to force nvprof to flush the data it collected
  416. to disk, so for CUDA profiling one has to use this context manager to annotate
  417. nvprof traces and wait for the process to exit before inspecting them.
  418. Then, either NVIDIA Visual Profiler (nvvp) can be used to visualize the timeline, or
  419. :func:`torch.autograd.profiler.load_nvprof` can load the results for inspection
  420. e.g. in Python REPL.
  421. .. warning:
  422. This context manager should not be called recursively, i.e. at most one
  423. instance should be enabled at any given time.
  424. Args:
  425. enabled (bool, optional, default=True): Setting ``enabled=False`` makes this context manager a no-op.
  426. Default: ``True``.
  427. record_shapes (bool, optional, default=False): If ``record_shapes=True``, the nvtx range wrapping
  428. each autograd op will append information about the sizes of Tensor arguments received
  429. by that op, in the following format:
  430. ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]``
  431. Non-tensor arguments will be represented by ``[]``.
  432. Arguments will be listed in the order they are received by the backend op.
  433. Please note that this order may not match the order in which those arguments were passed
  434. on the Python side. Also note that shape recording may increase the overhead of nvtx range creation.
  435. Example:
  436. >>> with torch.cuda.profiler.profile():
  437. ... model(x) # Warmup CUDA memory allocator and profiler
  438. ... with torch.autograd.profiler.emit_nvtx():
  439. ... model(x)
  440. **Forward-backward correlation**
  441. When viewing a profile created using :class:`emit_nvtx` in the Nvidia Visual Profiler,
  442. correlating each backward-pass op with the corresponding forward-pass op can be difficult.
  443. To ease this task, :class:`emit_nvtx` appends sequence number information to the ranges it
  444. generates.
  445. During the forward pass, each function range is decorated with ``seq=<N>``. ``seq`` is a running
  446. counter, incremented each time a new backward Function object is created and stashed for backward.
  447. Thus, the ``seq=<N>`` annotation associated with each forward function range tells you that
  448. if a backward Function object is created by this forward function,
  449. the backward object will receive sequence number N.
  450. During the backward pass, the top-level range wrapping each C++ backward Function's
  451. ``apply()`` call is decorated with ``stashed seq=<M>``. ``M`` is the sequence number that
  452. the backward object was created with. By comparing ``stashed seq`` numbers in backward with ``seq``
  453. numbers in forward, you can track down which forward op created each backward Function.
  454. Any functions executed during the backward pass are also decorated with ``seq=<N>``. During
  455. default backward (with ``create_graph=False``) this information is irrelevant, and in fact,
  456. ``N`` may simply be 0 for all such functions. Only the top-level ranges associated with
  457. backward Function objects' ``apply()`` methods are useful, as a way to correlate these Function
  458. objects with the earlier forward pass.
  459. **Double-backward**
  460. If, on the other hand, a backward pass with ``create_graph=True`` is underway (in other words,
  461. if you are setting up for a double-backward), each function's execution during backward
  462. is given a nonzero, useful ``seq=<N>``. Those functions may themselves create Function objects
  463. to be executed later during double-backward, just as the original functions in the forward pass did.
  464. The relationship between backward and double-backward is conceptually the same as the relationship
  465. between forward and backward: The functions still emit current-sequence-number-tagged ranges,
  466. the Function objects they create still stash those sequence numbers, and during the eventual
  467. double-backward, the Function objects' ``apply()`` ranges are still tagged with ``stashed seq``
  468. numbers, which can be compared to `seq` numbers from the backward pass.
  469. .. warning:
  470. The sequence number is thread-local, and some forward functions don't create an associated
  471. backward Function object (instead delegating that to sub-functions further down the call chain).
  472. For these reasons, the correspondence of stashed sequence numbers in
  473. backward Function ``apply()`` ranges with `seq` numbers in forward-pass ranges is
  474. not guaranteed to be 1 to 1. The sequence numbers alone may not be enough to fully
  475. disambiguate which forward function created which
  476. backward Function object. You may need to make a judgment based on analytic knowledge of what
  477. the expected correspondence should be.
  478. """
  479. def __init__(self, enabled=True, record_shapes=False):
  480. self.enabled = enabled
  481. self.entered = False
  482. self.record_shapes = record_shapes
  483. def __enter__(self):
  484. if not self.enabled:
  485. return
  486. if self.entered:
  487. raise RuntimeError("NVTX annotation context manager is not reentrant")
  488. self.entered = True
  489. torch.cuda.synchronize()
  490. _enable_profiler(
  491. ProfilerConfig(
  492. ProfilerState.NVTX,
  493. self.record_shapes,
  494. False,
  495. False,
  496. False,
  497. False,
  498. _ExperimentalConfig()),
  499. set()
  500. )
  501. return self
  502. def __exit__(self, exc_type, exc_val, exc_tb):
  503. if not self.enabled:
  504. return
  505. torch.cuda.synchronize()
  506. _disable_profiler()
  507. return False
  508. def load_nvprof(path):
  509. """Opens an nvprof trace file and parses autograd annotations.
  510. Args:
  511. path (str): path to nvprof trace
  512. """
  513. return EventList(parse_nvprof_trace(path))
  514. class EnforceUnique(object):
  515. """Raises an error if a key is seen more than once."""
  516. def __init__(self):
  517. self.seen = set()
  518. def see(self, *key):
  519. if key in self.seen:
  520. raise RuntimeError('duplicate key: ' + str(key))
  521. self.seen.add(key)
  522. def parse_nvprof_trace(path):
  523. import sqlite3
  524. conn = sqlite3.connect(path)
  525. conn.row_factory = sqlite3.Row
  526. # Parse strings table
  527. strings = {}
  528. for r in conn.execute("SELECT _id_ as id, value FROM StringTable"):
  529. strings[r["id"]] = torch._C._demangle(r["value"])
  530. # First, find all functions and create FunctionEvents for them
  531. marker_query = """
  532. SELECT
  533. start.id AS marker_id, start.name, start.timestamp AS start_time, end.timestamp AS end_time
  534. FROM
  535. CUPTI_ACTIVITY_KIND_MARKER AS start INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
  536. ON start.id = end.id
  537. WHERE
  538. start.name != 0 AND end.name = 0
  539. """
  540. functions = []
  541. functions_map = {}
  542. unique = EnforceUnique()
  543. for row in conn.execute(marker_query):
  544. unique.see(row['marker_id'])
  545. evt = FunctionEvent(id=row['marker_id'],
  546. node_id=0, # missing a node_id when calling FunctionEvent. This is just to ensure
  547. # that pytorch doesn't crash when creating a FunctionEvent() object
  548. name=strings[row['name']],
  549. start_us=row['start_time'],
  550. end_us=row['end_time'],
  551. thread=0) # TODO: find in sqlite database
  552. functions.append(evt)
  553. functions_map[evt.id] = evt
  554. # Now, correlate all kernels with FunctionEvents
  555. kernel_query = """
  556. SELECT
  557. start.id AS marker_id, start.name, start.timestamp, end.timestamp,
  558. runtime._id_ AS runtime_id, runtime.cbid, runtime.start AS runtime_start, runtime.end AS runtime_end,
  559. kernel.start AS kernel_start, kernel.end AS kernel_end, kernel.name AS kernel_name
  560. FROM
  561. CUPTI_ACTIVITY_KIND_MARKER AS start
  562. INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
  563. ON start.id = end.id
  564. INNER JOIN CUPTI_ACTIVITY_KIND_RUNTIME as runtime
  565. ON (start.timestamp < runtime.start AND runtime.end < end.timestamp)
  566. INNER JOIN CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL AS kernel
  567. ON kernel.correlationId = runtime.correlationId
  568. """
  569. unique = EnforceUnique()
  570. for row in conn.execute(kernel_query):
  571. unique.see(row['marker_id'], row['runtime_id'])
  572. # 211 is cudaKernelLaunch for cuda >= 9.2
  573. assert (row['cbid'] == 211)
  574. evt = functions_map[row['marker_id']]
  575. evt.append_kernel(row['kernel_name'],
  576. 0,
  577. row['kernel_end'] - row['kernel_start'])
  578. functions.sort(key=lambda evt: evt.time_range.start)
  579. return functions
  580. def kineto_step():
  581. """ Notify kineto so it is aware of iteration boundaries for asynchronous
  582. trace requests.
  583. """
  584. _kineto_step()