faiss_ivf_index.cc 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. // Copyright (c) Meta Platforms, Inc. and affiliates.
  2. // This source code is licensed under both the GPLv2 (found in the
  3. // COPYING file in the root directory) and Apache 2.0 License
  4. // (found in the LICENSE.Apache file in the root directory).
  5. #include <cassert>
  6. #include <optional>
  7. #include <stdexcept>
  8. #include <utility>
  9. #include "faiss/IndexIVF.h"
  10. #include "faiss/invlists/InvertedLists.h"
  11. #include "rocksdb/utilities/secondary_index_faiss.h"
  12. #include "util/autovector.h"
  13. #include "util/coding.h"
  14. namespace ROCKSDB_NAMESPACE {
  15. namespace {
  16. std::string SerializeLabel(faiss::idx_t label) {
  17. std::string label_str;
  18. PutVarsignedint64(&label_str, label);
  19. return label_str;
  20. }
  21. faiss::idx_t DeserializeLabel(Slice label_slice) {
  22. faiss::idx_t label = -1;
  23. [[maybe_unused]] const bool ok = GetVarsignedint64(&label_slice, &label);
  24. assert(ok);
  25. return label;
  26. }
  27. } // namespace
  28. struct FaissIVFIndex::KNNContext {
  29. SecondaryIndexIterator* it;
  30. autovector<std::string> keys;
  31. };
  32. class FaissIVFIndex::Adapter : public faiss::InvertedLists {
  33. public:
  34. Adapter(size_t num_lists, size_t code_size)
  35. : faiss::InvertedLists(num_lists, code_size) {
  36. use_iterator = true;
  37. }
  38. // Non-iterator-based read interface; not implemented/used since use_iterator
  39. // is true
  40. size_t list_size(size_t /* list_no */) const override {
  41. assert(false);
  42. return 0;
  43. }
  44. const uint8_t* get_codes(size_t /* list_no */) const override {
  45. assert(false);
  46. return nullptr;
  47. }
  48. const faiss::idx_t* get_ids(size_t /* list_no */) const override {
  49. assert(false);
  50. return nullptr;
  51. }
  52. // Iterator-based read interface
  53. faiss::InvertedListsIterator* get_iterator(
  54. size_t list_no, void* inverted_list_context = nullptr) const override {
  55. KNNContext* const knn_context =
  56. static_cast<KNNContext*>(inverted_list_context);
  57. assert(knn_context);
  58. return new IteratorAdapter(knn_context, list_no, code_size);
  59. }
  60. // Write interface; only add_entry is implemented/required for now
  61. size_t add_entry(size_t /* list_no */, faiss::idx_t /* id */,
  62. const uint8_t* code,
  63. void* inverted_list_context = nullptr) override {
  64. std::string* const code_str =
  65. static_cast<std::string*>(inverted_list_context);
  66. assert(code_str);
  67. code_str->assign(reinterpret_cast<const char*>(code), code_size);
  68. return 0;
  69. }
  70. size_t add_entries(size_t /* list_no */, size_t /* num_entries */,
  71. const faiss::idx_t* /* ids */,
  72. const uint8_t* /* code */) override {
  73. assert(false);
  74. return 0;
  75. }
  76. void update_entry(size_t /* list_no */, size_t /* offset */,
  77. faiss::idx_t /* id */, const uint8_t* /* code */) override {
  78. assert(false);
  79. }
  80. void update_entries(size_t /* list_no */, size_t /* offset */,
  81. size_t /* num_entries */, const faiss::idx_t* /* ids */,
  82. const uint8_t* /* code */) override {
  83. assert(false);
  84. }
  85. void resize(size_t /* list_no */, size_t /* new_size */) override {
  86. assert(false);
  87. }
  88. private:
  89. class IteratorAdapter : public faiss::InvertedListsIterator {
  90. public:
  91. IteratorAdapter(KNNContext* knn_context, size_t list_no, size_t code_size)
  92. : knn_context_(knn_context),
  93. it_(knn_context_->it),
  94. code_size_(code_size) {
  95. assert(knn_context_);
  96. assert(it_);
  97. const std::string label = SerializeLabel(list_no);
  98. it_->Seek(label);
  99. Update();
  100. }
  101. bool is_available() const override { return id_and_codes_.has_value(); }
  102. void next() override {
  103. it_->Next();
  104. Update();
  105. }
  106. std::pair<faiss::idx_t, const uint8_t*> get_id_and_codes() override {
  107. assert(is_available());
  108. return *id_and_codes_;
  109. }
  110. private:
  111. void Update() {
  112. id_and_codes_.reset();
  113. const Status status = it_->status();
  114. if (!status.ok()) {
  115. throw std::runtime_error(status.ToString());
  116. }
  117. if (!it_->Valid()) {
  118. return;
  119. }
  120. if (!it_->PrepareValue()) {
  121. throw std::runtime_error(
  122. "Failed to prepare value during iteration in FaissIVFIndex");
  123. }
  124. const Slice value = it_->value();
  125. if (value.size() != code_size_) {
  126. throw std::runtime_error(
  127. "Code with unexpected size encountered during iteration in "
  128. "FaissIVFIndex");
  129. }
  130. const faiss::idx_t id = knn_context_->keys.size();
  131. knn_context_->keys.emplace_back(it_->key().ToString());
  132. id_and_codes_.emplace(id, reinterpret_cast<const uint8_t*>(value.data()));
  133. }
  134. KNNContext* knn_context_;
  135. SecondaryIndexIterator* it_;
  136. size_t code_size_;
  137. std::optional<std::pair<faiss::idx_t, const uint8_t*>> id_and_codes_;
  138. };
  139. };
  140. FaissIVFIndex::FaissIVFIndex(std::unique_ptr<faiss::IndexIVF>&& index,
  141. std::string primary_column_name)
  142. : adapter_(std::make_unique<Adapter>(index->nlist, index->code_size)),
  143. index_(std::move(index)),
  144. primary_column_name_(std::move(primary_column_name)) {
  145. assert(index_);
  146. assert(index_->quantizer);
  147. index_->parallel_mode = 0;
  148. index_->replace_invlists(adapter_.get());
  149. }
  150. FaissIVFIndex::~FaissIVFIndex() = default;
  151. void FaissIVFIndex::SetPrimaryColumnFamily(ColumnFamilyHandle* column_family) {
  152. assert(column_family);
  153. primary_column_family_ = column_family;
  154. }
  155. void FaissIVFIndex::SetSecondaryColumnFamily(
  156. ColumnFamilyHandle* column_family) {
  157. assert(column_family);
  158. secondary_column_family_ = column_family;
  159. }
  160. ColumnFamilyHandle* FaissIVFIndex::GetPrimaryColumnFamily() const {
  161. return primary_column_family_;
  162. }
  163. ColumnFamilyHandle* FaissIVFIndex::GetSecondaryColumnFamily() const {
  164. return secondary_column_family_;
  165. }
  166. Slice FaissIVFIndex::GetPrimaryColumnName() const {
  167. return primary_column_name_;
  168. }
  169. Status FaissIVFIndex::UpdatePrimaryColumnValue(
  170. const Slice& /* primary_key */, const Slice& primary_column_value,
  171. std::optional<std::variant<Slice, std::string>>* updated_column_value)
  172. const {
  173. assert(updated_column_value);
  174. const float* const embedding =
  175. ConvertSliceToFloats(primary_column_value, index_->d);
  176. if (!embedding) {
  177. return Status::InvalidArgument(
  178. "Incorrectly sized vector passed to FaissIVFIndex");
  179. }
  180. constexpr faiss::idx_t n = 1;
  181. faiss::idx_t label = -1;
  182. try {
  183. index_->quantizer->assign(n, embedding, &label);
  184. } catch (const std::exception& e) {
  185. return Status::InvalidArgument(e.what());
  186. }
  187. if (label < 0 || label >= index_->nlist) {
  188. return Status::InvalidArgument(
  189. "Unexpected label returned by coarse quantizer");
  190. }
  191. updated_column_value->emplace(SerializeLabel(label));
  192. return Status::OK();
  193. }
  194. Status FaissIVFIndex::GetSecondaryKeyPrefix(
  195. const Slice& /* primary_key */, const Slice& primary_column_value,
  196. std::variant<Slice, std::string>* secondary_key_prefix) const {
  197. assert(secondary_key_prefix);
  198. [[maybe_unused]] const faiss::idx_t label =
  199. DeserializeLabel(primary_column_value);
  200. assert(label >= 0);
  201. assert(label < index_->nlist);
  202. *secondary_key_prefix = primary_column_value;
  203. return Status::OK();
  204. }
  205. Status FaissIVFIndex::FinalizeSecondaryKeyPrefix(
  206. std::variant<Slice, std::string>* /* secondary_key_prefix */) const {
  207. return Status::OK();
  208. }
  209. Status FaissIVFIndex::GetSecondaryValue(
  210. const Slice& /* primary_key */, const Slice& primary_column_value,
  211. const Slice& original_column_value,
  212. std::optional<std::variant<Slice, std::string>>* secondary_value) const {
  213. assert(secondary_value);
  214. const faiss::idx_t label = DeserializeLabel(primary_column_value);
  215. assert(label >= 0);
  216. assert(label < index_->nlist);
  217. constexpr faiss::idx_t n = 1;
  218. const float* const embedding =
  219. ConvertSliceToFloats(original_column_value, index_->d);
  220. assert(embedding);
  221. constexpr faiss::idx_t* xids = nullptr;
  222. std::string code_str;
  223. try {
  224. index_->add_core(n, embedding, xids, &label, &code_str);
  225. } catch (const std::exception& e) {
  226. return Status::Corruption(e.what());
  227. }
  228. if (code_str.size() != index_->code_size) {
  229. return Status::Corruption(
  230. "Code with unexpected size returned by fine quantizer");
  231. }
  232. secondary_value->emplace(std::move(code_str));
  233. return Status::OK();
  234. }
  235. Status FaissIVFIndex::FindKNearestNeighbors(
  236. SecondaryIndexIterator* it, const Slice& target, size_t neighbors,
  237. size_t probes, std::vector<std::pair<std::string, float>>* result) const {
  238. if (!it) {
  239. return Status::InvalidArgument("Secondary index iterator must be provided");
  240. }
  241. const float* const embedding = ConvertSliceToFloats(target, index_->d);
  242. if (!embedding) {
  243. return Status::InvalidArgument(
  244. "Incorrectly sized vector passed to FaissIVFIndex");
  245. }
  246. if (!neighbors) {
  247. return Status::InvalidArgument("Invalid number of neighbors");
  248. }
  249. if (!probes) {
  250. return Status::InvalidArgument("Invalid number of probes");
  251. }
  252. if (!result) {
  253. return Status::InvalidArgument("Result parameter must be provided");
  254. }
  255. result->clear();
  256. std::vector<float> distances(neighbors, 0.0f);
  257. std::vector<faiss::idx_t> ids(neighbors, -1);
  258. KNNContext knn_context{it, {}};
  259. faiss::SearchParametersIVF params;
  260. params.nprobe = probes;
  261. params.inverted_list_context = &knn_context;
  262. constexpr faiss::idx_t n = 1;
  263. try {
  264. index_->search(n, embedding, neighbors, distances.data(), ids.data(),
  265. &params);
  266. } catch (const std::exception& e) {
  267. return Status::Corruption(e.what());
  268. }
  269. result->reserve(neighbors);
  270. for (size_t i = 0; i < neighbors; ++i) {
  271. if (ids[i] < 0) {
  272. break;
  273. }
  274. if (ids[i] >= knn_context.keys.size()) {
  275. result->clear();
  276. return Status::Corruption("Unexpected id returned by FAISS");
  277. }
  278. result->emplace_back(knn_context.keys[ids[i]], distances[i]);
  279. }
  280. return Status::OK();
  281. }
  282. } // namespace ROCKSDB_NAMESPACE