thread_local_test.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580
  1. // Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
  2. // This source code is licensed under both the GPLv2 (found in the
  3. // COPYING file in the root directory) and Apache 2.0 License
  4. // (found in the LICENSE.Apache file in the root directory).
  5. #include <thread>
  6. #include <atomic>
  7. #include <string>
  8. #include "port/port.h"
  9. #include "rocksdb/env.h"
  10. #include "test_util/sync_point.h"
  11. #include "test_util/testharness.h"
  12. #include "test_util/testutil.h"
  13. #include "util/autovector.h"
  14. #include "util/thread_local.h"
  15. namespace ROCKSDB_NAMESPACE {
  16. class ThreadLocalTest : public testing::Test {
  17. public:
  18. ThreadLocalTest() : env_(Env::Default()) {}
  19. Env* env_;
  20. };
  21. namespace {
  22. struct Params {
  23. Params(port::Mutex* m, port::CondVar* c, int* u, int n,
  24. UnrefHandler handler = nullptr)
  25. : mu(m),
  26. cv(c),
  27. unref(u),
  28. total(n),
  29. started(0),
  30. completed(0),
  31. doWrite(false),
  32. tls1(handler),
  33. tls2(nullptr) {}
  34. port::Mutex* mu;
  35. port::CondVar* cv;
  36. int* unref;
  37. int total;
  38. int started;
  39. int completed;
  40. bool doWrite;
  41. ThreadLocalPtr tls1;
  42. ThreadLocalPtr* tls2;
  43. };
  44. class IDChecker : public ThreadLocalPtr {
  45. public:
  46. static uint32_t PeekId() {
  47. return TEST_PeekId();
  48. }
  49. };
  50. } // anonymous namespace
  51. // Suppress false positive clang analyzer warnings.
  52. #ifndef __clang_analyzer__
  53. TEST_F(ThreadLocalTest, UniqueIdTest) {
  54. port::Mutex mu;
  55. port::CondVar cv(&mu);
  56. uint32_t base_id = IDChecker::PeekId();
  57. // New ThreadLocal instance bumps id by 1
  58. {
  59. // Id used 0
  60. Params p1(&mu, &cv, nullptr, 1u);
  61. ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
  62. // Id used 1
  63. Params p2(&mu, &cv, nullptr, 1u);
  64. ASSERT_EQ(IDChecker::PeekId(), base_id + 2u);
  65. // Id used 2
  66. Params p3(&mu, &cv, nullptr, 1u);
  67. ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
  68. // Id used 3
  69. Params p4(&mu, &cv, nullptr, 1u);
  70. ASSERT_EQ(IDChecker::PeekId(), base_id + 4u);
  71. }
  72. // id 3, 2, 1, 0 are in the free queue in order
  73. ASSERT_EQ(IDChecker::PeekId(), base_id + 0u);
  74. // pick up 0
  75. Params p1(&mu, &cv, nullptr, 1u);
  76. ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
  77. // pick up 1
  78. Params* p2 = new Params(&mu, &cv, nullptr, 1u);
  79. ASSERT_EQ(IDChecker::PeekId(), base_id + 2u);
  80. // pick up 2
  81. Params p3(&mu, &cv, nullptr, 1u);
  82. ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
  83. // return up 1
  84. delete p2;
  85. ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
  86. // Now we have 3, 1 in queue
  87. // pick up 1
  88. Params p4(&mu, &cv, nullptr, 1u);
  89. ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
  90. // pick up 3
  91. Params p5(&mu, &cv, nullptr, 1u);
  92. // next new id
  93. ASSERT_EQ(IDChecker::PeekId(), base_id + 4u);
  94. // After exit, id sequence in queue:
  95. // 3, 1, 2, 0
  96. }
  97. #endif // __clang_analyzer__
  98. TEST_F(ThreadLocalTest, SequentialReadWriteTest) {
  99. // global id list carries over 3, 1, 2, 0
  100. uint32_t base_id = IDChecker::PeekId();
  101. port::Mutex mu;
  102. port::CondVar cv(&mu);
  103. Params p(&mu, &cv, nullptr, 1);
  104. ThreadLocalPtr tls2;
  105. p.tls2 = &tls2;
  106. auto func = [](void* ptr) {
  107. auto& params = *static_cast<Params*>(ptr);
  108. ASSERT_TRUE(params.tls1.Get() == nullptr);
  109. params.tls1.Reset(reinterpret_cast<int*>(1));
  110. ASSERT_TRUE(params.tls1.Get() == reinterpret_cast<int*>(1));
  111. params.tls1.Reset(reinterpret_cast<int*>(2));
  112. ASSERT_TRUE(params.tls1.Get() == reinterpret_cast<int*>(2));
  113. ASSERT_TRUE(params.tls2->Get() == nullptr);
  114. params.tls2->Reset(reinterpret_cast<int*>(1));
  115. ASSERT_TRUE(params.tls2->Get() == reinterpret_cast<int*>(1));
  116. params.tls2->Reset(reinterpret_cast<int*>(2));
  117. ASSERT_TRUE(params.tls2->Get() == reinterpret_cast<int*>(2));
  118. params.mu->Lock();
  119. ++(params.completed);
  120. params.cv->SignalAll();
  121. params.mu->Unlock();
  122. };
  123. for (int iter = 0; iter < 1024; ++iter) {
  124. ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
  125. // Another new thread, read/write should not see value from previous thread
  126. env_->StartThread(func, static_cast<void*>(&p));
  127. mu.Lock();
  128. while (p.completed != iter + 1) {
  129. cv.Wait();
  130. }
  131. mu.Unlock();
  132. ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
  133. }
  134. }
  135. TEST_F(ThreadLocalTest, ConcurrentReadWriteTest) {
  136. // global id list carries over 3, 1, 2, 0
  137. uint32_t base_id = IDChecker::PeekId();
  138. ThreadLocalPtr tls2;
  139. port::Mutex mu1;
  140. port::CondVar cv1(&mu1);
  141. Params p1(&mu1, &cv1, nullptr, 16);
  142. p1.tls2 = &tls2;
  143. port::Mutex mu2;
  144. port::CondVar cv2(&mu2);
  145. Params p2(&mu2, &cv2, nullptr, 16);
  146. p2.doWrite = true;
  147. p2.tls2 = &tls2;
  148. auto func = [](void* ptr) {
  149. auto& p = *static_cast<Params*>(ptr);
  150. p.mu->Lock();
  151. // Size_T switches size along with the ptr size
  152. // we want to cast to.
  153. size_t own = ++(p.started);
  154. p.cv->SignalAll();
  155. while (p.started != p.total) {
  156. p.cv->Wait();
  157. }
  158. p.mu->Unlock();
  159. // Let write threads write a different value from the read threads
  160. if (p.doWrite) {
  161. own += 8192;
  162. }
  163. ASSERT_TRUE(p.tls1.Get() == nullptr);
  164. ASSERT_TRUE(p.tls2->Get() == nullptr);
  165. auto* env = Env::Default();
  166. auto start = env->NowMicros();
  167. p.tls1.Reset(reinterpret_cast<size_t*>(own));
  168. p.tls2->Reset(reinterpret_cast<size_t*>(own + 1));
  169. // Loop for 1 second
  170. while (env->NowMicros() - start < 1000 * 1000) {
  171. for (int iter = 0; iter < 100000; ++iter) {
  172. ASSERT_TRUE(p.tls1.Get() == reinterpret_cast<size_t*>(own));
  173. ASSERT_TRUE(p.tls2->Get() == reinterpret_cast<size_t*>(own + 1));
  174. if (p.doWrite) {
  175. p.tls1.Reset(reinterpret_cast<size_t*>(own));
  176. p.tls2->Reset(reinterpret_cast<size_t*>(own + 1));
  177. }
  178. }
  179. }
  180. p.mu->Lock();
  181. ++(p.completed);
  182. p.cv->SignalAll();
  183. p.mu->Unlock();
  184. };
  185. // Initiate 2 instnaces: one keeps writing and one keeps reading.
  186. // The read instance should not see data from the write instance.
  187. // Each thread local copy of the value are also different from each
  188. // other.
  189. for (int th = 0; th < p1.total; ++th) {
  190. env_->StartThread(func, static_cast<void*>(&p1));
  191. }
  192. for (int th = 0; th < p2.total; ++th) {
  193. env_->StartThread(func, static_cast<void*>(&p2));
  194. }
  195. mu1.Lock();
  196. while (p1.completed != p1.total) {
  197. cv1.Wait();
  198. }
  199. mu1.Unlock();
  200. mu2.Lock();
  201. while (p2.completed != p2.total) {
  202. cv2.Wait();
  203. }
  204. mu2.Unlock();
  205. ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
  206. }
  207. TEST_F(ThreadLocalTest, Unref) {
  208. auto unref = [](void* ptr) {
  209. auto& p = *static_cast<Params*>(ptr);
  210. p.mu->Lock();
  211. ++(*p.unref);
  212. p.mu->Unlock();
  213. };
  214. // Case 0: no unref triggered if ThreadLocalPtr is never accessed
  215. auto func0 = [](void* ptr) {
  216. auto& p = *static_cast<Params*>(ptr);
  217. p.mu->Lock();
  218. ++(p.started);
  219. p.cv->SignalAll();
  220. while (p.started != p.total) {
  221. p.cv->Wait();
  222. }
  223. p.mu->Unlock();
  224. };
  225. for (int th = 1; th <= 128; th += th) {
  226. port::Mutex mu;
  227. port::CondVar cv(&mu);
  228. int unref_count = 0;
  229. Params p(&mu, &cv, &unref_count, th, unref);
  230. for (int i = 0; i < p.total; ++i) {
  231. env_->StartThread(func0, static_cast<void*>(&p));
  232. }
  233. env_->WaitForJoin();
  234. ASSERT_EQ(unref_count, 0);
  235. }
  236. // Case 1: unref triggered by thread exit
  237. auto func1 = [](void* ptr) {
  238. auto& p = *static_cast<Params*>(ptr);
  239. p.mu->Lock();
  240. ++(p.started);
  241. p.cv->SignalAll();
  242. while (p.started != p.total) {
  243. p.cv->Wait();
  244. }
  245. p.mu->Unlock();
  246. ASSERT_TRUE(p.tls1.Get() == nullptr);
  247. ASSERT_TRUE(p.tls2->Get() == nullptr);
  248. p.tls1.Reset(ptr);
  249. p.tls2->Reset(ptr);
  250. p.tls1.Reset(ptr);
  251. p.tls2->Reset(ptr);
  252. };
  253. for (int th = 1; th <= 128; th += th) {
  254. port::Mutex mu;
  255. port::CondVar cv(&mu);
  256. int unref_count = 0;
  257. ThreadLocalPtr tls2(unref);
  258. Params p(&mu, &cv, &unref_count, th, unref);
  259. p.tls2 = &tls2;
  260. for (int i = 0; i < p.total; ++i) {
  261. env_->StartThread(func1, static_cast<void*>(&p));
  262. }
  263. env_->WaitForJoin();
  264. // N threads x 2 ThreadLocal instance cleanup on thread exit
  265. ASSERT_EQ(unref_count, 2 * p.total);
  266. }
  267. // Case 2: unref triggered by ThreadLocal instance destruction
  268. auto func2 = [](void* ptr) {
  269. auto& p = *static_cast<Params*>(ptr);
  270. p.mu->Lock();
  271. ++(p.started);
  272. p.cv->SignalAll();
  273. while (p.started != p.total) {
  274. p.cv->Wait();
  275. }
  276. p.mu->Unlock();
  277. ASSERT_TRUE(p.tls1.Get() == nullptr);
  278. ASSERT_TRUE(p.tls2->Get() == nullptr);
  279. p.tls1.Reset(ptr);
  280. p.tls2->Reset(ptr);
  281. p.tls1.Reset(ptr);
  282. p.tls2->Reset(ptr);
  283. p.mu->Lock();
  284. ++(p.completed);
  285. p.cv->SignalAll();
  286. // Waiting for instruction to exit thread
  287. while (p.completed != 0) {
  288. p.cv->Wait();
  289. }
  290. p.mu->Unlock();
  291. };
  292. for (int th = 1; th <= 128; th += th) {
  293. port::Mutex mu;
  294. port::CondVar cv(&mu);
  295. int unref_count = 0;
  296. Params p(&mu, &cv, &unref_count, th, unref);
  297. p.tls2 = new ThreadLocalPtr(unref);
  298. for (int i = 0; i < p.total; ++i) {
  299. env_->StartThread(func2, static_cast<void*>(&p));
  300. }
  301. // Wait for all threads to finish using Params
  302. mu.Lock();
  303. while (p.completed != p.total) {
  304. cv.Wait();
  305. }
  306. mu.Unlock();
  307. // Now destroy one ThreadLocal instance
  308. delete p.tls2;
  309. p.tls2 = nullptr;
  310. // instance destroy for N threads
  311. ASSERT_EQ(unref_count, p.total);
  312. // Signal to exit
  313. mu.Lock();
  314. p.completed = 0;
  315. cv.SignalAll();
  316. mu.Unlock();
  317. env_->WaitForJoin();
  318. // additional N threads exit unref for the left instance
  319. ASSERT_EQ(unref_count, 2 * p.total);
  320. }
  321. }
  322. TEST_F(ThreadLocalTest, Swap) {
  323. ThreadLocalPtr tls;
  324. tls.Reset(reinterpret_cast<void*>(1));
  325. ASSERT_EQ(reinterpret_cast<int64_t>(tls.Swap(nullptr)), 1);
  326. ASSERT_TRUE(tls.Swap(reinterpret_cast<void*>(2)) == nullptr);
  327. ASSERT_EQ(reinterpret_cast<int64_t>(tls.Get()), 2);
  328. ASSERT_EQ(reinterpret_cast<int64_t>(tls.Swap(reinterpret_cast<void*>(3))), 2);
  329. }
  330. TEST_F(ThreadLocalTest, Scrape) {
  331. auto unref = [](void* ptr) {
  332. auto& p = *static_cast<Params*>(ptr);
  333. p.mu->Lock();
  334. ++(*p.unref);
  335. p.mu->Unlock();
  336. };
  337. auto func = [](void* ptr) {
  338. auto& p = *static_cast<Params*>(ptr);
  339. ASSERT_TRUE(p.tls1.Get() == nullptr);
  340. ASSERT_TRUE(p.tls2->Get() == nullptr);
  341. p.tls1.Reset(ptr);
  342. p.tls2->Reset(ptr);
  343. p.tls1.Reset(ptr);
  344. p.tls2->Reset(ptr);
  345. p.mu->Lock();
  346. ++(p.completed);
  347. p.cv->SignalAll();
  348. // Waiting for instruction to exit thread
  349. while (p.completed != 0) {
  350. p.cv->Wait();
  351. }
  352. p.mu->Unlock();
  353. };
  354. for (int th = 1; th <= 128; th += th) {
  355. port::Mutex mu;
  356. port::CondVar cv(&mu);
  357. int unref_count = 0;
  358. Params p(&mu, &cv, &unref_count, th, unref);
  359. p.tls2 = new ThreadLocalPtr(unref);
  360. for (int i = 0; i < p.total; ++i) {
  361. env_->StartThread(func, static_cast<void*>(&p));
  362. }
  363. // Wait for all threads to finish using Params
  364. mu.Lock();
  365. while (p.completed != p.total) {
  366. cv.Wait();
  367. }
  368. mu.Unlock();
  369. ASSERT_EQ(unref_count, 0);
  370. // Scrape all thread local data. No unref at thread
  371. // exit or ThreadLocalPtr destruction
  372. autovector<void*> ptrs;
  373. p.tls1.Scrape(&ptrs, nullptr);
  374. p.tls2->Scrape(&ptrs, nullptr);
  375. delete p.tls2;
  376. // Signal to exit
  377. mu.Lock();
  378. p.completed = 0;
  379. cv.SignalAll();
  380. mu.Unlock();
  381. env_->WaitForJoin();
  382. ASSERT_EQ(unref_count, 0);
  383. }
  384. }
  385. TEST_F(ThreadLocalTest, Fold) {
  386. auto unref = [](void* ptr) {
  387. delete static_cast<std::atomic<int64_t>*>(ptr);
  388. };
  389. static const int kNumThreads = 16;
  390. static const int kItersPerThread = 10;
  391. port::Mutex mu;
  392. port::CondVar cv(&mu);
  393. Params params(&mu, &cv, nullptr, kNumThreads, unref);
  394. auto func = [](void* ptr) {
  395. auto& p = *static_cast<Params*>(ptr);
  396. ASSERT_TRUE(p.tls1.Get() == nullptr);
  397. p.tls1.Reset(new std::atomic<int64_t>(0));
  398. for (int i = 0; i < kItersPerThread; ++i) {
  399. static_cast<std::atomic<int64_t>*>(p.tls1.Get())->fetch_add(1);
  400. }
  401. p.mu->Lock();
  402. ++(p.completed);
  403. p.cv->SignalAll();
  404. // Waiting for instruction to exit thread
  405. while (p.completed != 0) {
  406. p.cv->Wait();
  407. }
  408. p.mu->Unlock();
  409. };
  410. for (int th = 0; th < params.total; ++th) {
  411. env_->StartThread(func, static_cast<void*>(&params));
  412. }
  413. // Wait for all threads to finish using Params
  414. mu.Lock();
  415. while (params.completed != params.total) {
  416. cv.Wait();
  417. }
  418. mu.Unlock();
  419. // Verify Fold() behavior
  420. int64_t sum = 0;
  421. params.tls1.Fold(
  422. [](void* ptr, void* res) {
  423. auto sum_ptr = static_cast<int64_t*>(res);
  424. *sum_ptr += static_cast<std::atomic<int64_t>*>(ptr)->load();
  425. },
  426. &sum);
  427. ASSERT_EQ(sum, kNumThreads * kItersPerThread);
  428. // Signal to exit
  429. mu.Lock();
  430. params.completed = 0;
  431. cv.SignalAll();
  432. mu.Unlock();
  433. env_->WaitForJoin();
  434. }
  435. TEST_F(ThreadLocalTest, CompareAndSwap) {
  436. ThreadLocalPtr tls;
  437. ASSERT_TRUE(tls.Swap(reinterpret_cast<void*>(1)) == nullptr);
  438. void* expected = reinterpret_cast<void*>(1);
  439. // Swap in 2
  440. ASSERT_TRUE(tls.CompareAndSwap(reinterpret_cast<void*>(2), expected));
  441. expected = reinterpret_cast<void*>(100);
  442. // Fail Swap, still 2
  443. ASSERT_TRUE(!tls.CompareAndSwap(reinterpret_cast<void*>(2), expected));
  444. ASSERT_EQ(expected, reinterpret_cast<void*>(2));
  445. // Swap in 3
  446. expected = reinterpret_cast<void*>(2);
  447. ASSERT_TRUE(tls.CompareAndSwap(reinterpret_cast<void*>(3), expected));
  448. ASSERT_EQ(tls.Get(), reinterpret_cast<void*>(3));
  449. }
  450. namespace {
  451. void* AccessThreadLocal(void* /*arg*/) {
  452. TEST_SYNC_POINT("AccessThreadLocal:Start");
  453. ThreadLocalPtr tlp;
  454. tlp.Reset(new std::string("hello RocksDB"));
  455. TEST_SYNC_POINT("AccessThreadLocal:End");
  456. return nullptr;
  457. }
  458. } // namespace
  459. // The following test is disabled as it requires manual steps to run it
  460. // correctly.
  461. //
  462. // Currently we have no way to acess SyncPoint w/o ASAN error when the
  463. // child thread dies after the main thread dies. So if you manually enable
  464. // this test and only see an ASAN error on SyncPoint, it means you pass the
  465. // test.
  466. TEST_F(ThreadLocalTest, DISABLED_MainThreadDiesFirst) {
  467. ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
  468. {{"AccessThreadLocal:Start", "MainThreadDiesFirst:End"},
  469. {"PosixEnv::~PosixEnv():End", "AccessThreadLocal:End"}});
  470. // Triggers the initialization of singletons.
  471. Env::Default();
  472. #ifndef ROCKSDB_LITE
  473. try {
  474. #endif // ROCKSDB_LITE
  475. ROCKSDB_NAMESPACE::port::Thread th(&AccessThreadLocal, nullptr);
  476. th.detach();
  477. TEST_SYNC_POINT("MainThreadDiesFirst:End");
  478. #ifndef ROCKSDB_LITE
  479. } catch (const std::system_error& ex) {
  480. std::cerr << "Start thread: " << ex.code() << std::endl;
  481. FAIL();
  482. }
  483. #endif // ROCKSDB_LITE
  484. }
  485. } // namespace ROCKSDB_NAMESPACE
  486. int main(int argc, char** argv) {
  487. ::testing::InitGoogleTest(&argc, argv);
  488. return RUN_ALL_TESTS();
  489. }