AUTH's THMMY "Parallel and distributed systems" course assignments.
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

v0.hpp 3.5 KiB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. /**
  2. * \file v0.hpp
  3. * \brief
  4. *
  5. * \author
  6. * Christos Choutouridis AEM:8997
  7. * <cchoutou@ece.auth.gr>
  8. */
  9. #ifndef V0_HPP_
  10. #define V0_HPP_
  11. #include <cblas.h>
  12. #include <cmath>
  13. #include <vector>
  14. #include <algorithm>
  15. #include <matrix.hpp>
  16. #include <config.h>
  17. namespace v0 {
  18. /*!
  19. * Function to compute squared Euclidean distances
  20. *
  21. * \fn void pdist2(const double*, const double*, double*, int, int, int)
  22. * \param X m x d matrix (Column major)
  23. * \param Y n x d matrix (Column major)
  24. * \param D2 m x n matrix to store distances (Column major)
  25. * \param m number of rows in X
  26. * \param n number of rows in Y
  27. * \param d number of columns in both X and Y
  28. */
  29. template<typename DataType>
  30. void pdist2(const mtx::Matrix<DataType>& X, const mtx::Matrix<DataType>& Y, mtx::Matrix<DataType>& D2) {
  31. int M = X.rows();
  32. int N = Y.rows();
  33. int d = X.columns();
  34. // Compute the squared norms of each row in X and Y
  35. std::vector<DataType> X_norms(M), Y_norms(N);
  36. for (int i = 0; i < M ; ++i) {
  37. X_norms[i] = cblas_ddot(d, X.data() + i * d, 1, X.data() + i * d, 1);
  38. }
  39. for (int j = 0; j < N ; ++j) {
  40. Y_norms[j] = cblas_ddot(d, Y.data() + j * d, 1, Y.data() + j * d, 1);
  41. }
  42. // Compute -2 * X * Y'
  43. cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, M, N, d, -2.0, X.data(), d, Y.data(), d, 0.0, D2.data(), N);
  44. // Step 3: Add the squared norms to each entry in D2
  45. for (int i = 0; i < M ; ++i) {
  46. for (int j = 0; j < N; ++j) {
  47. D2.set(D2.get(i, j) + X_norms[i] + Y_norms[j], i, j);
  48. //D2.set(std::max(D2.get(i, j), 0.0), i, j); // Ensure non-negative
  49. D2.set(std::sqrt(D2.get(i, j)), i, j); // Take the square root of each
  50. }
  51. }
  52. }
  53. template<typename DataType, typename IndexType>
  54. void quickselect(std::vector<std::pair<DataType, IndexType>>& vec, int k) {
  55. std::nth_element(
  56. vec.begin(),
  57. vec.begin() + k,
  58. vec.end(),
  59. [](const std::pair<DataType, IndexType>& a, const std::pair<DataType, IndexType>& b) {
  60. return a.first < b.first;
  61. });
  62. vec.resize(k); // Keep only the k smallest elements
  63. }
  64. /*!
  65. * \param C Is a MxD matrix (Corpus)
  66. * \param Q Is a NxD matrix (Query)
  67. * \param k The number of nearest neighbors needed
  68. * \param idx Is the Nxk matrix with the k indexes of the C points, that are
  69. * neighbors of the nth point of Q
  70. * \param dst Is the Nxk matrix with the k distances to the C points of the nth
  71. * point of Q
  72. */
  73. template<typename DataType, typename IndexType>
  74. void knnsearch(const mtx::Matrix<DataType>& C, const mtx::Matrix<DataType>& Q, int k,
  75. mtx::Matrix<IndexType>& idx,
  76. mtx::Matrix<DataType>& dst) {
  77. int M = C.rows();
  78. int N = Q.rows();
  79. mtx::Matrix<DataType> D(M, N);
  80. pdist2(C, Q, D);
  81. idx.resize(N, k);
  82. dst.resize(N, k);
  83. for (int j = 0; j < N; ++j) {
  84. // Create a vector of pairs (distance, index) for the j-th query
  85. std::vector<std::pair<DataType, IndexType>> dst_idx(M);
  86. for (int i = 0; i < M; ++i) {
  87. dst_idx[i] = {D.data()[i * N + j], i};
  88. }
  89. // Find the k smallest distances using quickSelectKSmallest
  90. quickselect(dst_idx, k);
  91. // Sort the k smallest results by distance for consistency
  92. std::sort(dst_idx.begin(), dst_idx.end());
  93. // Store the indices and distances
  94. for (int i = 0; i < k; ++i) {
  95. idx(j, i) = dst_idx[i].second;
  96. dst(j, i) = dst_idx[i].first;
  97. }
  98. }
  99. }
  100. }
  101. #endif /* V0_HPP_ */