ParamCommsUtils.hpp 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #pragma once
  2. #include <string>
  3. #include <vector>
  4. #include <c10/macros/Macros.h>
  5. #include <c10/util/ThreadLocalDebugInfo.h>
  6. #include <ATen/record_function.h>
  7. #include <ATen/core/ivalue.h>
  8. namespace torch {
  9. extern TORCH_API const std::string kParamCommsCallName;
  10. class TORCH_API ParamCommsDebugInfo
  11. : public c10::DebugInfoBase {
  12. public:
  13. ParamCommsDebugInfo() = default;
  14. ParamCommsDebugInfo(
  15. int rank,
  16. std::string&& colName,
  17. int inSize,
  18. int outSize,
  19. at::ScalarType dType,
  20. std::vector<int64_t> inSplitSizes,
  21. std::vector<int64_t> outSplitSizes);
  22. ~ParamCommsDebugInfo() override = default;
  23. int getRank() const {
  24. return rank_;
  25. }
  26. const std::string getColumnName() const {
  27. return columnName_;
  28. }
  29. int getInMessageSize() const {
  30. return inMessageSize_;
  31. }
  32. int getOutMessageSize() const {
  33. return outMessageSize_;
  34. }
  35. at::ScalarType getDType() const {
  36. return dType_;
  37. }
  38. const std::vector<int64_t>& getInputSplitSizes() const {
  39. return inputSplitSizes_;
  40. }
  41. const std::vector<int64_t>& getOutputSplitSizes() const {
  42. return outputSplitSizes_;
  43. }
  44. private:
  45. int rank_{};
  46. std::string columnName_;
  47. int inMessageSize_{};
  48. int outMessageSize_{};
  49. at::ScalarType dType_ = at::kByte;
  50. std::vector<int64_t> inputSplitSizes_;
  51. std::vector<int64_t> outputSplitSizes_;
  52. };
  53. #define RECORD_PARAM_COMMS(rank, colName, inSize, outSize, dType, inSplitSizes, outSplitSizes) \
  54. auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \
  55. rank, \
  56. colName, \
  57. inSize, \
  58. outSize, \
  59. dType, \
  60. inSplitSizes, \
  61. outSplitSizes); \
  62. c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
  63. RECORD_FUNCTION(torch::kParamCommsCallName, std::vector<c10::IValue>());
  64. } // namespace torch