agg_merge.cc 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. // Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
  2. // This source code is licensed under both the GPLv2 (found in the
  3. // COPYING file in the root directory) and Apache 2.0 License
  4. // (found in the LICENSE.Apache file in the root directory).
  5. #include "rocksdb/utilities/agg_merge.h"
  6. #include <cassert>
  7. #include <deque>
  8. #include <memory>
  9. #include <type_traits>
  10. #include <utility>
  11. #include <vector>
  12. #include "port/lang.h"
  13. #include "port/likely.h"
  14. #include "rocksdb/merge_operator.h"
  15. #include "rocksdb/slice.h"
  16. #include "rocksdb/utilities/options_type.h"
  17. #include "util/coding.h"
  18. #include "utilities/agg_merge/agg_merge_impl.h"
  19. #include "utilities/merge_operators.h"
  20. namespace ROCKSDB_NAMESPACE {
  21. static std::unordered_map<std::string, std::unique_ptr<Aggregator>> func_map;
  22. const std::string kUnnamedFuncName;
  23. const std::string kErrorFuncName = "kErrorFuncName";
  24. Status AddAggregator(const std::string& function_name,
  25. std::unique_ptr<Aggregator>&& agg) {
  26. if (function_name == kErrorFuncName) {
  27. return Status::InvalidArgument(
  28. "Cannot register function name kErrorFuncName");
  29. }
  30. func_map.emplace(function_name, std::move(agg));
  31. return Status::OK();
  32. }
  33. AggMergeOperator::AggMergeOperator() = default;
  34. std::string EncodeAggFuncAndPayloadNoCheck(const Slice& function_name,
  35. const Slice& value) {
  36. std::string result;
  37. PutLengthPrefixedSlice(&result, function_name);
  38. result += value.ToString();
  39. return result;
  40. }
  41. Status EncodeAggFuncAndPayload(const Slice& function_name, const Slice& payload,
  42. std::string& output) {
  43. if (function_name == kErrorFuncName) {
  44. return Status::InvalidArgument("Cannot use error function name");
  45. }
  46. if (function_name != kUnnamedFuncName &&
  47. func_map.find(function_name.ToString()) == func_map.end()) {
  48. return Status::InvalidArgument("Function name not registered");
  49. }
  50. output = EncodeAggFuncAndPayloadNoCheck(function_name, payload);
  51. return Status::OK();
  52. }
  53. bool ExtractAggFuncAndValue(const Slice& op, Slice& func, Slice& value) {
  54. value = op;
  55. return GetLengthPrefixedSlice(&value, &func);
  56. }
  57. bool ExtractList(const Slice& encoded_list, std::vector<Slice>& decoded_list) {
  58. decoded_list.clear();
  59. Slice list_slice = encoded_list;
  60. Slice item;
  61. while (GetLengthPrefixedSlice(&list_slice, &item)) {
  62. decoded_list.push_back(item);
  63. }
  64. return list_slice.empty();
  65. }
  66. class AggMergeOperator::Accumulator {
  67. public:
  68. bool Add(const Slice& op, bool is_partial_aggregation) {
  69. if (ignore_operands_) {
  70. return true;
  71. }
  72. Slice my_func;
  73. Slice my_value;
  74. bool ret = ExtractAggFuncAndValue(op, my_func, my_value);
  75. if (!ret) {
  76. ignore_operands_ = true;
  77. return true;
  78. }
  79. // Determine whether we need to do partial merge.
  80. if (is_partial_aggregation && !my_func.empty()) {
  81. auto f = func_map.find(my_func.ToString());
  82. if (f == func_map.end() || !f->second->DoPartialAggregate()) {
  83. return false;
  84. }
  85. }
  86. if (!func_valid_) {
  87. if (my_func != kUnnamedFuncName) {
  88. func_ = my_func;
  89. func_valid_ = true;
  90. }
  91. } else if (func_ != my_func) {
  92. // User switched aggregation function. Need to aggregate the older
  93. // one first.
  94. // Previous aggreagion can't be done in partial merge
  95. if (is_partial_aggregation) {
  96. func_valid_ = false;
  97. ignore_operands_ = true;
  98. return false;
  99. }
  100. // We could consider stashing an iterator into the hash of aggregators
  101. // to avoid repeated lookups when the aggregator doesn't change.
  102. auto f = func_map.find(func_.ToString());
  103. if (f == func_map.end() || !f->second->Aggregate(values_, scratch_)) {
  104. func_valid_ = false;
  105. ignore_operands_ = true;
  106. return true;
  107. }
  108. std::swap(scratch_, aggregated_);
  109. values_.clear();
  110. values_.emplace_back(aggregated_);
  111. func_ = my_func;
  112. }
  113. values_.push_back(my_value);
  114. return true;
  115. }
  116. // Return false if aggregation fails.
  117. // One possible reason
  118. bool GetResult(std::string& result) {
  119. if (!func_valid_) {
  120. return false;
  121. }
  122. auto f = func_map.find(func_.ToString());
  123. if (f == func_map.end()) {
  124. return false;
  125. }
  126. if (!f->second->Aggregate(values_, scratch_)) {
  127. return false;
  128. }
  129. result = EncodeAggFuncAndPayloadNoCheck(func_, scratch_);
  130. return true;
  131. }
  132. void Clear() {
  133. func_.clear();
  134. values_.clear();
  135. aggregated_.clear();
  136. scratch_.clear();
  137. ignore_operands_ = false;
  138. func_valid_ = false;
  139. }
  140. private:
  141. Slice func_;
  142. std::vector<Slice> values_;
  143. std::string aggregated_;
  144. std::string scratch_;
  145. bool ignore_operands_ = false;
  146. bool func_valid_ = false;
  147. };
  148. // Creating and using a new Accumulator might invoke multiple malloc and is
  149. // expensive if it needs to be done when processing each merge operation.
  150. // AggMergeOperator's merge operators can be invoked concurrently by multiple
  151. // threads so we cannot simply create one Aggregator and reuse.
  152. // We use thread local instances instead.
  153. AggMergeOperator::Accumulator& AggMergeOperator::GetTLSAccumulator() {
  154. static thread_local Accumulator tls_acc;
  155. tls_acc.Clear();
  156. return tls_acc;
  157. }
  158. void AggMergeOperator::PackAllMergeOperands(const MergeOperationInput& merge_in,
  159. MergeOperationOutput& merge_out) {
  160. merge_out.new_value = "";
  161. PutLengthPrefixedSlice(&merge_out.new_value, kErrorFuncName);
  162. if (merge_in.existing_value != nullptr) {
  163. PutLengthPrefixedSlice(&merge_out.new_value, *merge_in.existing_value);
  164. }
  165. for (const Slice& op : merge_in.operand_list) {
  166. PutLengthPrefixedSlice(&merge_out.new_value, op);
  167. }
  168. }
  169. bool AggMergeOperator::FullMergeV2(const MergeOperationInput& merge_in,
  170. MergeOperationOutput* merge_out) const {
  171. Accumulator& agg = GetTLSAccumulator();
  172. if (merge_in.existing_value != nullptr) {
  173. agg.Add(*merge_in.existing_value, /*is_partial_aggregation=*/false);
  174. }
  175. for (const Slice& e : merge_in.operand_list) {
  176. agg.Add(e, /*is_partial_aggregation=*/false);
  177. }
  178. bool succ = agg.GetResult(merge_out->new_value);
  179. if (!succ) {
  180. // If aggregation can't happen, pack all merge operands. In contrast to
  181. // merge operator, we don't want to fail the DB. If users insert wrong
  182. // format or call unregistered an aggregation function, we still hope
  183. // the DB can continue functioning with other keys.
  184. PackAllMergeOperands(merge_in, *merge_out);
  185. }
  186. agg.Clear();
  187. return true;
  188. }
  189. bool AggMergeOperator::PartialMergeMulti(const Slice& /*key*/,
  190. const std::deque<Slice>& operand_list,
  191. std::string* new_value,
  192. Logger* /*logger*/) const {
  193. Accumulator& agg = GetTLSAccumulator();
  194. bool do_aggregation = true;
  195. for (const Slice& item : operand_list) {
  196. do_aggregation = agg.Add(item, /*is_partial_aggregation=*/true);
  197. if (!do_aggregation) {
  198. break;
  199. }
  200. }
  201. if (do_aggregation) {
  202. do_aggregation = agg.GetResult(*new_value);
  203. }
  204. agg.Clear();
  205. return do_aggregation;
  206. }
  207. std::shared_ptr<MergeOperator> GetAggMergeOperator() {
  208. STATIC_AVOID_DESTRUCTION(std::shared_ptr<MergeOperator>, instance)
  209. (std::make_shared<AggMergeOperator>());
  210. assert(instance);
  211. return instance;
  212. }
  213. } // namespace ROCKSDB_NAMESPACE