/** * \file v0.hpp * \brief * * \author * Christos Choutouridis AEM:8997 * */ #ifndef V1_HPP_ #define V1_HPP_ #include #include #include "matrix.hpp" #include "v0.hpp" #include "config.h" namespace v1 { template void mergeResultsWithM(mtx::Matrix& N1, mtx::Matrix& D1, mtx::Matrix& N2, mtx::Matrix& D2, size_t k, size_t m, mtx::Matrix& N, mtx::Matrix& D) { size_t numQueries = N1.rows(); size_t maxCandidates = std::min((IndexType)m, (IndexType)(N1.columns() + N2.columns())); for (size_t q = 0; q < numQueries; ++q) { // Combine distances and neighbors std::vector> candidates(N1.columns() + N2.columns()); // Concatenate N1 and N2 rows for (size_t i = 0; i < N1.columns(); ++i) { candidates[i] = {D1.get(q, i), N1.get(q, i)}; } for (size_t i = 0; i < N2.columns(); ++i) { candidates[i + N1.columns()] = {D2.get(q, i), N2.get(q, i)}; } // Keep only the top-m candidates v0::quickselect(candidates, maxCandidates); // Sort the top-m candidates std::sort(candidates.begin(), candidates.begin() + maxCandidates); // If m < k, pad the remaining slots with invalid values for (size_t i = 0; i < k; ++i) { if (i < maxCandidates) { D.set(candidates[i].first, q, i); N.set(candidates[i].second, q, i); } else { D.set(std::numeric_limits::infinity(), q, i); N.set(static_cast(-1), q, i); // Invalid index (end) } } } } template void knnsearch(const MatrixD& C, const MatrixD& Q, size_t idx_offset, size_t k, size_t m, MatrixI& idx, MatrixD& dst) { using DstType = typename MatrixD::dataType; using IdxType = typename MatrixI::dataType; if (C.rows() <= 8 || Q.rows() <= 4) { // Base case: Call knnsearch directly v0::knnsearch(C, Q, idx_offset, k, m, idx, dst); return; } // Divide Corpus and Query into subsets IdxType midC = C.rows() / 2; IdxType midQ = Q.rows() / 2; // Slice corpus and query matrixes MatrixD C1((DstType*)C.data(), 0, midC, C.columns()); MatrixD C2((DstType*)C.data(), midC, midC, C.columns()); MatrixD Q1((DstType*)Q.data(), 0, midQ, Q.columns()); MatrixD Q2((DstType*)Q.data(), midQ, midQ, Q.columns()); // Allocate temporary matrixes for all permutations MatrixI N1_1(midQ, k), N1_2(midQ, k), N2_1(midQ, k), N2_2(midQ, k); MatrixD D1_1(midQ, k), D1_2(midQ, k), D2_1(midQ, k), D2_2(midQ, k); // Recursive calls knnsearch(C1, Q1, idx_offset, k, m, N1_1, D1_1); knnsearch(C2, Q1, idx_offset + midC, k, m, N1_2, D1_2); knnsearch(C1, Q2, idx_offset, k, m, N2_1, D2_1); knnsearch(C2, Q2, idx_offset + midC, k, m, N2_2, D2_2); // slice output matrixes MatrixI N1((IdxType*)idx.data(), 0, midQ, k); MatrixI N2((IdxType*)idx.data(), midQ, midQ, k); MatrixD D1((DstType*)dst.data(), 0, midQ, k); MatrixD D2((DstType*)dst.data(), midQ, midQ, k); // Merge results in place mergeResultsWithM(N1_1, D1_1, N1_2, D1_2, k, m, N1, D1); mergeResultsWithM(N2_1, D2_1, N2_2, D2_2, k, m, N2, D2); } } // namespace v1 #endif /* V1_HPP_ */