AUTH's THMMY "Parallel and distributed systems" course assignments.
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

130 行
3.9 KiB

  1. /**
  2. * \file v0.hpp
  3. * \brief
  4. *
  5. * \author
  6. * Christos Choutouridis AEM:8997
  7. * <cchoutou@ece.auth.gr>
  8. */
  9. #ifndef V0_HPP_
  10. #define V0_HPP_
  11. #include <cblas.h>
  12. #include <cmath>
  13. #include <vector>
  14. #include <algorithm>
  15. #include "matrix.hpp"
  16. #include "config.h"
  17. namespace v0 {
  18. /*!
  19. * Function to compute squared Euclidean distances
  20. *
  21. * \fn void pdist2(const double*, const double*, double*, int, int, int)
  22. * \param X m x d matrix (Column major)
  23. * \param Y n x d matrix (Column major)
  24. * \param D2 m x n matrix to store distances (Column major)
  25. * \param m number of rows in X
  26. * \param n number of rows in Y
  27. * \param d number of columns in both X and Y
  28. */
  29. template<typename Matrix>
  30. void pdist2(const Matrix& X, const Matrix& Y, Matrix& D2) {
  31. using DataType = typename Matrix::dataType;
  32. int M = X.rows();
  33. int N = Y.rows();
  34. int d = X.columns();
  35. // Compute the squared norms of each row in X and Y
  36. std::vector<DataType> X_norms(M), Y_norms(N);
  37. for (int i = 0; i < M ; ++i) {
  38. X_norms[i] = cblas_ddot(d, X.data() + i * d, 1, X.data() + i * d, 1);
  39. }
  40. for (int j = 0; j < N ; ++j) {
  41. Y_norms[j] = cblas_ddot(d, Y.data() + j * d, 1, Y.data() + j * d, 1);
  42. }
  43. // Compute -2 * X * Y'
  44. cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, M, N, d, -2.0, X.data(), d, Y.data(), d, 0.0, D2.data(), N);
  45. // Step 3: Add the squared norms to each entry in D2
  46. for (int i = 0; i < M ; ++i) {
  47. for (int j = 0; j < N; ++j) {
  48. D2.set(D2.get(i, j) + X_norms[i] + Y_norms[j], i, j);
  49. D2.set(std::max(D2.get(i, j), 0.0), i, j); // Ensure non-negative
  50. D2.set(std::sqrt(D2.get(i, j)), i, j); // Take the square root of each
  51. }
  52. }
  53. M++;
  54. }
  55. /*!
  56. * Quick select implementation
  57. * \fn void quickselect(std::vector<std::pair<DataType,IndexType>>&, int)
  58. * \tparam DataType
  59. * \tparam IndexType
  60. * \param vec Vector of paire(distance, index) to partially sort over distance
  61. * \param k The number of elements to sort-select
  62. */
  63. template<typename DataType, typename IndexType>
  64. void quickselect(std::vector<std::pair<DataType, IndexType>>& vec, int k) {
  65. std::nth_element(
  66. vec.begin(),
  67. vec.begin() + k,
  68. vec.end(),
  69. [](const std::pair<DataType, IndexType>& a, const std::pair<DataType, IndexType>& b) {
  70. return a.first < b.first;
  71. });
  72. vec.resize(k); // Keep only the k smallest elements
  73. }
  74. /*!
  75. * \param C Is a MxD matrix (Corpus)
  76. * \param Q Is a NxD matrix (Query)
  77. * \param idx_offset The offset of the indexes for output (to match with the actual Corpus indexes)
  78. * \param k The number of nearest neighbors needed
  79. * \param idx Is the Nxk matrix with the k indexes of the C points, that are
  80. * neighbors of the nth point of Q
  81. * \param dst Is the Nxk matrix with the k distances to the C points of the nth
  82. * point of Q
  83. */
  84. template<typename MatrixD, typename MatrixI>
  85. void knnsearch(MatrixD& C, MatrixD& Q, size_t idx_offset, size_t k, size_t m, MatrixI& idx, MatrixD& dst) {
  86. using DstType = typename MatrixD::dataType;
  87. using IdxType = typename MatrixI::dataType;
  88. size_t M = C.rows();
  89. size_t N = Q.rows();
  90. mtx::Matrix<DstType> D(M, N);
  91. pdist2(C, Q, D);
  92. for (size_t j = 0; j < N; ++j) {
  93. // Create a vector of pairs (distance, index) for the j-th query
  94. std::vector<std::pair<DstType, IdxType>> dst_idx(M);
  95. for (size_t i = 0; i < M; ++i) {
  96. dst_idx[i] = {D.data()[i * N + j], i};
  97. }
  98. // Find the k smallest distances using quickSelectKSmallest
  99. quickselect(dst_idx, k);
  100. // Sort the k smallest results by distance for consistency
  101. std::sort(dst_idx.begin(), dst_idx.end());
  102. // Store the indices and distances
  103. for (size_t i = 0; i < k; ++i) {
  104. dst.set(dst_idx[i].first, j, i);
  105. idx.set(dst_idx[i].second + idx_offset, j, i);
  106. }
  107. }
  108. }
  109. }
  110. #endif /* V0_HPP_ */