write_callback_test.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. // Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
  2. // This source code is licensed under both the GPLv2 (found in the
  3. // COPYING file in the root directory) and Apache 2.0 License
  4. // (found in the LICENSE.Apache file in the root directory).
  5. #include "db/write_callback.h"
  6. #include <atomic>
  7. #include <functional>
  8. #include <string>
  9. #include <utility>
  10. #include <vector>
  11. #include "db/db_impl/db_impl.h"
  12. #include "port/port.h"
  13. #include "rocksdb/db.h"
  14. #include "rocksdb/user_write_callback.h"
  15. #include "rocksdb/write_batch.h"
  16. #include "test_util/sync_point.h"
  17. #include "test_util/testharness.h"
  18. #include "util/random.h"
  19. using std::string;
  20. namespace ROCKSDB_NAMESPACE {
  21. class WriteCallbackTest : public testing::Test {
  22. public:
  23. string dbname;
  24. WriteCallbackTest() {
  25. dbname = test::PerThreadDBPath("write_callback_testdb");
  26. }
  27. };
  28. class WriteCallbackTestWriteCallback1 : public WriteCallback {
  29. public:
  30. bool was_called = false;
  31. Status Callback(DB* db) override {
  32. was_called = true;
  33. // Make sure db is a DBImpl
  34. DBImpl* db_impl = dynamic_cast<DBImpl*>(db);
  35. if (db_impl == nullptr) {
  36. return Status::InvalidArgument("");
  37. }
  38. return Status::OK();
  39. }
  40. bool AllowWriteBatching() override { return true; }
  41. };
  42. class WriteCallbackTestWriteCallback2 : public WriteCallback {
  43. public:
  44. Status Callback(DB* /*db*/) override { return Status::Busy(); }
  45. bool AllowWriteBatching() override { return true; }
  46. };
  47. class MockWriteCallback : public WriteCallback {
  48. public:
  49. bool should_fail_ = false;
  50. bool allow_batching_ = false;
  51. std::atomic<bool> was_called_{false};
  52. MockWriteCallback() = default;
  53. MockWriteCallback(const MockWriteCallback& other) {
  54. should_fail_ = other.should_fail_;
  55. allow_batching_ = other.allow_batching_;
  56. was_called_.store(other.was_called_.load());
  57. }
  58. Status Callback(DB* /*db*/) override {
  59. was_called_.store(true);
  60. if (should_fail_) {
  61. return Status::Busy();
  62. } else {
  63. return Status::OK();
  64. }
  65. }
  66. bool AllowWriteBatching() override { return allow_batching_; }
  67. };
  68. class MockUserWriteCallback : public UserWriteCallback {
  69. public:
  70. std::atomic<bool> write_enqueued_{false};
  71. std::atomic<bool> wal_write_done_{false};
  72. MockUserWriteCallback() = default;
  73. MockUserWriteCallback(const MockUserWriteCallback& other) {
  74. write_enqueued_.store(other.write_enqueued_.load());
  75. wal_write_done_.store(other.wal_write_done_.load());
  76. }
  77. void OnWriteEnqueued() override { write_enqueued_.store(true); }
  78. void OnWalWriteFinish() override { wal_write_done_.store(true); }
  79. void Reset() {
  80. write_enqueued_.store(false);
  81. wal_write_done_.store(false);
  82. }
  83. };
  84. #if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
  85. class WriteCallbackPTest
  86. : public WriteCallbackTest,
  87. public ::testing::WithParamInterface<
  88. std::tuple<bool, bool, bool, bool, bool, bool, bool>> {
  89. public:
  90. WriteCallbackPTest() {
  91. std::tie(unordered_write_, seq_per_batch_, two_queues_, allow_parallel_,
  92. allow_batching_, enable_WAL_, enable_pipelined_write_) =
  93. GetParam();
  94. }
  95. protected:
  96. bool unordered_write_;
  97. bool seq_per_batch_;
  98. bool two_queues_;
  99. bool allow_parallel_;
  100. bool allow_batching_;
  101. bool enable_WAL_;
  102. bool enable_pipelined_write_;
  103. };
  104. TEST_P(WriteCallbackPTest, WriteWithCallbackTest) {
  105. struct WriteOP {
  106. WriteOP(bool should_fail = false) { callback_.should_fail_ = should_fail; }
  107. void Put(const string& key, const string& val) {
  108. kvs_.emplace_back(key, val);
  109. ASSERT_OK(write_batch_.Put(key, val));
  110. }
  111. void Clear() {
  112. kvs_.clear();
  113. write_batch_.Clear();
  114. callback_.was_called_.store(false);
  115. user_write_cb_.Reset();
  116. }
  117. MockWriteCallback callback_;
  118. MockUserWriteCallback user_write_cb_;
  119. WriteBatch write_batch_;
  120. std::vector<std::pair<string, string>> kvs_;
  121. };
  122. // In each scenario we'll launch multiple threads to write.
  123. // The size of each array equals to number of threads, and
  124. // each boolean in it denote whether callback of corresponding
  125. // thread should succeed or fail.
  126. std::vector<std::vector<WriteOP>> write_scenarios = {
  127. {true},
  128. {false},
  129. {false, false},
  130. {true, true},
  131. {true, false},
  132. {false, true},
  133. {false, false, false},
  134. {true, true, true},
  135. {false, true, false},
  136. {true, false, true},
  137. {true, false, false, false, false},
  138. {false, false, false, false, true},
  139. {false, false, true, false, true},
  140. };
  141. for (auto& write_group : write_scenarios) {
  142. Options options;
  143. options.create_if_missing = true;
  144. options.unordered_write = unordered_write_;
  145. options.allow_concurrent_memtable_write = allow_parallel_;
  146. options.enable_pipelined_write = enable_pipelined_write_;
  147. options.two_write_queues = two_queues_;
  148. // Skip unsupported combinations
  149. if (options.enable_pipelined_write && seq_per_batch_) {
  150. continue;
  151. }
  152. if (options.enable_pipelined_write && options.two_write_queues) {
  153. continue;
  154. }
  155. if (options.unordered_write && !options.allow_concurrent_memtable_write) {
  156. continue;
  157. }
  158. if (options.unordered_write && options.enable_pipelined_write) {
  159. continue;
  160. }
  161. ReadOptions read_options;
  162. std::unique_ptr<DB> db;
  163. DBImpl* db_impl;
  164. ASSERT_OK(DestroyDB(dbname, options));
  165. DBOptions db_options(options);
  166. ColumnFamilyOptions cf_options(options);
  167. std::vector<ColumnFamilyDescriptor> column_families;
  168. column_families.emplace_back(kDefaultColumnFamilyName, cf_options);
  169. std::vector<ColumnFamilyHandle*> handles;
  170. auto open_s = DBImpl::Open(db_options, dbname, column_families, &handles,
  171. &db, seq_per_batch_, true /* batch_per_txn */,
  172. false /* is_retry */, nullptr /* can_retry */);
  173. ASSERT_OK(open_s);
  174. assert(handles.size() == 1);
  175. delete handles[0];
  176. db_impl = dynamic_cast<DBImpl*>(db.get());
  177. ASSERT_TRUE(db_impl);
  178. // Writers that have called JoinBatchGroup.
  179. std::atomic<uint64_t> threads_joining(0);
  180. // Writers that have linked to the queue
  181. std::atomic<uint64_t> threads_linked(0);
  182. // Writers that pass WriteThread::JoinBatchGroup:Wait sync-point.
  183. std::atomic<uint64_t> threads_verified(0);
  184. std::atomic<uint64_t> seq(db_impl->GetLatestSequenceNumber());
  185. ASSERT_EQ(db_impl->GetLatestSequenceNumber(), 0);
  186. ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
  187. "WriteThread::JoinBatchGroup:Start", [&](void*) {
  188. uint64_t cur_threads_joining = threads_joining.fetch_add(1);
  189. // Wait for the last joined writer to link to the queue.
  190. // In this way the writers link to the queue one by one.
  191. // This allows us to confidently detect the first writer
  192. // who increases threads_linked as the leader.
  193. while (threads_linked.load() < cur_threads_joining) {
  194. }
  195. });
  196. // Verification once writers call JoinBatchGroup.
  197. ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
  198. "WriteThread::JoinBatchGroup:Wait", [&](void* arg) {
  199. uint64_t cur_threads_linked = threads_linked.fetch_add(1);
  200. bool is_leader = false;
  201. bool is_last = false;
  202. // who am i
  203. is_leader = (cur_threads_linked == 0);
  204. is_last = (cur_threads_linked == write_group.size() - 1);
  205. // check my state
  206. auto* writer = static_cast<WriteThread::Writer*>(arg);
  207. if (is_leader) {
  208. ASSERT_TRUE(writer->state ==
  209. WriteThread::State::STATE_GROUP_LEADER);
  210. } else {
  211. ASSERT_TRUE(writer->state == WriteThread::State::STATE_INIT);
  212. }
  213. // (meta test) the first WriteOP should indeed be the first
  214. // and the last should be the last (all others can be out of
  215. // order)
  216. if (is_leader) {
  217. ASSERT_TRUE(writer->callback->Callback(nullptr).ok() ==
  218. !write_group.front().callback_.should_fail_);
  219. } else if (is_last) {
  220. ASSERT_TRUE(writer->callback->Callback(nullptr).ok() ==
  221. !write_group.back().callback_.should_fail_);
  222. }
  223. threads_verified.fetch_add(1);
  224. // Wait here until all verification in this sync-point
  225. // callback finish for all writers.
  226. while (threads_verified.load() < write_group.size()) {
  227. }
  228. });
  229. ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
  230. "WriteThread::JoinBatchGroup:DoneWaiting", [&](void* arg) {
  231. // check my state
  232. auto* writer = static_cast<WriteThread::Writer*>(arg);
  233. if (!allow_batching_) {
  234. // no batching so everyone should be a leader
  235. ASSERT_TRUE(writer->state ==
  236. WriteThread::State::STATE_GROUP_LEADER);
  237. } else if (!allow_parallel_) {
  238. ASSERT_TRUE(writer->state == WriteThread::State::STATE_COMPLETED ||
  239. (enable_pipelined_write_ &&
  240. writer->state ==
  241. WriteThread::State::STATE_MEMTABLE_WRITER_LEADER));
  242. }
  243. });
  244. std::atomic<uint32_t> thread_num(0);
  245. std::atomic<char> dummy_key(0);
  246. // Each write thread create a random write batch and write to DB
  247. // with a write callback.
  248. std::function<void()> write_with_callback_func = [&]() {
  249. uint32_t i = thread_num.fetch_add(1);
  250. Random rnd(i);
  251. // leaders gotta lead
  252. while (i > 0 && threads_verified.load() < 1) {
  253. }
  254. // loser has to lose
  255. while (i == write_group.size() - 1 &&
  256. threads_verified.load() < write_group.size() - 1) {
  257. }
  258. auto& write_op = write_group.at(i);
  259. write_op.Clear();
  260. write_op.callback_.allow_batching_ = allow_batching_;
  261. // insert some keys
  262. for (uint32_t j = 0; j < rnd.Next() % 50; j++) {
  263. // grab unique key
  264. char my_key = dummy_key.fetch_add(1);
  265. string skey(5, my_key);
  266. string sval(10, my_key);
  267. write_op.Put(skey, sval);
  268. if (!write_op.callback_.should_fail_ && !seq_per_batch_) {
  269. seq.fetch_add(1);
  270. }
  271. }
  272. if (!write_op.callback_.should_fail_ && seq_per_batch_) {
  273. seq.fetch_add(1);
  274. }
  275. WriteOptions woptions;
  276. woptions.disableWAL = !enable_WAL_;
  277. woptions.sync = enable_WAL_;
  278. if (woptions.protection_bytes_per_key > 0) {
  279. ASSERT_OK(WriteBatchInternal::UpdateProtectionInfo(
  280. &write_op.write_batch_, woptions.protection_bytes_per_key));
  281. }
  282. Status s;
  283. if (seq_per_batch_) {
  284. class PublishSeqCallback : public PreReleaseCallback {
  285. public:
  286. PublishSeqCallback(DBImpl* db_impl_in) : db_impl_(db_impl_in) {}
  287. Status Callback(SequenceNumber last_seq, bool /*not used*/, uint64_t,
  288. size_t /*index*/, size_t /*total*/) override {
  289. db_impl_->SetLastPublishedSequence(last_seq);
  290. return Status::OK();
  291. }
  292. DBImpl* db_impl_;
  293. } publish_seq_callback(db_impl);
  294. // seq_per_batch_ requires a natural batch separator or Noop
  295. ASSERT_OK(WriteBatchInternal::InsertNoop(&write_op.write_batch_));
  296. const size_t ONE_BATCH = 1;
  297. s = db_impl->WriteImpl(woptions, &write_op.write_batch_,
  298. &write_op.callback_, &write_op.user_write_cb_,
  299. nullptr, 0, false, nullptr, ONE_BATCH,
  300. two_queues_ ? &publish_seq_callback : nullptr);
  301. } else {
  302. s = db_impl->WriteWithCallback(woptions, &write_op.write_batch_,
  303. &write_op.callback_,
  304. &write_op.user_write_cb_);
  305. }
  306. ASSERT_TRUE(write_op.user_write_cb_.write_enqueued_.load());
  307. if (write_op.callback_.should_fail_) {
  308. ASSERT_TRUE(s.IsBusy());
  309. ASSERT_FALSE(write_op.user_write_cb_.wal_write_done_.load());
  310. } else {
  311. ASSERT_OK(s);
  312. if (enable_WAL_) {
  313. ASSERT_TRUE(write_op.user_write_cb_.wal_write_done_.load());
  314. } else {
  315. ASSERT_FALSE(write_op.user_write_cb_.wal_write_done_.load());
  316. }
  317. }
  318. };
  319. ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
  320. // do all the writes
  321. std::vector<port::Thread> threads;
  322. for (uint32_t i = 0; i < write_group.size(); i++) {
  323. threads.emplace_back(write_with_callback_func);
  324. }
  325. for (auto& t : threads) {
  326. t.join();
  327. }
  328. ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
  329. // check for keys
  330. string value;
  331. for (auto& w : write_group) {
  332. ASSERT_TRUE(w.callback_.was_called_.load());
  333. for (auto& kvp : w.kvs_) {
  334. if (w.callback_.should_fail_) {
  335. ASSERT_TRUE(db->Get(read_options, kvp.first, &value).IsNotFound());
  336. } else {
  337. ASSERT_OK(db->Get(read_options, kvp.first, &value));
  338. ASSERT_EQ(value, kvp.second);
  339. }
  340. }
  341. }
  342. ASSERT_EQ(seq.load(), db_impl->TEST_GetLastVisibleSequence());
  343. db.reset();
  344. ASSERT_OK(DestroyDB(dbname, options));
  345. }
  346. }
  347. INSTANTIATE_TEST_CASE_P(WriteCallbackPTest, WriteCallbackPTest,
  348. ::testing::Combine(::testing::Bool(), ::testing::Bool(),
  349. ::testing::Bool(), ::testing::Bool(),
  350. ::testing::Bool(), ::testing::Bool(),
  351. ::testing::Bool()));
  352. #endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
  353. TEST_F(WriteCallbackTest, WriteCallBackTest) {
  354. Options options;
  355. WriteOptions write_options;
  356. ReadOptions read_options;
  357. string value;
  358. DB* db;
  359. DBImpl* db_impl;
  360. ASSERT_OK(DestroyDB(dbname, options));
  361. options.create_if_missing = true;
  362. Status s = DB::Open(options, dbname, &db);
  363. ASSERT_OK(s);
  364. db_impl = dynamic_cast<DBImpl*>(db);
  365. ASSERT_TRUE(db_impl);
  366. WriteBatch wb;
  367. ASSERT_OK(wb.Put("a", "value.a"));
  368. ASSERT_OK(wb.Delete("x"));
  369. // Test a simple Write
  370. s = db->Write(write_options, &wb);
  371. ASSERT_OK(s);
  372. s = db->Get(read_options, "a", &value);
  373. ASSERT_OK(s);
  374. ASSERT_EQ("value.a", value);
  375. // Test WriteWithCallback
  376. WriteCallbackTestWriteCallback1 callback1;
  377. WriteBatch wb2;
  378. ASSERT_OK(wb2.Put("a", "value.a2"));
  379. s = db_impl->WriteWithCallback(write_options, &wb2, &callback1);
  380. ASSERT_OK(s);
  381. ASSERT_TRUE(callback1.was_called);
  382. s = db->Get(read_options, "a", &value);
  383. ASSERT_OK(s);
  384. ASSERT_EQ("value.a2", value);
  385. // Test WriteWithCallback for a callback that fails
  386. WriteCallbackTestWriteCallback2 callback2;
  387. WriteBatch wb3;
  388. ASSERT_OK(wb3.Put("a", "value.a3"));
  389. s = db_impl->WriteWithCallback(write_options, &wb3, &callback2);
  390. ASSERT_NOK(s);
  391. s = db->Get(read_options, "a", &value);
  392. ASSERT_OK(s);
  393. ASSERT_EQ("value.a2", value);
  394. MockUserWriteCallback user_write_cb;
  395. WriteBatch wb4;
  396. ASSERT_OK(wb4.Put("a", "value.a4"));
  397. ASSERT_OK(db->WriteWithCallback(write_options, &wb4, &user_write_cb));
  398. ASSERT_OK(db->Get(read_options, "a", &value));
  399. ASSERT_EQ(value, "value.a4");
  400. ASSERT_TRUE(user_write_cb.write_enqueued_.load());
  401. ASSERT_TRUE(user_write_cb.wal_write_done_.load());
  402. delete db;
  403. ASSERT_OK(DestroyDB(dbname, options));
  404. }
  405. } // namespace ROCKSDB_NAMESPACE
  406. int main(int argc, char** argv) {
  407. ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
  408. ::testing::InitGoogleTest(&argc, argv);
  409. return RUN_ALL_TESTS();
  410. }