| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237 |
- // Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
- // This source code is licensed under both the GPLv2 (found in the
- // COPYING file in the root directory) and Apache 2.0 License
- // (found in the LICENSE.Apache file in the root directory).
- #include "rocksdb/utilities/agg_merge.h"
- #include <cassert>
- #include <deque>
- #include <memory>
- #include <type_traits>
- #include <utility>
- #include <vector>
- #include "port/lang.h"
- #include "port/likely.h"
- #include "rocksdb/merge_operator.h"
- #include "rocksdb/slice.h"
- #include "rocksdb/utilities/options_type.h"
- #include "util/coding.h"
- #include "utilities/agg_merge/agg_merge_impl.h"
- #include "utilities/merge_operators.h"
- namespace ROCKSDB_NAMESPACE {
- static std::unordered_map<std::string, std::unique_ptr<Aggregator>> func_map;
- const std::string kUnnamedFuncName;
- const std::string kErrorFuncName = "kErrorFuncName";
- Status AddAggregator(const std::string& function_name,
- std::unique_ptr<Aggregator>&& agg) {
- if (function_name == kErrorFuncName) {
- return Status::InvalidArgument(
- "Cannot register function name kErrorFuncName");
- }
- func_map.emplace(function_name, std::move(agg));
- return Status::OK();
- }
- AggMergeOperator::AggMergeOperator() = default;
- std::string EncodeAggFuncAndPayloadNoCheck(const Slice& function_name,
- const Slice& value) {
- std::string result;
- PutLengthPrefixedSlice(&result, function_name);
- result += value.ToString();
- return result;
- }
- Status EncodeAggFuncAndPayload(const Slice& function_name, const Slice& payload,
- std::string& output) {
- if (function_name == kErrorFuncName) {
- return Status::InvalidArgument("Cannot use error function name");
- }
- if (function_name != kUnnamedFuncName &&
- func_map.find(function_name.ToString()) == func_map.end()) {
- return Status::InvalidArgument("Function name not registered");
- }
- output = EncodeAggFuncAndPayloadNoCheck(function_name, payload);
- return Status::OK();
- }
- bool ExtractAggFuncAndValue(const Slice& op, Slice& func, Slice& value) {
- value = op;
- return GetLengthPrefixedSlice(&value, &func);
- }
- bool ExtractList(const Slice& encoded_list, std::vector<Slice>& decoded_list) {
- decoded_list.clear();
- Slice list_slice = encoded_list;
- Slice item;
- while (GetLengthPrefixedSlice(&list_slice, &item)) {
- decoded_list.push_back(item);
- }
- return list_slice.empty();
- }
- class AggMergeOperator::Accumulator {
- public:
- bool Add(const Slice& op, bool is_partial_aggregation) {
- if (ignore_operands_) {
- return true;
- }
- Slice my_func;
- Slice my_value;
- bool ret = ExtractAggFuncAndValue(op, my_func, my_value);
- if (!ret) {
- ignore_operands_ = true;
- return true;
- }
- // Determine whether we need to do partial merge.
- if (is_partial_aggregation && !my_func.empty()) {
- auto f = func_map.find(my_func.ToString());
- if (f == func_map.end() || !f->second->DoPartialAggregate()) {
- return false;
- }
- }
- if (!func_valid_) {
- if (my_func != kUnnamedFuncName) {
- func_ = my_func;
- func_valid_ = true;
- }
- } else if (func_ != my_func) {
- // User switched aggregation function. Need to aggregate the older
- // one first.
- // Previous aggreagion can't be done in partial merge
- if (is_partial_aggregation) {
- func_valid_ = false;
- ignore_operands_ = true;
- return false;
- }
- // We could consider stashing an iterator into the hash of aggregators
- // to avoid repeated lookups when the aggregator doesn't change.
- auto f = func_map.find(func_.ToString());
- if (f == func_map.end() || !f->second->Aggregate(values_, scratch_)) {
- func_valid_ = false;
- ignore_operands_ = true;
- return true;
- }
- std::swap(scratch_, aggregated_);
- values_.clear();
- values_.emplace_back(aggregated_);
- func_ = my_func;
- }
- values_.push_back(my_value);
- return true;
- }
- // Return false if aggregation fails.
- // One possible reason
- bool GetResult(std::string& result) {
- if (!func_valid_) {
- return false;
- }
- auto f = func_map.find(func_.ToString());
- if (f == func_map.end()) {
- return false;
- }
- if (!f->second->Aggregate(values_, scratch_)) {
- return false;
- }
- result = EncodeAggFuncAndPayloadNoCheck(func_, scratch_);
- return true;
- }
- void Clear() {
- func_.clear();
- values_.clear();
- aggregated_.clear();
- scratch_.clear();
- ignore_operands_ = false;
- func_valid_ = false;
- }
- private:
- Slice func_;
- std::vector<Slice> values_;
- std::string aggregated_;
- std::string scratch_;
- bool ignore_operands_ = false;
- bool func_valid_ = false;
- };
- // Creating and using a new Accumulator might invoke multiple malloc and is
- // expensive if it needs to be done when processing each merge operation.
- // AggMergeOperator's merge operators can be invoked concurrently by multiple
- // threads so we cannot simply create one Aggregator and reuse.
- // We use thread local instances instead.
- AggMergeOperator::Accumulator& AggMergeOperator::GetTLSAccumulator() {
- static thread_local Accumulator tls_acc;
- tls_acc.Clear();
- return tls_acc;
- }
- void AggMergeOperator::PackAllMergeOperands(const MergeOperationInput& merge_in,
- MergeOperationOutput& merge_out) {
- merge_out.new_value = "";
- PutLengthPrefixedSlice(&merge_out.new_value, kErrorFuncName);
- if (merge_in.existing_value != nullptr) {
- PutLengthPrefixedSlice(&merge_out.new_value, *merge_in.existing_value);
- }
- for (const Slice& op : merge_in.operand_list) {
- PutLengthPrefixedSlice(&merge_out.new_value, op);
- }
- }
- bool AggMergeOperator::FullMergeV2(const MergeOperationInput& merge_in,
- MergeOperationOutput* merge_out) const {
- Accumulator& agg = GetTLSAccumulator();
- if (merge_in.existing_value != nullptr) {
- agg.Add(*merge_in.existing_value, /*is_partial_aggregation=*/false);
- }
- for (const Slice& e : merge_in.operand_list) {
- agg.Add(e, /*is_partial_aggregation=*/false);
- }
- bool succ = agg.GetResult(merge_out->new_value);
- if (!succ) {
- // If aggregation can't happen, pack all merge operands. In contrast to
- // merge operator, we don't want to fail the DB. If users insert wrong
- // format or call unregistered an aggregation function, we still hope
- // the DB can continue functioning with other keys.
- PackAllMergeOperands(merge_in, *merge_out);
- }
- agg.Clear();
- return true;
- }
- bool AggMergeOperator::PartialMergeMulti(const Slice& /*key*/,
- const std::deque<Slice>& operand_list,
- std::string* new_value,
- Logger* /*logger*/) const {
- Accumulator& agg = GetTLSAccumulator();
- bool do_aggregation = true;
- for (const Slice& item : operand_list) {
- do_aggregation = agg.Add(item, /*is_partial_aggregation=*/true);
- if (!do_aggregation) {
- break;
- }
- }
- if (do_aggregation) {
- do_aggregation = agg.GetResult(*new_value);
- }
- agg.Clear();
- return do_aggregation;
- }
- std::shared_ptr<MergeOperator> GetAggMergeOperator() {
- STATIC_AVOID_DESTRUCTION(std::shared_ptr<MergeOperator>, instance)
- (std::make_shared<AggMergeOperator>());
- assert(instance);
- return instance;
- }
- } // namespace ROCKSDB_NAMESPACE
|