| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- #include <c10/util/Logging.h>
- #include <c10d/reducer.hpp>
- #include <mutex>
- namespace c10d {
- class TORCH_API Logger {
- public:
- explicit Logger(std::shared_ptr<c10d::Reducer> reducer);
- // Set logging data that can be got during DistributedDataParallel
- // construction time.
- void set_construction_data_and_log(
- const std::string& module_name,
- const std::vector<int>& device_ids,
- int output_device,
- bool broadcast_buffers,
- bool has_sync_bn,
- bool static_graph
- );
- void set_static_graph();
- // An interface for users to get DDPLoggingData and log them
- // in the applications. Explanation of logging fields are in
- // "struct DDPLoggingData" of "torch/c10/util/Logging.h".
- at::DDPLoggingData get_ddp_logging_data();
- // Stream insertion operator for logging data to stream under
- // TORCH_DISTRIBUTED_DEBUG.
- friend std::ostream& operator<<(std::ostream& output, const Logger& logger);
- ~Logger() noexcept(false) {
- // Log if DDP graph is static in Logger dtor instead of Reducer dtor since
- // Logger is deleted before Reducer.
- log_if_graph_static(reducer_->ddp_graph_static());
- }
- // Set environment variables.
- void set_env_variables();
- // Set parameters stats.
- void set_parameter_stats();
- // Get size of each bucket (Bytes).
- std::vector<int64_t> get_bucket_sizes();
- // Get variable indices for each bucket.
- std::vector<std::vector<size_t>> get_per_bucket_variable_indices();
- // Set comm. hook, if used
- void set_comm_hook(const std::string& hook);
- // Set running with uneven input detection (model.join() context manager)
- void set_uneven_input_join();
- // Reset performance stats at current iteration
- void reset_performance_stats();
- // Calculate avg stats using cpu timer and gpu timer
- // that has been recorded in reducer.
- void calculate_avg_time(
- int64_t& avg_time,
- int64_t& time_duration,
- Timer& timer,
- Timer::Event start_event,
- Timer::Event end_event);
- // Set the absolute time of the event that has been recorded in reducer.
- void set_event_time(
- int64_t& event_time,
- Timer& timer,
- Timer::Event event
- );
- // Set stats that can be collected only during
- // training loop. It is called at the beginning of forward call
- // to record the run time stats of sampled iterations that previouly ran.
- // GPU performance stats are collected only for single process
- // single device program and single device module right now.
- // TODO to support single process multiple devices and multi device modules,
- // events need to be created and recorded on multiple devices.
- void set_runtime_stats_and_log();
- // Called when DDP/reducer is failing with an error. The
- // logging data structure will have two fields filled: "has_error" indicating
- // that this iteration encountered an error and other fields are not valid,
- // and "error", a string which contains the error message that DDP failed
- // with.
- template <typename... Args>
- void set_error_and_log(const std::string& ddp_error, const Args&... args) {
- ddp_logging_data_->ints_map["has_error"] = 1;
- auto err = c10::str(ddp_error, args...);
- ddp_logging_data_->strs_map["error"] = err;
- // Report the iteration we are erroring at so user knows how many examples
- // successfully processed before this error was hit.
- ddp_logging_data_->ints_map["iteration"] = reducer_->num_iterations_;
- at::LogPyTorchDDPUsage(*ddp_logging_data_);
- }
- // When running without static graph, called when reducer is destroyed to log
- // if graph was actually static and is a candidate for static graph
- // optimization.
- void log_if_graph_static(bool is_static);
- private:
- // ddp_logging_data_ is used to hold all the ddp related logging
- // data fields.
- std::unique_ptr<at::DDPLoggingData> ddp_logging_data_;
- std::shared_ptr<c10d::Reducer> reducer_;
- // track the number of iterations when runtime stats are collected so far.
- long num_iterations_stats_recorded_ = 0;
- };
- } // namespace c10d
|