faiss_ivf_index_test.cc 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. // Copyright (c) Meta Platforms, Inc. and affiliates.
  2. // This source code is licensed under both the GPLv2 (found in the
  3. // COPYING file in the root directory) and Apache 2.0 License
  4. // (found in the LICENSE.Apache file in the root directory).
  5. #include <charconv>
  6. #include <memory>
  7. #include <string>
  8. #include <vector>
  9. #include "faiss/IndexFlat.h"
  10. #include "faiss/IndexIVFFlat.h"
  11. #include "faiss/utils/random.h"
  12. #include "rocksdb/utilities/secondary_index_faiss.h"
  13. #include "rocksdb/utilities/transaction_db.h"
  14. #include "test_util/testharness.h"
  15. #include "util/coding.h"
  16. namespace ROCKSDB_NAMESPACE {
  17. TEST(FaissIVFIndexTest, Basic) {
  18. constexpr size_t dim = 128;
  19. auto quantizer = std::make_unique<faiss::IndexFlatL2>(dim);
  20. constexpr size_t num_lists = 16;
  21. auto index =
  22. std::make_unique<faiss::IndexIVFFlat>(quantizer.get(), dim, num_lists);
  23. constexpr faiss::idx_t num_vectors = 1024;
  24. std::vector<float> embeddings(dim * num_vectors);
  25. faiss::float_rand(embeddings.data(), dim * num_vectors, 42);
  26. index->train(num_vectors, embeddings.data());
  27. const std::string primary_column_name = "embedding";
  28. auto faiss_ivf_index =
  29. std::make_shared<FaissIVFIndex>(std::move(index), primary_column_name);
  30. const std::string db_name = test::PerThreadDBPath("faiss_ivf_index_test");
  31. EXPECT_OK(DestroyDB(db_name, Options()));
  32. Options options;
  33. options.create_if_missing = true;
  34. TransactionDBOptions txn_db_options;
  35. txn_db_options.secondary_indices.emplace_back(faiss_ivf_index);
  36. TransactionDB* db = nullptr;
  37. ASSERT_OK(TransactionDB::Open(options, txn_db_options, db_name, &db));
  38. std::unique_ptr<TransactionDB> db_guard(db);
  39. ColumnFamilyOptions cf1_opts;
  40. ColumnFamilyHandle* cfh1 = nullptr;
  41. ASSERT_OK(db->CreateColumnFamily(cf1_opts, "cf1", &cfh1));
  42. std::unique_ptr<ColumnFamilyHandle> cfh1_guard(cfh1);
  43. ColumnFamilyOptions cf2_opts;
  44. ColumnFamilyHandle* cfh2 = nullptr;
  45. ASSERT_OK(db->CreateColumnFamily(cf2_opts, "cf2", &cfh2));
  46. std::unique_ptr<ColumnFamilyHandle> cfh2_guard(cfh2);
  47. const auto& secondary_index = txn_db_options.secondary_indices.back();
  48. secondary_index->SetPrimaryColumnFamily(cfh1);
  49. secondary_index->SetSecondaryColumnFamily(cfh2);
  50. // Write the embeddings to the primary column family, indexing them in the
  51. // process
  52. {
  53. std::unique_ptr<Transaction> txn(db->BeginTransaction(WriteOptions()));
  54. for (faiss::idx_t i = 0; i < num_vectors; ++i) {
  55. const std::string primary_key = std::to_string(i);
  56. ASSERT_OK(txn->PutEntity(
  57. cfh1, primary_key,
  58. WideColumns{
  59. {primary_column_name,
  60. ConvertFloatsToSlice(embeddings.data() + i * dim, dim)}}));
  61. }
  62. ASSERT_OK(txn->Commit());
  63. }
  64. // Verify the raw index data in the secondary column family
  65. {
  66. size_t num_found = 0;
  67. std::unique_ptr<Iterator> it(db->NewIterator(ReadOptions(), cfh2));
  68. for (it->SeekToFirst(); it->Valid(); it->Next()) {
  69. Slice key = it->key();
  70. faiss::idx_t label = -1;
  71. ASSERT_TRUE(GetVarsignedint64(&key, &label));
  72. ASSERT_GE(label, 0);
  73. ASSERT_LT(label, num_lists);
  74. faiss::idx_t id = -1;
  75. ASSERT_EQ(std::from_chars(key.data(), key.data() + key.size(), id).ec,
  76. std::errc());
  77. ASSERT_GE(id, 0);
  78. ASSERT_LT(id, num_vectors);
  79. // Since we use IndexIVFFlat, there is no fine quantization, so the code
  80. // is actually just the original embedding
  81. ASSERT_EQ(it->value(),
  82. ConvertFloatsToSlice(embeddings.data() + id * dim, dim));
  83. ++num_found;
  84. }
  85. ASSERT_OK(it->status());
  86. ASSERT_EQ(num_found, num_vectors);
  87. }
  88. // Query the index with some of the original embeddings
  89. std::unique_ptr<Iterator> underlying_it(db->NewIterator(ReadOptions(), cfh2));
  90. auto secondary_it = std::make_unique<SecondaryIndexIterator>(
  91. faiss_ivf_index.get(), std::move(underlying_it));
  92. auto get_id = [](const Slice& key) -> faiss::idx_t {
  93. faiss::idx_t id = -1;
  94. if (std::from_chars(key.data(), key.data() + key.size(), id).ec !=
  95. std::errc()) {
  96. return -1;
  97. }
  98. return id;
  99. };
  100. constexpr size_t neighbors = 8;
  101. auto verify = [&](faiss::idx_t id) {
  102. // Search for a vector from the original set; we expect to find the vector
  103. // itself as the closest match, since we're performing an exhaustive search
  104. std::vector<std::pair<std::string, float>> result;
  105. ASSERT_OK(faiss_ivf_index->FindKNearestNeighbors(
  106. secondary_it.get(),
  107. ConvertFloatsToSlice(embeddings.data() + id * dim, dim), neighbors,
  108. num_lists, &result));
  109. ASSERT_EQ(result.size(), neighbors);
  110. const faiss::idx_t first_id = get_id(result[0].first);
  111. ASSERT_GE(first_id, 0);
  112. ASSERT_LT(first_id, num_vectors);
  113. ASSERT_EQ(first_id, id);
  114. ASSERT_EQ(result[0].second, 0.0f);
  115. // Iterate over the rest of the results
  116. for (size_t i = 1; i < neighbors; ++i) {
  117. const faiss::idx_t other_id = get_id(result[i].first);
  118. ASSERT_GE(other_id, 0);
  119. ASSERT_LT(other_id, num_vectors);
  120. ASSERT_NE(other_id, id);
  121. ASSERT_GE(result[i].second, result[i - 1].second);
  122. }
  123. };
  124. verify(0);
  125. verify(16);
  126. verify(32);
  127. verify(64);
  128. // Sanity checks
  129. {
  130. // No secondary index iterator
  131. constexpr SecondaryIndexIterator* bad_secondary_it = nullptr;
  132. std::vector<std::pair<std::string, float>> result;
  133. ASSERT_TRUE(faiss_ivf_index
  134. ->FindKNearestNeighbors(
  135. bad_secondary_it,
  136. ConvertFloatsToSlice(embeddings.data(), dim), neighbors,
  137. num_lists, &result)
  138. .IsInvalidArgument());
  139. }
  140. {
  141. // Invalid target
  142. std::vector<std::pair<std::string, float>> result;
  143. ASSERT_TRUE(faiss_ivf_index
  144. ->FindKNearestNeighbors(secondary_it.get(), "foo",
  145. neighbors, num_lists, &result)
  146. .IsInvalidArgument());
  147. }
  148. {
  149. // Invalid value for neighbors
  150. constexpr size_t bad_neighbors = 0;
  151. std::vector<std::pair<std::string, float>> result;
  152. ASSERT_TRUE(faiss_ivf_index
  153. ->FindKNearestNeighbors(
  154. secondary_it.get(),
  155. ConvertFloatsToSlice(embeddings.data(), dim),
  156. bad_neighbors, num_lists, &result)
  157. .IsInvalidArgument());
  158. }
  159. {
  160. // Invalid value for neighbors
  161. constexpr size_t bad_probes = 0;
  162. std::vector<std::pair<std::string, float>> result;
  163. ASSERT_TRUE(faiss_ivf_index
  164. ->FindKNearestNeighbors(
  165. secondary_it.get(),
  166. ConvertFloatsToSlice(embeddings.data(), dim), neighbors,
  167. bad_probes, &result)
  168. .IsInvalidArgument());
  169. }
  170. {
  171. // No result parameter
  172. constexpr std::vector<std::pair<std::string, float>>* bad_result = nullptr;
  173. ASSERT_TRUE(faiss_ivf_index
  174. ->FindKNearestNeighbors(
  175. secondary_it.get(),
  176. ConvertFloatsToSlice(embeddings.data(), dim), neighbors,
  177. num_lists, bad_result)
  178. .IsInvalidArgument());
  179. }
  180. }
  181. TEST(FaissIVFIndexTest, Compare) {
  182. // Train two copies of the same index; hand over one to FaissIVFIndex and use
  183. // the other one as a baseline for comparison
  184. constexpr size_t dim = 128;
  185. auto quantizer_cmp = std::make_unique<faiss::IndexFlatL2>(dim);
  186. auto quantizer = std::make_unique<faiss::IndexFlatL2>(dim);
  187. constexpr size_t num_lists = 16;
  188. auto index_cmp = std::make_unique<faiss::IndexIVFFlat>(quantizer_cmp.get(),
  189. dim, num_lists);
  190. auto index =
  191. std::make_unique<faiss::IndexIVFFlat>(quantizer.get(), dim, num_lists);
  192. {
  193. constexpr faiss::idx_t num_train = 1024;
  194. std::vector<float> embeddings_train(dim * num_train);
  195. faiss::float_rand(embeddings_train.data(), dim * num_train, 42);
  196. index_cmp->train(num_train, embeddings_train.data());
  197. index->train(num_train, embeddings_train.data());
  198. }
  199. auto faiss_ivf_index = std::make_shared<FaissIVFIndex>(
  200. std::move(index), kDefaultWideColumnName.ToString());
  201. const std::string db_name = test::PerThreadDBPath("faiss_ivf_index_test");
  202. EXPECT_OK(DestroyDB(db_name, Options()));
  203. Options options;
  204. options.create_if_missing = true;
  205. TransactionDBOptions txn_db_options;
  206. txn_db_options.secondary_indices.emplace_back(faiss_ivf_index);
  207. TransactionDB* db = nullptr;
  208. ASSERT_OK(TransactionDB::Open(options, txn_db_options, db_name, &db));
  209. std::unique_ptr<TransactionDB> db_guard(db);
  210. ColumnFamilyOptions cf1_opts;
  211. ColumnFamilyHandle* cfh1 = nullptr;
  212. ASSERT_OK(db->CreateColumnFamily(cf1_opts, "cf1", &cfh1));
  213. std::unique_ptr<ColumnFamilyHandle> cfh1_guard(cfh1);
  214. ColumnFamilyOptions cf2_opts;
  215. ColumnFamilyHandle* cfh2 = nullptr;
  216. ASSERT_OK(db->CreateColumnFamily(cf2_opts, "cf2", &cfh2));
  217. std::unique_ptr<ColumnFamilyHandle> cfh2_guard(cfh2);
  218. const auto& secondary_index = txn_db_options.secondary_indices.back();
  219. secondary_index->SetPrimaryColumnFamily(cfh1);
  220. secondary_index->SetSecondaryColumnFamily(cfh2);
  221. // Add the same set of database vectors to both indices
  222. constexpr faiss::idx_t num_db = 4096;
  223. {
  224. std::vector<float> embeddings_db(dim * num_db);
  225. faiss::float_rand(embeddings_db.data(), dim * num_db, 123);
  226. for (faiss::idx_t i = 0; i < num_db; ++i) {
  227. const float* const embedding = embeddings_db.data() + i * dim;
  228. index_cmp->add(1, embedding);
  229. const std::string primary_key = std::to_string(i);
  230. ASSERT_OK(db->Put(WriteOptions(), cfh1, primary_key,
  231. ConvertFloatsToSlice(embedding, dim)));
  232. }
  233. }
  234. // Search both indices with the same set of query vectors and make sure the
  235. // results match
  236. {
  237. constexpr faiss::idx_t num_query = 32;
  238. std::vector<float> embeddings_query(dim * num_query);
  239. faiss::float_rand(embeddings_query.data(), dim * num_query, 456);
  240. std::unique_ptr<Iterator> underlying_it(
  241. db->NewIterator(ReadOptions(), cfh2));
  242. auto secondary_it = std::make_unique<SecondaryIndexIterator>(
  243. faiss_ivf_index.get(), std::move(underlying_it));
  244. auto get_id = [](const Slice& key) -> faiss::idx_t {
  245. faiss::idx_t id = -1;
  246. if (std::from_chars(key.data(), key.data() + key.size(), id).ec !=
  247. std::errc()) {
  248. return -1;
  249. }
  250. return id;
  251. };
  252. for (size_t neighbors : {1, 2, 4}) {
  253. for (size_t probes : {1, 2, 4}) {
  254. for (faiss::idx_t i = 0; i < num_query; ++i) {
  255. const float* const embedding = embeddings_query.data() + i * dim;
  256. std::vector<float> distances(neighbors, 0.0f);
  257. std::vector<faiss::idx_t> ids(neighbors, -1);
  258. faiss::SearchParametersIVF params;
  259. params.nprobe = probes;
  260. index_cmp->search(1, embedding, neighbors, distances.data(),
  261. ids.data(), &params);
  262. size_t result_size_cmp = 0;
  263. for (faiss::idx_t id_cmp : ids) {
  264. if (id_cmp < 0) {
  265. break;
  266. }
  267. ++result_size_cmp;
  268. }
  269. std::vector<std::pair<std::string, float>> result;
  270. ASSERT_OK(faiss_ivf_index->FindKNearestNeighbors(
  271. secondary_it.get(), ConvertFloatsToSlice(embedding, dim),
  272. neighbors, probes, &result));
  273. ASSERT_EQ(result.size(), result_size_cmp);
  274. for (size_t j = 0; j < result.size(); ++j) {
  275. const faiss::idx_t id = get_id(result[j].first);
  276. ASSERT_GE(id, 0);
  277. ASSERT_LT(id, num_db);
  278. ASSERT_EQ(id, ids[j]);
  279. ASSERT_EQ(result[j].second, distances[j]);
  280. }
  281. }
  282. }
  283. }
  284. }
  285. }
  286. } // namespace ROCKSDB_NAMESPACE
  287. int main(int argc, char** argv) {
  288. ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
  289. ::testing::InitGoogleTest(&argc, argv);
  290. return RUN_ALL_TESTS();
  291. }