profiler.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. import gzip
  2. import json
  3. import os
  4. import tempfile
  5. from enum import Enum
  6. from functools import partial
  7. from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
  8. from warnings import warn
  9. import torch
  10. import torch.autograd.profiler as prof
  11. from torch.autograd import ProfilerActivity, kineto_available
  12. from torch._C._autograd import _ExperimentalConfig
  13. def supported_activities():
  14. """
  15. Returns a set of supported profiler tracing activities.
  16. Note: profiler uses CUPTI library to trace on-device CUDA kernels.
  17. In case when CUDA is enabled but CUPTI is not available, passing
  18. ``ProfilerActivity.CUDA`` to profiler results in using the legacy CUDA
  19. profiling code (same as in the legacy ``torch.autograd.profiler``).
  20. This, in turn, results in including CUDA time in the profiler table output,
  21. but not in the JSON trace.
  22. """
  23. return torch.autograd._supported_activities()
  24. class _KinetoProfile(object):
  25. """Low-level profiler wrap the autograd profile
  26. Args:
  27. activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values:
  28. ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``.
  29. Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA.
  30. record_shapes (bool): save information about operator's input shapes.
  31. profile_memory (bool): track tensor memory allocation/deallocation.
  32. with_stack (bool): record source information (file and line number) for the ops.
  33. with_flops (bool): use formula to estimate the FLOPS of specific operators
  34. (matrix multiplication and 2D convolution).
  35. with_modules (bool): record module hierarchy (including function names)
  36. corresponding to the callstack of the op. e.g. If module A's forward call's
  37. module B's forward which contains an aten::add op,
  38. then aten::add's module hierarchy is A.B
  39. Note that this support exist, at the moment, only for TorchScript models
  40. and not eager mode models.
  41. experimental_config (_ExperimentalConfig) : A set of experimental options
  42. used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed.
  43. .. note::
  44. This API is an experimental and subject to change in future.
  45. Enabling shape and stack tracing results in additional overhead.
  46. When record_shapes=True is specified, profiler will temporarily hold references to the tensors;
  47. that may further prevent certain optimizations that depend on the reference count and introduce
  48. extra tensor copies.
  49. """
  50. def __init__(
  51. self,
  52. *,
  53. activities: Optional[Iterable[ProfilerActivity]] = None,
  54. record_shapes: bool = False,
  55. profile_memory: bool = False,
  56. with_stack: bool = False,
  57. with_flops: bool = False,
  58. with_modules: bool = False,
  59. experimental_config: Optional[_ExperimentalConfig] = None):
  60. self.activities = set(activities) if activities else supported_activities()
  61. self.record_shapes = record_shapes
  62. self.with_flops = with_flops
  63. self.profile_memory = profile_memory
  64. self.with_stack = with_stack
  65. self.with_modules = with_modules
  66. self.experimental_config = experimental_config
  67. self.profiler: Optional[prof.profile] = None
  68. def start(self):
  69. self.prepare_trace()
  70. self.start_trace()
  71. def stop(self):
  72. self.stop_trace()
  73. def prepare_trace(self):
  74. self.profiler = prof.profile(
  75. use_cuda=(ProfilerActivity.CUDA in self.activities),
  76. use_cpu=(ProfilerActivity.CPU in self.activities),
  77. record_shapes=self.record_shapes,
  78. with_flops=self.with_flops,
  79. profile_memory=self.profile_memory,
  80. with_stack=self.with_stack,
  81. with_modules=self.with_modules,
  82. use_kineto=True,
  83. experimental_config=self.experimental_config,
  84. )
  85. self.profiler._prepare_trace()
  86. def start_trace(self):
  87. assert self.profiler is not None
  88. self.profiler._start_trace()
  89. if kineto_available():
  90. dist_info = self._get_distributed_info()
  91. if dist_info:
  92. self.add_metadata_json("distributedInfo", json.dumps(dist_info))
  93. def stop_trace(self):
  94. assert self.profiler is not None
  95. self.profiler.__exit__(None, None, None)
  96. def export_chrome_trace(self, path: str):
  97. """
  98. Exports the collected trace in Chrome JSON format.
  99. """
  100. assert self.profiler
  101. if path.endswith('.gz'):
  102. fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
  103. fp.close()
  104. retvalue = self.profiler.export_chrome_trace(fp.name)
  105. with open(fp.name) as fin:
  106. with gzip.open(path, 'wt') as fout:
  107. fout.writelines(fin)
  108. os.remove(fp.name)
  109. return retvalue
  110. else:
  111. return self.profiler.export_chrome_trace(path)
  112. def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
  113. """Save stack traces in a file in a format suitable for visualization.
  114. Args:
  115. path (str): save stacks file to this location;
  116. metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total"
  117. .. note::
  118. Example of using FlameGraph tool:
  119. - git clone https://github.com/brendangregg/FlameGraph
  120. - cd FlameGraph
  121. - ./flamegraph.pl --title "CPU time" --countname "us." profiler.stacks > perf_viz.svg
  122. """
  123. assert self.profiler
  124. return self.profiler.export_stacks(path, metric)
  125. def key_averages(self, group_by_input_shape: bool = False, group_by_stack_n: int = 0):
  126. """Averages events, grouping them by operator name and (optionally) input shapes and
  127. stack.
  128. .. note::
  129. To use shape/stack functionality make sure to set record_shapes/with_stack
  130. when creating profiler context manager.
  131. """
  132. assert self.profiler
  133. return self.profiler.key_averages(group_by_input_shape, group_by_stack_n)
  134. def events(self):
  135. """
  136. Returns the list of unaggregated profiler events,
  137. to be used in the trace callback or after the profiling is finished
  138. """
  139. assert self.profiler
  140. return self.profiler.function_events
  141. def add_metadata(self, key: str, value: str):
  142. """
  143. Adds a user defined metadata with a string key and a string value
  144. into the trace file
  145. """
  146. wrapped_value = "\"" + value.replace('"', '\\"') + "\""
  147. torch.autograd._add_metadata_json(key, wrapped_value)
  148. def add_metadata_json(self, key: str, value: str):
  149. """
  150. Adds a user defined metadata with a string key and a valid json value
  151. into the trace file
  152. """
  153. torch.autograd._add_metadata_json(key, value)
  154. def _get_distributed_info(self):
  155. import torch.distributed as dist
  156. if not dist.is_available() or not dist.is_initialized():
  157. return None
  158. return {
  159. "backend": dist.get_backend(),
  160. "rank": dist.get_rank(),
  161. "world_size": dist.get_world_size()
  162. }
  163. class ProfilerAction(Enum):
  164. """
  165. Profiler actions that can be taken at the specified intervals
  166. """
  167. NONE = 0
  168. WARMUP = 1
  169. RECORD = 2
  170. RECORD_AND_SAVE = 3
  171. def schedule(*, wait: int, warmup: int, active: int, repeat: int = 0, skip_first: int = 0) -> Callable:
  172. """
  173. Returns a callable that can be used as profiler ``schedule`` argument. The profiler will skip
  174. the first ``skip_first`` steps, then wait for ``wait`` steps, then do the warmup for the next ``warmup`` steps,
  175. then do the active recording for the next ``active`` steps and then repeat the cycle starting with ``wait`` steps.
  176. The optional number of cycles is specified with the ``repeat`` parameter, the zero value means that
  177. the cycles will continue until the profiling is finished.
  178. """
  179. def schedule_fn(step: int) -> ProfilerAction:
  180. assert step >= 0
  181. if step < skip_first:
  182. return ProfilerAction.NONE
  183. else:
  184. step -= skip_first
  185. num_steps = wait + warmup + active
  186. if repeat > 0 and step / num_steps >= repeat:
  187. return ProfilerAction.NONE
  188. mod_step = step % num_steps
  189. if mod_step < wait:
  190. return ProfilerAction.NONE
  191. elif mod_step < wait + warmup:
  192. return ProfilerAction.WARMUP
  193. else:
  194. return ProfilerAction.RECORD if mod_step < num_steps - 1 \
  195. else ProfilerAction.RECORD_AND_SAVE
  196. assert wait >= 0 and warmup >= 0 and active > 0 and \
  197. repeat >= 0 and skip_first >= 0, "Invalid profiler schedule arguments"
  198. if warmup == 0:
  199. warn("Profiler won't be using warmup, this can skew profiler results")
  200. return schedule_fn
  201. def _default_schedule_fn(_: int) -> ProfilerAction:
  202. """
  203. Default profiler behavior - immediately starts recording the events,
  204. keeps doing it on every profiler step.
  205. """
  206. return ProfilerAction.RECORD
  207. def tensorboard_trace_handler(dir_name: str, worker_name: Optional[str] = None, use_gzip: bool = False):
  208. """
  209. Outputs tracing files to directory of ``dir_name``, then that directory can be
  210. directly delivered to tensorboard as logdir.
  211. ``worker_name`` should be unique for each worker in distributed scenario,
  212. it will be set to '[hostname]_[pid]' by default.
  213. """
  214. import os
  215. import socket
  216. import time
  217. def handler_fn(prof) -> None:
  218. nonlocal worker_name
  219. if not os.path.isdir(dir_name):
  220. try:
  221. os.makedirs(dir_name, exist_ok=True)
  222. except Exception:
  223. raise RuntimeError("Can't create directory: " + dir_name)
  224. if not worker_name:
  225. worker_name = "{}_{}".format(socket.gethostname(), str(os.getpid()))
  226. file_name = "{}.{}.pt.trace.json".format(worker_name, int(time.time() * 1000))
  227. if use_gzip:
  228. file_name = file_name + '.gz'
  229. prof.export_chrome_trace(os.path.join(dir_name, file_name))
  230. return handler_fn
  231. class profile(_KinetoProfile):
  232. """Profiler context manager.
  233. Args:
  234. activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values:
  235. ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``.
  236. Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA.
  237. schedule (callable): callable that takes step (int) as a single parameter and returns
  238. ``ProfilerAction`` value that specifies the profiler action to perform at each step.
  239. on_trace_ready (callable): callable that is called at each step when ``schedule``
  240. returns ``ProfilerAction.RECORD_AND_SAVE`` during the profiling.
  241. record_shapes (bool): save information about operator's input shapes.
  242. profile_memory (bool): track tensor memory allocation/deallocation.
  243. with_stack (bool): record source information (file and line number) for the ops.
  244. with_flops (bool): use formula to estimate the FLOPs (floating point operations) of specific operators
  245. (matrix multiplication and 2D convolution).
  246. with_modules (bool): record module hierarchy (including function names)
  247. corresponding to the callstack of the op. e.g. If module A's forward call's
  248. module B's forward which contains an aten::add op,
  249. then aten::add's module hierarchy is A.B
  250. Note that this support exist, at the moment, only for TorchScript models
  251. and not eager mode models.
  252. experimental_config (_ExperimentalConfig) : A set of experimental options
  253. used for Kineto library features. Note, backward compatibility is not guaranteed.
  254. use_cuda (bool):
  255. .. deprecated:: 1.8.1
  256. use ``activities`` instead.
  257. .. note::
  258. Use :func:`~torch.profiler.schedule` to generate the callable schedule.
  259. Non-default schedules are useful when profiling long training jobs
  260. and allow the user to obtain multiple traces at the different iterations
  261. of the training process.
  262. The default schedule simply records all the events continuously for the
  263. duration of the context manager.
  264. .. note::
  265. Use :func:`~torch.profiler.tensorboard_trace_handler` to generate result files for TensorBoard:
  266. ``on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)``
  267. After profiling, result files can be found in the specified directory. Use the command:
  268. ``tensorboard --logdir dir_name``
  269. to see the results in TensorBoard.
  270. For more information, see
  271. `PyTorch Profiler TensorBoard Plugin <https://github.com/pytorch/kineto/tree/master/tb_plugin>`__
  272. .. note::
  273. Enabling shape and stack tracing results in additional overhead.
  274. When record_shapes=True is specified, profiler will temporarily hold references to the tensors;
  275. that may further prevent certain optimizations that depend on the reference count and introduce
  276. extra tensor copies.
  277. Examples:
  278. .. code-block:: python
  279. with torch.profiler.profile(
  280. activities=[
  281. torch.profiler.ProfilerActivity.CPU,
  282. torch.profiler.ProfilerActivity.CUDA,
  283. ]
  284. ) as p:
  285. code_to_profile()
  286. print(p.key_averages().table(
  287. sort_by="self_cuda_time_total", row_limit=-1))
  288. Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions:
  289. .. code-block:: python
  290. # Non-default profiler schedule allows user to turn profiler on and off
  291. # on different iterations of the training loop;
  292. # trace_handler is called every time a new trace becomes available
  293. def trace_handler(prof):
  294. print(prof.key_averages().table(
  295. sort_by="self_cuda_time_total", row_limit=-1))
  296. # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")
  297. with torch.profiler.profile(
  298. activities=[
  299. torch.profiler.ProfilerActivity.CPU,
  300. torch.profiler.ProfilerActivity.CUDA,
  301. ],
  302. # In this example with wait=1, warmup=1, active=2,
  303. # profiler will skip the first step/iteration,
  304. # start warming up on the second, record
  305. # the third and the forth iterations,
  306. # after which the trace will become available
  307. # and on_trace_ready (when set) is called;
  308. # the cycle repeats starting with the next step
  309. schedule=torch.profiler.schedule(
  310. wait=1,
  311. warmup=1,
  312. active=2),
  313. on_trace_ready=trace_handler
  314. # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
  315. # used when outputting for tensorboard
  316. ) as p:
  317. for iter in range(N):
  318. code_iteration_to_profile(iter)
  319. # send a signal to the profiler that the next iteration has started
  320. p.step()
  321. """
  322. def __init__(
  323. self,
  324. *,
  325. activities: Optional[Iterable[ProfilerActivity]] = None,
  326. schedule: Optional[Callable[[int], ProfilerAction]] = None,
  327. on_trace_ready: Optional[Callable[..., Any]] = None,
  328. record_shapes: bool = False,
  329. profile_memory: bool = False,
  330. with_stack: bool = False,
  331. with_flops: bool = False,
  332. with_modules: bool = False,
  333. experimental_config: Optional[_ExperimentalConfig] = None,
  334. # deprecated:
  335. use_cuda: Optional[bool] = None):
  336. activities_set = set(activities) if activities else supported_activities()
  337. if use_cuda is not None:
  338. warn("use_cuda is deprecated, use activities argument instead")
  339. if use_cuda:
  340. activities_set.add(ProfilerActivity.CUDA)
  341. elif ProfilerActivity.CUDA in activities_set:
  342. activities_set.remove(ProfilerActivity.CUDA)
  343. assert len(activities_set) > 0, "No valid profiler activities found"
  344. super().__init__(
  345. activities=activities,
  346. record_shapes=record_shapes,
  347. profile_memory=profile_memory,
  348. with_stack=with_stack,
  349. with_flops=with_flops,
  350. with_modules=with_modules,
  351. experimental_config=experimental_config,
  352. )
  353. if schedule:
  354. self.schedule = schedule
  355. # add step markers into the trace and table view
  356. self.record_steps = True
  357. else:
  358. self.schedule = _default_schedule_fn
  359. self.record_steps = False
  360. self.on_trace_ready = on_trace_ready
  361. self.step_num = 0
  362. self.current_action = self.schedule(self.step_num)
  363. self.step_rec_fn: Optional[prof.record_function] = None
  364. self.action_map: Dict[Tuple[ProfilerAction, Optional[ProfilerAction]], List[Any]] = {
  365. # key is (prev_action, current_action), value is action list corresponding to the state pair.
  366. (ProfilerAction.NONE, ProfilerAction.NONE): [],
  367. (ProfilerAction.NONE, ProfilerAction.WARMUP): [self.prepare_trace],
  368. (ProfilerAction.NONE, ProfilerAction.RECORD): [self.prepare_trace, self.start_trace],
  369. (ProfilerAction.NONE, ProfilerAction.RECORD_AND_SAVE): [self.prepare_trace, self.start_trace],
  370. (ProfilerAction.WARMUP, ProfilerAction.NONE): [
  371. partial(warn, "Incorrect schedule: WARMUP followed by NONE"),
  372. self.start_trace,
  373. self.stop_trace],
  374. (ProfilerAction.WARMUP, ProfilerAction.WARMUP): [],
  375. (ProfilerAction.WARMUP, ProfilerAction.RECORD): [self.start_trace],
  376. (ProfilerAction.WARMUP, ProfilerAction.RECORD_AND_SAVE): [self.start_trace],
  377. (ProfilerAction.RECORD, ProfilerAction.NONE): [
  378. partial(warn, "Incorrect schedule: RECORD followed by NONE"),
  379. self.stop_trace],
  380. (ProfilerAction.RECORD, ProfilerAction.WARMUP): [
  381. partial(warn, "Incorrect schedule: RECORD followed by WARMUP"),
  382. self.stop_trace],
  383. (ProfilerAction.RECORD, ProfilerAction.RECORD): [],
  384. (ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE): [],
  385. (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.NONE): [self.stop_trace, self._trace_ready],
  386. (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.WARMUP): [self.stop_trace, self._trace_ready, self.prepare_trace],
  387. (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD): [
  388. self.stop_trace,
  389. self._trace_ready,
  390. self.prepare_trace,
  391. self.start_trace],
  392. (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD_AND_SAVE): [
  393. self.stop_trace,
  394. self._trace_ready,
  395. self.prepare_trace,
  396. self.start_trace],
  397. # used for exit action
  398. (ProfilerAction.WARMUP, None): [self.start_trace, self.stop_trace],
  399. (ProfilerAction.RECORD, None): [self.stop_trace, self._trace_ready],
  400. (ProfilerAction.RECORD_AND_SAVE, None): [self.stop_trace, self._trace_ready]
  401. }
  402. def __enter__(self):
  403. self.start()
  404. return self
  405. def __exit__(self, exc_type, exc_val, exc_tb):
  406. self.stop()
  407. def start(self):
  408. self._transit_action(ProfilerAction.NONE, self.current_action)
  409. if self.record_steps:
  410. self.step_rec_fn = prof.record_function("ProfilerStep#" + str(self.step_num))
  411. self.step_rec_fn.__enter__()
  412. def stop(self):
  413. if self.record_steps and self.step_rec_fn:
  414. self.step_rec_fn.__exit__(None, None, None)
  415. self._transit_action(self.current_action, None)
  416. def step(self):
  417. """
  418. Signals the profiler that the next profiling step has started.
  419. """
  420. if self.record_steps and self.step_rec_fn:
  421. self.step_rec_fn.__exit__(None, None, None)
  422. prev_action = self.current_action
  423. cur_step = self.step_num
  424. self.step_num += 1
  425. self.current_action = self.schedule(self.step_num)
  426. self._transit_action(prev_action, self.current_action)
  427. prof.kineto_step()
  428. if self.record_steps:
  429. self.step_rec_fn = prof.record_function("ProfilerStep#" + str(cur_step))
  430. self.step_rec_fn.__enter__()
  431. def _trace_ready(self):
  432. if self.on_trace_ready:
  433. self.on_trace_ready(self)
  434. def _transit_action(self, prev_action, current_action):
  435. action_list = self.action_map.get((prev_action, current_action))
  436. if action_list:
  437. for action in action_list:
  438. action()