|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- /**
- * \file v0.hpp
- * \brief
- *
- * \author
- * Christos Choutouridis AEM:8997
- * <cchoutou@ece.auth.gr>
- */
- #ifndef V1_HPP_
- #define V1_HPP_
-
- #include <vector>
- #include <algorithm>
-
- #include "matrix.hpp"
- #include "v0.hpp"
- #include "config.h"
-
- namespace v1 {
-
- template <typename DataType, typename IndexType>
- void mergeResultsWithM(mtx::Matrix<IndexType>& N1, mtx::Matrix<DataType>& D1,
- mtx::Matrix<IndexType>& N2, mtx::Matrix<DataType>& D2,
- size_t k, size_t m,
- mtx::Matrix<IndexType>& N, mtx::Matrix<DataType>& D) {
- size_t numQueries = N1.rows();
- size_t maxCandidates = std::min((IndexType)m, (IndexType)(N1.columns() + N2.columns()));
-
- for (size_t q = 0; q < numQueries; ++q) {
- // Combine distances and neighbors
- std::vector<std::pair<DataType, IndexType>> candidates(N1.columns() + N2.columns());
-
- // Concatenate N1 and N2 rows
- for (size_t i = 0; i < N1.columns(); ++i) {
- candidates[i] = {D1.get(q, i), N1.get(q, i)};
- }
- for (size_t i = 0; i < N2.columns(); ++i) {
- candidates[i + N1.columns()] = {D2.get(q, i), N2.get(q, i)};
- }
-
- // Keep only the top-m candidates
- v0::quickselect(candidates, maxCandidates);
-
- // Sort the top-m candidates
- std::sort(candidates.begin(), candidates.begin() + maxCandidates);
-
- // If m < k, pad the remaining slots with invalid values
- for (size_t i = 0; i < k; ++i) {
- if (i < maxCandidates) {
- D.set(candidates[i].first, q, i);
- N.set(candidates[i].second, q, i);
- } else {
- D.set(std::numeric_limits<DataType>::infinity(), q, i);
- N.set(static_cast<IndexType>(-1), q, i); // Invalid index (end)
- }
- }
- }
- }
-
-
-
- template<typename MatrixD, typename MatrixI>
- void knnsearch(const MatrixD& C, const MatrixD& Q, size_t idx_offset, size_t k, size_t m, MatrixI& idx, MatrixD& dst) {
-
- using DstType = typename MatrixD::dataType;
- using IdxType = typename MatrixI::dataType;
-
- if (C.rows() <= 8 || Q.rows() <= 4) {
- // Base case: Call knnsearch directly
- v0::knnsearch(C, Q, idx_offset, k, m, idx, dst);
- return;
- }
-
- // Divide Corpus and Query into subsets
- IdxType midC = C.rows() / 2;
- IdxType midQ = Q.rows() / 2;
-
- // Slice corpus and query matrixes
- MatrixD C1((DstType*)C.data(), 0, midC, C.columns());
- MatrixD C2((DstType*)C.data(), midC, midC, C.columns());
- MatrixD Q1((DstType*)Q.data(), 0, midQ, Q.columns());
- MatrixD Q2((DstType*)Q.data(), midQ, midQ, Q.columns());
-
- // Allocate temporary matrixes for all permutations
- MatrixI N1_1(midQ, k), N1_2(midQ, k), N2_1(midQ, k), N2_2(midQ, k);
- MatrixD D1_1(midQ, k), D1_2(midQ, k), D2_1(midQ, k), D2_2(midQ, k);
-
- // Recursive calls
- knnsearch(C1, Q1, idx_offset, k, m, N1_1, D1_1);
- knnsearch(C2, Q1, idx_offset + midC, k, m, N1_2, D1_2);
- knnsearch(C1, Q2, idx_offset, k, m, N2_1, D2_1);
- knnsearch(C2, Q2, idx_offset + midC, k, m, N2_2, D2_2);
-
- // slice output matrixes
- MatrixI N1((IdxType*)idx.data(), 0, midQ, k);
- MatrixI N2((IdxType*)idx.data(), midQ, midQ, k);
- MatrixD D1((DstType*)dst.data(), 0, midQ, k);
- MatrixD D2((DstType*)dst.data(), midQ, midQ, k);
-
- // Merge results in place
- mergeResultsWithM(N1_1, D1_1, N1_2, D1_2, k, m, N1, D1);
- mergeResultsWithM(N2_1, D2_1, N2_2, D2_2, k, m, N2, D2);
-
- }
-
- } // namespace v1
-
- #endif /* V1_HPP_ */
|