#pragma once #include #include #include #include #include #include namespace torch { extern TORCH_API const std::string kParamCommsCallName; class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase { public: ParamCommsDebugInfo() = default; ParamCommsDebugInfo( int rank, std::string&& colName, int inSize, int outSize, at::ScalarType dType, std::vector inSplitSizes, std::vector outSplitSizes); ~ParamCommsDebugInfo() override = default; int getRank() const { return rank_; } const std::string getColumnName() const { return columnName_; } int getInMessageSize() const { return inMessageSize_; } int getOutMessageSize() const { return outMessageSize_; } at::ScalarType getDType() const { return dType_; } const std::vector& getInputSplitSizes() const { return inputSplitSizes_; } const std::vector& getOutputSplitSizes() const { return outputSplitSizes_; } private: int rank_{}; std::string columnName_; int inMessageSize_{}; int outMessageSize_{}; at::ScalarType dType_ = at::kByte; std::vector inputSplitSizes_; std::vector outputSplitSizes_; }; #define RECORD_PARAM_COMMS(rank, colName, inSize, outSize, dType, inSplitSizes, outSplitSizes) \ auto paramCommsInfo = std::make_shared( \ rank, \ colName, \ inSize, \ outSize, \ dType, \ inSplitSizes, \ outSplitSizes); \ c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \ RECORD_FUNCTION(torch::kParamCommsCallName, std::vector()); } // namespace torch