| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580 |
- // Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
- // This source code is licensed under both the GPLv2 (found in the
- // COPYING file in the root directory) and Apache 2.0 License
- // (found in the LICENSE.Apache file in the root directory).
- #include <thread>
- #include <atomic>
- #include <string>
- #include "port/port.h"
- #include "rocksdb/env.h"
- #include "test_util/sync_point.h"
- #include "test_util/testharness.h"
- #include "test_util/testutil.h"
- #include "util/autovector.h"
- #include "util/thread_local.h"
- namespace ROCKSDB_NAMESPACE {
- class ThreadLocalTest : public testing::Test {
- public:
- ThreadLocalTest() : env_(Env::Default()) {}
- Env* env_;
- };
- namespace {
- struct Params {
- Params(port::Mutex* m, port::CondVar* c, int* u, int n,
- UnrefHandler handler = nullptr)
- : mu(m),
- cv(c),
- unref(u),
- total(n),
- started(0),
- completed(0),
- doWrite(false),
- tls1(handler),
- tls2(nullptr) {}
- port::Mutex* mu;
- port::CondVar* cv;
- int* unref;
- int total;
- int started;
- int completed;
- bool doWrite;
- ThreadLocalPtr tls1;
- ThreadLocalPtr* tls2;
- };
- class IDChecker : public ThreadLocalPtr {
- public:
- static uint32_t PeekId() {
- return TEST_PeekId();
- }
- };
- } // anonymous namespace
- // Suppress false positive clang analyzer warnings.
- #ifndef __clang_analyzer__
- TEST_F(ThreadLocalTest, UniqueIdTest) {
- port::Mutex mu;
- port::CondVar cv(&mu);
- uint32_t base_id = IDChecker::PeekId();
- // New ThreadLocal instance bumps id by 1
- {
- // Id used 0
- Params p1(&mu, &cv, nullptr, 1u);
- ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
- // Id used 1
- Params p2(&mu, &cv, nullptr, 1u);
- ASSERT_EQ(IDChecker::PeekId(), base_id + 2u);
- // Id used 2
- Params p3(&mu, &cv, nullptr, 1u);
- ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
- // Id used 3
- Params p4(&mu, &cv, nullptr, 1u);
- ASSERT_EQ(IDChecker::PeekId(), base_id + 4u);
- }
- // id 3, 2, 1, 0 are in the free queue in order
- ASSERT_EQ(IDChecker::PeekId(), base_id + 0u);
- // pick up 0
- Params p1(&mu, &cv, nullptr, 1u);
- ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
- // pick up 1
- Params* p2 = new Params(&mu, &cv, nullptr, 1u);
- ASSERT_EQ(IDChecker::PeekId(), base_id + 2u);
- // pick up 2
- Params p3(&mu, &cv, nullptr, 1u);
- ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
- // return up 1
- delete p2;
- ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
- // Now we have 3, 1 in queue
- // pick up 1
- Params p4(&mu, &cv, nullptr, 1u);
- ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
- // pick up 3
- Params p5(&mu, &cv, nullptr, 1u);
- // next new id
- ASSERT_EQ(IDChecker::PeekId(), base_id + 4u);
- // After exit, id sequence in queue:
- // 3, 1, 2, 0
- }
- #endif // __clang_analyzer__
- TEST_F(ThreadLocalTest, SequentialReadWriteTest) {
- // global id list carries over 3, 1, 2, 0
- uint32_t base_id = IDChecker::PeekId();
- port::Mutex mu;
- port::CondVar cv(&mu);
- Params p(&mu, &cv, nullptr, 1);
- ThreadLocalPtr tls2;
- p.tls2 = &tls2;
- auto func = [](void* ptr) {
- auto& params = *static_cast<Params*>(ptr);
- ASSERT_TRUE(params.tls1.Get() == nullptr);
- params.tls1.Reset(reinterpret_cast<int*>(1));
- ASSERT_TRUE(params.tls1.Get() == reinterpret_cast<int*>(1));
- params.tls1.Reset(reinterpret_cast<int*>(2));
- ASSERT_TRUE(params.tls1.Get() == reinterpret_cast<int*>(2));
- ASSERT_TRUE(params.tls2->Get() == nullptr);
- params.tls2->Reset(reinterpret_cast<int*>(1));
- ASSERT_TRUE(params.tls2->Get() == reinterpret_cast<int*>(1));
- params.tls2->Reset(reinterpret_cast<int*>(2));
- ASSERT_TRUE(params.tls2->Get() == reinterpret_cast<int*>(2));
- params.mu->Lock();
- ++(params.completed);
- params.cv->SignalAll();
- params.mu->Unlock();
- };
- for (int iter = 0; iter < 1024; ++iter) {
- ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
- // Another new thread, read/write should not see value from previous thread
- env_->StartThread(func, static_cast<void*>(&p));
- mu.Lock();
- while (p.completed != iter + 1) {
- cv.Wait();
- }
- mu.Unlock();
- ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
- }
- }
- TEST_F(ThreadLocalTest, ConcurrentReadWriteTest) {
- // global id list carries over 3, 1, 2, 0
- uint32_t base_id = IDChecker::PeekId();
- ThreadLocalPtr tls2;
- port::Mutex mu1;
- port::CondVar cv1(&mu1);
- Params p1(&mu1, &cv1, nullptr, 16);
- p1.tls2 = &tls2;
- port::Mutex mu2;
- port::CondVar cv2(&mu2);
- Params p2(&mu2, &cv2, nullptr, 16);
- p2.doWrite = true;
- p2.tls2 = &tls2;
- auto func = [](void* ptr) {
- auto& p = *static_cast<Params*>(ptr);
- p.mu->Lock();
- // Size_T switches size along with the ptr size
- // we want to cast to.
- size_t own = ++(p.started);
- p.cv->SignalAll();
- while (p.started != p.total) {
- p.cv->Wait();
- }
- p.mu->Unlock();
- // Let write threads write a different value from the read threads
- if (p.doWrite) {
- own += 8192;
- }
- ASSERT_TRUE(p.tls1.Get() == nullptr);
- ASSERT_TRUE(p.tls2->Get() == nullptr);
- auto* env = Env::Default();
- auto start = env->NowMicros();
- p.tls1.Reset(reinterpret_cast<size_t*>(own));
- p.tls2->Reset(reinterpret_cast<size_t*>(own + 1));
- // Loop for 1 second
- while (env->NowMicros() - start < 1000 * 1000) {
- for (int iter = 0; iter < 100000; ++iter) {
- ASSERT_TRUE(p.tls1.Get() == reinterpret_cast<size_t*>(own));
- ASSERT_TRUE(p.tls2->Get() == reinterpret_cast<size_t*>(own + 1));
- if (p.doWrite) {
- p.tls1.Reset(reinterpret_cast<size_t*>(own));
- p.tls2->Reset(reinterpret_cast<size_t*>(own + 1));
- }
- }
- }
- p.mu->Lock();
- ++(p.completed);
- p.cv->SignalAll();
- p.mu->Unlock();
- };
- // Initiate 2 instnaces: one keeps writing and one keeps reading.
- // The read instance should not see data from the write instance.
- // Each thread local copy of the value are also different from each
- // other.
- for (int th = 0; th < p1.total; ++th) {
- env_->StartThread(func, static_cast<void*>(&p1));
- }
- for (int th = 0; th < p2.total; ++th) {
- env_->StartThread(func, static_cast<void*>(&p2));
- }
- mu1.Lock();
- while (p1.completed != p1.total) {
- cv1.Wait();
- }
- mu1.Unlock();
- mu2.Lock();
- while (p2.completed != p2.total) {
- cv2.Wait();
- }
- mu2.Unlock();
- ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
- }
- TEST_F(ThreadLocalTest, Unref) {
- auto unref = [](void* ptr) {
- auto& p = *static_cast<Params*>(ptr);
- p.mu->Lock();
- ++(*p.unref);
- p.mu->Unlock();
- };
- // Case 0: no unref triggered if ThreadLocalPtr is never accessed
- auto func0 = [](void* ptr) {
- auto& p = *static_cast<Params*>(ptr);
- p.mu->Lock();
- ++(p.started);
- p.cv->SignalAll();
- while (p.started != p.total) {
- p.cv->Wait();
- }
- p.mu->Unlock();
- };
- for (int th = 1; th <= 128; th += th) {
- port::Mutex mu;
- port::CondVar cv(&mu);
- int unref_count = 0;
- Params p(&mu, &cv, &unref_count, th, unref);
- for (int i = 0; i < p.total; ++i) {
- env_->StartThread(func0, static_cast<void*>(&p));
- }
- env_->WaitForJoin();
- ASSERT_EQ(unref_count, 0);
- }
- // Case 1: unref triggered by thread exit
- auto func1 = [](void* ptr) {
- auto& p = *static_cast<Params*>(ptr);
- p.mu->Lock();
- ++(p.started);
- p.cv->SignalAll();
- while (p.started != p.total) {
- p.cv->Wait();
- }
- p.mu->Unlock();
- ASSERT_TRUE(p.tls1.Get() == nullptr);
- ASSERT_TRUE(p.tls2->Get() == nullptr);
- p.tls1.Reset(ptr);
- p.tls2->Reset(ptr);
- p.tls1.Reset(ptr);
- p.tls2->Reset(ptr);
- };
- for (int th = 1; th <= 128; th += th) {
- port::Mutex mu;
- port::CondVar cv(&mu);
- int unref_count = 0;
- ThreadLocalPtr tls2(unref);
- Params p(&mu, &cv, &unref_count, th, unref);
- p.tls2 = &tls2;
- for (int i = 0; i < p.total; ++i) {
- env_->StartThread(func1, static_cast<void*>(&p));
- }
- env_->WaitForJoin();
- // N threads x 2 ThreadLocal instance cleanup on thread exit
- ASSERT_EQ(unref_count, 2 * p.total);
- }
- // Case 2: unref triggered by ThreadLocal instance destruction
- auto func2 = [](void* ptr) {
- auto& p = *static_cast<Params*>(ptr);
- p.mu->Lock();
- ++(p.started);
- p.cv->SignalAll();
- while (p.started != p.total) {
- p.cv->Wait();
- }
- p.mu->Unlock();
- ASSERT_TRUE(p.tls1.Get() == nullptr);
- ASSERT_TRUE(p.tls2->Get() == nullptr);
- p.tls1.Reset(ptr);
- p.tls2->Reset(ptr);
- p.tls1.Reset(ptr);
- p.tls2->Reset(ptr);
- p.mu->Lock();
- ++(p.completed);
- p.cv->SignalAll();
- // Waiting for instruction to exit thread
- while (p.completed != 0) {
- p.cv->Wait();
- }
- p.mu->Unlock();
- };
- for (int th = 1; th <= 128; th += th) {
- port::Mutex mu;
- port::CondVar cv(&mu);
- int unref_count = 0;
- Params p(&mu, &cv, &unref_count, th, unref);
- p.tls2 = new ThreadLocalPtr(unref);
- for (int i = 0; i < p.total; ++i) {
- env_->StartThread(func2, static_cast<void*>(&p));
- }
- // Wait for all threads to finish using Params
- mu.Lock();
- while (p.completed != p.total) {
- cv.Wait();
- }
- mu.Unlock();
- // Now destroy one ThreadLocal instance
- delete p.tls2;
- p.tls2 = nullptr;
- // instance destroy for N threads
- ASSERT_EQ(unref_count, p.total);
- // Signal to exit
- mu.Lock();
- p.completed = 0;
- cv.SignalAll();
- mu.Unlock();
- env_->WaitForJoin();
- // additional N threads exit unref for the left instance
- ASSERT_EQ(unref_count, 2 * p.total);
- }
- }
- TEST_F(ThreadLocalTest, Swap) {
- ThreadLocalPtr tls;
- tls.Reset(reinterpret_cast<void*>(1));
- ASSERT_EQ(reinterpret_cast<int64_t>(tls.Swap(nullptr)), 1);
- ASSERT_TRUE(tls.Swap(reinterpret_cast<void*>(2)) == nullptr);
- ASSERT_EQ(reinterpret_cast<int64_t>(tls.Get()), 2);
- ASSERT_EQ(reinterpret_cast<int64_t>(tls.Swap(reinterpret_cast<void*>(3))), 2);
- }
- TEST_F(ThreadLocalTest, Scrape) {
- auto unref = [](void* ptr) {
- auto& p = *static_cast<Params*>(ptr);
- p.mu->Lock();
- ++(*p.unref);
- p.mu->Unlock();
- };
- auto func = [](void* ptr) {
- auto& p = *static_cast<Params*>(ptr);
- ASSERT_TRUE(p.tls1.Get() == nullptr);
- ASSERT_TRUE(p.tls2->Get() == nullptr);
- p.tls1.Reset(ptr);
- p.tls2->Reset(ptr);
- p.tls1.Reset(ptr);
- p.tls2->Reset(ptr);
- p.mu->Lock();
- ++(p.completed);
- p.cv->SignalAll();
- // Waiting for instruction to exit thread
- while (p.completed != 0) {
- p.cv->Wait();
- }
- p.mu->Unlock();
- };
- for (int th = 1; th <= 128; th += th) {
- port::Mutex mu;
- port::CondVar cv(&mu);
- int unref_count = 0;
- Params p(&mu, &cv, &unref_count, th, unref);
- p.tls2 = new ThreadLocalPtr(unref);
- for (int i = 0; i < p.total; ++i) {
- env_->StartThread(func, static_cast<void*>(&p));
- }
- // Wait for all threads to finish using Params
- mu.Lock();
- while (p.completed != p.total) {
- cv.Wait();
- }
- mu.Unlock();
- ASSERT_EQ(unref_count, 0);
- // Scrape all thread local data. No unref at thread
- // exit or ThreadLocalPtr destruction
- autovector<void*> ptrs;
- p.tls1.Scrape(&ptrs, nullptr);
- p.tls2->Scrape(&ptrs, nullptr);
- delete p.tls2;
- // Signal to exit
- mu.Lock();
- p.completed = 0;
- cv.SignalAll();
- mu.Unlock();
- env_->WaitForJoin();
- ASSERT_EQ(unref_count, 0);
- }
- }
- TEST_F(ThreadLocalTest, Fold) {
- auto unref = [](void* ptr) {
- delete static_cast<std::atomic<int64_t>*>(ptr);
- };
- static const int kNumThreads = 16;
- static const int kItersPerThread = 10;
- port::Mutex mu;
- port::CondVar cv(&mu);
- Params params(&mu, &cv, nullptr, kNumThreads, unref);
- auto func = [](void* ptr) {
- auto& p = *static_cast<Params*>(ptr);
- ASSERT_TRUE(p.tls1.Get() == nullptr);
- p.tls1.Reset(new std::atomic<int64_t>(0));
- for (int i = 0; i < kItersPerThread; ++i) {
- static_cast<std::atomic<int64_t>*>(p.tls1.Get())->fetch_add(1);
- }
- p.mu->Lock();
- ++(p.completed);
- p.cv->SignalAll();
- // Waiting for instruction to exit thread
- while (p.completed != 0) {
- p.cv->Wait();
- }
- p.mu->Unlock();
- };
- for (int th = 0; th < params.total; ++th) {
- env_->StartThread(func, static_cast<void*>(¶ms));
- }
- // Wait for all threads to finish using Params
- mu.Lock();
- while (params.completed != params.total) {
- cv.Wait();
- }
- mu.Unlock();
- // Verify Fold() behavior
- int64_t sum = 0;
- params.tls1.Fold(
- [](void* ptr, void* res) {
- auto sum_ptr = static_cast<int64_t*>(res);
- *sum_ptr += static_cast<std::atomic<int64_t>*>(ptr)->load();
- },
- &sum);
- ASSERT_EQ(sum, kNumThreads * kItersPerThread);
- // Signal to exit
- mu.Lock();
- params.completed = 0;
- cv.SignalAll();
- mu.Unlock();
- env_->WaitForJoin();
- }
- TEST_F(ThreadLocalTest, CompareAndSwap) {
- ThreadLocalPtr tls;
- ASSERT_TRUE(tls.Swap(reinterpret_cast<void*>(1)) == nullptr);
- void* expected = reinterpret_cast<void*>(1);
- // Swap in 2
- ASSERT_TRUE(tls.CompareAndSwap(reinterpret_cast<void*>(2), expected));
- expected = reinterpret_cast<void*>(100);
- // Fail Swap, still 2
- ASSERT_TRUE(!tls.CompareAndSwap(reinterpret_cast<void*>(2), expected));
- ASSERT_EQ(expected, reinterpret_cast<void*>(2));
- // Swap in 3
- expected = reinterpret_cast<void*>(2);
- ASSERT_TRUE(tls.CompareAndSwap(reinterpret_cast<void*>(3), expected));
- ASSERT_EQ(tls.Get(), reinterpret_cast<void*>(3));
- }
- namespace {
- void* AccessThreadLocal(void* /*arg*/) {
- TEST_SYNC_POINT("AccessThreadLocal:Start");
- ThreadLocalPtr tlp;
- tlp.Reset(new std::string("hello RocksDB"));
- TEST_SYNC_POINT("AccessThreadLocal:End");
- return nullptr;
- }
- } // namespace
- // The following test is disabled as it requires manual steps to run it
- // correctly.
- //
- // Currently we have no way to acess SyncPoint w/o ASAN error when the
- // child thread dies after the main thread dies. So if you manually enable
- // this test and only see an ASAN error on SyncPoint, it means you pass the
- // test.
- TEST_F(ThreadLocalTest, DISABLED_MainThreadDiesFirst) {
- ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
- {{"AccessThreadLocal:Start", "MainThreadDiesFirst:End"},
- {"PosixEnv::~PosixEnv():End", "AccessThreadLocal:End"}});
- // Triggers the initialization of singletons.
- Env::Default();
- #ifndef ROCKSDB_LITE
- try {
- #endif // ROCKSDB_LITE
- ROCKSDB_NAMESPACE::port::Thread th(&AccessThreadLocal, nullptr);
- th.detach();
- TEST_SYNC_POINT("MainThreadDiesFirst:End");
- #ifndef ROCKSDB_LITE
- } catch (const std::system_error& ex) {
- std::cerr << "Start thread: " << ex.code() << std::endl;
- FAIL();
- }
- #endif // ROCKSDB_LITE
- }
- } // namespace ROCKSDB_NAMESPACE
- int main(int argc, char** argv) {
- ::testing::InitGoogleTest(&argc, argv);
- return RUN_ALL_TESTS();
- }
|