write_callback_test.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. // Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
  2. // This source code is licensed under both the GPLv2 (found in the
  3. // COPYING file in the root directory) and Apache 2.0 License
  4. // (found in the LICENSE.Apache file in the root directory).
  5. #ifndef ROCKSDB_LITE
  6. #include <atomic>
  7. #include <functional>
  8. #include <string>
  9. #include <utility>
  10. #include <vector>
  11. #include "db/db_impl/db_impl.h"
  12. #include "db/write_callback.h"
  13. #include "port/port.h"
  14. #include "rocksdb/db.h"
  15. #include "rocksdb/write_batch.h"
  16. #include "test_util/sync_point.h"
  17. #include "test_util/testharness.h"
  18. #include "util/random.h"
  19. using std::string;
  20. namespace ROCKSDB_NAMESPACE {
  21. class WriteCallbackTest : public testing::Test {
  22. public:
  23. string dbname;
  24. WriteCallbackTest() {
  25. dbname = test::PerThreadDBPath("write_callback_testdb");
  26. }
  27. };
  28. class WriteCallbackTestWriteCallback1 : public WriteCallback {
  29. public:
  30. bool was_called = false;
  31. Status Callback(DB *db) override {
  32. was_called = true;
  33. // Make sure db is a DBImpl
  34. DBImpl* db_impl = dynamic_cast<DBImpl*> (db);
  35. if (db_impl == nullptr) {
  36. return Status::InvalidArgument("");
  37. }
  38. return Status::OK();
  39. }
  40. bool AllowWriteBatching() override { return true; }
  41. };
  42. class WriteCallbackTestWriteCallback2 : public WriteCallback {
  43. public:
  44. Status Callback(DB* /*db*/) override { return Status::Busy(); }
  45. bool AllowWriteBatching() override { return true; }
  46. };
  47. class MockWriteCallback : public WriteCallback {
  48. public:
  49. bool should_fail_ = false;
  50. bool allow_batching_ = false;
  51. std::atomic<bool> was_called_{false};
  52. MockWriteCallback() {}
  53. MockWriteCallback(const MockWriteCallback& other) {
  54. should_fail_ = other.should_fail_;
  55. allow_batching_ = other.allow_batching_;
  56. was_called_.store(other.was_called_.load());
  57. }
  58. Status Callback(DB* /*db*/) override {
  59. was_called_.store(true);
  60. if (should_fail_) {
  61. return Status::Busy();
  62. } else {
  63. return Status::OK();
  64. }
  65. }
  66. bool AllowWriteBatching() override { return allow_batching_; }
  67. };
  68. TEST_F(WriteCallbackTest, WriteWithCallbackTest) {
  69. struct WriteOP {
  70. WriteOP(bool should_fail = false) { callback_.should_fail_ = should_fail; }
  71. void Put(const string& key, const string& val) {
  72. kvs_.push_back(std::make_pair(key, val));
  73. write_batch_.Put(key, val);
  74. }
  75. void Clear() {
  76. kvs_.clear();
  77. write_batch_.Clear();
  78. callback_.was_called_.store(false);
  79. }
  80. MockWriteCallback callback_;
  81. WriteBatch write_batch_;
  82. std::vector<std::pair<string, string>> kvs_;
  83. };
  84. // In each scenario we'll launch multiple threads to write.
  85. // The size of each array equals to number of threads, and
  86. // each boolean in it denote whether callback of corresponding
  87. // thread should succeed or fail.
  88. std::vector<std::vector<WriteOP>> write_scenarios = {
  89. {true},
  90. {false},
  91. {false, false},
  92. {true, true},
  93. {true, false},
  94. {false, true},
  95. {false, false, false},
  96. {true, true, true},
  97. {false, true, false},
  98. {true, false, true},
  99. {true, false, false, false, false},
  100. {false, false, false, false, true},
  101. {false, false, true, false, true},
  102. };
  103. for (auto& unordered_write : {true, false}) {
  104. for (auto& seq_per_batch : {true, false}) {
  105. for (auto& two_queues : {true, false}) {
  106. for (auto& allow_parallel : {true, false}) {
  107. for (auto& allow_batching : {true, false}) {
  108. for (auto& enable_WAL : {true, false}) {
  109. for (auto& enable_pipelined_write : {true, false}) {
  110. for (auto& write_group : write_scenarios) {
  111. Options options;
  112. options.create_if_missing = true;
  113. options.unordered_write = unordered_write;
  114. options.allow_concurrent_memtable_write = allow_parallel;
  115. options.enable_pipelined_write = enable_pipelined_write;
  116. options.two_write_queues = two_queues;
  117. // Skip unsupported combinations
  118. if (options.enable_pipelined_write && seq_per_batch) {
  119. continue;
  120. }
  121. if (options.enable_pipelined_write && options.two_write_queues) {
  122. continue;
  123. }
  124. if (options.unordered_write &&
  125. !options.allow_concurrent_memtable_write) {
  126. continue;
  127. }
  128. if (options.unordered_write && options.enable_pipelined_write) {
  129. continue;
  130. }
  131. ReadOptions read_options;
  132. DB* db;
  133. DBImpl* db_impl;
  134. DestroyDB(dbname, options);
  135. DBOptions db_options(options);
  136. ColumnFamilyOptions cf_options(options);
  137. std::vector<ColumnFamilyDescriptor> column_families;
  138. column_families.push_back(
  139. ColumnFamilyDescriptor(kDefaultColumnFamilyName, cf_options));
  140. std::vector<ColumnFamilyHandle*> handles;
  141. auto open_s =
  142. DBImpl::Open(db_options, dbname, column_families, &handles,
  143. &db, seq_per_batch, true /* batch_per_txn */);
  144. ASSERT_OK(open_s);
  145. assert(handles.size() == 1);
  146. delete handles[0];
  147. db_impl = dynamic_cast<DBImpl*>(db);
  148. ASSERT_TRUE(db_impl);
  149. // Writers that have called JoinBatchGroup.
  150. std::atomic<uint64_t> threads_joining(0);
  151. // Writers that have linked to the queue
  152. std::atomic<uint64_t> threads_linked(0);
  153. // Writers that pass WriteThread::JoinBatchGroup:Wait sync-point.
  154. std::atomic<uint64_t> threads_verified(0);
  155. std::atomic<uint64_t> seq(db_impl->GetLatestSequenceNumber());
  156. ASSERT_EQ(db_impl->GetLatestSequenceNumber(), 0);
  157. ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
  158. "WriteThread::JoinBatchGroup:Start", [&](void*) {
  159. uint64_t cur_threads_joining = threads_joining.fetch_add(1);
  160. // Wait for the last joined writer to link to the queue.
  161. // In this way the writers link to the queue one by one.
  162. // This allows us to confidently detect the first writer
  163. // who increases threads_linked as the leader.
  164. while (threads_linked.load() < cur_threads_joining) {
  165. }
  166. });
  167. // Verification once writers call JoinBatchGroup.
  168. ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
  169. "WriteThread::JoinBatchGroup:Wait", [&](void* arg) {
  170. uint64_t cur_threads_linked = threads_linked.fetch_add(1);
  171. bool is_leader = false;
  172. bool is_last = false;
  173. // who am i
  174. is_leader = (cur_threads_linked == 0);
  175. is_last = (cur_threads_linked == write_group.size() - 1);
  176. // check my state
  177. auto* writer = reinterpret_cast<WriteThread::Writer*>(arg);
  178. if (is_leader) {
  179. ASSERT_TRUE(writer->state ==
  180. WriteThread::State::STATE_GROUP_LEADER);
  181. } else {
  182. ASSERT_TRUE(writer->state ==
  183. WriteThread::State::STATE_INIT);
  184. }
  185. // (meta test) the first WriteOP should indeed be the first
  186. // and the last should be the last (all others can be out of
  187. // order)
  188. if (is_leader) {
  189. ASSERT_TRUE(writer->callback->Callback(nullptr).ok() ==
  190. !write_group.front().callback_.should_fail_);
  191. } else if (is_last) {
  192. ASSERT_TRUE(writer->callback->Callback(nullptr).ok() ==
  193. !write_group.back().callback_.should_fail_);
  194. }
  195. threads_verified.fetch_add(1);
  196. // Wait here until all verification in this sync-point
  197. // callback finish for all writers.
  198. while (threads_verified.load() < write_group.size()) {
  199. }
  200. });
  201. ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
  202. "WriteThread::JoinBatchGroup:DoneWaiting", [&](void* arg) {
  203. // check my state
  204. auto* writer = reinterpret_cast<WriteThread::Writer*>(arg);
  205. if (!allow_batching) {
  206. // no batching so everyone should be a leader
  207. ASSERT_TRUE(writer->state ==
  208. WriteThread::State::STATE_GROUP_LEADER);
  209. } else if (!allow_parallel) {
  210. ASSERT_TRUE(writer->state ==
  211. WriteThread::State::STATE_COMPLETED ||
  212. (enable_pipelined_write &&
  213. writer->state ==
  214. WriteThread::State::
  215. STATE_MEMTABLE_WRITER_LEADER));
  216. }
  217. });
  218. std::atomic<uint32_t> thread_num(0);
  219. std::atomic<char> dummy_key(0);
  220. // Each write thread create a random write batch and write to DB
  221. // with a write callback.
  222. std::function<void()> write_with_callback_func = [&]() {
  223. uint32_t i = thread_num.fetch_add(1);
  224. Random rnd(i);
  225. // leaders gotta lead
  226. while (i > 0 && threads_verified.load() < 1) {
  227. }
  228. // loser has to lose
  229. while (i == write_group.size() - 1 &&
  230. threads_verified.load() < write_group.size() - 1) {
  231. }
  232. auto& write_op = write_group.at(i);
  233. write_op.Clear();
  234. write_op.callback_.allow_batching_ = allow_batching;
  235. // insert some keys
  236. for (uint32_t j = 0; j < rnd.Next() % 50; j++) {
  237. // grab unique key
  238. char my_key = dummy_key.fetch_add(1);
  239. string skey(5, my_key);
  240. string sval(10, my_key);
  241. write_op.Put(skey, sval);
  242. if (!write_op.callback_.should_fail_ && !seq_per_batch) {
  243. seq.fetch_add(1);
  244. }
  245. }
  246. if (!write_op.callback_.should_fail_ && seq_per_batch) {
  247. seq.fetch_add(1);
  248. }
  249. WriteOptions woptions;
  250. woptions.disableWAL = !enable_WAL;
  251. woptions.sync = enable_WAL;
  252. Status s;
  253. if (seq_per_batch) {
  254. class PublishSeqCallback : public PreReleaseCallback {
  255. public:
  256. PublishSeqCallback(DBImpl* db_impl_in)
  257. : db_impl_(db_impl_in) {}
  258. Status Callback(SequenceNumber last_seq, bool /*not used*/,
  259. uint64_t, size_t /*index*/,
  260. size_t /*total*/) override {
  261. db_impl_->SetLastPublishedSequence(last_seq);
  262. return Status::OK();
  263. }
  264. DBImpl* db_impl_;
  265. } publish_seq_callback(db_impl);
  266. // seq_per_batch requires a natural batch separator or Noop
  267. WriteBatchInternal::InsertNoop(&write_op.write_batch_);
  268. const size_t ONE_BATCH = 1;
  269. s = db_impl->WriteImpl(
  270. woptions, &write_op.write_batch_, &write_op.callback_,
  271. nullptr, 0, false, nullptr, ONE_BATCH,
  272. two_queues ? &publish_seq_callback : nullptr);
  273. } else {
  274. s = db_impl->WriteWithCallback(
  275. woptions, &write_op.write_batch_, &write_op.callback_);
  276. }
  277. if (write_op.callback_.should_fail_) {
  278. ASSERT_TRUE(s.IsBusy());
  279. } else {
  280. ASSERT_OK(s);
  281. }
  282. };
  283. ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
  284. // do all the writes
  285. std::vector<port::Thread> threads;
  286. for (uint32_t i = 0; i < write_group.size(); i++) {
  287. threads.emplace_back(write_with_callback_func);
  288. }
  289. for (auto& t : threads) {
  290. t.join();
  291. }
  292. ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
  293. // check for keys
  294. string value;
  295. for (auto& w : write_group) {
  296. ASSERT_TRUE(w.callback_.was_called_.load());
  297. for (auto& kvp : w.kvs_) {
  298. if (w.callback_.should_fail_) {
  299. ASSERT_TRUE(
  300. db->Get(read_options, kvp.first, &value).IsNotFound());
  301. } else {
  302. ASSERT_OK(db->Get(read_options, kvp.first, &value));
  303. ASSERT_EQ(value, kvp.second);
  304. }
  305. }
  306. }
  307. ASSERT_EQ(seq.load(), db_impl->TEST_GetLastVisibleSequence());
  308. delete db;
  309. DestroyDB(dbname, options);
  310. }
  311. }
  312. }
  313. }
  314. }
  315. }
  316. }
  317. }
  318. }
  319. TEST_F(WriteCallbackTest, WriteCallBackTest) {
  320. Options options;
  321. WriteOptions write_options;
  322. ReadOptions read_options;
  323. string value;
  324. DB* db;
  325. DBImpl* db_impl;
  326. DestroyDB(dbname, options);
  327. options.create_if_missing = true;
  328. Status s = DB::Open(options, dbname, &db);
  329. ASSERT_OK(s);
  330. db_impl = dynamic_cast<DBImpl*> (db);
  331. ASSERT_TRUE(db_impl);
  332. WriteBatch wb;
  333. wb.Put("a", "value.a");
  334. wb.Delete("x");
  335. // Test a simple Write
  336. s = db->Write(write_options, &wb);
  337. ASSERT_OK(s);
  338. s = db->Get(read_options, "a", &value);
  339. ASSERT_OK(s);
  340. ASSERT_EQ("value.a", value);
  341. // Test WriteWithCallback
  342. WriteCallbackTestWriteCallback1 callback1;
  343. WriteBatch wb2;
  344. wb2.Put("a", "value.a2");
  345. s = db_impl->WriteWithCallback(write_options, &wb2, &callback1);
  346. ASSERT_OK(s);
  347. ASSERT_TRUE(callback1.was_called);
  348. s = db->Get(read_options, "a", &value);
  349. ASSERT_OK(s);
  350. ASSERT_EQ("value.a2", value);
  351. // Test WriteWithCallback for a callback that fails
  352. WriteCallbackTestWriteCallback2 callback2;
  353. WriteBatch wb3;
  354. wb3.Put("a", "value.a3");
  355. s = db_impl->WriteWithCallback(write_options, &wb3, &callback2);
  356. ASSERT_NOK(s);
  357. s = db->Get(read_options, "a", &value);
  358. ASSERT_OK(s);
  359. ASSERT_EQ("value.a2", value);
  360. delete db;
  361. DestroyDB(dbname, options);
  362. }
  363. } // namespace ROCKSDB_NAMESPACE
  364. int main(int argc, char** argv) {
  365. ::testing::InitGoogleTest(&argc, argv);
  366. return RUN_ALL_TESTS();
  367. }
  368. #else
  369. #include <stdio.h>
  370. int main(int /*argc*/, char** /*argv*/) {
  371. fprintf(stderr,
  372. "SKIPPED as WriteWithCallback is not supported in ROCKSDB_LITE\n");
  373. return 0;
  374. }
  375. #endif // !ROCKSDB_LITE