logger.hpp 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. #include <c10/util/Logging.h>
  2. #include <c10d/reducer.hpp>
  3. #include <mutex>
  4. namespace c10d {
  5. class TORCH_API Logger {
  6. public:
  7. explicit Logger(std::shared_ptr<c10d::Reducer> reducer);
  8. // Set logging data that can be got during DistributedDataParallel
  9. // construction time.
  10. void set_construction_data_and_log(
  11. const std::string& module_name,
  12. const std::vector<int>& device_ids,
  13. int output_device,
  14. bool broadcast_buffers,
  15. bool has_sync_bn,
  16. bool static_graph
  17. );
  18. void set_static_graph();
  19. // An interface for users to get DDPLoggingData and log them
  20. // in the applications. Explanation of logging fields are in
  21. // "struct DDPLoggingData" of "torch/c10/util/Logging.h".
  22. at::DDPLoggingData get_ddp_logging_data();
  23. // Stream insertion operator for logging data to stream under
  24. // TORCH_DISTRIBUTED_DEBUG.
  25. friend std::ostream& operator<<(std::ostream& output, const Logger& logger);
  26. ~Logger() noexcept(false) {
  27. // Log if DDP graph is static in Logger dtor instead of Reducer dtor since
  28. // Logger is deleted before Reducer.
  29. log_if_graph_static(reducer_->ddp_graph_static());
  30. }
  31. // Set environment variables.
  32. void set_env_variables();
  33. // Set parameters stats.
  34. void set_parameter_stats();
  35. // Get size of each bucket (Bytes).
  36. std::vector<int64_t> get_bucket_sizes();
  37. // Get variable indices for each bucket.
  38. std::vector<std::vector<size_t>> get_per_bucket_variable_indices();
  39. // Set comm. hook, if used
  40. void set_comm_hook(const std::string& hook);
  41. // Set running with uneven input detection (model.join() context manager)
  42. void set_uneven_input_join();
  43. // Reset performance stats at current iteration
  44. void reset_performance_stats();
  45. // Calculate avg stats using cpu timer and gpu timer
  46. // that has been recorded in reducer.
  47. void calculate_avg_time(
  48. int64_t& avg_time,
  49. int64_t& time_duration,
  50. Timer& timer,
  51. Timer::Event start_event,
  52. Timer::Event end_event);
  53. // Set the absolute time of the event that has been recorded in reducer.
  54. void set_event_time(
  55. int64_t& event_time,
  56. Timer& timer,
  57. Timer::Event event
  58. );
  59. // Set stats that can be collected only during
  60. // training loop. It is called at the beginning of forward call
  61. // to record the run time stats of sampled iterations that previouly ran.
  62. // GPU performance stats are collected only for single process
  63. // single device program and single device module right now.
  64. // TODO to support single process multiple devices and multi device modules,
  65. // events need to be created and recorded on multiple devices.
  66. void set_runtime_stats_and_log();
  67. // Called when DDP/reducer is failing with an error. The
  68. // logging data structure will have two fields filled: "has_error" indicating
  69. // that this iteration encountered an error and other fields are not valid,
  70. // and "error", a string which contains the error message that DDP failed
  71. // with.
  72. template <typename... Args>
  73. void set_error_and_log(const std::string& ddp_error, const Args&... args) {
  74. ddp_logging_data_->ints_map["has_error"] = 1;
  75. auto err = c10::str(ddp_error, args...);
  76. ddp_logging_data_->strs_map["error"] = err;
  77. // Report the iteration we are erroring at so user knows how many examples
  78. // successfully processed before this error was hit.
  79. ddp_logging_data_->ints_map["iteration"] = reducer_->num_iterations_;
  80. at::LogPyTorchDDPUsage(*ddp_logging_data_);
  81. }
  82. // When running without static graph, called when reducer is destroyed to log
  83. // if graph was actually static and is a candidate for static graph
  84. // optimization.
  85. void log_if_graph_static(bool is_static);
  86. private:
  87. // ddp_logging_data_ is used to hold all the ddp related logging
  88. // data fields.
  89. std::unique_ptr<at::DDPLoggingData> ddp_logging_data_;
  90. std::shared_ptr<c10d::Reducer> reducer_;
  91. // track the number of iterations when runtime stats are collected so far.
  92. long num_iterations_stats_recorded_ = 0;
  93. };
  94. } // namespace c10d