#pragma once #include #include #include #include #include #include #include #include namespace c10d { inline std::string getTraceStartKey(const std::string& pgName, int rank) { return pgName + "_" + std::to_string(rank) + "_trace_start"; } inline std::string getTraceEndKey(const std::string& pgName, int rank) { return pgName + "_" + std::to_string(rank) + "_trace_end"; } inline bool traceUpdate( c10::intrusive_ptr& store, const std::string& key, uint64_t seq, const std::string& col) { std::vector value(col.size() + sizeof(seq) + 1); memcpy(value.data(), &seq, sizeof(seq)); memcpy(value.data() + sizeof(seq), col.data(), col.size()); try { store->set(key, value); return true; } catch (...) { LOG(ERROR) << "Store is down while updating #" << seq << " with key " << key; return false; } return true; } enum TraceDebugEvent { kEventStart, kEventEnd, }; // >> using TraceMap = std::map>>; inline std::string ranksToString(const std::vector& ranks) { std::string str; for (int rank : ranks) { if (str.empty()) { str = std::to_string(rank); } else { str += ", " + std::to_string(rank); } } return str; } inline std::string ranksFromTrace( const std::vector>& items) { std::string ranks; for (auto& p : items) { if (ranks.empty()) { ranks = std::to_string(p.first); } else { ranks += ", " + std::to_string(p.first); } } return ranks; } inline std::string analyzeMissingRanks(const std::vector& missingRanks) { return c10::str( "\n\t - To our best knowledge, ranks [", ranksToString(missingRanks), "] are the lagging ranks that caused this timeout. " "They never joined any collectives"); } inline std::string analyzeLaggingRanks(const TraceMap& traceMap) { uint64_t lagSeq = traceMap.begin()->first; std::vector startRanks; std::vector endRanks; for (auto& p : traceMap.begin()->second) { if (p.second.second == kEventStart) { startRanks.push_back(p.first); } else { endRanks.push_back(p.first); } } std::string report = "\n\t - To our best knowledge, the lagging/dead/mismatched ranks " "that caused the desync are:"; if (startRanks.size()) { report += c10::str( "\n\t - [", ranksToString(startRanks), "] joined but didn't finish collective #", lagSeq, " (count from 1)"); } if (endRanks.size()) { report += c10::str( "\n\t [", ranksToString(endRanks), "] finished collective #", lagSeq, ", but didn't join collective #", lagSeq + 1, " (count from 1)"); } return report; } inline std::string dumpSnapshot(TraceMap& traceMap) { std::string report = "\n\t - Snapshot of ranks' latest states:"; for (auto& tracePair : traceMap) { uint64_t seq = tracePair.first; std::map>& subMap = tracePair.second; std::unordered_map> collectivesStart; std::unordered_map> collectivesEnd; for (auto& p : subMap) { int rank = p.first; const std::string& col = p.second.first; if (p.second.second == kEventStart) { collectivesStart[col].push_back(rank); } else { collectivesEnd[col].push_back(rank); } } if (collectivesStart.size()) { report += c10::str("\n\t #", seq, " started ranks:"); for (auto& mapPair : collectivesStart) { report += c10::str( "\n\t [", ranksToString(mapPair.second), "] started ", mapPair.first); } } if (collectivesEnd.size()) { report += c10::str("\n\t #", seq, " finished ranks:"); for (auto& mapPair : collectivesEnd) { report += c10::str( "\n\t [", ranksToString(mapPair.second), "] finished ", mapPair.first); } } } return report; } inline bool parseTraceValue( c10::intrusive_ptr& store, const std::string& key, uint64_t& seq, std::string& col) { try { std::vector traceValue = store->get(key); memcpy(&seq, traceValue.data(), sizeof(seq)); std::string colName((char*)traceValue.data() + sizeof(seq)); col = colName; return true; } catch (...) { LOG(ERROR) << "Store is down while getting key " << key; return false; } return true; } inline std::string retrieveDesyncReport( c10::intrusive_ptr& store, const std::string& pgName, int myRank, int worldSize) { std::string report; uint64_t thisSeq; std::string thisCol; std::vector missingRanks; TraceMap traceMap; for (const auto rank : c10::irange(worldSize)) { // Build traceMapStart. uint64_t seqStart; { std::string traceKeyStart = getTraceStartKey(pgName, rank); if (!store->check({traceKeyStart})) { missingRanks.push_back(rank); continue; } std::string col; if (!parseTraceValue(store, traceKeyStart, seqStart, col)) { return report; } traceMap[seqStart].emplace(rank, std::make_pair(col, kEventStart)); if (rank == myRank) { thisSeq = seqStart; thisCol = std::move(col); } } // Build traceMapEnd. { std::string traceKeyEnd = getTraceEndKey(pgName, rank); if (!store->check({traceKeyEnd})) { continue; } uint64_t seq; std::string col; if (!parseTraceValue(store, traceKeyEnd, seq, col)) { return report; } if (seq == seqStart) { traceMap[seq][rank].second = kEventEnd; } } } TORCH_INTERNAL_ASSERT( !missingRanks.empty() || !traceMap.empty(), "Trace shouldn't be empty while enabled GLOO_ASYNC_TIMEOUT_DEBUG"); TORCH_INTERNAL_ASSERT( !thisCol.empty(), "Timeout rank [", myRank, "] must have collective tracking iteam in c10::Store trace"); TORCH_INTERNAL_ASSERT( traceMap[thisSeq][myRank].second == kEventStart, "Timeout rank [", myRank, "] last trace item must be kEventStart. thisSeq = ", thisSeq, ", col = ", thisCol); report += c10::str( "\n\t - [", myRank, "] Timeout at collective: ", thisCol, ", #", thisSeq); if (!missingRanks.empty()) { report += analyzeMissingRanks(missingRanks); } else { report += analyzeLaggingRanks(traceMap); report += dumpSnapshot(traceMap); } return report; } } // namespace c10d