AUTH's THMMY "Parallel and distributed systems" course assignments.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

v1.hpp 3.4 KiB

5 dagen geleden
5 dagen geleden
5 dagen geleden
5 dagen geleden
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. /**
  2. * \file v0.hpp
  3. * \brief
  4. *
  5. * \author
  6. * Christos Choutouridis AEM:8997
  7. * <cchoutou@ece.auth.gr>
  8. */
  9. #ifndef V1_HPP_
  10. #define V1_HPP_
  11. #include <vector>
  12. #include <algorithm>
  13. #include "matrix.hpp"
  14. #include "v0.hpp"
  15. #include "config.h"
  16. namespace v1 {
  17. template <typename DataType, typename IndexType>
  18. void mergeResultsWithM(mtx::Matrix<IndexType>& N1, mtx::Matrix<DataType>& D1,
  19. mtx::Matrix<IndexType>& N2, mtx::Matrix<DataType>& D2,
  20. size_t k, size_t m,
  21. mtx::Matrix<IndexType>& N, mtx::Matrix<DataType>& D) {
  22. size_t numQueries = N1.rows();
  23. size_t maxCandidates = std::min((IndexType)m, (IndexType)(N1.columns() + N2.columns()));
  24. for (size_t q = 0; q < numQueries; ++q) {
  25. // Combine distances and neighbors
  26. std::vector<std::pair<DataType, IndexType>> candidates(N1.columns() + N2.columns());
  27. // Concatenate N1 and N2 rows
  28. for (size_t i = 0; i < N1.columns(); ++i) {
  29. candidates[i] = {D1.get(q, i), N1.get(q, i)};
  30. }
  31. for (size_t i = 0; i < N2.columns(); ++i) {
  32. candidates[i + N1.columns()] = {D2.get(q, i), N2.get(q, i)};
  33. }
  34. // Keep only the top-m candidates
  35. v0::quickselect(candidates, maxCandidates);
  36. // Sort the top-m candidates
  37. std::sort(candidates.begin(), candidates.begin() + maxCandidates);
  38. // If m < k, pad the remaining slots with invalid values
  39. for (size_t i = 0; i < k; ++i) {
  40. if (i < maxCandidates) {
  41. D.set(candidates[i].first, q, i);
  42. N.set(candidates[i].second, q, i);
  43. } else {
  44. D.set(std::numeric_limits<DataType>::infinity(), q, i);
  45. N.set(static_cast<IndexType>(-1), q, i); // Invalid index (end)
  46. }
  47. }
  48. }
  49. }
  50. template<typename MatrixD, typename MatrixI>
  51. void knnsearch(const MatrixD& C, const MatrixD& Q, size_t idx_offset, size_t k, size_t m, MatrixI& idx, MatrixD& dst) {
  52. using DstType = typename MatrixD::dataType;
  53. using IdxType = typename MatrixI::dataType;
  54. if (C.rows() <= 8 || Q.rows() <= 4) {
  55. // Base case: Call knnsearch directly
  56. v0::knnsearch(C, Q, idx_offset, k, m, idx, dst);
  57. return;
  58. }
  59. // Divide Corpus and Query into subsets
  60. IdxType midC = C.rows() / 2;
  61. IdxType midQ = Q.rows() / 2;
  62. // Slice corpus and query matrixes
  63. MatrixD C1((DstType*)C.data(), 0, midC, C.columns());
  64. MatrixD C2((DstType*)C.data(), midC, midC, C.columns());
  65. MatrixD Q1((DstType*)Q.data(), 0, midQ, Q.columns());
  66. MatrixD Q2((DstType*)Q.data(), midQ, midQ, Q.columns());
  67. // Allocate temporary matrixes for all permutations
  68. MatrixI N1_1(midQ, k), N1_2(midQ, k), N2_1(midQ, k), N2_2(midQ, k);
  69. MatrixD D1_1(midQ, k), D1_2(midQ, k), D2_1(midQ, k), D2_2(midQ, k);
  70. // Recursive calls
  71. knnsearch(C1, Q1, idx_offset, k, m, N1_1, D1_1);
  72. knnsearch(C2, Q1, idx_offset + midC, k, m, N1_2, D1_2);
  73. knnsearch(C1, Q2, idx_offset, k, m, N2_1, D2_1);
  74. knnsearch(C2, Q2, idx_offset + midC, k, m, N2_2, D2_2);
  75. // slice output matrixes
  76. MatrixI N1((IdxType*)idx.data(), 0, midQ, k);
  77. MatrixI N2((IdxType*)idx.data(), midQ, midQ, k);
  78. MatrixD D1((DstType*)dst.data(), 0, midQ, k);
  79. MatrixD D2((DstType*)dst.data(), midQ, midQ, k);
  80. // Merge results in place
  81. mergeResultsWithM(N1_1, D1_1, N1_2, D1_2, k, m, N1, D1);
  82. mergeResultsWithM(N2_1, D2_1, N2_2, D2_2, k, m, N2, D2);
  83. }
  84. } // namespace v1
  85. #endif /* V1_HPP_ */