AUTH's THMMY "Parallel and distributed systems" course assignments.
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

v0.hpp 3.5 KiB

il y a 5 jours
il y a 5 jours
il y a 5 jours
il y a 5 jours
il y a 5 jours
il y a 5 jours
il y a 5 jours
il y a 5 jours
il y a 5 jours
il y a 5 jours
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. /**
  2. * \file v0.hpp
  3. * \brief
  4. *
  5. * \author
  6. * Christos Choutouridis AEM:8997
  7. * <cchoutou@ece.auth.gr>
  8. */
  9. #ifndef V0_HPP_
  10. #define V0_HPP_
  11. #include <cblas.h>
  12. #include <cmath>
  13. #include <vector>
  14. #include <algorithm>
  15. #include "matrix.hpp"
  16. #include "config.h"
  17. namespace v0 {
  18. /*!
  19. * Function to compute squared Euclidean distances
  20. *
  21. * \fn void pdist2(const double*, const double*, double*, int, int, int)
  22. * \param X m x d matrix (Column major)
  23. * \param Y n x d matrix (Column major)
  24. * \param D2 m x n matrix to store distances (Column major)
  25. * \param m number of rows in X
  26. * \param n number of rows in Y
  27. * \param d number of columns in both X and Y
  28. */
  29. template<typename Matrix>
  30. void pdist2(const Matrix& X, const Matrix& Y, Matrix& D2) {
  31. using DataType = typename Matrix::dataType;
  32. int M = X.rows();
  33. int N = Y.rows();
  34. int d = X.columns();
  35. // Compute the squared norms of each row in X and Y
  36. std::vector<DataType> X_norms(M), Y_norms(N);
  37. for (int i = 0; i < M ; ++i) {
  38. X_norms[i] = cblas_ddot(d, X.data() + i * d, 1, X.data() + i * d, 1);
  39. }
  40. for (int j = 0; j < N ; ++j) {
  41. Y_norms[j] = cblas_ddot(d, Y.data() + j * d, 1, Y.data() + j * d, 1);
  42. }
  43. // Compute -2 * X * Y'
  44. cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, M, N, d, -2.0, X.data(), d, Y.data(), d, 0.0, D2.data(), N);
  45. // Step 3: Add the squared norms to each entry in D2
  46. for (int i = 0; i < M ; ++i) {
  47. for (int j = 0; j < N; ++j) {
  48. D2.set(D2.get(i, j) + X_norms[i] + Y_norms[j], i, j);
  49. //D2.set(std::max(D2.get(i, j), 0.0), i, j); // Ensure non-negative
  50. D2.set(std::sqrt(D2.get(i, j)), i, j); // Take the square root of each
  51. }
  52. }
  53. }
  54. template<typename DataType, typename IndexType>
  55. void quickselect(std::vector<std::pair<DataType, IndexType>>& vec, int k) {
  56. std::nth_element(
  57. vec.begin(),
  58. vec.begin() + k,
  59. vec.end(),
  60. [](const std::pair<DataType, IndexType>& a, const std::pair<DataType, IndexType>& b) {
  61. return a.first < b.first;
  62. });
  63. vec.resize(k); // Keep only the k smallest elements
  64. }
  65. /*!
  66. * \param C Is a MxD matrix (Corpus)
  67. * \param Q Is a NxD matrix (Query)
  68. * \param k The number of nearest neighbors needed
  69. * \param idx Is the Nxk matrix with the k indexes of the C points, that are
  70. * neighbors of the nth point of Q
  71. * \param dst Is the Nxk matrix with the k distances to the C points of the nth
  72. * point of Q
  73. */
  74. template<typename MatrixD, typename MatrixI>
  75. void knnsearch(const MatrixD& C, const MatrixD& Q, size_t idx_offset, size_t k, size_t m, MatrixI& idx, MatrixD& dst) {
  76. using DstType = typename MatrixD::dataType;
  77. using IdxType = typename MatrixI::dataType;
  78. size_t M = C.rows();
  79. size_t N = Q.rows();
  80. mtx::Matrix<DstType> D(M, N);
  81. pdist2(C, Q, D);
  82. for (size_t j = 0; j < N; ++j) {
  83. // Create a vector of pairs (distance, index) for the j-th query
  84. std::vector<std::pair<DstType, IdxType>> dst_idx(M);
  85. for (size_t i = 0; i < M; ++i) {
  86. dst_idx[i] = {D.data()[i * N + j], i};
  87. }
  88. // Find the k smallest distances using quickSelectKSmallest
  89. quickselect(dst_idx, k);
  90. // Sort the k smallest results by distance for consistency
  91. std::sort(dst_idx.begin(), dst_idx.end());
  92. // Store the indices and distances
  93. for (size_t i = 0; i < k; ++i) {
  94. dst.set(dst_idx[i].first, j, i);
  95. idx.set(dst_idx[i].second + idx_offset, j, i);
  96. }
  97. }
  98. }
  99. }
  100. #endif /* V0_HPP_ */